max-div 0.0.3__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,167 @@
1
+ import numba
2
+ import numpy as np
3
+
4
+ # -------------------------------------------------------------------------
5
+ # Constants
6
+ # -------------------------------------------------------------------------
7
+
8
+ # --- float64 ---------------------------------------------
9
+ _D_LOG_2 = 0.6931471805599453 # np.log(2)
10
+
11
+ _D20 = -2.664030016771488
12
+ _D21 = 4.018729576130520
13
+ _D22 = -1.359007257261774
14
+
15
+ _D30 = -3.144940630924
16
+ _D31 = 6.058956424048
17
+ _D32 = -4.157032648692
18
+ _D33 = 1.243566840636
19
+
20
+ _D40 = -3.505614661980
21
+ _D41 = 8.099233785172
22
+ _D42 = -8.397609124753
23
+ _D43 = 5.084088932163
24
+ _D44 = -1.280174030509
25
+
26
+ _D50 = -3.794153676536
27
+ _D51 = 10.139512633266
28
+ _D52 = -14.080875352582
29
+ _D53 = 12.881420375173
30
+ _D54 = -6.551609372263
31
+ _D55 = 1.405716091134
32
+
33
+ # --- float32 ---------------------------------------------
34
+ _S_LOG_2 = np.float32(_D_LOG_2)
35
+
36
+ _S20 = np.float32(_D20)
37
+ _S21 = np.float32(_D21)
38
+ _S22 = np.float32(_D22)
39
+
40
+ _S30 = np.float32(_D30)
41
+ _S31 = np.float32(_D31)
42
+ _S32 = np.float32(_D32)
43
+ _S33 = np.float32(_D33)
44
+
45
+ _S40 = np.float32(_D40)
46
+ _S41 = np.float32(_D41)
47
+ _S42 = np.float32(_D42)
48
+ _S43 = np.float32(_D43)
49
+ _S44 = np.float32(_D44)
50
+
51
+ _S50 = np.float32(_D50)
52
+ _S51 = np.float32(_D51)
53
+ _S52 = np.float32(_D52)
54
+ _S53 = np.float32(_D53)
55
+ _S54 = np.float32(_D54)
56
+ _S55 = np.float32(_D55)
57
+
58
+
59
+ # -------------------------------------------------------------------------
60
+ # Fast approximations for np.log
61
+ # -------------------------------------------------------------------------
62
+ @numba.njit(fastmath=True, inline="always")
63
+ def fast_log_f64_poly(x: np.float64, degree: int) -> np.float64:
64
+ """
65
+ Fast log approximation using polynomial after range reduction.
66
+ Accuracy depends on degree:
67
+ degree=2: max abs error ~0.004 over entire range.
68
+ degree=3: max abs error ~0.0005 over entire range.
69
+ degree=4: max abs error ~0.00007 over entire range.
70
+ degree=5: max abs error ~0.00001 over entire range.
71
+ """
72
+ return _D_LOG_2 * fast_log2_f64_poly(x, degree)
73
+
74
+
75
+ @numba.njit(fastmath=True, inline="always")
76
+ def fast_log_f32_poly(x: np.float32, degree: int) -> np.float32:
77
+ """
78
+ Fast log approximation using polynomial after range reduction.
79
+ Accuracy depends on degree:
80
+ degree=2: max abs error ~0.004 over entire range.
81
+ degree=3: max abs error ~0.0005 over entire range.
82
+ degree=4: max abs error ~0.00007 over entire range.
83
+ degree=5: max abs error ~0.00001 over entire range.
84
+ """
85
+ return _S_LOG_2 * fast_log2_f32_poly(x, degree)
86
+
87
+
88
+ # -------------------------------------------------------------------------
89
+ # Fast approximations for np.log2
90
+ # -------------------------------------------------------------------------
91
+ @numba.njit(fastmath=True, inline="always")
92
+ def fast_log2_f64_poly(x: np.float64, degree: int) -> np.float64:
93
+ """
94
+ Fast log2 approximation using polynomial after range reduction.
95
+ Accuracy depends on degree:
96
+ degree=2: max abs error ~0.006 over entire range.
97
+ degree=3: max abs error ~0.0007 over entire range.
98
+ degree=4: max abs error ~0.0001 over entire range.
99
+ degree=5: max abs error ~0.000015 over entire range.
100
+ """
101
+
102
+ # --- extract mantissa & exponent ---------------------
103
+ # exponent
104
+ xi = np.int64(np.float64(x).view(np.int64))
105
+ exponent = ((xi >> 52) & 0x7FF) - 1022
106
+ # mantissa
107
+ xi = (xi & 0x000FFFFFFFFFFFFF) | 0x3FE0000000000000
108
+ m = np.int64(xi).view(np.float64)
109
+
110
+ # --- polynomial approximation ------------------------
111
+ if degree == 2:
112
+ # log2_mantissa = _D20 + m * (_D21 + m * _D22)
113
+ log2_mantissa = _D20 + (m * _D21) + (m * m * _D22)
114
+ elif degree == 3:
115
+ log2_mantissa = _D30 + m * (_D31 + m * (_D32 + m * _D33))
116
+ elif degree == 4:
117
+ log2_mantissa = _D40 + m * (_D41 + m * (_D42 + m * (_D43 + m * _D44)))
118
+ else:
119
+ log2_mantissa = _D50 + m * (_D51 + m * (_D52 + m * (_D53 + m * (_D54 + m * _D55))))
120
+
121
+ # Return log2(x) = exponent + log2(m)
122
+ return exponent + log2_mantissa
123
+
124
+
125
+ @numba.njit(fastmath=True, inline="always")
126
+ def fast_log2_f32_poly(x: np.float32, degree: int) -> np.float32:
127
+ """
128
+ Fast log2 approximation using polynomial after range reduction.
129
+ Accuracy depends on degree:
130
+ degree=2: max abs error ~0.006 over entire range.
131
+ degree=3: max abs error ~0.0007 over entire range.
132
+ degree=4: max abs error ~0.0001 over entire range.
133
+ degree=5: max abs error ~0.000015 over entire range.
134
+ """
135
+
136
+ # --- extract mantissa & exponent ---------------------
137
+ # exponent
138
+ xi = np.int32(np.float32(x).view(np.int32))
139
+ exponent = ((xi >> 23) & 0xFF) - 126
140
+ # mantissa
141
+ xi = (xi & 0x007FFFFF) | 0x3F000000
142
+ m = np.int32(xi).view(np.float32)
143
+
144
+ # --- polynomial approximation ------------------------
145
+ if degree == 2:
146
+ # log2_mantissa = _S20 + m * (_S21 + m * _S22)
147
+ log2_mantissa = _S20 + (m * _S21) + (m * m * _S22)
148
+ elif degree == 3:
149
+ log2_mantissa = _S30 + m * (_S31 + m * (_S32 + m * _S33))
150
+ elif degree == 4:
151
+ log2_mantissa = _S40 + m * (_S41 + m * (_S42 + m * (_S43 + m * _S44)))
152
+ else:
153
+ log2_mantissa = _S50 + m * (_S51 + m * (_S52 + m * (_S53 + m * (_S54 + m * _S55))))
154
+
155
+ # Return log2(x) = exponent + log2(mantissa)
156
+ return exponent + log2_mantissa
157
+
158
+
159
+ # =================================================================================================
160
+ # Public API
161
+ # =================================================================================================
162
+ __ALL__ = [
163
+ "fast_log2_f32_poly",
164
+ "fast_log2_f32_poly",
165
+ "fast_log_f64_poly",
166
+ "fast_log_f64_poly",
167
+ ]
@@ -0,0 +1,166 @@
1
+ """
2
+ Custom simple random number generation module for integration in numba oriented code. Faster than numpy.random
3
+ even when used inside numba.njit.
4
+
5
+ This is based on the following:
6
+ - https://www.pcg-random.org/posts/bounded-rands.html
7
+ - xoroshiro128+ algorithm by David Blackman and Sebastiano Vigna (http://xoroshiro.di.unimi.it/)
8
+ - splitmix64 for seed initialization by Sebastiano Vigna (http://xorshift.di.unimi.it/splitmix64.c)
9
+ """
10
+
11
+ import numba
12
+ import numpy as np
13
+ from numpy import float32, float64, int32, int64, uint32, uint64
14
+
15
+ # =================================================================================================
16
+ # Constants
17
+ # =================================================================================================
18
+
19
+ # Constant for converting uint64 to float64 in [0.0, 1.0): 1.0 / 2**53
20
+ _TO_FLOAT64 = float64(1.0 / 9007199254740992.0)
21
+
22
+
23
+ # Constant for converting uint64 to float32 in [0.0, 1.0): 1.0 / 2**24
24
+ _TO_FLOAT32 = float32(1.0 / 16777216.0)
25
+
26
+
27
+ # =================================================================================================
28
+ # Core
29
+ # =================================================================================================
30
+ @numba.njit(fastmath=True, inline="always")
31
+ def rotl(x: uint64, k: uint64) -> uint64:
32
+ """Rotate left operation"""
33
+ return (x << k) | (x >> (uint64(64) - k))
34
+
35
+
36
+ @numba.njit(fastmath=True, inline="always")
37
+ def _xoroshiro128plus_next(rng_state: np.ndarray[uint64]) -> uint64:
38
+ """Generate next random uint64 and update state in-place"""
39
+ s0 = rng_state[0]
40
+ s1 = rng_state[1]
41
+ result = s0 + s1
42
+
43
+ s1 ^= s0
44
+ rng_state[0] = rotl(s0, uint64(24)) ^ s1 ^ (s1 << uint64(16))
45
+ rng_state[1] = rotl(s1, uint64(37))
46
+
47
+ return result
48
+
49
+
50
+ @numba.njit(fastmath=True)
51
+ def _splitmix64_next(init_state: np.ndarray[uint64]) -> uint64:
52
+ """Used to initialize xoroshiro128+ state from single seed; state is a 1-element array, modified in-place."""
53
+ z = init_state[0] + uint64(0x9E3779B97F4A7C15)
54
+ init_state[0] = z
55
+ z = (z ^ (z >> uint64(30))) * uint64(0xBF58476D1CE4E5B9)
56
+ z = (z ^ (z >> uint64(27))) * uint64(0x94D049BB133111EB)
57
+ return z ^ (z >> uint64(31))
58
+
59
+
60
+ # =================================================================================================
61
+ # Interface
62
+ # =================================================================================================
63
+ @numba.njit(fastmath=True, inline="always")
64
+ def set_seed(seed: np.int64) -> np.ndarray[uint64]:
65
+ """Initialize xoroshiro128+ state from single seed; using splitmix64 algorithm."""
66
+ init_state = np.array([seed], dtype=uint64)
67
+
68
+ state = np.empty(2, dtype=uint64)
69
+ state[0] = _splitmix64_next(init_state)
70
+ state[1] = _splitmix64_next(init_state)
71
+
72
+ return state
73
+
74
+
75
+ @numba.njit("float64(uint64[:])", fastmath=True, inline="always")
76
+ def rand_float64(rng_state: np.ndarray[uint64]) -> float64:
77
+ """Generate a random float64 in [0.0, 1.0) using the provided rng_state."""
78
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
79
+ return float64((rnd_uint64 >> uint64(11)) * _TO_FLOAT64) # 2**-53
80
+
81
+
82
+ @numba.njit("float32(uint64[:])", fastmath=True, inline="always")
83
+ def rand_float32(rng_state: np.ndarray[uint64]) -> float32:
84
+ """Generate a random float32 in [0.0, 1.0) using the provided rng_state."""
85
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
86
+ return float32((rnd_uint64 >> uint64(40)) * _TO_FLOAT32) # 2**-24
87
+
88
+
89
+ @numba.njit("int64(uint64[:], int64, int64)", fastmath=True, inline="always")
90
+ def rand_int64(rng_state: np.ndarray[uint64], low: np.int64, high: np.int64) -> np.int64:
91
+ """
92
+ Generate a random int64 in [low, high) using the provided rng_state.
93
+ There might be a small bias for large (high-low) if the range is not a power of two.
94
+ """
95
+ if low == 0:
96
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
97
+ return int64(rnd_uint64 % uint64(high))
98
+ else:
99
+ range_size = high - low
100
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
101
+ return low + int64(rnd_uint64 % uint64(range_size))
102
+
103
+
104
+ @numba.njit("int32(uint64[:], int32, int32)", fastmath=True, inline="always")
105
+ def rand_int32(rng_state: np.ndarray[uint64], low: np.int32, high: np.int32) -> np.int32:
106
+ """
107
+ Generate a random int32 in [low, high) using the provided rng_state.
108
+ There might be a small bias for large (high-low) if the range is not a power of two.
109
+ """
110
+ if low == 0:
111
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
112
+ return int32(rnd_uint64 % uint64(high))
113
+ else:
114
+ range_size = high - low
115
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
116
+ return low + int32(rnd_uint64 % uint64(range_size))
117
+
118
+
119
+ @numba.njit("int32[:](uint64[:], int32, int32, int32)", fastmath=True, inline="always")
120
+ def rand_int32_array(
121
+ rng_state: np.ndarray[uint64], low: np.int32, high: np.int32, size: np.int32
122
+ ) -> np.ndarray[np.int32]:
123
+ """
124
+ Generate an array of random int32 values in [low, high) using the provided rng_state.
125
+ Optimized to generate 2 values per RNG call by using upper and lower 32 bits.
126
+ There might be a small bias for large (high-low) if the range is not a power of two.
127
+ """
128
+ result = np.empty(size, dtype=np.int32)
129
+ if low == 0:
130
+ range_size = uint64(high)
131
+ i = 0
132
+ while i < size:
133
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
134
+ # Use lower 32 bits for first value
135
+ result[i] = int32((rnd_uint64 & uint64(0xFFFFFFFF)) % range_size)
136
+ i += 1
137
+ # Use upper 32 bits for second value if needed
138
+ if i < size:
139
+ result[i] = int32((rnd_uint64 >> uint64(32)) % range_size)
140
+ i += 1
141
+ else:
142
+ range_size = uint64(high - low)
143
+ i = 0
144
+ while i < size:
145
+ rnd_uint64 = _xoroshiro128plus_next(rng_state)
146
+ # Use lower 32 bits for first value
147
+ result[i] = low + int32((rnd_uint64 & uint64(0xFFFFFFFF)) % range_size)
148
+ i += 1
149
+ # Use upper 32 bits for second value if needed
150
+ if i < size:
151
+ result[i] = low + int32((rnd_uint64 >> uint64(32)) % range_size)
152
+ i += 1
153
+ return result
154
+
155
+
156
+ # =================================================================================================
157
+ # Public API
158
+ # =================================================================================================
159
+ __ALL__ = [
160
+ "rand_float32",
161
+ "rand_float64",
162
+ "rand_int32",
163
+ "rand_int32_array",
164
+ "rand_int64",
165
+ "set_seed",
166
+ ]
@@ -0,0 +1,250 @@
1
+ import numba
2
+ import numpy as np
3
+
4
+
5
+ # =================================================================================================
6
+ # select_k_min
7
+ # =================================================================================================
8
+ @numba.njit("int32[:](float32[:], int32)", fastmath=True, inline="always")
9
+ def select_k_min(arr: np.ndarray[np.float32], k: np.int32) -> np.ndarray[np.int32]:
10
+ """
11
+ Find indices of k smallest elements in a float32 array using Numba.
12
+
13
+ This implementation uses a max-heap approach with O(n log k) complexity,
14
+ which is efficient when k << n. The heap maintains the k smallest elements
15
+ seen so far, with the largest of these at the root.
16
+
17
+ Parameters:
18
+ -----------
19
+ arr : np.ndarray[np.float32]
20
+ Input array with n elements (typically 1000-10000)
21
+ k : int
22
+ Number of smallest elements to find
23
+
24
+ Returns:
25
+ --------
26
+ indices : np.ndarray[np.int32]
27
+ Array of k indices pointing to the smallest elements.
28
+ Indices are returned in arbitrary order (not sorted by value).
29
+
30
+ Performance:
31
+ ------------
32
+ - 2-8x faster than np.argpartition for small to moderate k (k ~ 10-100)
33
+ - Best when k << n (e.g., k=100, n=10000)
34
+ - Uses fastmath=True for additional SIMD optimizations
35
+ """
36
+ n = len(arr)
37
+ heap_idx = np.empty(k, dtype=np.int32) # indices (into arr) of elements in the heap
38
+ heap_values = np.empty(k, dtype=np.float32) # values of elements in the heap; largest at heap_values[0]
39
+
40
+ # Build initial heap with first k elements
41
+ for i in range(k):
42
+ heap_idx[i] = i
43
+ heap_values[i] = arr[i]
44
+
45
+ # Heapify: Convert initial k elements into a max-heap
46
+ # -----------------------------------------------------------------------------------
47
+ #
48
+ # assuming we want to represent values v0 >= v1 >= v2 >= v3 >= v4 >= v5 >= v6 into a heap:
49
+ #
50
+ # v0
51
+ # / \
52
+ # v1 v2
53
+ # / \ / \
54
+ # v3 v4 v5 v6
55
+ #
56
+ # Invariant relations:
57
+ # - parents >= leaves (i.e. v0 >= v1,v2; v1 >= v3,v4; v2 >= v5,v6)
58
+ # - leaves of same parent are not necessarily sorted (!) (i.e. the tree could swap branches v1 & v2)
59
+ # - if a parent is at index i, then its children are at indices 2*i+1 and 2*i+2
60
+ # -----------------------------------------------------------------------------------
61
+ # Start from last non-leaf node and sift down
62
+ for i in range(k // 2 - 1, -1, -1):
63
+ i_parent = i
64
+ value = heap_values[i_parent]
65
+ idx = heap_idx[i_parent]
66
+
67
+ # Sift down: move element down until heap property is restored
68
+ while True:
69
+ i_child_left = 2 * i_parent + 1
70
+ i_child_right = i_child_left + 1
71
+ i_child_largest = -1
72
+
73
+ # Find the largest child
74
+ if i_child_left < k:
75
+ if i_child_right < k:
76
+ i_child_largest = (
77
+ i_child_left if heap_values[i_child_left] > heap_values[i_child_right] else i_child_right
78
+ )
79
+ else:
80
+ i_child_largest = i_child_left
81
+
82
+ # If no children or value is larger than the largest child, we're done
83
+ if i_child_largest == -1 or value >= heap_values[i_child_largest]:
84
+ heap_values[i_parent] = value
85
+ heap_idx[i_parent] = idx
86
+ break
87
+
88
+ # Otherwise, move the larger child up and continue
89
+ heap_values[i_parent] = heap_values[i_child_largest]
90
+ heap_idx[i_parent] = heap_idx[i_child_largest]
91
+ i_parent = i_child_largest
92
+
93
+ # Process remaining elements
94
+ # For each element, if it's smaller than heap maximum, replace and sift down
95
+ for i in range(k, n):
96
+ value = arr[i]
97
+ if value < heap_values[0]: # heap_values[0] is the maximum of k smallest
98
+ i_parent = 0
99
+
100
+ # Sift down from root
101
+ while True:
102
+ i_child_left = 2 * i_parent + 1
103
+ i_child_right = i_child_left + 1
104
+ i_child_largest = -1
105
+
106
+ # Find the largest child
107
+ if i_child_left < k:
108
+ if i_child_right < k:
109
+ i_child_largest = (
110
+ i_child_left if heap_values[i_child_left] > heap_values[i_child_right] else i_child_right
111
+ )
112
+ else:
113
+ i_child_largest = i_child_left
114
+
115
+ # If no children or val is larger than largest child, we're done
116
+ if i_child_largest == -1 or value >= heap_values[i_child_largest]:
117
+ heap_values[i_parent] = value
118
+ heap_idx[i_parent] = i
119
+ break
120
+
121
+ # Otherwise, move the larger child up and continue
122
+ heap_values[i_parent] = heap_values[i_child_largest]
123
+ heap_idx[i_parent] = heap_idx[i_child_largest]
124
+ i_parent = i_child_largest
125
+
126
+ return heap_idx
127
+
128
+
129
+ # =================================================================================================
130
+ # select_k_max
131
+ # =================================================================================================
132
+ @numba.njit("int32[:](float32[:], int32)", fastmath=True, inline="always")
133
+ def select_k_max(arr: np.ndarray[np.float32], k: np.int32) -> np.ndarray[np.int32]:
134
+ """
135
+ Find indices of k largest elements in a float32 array using Numba.
136
+
137
+ This implementation uses a min-heap approach with O(n log k) complexity,
138
+ which is efficient when k << n. The heap maintains the k largest elements
139
+ seen so far, with the smallest of these at the root.
140
+
141
+ Parameters:
142
+ -----------
143
+ arr : np.ndarray[np.float32]
144
+ Input array with n elements (typically 1000-10000)
145
+ k : int
146
+ Number of largest elements to find
147
+
148
+ Returns:
149
+ --------
150
+ indices : np.ndarray[np.int32]
151
+ Array of k indices pointing to the largest elements.
152
+ Indices are returned in arbitrary order (not sorted by value).
153
+
154
+ Performance:
155
+ ------------
156
+ - 2-8x faster than np.argpartition for small to moderate k (k ~ 10-100)
157
+ - Best when k << n (e.g., k=100, n=10000)
158
+ - Uses fastmath=True for additional SIMD optimizations
159
+ """
160
+ n = len(arr)
161
+ heap_idx = np.empty(k, dtype=np.int32) # indices (into arr) of elements in the heap
162
+ heap_values = np.empty(k, dtype=np.float32) # values of elements in the heap; smallest at heap_values[0]
163
+
164
+ # Build initial heap with first k elements
165
+ for i in range(k):
166
+ heap_idx[i] = i
167
+ heap_values[i] = arr[i]
168
+
169
+ # Heapify: Convert initial k elements into a min-heap
170
+ # -----------------------------------------------------------------------------------
171
+ #
172
+ # assuming we want to represent values v0 <= v1 <= v2 <= v3 <= v4 <= v5 <= v6 into a heap:
173
+ #
174
+ # v0
175
+ # / \
176
+ # v1 v2
177
+ # / \ / \
178
+ # v3 v4 v5 v6
179
+ #
180
+ # Invariant relations:
181
+ # - parents <= leaves (i.e. v0 <= v1,v2; v1 <= v3,v4; v2 <= v5,v6)
182
+ # - leaves of same parent are not necessarily sorted (!) (i.e. the tree could swap branches v1 & v2)
183
+ # - if a parent is at index i, then its children are at indices 2*i+1 and 2*i+2
184
+ # -----------------------------------------------------------------------------------
185
+ # Start from last non-leaf node and sift down
186
+ for i in range(k // 2 - 1, -1, -1):
187
+ i_parent = i
188
+ value = heap_values[i_parent]
189
+ idx = heap_idx[i_parent]
190
+
191
+ # Sift down: move element down until heap property is restored
192
+ while True:
193
+ i_child_left = 2 * i_parent + 1
194
+ i_child_right = i_child_left + 1
195
+ i_child_smallest = -1
196
+
197
+ # Find the smallest child
198
+ if i_child_left < k:
199
+ if i_child_right < k:
200
+ i_child_smallest = (
201
+ i_child_left if heap_values[i_child_left] < heap_values[i_child_right] else i_child_right
202
+ )
203
+ else:
204
+ i_child_smallest = i_child_left
205
+
206
+ # If no children or value is smaller than the smallest child, we're done
207
+ if i_child_smallest == -1 or value <= heap_values[i_child_smallest]:
208
+ heap_values[i_parent] = value
209
+ heap_idx[i_parent] = idx
210
+ break
211
+
212
+ # Otherwise, move the smaller child up and continue
213
+ heap_values[i_parent] = heap_values[i_child_smallest]
214
+ heap_idx[i_parent] = heap_idx[i_child_smallest]
215
+ i_parent = i_child_smallest
216
+
217
+ # Process remaining elements
218
+ # For each element, if it's larger than heap minimum, replace and sift down
219
+ for i in range(k, n):
220
+ value = arr[i]
221
+ if value > heap_values[0]: # heap_values[0] is the minimum of k largest
222
+ i_parent = 0
223
+
224
+ # Sift down from root
225
+ while True:
226
+ i_child_left = 2 * i_parent + 1
227
+ i_child_right = i_child_left + 1
228
+ i_child_smallest = -1
229
+
230
+ # Find the smallest child
231
+ if i_child_left < k:
232
+ if i_child_right < k:
233
+ i_child_smallest = (
234
+ i_child_left if heap_values[i_child_left] < heap_values[i_child_right] else i_child_right
235
+ )
236
+ else:
237
+ i_child_smallest = i_child_left
238
+
239
+ # If no children or val is smaller than smallest child, we're done
240
+ if i_child_smallest == -1 or value <= heap_values[i_child_smallest]:
241
+ heap_values[i_parent] = value
242
+ heap_idx[i_parent] = i
243
+ break
244
+
245
+ # Otherwise, move the smaller child up and continue
246
+ heap_values[i_parent] = heap_values[i_child_smallest]
247
+ heap_idx[i_parent] = heap_idx[i_child_smallest]
248
+ i_parent = i_child_smallest
249
+
250
+ return heap_idx
@@ -1 +1 @@
1
- from .discrete import sample_int
1
+ from .uncon import randint, randint_numba, randint_numpy