Source file src/crypto/rsa/pss.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rsa
     6  
     7  // This file implements the RSASSA-PSS signature scheme according to RFC 8017.
     8  
     9  import (
    10  	"bytes"
    11  	"crypto"
    12  	"crypto/internal/boring"
    13  	"errors"
    14  	"hash"
    15  	"io"
    16  )
    17  
    18  // Per RFC 8017, Section 9.1
    19  //
    20  //     EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
    21  //
    22  // where
    23  //
    24  //     DB = PS || 0x01 || salt
    25  //
    26  // and PS can be empty so
    27  //
    28  //     emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
    29  //
    30  
    31  func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
    32  	// See RFC 8017, Section 9.1.1.
    33  
    34  	hLen := hash.Size()
    35  	sLen := len(salt)
    36  	emLen := (emBits + 7) / 8
    37  
    38  	// 1.  If the length of M is greater than the input limitation for the
    39  	//     hash function (2^61 - 1 octets for SHA-1), output "message too
    40  	//     long" and stop.
    41  	//
    42  	// 2.  Let mHash = Hash(M), an octet string of length hLen.
    43  
    44  	if len(mHash) != hLen {
    45  		return nil, errors.New("crypto/rsa: input must be hashed with given hash")
    46  	}
    47  
    48  	// 3.  If emLen < hLen + sLen + 2, output "encoding error" and stop.
    49  
    50  	if emLen < hLen+sLen+2 {
    51  		return nil, ErrMessageTooLong
    52  	}
    53  
    54  	em := make([]byte, emLen)
    55  	psLen := emLen - sLen - hLen - 2
    56  	db := em[:psLen+1+sLen]
    57  	h := em[psLen+1+sLen : emLen-1]
    58  
    59  	// 4.  Generate a random octet string salt of length sLen; if sLen = 0,
    60  	//     then salt is the empty string.
    61  	//
    62  	// 5.  Let
    63  	//       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
    64  	//
    65  	//     M' is an octet string of length 8 + hLen + sLen with eight
    66  	//     initial zero octets.
    67  	//
    68  	// 6.  Let H = Hash(M'), an octet string of length hLen.
    69  
    70  	var prefix [8]byte
    71  
    72  	hash.Write(prefix[:])
    73  	hash.Write(mHash)
    74  	hash.Write(salt)
    75  
    76  	h = hash.Sum(h[:0])
    77  	hash.Reset()
    78  
    79  	// 7.  Generate an octet string PS consisting of emLen - sLen - hLen - 2
    80  	//     zero octets. The length of PS may be 0.
    81  	//
    82  	// 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
    83  	//     emLen - hLen - 1.
    84  
    85  	db[psLen] = 0x01
    86  	copy(db[psLen+1:], salt)
    87  
    88  	// 9.  Let dbMask = MGF(H, emLen - hLen - 1).
    89  	//
    90  	// 10. Let maskedDB = DB \xor dbMask.
    91  
    92  	mgf1XOR(db, hash, h)
    93  
    94  	// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
    95  	//     maskedDB to zero.
    96  
    97  	db[0] &= 0xff >> (8*emLen - emBits)
    98  
    99  	// 12. Let EM = maskedDB || H || 0xbc.
   100  	em[emLen-1] = 0xbc
   101  
   102  	// 13. Output EM.
   103  	return em, nil
   104  }
   105  
   106  func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
   107  	// See RFC 8017, Section 9.1.2.
   108  
   109  	hLen := hash.Size()
   110  	if sLen == PSSSaltLengthEqualsHash {
   111  		sLen = hLen
   112  	}
   113  	emLen := (emBits + 7) / 8
   114  	if emLen != len(em) {
   115  		return errors.New("rsa: internal error: inconsistent length")
   116  	}
   117  
   118  	// 1.  If the length of M is greater than the input limitation for the
   119  	//     hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
   120  	//     and stop.
   121  	//
   122  	// 2.  Let mHash = Hash(M), an octet string of length hLen.
   123  	if hLen != len(mHash) {
   124  		return ErrVerification
   125  	}
   126  
   127  	// 3.  If emLen < hLen + sLen + 2, output "inconsistent" and stop.
   128  	if emLen < hLen+sLen+2 {
   129  		return ErrVerification
   130  	}
   131  
   132  	// 4.  If the rightmost octet of EM does not have hexadecimal value
   133  	//     0xbc, output "inconsistent" and stop.
   134  	if em[emLen-1] != 0xbc {
   135  		return ErrVerification
   136  	}
   137  
   138  	// 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
   139  	//     let H be the next hLen octets.
   140  	db := em[:emLen-hLen-1]
   141  	h := em[emLen-hLen-1 : emLen-1]
   142  
   143  	// 6.  If the leftmost 8 * emLen - emBits bits of the leftmost octet in
   144  	//     maskedDB are not all equal to zero, output "inconsistent" and
   145  	//     stop.
   146  	var bitMask byte = 0xff >> (8*emLen - emBits)
   147  	if em[0] & ^bitMask != 0 {
   148  		return ErrVerification
   149  	}
   150  
   151  	// 7.  Let dbMask = MGF(H, emLen - hLen - 1).
   152  	//
   153  	// 8.  Let DB = maskedDB \xor dbMask.
   154  	mgf1XOR(db, hash, h)
   155  
   156  	// 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
   157  	//     to zero.
   158  	db[0] &= bitMask
   159  
   160  	// If we don't know the salt length, look for the 0x01 delimiter.
   161  	if sLen == PSSSaltLengthAuto {
   162  		psLen := bytes.IndexByte(db, 0x01)
   163  		if psLen < 0 {
   164  			return ErrVerification
   165  		}
   166  		sLen = len(db) - psLen - 1
   167  	}
   168  
   169  	// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
   170  	//     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
   171  	//     position is "position 1") does not have hexadecimal value 0x01,
   172  	//     output "inconsistent" and stop.
   173  	psLen := emLen - hLen - sLen - 2
   174  	for _, e := range db[:psLen] {
   175  		if e != 0x00 {
   176  			return ErrVerification
   177  		}
   178  	}
   179  	if db[psLen] != 0x01 {
   180  		return ErrVerification
   181  	}
   182  
   183  	// 11.  Let salt be the last sLen octets of DB.
   184  	salt := db[len(db)-sLen:]
   185  
   186  	// 12.  Let
   187  	//          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
   188  	//     M' is an octet string of length 8 + hLen + sLen with eight
   189  	//     initial zero octets.
   190  	//
   191  	// 13. Let H' = Hash(M'), an octet string of length hLen.
   192  	var prefix [8]byte
   193  	hash.Write(prefix[:])
   194  	hash.Write(mHash)
   195  	hash.Write(salt)
   196  
   197  	h0 := hash.Sum(nil)
   198  
   199  	// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
   200  	if !bytes.Equal(h0, h) { // TODO: constant time?
   201  		return ErrVerification
   202  	}
   203  	return nil
   204  }
   205  
   206  // signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
   207  // Note that hashed must be the result of hashing the input message using the
   208  // given hash function. salt is a random sequence of bytes whose length will be
   209  // later used to verify the signature.
   210  func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
   211  	emBits := priv.N.BitLen() - 1
   212  	em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  
   217  	if boring.Enabled {
   218  		bkey, err := boringPrivateKey(priv)
   219  		if err != nil {
   220  			return nil, err
   221  		}
   222  		// Note: BoringCrypto always does decrypt "withCheck".
   223  		// (It's not just decrypt.)
   224  		s, err := boring.DecryptRSANoPadding(bkey, em)
   225  		if err != nil {
   226  			return nil, err
   227  		}
   228  		return s, nil
   229  	}
   230  
   231  	// RFC 8017: "Note that the octet length of EM will be one less than k if
   232  	// modBits - 1 is divisible by 8 and equal to k otherwise, where k is the
   233  	// length in octets of the RSA modulus n." 🙄
   234  	//
   235  	// This is extremely annoying, as all other encrypt and decrypt inputs are
   236  	// always the exact same size as the modulus. Since it only happens for
   237  	// weird modulus sizes, fix it by padding inefficiently.
   238  	if emLen, k := len(em), priv.Size(); emLen < k {
   239  		emNew := make([]byte, k)
   240  		copy(emNew[k-emLen:], em)
   241  		em = emNew
   242  	}
   243  
   244  	return decrypt(priv, em, withCheck)
   245  }
   246  
   247  const (
   248  	// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
   249  	// as possible when signing, and to be auto-detected when verifying.
   250  	PSSSaltLengthAuto = 0
   251  	// PSSSaltLengthEqualsHash causes the salt length to equal the length
   252  	// of the hash used in the signature.
   253  	PSSSaltLengthEqualsHash = -1
   254  )
   255  
   256  // PSSOptions contains options for creating and verifying PSS signatures.
   257  type PSSOptions struct {
   258  	// SaltLength controls the length of the salt used in the PSS signature. It
   259  	// can either be a positive number of bytes, or one of the special
   260  	// PSSSaltLength constants.
   261  	SaltLength int
   262  
   263  	// Hash is the hash function used to generate the message digest. If not
   264  	// zero, it overrides the hash function passed to SignPSS. It's required
   265  	// when using PrivateKey.Sign.
   266  	Hash crypto.Hash
   267  }
   268  
   269  // HashFunc returns opts.Hash so that [PSSOptions] implements [crypto.SignerOpts].
   270  func (opts *PSSOptions) HashFunc() crypto.Hash {
   271  	return opts.Hash
   272  }
   273  
   274  func (opts *PSSOptions) saltLength() int {
   275  	if opts == nil {
   276  		return PSSSaltLengthAuto
   277  	}
   278  	return opts.SaltLength
   279  }
   280  
   281  var invalidSaltLenErr = errors.New("crypto/rsa: PSSOptions.SaltLength cannot be negative")
   282  
   283  // SignPSS calculates the signature of digest using PSS.
   284  //
   285  // digest must be the result of hashing the input message using the given hash
   286  // function. The opts argument may be nil, in which case sensible defaults are
   287  // used. If opts.Hash is set, it overrides hash.
   288  //
   289  // The signature is randomized depending on the message, key, and salt size,
   290  // using bytes from rand. Most applications should use [crypto/rand.Reader] as
   291  // rand.
   292  func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
   293  	// Note that while we don't commit to deterministic execution with respect
   294  	// to the rand stream, we also don't apply MaybeReadByte, so per Hyrum's Law
   295  	// it's probably relied upon by some. It's a tolerable promise because a
   296  	// well-specified number of random bytes is included in the signature, in a
   297  	// well-specified way.
   298  
   299  	if boring.Enabled && rand == boring.RandReader {
   300  		bkey, err := boringPrivateKey(priv)
   301  		if err != nil {
   302  			return nil, err
   303  		}
   304  		return boring.SignRSAPSS(bkey, hash, digest, opts.saltLength())
   305  	}
   306  	boring.UnreachableExceptTests()
   307  
   308  	if opts != nil && opts.Hash != 0 {
   309  		hash = opts.Hash
   310  	}
   311  
   312  	saltLength := opts.saltLength()
   313  	switch saltLength {
   314  	case PSSSaltLengthAuto:
   315  		saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
   316  		if saltLength < 0 {
   317  			return nil, ErrMessageTooLong
   318  		}
   319  	case PSSSaltLengthEqualsHash:
   320  		saltLength = hash.Size()
   321  	default:
   322  		// If we get here saltLength is either > 0 or < -1, in the
   323  		// latter case we fail out.
   324  		if saltLength <= 0 {
   325  			return nil, invalidSaltLenErr
   326  		}
   327  	}
   328  	salt := make([]byte, saltLength)
   329  	if _, err := io.ReadFull(rand, salt); err != nil {
   330  		return nil, err
   331  	}
   332  	return signPSSWithSalt(priv, hash, digest, salt)
   333  }
   334  
   335  // VerifyPSS verifies a PSS signature.
   336  //
   337  // A valid signature is indicated by returning a nil error. digest must be the
   338  // result of hashing the input message using the given hash function. The opts
   339  // argument may be nil, in which case sensible defaults are used. opts.Hash is
   340  // ignored.
   341  //
   342  // The inputs are not considered confidential, and may leak through timing side
   343  // channels, or if an attacker has control of part of the inputs.
   344  func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
   345  	if boring.Enabled {
   346  		bkey, err := boringPublicKey(pub)
   347  		if err != nil {
   348  			return err
   349  		}
   350  		if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil {
   351  			return ErrVerification
   352  		}
   353  		return nil
   354  	}
   355  	if len(sig) != pub.Size() {
   356  		return ErrVerification
   357  	}
   358  	// Salt length must be either one of the special constants (-1 or 0)
   359  	// or otherwise positive. If it is < PSSSaltLengthEqualsHash (-1)
   360  	// we return an error.
   361  	if opts.saltLength() < PSSSaltLengthEqualsHash {
   362  		return invalidSaltLenErr
   363  	}
   364  
   365  	emBits := pub.N.BitLen() - 1
   366  	emLen := (emBits + 7) / 8
   367  	em, err := encrypt(pub, sig)
   368  	if err != nil {
   369  		return ErrVerification
   370  	}
   371  
   372  	// Like in signPSSWithSalt, deal with mismatches between emLen and the size
   373  	// of the modulus. The spec would have us wire emLen into the encoding
   374  	// function, but we'd rather always encode to the size of the modulus and
   375  	// then strip leading zeroes if necessary. This only happens for weird
   376  	// modulus sizes anyway.
   377  	for len(em) > emLen && len(em) > 0 {
   378  		if em[0] != 0 {
   379  			return ErrVerification
   380  		}
   381  		em = em[1:]
   382  	}
   383  
   384  	return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
   385  }
   386  

View as plain text