v / vlib / math / big / special_array_ops.v
368 lines · 331 sloc · 8.47 KB · 38519eca27657e6e678b26fe4cd56c12691eb200
Raw
1module big
2
3import strings
4
5@[direct_array_access; inline]
6fn shrink_tail_zeros(mut a []u64) {
7 mut alen := a.len
8 for alen > 0 && a[alen - 1] == 0 {
9 alen--
10 }
11 unsafe {
12 a.len = alen
13 }
14}
15
16@[direct_array_access; inline]
17fn (i &Integer) shrink_tail_zeros() {
18 mut alen := i.digits.len
19 for alen > 0 && i.digits[alen - 1] == 0 {
20 alen--
21 }
22 unsafe {
23 i.digits.len = alen
24 }
25}
26
27// debug_u64_str output a `[]u64`
28@[direct_array_access]
29fn debug_u64_str(a []u64) string {
30 mut sb := strings.new_builder(30)
31 sb.write_string('[')
32 mut first := true
33 for i in 0 .. a.len {
34 if !first {
35 sb.write_string(', ')
36 }
37 sb.write_string('0x${a[i].hex():016}')
38 first = false
39 }
40 sb.write_string(']')
41 return sb.str()
42}
43
44// debug_u32_str for 32bit bignum test only, convert a `[]u64` to `[]u32`.
45@[direct_array_access]
46fn debug_u32_str(a []u64) string {
47 mut b := []u32{cap: a.len * 2}
48 mut curr_u32 := u32(0)
49 mut bits_collected := 0
50 for w in a {
51 for i in 0 .. digit_bits {
52 bit := (w >> i) & 1
53 curr_u32 |= u32(bit) << bits_collected
54 bits_collected++
55 if bits_collected == 32 {
56 b << curr_u32
57 curr_u32 = 0
58 bits_collected = 0
59 }
60 }
61 }
62 if bits_collected > 0 {
63 b << curr_u32
64 }
65
66 mut blen := b.len
67 for blen > 0 && b[blen - 1] == 0 {
68 blen--
69 }
70 unsafe {
71 b.len = blen
72 }
73 mut sb := strings.new_builder(30)
74 sb.write_string('[')
75 mut first := true
76 for i in 0 .. b.len {
77 if !first {
78 sb.write_string(', ')
79 }
80 sb.write_string('0x${b[i].hex():08}')
81 first = false
82 }
83 sb.write_string(']')
84 return sb.str()
85}
86
87@[direct_array_access; inline]
88fn found_multiplication_base_case(operand_a []u64, operand_b []u64, mut storage []u64) bool {
89 // base case necessary to end recursion
90 if operand_a.len == 0 || operand_b.len == 0 {
91 storage.clear()
92 return true
93 }
94
95 if operand_a.len < operand_b.len {
96 multiply_digit_array(operand_b, operand_a, mut storage)
97 return true
98 }
99
100 if operand_b.len == 1 {
101 multiply_array_by_digit(operand_a, operand_b[0], mut storage)
102 return true
103 }
104 return false
105}
106
107// karatsuba algorithm for multiplication
108// possible optimisations:
109// - transform one or all the recurrences in loops
110@[direct_array_access]
111fn karatsuba_multiply_digit_array(operand_a []u64, operand_b []u64, mut storage []u64) {
112 if found_multiplication_base_case(operand_a, operand_b, mut storage) {
113 return
114 }
115
116 // thanks to the base cases we can pass zero-length arrays to the mult func
117 half := imax(operand_a.len, operand_b.len) / 2
118 mut a_l := unsafe { operand_a[0..half] }
119 mut a_h := unsafe { operand_a[half..] }
120 mut b_l := []u64{}
121 mut b_h := []u64{}
122 if half <= operand_b.len {
123 b_l = unsafe { operand_b[0..half] }
124 b_h = unsafe { operand_b[half..] }
125 } else {
126 b_l = unsafe { operand_b }
127 // b_h = []u64{}
128 }
129 shrink_tail_zeros(mut a_l)
130 shrink_tail_zeros(mut a_h)
131 shrink_tail_zeros(mut b_l)
132 shrink_tail_zeros(mut b_h)
133
134 // use storage for p_1 to avoid allocation and copy later
135 multiply_digit_array(a_h, b_h, mut storage)
136
137 mut p_3 := []u64{len: a_l.len + b_l.len + 1}
138 multiply_digit_array(a_l, b_l, mut p_3)
139
140 mut tmp_1 := []u64{len: imax(a_h.len, a_l.len) + 1}
141 mut tmp_2 := []u64{len: imax(b_h.len, b_l.len) + 1}
142 add_digit_array(a_h, a_l, mut tmp_1)
143 add_digit_array(b_h, b_l, mut tmp_2)
144
145 mut p_2 := []u64{len: operand_a.len + operand_b.len + 1}
146 multiply_digit_array(tmp_1, tmp_2, mut p_2)
147 subtract_in_place(mut p_2, storage) // p_1
148 subtract_in_place(mut p_2, p_3)
149
150 // return p_1.left_shift(2 * u32(half * 32)) + p_2.left_shift(u32(half * 32)) + p_3
151 left_shift_digits_in_place(mut storage, 2 * half)
152 left_shift_digits_in_place(mut p_2, half)
153 add_in_place(mut storage, p_2)
154 add_in_place(mut storage, p_3)
155
156 shrink_tail_zeros(mut storage)
157}
158
159// TODO: the manualfree tag here is a workaround for compilation with -autofree. Remove it, when the -autofree bug is fixed.
160@[direct_array_access; manualfree]
161fn toom3_multiply_digit_array(operand_a []u64, operand_b []u64, mut storage []u64) {
162 if found_multiplication_base_case(operand_a, operand_b, mut storage) {
163 return
164 }
165
166 // After the base case, we have operand_a as the larger integer in terms of digit length
167
168 // k is the length (in u64 digits) of the lower order slices
169 k := (operand_a.len + 2) / 3
170 k2 := 2 * k
171
172 // The pieces of the calculation need to be worked on as proper big.Integers
173 // because the intermediate results can be negative. After recombination, the
174 // final result will be positive.
175
176 // Slices of a and b
177 a0 := Integer{
178 digits: unsafe { operand_a[..k] }
179 signum: if operand_a[..k].all(it == 0) {
180 0
181 } else {
182 1
183 }
184 }
185 a0.shrink_tail_zeros()
186 a1 := Integer{
187 digits: unsafe { operand_a[k..k2] }
188 signum: if operand_a[k..k2].all(it == 0) {
189 0
190 } else {
191 1
192 }
193 }
194 a1.shrink_tail_zeros()
195 a2 := Integer{
196 digits: unsafe { operand_a[k2..] }
197 signum: 1
198 }
199
200 // Zero arrays by default
201 mut b0 := zero_int.clone()
202 mut b1 := zero_int.clone()
203 mut b2 := zero_int.clone()
204
205 if operand_b.len < k {
206 b0 = Integer{
207 digits: operand_b
208 signum: 1
209 }
210 } else if operand_b.len < k2 {
211 if !operand_b[..k].all(it == 0) {
212 b0 = Integer{
213 digits: operand_b[..k].clone()
214 signum: 1
215 }
216 }
217 b0.shrink_tail_zeros()
218 b1 = Integer{
219 digits: operand_b[k..].clone()
220 signum: 1
221 }
222 } else {
223 if !operand_b[..k].all(it == 0) {
224 b0 = Integer{
225 digits: operand_b[..k].clone()
226 signum: 1
227 }
228 }
229 b0.shrink_tail_zeros()
230 if !operand_b[k..k2].all(it == 0) {
231 b1 = Integer{
232 digits: operand_b[k..k2].clone()
233 signum: 1
234 }
235 }
236 b1.shrink_tail_zeros()
237 b2 = Integer{
238 digits: operand_b[k2..].clone()
239 signum: 1
240 }
241 }
242
243 // https://en.wikipedia.org/wiki/Toom%E2%80%93Cook_multiplication#Details
244 // DOI: 10.1007/978-3-540-73074-3_10
245
246 p0 := a0 * b0
247 mut ptemp := a2 + a0
248 mut qtemp := b2 + b0
249 vm1 := (ptemp - a1) * (qtemp - b1)
250 ptemp += a1
251 qtemp += b1
252 p1 := ptemp * qtemp
253 p2 := ((ptemp + a2).left_shift(1) - a0) * ((qtemp + b2).left_shift(1) - b0)
254 pinf := a2 * b2
255
256 mut t2, _ := (p2 - vm1).div_mod_internal(three_int)
257 mut tm1 := (p1 - vm1).right_shift(1)
258 mut t1 := p1 - p0
259 t2 = (t2 - t1).right_shift(1)
260 t1 = (t1 - tm1 - pinf)
261 t2 = t2 - pinf.left_shift(1)
262 tm1 = tm1 - t2
263
264 // shift amount
265 s := u32(k) * digit_bits
266
267 result := (((pinf.left_shift(s) + t2).left_shift(s) + t1).left_shift(s) + tm1).left_shift(s) +
268 p0
269
270 storage = result.digits.clone()
271}
272
273@[inline]
274fn pow2(k int) Integer {
275 mut ret := []u64{len: (k / digit_bits) + 1}
276 bit_set(mut ret, k)
277 return Integer{
278 signum: 1
279 digits: ret
280 }
281}
282
283// optimized left shift in place. amount must be positive
284fn left_shift_digits_in_place(mut a []u64, amount int) {
285 // this is actual in builtin/array.v, prepend_many (private fn)
286 // x := []u64{ len : amount }
287 // a.prepend_many(&x[0], amount)
288 old_len := a.len
289 elem_size := a.element_size
290 unsafe {
291 a.grow_len(amount)
292 sptr := &u8(a.data)
293 dptr := &u8(a.data) + u64(amount) * u64(elem_size)
294 vmemmove(dptr, sptr, u64(old_len) * u64(elem_size))
295 vmemset(sptr, 0, u64(amount) * u64(elem_size))
296 }
297}
298
299// optimized right shift in place. amount must be positive
300fn right_shift_digits_in_place(mut a []u64, amount int) {
301 a.drop(amount)
302}
303
304// operand b can be greater than operand a
305// the capacity of both array is supposed to be sufficient
306@[direct_array_access; inline]
307fn add_in_place(mut a []u64, b []u64) {
308 len_a := a.len
309 len_b := b.len
310 max := imax(len_a, len_b)
311 min := imin(len_a, len_b)
312 mut carry := u64(0)
313 for index in 0 .. min {
314 partial := carry + a[index] + b[index]
315 a[index] = u64(partial) & max_digit
316 carry = u64(partial >> digit_bits)
317 }
318 if len_a >= len_b {
319 for index in min .. max {
320 partial := carry + a[index]
321 a[index] = u64(partial) & max_digit
322 carry = u64(partial >> digit_bits)
323 }
324 } else {
325 for index in min .. max {
326 partial := carry + b[index]
327 a << u64(partial) & max_digit
328 carry = u64(partial >> digit_bits)
329 }
330 }
331 if carry > 0 {
332 a << carry
333 }
334}
335
336// a := a - b supposed a >= b
337@[direct_array_access; inline]
338fn subtract_in_place(mut a []u64, b []u64) {
339 len_a := a.len
340 len_b := b.len
341 max := imax(len_a, len_b)
342 min := imin(len_a, len_b)
343
344 mut borrow := false
345 for index in 0 .. min {
346 mut a_digit := a[index]
347 b_digit := b[index] + if borrow { u64(1) } else { u64(0) }
348 borrow = a_digit < b_digit
349 if borrow {
350 a_digit = a_digit | (u64(1) << digit_bits)
351 }
352 a[index] = a_digit - b_digit
353 }
354
355 if len_a >= len_b {
356 for index in min .. max {
357 mut a_digit := a[index]
358 b_digit := if borrow { u64(1) } else { u64(0) }
359 borrow = a_digit < b_digit
360 if borrow {
361 a_digit = a_digit | (u64(1) << digit_bits)
362 }
363 a[index] = a_digit - b_digit
364 }
365 } else { // if len.b > len.a return zero
366 a.clear()
367 }
368}
369