SearchLibrium 0.0.85__tar.gz → 0.0.86__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/PKG-INFO +1 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/pyproject.toml +1 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/mdcev.py +214 -74
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/siman.py +3 -6
- searchlibrium-0.0.86/src/SearchLibrium/version.txt +1 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
- searchlibrium-0.0.85/src/SearchLibrium/version.txt +0 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/README.md +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/setup.cfg +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/Halton.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/MixedLogit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/RandomP.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/Two_Level_Nest.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/__init__.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/__main__.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/_choice_model.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/_device.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/bhhh/minimize.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/boxcox_functions.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/call_meta.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/constraints_builder.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/harmony.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/latent_class.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/main.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/main_debug.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/misc.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/mixed_logit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/mixed_nested.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/mixedrrm.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/multinomial_logit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/multinomial_nested.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/multinomial_probit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/ordered_logit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/rrm.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/search.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/selection_models.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/setup.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium/threshold.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/SOURCES.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/requires.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/top_level.txt +0 -0
|
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
|
|
|
59
59
|
realpython = "SearchLibrium.__main__:main"
|
|
60
60
|
|
|
61
61
|
[tool.bumpver]
|
|
62
|
-
current_version = "0.0.
|
|
62
|
+
current_version = "0.0.86"
|
|
63
63
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
64
64
|
commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
|
|
65
65
|
commit = true
|
|
@@ -9,11 +9,18 @@ conditions.
|
|
|
9
9
|
|
|
10
10
|
The class is intended as a practical bridge between the current scalar budget
|
|
11
11
|
models and a fuller MDCEV pipeline. It includes both a stable heuristic fit
|
|
12
|
-
and a
|
|
12
|
+
and a JAX-accelerated quasi-MLE refinement with exact automatic differentiation.
|
|
13
|
+
|
|
14
|
+
JAX is used for:
|
|
15
|
+
- Bisection solver: JIT-compiled + vmapped over observations (fast predict/simulate)
|
|
16
|
+
- fit_mle gradients: exact autodiff via jax.grad instead of scipy FD approximations
|
|
17
|
+
|
|
18
|
+
A pure-numpy fallback is provided when JAX is not installed.
|
|
13
19
|
"""
|
|
14
20
|
|
|
15
21
|
from __future__ import annotations
|
|
16
22
|
|
|
23
|
+
import os as _os
|
|
17
24
|
from dataclasses import dataclass
|
|
18
25
|
from typing import Iterable, Optional
|
|
19
26
|
|
|
@@ -21,6 +28,17 @@ import numpy as np
|
|
|
21
28
|
import pandas as pd
|
|
22
29
|
from scipy.optimize import minimize
|
|
23
30
|
|
|
31
|
+
# ── JAX optional import ──────────────────────────────────────────────────────
|
|
32
|
+
_os.environ.setdefault("JAX_PLATFORMS", "cpu")
|
|
33
|
+
try:
|
|
34
|
+
import jax
|
|
35
|
+
import jax.numpy as jnp
|
|
36
|
+
from jax import jit as _jit, vmap as _vmap
|
|
37
|
+
jax.config.update("jax_enable_x64", True)
|
|
38
|
+
_JAX_AVAILABLE = True
|
|
39
|
+
except Exception:
|
|
40
|
+
_JAX_AVAILABLE = False
|
|
41
|
+
|
|
24
42
|
|
|
25
43
|
def _as_2d_float(array_like) -> np.ndarray:
|
|
26
44
|
arr = np.asarray(array_like, dtype=float)
|
|
@@ -42,12 +60,60 @@ class MDCEVFitResult:
|
|
|
42
60
|
mean_budget: float
|
|
43
61
|
|
|
44
62
|
|
|
63
|
+
# ── JAX bisection kernel (module-level, JIT-compiled once) ───────────────────
|
|
64
|
+
if _JAX_AVAILABLE:
|
|
65
|
+
def _jax_bisect_single(budget, util_idx, alpha, gamma):
|
|
66
|
+
"""Translated-utility KKT bisection for one observation.
|
|
67
|
+
|
|
68
|
+
Fully JAX-jittable and differentiable: uses ``jax.lax.scan`` for both
|
|
69
|
+
the initial bracket search and the bisection, so gradients flow through
|
|
70
|
+
the solver and ``fit_mle`` can use exact autodiff.
|
|
71
|
+
"""
|
|
72
|
+
_tol = 1e-9
|
|
73
|
+
weights = jnp.exp(jnp.clip(util_idx, -40.0, 40.0))
|
|
74
|
+
|
|
75
|
+
def _alloc_sum(lam):
|
|
76
|
+
power = 1.0 / jnp.clip(1.0 - alpha, _tol, None)
|
|
77
|
+
raw = jnp.power(weights / jnp.maximum(lam, _tol), power) - gamma
|
|
78
|
+
return jnp.maximum(raw, 0.0).sum()
|
|
79
|
+
|
|
80
|
+
# Find upper bracket: keep doubling hi until alloc_sum(hi) <= budget.
|
|
81
|
+
# scan (max 60 steps) is differentiable; once condition is false hi stays fixed.
|
|
82
|
+
init_hi = jnp.maximum(jnp.max(weights), 1.0)
|
|
83
|
+
|
|
84
|
+
def _double_step(hi, _):
|
|
85
|
+
return jnp.where(_alloc_sum(hi) > budget, hi * 2.0, hi), None
|
|
86
|
+
|
|
87
|
+
hi, _ = jax.lax.scan(_double_step, init_hi, None, length=60)
|
|
88
|
+
|
|
89
|
+
# 80-step bisection
|
|
90
|
+
def _bisect_step(carry, _):
|
|
91
|
+
lo, hi = carry
|
|
92
|
+
mid = 0.5 * (lo + hi)
|
|
93
|
+
go_lo = _alloc_sum(mid) > budget
|
|
94
|
+
return (jnp.where(go_lo, mid, lo), jnp.where(go_lo, hi, mid)), None
|
|
95
|
+
|
|
96
|
+
(_, hi_f), _ = jax.lax.scan(_bisect_step, (_tol, hi), None, length=80)
|
|
97
|
+
|
|
98
|
+
power = 1.0 / jnp.clip(1.0 - alpha, _tol, None)
|
|
99
|
+
raw = jnp.power(weights / jnp.maximum(hi_f, _tol), power) - gamma
|
|
100
|
+
allocation = jnp.maximum(raw, 0.0)
|
|
101
|
+
total = allocation.sum()
|
|
102
|
+
allocation = jnp.where(total > _tol, allocation * (budget / total), allocation)
|
|
103
|
+
# zero-budget guard
|
|
104
|
+
return jnp.where(budget <= _tol, jnp.zeros_like(allocation), allocation)
|
|
105
|
+
|
|
106
|
+
# Batch version: vmap over (budget, util_idx); alpha and gamma are shared
|
|
107
|
+
_jax_bisect_batch = _jit(_vmap(_jax_bisect_single, in_axes=(0, 0, None, None)))
|
|
108
|
+
|
|
109
|
+
|
|
45
110
|
class MDCEVModel:
|
|
46
111
|
"""Translated-utility MDCEV-style allocator.
|
|
47
112
|
|
|
48
113
|
Parameters are learned from observed budget shares using stable moment-based
|
|
49
|
-
heuristics, then
|
|
50
|
-
|
|
114
|
+
heuristics (``fit``), then optionally refined via quasi-MLE with JAX
|
|
115
|
+
automatic differentiation (``fit_mle``). Predictions use a JAX-jitted and
|
|
116
|
+
vmapped bisection solver when JAX is available, with a pure-numpy fallback.
|
|
51
117
|
"""
|
|
52
118
|
|
|
53
119
|
def __init__(
|
|
@@ -152,87 +218,139 @@ class MDCEVModel:
|
|
|
152
218
|
):
|
|
153
219
|
"""Likelihood-based parameter refinement.
|
|
154
220
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
improving fit over pure moments.
|
|
221
|
+
Uses JAX automatic differentiation for exact analytic gradients when
|
|
222
|
+
JAX is available, replacing the slow scipy finite-difference
|
|
223
|
+
approximation. Falls back to scipy FD when JAX is not installed.
|
|
159
224
|
"""
|
|
160
225
|
self.fit(allocations, labels=labels)
|
|
161
226
|
|
|
162
227
|
y = _as_2d_float(allocations)
|
|
163
|
-
|
|
228
|
+
budgets_np = y.sum(axis=1)
|
|
164
229
|
n_alt = y.shape[1]
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
230
|
+
alpha_floor = self.alpha_floor
|
|
231
|
+
alpha_cap = self.alpha_cap
|
|
232
|
+
gamma_floor = self.gamma_floor
|
|
233
|
+
tol = self.tol
|
|
234
|
+
og = self.outside_good
|
|
235
|
+
has_og = og is not None and 0 <= og < n_alt
|
|
236
|
+
free_base_idx = [i for i in range(n_alt) if not has_og or i != og]
|
|
237
|
+
|
|
238
|
+
def _pack_np(base, alpha, gamma, sigma):
|
|
173
239
|
p = []
|
|
174
|
-
p.extend(
|
|
175
|
-
p.extend(np.log(np.clip(
|
|
176
|
-
|
|
240
|
+
p.extend(np.asarray(base)[free_base_idx].tolist())
|
|
241
|
+
p.extend(np.log(np.clip(
|
|
242
|
+
(np.asarray(alpha) - alpha_floor) / np.clip(alpha_cap - np.asarray(alpha), tol, None),
|
|
243
|
+
tol, None)).tolist())
|
|
244
|
+
p.extend(np.log(np.clip(np.asarray(gamma), gamma_floor, None)).tolist())
|
|
177
245
|
p.append(np.log(max(float(sigma), 1e-3)))
|
|
178
246
|
return np.asarray(p, dtype=float)
|
|
179
247
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
248
|
+
theta0 = _pack_np(self.baseline_utility_, self.alpha_, self.gamma_, sigma=0.5)
|
|
249
|
+
|
|
250
|
+
if _JAX_AVAILABLE:
|
|
251
|
+
# ── JAX path: exact gradients via autodiff ────────────────────
|
|
252
|
+
y_jax = jnp.array(y)
|
|
253
|
+
B_jax = jnp.array(budgets_np)
|
|
254
|
+
free_idx_jax = jnp.array(free_base_idx)
|
|
255
|
+
|
|
256
|
+
def _unpack_jax(theta):
|
|
257
|
+
o = 0
|
|
258
|
+
n_free = len(free_base_idx)
|
|
259
|
+
free_base = theta[o:o + n_free]
|
|
260
|
+
o += n_free
|
|
261
|
+
base = jnp.zeros(n_alt).at[free_idx_jax].set(free_base)
|
|
262
|
+
|
|
263
|
+
alpha_raw = theta[o:o + n_alt]
|
|
264
|
+
o += n_alt
|
|
265
|
+
alpha = alpha_floor + (alpha_cap - alpha_floor) * jax.nn.sigmoid(alpha_raw)
|
|
266
|
+
|
|
267
|
+
gamma_raw = theta[o:o + n_alt]
|
|
268
|
+
o += n_alt
|
|
269
|
+
gamma = jnp.maximum(jnp.exp(gamma_raw), gamma_floor)
|
|
270
|
+
|
|
271
|
+
sigma = jnp.maximum(jnp.exp(theta[o]), 1e-3)
|
|
272
|
+
return base, alpha, gamma, sigma
|
|
273
|
+
|
|
274
|
+
@_jit
|
|
275
|
+
def _neg_loglike_jax(theta):
|
|
276
|
+
base, alpha, gamma, sigma = _unpack_jax(theta)
|
|
277
|
+
util_matrix = jnp.broadcast_to(base[None, :], (len(B_jax), n_alt))
|
|
278
|
+
preds = _jax_bisect_batch(B_jax, util_matrix, alpha, gamma)
|
|
279
|
+
log_y = jnp.log(jnp.clip(y_jax, tol, None))
|
|
280
|
+
log_p = jnp.log(jnp.clip(preds, tol, None))
|
|
281
|
+
resid = log_y - log_p
|
|
282
|
+
ll = -0.5 * resid.size * jnp.log(2.0 * jnp.pi * sigma * sigma)
|
|
283
|
+
ll -= 0.5 * jnp.sum((resid / sigma) ** 2)
|
|
284
|
+
ll -= l2_penalty * jnp.sum(theta * theta)
|
|
285
|
+
return -ll
|
|
286
|
+
|
|
287
|
+
val_and_grad = _jit(jax.value_and_grad(_neg_loglike_jax))
|
|
288
|
+
|
|
289
|
+
def _scipy_obj(theta_np):
|
|
290
|
+
val, grad = val_and_grad(jnp.array(theta_np, dtype=jnp.float64))
|
|
291
|
+
return float(val), np.asarray(grad, dtype=np.float64)
|
|
292
|
+
|
|
293
|
+
res = minimize(
|
|
294
|
+
_scipy_obj,
|
|
295
|
+
theta0,
|
|
296
|
+
method="L-BFGS-B",
|
|
297
|
+
jac=True,
|
|
298
|
+
options={"maxiter": int(maxiter), "ftol": 1e-9},
|
|
299
|
+
)
|
|
300
|
+
base_j, alpha_j, gamma_j, sigma_j = _unpack_jax(jnp.array(res.x))
|
|
301
|
+
self.baseline_utility_ = np.asarray(base_j)
|
|
302
|
+
self.alpha_ = np.asarray(alpha_j)
|
|
303
|
+
self.gamma_ = np.asarray(gamma_j)
|
|
304
|
+
self.noise_sigma_ = float(sigma_j)
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
# ── numpy / scipy FD fallback (original behaviour) ────────────
|
|
308
|
+
def _unpack_np(theta):
|
|
309
|
+
theta = np.asarray(theta, dtype=float)
|
|
310
|
+
o = 0
|
|
311
|
+
base = self.baseline_utility_.copy()
|
|
312
|
+
for idx in free_base_idx:
|
|
313
|
+
base[idx] = theta[o]
|
|
314
|
+
o += 1
|
|
315
|
+
if has_og:
|
|
316
|
+
base[og] = 0.0
|
|
317
|
+
alpha_raw = theta[o:o + n_alt]; o += n_alt
|
|
318
|
+
alpha = alpha_floor + (alpha_cap - alpha_floor) / (1.0 + np.exp(-alpha_raw))
|
|
319
|
+
gamma_raw = theta[o:o + n_alt]; o += n_alt
|
|
320
|
+
gamma = np.maximum(np.exp(gamma_raw), gamma_floor)
|
|
321
|
+
sigma = max(np.exp(theta[o]), 1e-3)
|
|
322
|
+
return base, alpha, gamma, sigma
|
|
323
|
+
|
|
324
|
+
def _neg_loglike_np(theta):
|
|
325
|
+
base, alpha, gamma, sigma = _unpack_np(theta)
|
|
326
|
+
old_b, old_a, old_g = self.baseline_utility_, self.alpha_, self.gamma_
|
|
327
|
+
self.baseline_utility_, self.alpha_, self.gamma_ = base, alpha, gamma
|
|
328
|
+
try:
|
|
329
|
+
mu = np.zeros_like(y)
|
|
330
|
+
for i, b in enumerate(budgets_np):
|
|
331
|
+
mu[i] = self._solve_budget(float(b), base)
|
|
332
|
+
finally:
|
|
333
|
+
self.baseline_utility_, self.alpha_, self.gamma_ = old_b, old_a, old_g
|
|
334
|
+
log_y = np.log(np.clip(y, tol, None))
|
|
335
|
+
log_mu = np.log(np.clip(mu, tol, None))
|
|
336
|
+
resid = log_y - log_mu
|
|
337
|
+
ll = -0.5 * resid.size * np.log(2.0 * np.pi * sigma * sigma)
|
|
338
|
+
ll -= 0.5 * np.sum((resid / sigma) ** 2)
|
|
339
|
+
ll -= l2_penalty * np.sum(theta * theta)
|
|
340
|
+
return -float(ll)
|
|
341
|
+
|
|
342
|
+
res = minimize(
|
|
343
|
+
_neg_loglike_np,
|
|
344
|
+
theta0,
|
|
345
|
+
method="L-BFGS-B",
|
|
346
|
+
options={"maxiter": int(maxiter), "ftol": 1e-9},
|
|
347
|
+
)
|
|
348
|
+
base, alpha, gamma, sigma = _unpack_np(res.x)
|
|
349
|
+
self.baseline_utility_ = base
|
|
350
|
+
self.alpha_ = alpha
|
|
351
|
+
self.gamma_ = gamma
|
|
352
|
+
self.noise_sigma_ = float(sigma)
|
|
230
353
|
|
|
231
|
-
base, alpha, gamma, sigma = _unpack(res.x)
|
|
232
|
-
self.baseline_utility_ = base
|
|
233
|
-
self.alpha_ = alpha
|
|
234
|
-
self.gamma_ = gamma
|
|
235
|
-
self.noise_sigma_ = float(sigma)
|
|
236
354
|
self.mle_success_ = bool(res.success)
|
|
237
355
|
self.mle_message_ = str(res.message)
|
|
238
356
|
return self
|
|
@@ -267,6 +385,16 @@ class MDCEVModel:
|
|
|
267
385
|
budgets_arr = np.asarray(budgets, dtype=float).reshape(-1)
|
|
268
386
|
shifts = self._prepare_utility_shift(utility_shift, len(budgets_arr))
|
|
269
387
|
|
|
388
|
+
if _JAX_AVAILABLE:
|
|
389
|
+
util_matrix = jnp.array(self.baseline_utility_[None, :] + shifts)
|
|
390
|
+
preds = _jax_bisect_batch(
|
|
391
|
+
jnp.array(budgets_arr),
|
|
392
|
+
util_matrix,
|
|
393
|
+
jnp.array(self.alpha_),
|
|
394
|
+
jnp.array(self.gamma_),
|
|
395
|
+
)
|
|
396
|
+
return np.asarray(preds)
|
|
397
|
+
|
|
270
398
|
predictions = np.zeros((len(budgets_arr), len(self.baseline_utility_)), dtype=float)
|
|
271
399
|
for row_idx, budget in enumerate(budgets_arr):
|
|
272
400
|
predictions[row_idx] = self._solve_budget(budget, self.baseline_utility_ + shifts[row_idx])
|
|
@@ -280,6 +408,17 @@ class MDCEVModel:
|
|
|
280
408
|
rng = np.random.default_rng(random_state)
|
|
281
409
|
|
|
282
410
|
sims = np.zeros((n_draws, len(budgets_arr), len(self.baseline_utility_)), dtype=float)
|
|
411
|
+
|
|
412
|
+
if _JAX_AVAILABLE:
|
|
413
|
+
B_jax = jnp.array(budgets_arr)
|
|
414
|
+
alpha_jax = jnp.array(self.alpha_)
|
|
415
|
+
gamma_jax = jnp.array(self.gamma_)
|
|
416
|
+
for draw_idx in range(n_draws):
|
|
417
|
+
shocks = rng.gumbel(loc=0.0, scale=1.0, size=shifts.shape)
|
|
418
|
+
util_matrix = jnp.array(self.baseline_utility_[None, :] + shifts + shocks)
|
|
419
|
+
sims[draw_idx] = np.asarray(_jax_bisect_batch(B_jax, util_matrix, alpha_jax, gamma_jax))
|
|
420
|
+
return sims
|
|
421
|
+
|
|
283
422
|
for draw_idx in range(n_draws):
|
|
284
423
|
shocks = rng.gumbel(loc=0.0, scale=1.0, size=shifts.shape)
|
|
285
424
|
for row_idx, budget in enumerate(budgets_arr):
|
|
@@ -304,6 +443,7 @@ class MDCEVModel:
|
|
|
304
443
|
return shift_arr
|
|
305
444
|
|
|
306
445
|
def _solve_budget(self, budget: float, utility_index: np.ndarray) -> np.ndarray:
|
|
446
|
+
"""Pure-numpy bisection fallback — used when JAX is not available."""
|
|
307
447
|
if budget <= self.tol:
|
|
308
448
|
return np.zeros(len(self.baseline_utility_), dtype=float)
|
|
309
449
|
|
|
@@ -616,7 +616,7 @@ class SA(Search):
|
|
|
616
616
|
def copy_solution(self, sol):
|
|
617
617
|
# {
|
|
618
618
|
logging.info('normal copy')
|
|
619
|
-
copy_sol = sol
|
|
619
|
+
copy_sol = copy.deepcopy(sol) # Deep copy to prevent shared mutable state (lists, dicts inside solution)
|
|
620
620
|
return copy_sol
|
|
621
621
|
# }
|
|
622
622
|
|
|
@@ -998,15 +998,12 @@ class SA(Search):
|
|
|
998
998
|
print("Initialized best_sol with the first solution")
|
|
999
999
|
return
|
|
1000
1000
|
|
|
1001
|
-
self.no_impr = 0
|
|
1002
1001
|
if self.nb_crit == 1:
|
|
1003
1002
|
# {
|
|
1004
1003
|
if is_better(sol.obj(0), self.best_sol.obj(0), self.param.sign_crit(0)):
|
|
1005
1004
|
self.best_sol = self.copy_solution(sol)
|
|
1006
|
-
|
|
1007
|
-
print('new best')
|
|
1008
|
-
|
|
1009
1005
|
self.no_impr = 0
|
|
1006
|
+
print('new best')
|
|
1010
1007
|
# }
|
|
1011
1008
|
else:
|
|
1012
1009
|
self.archive = self.update_archive(sol)
|
|
@@ -1087,7 +1084,7 @@ class SA(Search):
|
|
|
1087
1084
|
# {
|
|
1088
1085
|
print("ARCHIVE STATIC. RESTORE A NON DOMINATED SOLUTION.")
|
|
1089
1086
|
choice = np.random.randint(len(self.archive))
|
|
1090
|
-
self.current_sol = self.archive[choice]
|
|
1087
|
+
self.current_sol = self.copy_solution(self.archive[choice]) # Deep copy to avoid aliasing archive entry
|
|
1091
1088
|
self.no_impr = 0
|
|
1092
1089
|
# }
|
|
1093
1090
|
# }
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.86
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
0.0.85
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{searchlibrium-0.0.85 → searchlibrium-0.0.86}/src/SearchLibrium.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|