v2 / vlib / x / crypto / mldsa / mldsa.v
518 lines · 448 sloc · 12.16 KB · b615cd08d134956354a72dcc42a6a6ad4e39cb64
Raw
1// Copyright 2025 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4//
5// Ported to V from Go's crypto/internal/fips140/mldsa.
6
7// ML-DSA (Module-Lattice-Based Digital Signature Algorithm) per FIPS 204
8// https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.204.pdf
9
10module mldsa
11
12import crypto.rand
13import crypto.sha3
14import crypto.internal.subtle
15
16@[direct_array_access]
17fn slice_to_32(s []u8) [32]u8 {
18 mut a := [32]u8{}
19 for i in 0 .. 32 {
20 a[i] = s[i]
21 }
22 return a
23}
24
25@[direct_array_access]
26fn slice_to_64(s []u8) [64]u8 {
27 mut a := [64]u8{}
28 for i in 0 .. 64 {
29 a[i] = s[i]
30 }
31 return a
32}
33
34pub struct PrivateKey {
35 seed [32]u8
36 pk PublicKey
37 s1 []NttElement // len = l
38 s2 []NttElement // len = k
39 t0 []NttElement // len = k
40 k [32]u8
41}
42
43pub struct PublicKey {
44 raw []u8
45 p Params
46 a []NttElement // k*l matrix in NTT domain
47 t1 []NttElement // len = k, NTT(t1 * 2^d)
48 tr [64]u8
49}
50
51// algo. 1: ML-DSA.KeyGen (s. 5.1)
52pub fn PrivateKey.generate(kind Kind) !PrivateKey {
53 return new_private_key(slice_to_32(rand.read(32)!), kind.params())
54}
55
56pub fn PrivateKey.from_seed(seed []u8, kind Kind) !PrivateKey {
57 if seed.len != 32 {
58 return error('invalid seed length')
59 }
60 return new_private_key(slice_to_32(seed), kind.params())
61}
62
63// from FIPS 204 semi-expanded encoding. seed() and equal() are
64// meaningless on the result — use from_seed when possible.
65pub fn PrivateKey.from_bytes(raw []u8, kind Kind) !PrivateKey {
66 return new_private_key_from_bytes(raw, kind.params())
67}
68
69pub fn PublicKey.from_bytes(raw []u8, kind Kind) !PublicKey {
70 return new_public_key(raw, kind.params())
71}
72
73pub fn (sk &PrivateKey) public_key() &PublicKey {
74 return &sk.pk
75}
76
77pub fn (sk &PrivateKey) seed() []u8 {
78 mut s := []u8{len: 32}
79 for i in 0 .. 32 {
80 s[i] = sk.seed[i]
81 }
82 return s
83}
84
85pub fn (sk &PrivateKey) bytes() []u8 {
86 return sk_encode(sk.pk.raw[..32], sk.k, sk.pk.tr, sk.s1, sk.s2, sk.t0, sk.pk.p)
87}
88
89// seed-based constant-time comparison. not meaningful for from_bytes keys.
90pub fn (sk &PrivateKey) equal(other &PrivateKey) bool {
91 mut a := []u8{len: 32}
92 mut b := []u8{len: 32}
93 for i in 0 .. 32 {
94 a[i] = sk.seed[i]
95 b[i] = other.seed[i]
96 }
97 return sk.pk.p == other.pk.p && subtle.constant_time_compare(a, b) == 1
98}
99
100// constant-time comparison of the serialized key material. slower but works for from_bytes keys.
101pub fn (sk &PrivateKey) equal_bytes(other &PrivateKey) bool {
102 return sk.pk.p == other.pk.p && subtle.constant_time_compare(sk.bytes(), other.bytes()) == 1
103}
104
105// algo. 2/4: ML-DSA.Sign / HashML-DSA.Sign (s. 5.2, 5.4.1)
106pub fn (sk &PrivateKey) sign(msg []u8, opts SignerOpts) ![]u8 {
107 if opts.context.len > 255 {
108 return error('context too long')
109 }
110 mu := if opts.prehash != .none {
111 compute_mu_prehash(sk.pk.tr[..], msg, opts.context, opts.prehash)
112 } else {
113 compute_mu(sk.pk.tr[..], msg, opts.context)
114 }
115 if opts.deterministic {
116 return sign_internal(sk, mu, [32]u8{})
117 }
118 return sign_internal(sk, mu, slice_to_32(rand.read(32)!))
119}
120
121// sign_mu signs a precomputed mu value with explicit randomness.
122// mu must be 64 bytes. rnd must be 32 bytes (use all zeros for deterministic signing).
123pub fn (sk &PrivateKey) sign_mu(mu []u8, rnd []u8) ![]u8 {
124 if mu.len != 64 {
125 return error('mu must be 64 bytes')
126 }
127 if rnd.len != 32 {
128 return error('rnd must be 32 bytes')
129 }
130 return sign_internal(sk, slice_to_64(mu), slice_to_32(rnd))
131}
132
133pub fn (pk &PublicKey) bytes() []u8 {
134 return pk.raw.clone()
135}
136
137// tr returns the 64-byte transcript hash (H(pk)) used in mu computation.
138pub fn (pk &PublicKey) tr() []u8 {
139 return pk.tr[..]
140}
141
142pub fn (pk &PublicKey) equal(other &PublicKey) bool {
143 return pk.p == other.p && subtle.constant_time_compare(pk.raw, other.raw) == 1
144}
145
146// algo. 3/5: ML-DSA.Verify / HashML-DSA.Verify (s. 5.3, 5.4.1)
147pub fn (pk &PublicKey) verify(msg []u8, sig []u8, opts SignerOpts) !bool {
148 if opts.context.len > 255 {
149 return error('context too long')
150 }
151 mu := if opts.prehash != .none {
152 compute_mu_prehash(pk.tr[..], msg, opts.context, opts.prehash)
153 } else {
154 compute_mu(pk.tr[..], msg, opts.context)
155 }
156 return verify_internal(pk, mu, sig)
157}
158
159pub fn (pk &PublicKey) verify_mu(mu []u8, sig []u8) !bool {
160 if mu.len != 64 {
161 return error('mu must be exactly 64 bytes')
162 }
163 return verify_internal(pk, slice_to_64(mu), sig)
164}
165
166// algo. 6: ML-DSA.KeyGen_internal (s. 6.1)
167fn new_private_key(seed [32]u8, p Params) PrivateKey {
168 k, l := p.k, p.l
169
170 // expand seed into rho, rho', K
171 mut xi := sha3.new_shake256()
172 xi.write(seed[..])
173 xi.write([u8(k), u8(l)])
174 rho := xi.read(32)
175 rho_s := xi.read(64)
176 k_bytes := xi.read(32)
177
178 a := compute_matrix_a(rho, p)
179
180 mut s1 := []NttElement{len: l}
181 for r in 0 .. l {
182 s1[r] = ntt(sample_bounded_poly(rho_s, u8(r), p))
183 }
184 mut s2 := []NttElement{len: k}
185 for r in 0 .. k {
186 s2[r] = ntt(sample_bounded_poly(rho_s, u8(l + r), p))
187 }
188
189 // t_hat = A_hat * s1_hat + s2_hat
190 mut t_hat := []NttElement{len: k}
191 for i in 0 .. k {
192 t_hat[i] = s2[i]
193 for j in 0 .. l {
194 t_hat[i] = poly_add_ntt(t_hat[i], ntt_mul(a[i * l + j], s1[j]))
195 }
196 }
197
198 mut t1 := [][]u16{len: k, init: []u16{len: n}}
199 mut t0 := []NttElement{len: k}
200 for i in 0 .. k {
201 t_i := inverse_ntt(t_hat[i])
202 mut w := RingElement{}
203 for j in 0 .. n {
204 t1[i][j], w[j] = power2_round(t_i[j])
205 }
206 t0[i] = ntt(w)
207 }
208
209 pk_bytes := pk_encode(rho, t1, p)
210 tr := compute_pk_hash(pk_bytes)
211 t1_hat := compute_t1_hat(t1)
212
213 k_arr := slice_to_32(k_bytes)
214
215 return PrivateKey{
216 seed: seed
217 pk: PublicKey{
218 raw: pk_bytes
219 p: p
220 a: a
221 t1: t1_hat
222 tr: tr
223 }
224 s1: s1
225 s2: s2
226 t0: t0
227 k: k_arr
228 }
229}
230
231fn new_private_key_from_bytes(sk []u8, p Params) !PrivateKey {
232 k, l := p.k, p.l
233
234 rho, capital_k, tr, s1_ring, s2_ring, t0_ring := sk_decode(sk, p)!
235
236 a := compute_matrix_a(rho, p)
237
238 mut s1 := []NttElement{len: l}
239 for r in 0 .. l {
240 s1[r] = ntt(s1_ring[r])
241 }
242 mut s2 := []NttElement{len: k}
243 for r in 0 .. k {
244 s2[r] = ntt(s2_ring[r])
245 }
246 mut t0 := []NttElement{len: k}
247 for r in 0 .. k {
248 t0[r] = ntt(t0_ring[r])
249 }
250
251 // recompute t1 from rho, s1, s2 to verify consistency
252 mut t1 := [][]u16{len: k, init: []u16{len: n}}
253 for i in 0 .. k {
254 mut t_hat := s2[i]
255 for j in 0 .. l {
256 t_hat = poly_add_ntt(t_hat, ntt_mul(a[i * l + j], s1[j]))
257 }
258 t_i := inverse_ntt(t_hat)
259 for j in 0 .. n {
260 r1, r0 := power2_round(t_i[j])
261 t1[i][j] = r1
262 if r0 != t0_ring[i][j] {
263 return error('mldsa: private key inconsistent with t0')
264 }
265 }
266 }
267
268 pk_bytes := pk_encode(rho, t1, p)
269 computed_tr := compute_pk_hash(pk_bytes)
270 if computed_tr != tr {
271 return error('mldsa: private key inconsistent with public key hash')
272 }
273 t1_hat := compute_t1_hat(t1)
274
275 // use random bytes for seed since the semi-expanded format doesn't contain it
276 seed := slice_to_32(rand.read(32)!)
277
278 return PrivateKey{
279 seed: seed
280 pk: PublicKey{
281 raw: pk_bytes
282 p: p
283 a: a
284 t1: t1_hat
285 tr: tr
286 }
287 s1: s1
288 s2: s2
289 t0: t0
290 k: capital_k
291 }
292}
293
294fn new_public_key(raw []u8, p Params) !PublicKey {
295 k, l := p.k, p.l
296
297 rho, t1 := pk_decode(raw, p)!
298 a := compute_matrix_a(rho, p)
299 tr := compute_pk_hash(raw)
300 t1_hat := compute_t1_hat(t1)
301
302 return PublicKey{
303 raw: raw.clone()
304 p: p
305 a: a[..k * l].clone()
306 t1: t1_hat[..k].clone()
307 tr: tr
308 }
309}
310
311// algo. 2, lines 10-11: M' = 0x00 || |ctx| || ctx || M; mu = H(tr || M', 64)
312// compute_mu computes mu = H(tr || M', 64) where M' = 0x00 || |ctx| || ctx || msg.
313pub fn compute_mu(tr []u8, msg []u8, context string) [64]u8 {
314 mut h := sha3.new_shake256()
315 h.write(tr)
316 h.write([u8(0)]) // pure mode domain sep
317 h.write([u8(context.len)])
318 h.write(context.bytes())
319 h.write(msg)
320 return slice_to_64(h.read(64))
321}
322
323// algo. 7: ML-DSA.Sign_internal (s. 6.2)
324@[direct_array_access]
325fn sign_internal(sk &PrivateKey, mu [64]u8, random [32]u8) ![]u8 {
326 p := sk.pk.p
327 k, l := p.k, p.l
328 a := sk.pk.a
329 s1 := sk.s1
330 s2 := sk.s2
331 t0 := sk.t0
332
333 beta := u32(p.tau * p.eta)
334 gamma1 := u32(1) << p.gamma1
335 gamma1_beta := gamma1 - beta
336 gamma2 := (q - 1) / u32(p.gamma2)
337 gamma2_beta := gamma2 - beta
338
339 // line 7: rho'' = H(K || rnd || mu, 64)
340 mut h_nonce := sha3.new_shake256()
341 h_nonce.write(sk.k[..])
342 h_nonce.write(random[..])
343 h_nonce.write(mu[..])
344 nonce := h_nonce.read(64)
345
346 mut kappa := 0
347
348 mut y := []RingElement{len: l}
349 mut y_hat := []NttElement{len: l}
350 mut w := []RingElement{len: k}
351 mut cs1 := []RingElement{len: l}
352 mut cs2 := []RingElement{len: k}
353 mut z := []RingElement{len: l}
354 mut ct0 := []RingElement{len: k}
355 mut h := [][256]u8{len: k, init: [256]u8{}}
356 mut w1_buf := []u8{len: w1_encode_len(p)}
357
358 // lines 10-32: rejection sampling loop (bounded by max_sign_attempts)
359 for _ in 0 .. max_sign_attempts {
360 // line 11: y = ExpandMask(rho'', kappa) (algo. 34)
361 for r in 0 .. l {
362 counter := [u8(kappa & 0xff), u8(kappa >> 8)]
363 kappa++
364
365 mut h_y := sha3.new_shake256()
366 h_y.write(nonce)
367 h_y.write(counter)
368 v_bytes := h_y.read((p.gamma1 + 1) * n / 8)
369 y[r] = bit_unpack(v_bytes, p)
370 }
371
372 // line 12: w = NTT^-1(A_hat * NTT(y))
373 for i in 0 .. l {
374 y_hat[i] = ntt(y[i])
375 }
376 for i in 0 .. k {
377 mut w_hat := NttElement{}
378 for j in 0 .. l {
379 w_hat = poly_add_ntt(w_hat, ntt_mul(a[i * l + j], y_hat[j]))
380 }
381 w[i] = inverse_ntt(w_hat)
382 }
383
384 // line 13-14: w1 = HighBits(w); c_tilde = H(mu || w1Encode(w1), lambda/4)
385 mut h_ch := sha3.new_shake256()
386 h_ch.write(mu[..])
387 for i in 0 .. k {
388 w1_encode(high_bits(w[i], p), p, mut w1_buf)
389 h_ch.write(w1_buf)
390 }
391 ch := h_ch.read(p.lambda / 4)
392
393 // line 15-16: c = SampleInBall(c_tilde); c_hat = NTT(c)
394 c := ntt(sample_in_ball(ch, p))
395
396 // lines 17-20: cs1 = NTT^-1(c_hat * s1_hat); z = y + cs1
397 for i in 0 .. l {
398 cs1[i] = inverse_ntt(ntt_mul(c, s1[i]))
399 }
400 for i in 0 .. k {
401 cs2[i] = inverse_ntt(ntt_mul(c, s2[i]))
402 }
403
404 // line 23: ||z||_inf >= gamma1 - beta
405 mut reject := false
406 for i in 0 .. l {
407 z[i] = poly_add_ring(y[i], cs1[i])
408 if coefficients_exceed_bound(z[i], gamma1_beta) {
409 reject = true
410 break
411 }
412 }
413 if reject {
414 continue
415 }
416
417 // line 23: ||r0||_inf >= gamma2 - beta
418 reject = false
419 for i in 0 .. k {
420 r0 := poly_sub_ring(w[i], cs2[i])
421 if low_bits_exceed_bound(r0, gamma2_beta, p) {
422 reject = true
423 break
424 }
425 }
426 if reject {
427 continue
428 }
429
430 // line 25, 28: ct0 = NTT^-1(c_hat * t0_hat); ||ct0||_inf >= gamma2
431 reject = false
432 for i in 0 .. k {
433 ct0[i] = inverse_ntt(ntt_mul(c, t0[i]))
434 if coefficients_exceed_bound(ct0[i], gamma2) {
435 reject = true
436 break
437 }
438 }
439 if reject {
440 continue
441 }
442
443 // line 26, 28: h = MakeHint(-ct0, w - cs2 + ct0); count(h) > omega
444 mut count1s := 0
445 for i in 0 .. k {
446 hint_result, count := make_hint(ct0[i], w[i], cs2[i], p)
447 h[i] = hint_result
448 count1s += count
449 }
450 if count1s > p.omega {
451 continue
452 }
453
454 return sig_encode(ch, z, h, p) // line 33: sigEncode(c_tilde, z, h)
455 }
456 return error('signing failed: rejection sampling did not converge after ${max_sign_attempts} attempts')
457}
458
459// algo. 8: ML-DSA.Verify_internal (s. 6.3)
460@[direct_array_access]
461fn verify_internal(pk &PublicKey, mu [64]u8, sig []u8) !bool {
462 p := pk.p
463 k, l := p.k, p.l
464 t1 := pk.t1
465 a := pk.a
466
467 beta := u32(p.tau * p.eta)
468 gamma1 := u32(1) << p.gamma1
469 gamma1_beta := gamma1 - beta
470
471 ch, z, h := sig_decode(sig, p) or { return false }
472
473 c := ntt(sample_in_ball(ch, p))
474
475 // line 9: w'_approx = NTT^-1(A_hat * NTT(z) - NTT(c) * NTT(t1 * 2^d))
476 mut z_hat := []NttElement{len: l}
477 for i in 0 .. l {
478 z_hat[i] = ntt(z[i])
479 }
480 mut w := []RingElement{len: k}
481 for i in 0 .. k {
482 mut w_hat := NttElement{}
483 for j in 0 .. l {
484 w_hat = poly_add_ntt(w_hat, ntt_mul(a[i * l + j], z_hat[j]))
485 }
486 w_hat = poly_sub_ntt(w_hat, ntt_mul(c, t1[i]))
487 w[i] = inverse_ntt(w_hat)
488 }
489
490 // line 10: w'1 = UseHint(h, w'_approx)
491 mut w1 := [][256]u8{len: k, init: [256]u8{}}
492 for i in 0 .. k {
493 w1[i] = use_hint(w[i], h[i], p)
494 }
495
496 // line 12: c_tilde' = H(mu || w1Encode(w'1), lambda/4)
497 mut h_ch := sha3.new_shake256()
498 h_ch.write(mu[..])
499 mut w1_buf := []u8{len: w1_encode_len(p)}
500 for i in 0 .. k {
501 w1_encode(w1[i], p, mut w1_buf)
502 h_ch.write(w1_buf)
503 }
504 computed_ch := h_ch.read(p.lambda / 4)
505
506 // line 13: ||z||_inf < gamma1 - beta and c_tilde == c_tilde'
507 for i in 0 .. l {
508 if coefficients_exceed_bound(z[i], gamma1_beta) {
509 return false
510 }
511 }
512
513 if subtle.constant_time_compare(ch, computed_ch) != 1 {
514 return false
515 }
516
517 return true
518}
519