1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "strings"
16 "sync"
17 _ "unsafe"
18 )
19
20
21
22 var errMessageTooLarge = errors.New("message too large")
23
24
25
26 type Reader struct {
27 R *bufio.Reader
28 dot *dotReader
29 buf []byte
30 }
31
32
33
34
35
36
37 func NewReader(r *bufio.Reader) *Reader {
38 return &Reader{R: r}
39 }
40
41
42
43 func (r *Reader) ReadLine() (string, error) {
44 line, err := r.readLineSlice(-1)
45 return string(line), err
46 }
47
48
49 func (r *Reader) ReadLineBytes() ([]byte, error) {
50 line, err := r.readLineSlice(-1)
51 if line != nil {
52 line = bytes.Clone(line)
53 }
54 return line, err
55 }
56
57
58
59
60 func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
61 r.closeDot()
62 var line []byte
63 for {
64 l, more, err := r.R.ReadLine()
65 if err != nil {
66 return nil, err
67 }
68 if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
69 return nil, errMessageTooLarge
70 }
71
72 if line == nil && !more {
73 return l, nil
74 }
75 line = append(line, l...)
76 if !more {
77 break
78 }
79 }
80 return line, nil
81 }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101 func (r *Reader) ReadContinuedLine() (string, error) {
102 line, err := r.readContinuedLineSlice(-1, noValidation)
103 return string(line), err
104 }
105
106
107
108 func trim(s []byte) []byte {
109 i := 0
110 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
111 i++
112 }
113 n := len(s)
114 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
115 n--
116 }
117 return s[i:n]
118 }
119
120
121
122 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
123 line, err := r.readContinuedLineSlice(-1, noValidation)
124 if line != nil {
125 line = bytes.Clone(line)
126 }
127 return line, err
128 }
129
130
131
132
133
134
135 func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
136 if validateFirstLine == nil {
137 return nil, fmt.Errorf("missing validateFirstLine func")
138 }
139
140
141 line, err := r.readLineSlice(lim)
142 if err != nil {
143 return nil, err
144 }
145 if len(line) == 0 {
146 return line, nil
147 }
148
149 if err := validateFirstLine(line); err != nil {
150 return nil, err
151 }
152
153
154
155
156
157 if r.R.Buffered() > 1 {
158 peek, _ := r.R.Peek(2)
159 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
160 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
161 return trim(line), nil
162 }
163 }
164
165
166
167 r.buf = append(r.buf[:0], trim(line)...)
168
169 if lim < 0 {
170 lim = math.MaxInt64
171 }
172 lim -= int64(len(r.buf))
173
174
175 for r.skipSpace() > 0 {
176 r.buf = append(r.buf, ' ')
177 if int64(len(r.buf)) >= lim {
178 return nil, errMessageTooLarge
179 }
180 line, err := r.readLineSlice(lim - int64(len(r.buf)))
181 if err != nil {
182 break
183 }
184 r.buf = append(r.buf, trim(line)...)
185 }
186 return r.buf, nil
187 }
188
189
190 func (r *Reader) skipSpace() int {
191 n := 0
192 for {
193 c, err := r.R.ReadByte()
194 if err != nil {
195
196 break
197 }
198 if c != ' ' && c != '\t' {
199 r.R.UnreadByte()
200 break
201 }
202 n++
203 }
204 return n
205 }
206
207 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
208 line, err := r.ReadLine()
209 if err != nil {
210 return
211 }
212 return parseCodeLine(line, expectCode)
213 }
214
215 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
216 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
217 err = ProtocolError("short response: " + line)
218 return
219 }
220 continued = line[3] == '-'
221 code, err = strconv.Atoi(line[0:3])
222 if err != nil || code < 100 {
223 err = ProtocolError("invalid response code: " + line)
224 return
225 }
226 message = line[4:]
227 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
228 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
229 100 <= expectCode && expectCode < 1000 && code != expectCode {
230 err = &Error{code, message}
231 }
232 return
233 }
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
253 code, continued, message, err := r.readCodeLine(expectCode)
254 if err == nil && continued {
255 err = ProtocolError("unexpected multi-line response: " + message)
256 }
257 return
258 }
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
287 code, continued, message, err := r.readCodeLine(expectCode)
288 multi := continued
289 for continued {
290 line, err := r.ReadLine()
291 if err != nil {
292 return 0, "", err
293 }
294
295 var code2 int
296 var moreMessage string
297 code2, continued, moreMessage, err = parseCodeLine(line, 0)
298 if err != nil || code2 != code {
299 message += "\n" + strings.TrimRight(line, "\r\n")
300 continued = true
301 continue
302 }
303 message += "\n" + moreMessage
304 }
305 if err != nil && multi && message != "" {
306
307 err = &Error{code, message}
308 }
309 return
310 }
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328 func (r *Reader) DotReader() io.Reader {
329 r.closeDot()
330 r.dot = &dotReader{r: r}
331 return r.dot
332 }
333
334 type dotReader struct {
335 r *Reader
336 state int
337 }
338
339
340 func (d *dotReader) Read(b []byte) (n int, err error) {
341
342
343
344 const (
345 stateBeginLine = iota
346 stateDot
347 stateDotCR
348 stateCR
349 stateData
350 stateEOF
351 )
352 br := d.r.R
353 for n < len(b) && d.state != stateEOF {
354 var c byte
355 c, err = br.ReadByte()
356 if err != nil {
357 if err == io.EOF {
358 err = io.ErrUnexpectedEOF
359 }
360 break
361 }
362 switch d.state {
363 case stateBeginLine:
364 if c == '.' {
365 d.state = stateDot
366 continue
367 }
368 if c == '\r' {
369 d.state = stateCR
370 continue
371 }
372 d.state = stateData
373
374 case stateDot:
375 if c == '\r' {
376 d.state = stateDotCR
377 continue
378 }
379 if c == '\n' {
380 d.state = stateEOF
381 continue
382 }
383 d.state = stateData
384
385 case stateDotCR:
386 if c == '\n' {
387 d.state = stateEOF
388 continue
389 }
390
391
392 br.UnreadByte()
393 c = '\r'
394 d.state = stateData
395
396 case stateCR:
397 if c == '\n' {
398 d.state = stateBeginLine
399 break
400 }
401
402 br.UnreadByte()
403 c = '\r'
404 d.state = stateData
405
406 case stateData:
407 if c == '\r' {
408 d.state = stateCR
409 continue
410 }
411 if c == '\n' {
412 d.state = stateBeginLine
413 }
414 }
415 b[n] = c
416 n++
417 }
418 if err == nil && d.state == stateEOF {
419 err = io.EOF
420 }
421 if err != nil && d.r.dot == d {
422 d.r.dot = nil
423 }
424 return
425 }
426
427
428
429 func (r *Reader) closeDot() {
430 if r.dot == nil {
431 return
432 }
433 buf := make([]byte, 128)
434 for r.dot != nil {
435
436
437 r.dot.Read(buf)
438 }
439 }
440
441
442
443
444 func (r *Reader) ReadDotBytes() ([]byte, error) {
445 return io.ReadAll(r.DotReader())
446 }
447
448
449
450
451
452 func (r *Reader) ReadDotLines() ([]string, error) {
453
454
455
456 var v []string
457 var err error
458 for {
459 var line string
460 line, err = r.ReadLine()
461 if err != nil {
462 if err == io.EOF {
463 err = io.ErrUnexpectedEOF
464 }
465 break
466 }
467
468
469 if len(line) > 0 && line[0] == '.' {
470 if len(line) == 1 {
471 break
472 }
473 line = line[1:]
474 }
475 v = append(v, line)
476 }
477 return v, err
478 }
479
480 var colon = []byte(":")
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
502 return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
503 }
504
505
506
507
508
509
510 func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
511
512
513
514 var strs []string
515 hint := r.upcomingHeaderKeys()
516 if hint > 0 {
517 if hint > 1000 {
518 hint = 1000
519 }
520 strs = make([]string, hint)
521 }
522
523 m := make(MIMEHeader, hint)
524
525
526
527
528 maxMemory -= 400
529 const mapEntryOverhead = 200
530
531
532 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
533 const errorLimit = 80
534 line, err := r.readLineSlice(errorLimit)
535 if err != nil {
536 return m, err
537 }
538 return m, ProtocolError("malformed MIME header initial line: " + string(line))
539 }
540
541 for {
542 kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
543 if len(kv) == 0 {
544 return m, err
545 }
546
547
548 k, v, ok := bytes.Cut(kv, colon)
549 if !ok {
550 return m, ProtocolError("malformed MIME header line: " + string(kv))
551 }
552 key, ok := canonicalMIMEHeaderKey(k)
553 if !ok {
554 return m, ProtocolError("malformed MIME header line: " + string(kv))
555 }
556 for _, c := range v {
557 if !validHeaderValueByte(c) {
558 return m, ProtocolError("malformed MIME header line: " + string(kv))
559 }
560 }
561
562 maxHeaders--
563 if maxHeaders < 0 {
564 return nil, errMessageTooLarge
565 }
566
567
568 value := string(bytes.TrimLeft(v, " \t"))
569
570 vv := m[key]
571 if vv == nil {
572 maxMemory -= int64(len(key))
573 maxMemory -= mapEntryOverhead
574 }
575 maxMemory -= int64(len(value))
576 if maxMemory < 0 {
577 return m, errMessageTooLarge
578 }
579 if vv == nil && len(strs) > 0 {
580
581
582
583
584 vv, strs = strs[:1:1], strs[1:]
585 vv[0] = value
586 m[key] = vv
587 } else {
588 m[key] = append(vv, value)
589 }
590
591 if err != nil {
592 return m, err
593 }
594 }
595 }
596
597
598
599 func noValidation(_ []byte) error { return nil }
600
601
602
603
604 func mustHaveFieldNameColon(line []byte) error {
605 if bytes.IndexByte(line, ':') < 0 {
606 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
607 }
608 return nil
609 }
610
611 var nl = []byte("\n")
612
613
614
615 func (r *Reader) upcomingHeaderKeys() (n int) {
616
617 r.R.Peek(1)
618 s := r.R.Buffered()
619 if s == 0 {
620 return
621 }
622 peek, _ := r.R.Peek(s)
623 for len(peek) > 0 && n < 1000 {
624 var line []byte
625 line, peek, _ = bytes.Cut(peek, nl)
626 if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
627
628 break
629 }
630 if line[0] == ' ' || line[0] == '\t' {
631
632 continue
633 }
634 n++
635 }
636 return n
637 }
638
639
640
641
642
643
644
645
646
647 func CanonicalMIMEHeaderKey(s string) string {
648
649 upper := true
650 for i := 0; i < len(s); i++ {
651 c := s[i]
652 if !validHeaderFieldByte(c) {
653 return s
654 }
655 if upper && 'a' <= c && c <= 'z' {
656 s, _ = canonicalMIMEHeaderKey([]byte(s))
657 return s
658 }
659 if !upper && 'A' <= c && c <= 'Z' {
660 s, _ = canonicalMIMEHeaderKey([]byte(s))
661 return s
662 }
663 upper = c == '-'
664 }
665 return s
666 }
667
668 const toLower = 'a' - 'A'
669
670
671
672
673
674
675
676
677
678 func validHeaderFieldByte(c byte) bool {
679
680
681
682
683 const mask = 0 |
684 (1<<(10)-1)<<'0' |
685 (1<<(26)-1)<<'a' |
686 (1<<(26)-1)<<'A' |
687 1<<'!' |
688 1<<'#' |
689 1<<'$' |
690 1<<'%' |
691 1<<'&' |
692 1<<'\'' |
693 1<<'*' |
694 1<<'+' |
695 1<<'-' |
696 1<<'.' |
697 1<<'^' |
698 1<<'_' |
699 1<<'`' |
700 1<<'|' |
701 1<<'~'
702 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
703 (uint64(1)<<(c-64))&(mask>>64)) != 0
704 }
705
706
707
708
709
710
711
712
713
714
715
716
717
718 func validHeaderValueByte(c byte) bool {
719
720
721
722
723
724 const mask = 0 |
725 (1<<(0x7f-0x21)-1)<<0x21 |
726 1<<0x20 |
727 1<<0x09
728 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
729 (uint64(1)<<(c-64))&^(mask>>64)) == 0
730 }
731
732
733
734
735
736
737
738
739
740
741
742 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
743 if len(a) == 0 {
744 return "", false
745 }
746
747
748 noCanon := false
749 for _, c := range a {
750 if validHeaderFieldByte(c) {
751 continue
752 }
753
754 if c == ' ' {
755
756
757
758 noCanon = true
759 continue
760 }
761 return string(a), false
762 }
763 if noCanon {
764 return string(a), true
765 }
766
767 upper := true
768 for i, c := range a {
769
770
771
772
773 if upper && 'a' <= c && c <= 'z' {
774 c -= toLower
775 } else if !upper && 'A' <= c && c <= 'Z' {
776 c += toLower
777 }
778 a[i] = c
779 upper = c == '-'
780 }
781 commonHeaderOnce.Do(initCommonHeader)
782
783
784
785 if v := commonHeader[string(a)]; v != "" {
786 return v, true
787 }
788 return string(a), true
789 }
790
791
792 var commonHeader map[string]string
793
794 var commonHeaderOnce sync.Once
795
796 func initCommonHeader() {
797 commonHeader = make(map[string]string)
798 for _, v := range []string{
799 "Accept",
800 "Accept-Charset",
801 "Accept-Encoding",
802 "Accept-Language",
803 "Accept-Ranges",
804 "Cache-Control",
805 "Cc",
806 "Connection",
807 "Content-Id",
808 "Content-Language",
809 "Content-Length",
810 "Content-Transfer-Encoding",
811 "Content-Type",
812 "Cookie",
813 "Date",
814 "Dkim-Signature",
815 "Etag",
816 "Expires",
817 "From",
818 "Host",
819 "If-Modified-Since",
820 "If-None-Match",
821 "In-Reply-To",
822 "Last-Modified",
823 "Location",
824 "Message-Id",
825 "Mime-Version",
826 "Pragma",
827 "Received",
828 "Return-Path",
829 "Server",
830 "Set-Cookie",
831 "Subject",
832 "To",
833 "User-Agent",
834 "Via",
835 "X-Forwarded-For",
836 "X-Imforwards",
837 "X-Powered-By",
838 } {
839 commonHeader[v] = v
840 }
841 }
842
View as plain text