| 1 | module big |
| 2 | |
| 3 | /* |
| 4 | for a detailed explanation on these internal functions and the algorithms they |
| 5 | are based on refer to https://github.com/vlang/v/pull/18461 |
| 6 | */ |
| 7 | |
| 8 | // internal struct to make passing montgomery values simpler |
| 9 | struct MontgomeryContext { |
| 10 | n Integer // |modulus| |
| 11 | ni Integer // n^(-1) |
| 12 | rr Integer // for conversions |
| 13 | } |
| 14 | |
| 15 | // montgomery calculates a montgomery context for reductions to montgomery space based on |
| 16 | // the modulus provided in the integer `m`; assume m is odd and m != 0 |
| 17 | fn (m Integer) montgomery() MontgomeryContext { |
| 18 | $if debug { |
| 19 | assert m != zero_int |
| 20 | assert m.is_odd() |
| 21 | } |
| 22 | |
| 23 | n := m.abs() |
| 24 | b := u32(n.bit_len()) |
| 25 | |
| 26 | return MontgomeryContext{ |
| 27 | n: n |
| 28 | // r = 2^(log_2(n)) |
| 29 | // ri := multiplicative inverse of r in the ring Z/nZ |
| 30 | // ri * r == 1 (mod n) |
| 31 | // ni = ((ri * 2^(log_2(n))) - 1) / n |
| 32 | ni: (one_int.left_shift(b).mod_inv(n).left_shift(b) - one_int) / n |
| 33 | rr: one_int.left_shift(b * 2) % n |
| 34 | } |
| 35 | } |
| 36 | |
| 37 | // mont_odd calculates `a^x (mod m)`, where `m` is odd by reducing `a` to montgomery space |
| 38 | // and then exponentiating the value using the sliding window method and montgomery multiplication |
| 39 | // ----- |
| 40 | // assumes a, x > 1 and m is odd |
| 41 | @[direct_array_access] |
| 42 | fn (a Integer) mont_odd(x Integer, m Integer) Integer { |
| 43 | $if debug { |
| 44 | assert a > one_int && x > one_int |
| 45 | assert m.is_odd() |
| 46 | } |
| 47 | |
| 48 | window := get_window_size(u32(x.bit_len())) |
| 49 | |
| 50 | mut table := []Integer{len: 1 << window} |
| 51 | |
| 52 | ctx := m.montgomery() |
| 53 | aa := if a.signum < 0 || a.abs_cmp(m) >= 0 { |
| 54 | a % m |
| 55 | } else { |
| 56 | a |
| 57 | } |
| 58 | |
| 59 | table[0] = aa.to_mont(ctx) |
| 60 | |
| 61 | { |
| 62 | d := table[0].mont_mul(table[0], ctx) |
| 63 | for i := 1; i < table.len; i++ { |
| 64 | table[i] = table[i - 1].mont_mul(d, ctx) |
| 65 | } |
| 66 | } |
| 67 | mut r := if m.digits.last() & (u64(1) << (digit_bits - 1)) != 0 { |
| 68 | mut rdigits := []u64{len: m.digits.len} |
| 69 | |
| 70 | rdigits[0] = (-m.digits[0]) & max_digit |
| 71 | for i := 1; i < m.digits.len; i++ { |
| 72 | rdigits[i] = (~m.digits[i]) & max_digit |
| 73 | } |
| 74 | |
| 75 | Integer{ |
| 76 | digits: rdigits |
| 77 | signum: 1 |
| 78 | } |
| 79 | } else { |
| 80 | one_int.to_mont(ctx) |
| 81 | } |
| 82 | |
| 83 | mut start := true |
| 84 | mut wstart := x.bit_len() - 1 |
| 85 | mut wvalue := 0 |
| 86 | mut wend := 0 |
| 87 | |
| 88 | for { |
| 89 | if !x.get_bit(u32(wstart)) { |
| 90 | if !start { |
| 91 | r = r.mont_mul(r, ctx) |
| 92 | } |
| 93 | if wstart == 0 { |
| 94 | break |
| 95 | } |
| 96 | wstart-- |
| 97 | continue |
| 98 | } |
| 99 | |
| 100 | wvalue = 1 |
| 101 | wend = 0 |
| 102 | for i := 1; i < window; i++ { |
| 103 | if wstart - i < 0 { |
| 104 | break |
| 105 | } |
| 106 | if x.get_bit(u32(wstart - i)) { |
| 107 | wvalue <<= (i - wend) |
| 108 | wvalue |= 1 |
| 109 | wend = i |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | j := wend + 1 |
| 114 | if !start { |
| 115 | for i := 0; i < j; i++ { |
| 116 | r = r.mont_mul(r, ctx) |
| 117 | } |
| 118 | } |
| 119 | |
| 120 | r = r.mont_mul(table[wvalue >> 1], ctx) |
| 121 | |
| 122 | wstart -= j |
| 123 | wvalue = 0 |
| 124 | start = false |
| 125 | if wstart < 0 { |
| 126 | break |
| 127 | } |
| 128 | } |
| 129 | |
| 130 | return r.from_mont(ctx) |
| 131 | } |
| 132 | |
| 133 | // mont_even calculates `a^x (mod m)` where `m` is even. This is done by factoring the modulus |
| 134 | // `m` into an odd integer `m1` and an even multiple of 2, which we'll call m2. |
| 135 | // We then calculate: |
| 136 | // x1 = a^x (mod m1) |
| 137 | // x2 = a^x (mod m2) |
| 138 | // |
| 139 | // Exponentiation with the modulus `m1` can be done using the traditional montgomery method, |
| 140 | // whereas `m2` is done using binary exponentiation modulo `m2`, which is fast seeing as we |
| 141 | // simply mask the low bits. |
| 142 | // |
| 143 | // The result `y` then satisfies (where `==` denotes congruence): |
| 144 | // y == x1 (mod m1) |
| 145 | // y == x2 (mod m2) |
| 146 | // |
| 147 | // We then use the Chinese Remainder Theorem (mixed-radix conversion algorithm) to calculate `y`. |
| 148 | // y = x1 + m1 * t |
| 149 | // where |
| 150 | // t = (x2 - x1) * m1^(-1) (mod m2) |
| 151 | // |
| 152 | // The multiplicative inverse of m1 in Z/m2Z exists since it is odd, therefore we can safely |
| 153 | // use the unchecked internal function. |
| 154 | // |
| 155 | // See Montgomery Reduction with Even Modulus by Çetin Kaya Koç |
| 156 | // (https://cetinkayakoc.net/docs/j34.pdf) |
| 157 | // ----- |
| 158 | // assumes a, x > 1 and m is even |
| 159 | @[direct_array_access] |
| 160 | fn (a Integer) mont_even(x Integer, m Integer) Integer { |
| 161 | $if debug { |
| 162 | assert a > one_int && x > one_int |
| 163 | assert !m.is_odd() |
| 164 | } |
| 165 | |
| 166 | m1, j := m.rsh_to_set_bit() |
| 167 | m2 := one_int.left_shift(j) |
| 168 | |
| 169 | $if debug { |
| 170 | assert m1 * m2 == m |
| 171 | assert m1.is_odd() && !m2.is_odd() |
| 172 | } |
| 173 | |
| 174 | mut x1 := a.mont_odd(x, m1) |
| 175 | mut x2 := a.exp_binary(x, m2) |
| 176 | |
| 177 | m2n := u32(m2.bit_len()) - 1 |
| 178 | |
| 179 | m1i := m1.mod_inv(m2) |
| 180 | |
| 181 | $if debug { |
| 182 | assert (m1i * m1).mask_bits(m2n) == one_int |
| 183 | } |
| 184 | |
| 185 | t1 := x1.mask_bits(m2n) |
| 186 | t2 := x2.mask_bits(m2n) |
| 187 | |
| 188 | t := (if t2.abs_cmp(t1) >= 0 { |
| 189 | (t2 - t1).mask_bits(m2n) |
| 190 | } else { |
| 191 | // (x2 - x1) % m2 = 1 + ((~((x2 % m2) - (x1 % m2))) % m2) |
| 192 | (t1 - t2).abs().bitwise_not().mask_bits(m2n) + one_int |
| 193 | } * m1i).mask_bits(m2n) |
| 194 | |
| 195 | return x1 + m1 * t |
| 196 | } |
| 197 | |
| 198 | // exp_binary calculates `a^x (mod m)`, where m is a power of 2 |
| 199 | // ----- |
| 200 | // assumes a, x > 1 and m = 2^n |
| 201 | @[direct_array_access] |
| 202 | fn (a Integer) exp_binary(x Integer, m Integer) Integer { |
| 203 | $if debug { |
| 204 | assert a > one_int && x > one_int |
| 205 | assert m.is_power_of_2() |
| 206 | } |
| 207 | |
| 208 | n := u32(m.bit_len()) - 1 |
| 209 | |
| 210 | window := get_window_size(u32(x.bit_len())) |
| 211 | |
| 212 | mut table := []Integer{len: 1 << window} |
| 213 | |
| 214 | // table[i] = a^i + 1, since a^0 is known to be 1, there is no point |
| 215 | // in eventually multiplying by one, so the for loop part continues until it |
| 216 | // meets a block starting with a set bit |
| 217 | table[0] = a.mask_bits(n) |
| 218 | |
| 219 | d := (table[0] * table[0]).mask_bits(n) |
| 220 | for i := 1; i < table.len; i++ { |
| 221 | table[i] = (table[i - 1] * d).mask_bits(n) |
| 222 | } |
| 223 | |
| 224 | mut r := one_int |
| 225 | |
| 226 | mut start := true |
| 227 | mut wstart := x.bit_len() - 1 |
| 228 | mut wend := 0 |
| 229 | mut wvalue := 1 |
| 230 | |
| 231 | for wstart >= 0 { |
| 232 | if !x.get_bit(u32(wstart)) { |
| 233 | // no point squaring while r = 1 |
| 234 | if !start { |
| 235 | r = (r * r).mask_bits(n) |
| 236 | } |
| 237 | if wstart == 0 { |
| 238 | break |
| 239 | } |
| 240 | wstart-- |
| 241 | continue |
| 242 | } |
| 243 | |
| 244 | // the bit x[wstart] is now known to be 1, so no reason to check it again |
| 245 | for i := 1; i < window; i++ { |
| 246 | if wstart - i < 0 { |
| 247 | break |
| 248 | } |
| 249 | if x.get_bit(u32(wstart - i)) { |
| 250 | wvalue <<= (i - wend) // i - wend is the amount of 0 bits that have been read |
| 251 | wvalue |= 1 |
| 252 | wend = i |
| 253 | } |
| 254 | } |
| 255 | |
| 256 | j := wend + 1 |
| 257 | // same as before; r has not been populated yet, so squaring wouldn't do anything |
| 258 | if !start { |
| 259 | for i := 0; i < j; i++ { |
| 260 | r = (r * r).mask_bits(n) |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | r = (r * table[wvalue >> 1]).mask_bits(n) |
| 265 | |
| 266 | wstart -= j |
| 267 | wvalue = 1 |
| 268 | wend = 0 |
| 269 | start = false |
| 270 | } |
| 271 | |
| 272 | return r.mask_bits(n) |
| 273 | } |
| 274 | |
| 275 | // generally sticking to a window size of 4 for sliding window exponentiation |
| 276 | // works well as the table stays relatively small and the blocks aren't too large. |
| 277 | // though in terms of larger exponents it's faster to use larger windows (going over |
| 278 | // a window size of 6, would cause extremely large table sizes, hindering performance) |
| 279 | // |
| 280 | // 6 is already a large window to use, requiring 64 elements in the table, so we'll |
| 281 | // limit it to only the largest of exponents |
| 282 | // |
| 283 | // according to the paper on montgomery multiplication by Shay Gueron, for the |
| 284 | // case of the exponent being 512 bits a window size of 5 is considered optimal |
| 285 | @[inline] |
| 286 | fn get_window_size(n u32) int { |
| 287 | return if n > 768 { |
| 288 | 6 |
| 289 | } else if n > 256 { |
| 290 | 5 |
| 291 | } else if n > 32 { |
| 292 | 4 |
| 293 | } else { |
| 294 | 3 |
| 295 | } |
| 296 | } |
| 297 | |
| 298 | // mont_mul performs multiplication of two variables in montgomery |
| 299 | // space and reduces the result to montgomery space |
| 300 | fn (a Integer) mont_mul(b Integer, ctx MontgomeryContext) Integer { |
| 301 | if (a.digits.len + b.digits.len) > 2 * ctx.n.digits.len { |
| 302 | return zero_int |
| 303 | } |
| 304 | |
| 305 | t := a * b |
| 306 | |
| 307 | return t.from_mont(ctx) |
| 308 | } |
| 309 | |
| 310 | fn (a Integer) to_mont(ctx MontgomeryContext) Integer { |
| 311 | return a.mont_mul(ctx.rr, ctx) |
| 312 | } |
| 313 | |
| 314 | // See Fig. 1. "A Montgomery Reduction lemma" in |
| 315 | // Efficient Software Implementations of Modular Exponentiation by Shay Gueron |
| 316 | // (https://eprint.iacr.org/2011/239.pdf) |
| 317 | fn (a Integer) from_mont(ctx MontgomeryContext) Integer { |
| 318 | log2n := u32(ctx.n.bit_len()) |
| 319 | |
| 320 | r := (a + ((a.mask_bits(log2n) * ctx.ni).mask_bits(log2n) * ctx.n)).right_shift(log2n) |
| 321 | |
| 322 | return if r.abs_cmp(ctx.n) >= 0 { |
| 323 | r - ctx.n |
| 324 | } else { |
| 325 | r |
| 326 | } |
| 327 | } |
| 328 | |