v / vlib / math / big / exponentiation.v
327 lines · 275 sloc · 7.46 KB · a1b131c99aa9f3c0a96827d471c35bb3cead6eed
Raw
1module big
2
3/*
4for a detailed explanation on these internal functions and the algorithms they
5are based on refer to https://github.com/vlang/v/pull/18461
6*/
7
8// internal struct to make passing montgomery values simpler
9struct MontgomeryContext {
10 n Integer // |modulus|
11 ni Integer // n^(-1)
12 rr Integer // for conversions
13}
14
15// montgomery calculates a montgomery context for reductions to montgomery space based on
16// the modulus provided in the integer `m`; assume m is odd and m != 0
17fn (m Integer) montgomery() MontgomeryContext {
18 $if debug {
19 assert m != zero_int
20 assert m.is_odd()
21 }
22
23 n := m.abs()
24 b := u32(n.bit_len())
25
26 return MontgomeryContext{
27 n: n
28 // r = 2^(log_2(n))
29 // ri := multiplicative inverse of r in the ring Z/nZ
30 // ri * r == 1 (mod n)
31 // ni = ((ri * 2^(log_2(n))) - 1) / n
32 ni: (one_int.left_shift(b).mod_inv(n).left_shift(b) - one_int) / n
33 rr: one_int.left_shift(b * 2) % n
34 }
35}
36
37// mont_odd calculates `a^x (mod m)`, where `m` is odd by reducing `a` to montgomery space
38// and then exponentiating the value using the sliding window method and montgomery multiplication
39// -----
40// assumes a, x > 1 and m is odd
41@[direct_array_access]
42fn (a Integer) mont_odd(x Integer, m Integer) Integer {
43 $if debug {
44 assert a > one_int && x > one_int
45 assert m.is_odd()
46 }
47
48 window := get_window_size(u32(x.bit_len()))
49
50 mut table := []Integer{len: 1 << window}
51
52 ctx := m.montgomery()
53 aa := if a.signum < 0 || a.abs_cmp(m) >= 0 {
54 a % m
55 } else {
56 a
57 }
58
59 table[0] = aa.to_mont(ctx)
60
61 {
62 d := table[0].mont_mul(table[0], ctx)
63 for i := 1; i < table.len; i++ {
64 table[i] = table[i - 1].mont_mul(d, ctx)
65 }
66 }
67 mut r := if m.digits.last() & (u64(1) << (digit_bits - 1)) != 0 {
68 mut rdigits := []u64{len: m.digits.len}
69
70 rdigits[0] = (-m.digits[0]) & max_digit
71 for i := 1; i < m.digits.len; i++ {
72 rdigits[i] = (~m.digits[i]) & max_digit
73 }
74
75 Integer{
76 digits: rdigits
77 signum: 1
78 }
79 } else {
80 one_int.to_mont(ctx)
81 }
82
83 mut start := true
84 mut wstart := x.bit_len() - 1
85 mut wvalue := 0
86 mut wend := 0
87
88 for {
89 if !x.get_bit(u32(wstart)) {
90 if !start {
91 r = r.mont_mul(r, ctx)
92 }
93 if wstart == 0 {
94 break
95 }
96 wstart--
97 continue
98 }
99
100 wvalue = 1
101 wend = 0
102 for i := 1; i < window; i++ {
103 if wstart - i < 0 {
104 break
105 }
106 if x.get_bit(u32(wstart - i)) {
107 wvalue <<= (i - wend)
108 wvalue |= 1
109 wend = i
110 }
111 }
112
113 j := wend + 1
114 if !start {
115 for i := 0; i < j; i++ {
116 r = r.mont_mul(r, ctx)
117 }
118 }
119
120 r = r.mont_mul(table[wvalue >> 1], ctx)
121
122 wstart -= j
123 wvalue = 0
124 start = false
125 if wstart < 0 {
126 break
127 }
128 }
129
130 return r.from_mont(ctx)
131}
132
133// mont_even calculates `a^x (mod m)` where `m` is even. This is done by factoring the modulus
134// `m` into an odd integer `m1` and an even multiple of 2, which we'll call m2.
135// We then calculate:
136// x1 = a^x (mod m1)
137// x2 = a^x (mod m2)
138//
139// Exponentiation with the modulus `m1` can be done using the traditional montgomery method,
140// whereas `m2` is done using binary exponentiation modulo `m2`, which is fast seeing as we
141// simply mask the low bits.
142//
143// The result `y` then satisfies (where `==` denotes congruence):
144// y == x1 (mod m1)
145// y == x2 (mod m2)
146//
147// We then use the Chinese Remainder Theorem (mixed-radix conversion algorithm) to calculate `y`.
148// y = x1 + m1 * t
149// where
150// t = (x2 - x1) * m1^(-1) (mod m2)
151//
152// The multiplicative inverse of m1 in Z/m2Z exists since it is odd, therefore we can safely
153// use the unchecked internal function.
154//
155// See Montgomery Reduction with Even Modulus by Çetin Kaya Koç
156// (https://cetinkayakoc.net/docs/j34.pdf)
157// -----
158// assumes a, x > 1 and m is even
159@[direct_array_access]
160fn (a Integer) mont_even(x Integer, m Integer) Integer {
161 $if debug {
162 assert a > one_int && x > one_int
163 assert !m.is_odd()
164 }
165
166 m1, j := m.rsh_to_set_bit()
167 m2 := one_int.left_shift(j)
168
169 $if debug {
170 assert m1 * m2 == m
171 assert m1.is_odd() && !m2.is_odd()
172 }
173
174 mut x1 := a.mont_odd(x, m1)
175 mut x2 := a.exp_binary(x, m2)
176
177 m2n := u32(m2.bit_len()) - 1
178
179 m1i := m1.mod_inv(m2)
180
181 $if debug {
182 assert (m1i * m1).mask_bits(m2n) == one_int
183 }
184
185 t1 := x1.mask_bits(m2n)
186 t2 := x2.mask_bits(m2n)
187
188 t := (if t2.abs_cmp(t1) >= 0 {
189 (t2 - t1).mask_bits(m2n)
190 } else {
191 // (x2 - x1) % m2 = 1 + ((~((x2 % m2) - (x1 % m2))) % m2)
192 (t1 - t2).abs().bitwise_not().mask_bits(m2n) + one_int
193 } * m1i).mask_bits(m2n)
194
195 return x1 + m1 * t
196}
197
198// exp_binary calculates `a^x (mod m)`, where m is a power of 2
199// -----
200// assumes a, x > 1 and m = 2^n
201@[direct_array_access]
202fn (a Integer) exp_binary(x Integer, m Integer) Integer {
203 $if debug {
204 assert a > one_int && x > one_int
205 assert m.is_power_of_2()
206 }
207
208 n := u32(m.bit_len()) - 1
209
210 window := get_window_size(u32(x.bit_len()))
211
212 mut table := []Integer{len: 1 << window}
213
214 // table[i] = a^i + 1, since a^0 is known to be 1, there is no point
215 // in eventually multiplying by one, so the for loop part continues until it
216 // meets a block starting with a set bit
217 table[0] = a.mask_bits(n)
218
219 d := (table[0] * table[0]).mask_bits(n)
220 for i := 1; i < table.len; i++ {
221 table[i] = (table[i - 1] * d).mask_bits(n)
222 }
223
224 mut r := one_int
225
226 mut start := true
227 mut wstart := x.bit_len() - 1
228 mut wend := 0
229 mut wvalue := 1
230
231 for wstart >= 0 {
232 if !x.get_bit(u32(wstart)) {
233 // no point squaring while r = 1
234 if !start {
235 r = (r * r).mask_bits(n)
236 }
237 if wstart == 0 {
238 break
239 }
240 wstart--
241 continue
242 }
243
244 // the bit x[wstart] is now known to be 1, so no reason to check it again
245 for i := 1; i < window; i++ {
246 if wstart - i < 0 {
247 break
248 }
249 if x.get_bit(u32(wstart - i)) {
250 wvalue <<= (i - wend) // i - wend is the amount of 0 bits that have been read
251 wvalue |= 1
252 wend = i
253 }
254 }
255
256 j := wend + 1
257 // same as before; r has not been populated yet, so squaring wouldn't do anything
258 if !start {
259 for i := 0; i < j; i++ {
260 r = (r * r).mask_bits(n)
261 }
262 }
263
264 r = (r * table[wvalue >> 1]).mask_bits(n)
265
266 wstart -= j
267 wvalue = 1
268 wend = 0
269 start = false
270 }
271
272 return r.mask_bits(n)
273}
274
275// generally sticking to a window size of 4 for sliding window exponentiation
276// works well as the table stays relatively small and the blocks aren't too large.
277// though in terms of larger exponents it's faster to use larger windows (going over
278// a window size of 6, would cause extremely large table sizes, hindering performance)
279//
280// 6 is already a large window to use, requiring 64 elements in the table, so we'll
281// limit it to only the largest of exponents
282//
283// according to the paper on montgomery multiplication by Shay Gueron, for the
284// case of the exponent being 512 bits a window size of 5 is considered optimal
285@[inline]
286fn get_window_size(n u32) int {
287 return if n > 768 {
288 6
289 } else if n > 256 {
290 5
291 } else if n > 32 {
292 4
293 } else {
294 3
295 }
296}
297
298// mont_mul performs multiplication of two variables in montgomery
299// space and reduces the result to montgomery space
300fn (a Integer) mont_mul(b Integer, ctx MontgomeryContext) Integer {
301 if (a.digits.len + b.digits.len) > 2 * ctx.n.digits.len {
302 return zero_int
303 }
304
305 t := a * b
306
307 return t.from_mont(ctx)
308}
309
310fn (a Integer) to_mont(ctx MontgomeryContext) Integer {
311 return a.mont_mul(ctx.rr, ctx)
312}
313
314// See Fig. 1. "A Montgomery Reduction lemma" in
315// Efficient Software Implementations of Modular Exponentiation by Shay Gueron
316// (https://eprint.iacr.org/2011/239.pdf)
317fn (a Integer) from_mont(ctx MontgomeryContext) Integer {
318 log2n := u32(ctx.n.bit_len())
319
320 r := (a + ((a.mask_bits(log2n) * ctx.ni).mask_bits(log2n) * ctx.n)).right_shift(log2n)
321
322 return if r.abs_cmp(ctx.n) >= 0 {
323 r - ctx.n
324 } else {
325 r
326 }
327}
328