1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "strings"
16 "sync"
17 _ "unsafe"
18 )
19
20
21
22 var errMessageTooLarge = errors.New("message too large")
23
24
25
26 type Reader struct {
27 R *bufio.Reader
28 dot *dotReader
29 buf []byte
30 }
31
32
33
34
35
36
37 func NewReader(r *bufio.Reader) *Reader {
38 return &Reader{R: r}
39 }
40
41
42
43 func (r *Reader) ReadLine() (string, error) {
44 line, err := r.readLineSlice(-1)
45 return string(line), err
46 }
47
48
49 func (r *Reader) ReadLineBytes() ([]byte, error) {
50 line, err := r.readLineSlice(-1)
51 if line != nil {
52 line = bytes.Clone(line)
53 }
54 return line, err
55 }
56
57
58
59
60 func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
61 r.closeDot()
62 var line []byte
63 for {
64 l, more, err := r.R.ReadLine()
65 if err != nil {
66 return nil, err
67 }
68 if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
69 return nil, errMessageTooLarge
70 }
71
72 if line == nil && !more {
73 return l, nil
74 }
75 line = append(line, l...)
76 if !more {
77 break
78 }
79 }
80 return line, nil
81 }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101 func (r *Reader) ReadContinuedLine() (string, error) {
102 line, err := r.readContinuedLineSlice(-1, noValidation)
103 return string(line), err
104 }
105
106
107
108 func trim(s []byte) []byte {
109 i := 0
110 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
111 i++
112 }
113 n := len(s)
114 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
115 n--
116 }
117 return s[i:n]
118 }
119
120
121
122 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
123 line, err := r.readContinuedLineSlice(-1, noValidation)
124 if line != nil {
125 line = bytes.Clone(line)
126 }
127 return line, err
128 }
129
130
131
132
133
134
135 func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
136 if validateFirstLine == nil {
137 return nil, fmt.Errorf("missing validateFirstLine func")
138 }
139
140
141 line, err := r.readLineSlice(lim)
142 if err != nil {
143 return nil, err
144 }
145 if len(line) == 0 {
146 return line, nil
147 }
148
149 if err := validateFirstLine(line); err != nil {
150 return nil, err
151 }
152
153
154
155
156
157 if r.R.Buffered() > 1 {
158 peek, _ := r.R.Peek(2)
159 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
160 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
161 return trim(line), nil
162 }
163 }
164
165
166
167 r.buf = append(r.buf[:0], trim(line)...)
168
169 if lim < 0 {
170 lim = math.MaxInt64
171 }
172 lim -= int64(len(r.buf))
173
174
175 for r.skipSpace() > 0 {
176 r.buf = append(r.buf, ' ')
177 if int64(len(r.buf)) >= lim {
178 return nil, errMessageTooLarge
179 }
180 line, err := r.readLineSlice(lim - int64(len(r.buf)))
181 if err != nil {
182 break
183 }
184 r.buf = append(r.buf, trim(line)...)
185 }
186 return r.buf, nil
187 }
188
189
190 func (r *Reader) skipSpace() int {
191 n := 0
192 for {
193 c, err := r.R.ReadByte()
194 if err != nil {
195
196 break
197 }
198 if c != ' ' && c != '\t' {
199 r.R.UnreadByte()
200 break
201 }
202 n++
203 }
204 return n
205 }
206
207 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
208 line, err := r.ReadLine()
209 if err != nil {
210 return
211 }
212 return parseCodeLine(line, expectCode)
213 }
214
215 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
216 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
217 err = ProtocolError("short response: " + line)
218 return
219 }
220 continued = line[3] == '-'
221 code, err = strconv.Atoi(line[0:3])
222 if err != nil || code < 100 {
223 err = ProtocolError("invalid response code: " + line)
224 return
225 }
226 message = line[4:]
227 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
228 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
229 100 <= expectCode && expectCode < 1000 && code != expectCode {
230 err = &Error{code, message}
231 }
232 return
233 }
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
253 code, continued, message, err := r.readCodeLine(expectCode)
254 if err == nil && continued {
255 err = ProtocolError("unexpected multi-line response: " + message)
256 }
257 return
258 }
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
287 code, continued, first, err := r.readCodeLine(expectCode)
288 multi := continued
289 var messageBuilder strings.Builder
290 messageBuilder.WriteString(first)
291 for continued {
292 line, err := r.ReadLine()
293 if err != nil {
294 return 0, "", err
295 }
296
297 var code2 int
298 var moreMessage string
299 code2, continued, moreMessage, err = parseCodeLine(line, 0)
300 if err != nil || code2 != code {
301 messageBuilder.WriteByte('\n')
302 messageBuilder.WriteString(strings.TrimRight(line, "\r\n"))
303 continued = true
304 continue
305 }
306 messageBuilder.WriteByte('\n')
307 messageBuilder.WriteString(moreMessage)
308 }
309 message = messageBuilder.String()
310 if err != nil && multi && message != "" {
311
312 err = &Error{code, message}
313 }
314 return
315 }
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333 func (r *Reader) DotReader() io.Reader {
334 r.closeDot()
335 r.dot = &dotReader{r: r}
336 return r.dot
337 }
338
339 type dotReader struct {
340 r *Reader
341 state int
342 }
343
344
345 func (d *dotReader) Read(b []byte) (n int, err error) {
346
347
348
349 const (
350 stateBeginLine = iota
351 stateDot
352 stateDotCR
353 stateCR
354 stateData
355 stateEOF
356 )
357 br := d.r.R
358 for n < len(b) && d.state != stateEOF {
359 var c byte
360 c, err = br.ReadByte()
361 if err != nil {
362 if err == io.EOF {
363 err = io.ErrUnexpectedEOF
364 }
365 break
366 }
367 switch d.state {
368 case stateBeginLine:
369 if c == '.' {
370 d.state = stateDot
371 continue
372 }
373 if c == '\r' {
374 d.state = stateCR
375 continue
376 }
377 d.state = stateData
378
379 case stateDot:
380 if c == '\r' {
381 d.state = stateDotCR
382 continue
383 }
384 if c == '\n' {
385 d.state = stateEOF
386 continue
387 }
388 d.state = stateData
389
390 case stateDotCR:
391 if c == '\n' {
392 d.state = stateEOF
393 continue
394 }
395
396
397 br.UnreadByte()
398 c = '\r'
399 d.state = stateData
400
401 case stateCR:
402 if c == '\n' {
403 d.state = stateBeginLine
404 break
405 }
406
407 br.UnreadByte()
408 c = '\r'
409 d.state = stateData
410
411 case stateData:
412 if c == '\r' {
413 d.state = stateCR
414 continue
415 }
416 if c == '\n' {
417 d.state = stateBeginLine
418 }
419 }
420 b[n] = c
421 n++
422 }
423 if err == nil && d.state == stateEOF {
424 err = io.EOF
425 }
426 if err != nil && d.r.dot == d {
427 d.r.dot = nil
428 }
429 return
430 }
431
432
433
434 func (r *Reader) closeDot() {
435 if r.dot == nil {
436 return
437 }
438 buf := make([]byte, 128)
439 for r.dot != nil {
440
441
442 r.dot.Read(buf)
443 }
444 }
445
446
447
448
449 func (r *Reader) ReadDotBytes() ([]byte, error) {
450 return io.ReadAll(r.DotReader())
451 }
452
453
454
455
456
457 func (r *Reader) ReadDotLines() ([]string, error) {
458
459
460
461 var v []string
462 var err error
463 for {
464 var line string
465 line, err = r.ReadLine()
466 if err != nil {
467 if err == io.EOF {
468 err = io.ErrUnexpectedEOF
469 }
470 break
471 }
472
473
474 if len(line) > 0 && line[0] == '.' {
475 if len(line) == 1 {
476 break
477 }
478 line = line[1:]
479 }
480 v = append(v, line)
481 }
482 return v, err
483 }
484
485 var colon = []byte(":")
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
507 return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
508 }
509
510
511
512
513
514
515 func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
516
517
518
519 var strs []string
520 hint := r.upcomingHeaderKeys()
521 if hint > 0 {
522 if hint > 1000 {
523 hint = 1000
524 }
525 strs = make([]string, hint)
526 }
527
528 m := make(MIMEHeader, hint)
529
530
531
532
533 maxMemory -= 400
534 const mapEntryOverhead = 200
535
536
537 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
538 const errorLimit = 80
539 line, err := r.readLineSlice(errorLimit)
540 if err != nil {
541 return m, err
542 }
543 return m, ProtocolError("malformed MIME header initial line: " + string(line))
544 }
545
546 for {
547 kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
548 if len(kv) == 0 {
549 return m, err
550 }
551
552
553 k, v, ok := bytes.Cut(kv, colon)
554 if !ok {
555 return m, ProtocolError("malformed MIME header line: " + string(kv))
556 }
557 key, ok := canonicalMIMEHeaderKey(k)
558 if !ok {
559 return m, ProtocolError("malformed MIME header line: " + string(kv))
560 }
561 for _, c := range v {
562 if !validHeaderValueByte(c) {
563 return m, ProtocolError("malformed MIME header line: " + string(kv))
564 }
565 }
566
567 maxHeaders--
568 if maxHeaders < 0 {
569 return nil, errMessageTooLarge
570 }
571
572
573 value := string(bytes.TrimLeft(v, " \t"))
574
575 vv := m[key]
576 if vv == nil {
577 maxMemory -= int64(len(key))
578 maxMemory -= mapEntryOverhead
579 }
580 maxMemory -= int64(len(value))
581 if maxMemory < 0 {
582 return m, errMessageTooLarge
583 }
584 if vv == nil && len(strs) > 0 {
585
586
587
588
589 vv, strs = strs[:1:1], strs[1:]
590 vv[0] = value
591 m[key] = vv
592 } else {
593 m[key] = append(vv, value)
594 }
595
596 if err != nil {
597 return m, err
598 }
599 }
600 }
601
602
603
604 func noValidation(_ []byte) error { return nil }
605
606
607
608
609 func mustHaveFieldNameColon(line []byte) error {
610 if bytes.IndexByte(line, ':') < 0 {
611 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
612 }
613 return nil
614 }
615
616 var nl = []byte("\n")
617
618
619
620 func (r *Reader) upcomingHeaderKeys() (n int) {
621
622 r.R.Peek(1)
623 s := r.R.Buffered()
624 if s == 0 {
625 return
626 }
627 peek, _ := r.R.Peek(s)
628 for len(peek) > 0 && n < 1000 {
629 var line []byte
630 line, peek, _ = bytes.Cut(peek, nl)
631 if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
632
633 break
634 }
635 if line[0] == ' ' || line[0] == '\t' {
636
637 continue
638 }
639 n++
640 }
641 return n
642 }
643
644
645
646
647
648
649
650
651
652 func CanonicalMIMEHeaderKey(s string) string {
653
654 upper := true
655 for i := 0; i < len(s); i++ {
656 c := s[i]
657 if !validHeaderFieldByte(c) {
658 return s
659 }
660 if upper && 'a' <= c && c <= 'z' {
661 s, _ = canonicalMIMEHeaderKey([]byte(s))
662 return s
663 }
664 if !upper && 'A' <= c && c <= 'Z' {
665 s, _ = canonicalMIMEHeaderKey([]byte(s))
666 return s
667 }
668 upper = c == '-'
669 }
670 return s
671 }
672
673 const toLower = 'a' - 'A'
674
675
676
677
678
679
680
681
682
683 func validHeaderFieldByte(c byte) bool {
684
685
686
687
688 const mask = 0 |
689 (1<<(10)-1)<<'0' |
690 (1<<(26)-1)<<'a' |
691 (1<<(26)-1)<<'A' |
692 1<<'!' |
693 1<<'#' |
694 1<<'$' |
695 1<<'%' |
696 1<<'&' |
697 1<<'\'' |
698 1<<'*' |
699 1<<'+' |
700 1<<'-' |
701 1<<'.' |
702 1<<'^' |
703 1<<'_' |
704 1<<'`' |
705 1<<'|' |
706 1<<'~'
707 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
708 (uint64(1)<<(c-64))&(mask>>64)) != 0
709 }
710
711
712
713
714
715
716
717
718
719
720
721
722
723 func validHeaderValueByte(c byte) bool {
724
725
726
727
728
729 const mask = 0 |
730 (1<<(0x7f-0x21)-1)<<0x21 |
731 1<<0x20 |
732 1<<0x09
733 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
734 (uint64(1)<<(c-64))&^(mask>>64)) == 0
735 }
736
737
738
739
740
741
742
743
744
745
746
747 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
748 if len(a) == 0 {
749 return "", false
750 }
751
752
753 noCanon := false
754 for _, c := range a {
755 if validHeaderFieldByte(c) {
756 continue
757 }
758
759 if c == ' ' {
760
761
762
763 noCanon = true
764 continue
765 }
766 return string(a), false
767 }
768 if noCanon {
769 return string(a), true
770 }
771
772 upper := true
773 for i, c := range a {
774
775
776
777
778 if upper && 'a' <= c && c <= 'z' {
779 c -= toLower
780 } else if !upper && 'A' <= c && c <= 'Z' {
781 c += toLower
782 }
783 a[i] = c
784 upper = c == '-'
785 }
786 commonHeaderOnce.Do(initCommonHeader)
787
788
789
790 if v := commonHeader[string(a)]; v != "" {
791 return v, true
792 }
793 return string(a), true
794 }
795
796
797 var commonHeader map[string]string
798
799 var commonHeaderOnce sync.Once
800
801 func initCommonHeader() {
802 commonHeader = make(map[string]string)
803 for _, v := range []string{
804 "Accept",
805 "Accept-Charset",
806 "Accept-Encoding",
807 "Accept-Language",
808 "Accept-Ranges",
809 "Cache-Control",
810 "Cc",
811 "Connection",
812 "Content-Id",
813 "Content-Language",
814 "Content-Length",
815 "Content-Transfer-Encoding",
816 "Content-Type",
817 "Cookie",
818 "Date",
819 "Dkim-Signature",
820 "Etag",
821 "Expires",
822 "From",
823 "Host",
824 "If-Modified-Since",
825 "If-None-Match",
826 "In-Reply-To",
827 "Last-Modified",
828 "Location",
829 "Message-Id",
830 "Mime-Version",
831 "Pragma",
832 "Received",
833 "Return-Path",
834 "Server",
835 "Set-Cookie",
836 "Subject",
837 "To",
838 "User-Agent",
839 "Via",
840 "X-Forwarded-For",
841 "X-Imforwards",
842 "X-Powered-By",
843 } {
844 commonHeader[v] = v
845 }
846 }
847
View as plain text