// Copyright 2024 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package hpke import ( "crypto" "crypto/aes" "crypto/cipher" "crypto/ecdh" "crypto/internal/fips140/hkdf" "crypto/rand" "errors" "internal/byteorder" "math/bits" "golang.org/x/crypto/chacha20poly1305" ) // testingOnlyGenerateKey is only used during testing, to provide // a fixed test key to use when checking the RFC 9180 vectors. var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error) type hkdfKDF struct { hash crypto.Hash } func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte { labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey)) labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) labeledIKM = append(labeledIKM, sid...) labeledIKM = append(labeledIKM, label...) labeledIKM = append(labeledIKM, inputKey...) return hkdf.Extract(kdf.hash.New, labeledIKM, salt) } func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte { labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) labeledInfo = byteorder.BEAppendUint16(labeledInfo, length) labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) labeledInfo = append(labeledInfo, suiteID...) labeledInfo = append(labeledInfo, label...) labeledInfo = append(labeledInfo, info...) return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length)) } // dhKEM implements the KEM specified in RFC 9180, Section 4.1. type dhKEM struct { dh ecdh.Curve kdf hkdfKDF suiteID []byte nSecret uint16 } type KemID uint16 const DHKEM_X25519_HKDF_SHA256 = 0x0020 var SupportedKEMs = map[uint16]struct { curve ecdh.Curve hash crypto.Hash nSecret uint16 }{ // RFC 9180 Section 7.1 DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32}, } func newDHKem(kemID uint16) (*dhKEM, error) { suite, ok := SupportedKEMs[kemID] if !ok { return nil, errors.New("unsupported suite ID") } return &dhKEM{ dh: suite.curve, kdf: hkdfKDF{suite.hash}, suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID), nSecret: suite.nSecret, }, nil } func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte { eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret) } func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) { var privEph *ecdh.PrivateKey if testingOnlyGenerateKey != nil { privEph, err = testingOnlyGenerateKey() } else { privEph, err = dh.dh.GenerateKey(rand.Reader) } if err != nil { return nil, nil, err } dhVal, err := privEph.ECDH(pubRecipient) if err != nil { return nil, nil, err } encPubEph := privEph.PublicKey().Bytes() encPubRecip := pubRecipient.Bytes() kemContext := append(encPubEph, encPubRecip...) return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil } func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) { pubEph, err := dh.dh.NewPublicKey(encPubEph) if err != nil { return nil, err } dhVal, err := secRecipient.ECDH(pubEph) if err != nil { return nil, err } kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...) return dh.ExtractAndExpand(dhVal, kemContext), nil } type context struct { aead cipher.AEAD sharedSecret []byte suiteID []byte key []byte baseNonce []byte exporterSecret []byte seqNum uint128 } type Sender struct { *context } type Receipient struct { *context } var aesGCMNew = func(key []byte) (cipher.AEAD, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err } return cipher.NewGCM(block) } type AEADID uint16 const ( AEAD_AES_128_GCM = 0x0001 AEAD_AES_256_GCM = 0x0002 AEAD_ChaCha20Poly1305 = 0x0003 ) var SupportedAEADs = map[uint16]struct { keySize int nonceSize int aead func([]byte) (cipher.AEAD, error) }{ // RFC 9180, Section 7.3 AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew}, AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew}, AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New}, } type KDFID uint16 const KDF_HKDF_SHA256 = 0x0001 var SupportedKDFs = map[uint16]func() *hkdfKDF{ // RFC 9180, Section 7.2 KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} }, } func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) { sid := suiteID(kemID, kdfID, aeadID) kdfInit, ok := SupportedKDFs[kdfID] if !ok { return nil, errors.New("unsupported KDF id") } kdf := kdfInit() aeadInfo, ok := SupportedAEADs[aeadID] if !ok { return nil, errors.New("unsupported AEAD id") } pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info) ksContext := append([]byte{0}, pskIDHash...) ksContext = append(ksContext, infoHash...) secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) aead, err := aeadInfo.aead(key) if err != nil { return nil, err } return &context{ aead: aead, sharedSecret: sharedSecret, suiteID: sid, key: key, baseNonce: baseNonce, exporterSecret: exporterSecret, }, nil } func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) { kem, err := newDHKem(kemID) if err != nil { return nil, nil, err } sharedSecret, encapsulatedKey, err := kem.Encap(pub) if err != nil { return nil, nil, err } context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info) if err != nil { return nil, nil, err } return encapsulatedKey, &Sender{context}, nil } func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) { kem, err := newDHKem(kemID) if err != nil { return nil, err } sharedSecret, err := kem.Decap(encPubEph, priv) if err != nil { return nil, err } context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info) if err != nil { return nil, err } return &Receipient{context}, nil } func (ctx *context) nextNonce() []byte { nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():] for i := range ctx.baseNonce { nonce[i] ^= ctx.baseNonce[i] } return nonce } func (ctx *context) incrementNonce() { // Message limit is, according to the RFC, 2^95+1, which // is somewhat confusing, but we do as we're told. if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 { panic("message limit reached") } ctx.seqNum = ctx.seqNum.addOne() } func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad) s.incrementNonce() return ciphertext, nil } func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) { plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad) if err != nil { return nil, err } r.incrementNonce() return plaintext, nil } func suiteID(kemID, kdfID, aeadID uint16) []byte { suiteID := make([]byte, 0, 4+2+2+2) suiteID = append(suiteID, []byte("HPKE")...) suiteID = byteorder.BEAppendUint16(suiteID, kemID) suiteID = byteorder.BEAppendUint16(suiteID, kdfID) suiteID = byteorder.BEAppendUint16(suiteID, aeadID) return suiteID } func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) { kemInfo, ok := SupportedKEMs[kemID] if !ok { return nil, errors.New("unsupported KEM id") } return kemInfo.curve.NewPublicKey(bytes) } func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) { kemInfo, ok := SupportedKEMs[kemID] if !ok { return nil, errors.New("unsupported KEM id") } return kemInfo.curve.NewPrivateKey(bytes) } type uint128 struct { hi, lo uint64 } func (u uint128) addOne() uint128 { lo, carry := bits.Add64(u.lo, 1, 0) return uint128{u.hi + carry, lo} } func (u uint128) bitLen() int { return bits.Len64(u.hi) + bits.Len64(u.lo) } func (u uint128) bytes() []byte { b := make([]byte, 16) byteorder.BEPutUint64(b[0:], u.hi) byteorder.BEPutUint64(b[8:], u.lo) return b }