1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "strings"
16 "sync"
17 _ "unsafe"
18 )
19
20
21
22 var errMessageTooLarge = errors.New("message too large")
23
24
25
26 type Reader struct {
27 R *bufio.Reader
28 dot *dotReader
29 buf []byte
30 }
31
32
33
34
35
36
37 func NewReader(r *bufio.Reader) *Reader {
38 return &Reader{R: r}
39 }
40
41
42
43 func (r *Reader) ReadLine() (string, error) {
44 line, err := r.readLineSlice(-1)
45 return string(line), err
46 }
47
48
49 func (r *Reader) ReadLineBytes() ([]byte, error) {
50 line, err := r.readLineSlice(-1)
51 if line != nil {
52 line = bytes.Clone(line)
53 }
54 return line, err
55 }
56
57
58
59
60 func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
61 r.closeDot()
62 var line []byte
63 for {
64 l, more, err := r.R.ReadLine()
65 if err != nil {
66 return nil, err
67 }
68 if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
69 return nil, errMessageTooLarge
70 }
71
72 if line == nil && !more {
73 return l, nil
74 }
75 line = append(line, l...)
76 if !more {
77 break
78 }
79 }
80 return line, nil
81 }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101 func (r *Reader) ReadContinuedLine() (string, error) {
102 line, err := r.readContinuedLineSlice(-1, noValidation)
103 return string(line), err
104 }
105
106
107
108 func trim(s []byte) []byte {
109 i := 0
110 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
111 i++
112 }
113 s = s[i:]
114 n := len(s) - 1
115 for n >= 0 && (s[n] == ' ' || s[n] == '\t') {
116 n--
117 }
118 return s[:n+1]
119 }
120
121
122
123 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
124 line, err := r.readContinuedLineSlice(-1, noValidation)
125 if line != nil {
126 line = bytes.Clone(line)
127 }
128 return line, err
129 }
130
131
132
133
134
135
136 func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
137 if validateFirstLine == nil {
138 return nil, fmt.Errorf("missing validateFirstLine func")
139 }
140
141
142 line, err := r.readLineSlice(lim)
143 if err != nil {
144 return nil, err
145 }
146 if len(line) == 0 {
147 return line, nil
148 }
149
150 if err := validateFirstLine(line); err != nil {
151 return nil, err
152 }
153
154
155
156
157
158 if r.R.Buffered() > 1 {
159 peek, _ := r.R.Peek(2)
160 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
161 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
162 return trim(line), nil
163 }
164 }
165
166
167
168 r.buf = append(r.buf[:0], trim(line)...)
169
170 if lim < 0 {
171 lim = math.MaxInt64
172 }
173 lim -= int64(len(r.buf))
174
175
176 for r.skipSpace() > 0 {
177 r.buf = append(r.buf, ' ')
178 if int64(len(r.buf)) >= lim {
179 return nil, errMessageTooLarge
180 }
181 line, err := r.readLineSlice(lim - int64(len(r.buf)))
182 if err != nil {
183 break
184 }
185 r.buf = append(r.buf, trim(line)...)
186 }
187 return r.buf, nil
188 }
189
190
191 func (r *Reader) skipSpace() int {
192 n := 0
193 for {
194 c, err := r.R.ReadByte()
195 if err != nil {
196
197 break
198 }
199 if c != ' ' && c != '\t' {
200 r.R.UnreadByte()
201 break
202 }
203 n++
204 }
205 return n
206 }
207
208 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
209 line, err := r.ReadLine()
210 if err != nil {
211 return
212 }
213 return parseCodeLine(line, expectCode)
214 }
215
216 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
217 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
218 err = ProtocolError("short response: " + line)
219 return
220 }
221 continued = line[3] == '-'
222 code, err = strconv.Atoi(line[0:3])
223 if err != nil || code < 100 {
224 err = ProtocolError("invalid response code: " + line)
225 return
226 }
227 message = line[4:]
228 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
229 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
230 100 <= expectCode && expectCode < 1000 && code != expectCode {
231 err = &Error{code, message}
232 }
233 return
234 }
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
254 code, continued, message, err := r.readCodeLine(expectCode)
255 if err == nil && continued {
256 err = ProtocolError("unexpected multi-line response: " + message)
257 }
258 return
259 }
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
288 code, continued, first, err := r.readCodeLine(expectCode)
289 multi := continued
290 var messageBuilder strings.Builder
291 messageBuilder.WriteString(first)
292 for continued {
293 line, err := r.ReadLine()
294 if err != nil {
295 return 0, "", err
296 }
297
298 var code2 int
299 var moreMessage string
300 code2, continued, moreMessage, err = parseCodeLine(line, 0)
301 if err != nil || code2 != code {
302 messageBuilder.WriteByte('\n')
303 messageBuilder.WriteString(strings.TrimRight(line, "\r\n"))
304 continued = true
305 continue
306 }
307 messageBuilder.WriteByte('\n')
308 messageBuilder.WriteString(moreMessage)
309 }
310 message = messageBuilder.String()
311 if err != nil && multi && message != "" {
312
313 err = &Error{code, message}
314 }
315 return
316 }
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334 func (r *Reader) DotReader() io.Reader {
335 r.closeDot()
336 r.dot = &dotReader{r: r}
337 return r.dot
338 }
339
340 type dotReader struct {
341 r *Reader
342 state int
343 }
344
345
346 func (d *dotReader) Read(b []byte) (n int, err error) {
347
348
349
350 const (
351 stateBeginLine = iota
352 stateDot
353 stateDotCR
354 stateCR
355 stateData
356 stateEOF
357 )
358 br := d.r.R
359 for n < len(b) && d.state != stateEOF {
360 var c byte
361 c, err = br.ReadByte()
362 if err != nil {
363 if err == io.EOF {
364 err = io.ErrUnexpectedEOF
365 }
366 break
367 }
368 switch d.state {
369 case stateBeginLine:
370 if c == '.' {
371 d.state = stateDot
372 continue
373 }
374 if c == '\r' {
375 d.state = stateCR
376 continue
377 }
378 d.state = stateData
379
380 case stateDot:
381 if c == '\r' {
382 d.state = stateDotCR
383 continue
384 }
385 if c == '\n' {
386 d.state = stateEOF
387 continue
388 }
389 d.state = stateData
390
391 case stateDotCR:
392 if c == '\n' {
393 d.state = stateEOF
394 continue
395 }
396
397
398 br.UnreadByte()
399 c = '\r'
400 d.state = stateData
401
402 case stateCR:
403 if c == '\n' {
404 d.state = stateBeginLine
405 break
406 }
407
408 br.UnreadByte()
409 c = '\r'
410 d.state = stateData
411
412 case stateData:
413 if c == '\r' {
414 d.state = stateCR
415 continue
416 }
417 if c == '\n' {
418 d.state = stateBeginLine
419 }
420 }
421 b[n] = c
422 n++
423 }
424 if err == nil && d.state == stateEOF {
425 err = io.EOF
426 }
427 if err != nil && d.r.dot == d {
428 d.r.dot = nil
429 }
430 return
431 }
432
433
434
435 func (r *Reader) closeDot() {
436 if r.dot == nil {
437 return
438 }
439 buf := make([]byte, 128)
440 for r.dot != nil {
441
442
443 r.dot.Read(buf)
444 }
445 }
446
447
448
449
450 func (r *Reader) ReadDotBytes() ([]byte, error) {
451 return io.ReadAll(r.DotReader())
452 }
453
454
455
456
457
458 func (r *Reader) ReadDotLines() ([]string, error) {
459
460
461
462 var v []string
463 var err error
464 for {
465 var line string
466 line, err = r.ReadLine()
467 if err != nil {
468 if err == io.EOF {
469 err = io.ErrUnexpectedEOF
470 }
471 break
472 }
473
474
475 if len(line) > 0 && line[0] == '.' {
476 if len(line) == 1 {
477 break
478 }
479 line = line[1:]
480 }
481 v = append(v, line)
482 }
483 return v, err
484 }
485
486 var colon = []byte(":")
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
508 return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
509 }
510
511
512
513
514
515
516 func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
517
518
519
520 var strs []string
521 hint := r.upcomingHeaderKeys()
522 if hint > 0 {
523 if hint > 1000 {
524 hint = 1000
525 }
526 strs = make([]string, hint)
527 }
528
529 m := make(MIMEHeader, hint)
530
531
532
533
534 maxMemory -= 400
535 const mapEntryOverhead = 200
536
537
538 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
539 const errorLimit = 80
540 line, err := r.readLineSlice(errorLimit)
541 if err != nil {
542 return m, err
543 }
544 return m, ProtocolError("malformed MIME header initial line: " + string(line))
545 }
546
547 for {
548 kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
549 if len(kv) == 0 {
550 return m, err
551 }
552
553
554 k, v, ok := bytes.Cut(kv, colon)
555 if !ok {
556 return m, ProtocolError("malformed MIME header line: " + string(kv))
557 }
558 key, ok := canonicalMIMEHeaderKey(k)
559 if !ok {
560 return m, ProtocolError("malformed MIME header line: " + string(kv))
561 }
562 for _, c := range v {
563 if !validHeaderValueByte(c) {
564 return m, ProtocolError("malformed MIME header line: " + string(kv))
565 }
566 }
567
568 maxHeaders--
569 if maxHeaders < 0 {
570 return nil, errMessageTooLarge
571 }
572
573
574 value := string(bytes.TrimLeft(v, " \t"))
575
576 vv := m[key]
577 if vv == nil {
578 maxMemory -= int64(len(key))
579 maxMemory -= mapEntryOverhead
580 }
581 maxMemory -= int64(len(value))
582 if maxMemory < 0 {
583 return m, errMessageTooLarge
584 }
585 if vv == nil && len(strs) > 0 {
586
587
588
589
590 vv, strs = strs[:1:1], strs[1:]
591 vv[0] = value
592 m[key] = vv
593 } else {
594 m[key] = append(vv, value)
595 }
596
597 if err != nil {
598 return m, err
599 }
600 }
601 }
602
603
604
605 func noValidation(_ []byte) error { return nil }
606
607
608
609
610 func mustHaveFieldNameColon(line []byte) error {
611 if bytes.IndexByte(line, ':') < 0 {
612 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
613 }
614 return nil
615 }
616
617 var nl = []byte("\n")
618
619
620
621 func (r *Reader) upcomingHeaderKeys() (n int) {
622
623 r.R.Peek(1)
624 s := r.R.Buffered()
625 if s == 0 {
626 return
627 }
628 peek, _ := r.R.Peek(s)
629 for len(peek) > 0 && n < 1000 {
630 var line []byte
631 line, peek, _ = bytes.Cut(peek, nl)
632 if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
633
634 break
635 }
636 if line[0] == ' ' || line[0] == '\t' {
637
638 continue
639 }
640 n++
641 }
642 return n
643 }
644
645
646
647
648
649
650
651
652
653 func CanonicalMIMEHeaderKey(s string) string {
654
655 upper := true
656 for i := 0; i < len(s); i++ {
657 c := s[i]
658 if !validHeaderFieldByte(c) {
659 return s
660 }
661 if upper && 'a' <= c && c <= 'z' {
662 s, _ = canonicalMIMEHeaderKey([]byte(s))
663 return s
664 }
665 if !upper && 'A' <= c && c <= 'Z' {
666 s, _ = canonicalMIMEHeaderKey([]byte(s))
667 return s
668 }
669 upper = c == '-'
670 }
671 return s
672 }
673
674 const toLower = 'a' - 'A'
675
676
677
678
679
680
681
682
683
684 func validHeaderFieldByte(c byte) bool {
685
686
687
688
689 const mask = 0 |
690 (1<<(10)-1)<<'0' |
691 (1<<(26)-1)<<'a' |
692 (1<<(26)-1)<<'A' |
693 1<<'!' |
694 1<<'#' |
695 1<<'$' |
696 1<<'%' |
697 1<<'&' |
698 1<<'\'' |
699 1<<'*' |
700 1<<'+' |
701 1<<'-' |
702 1<<'.' |
703 1<<'^' |
704 1<<'_' |
705 1<<'`' |
706 1<<'|' |
707 1<<'~'
708 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
709 (uint64(1)<<(c-64))&(mask>>64)) != 0
710 }
711
712
713
714
715
716
717
718
719
720
721
722
723
724 func validHeaderValueByte(c byte) bool {
725
726
727
728
729
730 const mask = 0 |
731 (1<<(0x7f-0x21)-1)<<0x21 |
732 1<<0x20 |
733 1<<0x09
734 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
735 (uint64(1)<<(c-64))&^(mask>>64)) == 0
736 }
737
738
739
740
741
742
743
744
745
746
747
748 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
749 if len(a) == 0 {
750 return "", false
751 }
752
753
754 noCanon := false
755 for _, c := range a {
756 if validHeaderFieldByte(c) {
757 continue
758 }
759
760 if c == ' ' {
761
762
763
764 noCanon = true
765 continue
766 }
767 return string(a), false
768 }
769 if noCanon {
770 return string(a), true
771 }
772
773 upper := true
774 for i, c := range a {
775
776
777
778
779 if upper && 'a' <= c && c <= 'z' {
780 c -= toLower
781 } else if !upper && 'A' <= c && c <= 'Z' {
782 c += toLower
783 }
784 a[i] = c
785 upper = c == '-'
786 }
787 commonHeaderOnce.Do(initCommonHeader)
788
789
790
791 if v := commonHeader[string(a)]; v != "" {
792 return v, true
793 }
794 return string(a), true
795 }
796
797
798 var commonHeader map[string]string
799
800 var commonHeaderOnce sync.Once
801
802 func initCommonHeader() {
803 commonHeader = make(map[string]string)
804 for _, v := range []string{
805 "Accept",
806 "Accept-Charset",
807 "Accept-Encoding",
808 "Accept-Language",
809 "Accept-Ranges",
810 "Cache-Control",
811 "Cc",
812 "Connection",
813 "Content-Id",
814 "Content-Language",
815 "Content-Length",
816 "Content-Transfer-Encoding",
817 "Content-Type",
818 "Cookie",
819 "Date",
820 "Dkim-Signature",
821 "Etag",
822 "Expires",
823 "From",
824 "Host",
825 "If-Modified-Since",
826 "If-None-Match",
827 "In-Reply-To",
828 "Last-Modified",
829 "Location",
830 "Message-Id",
831 "Mime-Version",
832 "Pragma",
833 "Received",
834 "Return-Path",
835 "Server",
836 "Set-Cookie",
837 "Subject",
838 "To",
839 "User-Agent",
840 "Via",
841 "X-Forwarded-For",
842 "X-Imforwards",
843 "X-Powered-By",
844 } {
845 commonHeader[v] = v
846 }
847 }
848
View as plain text