max-div 0.0.3__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- max_div/_cli.py +99 -0
- max_div/benchmark/__init__.py +2 -1
- max_div/benchmark/_formatting.py +218 -0
- max_div/benchmark/randint.py +104 -0
- max_div/benchmark/randint_constrained.py +355 -0
- max_div/constraints/__init__.py +2 -0
- max_div/constraints/_numba.py +110 -0
- max_div/constraints/constraint.py +10 -0
- max_div/constraints/constraints.py +47 -0
- max_div/internal/benchmarking/_micro_benchmark.py +48 -7
- max_div/internal/formatting/__init__.py +1 -0
- max_div/internal/formatting/_markdown.py +43 -0
- max_div/internal/math/__init__.py +1 -0
- max_div/internal/math/fast_log.py +167 -0
- max_div/internal/math/random.py +166 -0
- max_div/internal/math/select_k_minmax.py +250 -0
- max_div/sampling/__init__.py +1 -1
- max_div/sampling/con.py +350 -0
- max_div/sampling/uncon.py +269 -0
- {max_div-0.0.3.dist-info → max_div-0.1.1.dist-info}/METADATA +13 -8
- max_div-0.1.1.dist-info/RECORD +32 -0
- max_div-0.1.1.dist-info/entry_points.txt +2 -0
- max_div/benchmark/sample_int.py +0 -85
- max_div/internal/compat/__init__.py +0 -1
- max_div/internal/compat/_numba/__init__.py +0 -14
- max_div/internal/compat/_numba/_dummy_numba.py +0 -94
- max_div/internal/compat/_numba/_helpers.py +0 -14
- max_div/sampling/discrete.py +0 -176
- max_div-0.0.3.dist-info/RECORD +0 -23
- max_div-0.0.3.dist-info/entry_points.txt +0 -2
- {max_div-0.0.3.dist-info → max_div-0.1.1.dist-info}/WHEEL +0 -0
- {max_div-0.0.3.dist-info → max_div-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import numba
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
# -------------------------------------------------------------------------
|
|
5
|
+
# Constants
|
|
6
|
+
# -------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
# --- float64 ---------------------------------------------
|
|
9
|
+
_D_LOG_2 = 0.6931471805599453 # np.log(2)
|
|
10
|
+
|
|
11
|
+
_D20 = -2.664030016771488
|
|
12
|
+
_D21 = 4.018729576130520
|
|
13
|
+
_D22 = -1.359007257261774
|
|
14
|
+
|
|
15
|
+
_D30 = -3.144940630924
|
|
16
|
+
_D31 = 6.058956424048
|
|
17
|
+
_D32 = -4.157032648692
|
|
18
|
+
_D33 = 1.243566840636
|
|
19
|
+
|
|
20
|
+
_D40 = -3.505614661980
|
|
21
|
+
_D41 = 8.099233785172
|
|
22
|
+
_D42 = -8.397609124753
|
|
23
|
+
_D43 = 5.084088932163
|
|
24
|
+
_D44 = -1.280174030509
|
|
25
|
+
|
|
26
|
+
_D50 = -3.794153676536
|
|
27
|
+
_D51 = 10.139512633266
|
|
28
|
+
_D52 = -14.080875352582
|
|
29
|
+
_D53 = 12.881420375173
|
|
30
|
+
_D54 = -6.551609372263
|
|
31
|
+
_D55 = 1.405716091134
|
|
32
|
+
|
|
33
|
+
# --- float32 ---------------------------------------------
|
|
34
|
+
_S_LOG_2 = np.float32(_D_LOG_2)
|
|
35
|
+
|
|
36
|
+
_S20 = np.float32(_D20)
|
|
37
|
+
_S21 = np.float32(_D21)
|
|
38
|
+
_S22 = np.float32(_D22)
|
|
39
|
+
|
|
40
|
+
_S30 = np.float32(_D30)
|
|
41
|
+
_S31 = np.float32(_D31)
|
|
42
|
+
_S32 = np.float32(_D32)
|
|
43
|
+
_S33 = np.float32(_D33)
|
|
44
|
+
|
|
45
|
+
_S40 = np.float32(_D40)
|
|
46
|
+
_S41 = np.float32(_D41)
|
|
47
|
+
_S42 = np.float32(_D42)
|
|
48
|
+
_S43 = np.float32(_D43)
|
|
49
|
+
_S44 = np.float32(_D44)
|
|
50
|
+
|
|
51
|
+
_S50 = np.float32(_D50)
|
|
52
|
+
_S51 = np.float32(_D51)
|
|
53
|
+
_S52 = np.float32(_D52)
|
|
54
|
+
_S53 = np.float32(_D53)
|
|
55
|
+
_S54 = np.float32(_D54)
|
|
56
|
+
_S55 = np.float32(_D55)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# -------------------------------------------------------------------------
|
|
60
|
+
# Fast approximations for np.log
|
|
61
|
+
# -------------------------------------------------------------------------
|
|
62
|
+
@numba.njit(fastmath=True, inline="always")
|
|
63
|
+
def fast_log_f64_poly(x: np.float64, degree: int) -> np.float64:
|
|
64
|
+
"""
|
|
65
|
+
Fast log approximation using polynomial after range reduction.
|
|
66
|
+
Accuracy depends on degree:
|
|
67
|
+
degree=2: max abs error ~0.004 over entire range.
|
|
68
|
+
degree=3: max abs error ~0.0005 over entire range.
|
|
69
|
+
degree=4: max abs error ~0.00007 over entire range.
|
|
70
|
+
degree=5: max abs error ~0.00001 over entire range.
|
|
71
|
+
"""
|
|
72
|
+
return _D_LOG_2 * fast_log2_f64_poly(x, degree)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@numba.njit(fastmath=True, inline="always")
|
|
76
|
+
def fast_log_f32_poly(x: np.float32, degree: int) -> np.float32:
|
|
77
|
+
"""
|
|
78
|
+
Fast log approximation using polynomial after range reduction.
|
|
79
|
+
Accuracy depends on degree:
|
|
80
|
+
degree=2: max abs error ~0.004 over entire range.
|
|
81
|
+
degree=3: max abs error ~0.0005 over entire range.
|
|
82
|
+
degree=4: max abs error ~0.00007 over entire range.
|
|
83
|
+
degree=5: max abs error ~0.00001 over entire range.
|
|
84
|
+
"""
|
|
85
|
+
return _S_LOG_2 * fast_log2_f32_poly(x, degree)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# -------------------------------------------------------------------------
|
|
89
|
+
# Fast approximations for np.log2
|
|
90
|
+
# -------------------------------------------------------------------------
|
|
91
|
+
@numba.njit(fastmath=True, inline="always")
|
|
92
|
+
def fast_log2_f64_poly(x: np.float64, degree: int) -> np.float64:
|
|
93
|
+
"""
|
|
94
|
+
Fast log2 approximation using polynomial after range reduction.
|
|
95
|
+
Accuracy depends on degree:
|
|
96
|
+
degree=2: max abs error ~0.006 over entire range.
|
|
97
|
+
degree=3: max abs error ~0.0007 over entire range.
|
|
98
|
+
degree=4: max abs error ~0.0001 over entire range.
|
|
99
|
+
degree=5: max abs error ~0.000015 over entire range.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
# --- extract mantissa & exponent ---------------------
|
|
103
|
+
# exponent
|
|
104
|
+
xi = np.int64(np.float64(x).view(np.int64))
|
|
105
|
+
exponent = ((xi >> 52) & 0x7FF) - 1022
|
|
106
|
+
# mantissa
|
|
107
|
+
xi = (xi & 0x000FFFFFFFFFFFFF) | 0x3FE0000000000000
|
|
108
|
+
m = np.int64(xi).view(np.float64)
|
|
109
|
+
|
|
110
|
+
# --- polynomial approximation ------------------------
|
|
111
|
+
if degree == 2:
|
|
112
|
+
# log2_mantissa = _D20 + m * (_D21 + m * _D22)
|
|
113
|
+
log2_mantissa = _D20 + (m * _D21) + (m * m * _D22)
|
|
114
|
+
elif degree == 3:
|
|
115
|
+
log2_mantissa = _D30 + m * (_D31 + m * (_D32 + m * _D33))
|
|
116
|
+
elif degree == 4:
|
|
117
|
+
log2_mantissa = _D40 + m * (_D41 + m * (_D42 + m * (_D43 + m * _D44)))
|
|
118
|
+
else:
|
|
119
|
+
log2_mantissa = _D50 + m * (_D51 + m * (_D52 + m * (_D53 + m * (_D54 + m * _D55))))
|
|
120
|
+
|
|
121
|
+
# Return log2(x) = exponent + log2(m)
|
|
122
|
+
return exponent + log2_mantissa
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@numba.njit(fastmath=True, inline="always")
|
|
126
|
+
def fast_log2_f32_poly(x: np.float32, degree: int) -> np.float32:
|
|
127
|
+
"""
|
|
128
|
+
Fast log2 approximation using polynomial after range reduction.
|
|
129
|
+
Accuracy depends on degree:
|
|
130
|
+
degree=2: max abs error ~0.006 over entire range.
|
|
131
|
+
degree=3: max abs error ~0.0007 over entire range.
|
|
132
|
+
degree=4: max abs error ~0.0001 over entire range.
|
|
133
|
+
degree=5: max abs error ~0.000015 over entire range.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
# --- extract mantissa & exponent ---------------------
|
|
137
|
+
# exponent
|
|
138
|
+
xi = np.int32(np.float32(x).view(np.int32))
|
|
139
|
+
exponent = ((xi >> 23) & 0xFF) - 126
|
|
140
|
+
# mantissa
|
|
141
|
+
xi = (xi & 0x007FFFFF) | 0x3F000000
|
|
142
|
+
m = np.int32(xi).view(np.float32)
|
|
143
|
+
|
|
144
|
+
# --- polynomial approximation ------------------------
|
|
145
|
+
if degree == 2:
|
|
146
|
+
# log2_mantissa = _S20 + m * (_S21 + m * _S22)
|
|
147
|
+
log2_mantissa = _S20 + (m * _S21) + (m * m * _S22)
|
|
148
|
+
elif degree == 3:
|
|
149
|
+
log2_mantissa = _S30 + m * (_S31 + m * (_S32 + m * _S33))
|
|
150
|
+
elif degree == 4:
|
|
151
|
+
log2_mantissa = _S40 + m * (_S41 + m * (_S42 + m * (_S43 + m * _S44)))
|
|
152
|
+
else:
|
|
153
|
+
log2_mantissa = _S50 + m * (_S51 + m * (_S52 + m * (_S53 + m * (_S54 + m * _S55))))
|
|
154
|
+
|
|
155
|
+
# Return log2(x) = exponent + log2(mantissa)
|
|
156
|
+
return exponent + log2_mantissa
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# =================================================================================================
|
|
160
|
+
# Public API
|
|
161
|
+
# =================================================================================================
|
|
162
|
+
__ALL__ = [
|
|
163
|
+
"fast_log2_f32_poly",
|
|
164
|
+
"fast_log2_f32_poly",
|
|
165
|
+
"fast_log_f64_poly",
|
|
166
|
+
"fast_log_f64_poly",
|
|
167
|
+
]
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Custom simple random number generation module for integration in numba oriented code. Faster than numpy.random
|
|
3
|
+
even when used inside numba.njit.
|
|
4
|
+
|
|
5
|
+
This is based on the following:
|
|
6
|
+
- https://www.pcg-random.org/posts/bounded-rands.html
|
|
7
|
+
- xoroshiro128+ algorithm by David Blackman and Sebastiano Vigna (http://xoroshiro.di.unimi.it/)
|
|
8
|
+
- splitmix64 for seed initialization by Sebastiano Vigna (http://xorshift.di.unimi.it/splitmix64.c)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import numba
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy import float32, float64, int32, int64, uint32, uint64
|
|
14
|
+
|
|
15
|
+
# =================================================================================================
|
|
16
|
+
# Constants
|
|
17
|
+
# =================================================================================================
|
|
18
|
+
|
|
19
|
+
# Constant for converting uint64 to float64 in [0.0, 1.0): 1.0 / 2**53
|
|
20
|
+
_TO_FLOAT64 = float64(1.0 / 9007199254740992.0)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Constant for converting uint64 to float32 in [0.0, 1.0): 1.0 / 2**24
|
|
24
|
+
_TO_FLOAT32 = float32(1.0 / 16777216.0)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# =================================================================================================
|
|
28
|
+
# Core
|
|
29
|
+
# =================================================================================================
|
|
30
|
+
@numba.njit(fastmath=True, inline="always")
|
|
31
|
+
def rotl(x: uint64, k: uint64) -> uint64:
|
|
32
|
+
"""Rotate left operation"""
|
|
33
|
+
return (x << k) | (x >> (uint64(64) - k))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@numba.njit(fastmath=True, inline="always")
|
|
37
|
+
def _xoroshiro128plus_next(rng_state: np.ndarray[uint64]) -> uint64:
|
|
38
|
+
"""Generate next random uint64 and update state in-place"""
|
|
39
|
+
s0 = rng_state[0]
|
|
40
|
+
s1 = rng_state[1]
|
|
41
|
+
result = s0 + s1
|
|
42
|
+
|
|
43
|
+
s1 ^= s0
|
|
44
|
+
rng_state[0] = rotl(s0, uint64(24)) ^ s1 ^ (s1 << uint64(16))
|
|
45
|
+
rng_state[1] = rotl(s1, uint64(37))
|
|
46
|
+
|
|
47
|
+
return result
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@numba.njit(fastmath=True)
|
|
51
|
+
def _splitmix64_next(init_state: np.ndarray[uint64]) -> uint64:
|
|
52
|
+
"""Used to initialize xoroshiro128+ state from single seed; state is a 1-element array, modified in-place."""
|
|
53
|
+
z = init_state[0] + uint64(0x9E3779B97F4A7C15)
|
|
54
|
+
init_state[0] = z
|
|
55
|
+
z = (z ^ (z >> uint64(30))) * uint64(0xBF58476D1CE4E5B9)
|
|
56
|
+
z = (z ^ (z >> uint64(27))) * uint64(0x94D049BB133111EB)
|
|
57
|
+
return z ^ (z >> uint64(31))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# =================================================================================================
|
|
61
|
+
# Interface
|
|
62
|
+
# =================================================================================================
|
|
63
|
+
@numba.njit(fastmath=True, inline="always")
|
|
64
|
+
def set_seed(seed: np.int64) -> np.ndarray[uint64]:
|
|
65
|
+
"""Initialize xoroshiro128+ state from single seed; using splitmix64 algorithm."""
|
|
66
|
+
init_state = np.array([seed], dtype=uint64)
|
|
67
|
+
|
|
68
|
+
state = np.empty(2, dtype=uint64)
|
|
69
|
+
state[0] = _splitmix64_next(init_state)
|
|
70
|
+
state[1] = _splitmix64_next(init_state)
|
|
71
|
+
|
|
72
|
+
return state
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@numba.njit("float64(uint64[:])", fastmath=True, inline="always")
|
|
76
|
+
def rand_float64(rng_state: np.ndarray[uint64]) -> float64:
|
|
77
|
+
"""Generate a random float64 in [0.0, 1.0) using the provided rng_state."""
|
|
78
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
79
|
+
return float64((rnd_uint64 >> uint64(11)) * _TO_FLOAT64) # 2**-53
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@numba.njit("float32(uint64[:])", fastmath=True, inline="always")
|
|
83
|
+
def rand_float32(rng_state: np.ndarray[uint64]) -> float32:
|
|
84
|
+
"""Generate a random float32 in [0.0, 1.0) using the provided rng_state."""
|
|
85
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
86
|
+
return float32((rnd_uint64 >> uint64(40)) * _TO_FLOAT32) # 2**-24
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@numba.njit("int64(uint64[:], int64, int64)", fastmath=True, inline="always")
|
|
90
|
+
def rand_int64(rng_state: np.ndarray[uint64], low: np.int64, high: np.int64) -> np.int64:
|
|
91
|
+
"""
|
|
92
|
+
Generate a random int64 in [low, high) using the provided rng_state.
|
|
93
|
+
There might be a small bias for large (high-low) if the range is not a power of two.
|
|
94
|
+
"""
|
|
95
|
+
if low == 0:
|
|
96
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
97
|
+
return int64(rnd_uint64 % uint64(high))
|
|
98
|
+
else:
|
|
99
|
+
range_size = high - low
|
|
100
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
101
|
+
return low + int64(rnd_uint64 % uint64(range_size))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@numba.njit("int32(uint64[:], int32, int32)", fastmath=True, inline="always")
|
|
105
|
+
def rand_int32(rng_state: np.ndarray[uint64], low: np.int32, high: np.int32) -> np.int32:
|
|
106
|
+
"""
|
|
107
|
+
Generate a random int32 in [low, high) using the provided rng_state.
|
|
108
|
+
There might be a small bias for large (high-low) if the range is not a power of two.
|
|
109
|
+
"""
|
|
110
|
+
if low == 0:
|
|
111
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
112
|
+
return int32(rnd_uint64 % uint64(high))
|
|
113
|
+
else:
|
|
114
|
+
range_size = high - low
|
|
115
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
116
|
+
return low + int32(rnd_uint64 % uint64(range_size))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@numba.njit("int32[:](uint64[:], int32, int32, int32)", fastmath=True, inline="always")
|
|
120
|
+
def rand_int32_array(
|
|
121
|
+
rng_state: np.ndarray[uint64], low: np.int32, high: np.int32, size: np.int32
|
|
122
|
+
) -> np.ndarray[np.int32]:
|
|
123
|
+
"""
|
|
124
|
+
Generate an array of random int32 values in [low, high) using the provided rng_state.
|
|
125
|
+
Optimized to generate 2 values per RNG call by using upper and lower 32 bits.
|
|
126
|
+
There might be a small bias for large (high-low) if the range is not a power of two.
|
|
127
|
+
"""
|
|
128
|
+
result = np.empty(size, dtype=np.int32)
|
|
129
|
+
if low == 0:
|
|
130
|
+
range_size = uint64(high)
|
|
131
|
+
i = 0
|
|
132
|
+
while i < size:
|
|
133
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
134
|
+
# Use lower 32 bits for first value
|
|
135
|
+
result[i] = int32((rnd_uint64 & uint64(0xFFFFFFFF)) % range_size)
|
|
136
|
+
i += 1
|
|
137
|
+
# Use upper 32 bits for second value if needed
|
|
138
|
+
if i < size:
|
|
139
|
+
result[i] = int32((rnd_uint64 >> uint64(32)) % range_size)
|
|
140
|
+
i += 1
|
|
141
|
+
else:
|
|
142
|
+
range_size = uint64(high - low)
|
|
143
|
+
i = 0
|
|
144
|
+
while i < size:
|
|
145
|
+
rnd_uint64 = _xoroshiro128plus_next(rng_state)
|
|
146
|
+
# Use lower 32 bits for first value
|
|
147
|
+
result[i] = low + int32((rnd_uint64 & uint64(0xFFFFFFFF)) % range_size)
|
|
148
|
+
i += 1
|
|
149
|
+
# Use upper 32 bits for second value if needed
|
|
150
|
+
if i < size:
|
|
151
|
+
result[i] = low + int32((rnd_uint64 >> uint64(32)) % range_size)
|
|
152
|
+
i += 1
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# =================================================================================================
|
|
157
|
+
# Public API
|
|
158
|
+
# =================================================================================================
|
|
159
|
+
__ALL__ = [
|
|
160
|
+
"rand_float32",
|
|
161
|
+
"rand_float64",
|
|
162
|
+
"rand_int32",
|
|
163
|
+
"rand_int32_array",
|
|
164
|
+
"rand_int64",
|
|
165
|
+
"set_seed",
|
|
166
|
+
]
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import numba
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# =================================================================================================
|
|
6
|
+
# select_k_min
|
|
7
|
+
# =================================================================================================
|
|
8
|
+
@numba.njit("int32[:](float32[:], int32)", fastmath=True, inline="always")
|
|
9
|
+
def select_k_min(arr: np.ndarray[np.float32], k: np.int32) -> np.ndarray[np.int32]:
|
|
10
|
+
"""
|
|
11
|
+
Find indices of k smallest elements in a float32 array using Numba.
|
|
12
|
+
|
|
13
|
+
This implementation uses a max-heap approach with O(n log k) complexity,
|
|
14
|
+
which is efficient when k << n. The heap maintains the k smallest elements
|
|
15
|
+
seen so far, with the largest of these at the root.
|
|
16
|
+
|
|
17
|
+
Parameters:
|
|
18
|
+
-----------
|
|
19
|
+
arr : np.ndarray[np.float32]
|
|
20
|
+
Input array with n elements (typically 1000-10000)
|
|
21
|
+
k : int
|
|
22
|
+
Number of smallest elements to find
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
--------
|
|
26
|
+
indices : np.ndarray[np.int32]
|
|
27
|
+
Array of k indices pointing to the smallest elements.
|
|
28
|
+
Indices are returned in arbitrary order (not sorted by value).
|
|
29
|
+
|
|
30
|
+
Performance:
|
|
31
|
+
------------
|
|
32
|
+
- 2-8x faster than np.argpartition for small to moderate k (k ~ 10-100)
|
|
33
|
+
- Best when k << n (e.g., k=100, n=10000)
|
|
34
|
+
- Uses fastmath=True for additional SIMD optimizations
|
|
35
|
+
"""
|
|
36
|
+
n = len(arr)
|
|
37
|
+
heap_idx = np.empty(k, dtype=np.int32) # indices (into arr) of elements in the heap
|
|
38
|
+
heap_values = np.empty(k, dtype=np.float32) # values of elements in the heap; largest at heap_values[0]
|
|
39
|
+
|
|
40
|
+
# Build initial heap with first k elements
|
|
41
|
+
for i in range(k):
|
|
42
|
+
heap_idx[i] = i
|
|
43
|
+
heap_values[i] = arr[i]
|
|
44
|
+
|
|
45
|
+
# Heapify: Convert initial k elements into a max-heap
|
|
46
|
+
# -----------------------------------------------------------------------------------
|
|
47
|
+
#
|
|
48
|
+
# assuming we want to represent values v0 >= v1 >= v2 >= v3 >= v4 >= v5 >= v6 into a heap:
|
|
49
|
+
#
|
|
50
|
+
# v0
|
|
51
|
+
# / \
|
|
52
|
+
# v1 v2
|
|
53
|
+
# / \ / \
|
|
54
|
+
# v3 v4 v5 v6
|
|
55
|
+
#
|
|
56
|
+
# Invariant relations:
|
|
57
|
+
# - parents >= leaves (i.e. v0 >= v1,v2; v1 >= v3,v4; v2 >= v5,v6)
|
|
58
|
+
# - leaves of same parent are not necessarily sorted (!) (i.e. the tree could swap branches v1 & v2)
|
|
59
|
+
# - if a parent is at index i, then its children are at indices 2*i+1 and 2*i+2
|
|
60
|
+
# -----------------------------------------------------------------------------------
|
|
61
|
+
# Start from last non-leaf node and sift down
|
|
62
|
+
for i in range(k // 2 - 1, -1, -1):
|
|
63
|
+
i_parent = i
|
|
64
|
+
value = heap_values[i_parent]
|
|
65
|
+
idx = heap_idx[i_parent]
|
|
66
|
+
|
|
67
|
+
# Sift down: move element down until heap property is restored
|
|
68
|
+
while True:
|
|
69
|
+
i_child_left = 2 * i_parent + 1
|
|
70
|
+
i_child_right = i_child_left + 1
|
|
71
|
+
i_child_largest = -1
|
|
72
|
+
|
|
73
|
+
# Find the largest child
|
|
74
|
+
if i_child_left < k:
|
|
75
|
+
if i_child_right < k:
|
|
76
|
+
i_child_largest = (
|
|
77
|
+
i_child_left if heap_values[i_child_left] > heap_values[i_child_right] else i_child_right
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
i_child_largest = i_child_left
|
|
81
|
+
|
|
82
|
+
# If no children or value is larger than the largest child, we're done
|
|
83
|
+
if i_child_largest == -1 or value >= heap_values[i_child_largest]:
|
|
84
|
+
heap_values[i_parent] = value
|
|
85
|
+
heap_idx[i_parent] = idx
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
# Otherwise, move the larger child up and continue
|
|
89
|
+
heap_values[i_parent] = heap_values[i_child_largest]
|
|
90
|
+
heap_idx[i_parent] = heap_idx[i_child_largest]
|
|
91
|
+
i_parent = i_child_largest
|
|
92
|
+
|
|
93
|
+
# Process remaining elements
|
|
94
|
+
# For each element, if it's smaller than heap maximum, replace and sift down
|
|
95
|
+
for i in range(k, n):
|
|
96
|
+
value = arr[i]
|
|
97
|
+
if value < heap_values[0]: # heap_values[0] is the maximum of k smallest
|
|
98
|
+
i_parent = 0
|
|
99
|
+
|
|
100
|
+
# Sift down from root
|
|
101
|
+
while True:
|
|
102
|
+
i_child_left = 2 * i_parent + 1
|
|
103
|
+
i_child_right = i_child_left + 1
|
|
104
|
+
i_child_largest = -1
|
|
105
|
+
|
|
106
|
+
# Find the largest child
|
|
107
|
+
if i_child_left < k:
|
|
108
|
+
if i_child_right < k:
|
|
109
|
+
i_child_largest = (
|
|
110
|
+
i_child_left if heap_values[i_child_left] > heap_values[i_child_right] else i_child_right
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
i_child_largest = i_child_left
|
|
114
|
+
|
|
115
|
+
# If no children or val is larger than largest child, we're done
|
|
116
|
+
if i_child_largest == -1 or value >= heap_values[i_child_largest]:
|
|
117
|
+
heap_values[i_parent] = value
|
|
118
|
+
heap_idx[i_parent] = i
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
# Otherwise, move the larger child up and continue
|
|
122
|
+
heap_values[i_parent] = heap_values[i_child_largest]
|
|
123
|
+
heap_idx[i_parent] = heap_idx[i_child_largest]
|
|
124
|
+
i_parent = i_child_largest
|
|
125
|
+
|
|
126
|
+
return heap_idx
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# =================================================================================================
|
|
130
|
+
# select_k_max
|
|
131
|
+
# =================================================================================================
|
|
132
|
+
@numba.njit("int32[:](float32[:], int32)", fastmath=True, inline="always")
|
|
133
|
+
def select_k_max(arr: np.ndarray[np.float32], k: np.int32) -> np.ndarray[np.int32]:
|
|
134
|
+
"""
|
|
135
|
+
Find indices of k largest elements in a float32 array using Numba.
|
|
136
|
+
|
|
137
|
+
This implementation uses a min-heap approach with O(n log k) complexity,
|
|
138
|
+
which is efficient when k << n. The heap maintains the k largest elements
|
|
139
|
+
seen so far, with the smallest of these at the root.
|
|
140
|
+
|
|
141
|
+
Parameters:
|
|
142
|
+
-----------
|
|
143
|
+
arr : np.ndarray[np.float32]
|
|
144
|
+
Input array with n elements (typically 1000-10000)
|
|
145
|
+
k : int
|
|
146
|
+
Number of largest elements to find
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
--------
|
|
150
|
+
indices : np.ndarray[np.int32]
|
|
151
|
+
Array of k indices pointing to the largest elements.
|
|
152
|
+
Indices are returned in arbitrary order (not sorted by value).
|
|
153
|
+
|
|
154
|
+
Performance:
|
|
155
|
+
------------
|
|
156
|
+
- 2-8x faster than np.argpartition for small to moderate k (k ~ 10-100)
|
|
157
|
+
- Best when k << n (e.g., k=100, n=10000)
|
|
158
|
+
- Uses fastmath=True for additional SIMD optimizations
|
|
159
|
+
"""
|
|
160
|
+
n = len(arr)
|
|
161
|
+
heap_idx = np.empty(k, dtype=np.int32) # indices (into arr) of elements in the heap
|
|
162
|
+
heap_values = np.empty(k, dtype=np.float32) # values of elements in the heap; smallest at heap_values[0]
|
|
163
|
+
|
|
164
|
+
# Build initial heap with first k elements
|
|
165
|
+
for i in range(k):
|
|
166
|
+
heap_idx[i] = i
|
|
167
|
+
heap_values[i] = arr[i]
|
|
168
|
+
|
|
169
|
+
# Heapify: Convert initial k elements into a min-heap
|
|
170
|
+
# -----------------------------------------------------------------------------------
|
|
171
|
+
#
|
|
172
|
+
# assuming we want to represent values v0 <= v1 <= v2 <= v3 <= v4 <= v5 <= v6 into a heap:
|
|
173
|
+
#
|
|
174
|
+
# v0
|
|
175
|
+
# / \
|
|
176
|
+
# v1 v2
|
|
177
|
+
# / \ / \
|
|
178
|
+
# v3 v4 v5 v6
|
|
179
|
+
#
|
|
180
|
+
# Invariant relations:
|
|
181
|
+
# - parents <= leaves (i.e. v0 <= v1,v2; v1 <= v3,v4; v2 <= v5,v6)
|
|
182
|
+
# - leaves of same parent are not necessarily sorted (!) (i.e. the tree could swap branches v1 & v2)
|
|
183
|
+
# - if a parent is at index i, then its children are at indices 2*i+1 and 2*i+2
|
|
184
|
+
# -----------------------------------------------------------------------------------
|
|
185
|
+
# Start from last non-leaf node and sift down
|
|
186
|
+
for i in range(k // 2 - 1, -1, -1):
|
|
187
|
+
i_parent = i
|
|
188
|
+
value = heap_values[i_parent]
|
|
189
|
+
idx = heap_idx[i_parent]
|
|
190
|
+
|
|
191
|
+
# Sift down: move element down until heap property is restored
|
|
192
|
+
while True:
|
|
193
|
+
i_child_left = 2 * i_parent + 1
|
|
194
|
+
i_child_right = i_child_left + 1
|
|
195
|
+
i_child_smallest = -1
|
|
196
|
+
|
|
197
|
+
# Find the smallest child
|
|
198
|
+
if i_child_left < k:
|
|
199
|
+
if i_child_right < k:
|
|
200
|
+
i_child_smallest = (
|
|
201
|
+
i_child_left if heap_values[i_child_left] < heap_values[i_child_right] else i_child_right
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
i_child_smallest = i_child_left
|
|
205
|
+
|
|
206
|
+
# If no children or value is smaller than the smallest child, we're done
|
|
207
|
+
if i_child_smallest == -1 or value <= heap_values[i_child_smallest]:
|
|
208
|
+
heap_values[i_parent] = value
|
|
209
|
+
heap_idx[i_parent] = idx
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
# Otherwise, move the smaller child up and continue
|
|
213
|
+
heap_values[i_parent] = heap_values[i_child_smallest]
|
|
214
|
+
heap_idx[i_parent] = heap_idx[i_child_smallest]
|
|
215
|
+
i_parent = i_child_smallest
|
|
216
|
+
|
|
217
|
+
# Process remaining elements
|
|
218
|
+
# For each element, if it's larger than heap minimum, replace and sift down
|
|
219
|
+
for i in range(k, n):
|
|
220
|
+
value = arr[i]
|
|
221
|
+
if value > heap_values[0]: # heap_values[0] is the minimum of k largest
|
|
222
|
+
i_parent = 0
|
|
223
|
+
|
|
224
|
+
# Sift down from root
|
|
225
|
+
while True:
|
|
226
|
+
i_child_left = 2 * i_parent + 1
|
|
227
|
+
i_child_right = i_child_left + 1
|
|
228
|
+
i_child_smallest = -1
|
|
229
|
+
|
|
230
|
+
# Find the smallest child
|
|
231
|
+
if i_child_left < k:
|
|
232
|
+
if i_child_right < k:
|
|
233
|
+
i_child_smallest = (
|
|
234
|
+
i_child_left if heap_values[i_child_left] < heap_values[i_child_right] else i_child_right
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
i_child_smallest = i_child_left
|
|
238
|
+
|
|
239
|
+
# If no children or val is smaller than smallest child, we're done
|
|
240
|
+
if i_child_smallest == -1 or value <= heap_values[i_child_smallest]:
|
|
241
|
+
heap_values[i_parent] = value
|
|
242
|
+
heap_idx[i_parent] = i
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
# Otherwise, move the smaller child up and continue
|
|
246
|
+
heap_values[i_parent] = heap_values[i_child_smallest]
|
|
247
|
+
heap_idx[i_parent] = heap_idx[i_child_smallest]
|
|
248
|
+
i_parent = i_child_smallest
|
|
249
|
+
|
|
250
|
+
return heap_idx
|
max_div/sampling/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .uncon import randint, randint_numba, randint_numpy
|