| 1 | // vtest build: has_modern_openssl? |
| 2 | // regenerate go test vecs: v run testdata/gen.vsh [go-source-path] |
| 3 | |
| 4 | module mldsa |
| 5 | |
| 6 | import crypto.sha256 |
| 7 | import encoding.hex |
| 8 | import json |
| 9 | import os |
| 10 | |
| 11 | struct TestVec { |
| 12 | kind string |
| 13 | seed string |
| 14 | msg string |
| 15 | pk_sha256 string |
| 16 | sig_sha256 string |
| 17 | context string |
| 18 | } |
| 19 | |
| 20 | const vecs_json = os.read_file(os.real_path(os.join_path(os.dir(@FILE), 'testdata', 'vectors.json'))) or { |
| 21 | panic(err) |
| 22 | } |
| 23 | const test_vecs = json.decode([]TestVec, vecs_json) or { panic(err) } |
| 24 | |
| 25 | fn parse_kind(s string) Kind { |
| 26 | return match s { |
| 27 | 'ml_dsa_44' { Kind.ml_dsa_44 } |
| 28 | 'ml_dsa_65' { Kind.ml_dsa_65 } |
| 29 | 'ml_dsa_87' { Kind.ml_dsa_87 } |
| 30 | else { panic('unknown kind: ${s}') } |
| 31 | } |
| 32 | } |
| 33 | |
| 34 | fn test_keygen_sign_verify() { |
| 35 | assert test_vecs.len > 0, 'no test vectors loaded' |
| 36 | |
| 37 | for tv in test_vecs { |
| 38 | kind := parse_kind(tv.kind) |
| 39 | seed := hex.decode(tv.seed) or { panic(err) } |
| 40 | msg := hex.decode(tv.msg) or { panic(err) } |
| 41 | expected_pk_hash := hex.decode(tv.pk_sha256) or { panic(err) } |
| 42 | expected_sig_hash := hex.decode(tv.sig_sha256) or { panic(err) } |
| 43 | |
| 44 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 45 | pk := sk.public_key() |
| 46 | |
| 47 | pk_hash := sha256.sum(pk.bytes()) |
| 48 | assert pk_hash[..] == expected_pk_hash, 'pk hash mismatch for ${tv.kind} seed=${tv.seed[..16]}...' |
| 49 | |
| 50 | sig := sk.sign(msg, deterministic: true, context: tv.context) or { panic(err) } |
| 51 | sig_hash := sha256.sum(sig) |
| 52 | assert sig_hash[..] == expected_sig_hash, 'sig hash mismatch for ${tv.kind} seed=${tv.seed[..16]}...' |
| 53 | |
| 54 | verified := pk.verify(msg, sig, context: tv.context) or { panic(err) } |
| 55 | assert verified, 'verify returned false for ${tv.kind} seed=${tv.seed[..16]}...' |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | fn test_verify_rejects_bad_signature() { |
| 60 | for kind in [Kind.ml_dsa_44, .ml_dsa_65, .ml_dsa_87] { |
| 61 | seed := []u8{len: 32, init: 0x00} |
| 62 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 63 | pk := sk.public_key() |
| 64 | msg := 'deadbeef'.bytes() |
| 65 | |
| 66 | sig := sk.sign(msg, deterministic: true) or { panic(err) } |
| 67 | |
| 68 | mut bad_sig := sig.clone() |
| 69 | bad_sig[10] ^= 0xff |
| 70 | |
| 71 | result := pk.verify(msg, bad_sig) or { false } |
| 72 | assert result == false, 'verify should reject tampered sig for ${kind}' |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | fn test_verify_rejects_wrong_message() { |
| 77 | for kind in [Kind.ml_dsa_44, .ml_dsa_65, .ml_dsa_87] { |
| 78 | seed := []u8{len: 32, init: 0x01} |
| 79 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 80 | pk := sk.public_key() |
| 81 | msg := 'the beef is alive'.bytes() |
| 82 | |
| 83 | sig := sk.sign(msg, deterministic: true) or { panic(err) } |
| 84 | |
| 85 | result := pk.verify('I love strawberries'.bytes(), sig) or { false } |
| 86 | assert result == false, 'verify should reject wrong message for ${kind}' |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | fn test_verify_rejects_wrong_context() { |
| 91 | for kind in [Kind.ml_dsa_44, .ml_dsa_65, .ml_dsa_87] { |
| 92 | seed := []u8{len: 32, init: 0x02} |
| 93 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 94 | pk := sk.public_key() |
| 95 | msg := 'very cool message'.bytes() |
| 96 | |
| 97 | sig := sk.sign(msg, deterministic: true, context: 'some context a') or { panic(err) } |
| 98 | |
| 99 | result := pk.verify(msg, sig, context: 'another context b') or { false } |
| 100 | assert result == false, 'verify should reject wrong context for ${kind}' |
| 101 | } |
| 102 | } |
| 103 | |
| 104 | fn test_public_key_roundtrip() { |
| 105 | for kind in [Kind.ml_dsa_44, .ml_dsa_65, .ml_dsa_87] { |
| 106 | seed := []u8{len: 32, init: 0x03} |
| 107 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 108 | pk := sk.public_key() |
| 109 | msg := 'pk roundtrip'.bytes() |
| 110 | |
| 111 | sig := sk.sign(msg, deterministic: true) or { panic(err) } |
| 112 | |
| 113 | pk2 := PublicKey.from_bytes(pk.bytes(), kind) or { panic(err) } |
| 114 | assert pk.equal(&pk2), 'roundtripped public key not equal' |
| 115 | |
| 116 | verified := pk2.verify(msg, sig) or { panic(err) } |
| 117 | assert verified, 'verify failed after public key roundtrip for ${kind}' |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | fn test_prehash_sign_verify() { |
| 122 | prehashes := [ |
| 123 | PreHash.sha2_256, |
| 124 | .sha2_384, |
| 125 | .sha2_512, |
| 126 | .sha3_256, |
| 127 | .sha3_512, |
| 128 | .shake_128, |
| 129 | .shake_256, |
| 130 | ] |
| 131 | for kind in [Kind.ml_dsa_44, .ml_dsa_65, .ml_dsa_87] { |
| 132 | seed := []u8{len: 32, init: 0x05} |
| 133 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 134 | pk := sk.public_key() |
| 135 | msg := 'prehash test message'.bytes() |
| 136 | |
| 137 | for ph in prehashes { |
| 138 | sig := sk.sign(msg, deterministic: true, prehash: ph) or { panic(err) } |
| 139 | verified := pk.verify(msg, sig, prehash: ph) or { panic(err) } |
| 140 | assert verified, 'prehash verify failed for ${kind} ${ph}' |
| 141 | |
| 142 | // pure verify must reject a prehashed sig |
| 143 | pure_result := pk.verify(msg, sig) or { false } |
| 144 | assert pure_result == false, 'pure verify should reject prehash signature for ${kind} ${ph}' |
| 145 | } |
| 146 | } |
| 147 | } |
| 148 | |
| 149 | fn test_field_to_montgomery_roundtrip() { |
| 150 | for val in [u32(0), 1, 2, 100, 1000, q - 1] { |
| 151 | m := field_to_montgomery(val) or { panic(err) } |
| 152 | back := field_from_montgomery(m) |
| 153 | assert back == val, 'roundtrip failed for ${val}: got ${back}' |
| 154 | } |
| 155 | } |
| 156 | |
| 157 | fn test_field_add_sub() { |
| 158 | a := field_to_montgomery(100) or { panic(err) } |
| 159 | b := field_to_montgomery(200) or { panic(err) } |
| 160 | sum := field_add(a, b) |
| 161 | assert field_from_montgomery(sum) == 300 |
| 162 | |
| 163 | diff := field_sub(sum, b) |
| 164 | assert field_from_montgomery(diff) == 100 |
| 165 | } |
| 166 | |
| 167 | fn test_field_mul() { |
| 168 | a := field_to_montgomery(1000) or { panic(err) } |
| 169 | b := field_to_montgomery(2000) or { panic(err) } |
| 170 | prod := field_montgomery_mul(a, b) |
| 171 | assert field_from_montgomery(prod) == (1000 * 2000) % q |
| 172 | } |
| 173 | |
| 174 | fn test_ntt_inverse_ntt_roundtrip() { |
| 175 | mut f := RingElement{} |
| 176 | for i in 0 .. n { |
| 177 | f[i] = field_to_montgomery(u32(i % 100)) or { panic(err) } |
| 178 | } |
| 179 | ntt_f := ntt(f) |
| 180 | back := inverse_ntt(ntt_f) |
| 181 | for i in 0 .. n { |
| 182 | assert field_from_montgomery(back[i]) == field_from_montgomery(f[i]), 'NTT roundtrip failed at index ${i}' |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | fn test_ntt_mul_is_polynomial_product() { |
| 187 | // (1 + x)^2 ?= x^2 + 2x + 1 |
| 188 | mut a := RingElement{} |
| 189 | a[0] = field_to_montgomery(1) or { panic(err) } |
| 190 | a[1] = field_to_montgomery(1) or { panic(err) } |
| 191 | |
| 192 | a_ntt := ntt(a) |
| 193 | prod_ntt := ntt_mul(a_ntt, a_ntt) |
| 194 | prod := inverse_ntt(prod_ntt) |
| 195 | |
| 196 | assert field_from_montgomery(prod[0]) == 1, 'expected x^2' |
| 197 | assert field_from_montgomery(prod[1]) == 2, 'expected 2x' |
| 198 | assert field_from_montgomery(prod[2]) == 1, 'expected 1' |
| 199 | |
| 200 | for i in 3 .. n { |
| 201 | assert field_from_montgomery(prod[i]) == 0, 'expected 0 at index ${i}, got ${field_from_montgomery(prod[i])}' |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | fn test_power2_round() { |
| 206 | for val in [u32(0), 1, 100, 1000, q / 2, q - 1] { |
| 207 | r := field_to_montgomery(val) or { panic(err) } |
| 208 | hi, lo := power2_round(r) |
| 209 | reconstructed := field_add(field_to_montgomery(u32(hi) << d) or { panic(err) }, lo) |
| 210 | assert field_from_montgomery(reconstructed) == val, 'power2_round failed for ${val}' |
| 211 | } |
| 212 | } |
| 213 | |
| 214 | fn test_private_key_roundtrip() { |
| 215 | for kind in [Kind.ml_dsa_44, .ml_dsa_65, .ml_dsa_87] { |
| 216 | seed := []u8{len: 32, init: 0x04} |
| 217 | sk := PrivateKey.from_seed(seed, kind) or { panic(err) } |
| 218 | sk2 := PrivateKey.from_seed(sk.seed(), kind) or { panic(err) } |
| 219 | assert sk.equal(&sk2), 'roundtripped private key not equal for ${kind}' |
| 220 | } |
| 221 | } |
| 222 | |