bigdecimal 3.2.2 → 4.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/bigdecimal.gemspec +6 -1
- data/ext/bigdecimal/bigdecimal.c +972 -2552
- data/ext/bigdecimal/bigdecimal.h +45 -60
- data/ext/bigdecimal/bits.h +3 -0
- data/ext/bigdecimal/div.h +192 -0
- data/ext/bigdecimal/extconf.rb +7 -8
- data/ext/bigdecimal/missing.h +5 -95
- data/ext/bigdecimal/ntt.h +191 -0
- data/lib/bigdecimal/jacobian.rb +2 -0
- data/lib/bigdecimal/ludcmp.rb +2 -0
- data/lib/bigdecimal/math.rb +828 -132
- data/lib/bigdecimal/newton.rb +2 -0
- data/lib/bigdecimal/util.rb +16 -15
- data/lib/bigdecimal.rb +391 -0
- data/sample/linear.rb +73 -37
- data/sample/nlsolve.rb +47 -30
- data/sample/pi.rb +2 -7
- data/sig/big_decimal.rbs +1502 -0
- data/sig/big_decimal_util.rbs +158 -0
- data/sig/big_math.rbs +423 -0
- metadata +8 -3
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication
|
|
2
|
+
|
|
3
|
+
#define NTT_PRIMITIVE_ROOT 17
|
|
4
|
+
#define NTT_PRIME_BASE1 24
|
|
5
|
+
#define NTT_PRIME_BASE2 26
|
|
6
|
+
#define NTT_PRIME_BASE3 29
|
|
7
|
+
#define NTT_PRIME_SHIFT 27
|
|
8
|
+
#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1)
|
|
9
|
+
#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1)
|
|
10
|
+
#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1)
|
|
11
|
+
#define MAX_NTT32_BITS 27
|
|
12
|
+
#define NTT_DECDIG_BASE 1000000000
|
|
13
|
+
|
|
14
|
+
// Calculates base**ex % mod
|
|
15
|
+
static uint32_t
|
|
16
|
+
mod_pow(uint32_t base, uint32_t ex, uint32_t mod) {
|
|
17
|
+
uint32_t res = 1;
|
|
18
|
+
uint32_t bit = 1;
|
|
19
|
+
while (true) {
|
|
20
|
+
if (ex & bit) {
|
|
21
|
+
ex ^= bit;
|
|
22
|
+
res = ((uint64_t)res * base) % mod;
|
|
23
|
+
}
|
|
24
|
+
if (!ex) break;
|
|
25
|
+
base = ((uint64_t)base * base) % mod;
|
|
26
|
+
bit <<= 1;
|
|
27
|
+
}
|
|
28
|
+
return res;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
// Recursively performs butterfly operations of NTT
|
|
32
|
+
static void
|
|
33
|
+
ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) {
|
|
34
|
+
if (depth > 0) {
|
|
35
|
+
ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime);
|
|
36
|
+
} else {
|
|
37
|
+
tmp = input;
|
|
38
|
+
}
|
|
39
|
+
uint32_t size_half = (uint32_t)1 << (size_bits - 1);
|
|
40
|
+
uint32_t stride = (uint32_t)1 << (size_bits - depth - 1);
|
|
41
|
+
uint32_t n = size_half / stride;
|
|
42
|
+
uint32_t rn = 1, rm = prime - 1;
|
|
43
|
+
for (uint32_t i = 0; i < n; i++) {
|
|
44
|
+
uint32_t *aptr = tmp + i * 2 * stride;
|
|
45
|
+
uint32_t *bptr = aptr + stride;
|
|
46
|
+
uint32_t *out1 = output + stride * i;
|
|
47
|
+
uint32_t *out2 = out1 + size_half;
|
|
48
|
+
for (uint32_t k = 0; k < stride; k++) {
|
|
49
|
+
uint32_t a = aptr[k], b = bptr[k];
|
|
50
|
+
out1[k] = (a + (uint64_t)rn * b) % prime;
|
|
51
|
+
out2[k] = (a + (uint64_t)rm * b) % prime;
|
|
52
|
+
}
|
|
53
|
+
rn = ((uint64_t)rn * r) % prime;
|
|
54
|
+
rm = ((uint64_t)rm * r) % prime;
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/* Perform NTT on input array.
|
|
59
|
+
* base, shift: Represent the prime number as (base << shift | 1)
|
|
60
|
+
* r_base: Primitive root of unity modulo prime
|
|
61
|
+
* size_bits: log2 of the size of the input array. Should be less or equal to shift
|
|
62
|
+
* input: input array of size (1 << size_bits)
|
|
63
|
+
*/
|
|
64
|
+
static void
|
|
65
|
+
ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) {
|
|
66
|
+
uint32_t size = (uint32_t)1 << size_bits;
|
|
67
|
+
uint32_t prime = ((uint32_t)base << shift) | 1;
|
|
68
|
+
|
|
69
|
+
// rmax**(1 << shift) % prime == 1
|
|
70
|
+
// r**size % prime == 1
|
|
71
|
+
uint32_t rmax = mod_pow(r_base, base, prime);
|
|
72
|
+
uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime);
|
|
73
|
+
|
|
74
|
+
if (dir < 0) r = mod_pow(r, prime - 2, prime);
|
|
75
|
+
ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime);
|
|
76
|
+
if (dir < 0) {
|
|
77
|
+
uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime);
|
|
78
|
+
for (uint32_t i = 0; i < size; i++) {
|
|
79
|
+
output[i] = ((uint64_t)output[i] * n_inv) % prime;
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3
|
|
85
|
+
* c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3)
|
|
86
|
+
* Assume c <= 999999999**2*(1<<27)
|
|
87
|
+
*/
|
|
88
|
+
static inline void
|
|
89
|
+
mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) {
|
|
90
|
+
// Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3
|
|
91
|
+
// [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2
|
|
92
|
+
// DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3
|
|
93
|
+
// 35002755423056150739595925972 = [1, 3489660916, 3113851359]
|
|
94
|
+
// 14584479687667766215746868453 = [0, 13, 1297437912]
|
|
95
|
+
// 37919651490985126265126719818 = [0, 0, 3373338954]
|
|
96
|
+
uint64_t c0 = mod1;
|
|
97
|
+
uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916;
|
|
98
|
+
uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3;
|
|
99
|
+
c2 += c1 / NTT_PRIME2;
|
|
100
|
+
c1 %= NTT_PRIME2;
|
|
101
|
+
c2 %= NTT_PRIME3;
|
|
102
|
+
// Base conversion. c fits in 3 digits.
|
|
103
|
+
c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2;
|
|
104
|
+
c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1;
|
|
105
|
+
c1 /= NTT_DECDIG_BASE;
|
|
106
|
+
digits[0] = c0 % NTT_DECDIG_BASE;
|
|
107
|
+
c0 /= NTT_DECDIG_BASE;
|
|
108
|
+
c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2;
|
|
109
|
+
c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1;
|
|
110
|
+
c1 /= NTT_DECDIG_BASE;
|
|
111
|
+
digits[1] = c0 % NTT_DECDIG_BASE;
|
|
112
|
+
digits[2] = (uint32_t)(c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
/*
|
|
116
|
+
* NTT multiplication
|
|
117
|
+
* Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1)
|
|
118
|
+
*/
|
|
119
|
+
static void
|
|
120
|
+
ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) {
|
|
121
|
+
if (a_size < b_size) {
|
|
122
|
+
ntt_multiply(b_size, a_size, b, a, c);
|
|
123
|
+
return;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
int ntt_size_bits = bit_length(b_size - 1) + 1;
|
|
127
|
+
if (ntt_size_bits > MAX_NTT32_BITS) {
|
|
128
|
+
rb_raise(rb_eArgError, "Multiply size too large");
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// To calculate large_a * small_b faster, split into several batches.
|
|
132
|
+
uint32_t ntt_size = (uint32_t)1 << ntt_size_bits;
|
|
133
|
+
uint32_t batch_size = ntt_size - (uint32_t)b_size;
|
|
134
|
+
uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size);
|
|
135
|
+
|
|
136
|
+
uint32_t *mem = ruby_xcalloc(sizeof(uint32_t), ntt_size * 9);
|
|
137
|
+
uint32_t *ntt1 = mem;
|
|
138
|
+
uint32_t *ntt2 = mem + ntt_size;
|
|
139
|
+
uint32_t *ntt3 = mem + ntt_size * 2;
|
|
140
|
+
uint32_t *tmp1 = mem + ntt_size * 3;
|
|
141
|
+
uint32_t *tmp2 = mem + ntt_size * 4;
|
|
142
|
+
uint32_t *tmp3 = mem + ntt_size * 5;
|
|
143
|
+
uint32_t *conv1 = mem + ntt_size * 6;
|
|
144
|
+
uint32_t *conv2 = mem + ntt_size * 7;
|
|
145
|
+
uint32_t *conv3 = mem + ntt_size * 8;
|
|
146
|
+
|
|
147
|
+
// Calculate NTT for b in three primes. Result is reused for each batch of a.
|
|
148
|
+
memcpy(tmp1, b, b_size * sizeof(uint32_t));
|
|
149
|
+
memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t));
|
|
150
|
+
ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1);
|
|
151
|
+
ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1);
|
|
152
|
+
ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1);
|
|
153
|
+
|
|
154
|
+
memset(c, 0, (a_size + b_size) * sizeof(uint32_t));
|
|
155
|
+
for (uint32_t idx = 0; idx < batch_count; idx++) {
|
|
156
|
+
uint32_t len = idx == batch_count - 1 ? (uint32_t)a_size - idx * batch_size : batch_size;
|
|
157
|
+
memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t));
|
|
158
|
+
memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t));
|
|
159
|
+
// Calculate convolution for this batch in three primes
|
|
160
|
+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1);
|
|
161
|
+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1;
|
|
162
|
+
ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1);
|
|
163
|
+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1);
|
|
164
|
+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2;
|
|
165
|
+
ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1);
|
|
166
|
+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1);
|
|
167
|
+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3;
|
|
168
|
+
ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1);
|
|
169
|
+
|
|
170
|
+
// Restore the original convolution value from three convolutions calculated in three primes.
|
|
171
|
+
// Each convolution value is maximum 999999999**2*(1<<27)/2
|
|
172
|
+
for (uint32_t i = 0; i < ntt_size; i++) {
|
|
173
|
+
uint32_t dig[3];
|
|
174
|
+
mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig);
|
|
175
|
+
// Maximum values of dig[0], dig[1], and dig[2] are 999999999, 999999999 and 67108863 respectively
|
|
176
|
+
// Maximum overlapped sum (considering overlaps between 2 batches) is less than 4134217722
|
|
177
|
+
// so this sum doesn't overflow uint32_t.
|
|
178
|
+
for (int j = 0; j < 3; j++) {
|
|
179
|
+
// Index check: if dig[j] is non-zero, assign index is within valid range.
|
|
180
|
+
if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j];
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
uint32_t carry = 0;
|
|
185
|
+
for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) {
|
|
186
|
+
uint32_t v = c[i] + carry;
|
|
187
|
+
c[i] = v % NTT_DECDIG_BASE;
|
|
188
|
+
carry = v / NTT_DECDIG_BASE;
|
|
189
|
+
}
|
|
190
|
+
ruby_xfree(mem);
|
|
191
|
+
}
|
data/lib/bigdecimal/jacobian.rb
CHANGED