Source file src/crypto/internal/fips140/rsa/keygen_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rsa
     6  
     7  import (
     8  	"bufio"
     9  	"crypto/internal/fips140/bigmod"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"math/big"
    13  	"os"
    14  	"strings"
    15  	"testing"
    16  )
    17  
    18  func TestMillerRabin(t *testing.T) {
    19  	f, err := os.Open("testdata/miller_rabin_tests.txt")
    20  	if err != nil {
    21  		t.Fatal(err)
    22  	}
    23  
    24  	var expected bool
    25  	var W, B string
    26  	var lineNum int
    27  	scanner := bufio.NewScanner(f)
    28  	for scanner.Scan() {
    29  		lineNum++
    30  		line := scanner.Text()
    31  		if len(line) == 0 || line[0] == '#' {
    32  			continue
    33  		}
    34  
    35  		k, v, _ := strings.Cut(line, " = ")
    36  		switch k {
    37  		case "Result":
    38  			switch v {
    39  			case "Composite":
    40  				expected = millerRabinCOMPOSITE
    41  			case "PossiblyPrime":
    42  				expected = millerRabinPOSSIBLYPRIME
    43  			default:
    44  				t.Fatalf("unknown result %q on line %d", v, lineNum)
    45  			}
    46  		case "W":
    47  			W = v
    48  		case "B":
    49  			B = v
    50  
    51  			t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
    52  				if len(W)%2 != 0 {
    53  					W = "0" + W
    54  				}
    55  				for len(B) < len(W) {
    56  					B = "0" + B
    57  				}
    58  
    59  				mr, err := millerRabinSetup(decodeHex(t, W))
    60  				if err != nil {
    61  					t.Logf("W = %s", W)
    62  					t.Logf("B = %s", B)
    63  					t.Fatalf("failed to set up Miller-Rabin test: %v", err)
    64  				}
    65  
    66  				result, err := millerRabinIteration(mr, decodeHex(t, B))
    67  				if err != nil {
    68  					t.Logf("W = %s", W)
    69  					t.Logf("B = %s", B)
    70  					t.Fatalf("failed to run Miller-Rabin test: %v", err)
    71  				}
    72  
    73  				if result != expected {
    74  					t.Logf("W = %s", W)
    75  					t.Logf("B = %s", B)
    76  					t.Fatalf("unexpected result: got %v, want %v", result, expected)
    77  				}
    78  			})
    79  		default:
    80  			t.Fatalf("unknown key %q on line %d", k, lineNum)
    81  		}
    82  	}
    83  	if err := scanner.Err(); err != nil {
    84  		t.Fatal(err)
    85  	}
    86  }
    87  
    88  func TestTotient(t *testing.T) {
    89  	f, err := os.Open("testdata/gcd_lcm_tests.txt")
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  
    94  	var GCD, A, B, LCM string
    95  	var lineNum int
    96  	scanner := bufio.NewScanner(f)
    97  	for scanner.Scan() {
    98  		lineNum++
    99  		line := scanner.Text()
   100  		if len(line) == 0 || line[0] == '#' {
   101  			continue
   102  		}
   103  
   104  		k, v, _ := strings.Cut(line, " = ")
   105  		switch k {
   106  		case "GCD":
   107  			GCD = v
   108  		case "A":
   109  			A = v
   110  		case "B":
   111  			B = v
   112  		case "LCM":
   113  			LCM = v
   114  
   115  			t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
   116  				if A == "0" || B == "0" {
   117  					t.Skip("skipping test with zero input")
   118  				}
   119  				if LCM == "1" {
   120  					t.Skip("skipping test with LCM=1")
   121  				}
   122  
   123  				p, _ := bigmod.NewModulus(addOne(decodeHex(t, A)))
   124  				a, _ := bigmod.NewNat().SetBytes(decodeHex(t, A), p)
   125  				q, _ := bigmod.NewModulus(addOne(decodeHex(t, B)))
   126  				b, _ := bigmod.NewNat().SetBytes(decodeHex(t, B), q)
   127  
   128  				gcd, err := bigmod.NewNat().GCDVarTime(a, b)
   129  				// GCD doesn't work if a and b are both even, but LCM handles it.
   130  				if err == nil {
   131  					if got := strings.TrimLeft(hex.EncodeToString(gcd.Bytes(p)), "0"); got != GCD {
   132  						t.Fatalf("unexpected GCD: got %s, want %s", got, GCD)
   133  					}
   134  				}
   135  
   136  				lcm, err := totient(p, q)
   137  				if oddDivisorLargerThan32Bits(decodeHex(t, GCD)) {
   138  					if err != errDivisorTooLarge {
   139  						t.Fatalf("expected divisor too large error, got %v", err)
   140  					}
   141  					t.Skip("GCD too large")
   142  				}
   143  				if err != nil {
   144  					t.Fatalf("failed to calculate totient: %v", err)
   145  				}
   146  				if got := strings.TrimLeft(hex.EncodeToString(lcm.Nat().Bytes(lcm)), "0"); got != LCM {
   147  					t.Fatalf("unexpected LCM: got %s, want %s", got, LCM)
   148  				}
   149  			})
   150  		default:
   151  			t.Fatalf("unknown key %q on line %d", k, lineNum)
   152  		}
   153  	}
   154  	if err := scanner.Err(); err != nil {
   155  		t.Fatal(err)
   156  	}
   157  }
   158  
   159  func oddDivisorLargerThan32Bits(b []byte) bool {
   160  	x := new(big.Int).SetBytes(b)
   161  	x.Rsh(x, x.TrailingZeroBits())
   162  	return x.BitLen() > 32
   163  }
   164  
   165  func addOne(b []byte) []byte {
   166  	x := new(big.Int).SetBytes(b)
   167  	x.Add(x, big.NewInt(1))
   168  	return x.Bytes()
   169  }
   170  
   171  func decodeHex(t *testing.T, s string) []byte {
   172  	t.Helper()
   173  	if len(s)%2 != 0 {
   174  		s = "0" + s
   175  	}
   176  	b, err := hex.DecodeString(s)
   177  	if err != nil {
   178  		t.Fatalf("failed to decode hex %q: %v", s, err)
   179  	}
   180  	return b
   181  }
   182  

View as plain text