@paraswap/dex-lib 4.8.27 → 4.8.28-uni-v3-rust.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. package/build/dex/idex.d.ts +1 -1
  2. package/build/dex/uniswap-v3/contract-math/native-bridge.d.ts +15 -0
  3. package/build/dex/uniswap-v3/contract-math/native-bridge.js +71 -0
  4. package/build/dex/uniswap-v3/contract-math/native-bridge.js.map +1 -0
  5. package/build/dex/uniswap-v3/scripts/measure-calc-time.js +222 -110
  6. package/build/dex/uniswap-v3/scripts/measure-calc-time.js.map +1 -1
  7. package/build/dex/uniswap-v3/uniswap-v3-pool.d.ts +2 -0
  8. package/build/dex/uniswap-v3/uniswap-v3-pool.js +3 -0
  9. package/build/dex/uniswap-v3/uniswap-v3-pool.js.map +1 -1
  10. package/build/dex/uniswap-v3/uniswap-v3.d.ts +8 -3
  11. package/build/dex/uniswap-v3/uniswap-v3.js +8 -5
  12. package/build/dex/uniswap-v3/uniswap-v3.js.map +1 -1
  13. package/build/pricing-helper.d.ts +1 -1
  14. package/build/pricing-helper.js +2 -2
  15. package/build/pricing-helper.js.map +1 -1
  16. package/native/Cargo.lock +279 -0
  17. package/native/Cargo.toml +21 -0
  18. package/native/build.rs +5 -0
  19. package/native/package-lock.json +32 -0
  20. package/native/package.json +20 -0
  21. package/native/src/config.rs +40 -0
  22. package/native/src/lib.rs +216 -0
  23. package/native/src/math/bit_math.rs +177 -0
  24. package/native/src/math/full_math.rs +217 -0
  25. package/native/src/math/liquidity_math.rs +72 -0
  26. package/native/src/math/mod.rs +10 -0
  27. package/native/src/math/oracle.rs +493 -0
  28. package/native/src/math/sqrt_price_math.rs +272 -0
  29. package/native/src/math/swap_math.rs +306 -0
  30. package/native/src/math/tick.rs +239 -0
  31. package/native/src/math/tick_bitmap.rs +292 -0
  32. package/native/src/math/tick_math.rs +321 -0
  33. package/native/src/math/unsafe_math.rs +67 -0
  34. package/native/src/pool_state.rs +36 -0
  35. package/native/src/query_outputs.rs +379 -0
  36. package/package.json +2 -1
@@ -0,0 +1,32 @@
1
+ {
2
+ "name": "@paraswap/v3-math-native",
3
+ "version": "0.1.0",
4
+ "lockfileVersion": 3,
5
+ "requires": true,
6
+ "packages": {
7
+ "": {
8
+ "name": "@paraswap/v3-math-native",
9
+ "version": "0.1.0",
10
+ "devDependencies": {
11
+ "@napi-rs/cli": "^2.18.0"
12
+ }
13
+ },
14
+ "node_modules/@napi-rs/cli": {
15
+ "version": "2.18.4",
16
+ "resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.4.tgz",
17
+ "integrity": "sha512-SgJeA4df9DE2iAEpr3M2H0OKl/yjtg1BnRI5/JyowS71tUWhrfSu2LT0V3vlHET+g1hBVlrO60PmEXwUEKp8Mg==",
18
+ "dev": true,
19
+ "license": "MIT",
20
+ "bin": {
21
+ "napi": "scripts/index.js"
22
+ },
23
+ "engines": {
24
+ "node": ">= 10"
25
+ },
26
+ "funding": {
27
+ "type": "github",
28
+ "url": "https://github.com/sponsors/Brooooooklyn"
29
+ }
30
+ }
31
+ }
32
+ }
@@ -0,0 +1,20 @@
1
+ {
2
+ "name": "@paraswap/v3-math-native",
3
+ "version": "0.1.0",
4
+ "private": true,
5
+ "main": "index.js",
6
+ "types": "index.d.ts",
7
+ "napi": {
8
+ "name": "v3-math-native",
9
+ "triples": {
10
+ "defaults": true
11
+ }
12
+ },
13
+ "scripts": {
14
+ "build": "napi build --platform --release",
15
+ "build:debug": "napi build --platform"
16
+ },
17
+ "devDependencies": {
18
+ "@napi-rs/cli": "^2.18.0"
19
+ }
20
+ }
@@ -0,0 +1,40 @@
1
+ /// Configures math variant differences between Uniswap V3 forks.
2
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
3
+ pub enum MathVariant {
4
+ /// Standard Uniswap V3: feeProtocol is 4-bit (% 16 / >> 4)
5
+ UniswapV3,
6
+ /// PancakeSwap V3: feeProtocol is 16-bit (% 65536 / >> 16), protocol_fee = feeAmount * fp / 10000
7
+ PancakeSwapV3,
8
+ /// Solidly V3: No oracle, fee from slot0.fee directly
9
+ SolidlyV3,
10
+ }
11
+
12
+ impl MathVariant {
13
+ pub fn from_str(s: &str) -> Self {
14
+ match s {
15
+ "pancakeswap_v3" => MathVariant::PancakeSwapV3,
16
+ "solidly_v3" => MathVariant::SolidlyV3,
17
+ _ => MathVariant::UniswapV3,
18
+ }
19
+ }
20
+
21
+ /// Extract the fee protocol value for the given swap direction.
22
+ pub fn fee_protocol(&self, fee_protocol_raw: ethnum::U256, zero_for_one: bool) -> ethnum::U256 {
23
+ match self {
24
+ MathVariant::UniswapV3 | MathVariant::SolidlyV3 => {
25
+ if zero_for_one {
26
+ fee_protocol_raw % ethnum::U256::from(16u32)
27
+ } else {
28
+ fee_protocol_raw >> 4
29
+ }
30
+ }
31
+ MathVariant::PancakeSwapV3 => {
32
+ if zero_for_one {
33
+ fee_protocol_raw % ethnum::U256::from(65536u32)
34
+ } else {
35
+ fee_protocol_raw >> 16
36
+ }
37
+ }
38
+ }
39
+ }
40
+ }
@@ -0,0 +1,216 @@
1
+ pub mod config;
2
+ pub mod math;
3
+ pub mod pool_state;
4
+ pub mod query_outputs;
5
+
6
+ use ethnum::{I256, U256};
7
+ use napi::bindgen_prelude::*;
8
+ use napi_derive::napi;
9
+ use std::collections::HashMap;
10
+
11
+ use config::MathVariant;
12
+ use math::oracle::OracleObservation;
13
+ use math::tick::TickInfo;
14
+ use pool_state::PoolState;
15
+
16
+ // ---- NAPI type definitions for JS interop ----
17
+
18
+ #[napi(object)]
19
+ pub struct JsTickEntry {
20
+ pub key: i32,
21
+ pub liquidity_gross: BigInt,
22
+ pub liquidity_net: BigInt,
23
+ }
24
+
25
+ #[napi(object)]
26
+ pub struct JsBitmapEntry {
27
+ pub key: i32,
28
+ pub value: BigInt,
29
+ }
30
+
31
+ #[napi(object)]
32
+ pub struct JsObservationEntry {
33
+ pub key: i32,
34
+ pub block_timestamp: BigInt,
35
+ pub tick_cumulative: BigInt,
36
+ pub seconds_per_liquidity_cumulative_x128: BigInt,
37
+ pub initialized: bool,
38
+ }
39
+
40
+ #[napi(object)]
41
+ pub struct JsPoolStateInit {
42
+ pub variant: String,
43
+ pub block_timestamp: BigInt,
44
+ pub tick_spacing: BigInt,
45
+ pub fee: BigInt,
46
+ pub sqrt_price_x96: BigInt,
47
+ pub tick: BigInt,
48
+ pub observation_index: i32,
49
+ pub observation_cardinality: i32,
50
+ pub observation_cardinality_next: i32,
51
+ pub fee_protocol: BigInt,
52
+ pub liquidity: BigInt,
53
+ pub max_liquidity_per_tick: BigInt,
54
+ pub start_tick_bitmap: BigInt,
55
+ pub lowest_known_tick: BigInt,
56
+ pub highest_known_tick: BigInt,
57
+ pub tick_bitmap: Vec<JsBitmapEntry>,
58
+ pub ticks: Vec<JsTickEntry>,
59
+ pub observations: Vec<JsObservationEntry>,
60
+ }
61
+
62
+ #[napi(object)]
63
+ pub struct JsOutputResult {
64
+ pub outputs: Vec<BigInt>,
65
+ pub tick_counts: Vec<i32>,
66
+ }
67
+
68
+ // ---- Conversion helpers ----
69
+ // napi::bindgen_prelude::BigInt stores { sign_bit: bool, words: Vec<u64> }
70
+ // words are little-endian u64 limbs.
71
+
72
+ fn bigint_to_u256(bi: &BigInt) -> U256 {
73
+ let words = &bi.words;
74
+ let low = words.first().copied().unwrap_or(0) as u128
75
+ | (words.get(1).copied().unwrap_or(0) as u128) << 64;
76
+ let high = words.get(2).copied().unwrap_or(0) as u128
77
+ | (words.get(3).copied().unwrap_or(0) as u128) << 64;
78
+ U256::from_words(high, low)
79
+ }
80
+
81
+ fn bigint_to_i256(bi: &BigInt) -> I256 {
82
+ let u = bigint_to_u256(bi);
83
+ let val = u.as_i256();
84
+ if bi.sign_bit {
85
+ -val
86
+ } else {
87
+ val
88
+ }
89
+ }
90
+
91
+ fn u256_to_bigint(val: U256) -> BigInt {
92
+ let (high, low) = val.into_words();
93
+ let mut words = vec![
94
+ low as u64,
95
+ (low >> 64) as u64,
96
+ high as u64,
97
+ (high >> 64) as u64,
98
+ ];
99
+ // Trim trailing zeros for cleaner representation
100
+ while words.len() > 1 && *words.last().unwrap() == 0 {
101
+ words.pop();
102
+ }
103
+ BigInt {
104
+ sign_bit: false,
105
+ words,
106
+ }
107
+ }
108
+
109
+ // ---- The main NAPI class ----
110
+
111
+ #[napi]
112
+ pub struct RustPoolHandle {
113
+ state: PoolState,
114
+ }
115
+
116
+ #[napi]
117
+ impl RustPoolHandle {
118
+ /// Create a new Rust-owned pool state from JS data.
119
+ #[napi(factory)]
120
+ pub fn create(init: JsPoolStateInit) -> Result<Self> {
121
+ let variant = MathVariant::from_str(&init.variant);
122
+
123
+ let mut tick_bitmap = HashMap::with_capacity(init.tick_bitmap.len());
124
+ for entry in &init.tick_bitmap {
125
+ tick_bitmap.insert(entry.key as i16, bigint_to_u256(&entry.value));
126
+ }
127
+
128
+ let mut ticks = HashMap::with_capacity(init.ticks.len());
129
+ for entry in &init.ticks {
130
+ ticks.insert(
131
+ entry.key,
132
+ TickInfo {
133
+ liquidity_gross: bigint_to_u256(&entry.liquidity_gross),
134
+ liquidity_net: bigint_to_i256(&entry.liquidity_net),
135
+ initialized: true,
136
+ },
137
+ );
138
+ }
139
+
140
+ let mut observations = HashMap::with_capacity(init.observations.len());
141
+ for entry in &init.observations {
142
+ observations.insert(
143
+ entry.key as u16,
144
+ OracleObservation {
145
+ block_timestamp: bigint_to_u256(&entry.block_timestamp),
146
+ tick_cumulative: bigint_to_i256(&entry.tick_cumulative),
147
+ seconds_per_liquidity_cumulative_x128: bigint_to_u256(
148
+ &entry.seconds_per_liquidity_cumulative_x128,
149
+ ),
150
+ initialized: entry.initialized,
151
+ },
152
+ );
153
+ }
154
+
155
+ Ok(Self {
156
+ state: PoolState {
157
+ block_timestamp: bigint_to_u256(&init.block_timestamp),
158
+ tick_spacing: bigint_to_i256(&init.tick_spacing),
159
+ fee: bigint_to_u256(&init.fee),
160
+ sqrt_price_x96: bigint_to_u256(&init.sqrt_price_x96),
161
+ tick: bigint_to_i256(&init.tick),
162
+ observation_index: init.observation_index as u16,
163
+ observation_cardinality: init.observation_cardinality as u16,
164
+ observation_cardinality_next: init.observation_cardinality_next as u16,
165
+ fee_protocol: bigint_to_u256(&init.fee_protocol),
166
+ liquidity: bigint_to_u256(&init.liquidity),
167
+ max_liquidity_per_tick: bigint_to_u256(&init.max_liquidity_per_tick),
168
+ tick_bitmap,
169
+ ticks,
170
+ observations,
171
+ start_tick_bitmap: bigint_to_i256(&init.start_tick_bitmap),
172
+ lowest_known_tick: bigint_to_i256(&init.lowest_known_tick),
173
+ highest_known_tick: bigint_to_i256(&init.highest_known_tick),
174
+ variant,
175
+ },
176
+ })
177
+ }
178
+
179
+ /// HOT PATH: Price N amounts in one call.
180
+ /// side: 0 = SELL, 1 = BUY
181
+ #[napi]
182
+ pub fn query_outputs(
183
+ &self,
184
+ amounts: Vec<BigInt>,
185
+ zero_for_one: bool,
186
+ side: u8,
187
+ ) -> Result<JsOutputResult> {
188
+ let amounts_u256: Vec<U256> = amounts.iter().map(|a| bigint_to_u256(a)).collect();
189
+
190
+ let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
191
+ query_outputs::query_outputs(&self.state, &amounts_u256, zero_for_one, side)
192
+ }));
193
+
194
+ match result {
195
+ Ok(output) => {
196
+ let outputs: Vec<BigInt> =
197
+ output.outputs.iter().map(|v| u256_to_bigint(*v)).collect();
198
+
199
+ Ok(JsOutputResult {
200
+ outputs,
201
+ tick_counts: output.tick_counts,
202
+ })
203
+ }
204
+ Err(panic_info) => {
205
+ let msg = if let Some(s) = panic_info.downcast_ref::<&str>() {
206
+ s.to_string()
207
+ } else if let Some(s) = panic_info.downcast_ref::<String>() {
208
+ s.clone()
209
+ } else {
210
+ "Unknown panic in query_outputs".to_string()
211
+ };
212
+ Err(Error::new(Status::GenericFailure, msg))
213
+ }
214
+ }
215
+ }
216
+ }
@@ -0,0 +1,177 @@
1
+ use ethnum::U256;
2
+
3
+ /// Returns the index of the most significant bit of the number,
4
+ /// where the least significant bit is at index 0 and the most significant bit is at index 255.
5
+ ///
6
+ /// Panics if x is zero.
7
+ pub fn most_significant_bit(x: U256) -> u8 {
8
+ assert!(x > U256::ZERO, "x must be > 0");
9
+ let mut x = x;
10
+ let mut r: u8 = 0;
11
+
12
+ if x >= U256::from_words(1, 0) {
13
+ x >>= 128;
14
+ r += 128;
15
+ }
16
+ if x >= U256::from(0x10000000000000000u128) {
17
+ x >>= 64;
18
+ r += 64;
19
+ }
20
+ if x >= U256::from(0x100000000u128) {
21
+ x >>= 32;
22
+ r += 32;
23
+ }
24
+ if x >= U256::from(0x10000u64) {
25
+ x >>= 16;
26
+ r += 16;
27
+ }
28
+ if x >= U256::from(0x100u64) {
29
+ x >>= 8;
30
+ r += 8;
31
+ }
32
+ if x >= U256::from(0x10u64) {
33
+ x >>= 4;
34
+ r += 4;
35
+ }
36
+ if x >= U256::from(0x4u64) {
37
+ x >>= 2;
38
+ r += 2;
39
+ }
40
+ if x >= U256::from(0x2u64) {
41
+ r += 1;
42
+ }
43
+
44
+ r
45
+ }
46
+
47
+ /// Returns the index of the least significant bit of the number,
48
+ /// where the least significant bit is at index 0 and the most significant bit is at index 255.
49
+ ///
50
+ /// Panics if x is zero.
51
+ pub fn least_significant_bit(x: U256) -> u8 {
52
+ assert!(x > U256::ZERO, "x must be > 0");
53
+ let mut x = x;
54
+ let mut r: u8 = 255;
55
+
56
+ let max_uint128: U256 = U256::new(u128::MAX);
57
+ let max_uint64: U256 = U256::from(u64::MAX);
58
+ let max_uint32: U256 = U256::from(u32::MAX as u64);
59
+ let max_uint16: U256 = U256::from(u16::MAX as u64);
60
+ let max_uint8: U256 = U256::from(u8::MAX as u64);
61
+
62
+ if (x & max_uint128) > U256::ZERO {
63
+ r -= 128;
64
+ } else {
65
+ x >>= 128;
66
+ }
67
+ if (x & max_uint64) > U256::ZERO {
68
+ r -= 64;
69
+ } else {
70
+ x >>= 64;
71
+ }
72
+ if (x & max_uint32) > U256::ZERO {
73
+ r -= 32;
74
+ } else {
75
+ x >>= 32;
76
+ }
77
+ if (x & max_uint16) > U256::ZERO {
78
+ r -= 16;
79
+ } else {
80
+ x >>= 16;
81
+ }
82
+ if (x & max_uint8) > U256::ZERO {
83
+ r -= 8;
84
+ } else {
85
+ x >>= 8;
86
+ }
87
+ if (x & U256::from(0xFu64)) > U256::ZERO {
88
+ r -= 4;
89
+ } else {
90
+ x >>= 4;
91
+ }
92
+ if (x & U256::from(0x3u64)) > U256::ZERO {
93
+ r -= 2;
94
+ } else {
95
+ x >>= 2;
96
+ }
97
+ if (x & U256::ONE) > U256::ZERO {
98
+ r -= 1;
99
+ }
100
+
101
+ r
102
+ }
103
+
104
+ #[cfg(test)]
105
+ mod tests {
106
+ use super::*;
107
+
108
+ #[test]
109
+ fn test_msb_one() {
110
+ assert_eq!(most_significant_bit(U256::ONE), 0);
111
+ }
112
+
113
+ #[test]
114
+ fn test_msb_two() {
115
+ assert_eq!(most_significant_bit(U256::from(2u64)), 1);
116
+ }
117
+
118
+ #[test]
119
+ fn test_msb_powers_of_two() {
120
+ for i in 0..=255u8 {
121
+ let x = U256::ONE << i;
122
+ assert_eq!(most_significant_bit(x), i);
123
+ }
124
+ }
125
+
126
+ #[test]
127
+ fn test_msb_max() {
128
+ assert_eq!(most_significant_bit(U256::MAX), 255);
129
+ }
130
+
131
+ #[test]
132
+ #[should_panic]
133
+ fn test_msb_zero_panics() {
134
+ most_significant_bit(U256::ZERO);
135
+ }
136
+
137
+ #[test]
138
+ fn test_lsb_one() {
139
+ assert_eq!(least_significant_bit(U256::ONE), 0);
140
+ }
141
+
142
+ #[test]
143
+ fn test_lsb_two() {
144
+ assert_eq!(least_significant_bit(U256::from(2u64)), 1);
145
+ }
146
+
147
+ #[test]
148
+ fn test_lsb_powers_of_two() {
149
+ for i in 0..=255u8 {
150
+ let x = U256::ONE << i;
151
+ assert_eq!(least_significant_bit(x), i);
152
+ }
153
+ }
154
+
155
+ #[test]
156
+ fn test_lsb_max() {
157
+ assert_eq!(least_significant_bit(U256::MAX), 0);
158
+ }
159
+
160
+ #[test]
161
+ #[should_panic]
162
+ fn test_lsb_zero_panics() {
163
+ least_significant_bit(U256::ZERO);
164
+ }
165
+
166
+ #[test]
167
+ fn test_msb_mixed() {
168
+ assert_eq!(most_significant_bit(U256::from(10u64)), 3);
169
+ assert_eq!(most_significant_bit(U256::from(24u64)), 4);
170
+ }
171
+
172
+ #[test]
173
+ fn test_lsb_mixed() {
174
+ assert_eq!(least_significant_bit(U256::from(10u64)), 1);
175
+ assert_eq!(least_significant_bit(U256::from(24u64)), 3);
176
+ }
177
+ }
@@ -0,0 +1,217 @@
1
+ use ethnum::U256;
2
+
3
+ /// Calculates floor(a * b / denominator).
4
+ ///
5
+ /// The TS version uses BigInt which has arbitrary precision, so `a * b` never
6
+ /// overflows. We replicate this by widening to 512-bit via two U256 halves.
7
+ ///
8
+ /// Panics if the result exceeds U256::MAX or denominator is zero.
9
+ pub fn mul_div(a: U256, b: U256, denominator: U256) -> U256 {
10
+ assert!(denominator > U256::ZERO, "denominator must be > 0");
11
+ let (lo, hi) = widening_mul(a, b);
12
+ let (quot, _) = div_512_by_256(lo, hi, denominator);
13
+ quot
14
+ }
15
+
16
+ /// Calculates ceil(a * b / denominator).
17
+ ///
18
+ /// Panics if the result exceeds U256::MAX or denominator is zero.
19
+ pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> U256 {
20
+ assert!(denominator > U256::ZERO, "denominator must be > 0");
21
+ // result = (a * b + denominator - 1) / denominator
22
+ let (lo, hi) = widening_mul(a, b);
23
+ // add (denominator - 1) to the 512-bit product
24
+ let addend = denominator - U256::ONE;
25
+ let (lo2, carry) = lo.overflowing_add(addend);
26
+ let hi2 = if carry { hi + U256::ONE } else { hi };
27
+ let (quot, _) = div_512_by_256(lo2, hi2, denominator);
28
+ quot
29
+ }
30
+
31
+ /// Returns (lo, hi) such that a * b = hi * 2^256 + lo.
32
+ fn widening_mul(a: U256, b: U256) -> (U256, U256) {
33
+ let mask128 = (U256::ONE << 128) - U256::ONE;
34
+
35
+ let a_lo = a & mask128;
36
+ let a_hi = a >> 128;
37
+ let b_lo = b & mask128;
38
+ let b_hi = b >> 128;
39
+
40
+ let p0: U256 = a_lo * b_lo;
41
+ let p1: U256 = a_lo * b_hi;
42
+ let p2: U256 = a_hi * b_lo;
43
+ let p3: U256 = a_hi * b_hi;
44
+
45
+ let lo: U256 = p0;
46
+ let hi: U256 = p3;
47
+
48
+ // Add p1 << 128
49
+ let p1_lo = p1 << 128;
50
+ let p1_hi = p1 >> 128;
51
+ let (lo, c1) = lo.overflowing_add(p1_lo);
52
+ let hi = hi + p1_hi + if c1 { U256::ONE } else { U256::ZERO };
53
+
54
+ // Add p2 << 128
55
+ let p2_lo = p2 << 128;
56
+ let p2_hi = p2 >> 128;
57
+ let (lo, c2) = lo.overflowing_add(p2_lo);
58
+ let hi = hi + p2_hi + if c2 { U256::ONE } else { U256::ZERO };
59
+
60
+ (lo, hi)
61
+ }
62
+
63
+ /// Divides a 512-bit number (lo + hi * 2^256) by a 256-bit denominator.
64
+ /// Returns (quotient, remainder). Panics if quotient overflows U256.
65
+ fn div_512_by_256(lo: U256, hi: U256, d: U256) -> (U256, U256) {
66
+ assert!(d > U256::ZERO, "division by zero");
67
+
68
+ if hi == U256::ZERO {
69
+ return (lo / d, lo % d);
70
+ }
71
+
72
+ assert!(hi < d, "mul_div result overflows U256");
73
+
74
+ // Split lo into two 128-bit halves and do two rounds of division.
75
+ let mask128 = (U256::ONE << 128) - U256::ONE;
76
+ let lo_hi = (lo >> 128) & mask128;
77
+ let lo_lo = lo & mask128;
78
+
79
+ // First round: divide (hi * 2^128 + lo_hi) by d
80
+ let (q_hi, r1) = div_384_by_256(lo_hi, hi, d);
81
+
82
+ // Second round: divide (r1 * 2^128 + lo_lo) by d
83
+ let (q_lo, rem) = div_384_by_256(lo_lo, r1, d);
84
+
85
+ let quotient = (q_hi << 128) + q_lo;
86
+ (quotient, rem)
87
+ }
88
+
89
+ /// Divides (hi * 2^128 + lo_128) by d, where hi < d and lo_128 < 2^128.
90
+ /// Returns (quotient, remainder).
91
+ fn div_384_by_256(lo_128: U256, hi: U256, d: U256) -> (U256, U256) {
92
+ let mask128 = (U256::ONE << 128) - U256::ONE;
93
+ let hi_upper = hi >> 128;
94
+
95
+ if hi_upper == U256::ZERO {
96
+ // hi fits in 128 bits, so hi * 2^128 + lo_128 fits in 256 bits
97
+ let numerator = (hi << 128) | lo_128;
98
+ return (numerator / d, numerator % d);
99
+ }
100
+
101
+ // hi doesn't fit in 128 bits. Use bit-by-bit long division.
102
+ // The quotient fits in at most ~129 bits (since hi < d).
103
+ let _ = mask128;
104
+ let mut remainder = hi;
105
+ let mut quotient = U256::ZERO;
106
+
107
+ for i in (0..128).rev() {
108
+ let bit = (lo_128 >> i) & U256::ONE;
109
+
110
+ // Check if shifting left would overflow
111
+ let overflow = remainder >> 255 != U256::ZERO;
112
+ remainder = (remainder << 1) | bit;
113
+
114
+ if overflow || remainder >= d {
115
+ remainder = remainder.wrapping_sub(d);
116
+ quotient = quotient | (U256::ONE << i);
117
+ }
118
+ }
119
+
120
+ (quotient, remainder)
121
+ }
122
+
123
+ #[cfg(test)]
124
+ mod tests {
125
+ use super::*;
126
+
127
+ #[test]
128
+ fn test_mul_div_simple() {
129
+ assert_eq!(
130
+ mul_div(U256::from(6u64), U256::from(7u64), U256::from(3u64)),
131
+ U256::from(14u64)
132
+ );
133
+ }
134
+
135
+ #[test]
136
+ fn test_mul_div_large() {
137
+ assert_eq!(
138
+ mul_div(U256::MAX, U256::ONE, U256::ONE),
139
+ U256::MAX
140
+ );
141
+ }
142
+
143
+ #[test]
144
+ fn test_mul_div_rounding_up_exact() {
145
+ assert_eq!(
146
+ mul_div_rounding_up(U256::from(6u64), U256::from(7u64), U256::from(3u64)),
147
+ U256::from(14u64)
148
+ );
149
+ }
150
+
151
+ #[test]
152
+ fn test_mul_div_rounding_up_rounds() {
153
+ // 5 * 7 / 3 = 35/3 = 11.666... -> ceil = 12
154
+ assert_eq!(
155
+ mul_div_rounding_up(U256::from(5u64), U256::from(7u64), U256::from(3u64)),
156
+ U256::from(12u64)
157
+ );
158
+ }
159
+
160
+ #[test]
161
+ fn test_mul_div_floor() {
162
+ // 5 * 7 / 3 = 35/3 = 11.666... -> floor = 11
163
+ assert_eq!(
164
+ mul_div(U256::from(5u64), U256::from(7u64), U256::from(3u64)),
165
+ U256::from(11u64)
166
+ );
167
+ }
168
+
169
+ #[test]
170
+ fn test_mul_div_large_product() {
171
+ let a = U256::ONE << 200;
172
+ let b = U256::ONE << 200;
173
+ let d = U256::ONE << 200;
174
+ assert_eq!(mul_div(a, b, d), U256::ONE << 200);
175
+ }
176
+
177
+ #[test]
178
+ fn test_mul_div_max_times_max() {
179
+ assert_eq!(mul_div(U256::MAX, U256::MAX, U256::MAX), U256::MAX);
180
+ }
181
+
182
+ #[test]
183
+ #[should_panic]
184
+ fn test_mul_div_overflow() {
185
+ mul_div(U256::MAX, U256::MAX, U256::ONE);
186
+ }
187
+
188
+ #[test]
189
+ fn test_mul_div_rounding_up_large() {
190
+ let a = U256::ONE << 128;
191
+ let b = U256::ONE << 128;
192
+ let d = (U256::ONE << 128) + U256::ONE;
193
+ let result = mul_div_rounding_up(a, b, d);
194
+ assert!(result > U256::ZERO);
195
+ }
196
+
197
+ #[test]
198
+ fn test_widening_mul_simple() {
199
+ let (lo, hi) = widening_mul(U256::from(3u64), U256::from(7u64));
200
+ assert_eq!(lo, U256::from(21u64));
201
+ assert_eq!(hi, U256::ZERO);
202
+ }
203
+
204
+ #[test]
205
+ fn test_widening_mul_large() {
206
+ let (lo, hi) = widening_mul(U256::MAX, U256::from(2u64));
207
+ assert_eq!(hi, U256::ONE);
208
+ assert_eq!(lo, U256::MAX - U256::ONE);
209
+ }
210
+
211
+ #[test]
212
+ fn test_mul_div_uniswap_style() {
213
+ // Test case from Uniswap V3: mulDiv(Q128, Q128, Q128) = Q128
214
+ let q128 = U256::ONE << 128;
215
+ assert_eq!(mul_div(q128, q128, q128), q128);
216
+ }
217
+ }