1
2
3
4
5 package rsa
6
7 import (
8 "bufio"
9 "crypto/internal/fips140/bigmod"
10 "encoding/hex"
11 "fmt"
12 "math/big"
13 "os"
14 "strings"
15 "testing"
16 )
17
18 func TestMillerRabin(t *testing.T) {
19 f, err := os.Open("testdata/miller_rabin_tests.txt")
20 if err != nil {
21 t.Fatal(err)
22 }
23
24 var expected bool
25 var W, B string
26 var lineNum int
27 scanner := bufio.NewScanner(f)
28 for scanner.Scan() {
29 lineNum++
30 line := scanner.Text()
31 if len(line) == 0 || line[0] == '#' {
32 continue
33 }
34
35 k, v, _ := strings.Cut(line, " = ")
36 switch k {
37 case "Result":
38 switch v {
39 case "Composite":
40 expected = millerRabinCOMPOSITE
41 case "PossiblyPrime":
42 expected = millerRabinPOSSIBLYPRIME
43 default:
44 t.Fatalf("unknown result %q on line %d", v, lineNum)
45 }
46 case "W":
47 W = v
48 case "B":
49 B = v
50
51 t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
52 if len(W)%2 != 0 {
53 W = "0" + W
54 }
55 for len(B) < len(W) {
56 B = "0" + B
57 }
58
59 mr, err := millerRabinSetup(decodeHex(t, W))
60 if err != nil {
61 t.Logf("W = %s", W)
62 t.Logf("B = %s", B)
63 t.Fatalf("failed to set up Miller-Rabin test: %v", err)
64 }
65
66 result, err := millerRabinIteration(mr, decodeHex(t, B))
67 if err != nil {
68 t.Logf("W = %s", W)
69 t.Logf("B = %s", B)
70 t.Fatalf("failed to run Miller-Rabin test: %v", err)
71 }
72
73 if result != expected {
74 t.Logf("W = %s", W)
75 t.Logf("B = %s", B)
76 t.Fatalf("unexpected result: got %v, want %v", result, expected)
77 }
78 })
79 default:
80 t.Fatalf("unknown key %q on line %d", k, lineNum)
81 }
82 }
83 if err := scanner.Err(); err != nil {
84 t.Fatal(err)
85 }
86 }
87
88 func TestTotient(t *testing.T) {
89 f, err := os.Open("testdata/gcd_lcm_tests.txt")
90 if err != nil {
91 t.Fatal(err)
92 }
93
94 var GCD, A, B, LCM string
95 var lineNum int
96 scanner := bufio.NewScanner(f)
97 for scanner.Scan() {
98 lineNum++
99 line := scanner.Text()
100 if len(line) == 0 || line[0] == '#' {
101 continue
102 }
103
104 k, v, _ := strings.Cut(line, " = ")
105 switch k {
106 case "GCD":
107 GCD = v
108 case "A":
109 A = v
110 case "B":
111 B = v
112 case "LCM":
113 LCM = v
114
115 t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
116 if A == "0" || B == "0" {
117 t.Skip("skipping test with zero input")
118 }
119 if LCM == "1" {
120 t.Skip("skipping test with LCM=1")
121 }
122
123 p, _ := bigmod.NewModulus(addOne(decodeHex(t, A)))
124 a, _ := bigmod.NewNat().SetBytes(decodeHex(t, A), p)
125 q, _ := bigmod.NewModulus(addOne(decodeHex(t, B)))
126 b, _ := bigmod.NewNat().SetBytes(decodeHex(t, B), q)
127
128 gcd, err := bigmod.NewNat().GCDVarTime(a, b)
129
130 if err == nil {
131 if got := strings.TrimLeft(hex.EncodeToString(gcd.Bytes(p)), "0"); got != GCD {
132 t.Fatalf("unexpected GCD: got %s, want %s", got, GCD)
133 }
134 }
135
136 lcm, err := totient(p, q)
137 if oddDivisorLargerThan32Bits(decodeHex(t, GCD)) {
138 if err != errDivisorTooLarge {
139 t.Fatalf("expected divisor too large error, got %v", err)
140 }
141 t.Skip("GCD too large")
142 }
143 if err != nil {
144 t.Fatalf("failed to calculate totient: %v", err)
145 }
146 if got := strings.TrimLeft(hex.EncodeToString(lcm.Nat().Bytes(lcm)), "0"); got != LCM {
147 t.Fatalf("unexpected LCM: got %s, want %s", got, LCM)
148 }
149 })
150 default:
151 t.Fatalf("unknown key %q on line %d", k, lineNum)
152 }
153 }
154 if err := scanner.Err(); err != nil {
155 t.Fatal(err)
156 }
157 }
158
159 func oddDivisorLargerThan32Bits(b []byte) bool {
160 x := new(big.Int).SetBytes(b)
161 x.Rsh(x, x.TrailingZeroBits())
162 return x.BitLen() > 32
163 }
164
165 func addOne(b []byte) []byte {
166 x := new(big.Int).SetBytes(b)
167 x.Add(x, big.NewInt(1))
168 return x.Bytes()
169 }
170
171 func decodeHex(t *testing.T, s string) []byte {
172 t.Helper()
173 if len(s)%2 != 0 {
174 s = "0" + s
175 }
176 b, err := hex.DecodeString(s)
177 if err != nil {
178 t.Fatalf("failed to decode hex %q: %v", s, err)
179 }
180 return b
181 }
182
View as plain text