Source file
src/crypto/tls/handshake_messages_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "crypto/x509"
10 "encoding/hex"
11 "math"
12 "math/rand"
13 "reflect"
14 "strings"
15 "testing"
16 "testing/quick"
17 "time"
18 )
19
20 var tests = []handshakeMessage{
21 &clientHelloMsg{},
22 &serverHelloMsg{},
23 &finishedMsg{},
24
25 &certificateMsg{},
26 &certificateRequestMsg{},
27 &certificateVerifyMsg{
28 hasSignatureAlgorithm: true,
29 },
30 &certificateStatusMsg{},
31 &clientKeyExchangeMsg{},
32 &newSessionTicketMsg{},
33 &encryptedExtensionsMsg{},
34 &endOfEarlyDataMsg{},
35 &keyUpdateMsg{},
36 &newSessionTicketMsgTLS13{},
37 &certificateRequestMsgTLS13{},
38 &certificateMsgTLS13{},
39 &SessionState{},
40 }
41
42 func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
43 t.Helper()
44 b, err := msg.marshal()
45 if err != nil {
46 t.Fatal(err)
47 }
48 return b
49 }
50
51 func TestMarshalUnmarshal(t *testing.T) {
52 rand := rand.New(rand.NewSource(time.Now().UnixNano()))
53
54 for i, m := range tests {
55 ty := reflect.ValueOf(m).Type()
56 t.Run(ty.String(), func(t *testing.T) {
57 n := 100
58 if testing.Short() {
59 n = 5
60 }
61 for j := 0; j < n; j++ {
62 v, ok := quick.Value(ty, rand)
63 if !ok {
64 t.Errorf("#%d: failed to create value", i)
65 break
66 }
67
68 m1 := v.Interface().(handshakeMessage)
69 marshaled := mustMarshal(t, m1)
70 if !m.unmarshal(marshaled) {
71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
72 break
73 }
74
75 if m, ok := m.(*SessionState); ok {
76 m.activeCertHandles = nil
77 }
78
79 if ch, ok := m.(*clientHelloMsg); ok {
80
81
82
83
84
85 if len(ch.extensions) == 0 {
86 t.Errorf("expected ch.extensions to be populated on unmarshal")
87 }
88 ch.extensions = nil
89 }
90
91
92
93
94
95
96 switch t := m.(type) {
97 case *clientHelloMsg:
98 t.original = nil
99 case *serverHelloMsg:
100 t.original = nil
101 }
102
103 if !reflect.DeepEqual(m1, m) {
104 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
105 break
106 }
107
108 if i >= 3 {
109
110
111
112
113
114 for j := 0; j < len(marshaled); j++ {
115 if m.unmarshal(marshaled[0:j]) {
116 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
117 break
118 }
119 }
120 }
121 }
122 })
123 }
124 }
125
126 func TestFuzz(t *testing.T) {
127 rand := rand.New(rand.NewSource(0))
128 for _, m := range tests {
129 for j := 0; j < 1000; j++ {
130 len := rand.Intn(1000)
131 bytes := randomBytes(len, rand)
132
133 m.unmarshal(bytes)
134 }
135 }
136 }
137
138 func randomBytes(n int, rand *rand.Rand) []byte {
139 r := make([]byte, n)
140 if _, err := rand.Read(r); err != nil {
141 panic("rand.Read failed: " + err.Error())
142 }
143 return r
144 }
145
146 func randomString(n int, rand *rand.Rand) string {
147 b := randomBytes(n, rand)
148 return string(b)
149 }
150
151 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
152 m := &clientHelloMsg{}
153 m.vers = uint16(rand.Intn(65536))
154 m.random = randomBytes(32, rand)
155 m.sessionId = randomBytes(rand.Intn(32), rand)
156 m.cipherSuites = make([]uint16, rand.Intn(63)+1)
157 for i := 0; i < len(m.cipherSuites); i++ {
158 cs := uint16(rand.Int31())
159 if cs == scsvRenegotiation {
160 cs += 1
161 }
162 m.cipherSuites[i] = cs
163 }
164 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
165 if rand.Intn(10) > 5 {
166 m.serverName = randomString(rand.Intn(255), rand)
167 for strings.HasSuffix(m.serverName, ".") {
168 m.serverName = m.serverName[:len(m.serverName)-1]
169 }
170 }
171 m.ocspStapling = rand.Intn(10) > 5
172 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
173 m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
174 for i := range m.supportedCurves {
175 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
176 }
177 if rand.Intn(10) > 5 {
178 m.ticketSupported = true
179 if rand.Intn(10) > 5 {
180 m.sessionTicket = randomBytes(rand.Intn(300), rand)
181 } else {
182 m.sessionTicket = make([]byte, 0)
183 }
184 }
185 if rand.Intn(10) > 5 {
186 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
187 }
188 if rand.Intn(10) > 5 {
189 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
190 }
191 for i := 0; i < rand.Intn(5); i++ {
192 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
193 }
194 if rand.Intn(10) > 5 {
195 m.scts = true
196 }
197 if rand.Intn(10) > 5 {
198 m.secureRenegotiationSupported = true
199 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
200 }
201 if rand.Intn(10) > 5 {
202 m.extendedMasterSecret = true
203 }
204 for i := 0; i < rand.Intn(5); i++ {
205 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
206 }
207 if rand.Intn(10) > 5 {
208 m.cookie = randomBytes(rand.Intn(500)+1, rand)
209 }
210 for i := 0; i < rand.Intn(5); i++ {
211 var ks keyShare
212 ks.group = CurveID(rand.Intn(30000) + 1)
213 ks.data = randomBytes(rand.Intn(200)+1, rand)
214 m.keyShares = append(m.keyShares, ks)
215 }
216 switch rand.Intn(3) {
217 case 1:
218 m.pskModes = []uint8{pskModeDHE}
219 case 2:
220 m.pskModes = []uint8{pskModeDHE, pskModePlain}
221 }
222 for i := 0; i < rand.Intn(5); i++ {
223 var psk pskIdentity
224 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
225 psk.label = randomBytes(rand.Intn(500)+1, rand)
226 m.pskIdentities = append(m.pskIdentities, psk)
227 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
228 }
229 if rand.Intn(10) > 5 {
230 m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
231 }
232 if rand.Intn(10) > 5 {
233 m.earlyData = true
234 }
235 if rand.Intn(10) > 5 {
236 m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
237 }
238
239 return reflect.ValueOf(m)
240 }
241
242 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
243 m := &serverHelloMsg{}
244 m.vers = uint16(rand.Intn(65536))
245 m.random = randomBytes(32, rand)
246 m.sessionId = randomBytes(rand.Intn(32), rand)
247 m.cipherSuite = uint16(rand.Int31())
248 m.compressionMethod = uint8(rand.Intn(256))
249 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
250
251 if rand.Intn(10) > 5 {
252 m.ocspStapling = true
253 }
254 if rand.Intn(10) > 5 {
255 m.ticketSupported = true
256 }
257 if rand.Intn(10) > 5 {
258 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
259 }
260
261 for i := 0; i < rand.Intn(4); i++ {
262 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
263 }
264
265 if rand.Intn(10) > 5 {
266 m.secureRenegotiationSupported = true
267 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
268 }
269 if rand.Intn(10) > 5 {
270 m.extendedMasterSecret = true
271 }
272 if rand.Intn(10) > 5 {
273 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
274 }
275 if rand.Intn(10) > 5 {
276 m.cookie = randomBytes(rand.Intn(500)+1, rand)
277 }
278 if rand.Intn(10) > 5 {
279 for i := 0; i < rand.Intn(5); i++ {
280 m.serverShare.group = CurveID(rand.Intn(30000) + 1)
281 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
282 }
283 } else if rand.Intn(10) > 5 {
284 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
285 }
286 if rand.Intn(10) > 5 {
287 m.selectedIdentityPresent = true
288 m.selectedIdentity = uint16(rand.Intn(0xffff))
289 }
290 if rand.Intn(10) > 5 {
291 m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
292 }
293 if rand.Intn(10) > 5 {
294 m.serverNameAck = rand.Intn(2) == 1
295 }
296
297 return reflect.ValueOf(m)
298 }
299
300 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
301 m := &encryptedExtensionsMsg{}
302
303 if rand.Intn(10) > 5 {
304 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
305 }
306 if rand.Intn(10) > 5 {
307 m.earlyData = true
308 }
309
310 return reflect.ValueOf(m)
311 }
312
313 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
314 m := &certificateMsg{}
315 numCerts := rand.Intn(20)
316 m.certificates = make([][]byte, numCerts)
317 for i := 0; i < numCerts; i++ {
318 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
319 }
320 return reflect.ValueOf(m)
321 }
322
323 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
324 m := &certificateRequestMsg{}
325 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
326 for i := 0; i < rand.Intn(100); i++ {
327 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
328 }
329 return reflect.ValueOf(m)
330 }
331
332 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
333 m := &certificateVerifyMsg{}
334 m.hasSignatureAlgorithm = true
335 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
336 m.signature = randomBytes(rand.Intn(15)+1, rand)
337 return reflect.ValueOf(m)
338 }
339
340 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
341 m := &certificateStatusMsg{}
342 m.response = randomBytes(rand.Intn(10)+1, rand)
343 return reflect.ValueOf(m)
344 }
345
346 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
347 m := &clientKeyExchangeMsg{}
348 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
349 return reflect.ValueOf(m)
350 }
351
352 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
353 m := &finishedMsg{}
354 m.verifyData = randomBytes(12, rand)
355 return reflect.ValueOf(m)
356 }
357
358 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
359 m := &newSessionTicketMsg{}
360 m.ticket = randomBytes(rand.Intn(4), rand)
361 return reflect.ValueOf(m)
362 }
363
364 var sessionTestCerts []*x509.Certificate
365
366 func init() {
367 cert, err := x509.ParseCertificate(testRSACertificate)
368 if err != nil {
369 panic(err)
370 }
371 sessionTestCerts = append(sessionTestCerts, cert)
372 cert, err = x509.ParseCertificate(testRSACertificateIssuer)
373 if err != nil {
374 panic(err)
375 }
376 sessionTestCerts = append(sessionTestCerts, cert)
377 }
378
379 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
380 s := &SessionState{}
381 isTLS13 := rand.Intn(10) > 5
382 if isTLS13 {
383 s.version = VersionTLS13
384 } else {
385 s.version = uint16(rand.Intn(VersionTLS13))
386 }
387 s.isClient = rand.Intn(10) > 5
388 s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
389 s.createdAt = uint64(rand.Int63())
390 s.secret = randomBytes(rand.Intn(100)+1, rand)
391 for n, i := rand.Intn(3), 0; i < n; i++ {
392 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
393 }
394 if rand.Intn(10) > 5 {
395 s.EarlyData = true
396 }
397 if rand.Intn(10) > 5 {
398 s.extMasterSecret = true
399 }
400 if s.isClient || rand.Intn(10) > 5 {
401 if rand.Intn(10) > 5 {
402 s.peerCertificates = sessionTestCerts
403 } else {
404 s.peerCertificates = sessionTestCerts[:1]
405 }
406 }
407 if rand.Intn(10) > 5 && s.peerCertificates != nil {
408 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
409 }
410 if rand.Intn(10) > 5 && s.peerCertificates != nil {
411 for i := 0; i < rand.Intn(2)+1; i++ {
412 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
413 }
414 }
415 if len(s.peerCertificates) > 0 {
416 for i := 0; i < rand.Intn(3); i++ {
417 if rand.Intn(10) > 5 {
418 s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
419 } else {
420 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
421 }
422 }
423 }
424 if rand.Intn(10) > 5 && s.EarlyData {
425 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
426 }
427 if s.isClient {
428 if isTLS13 {
429 s.useBy = uint64(rand.Int63())
430 s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
431 }
432 }
433 return reflect.ValueOf(s)
434 }
435
436 func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
437 func (s *SessionState) unmarshal(b []byte) bool {
438 ss, err := ParseSessionState(b)
439 if err != nil {
440 return false
441 }
442 *s = *ss
443 return true
444 }
445
446 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
447 m := &endOfEarlyDataMsg{}
448 return reflect.ValueOf(m)
449 }
450
451 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
452 m := &keyUpdateMsg{}
453 m.updateRequested = rand.Intn(10) > 5
454 return reflect.ValueOf(m)
455 }
456
457 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
458 m := &newSessionTicketMsgTLS13{}
459 m.lifetime = uint32(rand.Intn(500000))
460 m.ageAdd = uint32(rand.Intn(500000))
461 m.nonce = randomBytes(rand.Intn(100), rand)
462 m.label = randomBytes(rand.Intn(1000), rand)
463 if rand.Intn(10) > 5 {
464 m.maxEarlyData = uint32(rand.Intn(500000))
465 }
466 return reflect.ValueOf(m)
467 }
468
469 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
470 m := &certificateRequestMsgTLS13{}
471 if rand.Intn(10) > 5 {
472 m.ocspStapling = true
473 }
474 if rand.Intn(10) > 5 {
475 m.scts = true
476 }
477 if rand.Intn(10) > 5 {
478 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
479 }
480 if rand.Intn(10) > 5 {
481 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
482 }
483 if rand.Intn(10) > 5 {
484 m.certificateAuthorities = make([][]byte, 3)
485 for i := 0; i < 3; i++ {
486 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
487 }
488 }
489 return reflect.ValueOf(m)
490 }
491
492 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
493 m := &certificateMsgTLS13{}
494 for i := 0; i < rand.Intn(2)+1; i++ {
495 m.certificate.Certificate = append(
496 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
497 }
498 if rand.Intn(10) > 5 {
499 m.ocspStapling = true
500 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
501 }
502 if rand.Intn(10) > 5 {
503 m.scts = true
504 for i := 0; i < rand.Intn(2)+1; i++ {
505 m.certificate.SignedCertificateTimestamps = append(
506 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
507 }
508 }
509 return reflect.ValueOf(m)
510 }
511
512 func TestRejectEmptySCTList(t *testing.T) {
513
514
515 var random [32]byte
516 sct := []byte{0x42, 0x42, 0x42, 0x42}
517 serverHello := &serverHelloMsg{
518 vers: VersionTLS12,
519 random: random[:],
520 scts: [][]byte{sct},
521 }
522 serverHelloBytes := mustMarshal(t, serverHello)
523
524 var serverHelloCopy serverHelloMsg
525 if !serverHelloCopy.unmarshal(serverHelloBytes) {
526 t.Fatal("Failed to unmarshal initial message")
527 }
528
529
530 i := bytes.Index(serverHelloBytes, sct)
531 if i < 0 {
532 t.Fatal("Cannot find SCT in ServerHello")
533 }
534
535 var serverHelloEmptySCT []byte
536 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
537
538 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
539 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
540
541
542 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
543 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
544 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
545
546
547 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
548 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
549
550 if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
551 t.Fatal("Unmarshaled ServerHello with empty SCT list")
552 }
553 }
554
555 func TestRejectEmptySCT(t *testing.T) {
556
557
558
559 var random [32]byte
560 serverHello := &serverHelloMsg{
561 vers: VersionTLS12,
562 random: random[:],
563 scts: [][]byte{nil},
564 }
565 serverHelloBytes := mustMarshal(t, serverHello)
566
567 var serverHelloCopy serverHelloMsg
568 if serverHelloCopy.unmarshal(serverHelloBytes) {
569 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
570 }
571 }
572
573 func TestRejectDuplicateExtensions(t *testing.T) {
574 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
575 if err != nil {
576 t.Fatalf("failed to decode test ClientHello: %s", err)
577 }
578 var clientHelloCopy clientHelloMsg
579 if clientHelloCopy.unmarshal(clientHelloBytes) {
580 t.Error("Unmarshaled ClientHello with duplicate extensions")
581 }
582
583 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
584 if err != nil {
585 t.Fatalf("failed to decode test ServerHello: %s", err)
586 }
587 var serverHelloCopy serverHelloMsg
588 if serverHelloCopy.unmarshal(serverHelloBytes) {
589 t.Fatal("Unmarshaled ServerHello with duplicate extensions")
590 }
591 }
592
View as plain text