1
2
3 package mlkem
4
5 import (
6 "bytes"
7 "crypto/internal/fips140"
8 "crypto/internal/fips140/drbg"
9 "crypto/internal/fips140/sha3"
10 "crypto/internal/fips140/subtle"
11 "errors"
12 )
13
14
15
16 type DecapsulationKey1024 struct {
17 d [32]byte
18 z [32]byte
19
20 ρ [32]byte
21 h [32]byte
22
23 encryptionKey1024
24 decryptionKey1024
25 }
26
27
28
29
30 func (dk *DecapsulationKey1024) Bytes() []byte {
31 var b [SeedSize]byte
32 copy(b[:], dk.d[:])
33 copy(b[32:], dk.z[:])
34 return b[:]
35 }
36
37
38
39
40
41
42 func TestingOnlyExpandedBytes1024(dk *DecapsulationKey1024) []byte {
43 b := make([]byte, 0, decapsulationKeySize1024)
44
45
46 for i := range dk.s {
47 b = polyByteEncode(b, dk.s[i])
48 }
49
50
51 for i := range dk.t {
52 b = polyByteEncode(b, dk.t[i])
53 }
54 b = append(b, dk.ρ[:]...)
55
56
57 b = append(b, dk.h[:]...)
58 b = append(b, dk.z[:]...)
59
60 return b
61 }
62
63
64
65 func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
66 return &EncapsulationKey1024{
67 ρ: dk.ρ,
68 h: dk.h,
69 encryptionKey1024: dk.encryptionKey1024,
70 }
71 }
72
73
74
75 type EncapsulationKey1024 struct {
76 ρ [32]byte
77 h [32]byte
78 encryptionKey1024
79 }
80
81
82 func (ek *EncapsulationKey1024) Bytes() []byte {
83
84 b := make([]byte, 0, EncapsulationKeySize1024)
85 return ek.bytes(b)
86 }
87
88 func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
89 for i := range ek.t {
90 b = polyByteEncode(b, ek.t[i])
91 }
92 b = append(b, ek.ρ[:]...)
93 return b
94 }
95
96
97 type encryptionKey1024 struct {
98 t [k1024]nttElement
99 a [k1024 * k1024]nttElement
100 }
101
102
103 type decryptionKey1024 struct {
104 s [k1024]nttElement
105 }
106
107
108
109 func GenerateKey1024() (*DecapsulationKey1024, error) {
110
111 dk := &DecapsulationKey1024{}
112 return generateKey1024(dk)
113 }
114
115 func generateKey1024(dk *DecapsulationKey1024) (*DecapsulationKey1024, error) {
116 var d [32]byte
117 drbg.Read(d[:])
118 var z [32]byte
119 drbg.Read(z[:])
120 kemKeyGen1024(dk, &d, &z)
121 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) }); err != nil {
122
123 panic(err)
124 }
125 fips140.RecordApproved()
126 return dk, nil
127 }
128
129
130
131 func GenerateKeyInternal1024(d, z *[32]byte) *DecapsulationKey1024 {
132 dk := &DecapsulationKey1024{}
133 kemKeyGen1024(dk, d, z)
134 return dk
135 }
136
137
138
139 func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
140
141 dk := &DecapsulationKey1024{}
142 return newKeyFromSeed1024(dk, seed)
143 }
144
145 func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
146 if len(seed) != SeedSize {
147 return nil, errors.New("mlkem: invalid seed length")
148 }
149 d := (*[32]byte)(seed[:32])
150 z := (*[32]byte)(seed[32:])
151 kemKeyGen1024(dk, d, z)
152 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) }); err != nil {
153
154 panic(err)
155 }
156 fips140.RecordApproved()
157 return dk, nil
158 }
159
160
161
162
163
164
165
166
167 func TestingOnlyNewDecapsulationKey1024(b []byte) (*DecapsulationKey1024, error) {
168 if len(b) != decapsulationKeySize1024 {
169 return nil, errors.New("mlkem: invalid NIST decapsulation key length")
170 }
171
172 dk := &DecapsulationKey1024{}
173 for i := range dk.s {
174 var err error
175 dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
176 if err != nil {
177 return nil, errors.New("mlkem: invalid secret key encoding")
178 }
179 b = b[encodingSize12:]
180 }
181
182 ek, err := NewEncapsulationKey1024(b[:EncapsulationKeySize1024])
183 if err != nil {
184 return nil, err
185 }
186 dk.ρ = ek.ρ
187 dk.h = ek.h
188 dk.encryptionKey1024 = ek.encryptionKey1024
189 b = b[EncapsulationKeySize1024:]
190
191 if !bytes.Equal(dk.h[:], b[:32]) {
192 return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
193 }
194 b = b[32:]
195
196 copy(dk.z[:], b)
197
198
199
200
201
202 drbg.Read(dk.d[:])
203
204 return dk, nil
205 }
206
207
208
209
210
211
212 func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
213 dk.d = *d
214 dk.z = *z
215
216 g := sha3.New512()
217 g.Write(d[:])
218 g.Write([]byte{k1024})
219 G := g.Sum(make([]byte, 0, 64))
220 ρ, σ := G[:32], G[32:]
221 dk.ρ = [32]byte(ρ)
222
223 A := &dk.a
224 for i := byte(0); i < k1024; i++ {
225 for j := byte(0); j < k1024; j++ {
226 A[i*k1024+j] = sampleNTT(ρ, j, i)
227 }
228 }
229
230 var N byte
231 s := &dk.s
232 for i := range s {
233 s[i] = ntt(samplePolyCBD(σ, N))
234 N++
235 }
236 e := make([]nttElement, k1024)
237 for i := range e {
238 e[i] = ntt(samplePolyCBD(σ, N))
239 N++
240 }
241
242 t := &dk.t
243 for i := range t {
244 t[i] = e[i]
245 for j := range s {
246 t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
247 }
248 }
249
250 H := sha3.New256()
251 ek := dk.EncapsulationKey().Bytes()
252 H.Write(ek)
253 H.Sum(dk.h[:0])
254 }
255
256
257
258
259
260
261
262
263
264 func kemPCT1024(dk *DecapsulationKey1024) error {
265 ek := dk.EncapsulationKey()
266 K, c := ek.Encapsulate()
267 K1, err := dk.Decapsulate(c)
268 if err != nil {
269 return err
270 }
271 if subtle.ConstantTimeCompare(K, K1) != 1 {
272 return errors.New("mlkem: PCT failed")
273 }
274 return nil
275 }
276
277
278
279
280
281 func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
282
283 var cc [CiphertextSize1024]byte
284 return ek.encapsulate(&cc)
285 }
286
287 func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
288 var m [messageSize]byte
289 drbg.Read(m[:])
290
291
292 fips140.RecordApproved()
293 return kemEncaps1024(cc, ek, &m)
294 }
295
296
297
298 func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
299 cc := &[CiphertextSize1024]byte{}
300 return kemEncaps1024(cc, ek, m)
301 }
302
303
304
305
306 func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
307 g := sha3.New512()
308 g.Write(m[:])
309 g.Write(ek.h[:])
310 G := g.Sum(nil)
311 K, r := G[:SharedKeySize], G[SharedKeySize:]
312 c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
313 return K, c
314 }
315
316
317
318 func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
319
320 ek := &EncapsulationKey1024{}
321 return parseEK1024(ek, encapsulationKey)
322 }
323
324
325
326
327
328 func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
329 if len(ekPKE) != EncapsulationKeySize1024 {
330 return nil, errors.New("mlkem: invalid encapsulation key length")
331 }
332
333 h := sha3.New256()
334 h.Write(ekPKE)
335 h.Sum(ek.h[:0])
336
337 for i := range ek.t {
338 var err error
339 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
340 if err != nil {
341 return nil, err
342 }
343 ekPKE = ekPKE[encodingSize12:]
344 }
345 copy(ek.ρ[:], ekPKE)
346
347 for i := byte(0); i < k1024; i++ {
348 for j := byte(0); j < k1024; j++ {
349 ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
350 }
351 }
352
353 return ek, nil
354 }
355
356
357
358
359
360 func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
361 var N byte
362 r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
363 for i := range r {
364 r[i] = ntt(samplePolyCBD(rnd, N))
365 N++
366 }
367 for i := range e1 {
368 e1[i] = samplePolyCBD(rnd, N)
369 N++
370 }
371 e2 := samplePolyCBD(rnd, N)
372
373 u := make([]ringElement, k1024)
374 for i := range u {
375 u[i] = e1[i]
376 for j := range r {
377
378 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
379 }
380 }
381
382 μ := ringDecodeAndDecompress1(m)
383
384 var vNTT nttElement
385 for i := range ex.t {
386 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
387 }
388 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
389
390 c := cc[:0]
391 for _, f := range u {
392 c = ringCompressAndEncode11(c, f)
393 }
394 c = ringCompressAndEncode5(c, v)
395
396 return c
397 }
398
399
400
401
402
403 func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
404 if len(ciphertext) != CiphertextSize1024 {
405 return nil, errors.New("mlkem: invalid ciphertext length")
406 }
407 c := (*[CiphertextSize1024]byte)(ciphertext)
408
409
410
411 return kemDecaps1024(dk, c), nil
412 }
413
414
415
416
417 func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
418 fips140.RecordApproved()
419 m := pkeDecrypt1024(&dk.decryptionKey1024, c)
420 g := sha3.New512()
421 g.Write(m[:])
422 g.Write(dk.h[:])
423 G := g.Sum(make([]byte, 0, 64))
424 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
425 J := sha3.NewShake256()
426 J.Write(dk.z[:])
427 J.Write(c[:])
428 Kout := make([]byte, SharedKeySize)
429 J.Read(Kout)
430 var cc [CiphertextSize1024]byte
431 c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
432
433 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
434 return Kout
435 }
436
437
438
439
440
441 func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
442 u := make([]ringElement, k1024)
443 for i := range u {
444 b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
445 u[i] = ringDecodeAndDecompress11(b)
446 }
447
448 b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
449 v := ringDecodeAndDecompress5(b)
450
451 var mask nttElement
452 for i := range dx.s {
453 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
454 }
455 w := polySub(v, inverseNTT(mask))
456
457 return ringCompressAndEncode1(nil, w)
458 }
459
View as plain text