@theqrl/qrl-contracts 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,749 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // QRL Contracts (last updated v0.1.0) (utils/math/Math.hyp)
3
+
4
+ pragma hyperion >=0.0;
5
+
6
+ import {Panic} from "../Panic.hyp";
7
+ import {SafeCast} from "./SafeCast.hyp";
8
+
9
+ /**
10
+ * @dev Standard math utilities missing in the Hyperion language.
11
+ */
12
+ library Math {
13
+ enum Rounding {
14
+ Floor, // Toward negative infinity
15
+ Ceil, // Toward positive infinity
16
+ Trunc, // Toward zero
17
+ Expand // Away from zero
18
+ }
19
+
20
+ /**
21
+ * @dev Return the 512-bit addition of two uint256.
22
+ *
23
+ * The result is stored in two 256 variables such that sum = high * 2²⁵⁶ + low.
24
+ */
25
+ function add512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
26
+ assembly ("memory-safe") {
27
+ low := add(a, b)
28
+ high := lt(low, a)
29
+ }
30
+ }
31
+
32
+ /**
33
+ * @dev Return the 512-bit multiplication of two uint256.
34
+ *
35
+ * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
36
+ */
37
+ function mul512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
38
+ // 512-bit multiply [high low] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
39
+ // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
40
+ // variables such that product = high * 2²⁵⁶ + low.
41
+ assembly ("memory-safe") {
42
+ let mm := mulmod(a, b, not(0))
43
+ low := mul(a, b)
44
+ high := sub(sub(mm, low), lt(mm, low))
45
+ }
46
+ }
47
+
48
+ /**
49
+ * @dev Returns the addition of two unsigned integers, with a success flag (no overflow).
50
+ */
51
+ function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
52
+ unchecked {
53
+ uint256 c = a + b;
54
+ success = c >= a;
55
+ result = c * SafeCast.toUint(success);
56
+ }
57
+ }
58
+
59
+ /**
60
+ * @dev Returns the subtraction of two unsigned integers, with a success flag (no overflow).
61
+ */
62
+ function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
63
+ unchecked {
64
+ uint256 c = a - b;
65
+ success = c <= a;
66
+ result = c * SafeCast.toUint(success);
67
+ }
68
+ }
69
+
70
+ /**
71
+ * @dev Returns the multiplication of two unsigned integers, with a success flag (no overflow).
72
+ */
73
+ function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
74
+ unchecked {
75
+ uint256 c = a * b;
76
+ assembly ("memory-safe") {
77
+ // Only true when the multiplication doesn't overflow
78
+ // (c / a == b) || (a == 0)
79
+ success := or(eq(div(c, a), b), iszero(a))
80
+ }
81
+ // equivalent to: success ? c : 0
82
+ result = c * SafeCast.toUint(success);
83
+ }
84
+ }
85
+
86
+ /**
87
+ * @dev Returns the division of two unsigned integers, with a success flag (no division by zero).
88
+ */
89
+ function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
90
+ unchecked {
91
+ success = b > 0;
92
+ assembly ("memory-safe") {
93
+ // The `DIV` opcode returns zero when the denominator is 0.
94
+ result := div(a, b)
95
+ }
96
+ }
97
+ }
98
+
99
+ /**
100
+ * @dev Returns the remainder of dividing two unsigned integers, with a success flag (no division by zero).
101
+ */
102
+ function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
103
+ unchecked {
104
+ success = b > 0;
105
+ assembly ("memory-safe") {
106
+ // The `MOD` opcode returns zero when the denominator is 0.
107
+ result := mod(a, b)
108
+ }
109
+ }
110
+ }
111
+
112
+ /**
113
+ * @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
114
+ */
115
+ function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
116
+ (bool success, uint256 result) = tryAdd(a, b);
117
+ return ternary(success, result, type(uint256).max);
118
+ }
119
+
120
+ /**
121
+ * @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
122
+ */
123
+ function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
124
+ (, uint256 result) = trySub(a, b);
125
+ return result;
126
+ }
127
+
128
+ /**
129
+ * @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
130
+ */
131
+ function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
132
+ (bool success, uint256 result) = tryMul(a, b);
133
+ return ternary(success, result, type(uint256).max);
134
+ }
135
+
136
+ /**
137
+ * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
138
+ *
139
+ * IMPORTANT: This function may reduce bytecode size and consume less gas when used standalone.
140
+ * However, the compiler may optimize Hyperion ternary operations (i.e. `a ? b : c`) to only compute
141
+ * one branch when needed, making this function more expensive.
142
+ */
143
+ function ternary(bool condition, uint256 a, uint256 b) internal pure returns (uint256) {
144
+ unchecked {
145
+ // branchless ternary works because:
146
+ // b ^ (a ^ b) == a
147
+ // b ^ 0 == b
148
+ return b ^ ((a ^ b) * SafeCast.toUint(condition));
149
+ }
150
+ }
151
+
152
+ /**
153
+ * @dev Returns the largest of two numbers.
154
+ */
155
+ function max(uint256 a, uint256 b) internal pure returns (uint256) {
156
+ return ternary(a > b, a, b);
157
+ }
158
+
159
+ /**
160
+ * @dev Returns the smallest of two numbers.
161
+ */
162
+ function min(uint256 a, uint256 b) internal pure returns (uint256) {
163
+ return ternary(a < b, a, b);
164
+ }
165
+
166
+ /**
167
+ * @dev Returns the average of two numbers. The result is rounded towards
168
+ * zero.
169
+ */
170
+ function average(uint256 a, uint256 b) internal pure returns (uint256) {
171
+ // (a + b) / 2 can overflow.
172
+ return (a & b) + (a ^ b) / 2;
173
+ }
174
+
175
+ /**
176
+ * @dev Returns the ceiling of the division of two numbers.
177
+ *
178
+ * This differs from standard division with `/` in that it rounds towards infinity instead
179
+ * of rounding towards zero.
180
+ */
181
+ function ceilDiv(uint256 a, uint256 b) internal pure returns (uint256) {
182
+ if (b == 0) {
183
+ // Guarantee the same behavior as in a regular Hyperion division.
184
+ Panic.panic(Panic.DIVISION_BY_ZERO);
185
+ }
186
+
187
+ // The following calculation ensures accurate ceiling division without overflow.
188
+ // Since a is non-zero, (a - 1) / b will not overflow.
189
+ // The largest possible result occurs when (a - 1) / b is type(uint256).max,
190
+ // but the largest value we can obtain is type(uint256).max - 1, which happens
191
+ // when a = type(uint256).max and b = 1.
192
+ unchecked {
193
+ return SafeCast.toUint(a > 0) * ((a - 1) / b + 1);
194
+ }
195
+ }
196
+
197
+ /**
198
+ * @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
199
+ * denominator == 0.
200
+ *
201
+ * Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
202
+ * Uniswap Labs also under MIT license.
203
+ */
204
+ function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
205
+ unchecked {
206
+ (uint256 high, uint256 low) = mul512(x, y);
207
+
208
+ // Handle non-overflow cases, 256 by 256 division.
209
+ if (high == 0) {
210
+ // Hyperion will revert if denominator == 0, unlike the div opcode on its own.
211
+ // The surrounding unchecked block does not change this fact.
212
+ // See https://docs.hyperionlang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
213
+ return low / denominator;
214
+ }
215
+
216
+ // Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
217
+ if (denominator <= high) {
218
+ Panic.panic(ternary(denominator == 0, Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
219
+ }
220
+
221
+ ///////////////////////////////////////////////
222
+ // 512 by 256 division.
223
+ ///////////////////////////////////////////////
224
+
225
+ // Make division exact by subtracting the remainder from [high low].
226
+ uint256 remainder;
227
+ assembly ("memory-safe") {
228
+ // Compute remainder using mulmod.
229
+ remainder := mulmod(x, y, denominator)
230
+
231
+ // Subtract 256 bit number from 512 bit number.
232
+ high := sub(high, gt(remainder, low))
233
+ low := sub(low, remainder)
234
+ }
235
+
236
+ // Factor powers of two out of denominator and compute largest power of two divisor of denominator.
237
+ // Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
238
+
239
+ uint256 twos = denominator & (0 - denominator);
240
+ assembly ("memory-safe") {
241
+ // Divide denominator by twos.
242
+ denominator := div(denominator, twos)
243
+
244
+ // Divide [high low] by twos.
245
+ low := div(low, twos)
246
+
247
+ // Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one.
248
+ twos := add(div(sub(0, twos), twos), 1)
249
+ }
250
+
251
+ // Shift in bits from high into low.
252
+ low |= high * twos;
253
+
254
+ // Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such
255
+ // that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for
256
+ // four bits. That is, denominator * inv ≡ 1 mod 2⁴.
257
+ uint256 inverse = (3 * denominator) ^ 2;
258
+
259
+ // Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also
260
+ // works in modular arithmetic, doubling the correct bits in each step.
261
+ inverse *= 2 - denominator * inverse; // inverse mod 2⁸
262
+ inverse *= 2 - denominator * inverse; // inverse mod 2¹⁶
263
+ inverse *= 2 - denominator * inverse; // inverse mod 2³²
264
+ inverse *= 2 - denominator * inverse; // inverse mod 2⁶⁴
265
+ inverse *= 2 - denominator * inverse; // inverse mod 2¹²⁸
266
+ inverse *= 2 - denominator * inverse; // inverse mod 2²⁵⁶
267
+
268
+ // Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
269
+ // This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is
270
+ // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and high
271
+ // is no longer required.
272
+ result = low * inverse;
273
+ return result;
274
+ }
275
+ }
276
+
277
+ /**
278
+ * @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
279
+ */
280
+ function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
281
+ return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
282
+ }
283
+
284
+ /**
285
+ * @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
286
+ */
287
+ function mulShr(uint256 x, uint256 y, uint8 n) internal pure returns (uint256 result) {
288
+ unchecked {
289
+ (uint256 high, uint256 low) = mul512(x, y);
290
+ if (high >= 1 << n) {
291
+ Panic.panic(Panic.UNDER_OVERFLOW);
292
+ }
293
+ return (high << (256 - n)) | (low >> n);
294
+ }
295
+ }
296
+
297
+ /**
298
+ * @dev Calculates x * y >> n with full precision, following the selected rounding direction.
299
+ */
300
+ function mulShr(uint256 x, uint256 y, uint8 n, Rounding rounding) internal pure returns (uint256) {
301
+ return mulShr(x, y, n) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, 1 << n) > 0);
302
+ }
303
+
304
+ /**
305
+ * @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
306
+ *
307
+ * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, except 0.
308
+ * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
309
+ *
310
+ * If the input value is not inversible, 0 is returned.
311
+ *
312
+ * NOTE: If you know for sure that n is (big) a prime, it may be cheaper to use Fermat's little theorem and get the
313
+ * inverse using `Math.modExp(a, n - 2, n)`. See {invModPrime}.
314
+ */
315
+ function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
316
+ unchecked {
317
+ if (n == 0) return 0;
318
+
319
+ // The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version)
320
+ // Used to compute integers x and y such that: ax + ny = gcd(a, n).
321
+ // When the gcd is 1, then the inverse of a modulo n exists and it's x.
322
+ // ax + ny = 1
323
+ // ax = 1 + (-y)n
324
+ // ax ≡ 1 (mod n) # x is the inverse of a modulo n
325
+
326
+ // If the remainder is 0 the gcd is n right away.
327
+ uint256 remainder = a % n;
328
+ uint256 gcd = n;
329
+
330
+ // Therefore the initial coefficients are:
331
+ // ax + ny = gcd(a, n) = n
332
+ // 0a + 1n = n
333
+ int256 x = 0;
334
+ int256 y = 1;
335
+
336
+ while (remainder != 0) {
337
+ uint256 quotient = gcd / remainder;
338
+
339
+ (gcd, remainder) = (
340
+ // The old remainder is the next gcd to try.
341
+ remainder,
342
+ // Compute the next remainder.
343
+ // Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd
344
+ // where gcd is at most n (capped to type(uint256).max)
345
+ gcd - remainder * quotient
346
+ );
347
+
348
+ (x, y) = (
349
+ // Increment the coefficient of a.
350
+ y,
351
+ // Decrement the coefficient of n.
352
+ // Can overflow, but the result is casted to uint256 so that the
353
+ // next value of y is "wrapped around" to a value between 0 and n - 1.
354
+ x - y * int256(quotient)
355
+ );
356
+ }
357
+
358
+ if (gcd != 1) return 0; // No inverse exists.
359
+ return ternary(x < 0, n - uint256(-x), uint256(x)); // Wrap the result if it's negative.
360
+ }
361
+ }
362
+
363
+ /**
364
+ * @dev Variant of {invMod}. More efficient, but only works if `p` is known to be a prime greater than `2`.
365
+ *
366
+ * From https://en.wikipedia.org/wiki/Fermat%27s_little_theorem[Fermat's little theorem], we know that if p is
367
+ * prime, then `a**(p-1) ≡ 1 mod p`. As a consequence, we have `a * a**(p-2) ≡ 1 mod p`, which means that
368
+ * `a**(p-2)` is the modular multiplicative inverse of a in Fp.
369
+ *
370
+ * NOTE: this function does NOT check that `p` is a prime greater than `2`.
371
+ */
372
+ function invModPrime(uint256 a, uint256 p) internal view returns (uint256) {
373
+ unchecked {
374
+ return Math.modExp(a, p - 2, p);
375
+ }
376
+ }
377
+
378
+ /**
379
+ * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m)
380
+ *
381
+ * Requirements:
382
+ * - modulus can't be zero
383
+ * - underlying staticcall to precompile must succeed
384
+ *
385
+ * IMPORTANT: The result is only valid if the underlying call succeeds. When using this function, make
386
+ * sure the chain you're using it on supports the precompiled contract for modular exponentiation
387
+ * at address 0x05 as specified in https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise,
388
+ * the underlying function will succeed given the lack of a revert, but the result may be incorrectly
389
+ * interpreted as 0.
390
+ */
391
+ function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
392
+ (bool success, uint256 result) = tryModExp(b, e, m);
393
+ if (!success) {
394
+ Panic.panic(Panic.DIVISION_BY_ZERO);
395
+ }
396
+ return result;
397
+ }
398
+
399
+ /**
400
+ * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m).
401
+ * It includes a success flag indicating if the operation succeeded. Operation will be marked as failed if trying
402
+ * to operate modulo 0 or if the underlying precompile reverted.
403
+ *
404
+ * IMPORTANT: The result is only valid if the success flag is true. When using this function, make sure the chain
405
+ * you're using it on supports the precompiled contract for modular exponentiation at address 0x05 as specified in
406
+ * https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, the underlying function will succeed given the lack
407
+ * of a revert, but the result may be incorrectly interpreted as 0.
408
+ */
409
+ function tryModExp(uint256 b, uint256 e, uint256 m) internal view returns (bool success, uint256 result) {
410
+ if (m == 0) return (false, 0);
411
+ assembly ("memory-safe") {
412
+ let ptr := mload(0x40)
413
+ // | Offset | Content | Content (Hex) |
414
+ // |-----------|------------|--------------------------------------------------------------------|
415
+ // | 0x00:0x1f | size of b | 0x0000000000000000000000000000000000000000000000000000000000000020 |
416
+ // | 0x20:0x3f | size of e | 0x0000000000000000000000000000000000000000000000000000000000000020 |
417
+ // | 0x40:0x5f | size of m | 0x0000000000000000000000000000000000000000000000000000000000000020 |
418
+ // | 0x60:0x7f | value of b | 0x<.............................................................b> |
419
+ // | 0x80:0x9f | value of e | 0x<.............................................................e> |
420
+ // | 0xa0:0xbf | value of m | 0x<.............................................................m> |
421
+ mstore(ptr, 0x20)
422
+ mstore(add(ptr, 0x20), 0x20)
423
+ mstore(add(ptr, 0x40), 0x20)
424
+ mstore(add(ptr, 0x60), b)
425
+ mstore(add(ptr, 0x80), e)
426
+ mstore(add(ptr, 0xa0), m)
427
+
428
+ // Given the result < m, it's guaranteed to fit in 32 bytes,
429
+ // so we can use the memory scratch space located at offset 0.
430
+ success := staticcall(gas(), 0x05, ptr, 0xc0, 0x00, 0x20)
431
+ result := mload(0x00)
432
+ }
433
+ }
434
+
435
+ /**
436
+ * @dev Variant of {modExp} that supports inputs of arbitrary length.
437
+ */
438
+ function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
439
+ (bool success, bytes memory result) = tryModExp(b, e, m);
440
+ if (!success) {
441
+ Panic.panic(Panic.DIVISION_BY_ZERO);
442
+ }
443
+ return result;
444
+ }
445
+
446
+ /**
447
+ * @dev Variant of {tryModExp} that supports inputs of arbitrary length.
448
+ */
449
+ function tryModExp(
450
+ bytes memory b,
451
+ bytes memory e,
452
+ bytes memory m
453
+ ) internal view returns (bool success, bytes memory result) {
454
+ if (_zeroBytes(m)) return (false, new bytes(0));
455
+
456
+ uint256 mLen = m.length;
457
+
458
+ // Encode call args in result and move the free memory pointer
459
+ result = abi.encodePacked(b.length, e.length, mLen, b, e, m);
460
+
461
+ assembly ("memory-safe") {
462
+ let dataPtr := add(result, 0x20)
463
+ // Write result on top of args to avoid allocating extra memory.
464
+ success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen)
465
+ // Overwrite the length.
466
+ // result.length > returndatasize() is guaranteed because returndatasize() == m.length
467
+ mstore(result, mLen)
468
+ // Set the memory pointer after the returned data.
469
+ mstore(0x40, add(dataPtr, mLen))
470
+ }
471
+ }
472
+
473
+ /**
474
+ * @dev Returns whether the provided byte array is zero.
475
+ */
476
+ function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
477
+ for (uint256 i = 0; i < byteArray.length; ++i) {
478
+ if (byteArray[i] != 0) {
479
+ return false;
480
+ }
481
+ }
482
+ return true;
483
+ }
484
+
485
+ /**
486
+ * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
487
+ * towards zero.
488
+ *
489
+ * This method is based on Newton's method for computing square roots; the algorithm is restricted to only
490
+ * using integer operations.
491
+ */
492
+ function sqrt(uint256 a) internal pure returns (uint256) {
493
+ unchecked {
494
+ // Take care of easy edge cases when a == 0 or a == 1
495
+ if (a <= 1) {
496
+ return a;
497
+ }
498
+
499
+ // In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
500
+ // sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
501
+ // the current value as `ε_n = | x_n - sqrt(a) |`.
502
+ //
503
+ // For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
504
+ // of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
505
+ // bigger than any uint256.
506
+ //
507
+ // By noticing that
508
+ // `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
509
+ // we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
510
+ // to the msb function.
511
+ uint256 aa = a;
512
+ uint256 xn = 1;
513
+
514
+ if (aa >= (1 << 128)) {
515
+ aa >>= 128;
516
+ xn <<= 64;
517
+ }
518
+ if (aa >= (1 << 64)) {
519
+ aa >>= 64;
520
+ xn <<= 32;
521
+ }
522
+ if (aa >= (1 << 32)) {
523
+ aa >>= 32;
524
+ xn <<= 16;
525
+ }
526
+ if (aa >= (1 << 16)) {
527
+ aa >>= 16;
528
+ xn <<= 8;
529
+ }
530
+ if (aa >= (1 << 8)) {
531
+ aa >>= 8;
532
+ xn <<= 4;
533
+ }
534
+ if (aa >= (1 << 4)) {
535
+ aa >>= 4;
536
+ xn <<= 2;
537
+ }
538
+ if (aa >= (1 << 2)) {
539
+ xn <<= 1;
540
+ }
541
+
542
+ // We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
543
+ //
544
+ // We can refine our estimation by noticing that the middle of that interval minimizes the error.
545
+ // If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
546
+ // This is going to be our x_0 (and ε_0)
547
+ xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)
548
+
549
+ // From here, Newton's method give us:
550
+ // x_{n+1} = (x_n + a / x_n) / 2
551
+ //
552
+ // One should note that:
553
+ // x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
554
+ // = ((x_n² + a) / (2 * x_n))² - a
555
+ // = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
556
+ // = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
557
+ // = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
558
+ // = (x_n² - a)² / (2 * x_n)²
559
+ // = ((x_n² - a) / (2 * x_n))²
560
+ // ≥ 0
561
+ // Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
562
+ //
563
+ // This gives us the proof of quadratic convergence of the sequence:
564
+ // ε_{n+1} = | x_{n+1} - sqrt(a) |
565
+ // = | (x_n + a / x_n) / 2 - sqrt(a) |
566
+ // = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
567
+ // = | (x_n - sqrt(a))² / (2 * x_n) |
568
+ // = | ε_n² / (2 * x_n) |
569
+ // = ε_n² / | (2 * x_n) |
570
+ //
571
+ // For the first iteration, we have a special case where x_0 is known:
572
+ // ε_1 = ε_0² / | (2 * x_0) |
573
+ // ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
574
+ // ≤ 2**(2*e-4) / (3 * 2**(e-1))
575
+ // ≤ 2**(e-3) / 3
576
+ // ≤ 2**(e-3-log2(3))
577
+ // ≤ 2**(e-4.5)
578
+ //
579
+ // For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
580
+ // ε_{n+1} = ε_n² / | (2 * x_n) |
581
+ // ≤ (2**(e-k))² / (2 * 2**(e-1))
582
+ // ≤ 2**(2*e-2*k) / 2**e
583
+ // ≤ 2**(e-2*k)
584
+ xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above
585
+ xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5
586
+ xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9
587
+ xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18
588
+ xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36
589
+ xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72
590
+
591
+ // Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision
592
+ // ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
593
+ // sqrt(a) or sqrt(a) + 1.
594
+ return xn - SafeCast.toUint(xn > a / xn);
595
+ }
596
+ }
597
+
598
+ /**
599
+ * @dev Calculates sqrt(a), following the selected rounding direction.
600
+ */
601
+ function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
602
+ unchecked {
603
+ uint256 result = sqrt(a);
604
+ return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a);
605
+ }
606
+ }
607
+
608
+ /**
609
+ * @dev Return the log in base 2 of a positive value rounded towards zero.
610
+ * Returns 0 if given 0.
611
+ */
612
+ function log2(uint256 x) internal pure returns (uint256 r) {
613
+ // If value has upper 128 bits set, log2 result is at least 128
614
+ r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7;
615
+ // If upper 64 bits of 128-bit half set, add 64 to result
616
+ r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6;
617
+ // If upper 32 bits of 64-bit half set, add 32 to result
618
+ r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5;
619
+ // If upper 16 bits of 32-bit half set, add 16 to result
620
+ r |= SafeCast.toUint((x >> r) > 0xffff) << 4;
621
+ // If upper 8 bits of 16-bit half set, add 8 to result
622
+ r |= SafeCast.toUint((x >> r) > 0xff) << 3;
623
+ // If upper 4 bits of 8-bit half set, add 4 to result
624
+ r |= SafeCast.toUint((x >> r) > 0xf) << 2;
625
+
626
+ // Shifts value right by the current result and use it as an index into this lookup table:
627
+ //
628
+ // | x (4 bits) | index | table[index] = MSB position |
629
+ // |------------|---------|-----------------------------|
630
+ // | 0000 | 0 | table[0] = 0 |
631
+ // | 0001 | 1 | table[1] = 0 |
632
+ // | 0010 | 2 | table[2] = 1 |
633
+ // | 0011 | 3 | table[3] = 1 |
634
+ // | 0100 | 4 | table[4] = 2 |
635
+ // | 0101 | 5 | table[5] = 2 |
636
+ // | 0110 | 6 | table[6] = 2 |
637
+ // | 0111 | 7 | table[7] = 2 |
638
+ // | 1000 | 8 | table[8] = 3 |
639
+ // | 1001 | 9 | table[9] = 3 |
640
+ // | 1010 | 10 | table[10] = 3 |
641
+ // | 1011 | 11 | table[11] = 3 |
642
+ // | 1100 | 12 | table[12] = 3 |
643
+ // | 1101 | 13 | table[13] = 3 |
644
+ // | 1110 | 14 | table[14] = 3 |
645
+ // | 1111 | 15 | table[15] = 3 |
646
+ //
647
+ // The lookup table is represented as a 32-byte value with the MSB positions for 0-15 in the last 16 bytes.
648
+ assembly ("memory-safe") {
649
+ r := or(r, byte(shr(r, x), 0x0000010102020202030303030303030300000000000000000000000000000000))
650
+ }
651
+ }
652
+
653
+ /**
654
+ * @dev Return the log in base 2, following the selected rounding direction, of a positive value.
655
+ * Returns 0 if given 0.
656
+ */
657
+ function log2(uint256 value, Rounding rounding) internal pure returns (uint256) {
658
+ unchecked {
659
+ uint256 result = log2(value);
660
+ return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value);
661
+ }
662
+ }
663
+
664
+ /**
665
+ * @dev Return the log in base 10 of a positive value rounded towards zero.
666
+ * Returns 0 if given 0.
667
+ */
668
+ function log10(uint256 value) internal pure returns (uint256) {
669
+ uint256 result = 0;
670
+ unchecked {
671
+ if (value >= 10 ** 64) {
672
+ value /= 10 ** 64;
673
+ result += 64;
674
+ }
675
+ if (value >= 10 ** 32) {
676
+ value /= 10 ** 32;
677
+ result += 32;
678
+ }
679
+ if (value >= 10 ** 16) {
680
+ value /= 10 ** 16;
681
+ result += 16;
682
+ }
683
+ if (value >= 10 ** 8) {
684
+ value /= 10 ** 8;
685
+ result += 8;
686
+ }
687
+ if (value >= 10 ** 4) {
688
+ value /= 10 ** 4;
689
+ result += 4;
690
+ }
691
+ if (value >= 10 ** 2) {
692
+ value /= 10 ** 2;
693
+ result += 2;
694
+ }
695
+ if (value >= 10 ** 1) {
696
+ result += 1;
697
+ }
698
+ }
699
+ return result;
700
+ }
701
+
702
+ /**
703
+ * @dev Return the log in base 10, following the selected rounding direction, of a positive value.
704
+ * Returns 0 if given 0.
705
+ */
706
+ function log10(uint256 value, Rounding rounding) internal pure returns (uint256) {
707
+ unchecked {
708
+ uint256 result = log10(value);
709
+ return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value);
710
+ }
711
+ }
712
+
713
+ /**
714
+ * @dev Return the log in base 256 of a positive value rounded towards zero.
715
+ * Returns 0 if given 0.
716
+ *
717
+ * Adding one to the result gives the number of pairs of hex symbols needed to represent `value` as a hex string.
718
+ */
719
+ function log256(uint256 x) internal pure returns (uint256 r) {
720
+ // If value has upper 128 bits set, log2 result is at least 128
721
+ r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7;
722
+ // If upper 64 bits of 128-bit half set, add 64 to result
723
+ r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6;
724
+ // If upper 32 bits of 64-bit half set, add 32 to result
725
+ r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5;
726
+ // If upper 16 bits of 32-bit half set, add 16 to result
727
+ r |= SafeCast.toUint((x >> r) > 0xffff) << 4;
728
+ // Add 1 if upper 8 bits of 16-bit half set, and divide accumulated result by 8
729
+ return (r >> 3) | SafeCast.toUint((x >> r) > 0xff);
730
+ }
731
+
732
+ /**
733
+ * @dev Return the log in base 256, following the selected rounding direction, of a positive value.
734
+ * Returns 0 if given 0.
735
+ */
736
+ function log256(uint256 value, Rounding rounding) internal pure returns (uint256) {
737
+ unchecked {
738
+ uint256 result = log256(value);
739
+ return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value);
740
+ }
741
+ }
742
+
743
+ /**
744
+ * @dev Returns whether a provided rounding mode is considered rounding up for unsigned integers.
745
+ */
746
+ function unsignedRoundsUp(Rounding rounding) internal pure returns (bool) {
747
+ return uint8(rounding) % 2 == 1;
748
+ }
749
+ }