SearchLibrium 0.0.85__tar.gz → 0.0.87__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/PKG-INFO +1 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/pyproject.toml +1 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/__init__.py +2 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mdcev.py +214 -74
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/multinomial_logit.py +1 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/multinomial_nested.py +1 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/ordered_logit.py +2 -2
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/rrm.py +9 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/siman.py +3 -6
- searchlibrium-0.0.87/src/SearchLibrium/version.txt +1 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
- searchlibrium-0.0.85/src/SearchLibrium/version.txt +0 -1
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/README.md +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/setup.cfg +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/Halton.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/MixedLogit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/RandomP.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/Two_Level_Nest.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/__main__.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/_choice_model.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/_device.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/bhhh/minimize.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/boxcox_functions.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/call_meta.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/constraints_builder.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/harmony.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/latent_class.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/main.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/main_debug.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/misc.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mixed_logit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mixed_nested.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mixedrrm.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/multinomial_probit.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/search.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/selection_models.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/setup.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/threshold.py +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/SOURCES.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/requires.txt +0 -0
- {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/top_level.txt +0 -0
|
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
|
|
|
59
59
|
realpython = "SearchLibrium.__main__:main"
|
|
60
60
|
|
|
61
61
|
[tool.bumpver]
|
|
62
|
-
current_version = "0.0.
|
|
62
|
+
current_version = "0.0.87"
|
|
63
63
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
64
64
|
commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
|
|
65
65
|
commit = true
|
|
@@ -88,6 +88,7 @@ except ImportError:
|
|
|
88
88
|
try:
|
|
89
89
|
from ._choice_model import DiscreteChoiceModel
|
|
90
90
|
from .multinomial_logit import MultinomialLogit
|
|
91
|
+
from .MixedLogit import MixedLogit
|
|
91
92
|
from .multinomial_nested import NestedLogit, MultiLayerNestedLogit
|
|
92
93
|
from .Halton import Halton
|
|
93
94
|
from .rrm import RandomRegret
|
|
@@ -107,6 +108,7 @@ try:
|
|
|
107
108
|
except ImportError as e:
|
|
108
109
|
from _choice_model import DiscreteChoiceModel
|
|
109
110
|
from multinomial_logit import MultinomialLogit
|
|
111
|
+
from MixedLogit import MixedLogit
|
|
110
112
|
from multinomial_nested import NestedLogit, MultiLayerNestedLogit
|
|
111
113
|
from Halton import Halton
|
|
112
114
|
from rrm import RandomRegret
|
|
@@ -9,11 +9,18 @@ conditions.
|
|
|
9
9
|
|
|
10
10
|
The class is intended as a practical bridge between the current scalar budget
|
|
11
11
|
models and a fuller MDCEV pipeline. It includes both a stable heuristic fit
|
|
12
|
-
and a
|
|
12
|
+
and a JAX-accelerated quasi-MLE refinement with exact automatic differentiation.
|
|
13
|
+
|
|
14
|
+
JAX is used for:
|
|
15
|
+
- Bisection solver: JIT-compiled + vmapped over observations (fast predict/simulate)
|
|
16
|
+
- fit_mle gradients: exact autodiff via jax.grad instead of scipy FD approximations
|
|
17
|
+
|
|
18
|
+
A pure-numpy fallback is provided when JAX is not installed.
|
|
13
19
|
"""
|
|
14
20
|
|
|
15
21
|
from __future__ import annotations
|
|
16
22
|
|
|
23
|
+
import os as _os
|
|
17
24
|
from dataclasses import dataclass
|
|
18
25
|
from typing import Iterable, Optional
|
|
19
26
|
|
|
@@ -21,6 +28,17 @@ import numpy as np
|
|
|
21
28
|
import pandas as pd
|
|
22
29
|
from scipy.optimize import minimize
|
|
23
30
|
|
|
31
|
+
# ── JAX optional import ──────────────────────────────────────────────────────
|
|
32
|
+
_os.environ.setdefault("JAX_PLATFORMS", "cpu")
|
|
33
|
+
try:
|
|
34
|
+
import jax
|
|
35
|
+
import jax.numpy as jnp
|
|
36
|
+
from jax import jit as _jit, vmap as _vmap
|
|
37
|
+
jax.config.update("jax_enable_x64", True)
|
|
38
|
+
_JAX_AVAILABLE = True
|
|
39
|
+
except Exception:
|
|
40
|
+
_JAX_AVAILABLE = False
|
|
41
|
+
|
|
24
42
|
|
|
25
43
|
def _as_2d_float(array_like) -> np.ndarray:
|
|
26
44
|
arr = np.asarray(array_like, dtype=float)
|
|
@@ -42,12 +60,60 @@ class MDCEVFitResult:
|
|
|
42
60
|
mean_budget: float
|
|
43
61
|
|
|
44
62
|
|
|
63
|
+
# ── JAX bisection kernel (module-level, JIT-compiled once) ───────────────────
|
|
64
|
+
if _JAX_AVAILABLE:
|
|
65
|
+
def _jax_bisect_single(budget, util_idx, alpha, gamma):
|
|
66
|
+
"""Translated-utility KKT bisection for one observation.
|
|
67
|
+
|
|
68
|
+
Fully JAX-jittable and differentiable: uses ``jax.lax.scan`` for both
|
|
69
|
+
the initial bracket search and the bisection, so gradients flow through
|
|
70
|
+
the solver and ``fit_mle`` can use exact autodiff.
|
|
71
|
+
"""
|
|
72
|
+
_tol = 1e-9
|
|
73
|
+
weights = jnp.exp(jnp.clip(util_idx, -40.0, 40.0))
|
|
74
|
+
|
|
75
|
+
def _alloc_sum(lam):
|
|
76
|
+
power = 1.0 / jnp.clip(1.0 - alpha, _tol, None)
|
|
77
|
+
raw = jnp.power(weights / jnp.maximum(lam, _tol), power) - gamma
|
|
78
|
+
return jnp.maximum(raw, 0.0).sum()
|
|
79
|
+
|
|
80
|
+
# Find upper bracket: keep doubling hi until alloc_sum(hi) <= budget.
|
|
81
|
+
# scan (max 60 steps) is differentiable; once condition is false hi stays fixed.
|
|
82
|
+
init_hi = jnp.maximum(jnp.max(weights), 1.0)
|
|
83
|
+
|
|
84
|
+
def _double_step(hi, _):
|
|
85
|
+
return jnp.where(_alloc_sum(hi) > budget, hi * 2.0, hi), None
|
|
86
|
+
|
|
87
|
+
hi, _ = jax.lax.scan(_double_step, init_hi, None, length=60)
|
|
88
|
+
|
|
89
|
+
# 80-step bisection
|
|
90
|
+
def _bisect_step(carry, _):
|
|
91
|
+
lo, hi = carry
|
|
92
|
+
mid = 0.5 * (lo + hi)
|
|
93
|
+
go_lo = _alloc_sum(mid) > budget
|
|
94
|
+
return (jnp.where(go_lo, mid, lo), jnp.where(go_lo, hi, mid)), None
|
|
95
|
+
|
|
96
|
+
(_, hi_f), _ = jax.lax.scan(_bisect_step, (_tol, hi), None, length=80)
|
|
97
|
+
|
|
98
|
+
power = 1.0 / jnp.clip(1.0 - alpha, _tol, None)
|
|
99
|
+
raw = jnp.power(weights / jnp.maximum(hi_f, _tol), power) - gamma
|
|
100
|
+
allocation = jnp.maximum(raw, 0.0)
|
|
101
|
+
total = allocation.sum()
|
|
102
|
+
allocation = jnp.where(total > _tol, allocation * (budget / total), allocation)
|
|
103
|
+
# zero-budget guard
|
|
104
|
+
return jnp.where(budget <= _tol, jnp.zeros_like(allocation), allocation)
|
|
105
|
+
|
|
106
|
+
# Batch version: vmap over (budget, util_idx); alpha and gamma are shared
|
|
107
|
+
_jax_bisect_batch = _jit(_vmap(_jax_bisect_single, in_axes=(0, 0, None, None)))
|
|
108
|
+
|
|
109
|
+
|
|
45
110
|
class MDCEVModel:
|
|
46
111
|
"""Translated-utility MDCEV-style allocator.
|
|
47
112
|
|
|
48
113
|
Parameters are learned from observed budget shares using stable moment-based
|
|
49
|
-
heuristics, then
|
|
50
|
-
|
|
114
|
+
heuristics (``fit``), then optionally refined via quasi-MLE with JAX
|
|
115
|
+
automatic differentiation (``fit_mle``). Predictions use a JAX-jitted and
|
|
116
|
+
vmapped bisection solver when JAX is available, with a pure-numpy fallback.
|
|
51
117
|
"""
|
|
52
118
|
|
|
53
119
|
def __init__(
|
|
@@ -152,87 +218,139 @@ class MDCEVModel:
|
|
|
152
218
|
):
|
|
153
219
|
"""Likelihood-based parameter refinement.
|
|
154
220
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
improving fit over pure moments.
|
|
221
|
+
Uses JAX automatic differentiation for exact analytic gradients when
|
|
222
|
+
JAX is available, replacing the slow scipy finite-difference
|
|
223
|
+
approximation. Falls back to scipy FD when JAX is not installed.
|
|
159
224
|
"""
|
|
160
225
|
self.fit(allocations, labels=labels)
|
|
161
226
|
|
|
162
227
|
y = _as_2d_float(allocations)
|
|
163
|
-
|
|
228
|
+
budgets_np = y.sum(axis=1)
|
|
164
229
|
n_alt = y.shape[1]
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
230
|
+
alpha_floor = self.alpha_floor
|
|
231
|
+
alpha_cap = self.alpha_cap
|
|
232
|
+
gamma_floor = self.gamma_floor
|
|
233
|
+
tol = self.tol
|
|
234
|
+
og = self.outside_good
|
|
235
|
+
has_og = og is not None and 0 <= og < n_alt
|
|
236
|
+
free_base_idx = [i for i in range(n_alt) if not has_og or i != og]
|
|
237
|
+
|
|
238
|
+
def _pack_np(base, alpha, gamma, sigma):
|
|
173
239
|
p = []
|
|
174
|
-
p.extend(
|
|
175
|
-
p.extend(np.log(np.clip(
|
|
176
|
-
|
|
240
|
+
p.extend(np.asarray(base)[free_base_idx].tolist())
|
|
241
|
+
p.extend(np.log(np.clip(
|
|
242
|
+
(np.asarray(alpha) - alpha_floor) / np.clip(alpha_cap - np.asarray(alpha), tol, None),
|
|
243
|
+
tol, None)).tolist())
|
|
244
|
+
p.extend(np.log(np.clip(np.asarray(gamma), gamma_floor, None)).tolist())
|
|
177
245
|
p.append(np.log(max(float(sigma), 1e-3)))
|
|
178
246
|
return np.asarray(p, dtype=float)
|
|
179
247
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
248
|
+
theta0 = _pack_np(self.baseline_utility_, self.alpha_, self.gamma_, sigma=0.5)
|
|
249
|
+
|
|
250
|
+
if _JAX_AVAILABLE:
|
|
251
|
+
# ── JAX path: exact gradients via autodiff ────────────────────
|
|
252
|
+
y_jax = jnp.array(y)
|
|
253
|
+
B_jax = jnp.array(budgets_np)
|
|
254
|
+
free_idx_jax = jnp.array(free_base_idx)
|
|
255
|
+
|
|
256
|
+
def _unpack_jax(theta):
|
|
257
|
+
o = 0
|
|
258
|
+
n_free = len(free_base_idx)
|
|
259
|
+
free_base = theta[o:o + n_free]
|
|
260
|
+
o += n_free
|
|
261
|
+
base = jnp.zeros(n_alt).at[free_idx_jax].set(free_base)
|
|
262
|
+
|
|
263
|
+
alpha_raw = theta[o:o + n_alt]
|
|
264
|
+
o += n_alt
|
|
265
|
+
alpha = alpha_floor + (alpha_cap - alpha_floor) * jax.nn.sigmoid(alpha_raw)
|
|
266
|
+
|
|
267
|
+
gamma_raw = theta[o:o + n_alt]
|
|
268
|
+
o += n_alt
|
|
269
|
+
gamma = jnp.maximum(jnp.exp(gamma_raw), gamma_floor)
|
|
270
|
+
|
|
271
|
+
sigma = jnp.maximum(jnp.exp(theta[o]), 1e-3)
|
|
272
|
+
return base, alpha, gamma, sigma
|
|
273
|
+
|
|
274
|
+
@_jit
|
|
275
|
+
def _neg_loglike_jax(theta):
|
|
276
|
+
base, alpha, gamma, sigma = _unpack_jax(theta)
|
|
277
|
+
util_matrix = jnp.broadcast_to(base[None, :], (len(B_jax), n_alt))
|
|
278
|
+
preds = _jax_bisect_batch(B_jax, util_matrix, alpha, gamma)
|
|
279
|
+
log_y = jnp.log(jnp.clip(y_jax, tol, None))
|
|
280
|
+
log_p = jnp.log(jnp.clip(preds, tol, None))
|
|
281
|
+
resid = log_y - log_p
|
|
282
|
+
ll = -0.5 * resid.size * jnp.log(2.0 * jnp.pi * sigma * sigma)
|
|
283
|
+
ll -= 0.5 * jnp.sum((resid / sigma) ** 2)
|
|
284
|
+
ll -= l2_penalty * jnp.sum(theta * theta)
|
|
285
|
+
return -ll
|
|
286
|
+
|
|
287
|
+
val_and_grad = _jit(jax.value_and_grad(_neg_loglike_jax))
|
|
288
|
+
|
|
289
|
+
def _scipy_obj(theta_np):
|
|
290
|
+
val, grad = val_and_grad(jnp.array(theta_np, dtype=jnp.float64))
|
|
291
|
+
return float(val), np.asarray(grad, dtype=np.float64)
|
|
292
|
+
|
|
293
|
+
res = minimize(
|
|
294
|
+
_scipy_obj,
|
|
295
|
+
theta0,
|
|
296
|
+
method="L-BFGS-B",
|
|
297
|
+
jac=True,
|
|
298
|
+
options={"maxiter": int(maxiter), "ftol": 1e-9},
|
|
299
|
+
)
|
|
300
|
+
base_j, alpha_j, gamma_j, sigma_j = _unpack_jax(jnp.array(res.x))
|
|
301
|
+
self.baseline_utility_ = np.asarray(base_j)
|
|
302
|
+
self.alpha_ = np.asarray(alpha_j)
|
|
303
|
+
self.gamma_ = np.asarray(gamma_j)
|
|
304
|
+
self.noise_sigma_ = float(sigma_j)
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
# ── numpy / scipy FD fallback (original behaviour) ────────────
|
|
308
|
+
def _unpack_np(theta):
|
|
309
|
+
theta = np.asarray(theta, dtype=float)
|
|
310
|
+
o = 0
|
|
311
|
+
base = self.baseline_utility_.copy()
|
|
312
|
+
for idx in free_base_idx:
|
|
313
|
+
base[idx] = theta[o]
|
|
314
|
+
o += 1
|
|
315
|
+
if has_og:
|
|
316
|
+
base[og] = 0.0
|
|
317
|
+
alpha_raw = theta[o:o + n_alt]; o += n_alt
|
|
318
|
+
alpha = alpha_floor + (alpha_cap - alpha_floor) / (1.0 + np.exp(-alpha_raw))
|
|
319
|
+
gamma_raw = theta[o:o + n_alt]; o += n_alt
|
|
320
|
+
gamma = np.maximum(np.exp(gamma_raw), gamma_floor)
|
|
321
|
+
sigma = max(np.exp(theta[o]), 1e-3)
|
|
322
|
+
return base, alpha, gamma, sigma
|
|
323
|
+
|
|
324
|
+
def _neg_loglike_np(theta):
|
|
325
|
+
base, alpha, gamma, sigma = _unpack_np(theta)
|
|
326
|
+
old_b, old_a, old_g = self.baseline_utility_, self.alpha_, self.gamma_
|
|
327
|
+
self.baseline_utility_, self.alpha_, self.gamma_ = base, alpha, gamma
|
|
328
|
+
try:
|
|
329
|
+
mu = np.zeros_like(y)
|
|
330
|
+
for i, b in enumerate(budgets_np):
|
|
331
|
+
mu[i] = self._solve_budget(float(b), base)
|
|
332
|
+
finally:
|
|
333
|
+
self.baseline_utility_, self.alpha_, self.gamma_ = old_b, old_a, old_g
|
|
334
|
+
log_y = np.log(np.clip(y, tol, None))
|
|
335
|
+
log_mu = np.log(np.clip(mu, tol, None))
|
|
336
|
+
resid = log_y - log_mu
|
|
337
|
+
ll = -0.5 * resid.size * np.log(2.0 * np.pi * sigma * sigma)
|
|
338
|
+
ll -= 0.5 * np.sum((resid / sigma) ** 2)
|
|
339
|
+
ll -= l2_penalty * np.sum(theta * theta)
|
|
340
|
+
return -float(ll)
|
|
341
|
+
|
|
342
|
+
res = minimize(
|
|
343
|
+
_neg_loglike_np,
|
|
344
|
+
theta0,
|
|
345
|
+
method="L-BFGS-B",
|
|
346
|
+
options={"maxiter": int(maxiter), "ftol": 1e-9},
|
|
347
|
+
)
|
|
348
|
+
base, alpha, gamma, sigma = _unpack_np(res.x)
|
|
349
|
+
self.baseline_utility_ = base
|
|
350
|
+
self.alpha_ = alpha
|
|
351
|
+
self.gamma_ = gamma
|
|
352
|
+
self.noise_sigma_ = float(sigma)
|
|
230
353
|
|
|
231
|
-
base, alpha, gamma, sigma = _unpack(res.x)
|
|
232
|
-
self.baseline_utility_ = base
|
|
233
|
-
self.alpha_ = alpha
|
|
234
|
-
self.gamma_ = gamma
|
|
235
|
-
self.noise_sigma_ = float(sigma)
|
|
236
354
|
self.mle_success_ = bool(res.success)
|
|
237
355
|
self.mle_message_ = str(res.message)
|
|
238
356
|
return self
|
|
@@ -267,6 +385,16 @@ class MDCEVModel:
|
|
|
267
385
|
budgets_arr = np.asarray(budgets, dtype=float).reshape(-1)
|
|
268
386
|
shifts = self._prepare_utility_shift(utility_shift, len(budgets_arr))
|
|
269
387
|
|
|
388
|
+
if _JAX_AVAILABLE:
|
|
389
|
+
util_matrix = jnp.array(self.baseline_utility_[None, :] + shifts)
|
|
390
|
+
preds = _jax_bisect_batch(
|
|
391
|
+
jnp.array(budgets_arr),
|
|
392
|
+
util_matrix,
|
|
393
|
+
jnp.array(self.alpha_),
|
|
394
|
+
jnp.array(self.gamma_),
|
|
395
|
+
)
|
|
396
|
+
return np.asarray(preds)
|
|
397
|
+
|
|
270
398
|
predictions = np.zeros((len(budgets_arr), len(self.baseline_utility_)), dtype=float)
|
|
271
399
|
for row_idx, budget in enumerate(budgets_arr):
|
|
272
400
|
predictions[row_idx] = self._solve_budget(budget, self.baseline_utility_ + shifts[row_idx])
|
|
@@ -280,6 +408,17 @@ class MDCEVModel:
|
|
|
280
408
|
rng = np.random.default_rng(random_state)
|
|
281
409
|
|
|
282
410
|
sims = np.zeros((n_draws, len(budgets_arr), len(self.baseline_utility_)), dtype=float)
|
|
411
|
+
|
|
412
|
+
if _JAX_AVAILABLE:
|
|
413
|
+
B_jax = jnp.array(budgets_arr)
|
|
414
|
+
alpha_jax = jnp.array(self.alpha_)
|
|
415
|
+
gamma_jax = jnp.array(self.gamma_)
|
|
416
|
+
for draw_idx in range(n_draws):
|
|
417
|
+
shocks = rng.gumbel(loc=0.0, scale=1.0, size=shifts.shape)
|
|
418
|
+
util_matrix = jnp.array(self.baseline_utility_[None, :] + shifts + shocks)
|
|
419
|
+
sims[draw_idx] = np.asarray(_jax_bisect_batch(B_jax, util_matrix, alpha_jax, gamma_jax))
|
|
420
|
+
return sims
|
|
421
|
+
|
|
283
422
|
for draw_idx in range(n_draws):
|
|
284
423
|
shocks = rng.gumbel(loc=0.0, scale=1.0, size=shifts.shape)
|
|
285
424
|
for row_idx, budget in enumerate(budgets_arr):
|
|
@@ -304,6 +443,7 @@ class MDCEVModel:
|
|
|
304
443
|
return shift_arr
|
|
305
444
|
|
|
306
445
|
def _solve_budget(self, budget: float, utility_index: np.ndarray) -> np.ndarray:
|
|
446
|
+
"""Pure-numpy bisection fallback — used when JAX is not available."""
|
|
307
447
|
if budget <= self.tol:
|
|
308
448
|
return np.zeros(len(self.baseline_utility_), dtype=float)
|
|
309
449
|
|
|
@@ -626,7 +626,7 @@ class MultinomialLogit(DiscreteChoiceModel):
|
|
|
626
626
|
''' ---------------------------------------------------------- '''
|
|
627
627
|
''' Function '''
|
|
628
628
|
''' ---------------------------------------------------------- '''
|
|
629
|
-
def scipy_bfgs_optimization(self, betas, X, y, weights, avail, maxiter, ftol, gtol, jac, return_opg, sklearn=None):
|
|
629
|
+
def scipy_bfgs_optimization(self, betas, X, y, weights, avail, maxiter, ftol, gtol, jac, return_opg=False, sklearn=None):
|
|
630
630
|
# {
|
|
631
631
|
|
|
632
632
|
|
|
@@ -906,7 +906,7 @@ class NestedLogit(MultinomialLogit):
|
|
|
906
906
|
V_j = float(V[i, chosen_alt]) # V_{A1}
|
|
907
907
|
sumPV_g = float(sumPV_per_nest[chosen_nest_name][i]) # Σ_k P(k|A) V_k
|
|
908
908
|
logS_g = float(logS_per_nest[chosen_nest_name][i]) # ln S_A
|
|
909
|
-
T_i = float(T[i])
|
|
909
|
+
T_i = float(T[i, 0]) # T has shape (N,1) due to keepdims=True
|
|
910
910
|
|
|
911
911
|
# TERM 1: ( sumPV_g - V_j ) / lambda_g^2
|
|
912
912
|
term1 = (sumPV_g - V_j) / (lambda_g ** 2)
|
|
@@ -296,7 +296,7 @@ class OrderedLogit():
|
|
|
296
296
|
def find_category(self, values: np.ndarray, thresholds: np.ndarray)->np.ndarray:
|
|
297
297
|
# {
|
|
298
298
|
category = np.digitize(values, thresholds)
|
|
299
|
-
return category
|
|
299
|
+
return category
|
|
300
300
|
# Note: Subtract - 1 to convert indexing starting from 1 to indexing starting from 0
|
|
301
301
|
# }
|
|
302
302
|
|
|
@@ -699,7 +699,7 @@ class OrderedLogit():
|
|
|
699
699
|
self.method = method
|
|
700
700
|
|
|
701
701
|
if start is None:
|
|
702
|
-
start = [0] * self.
|
|
702
|
+
start = [0] * self.nparams
|
|
703
703
|
value = [1] * (self.J - 2) # These are the deltas
|
|
704
704
|
set_last_elements(start, self.J - 2, value)
|
|
705
705
|
|
|
@@ -206,7 +206,7 @@ class RandomRegret(DiscreteChoiceModel):
|
|
|
206
206
|
|
|
207
207
|
X, y, varnames, alts, isvars, transvars, ids, weights, panels, avail = \
|
|
208
208
|
self.set_asarray(X, y, varnames, alts, isvars, transvars, ids, weights, None, avail)
|
|
209
|
-
self.nb_samples, self.nb_attr =
|
|
209
|
+
self.nb_samples, self.nb_attr = X.shape
|
|
210
210
|
self.nb_alt = len(set(alts))
|
|
211
211
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
212
212
|
# CHECK FOR MISTAKES IN DATA
|
|
@@ -247,6 +247,13 @@ class RandomRegret(DiscreteChoiceModel):
|
|
|
247
247
|
self.fixedtransvars = self.transvars
|
|
248
248
|
self.X, self.Xnames = self.setup_design_matrix(self.X)
|
|
249
249
|
|
|
250
|
+
# Update nb_samples to reflect per-observation count (after 3D reshape)
|
|
251
|
+
self.nb_samples = self.N
|
|
252
|
+
self.nb_attr = self.X.shape[2] if self.X.ndim == 3 else self.X.shape[1]
|
|
253
|
+
# Convert binary choice indicators (N*J,) to per-obs chosen alt index (N,)
|
|
254
|
+
y_matrix = self.y.reshape(self.N, self.J)
|
|
255
|
+
self.y = np.argmax(y_matrix, axis=1).astype(int)
|
|
256
|
+
|
|
250
257
|
|
|
251
258
|
|
|
252
259
|
if self.weights is not None:
|
|
@@ -283,6 +290,7 @@ class RandomRegret(DiscreteChoiceModel):
|
|
|
283
290
|
'''
|
|
284
291
|
self.y = np.zeros((self.nb_samples), dtype=int)
|
|
285
292
|
self.attrs = kwargs.get("varnames", varnames)
|
|
293
|
+
self.normalize = kwargs.get('normalize', False)
|
|
286
294
|
self.initialise()
|
|
287
295
|
|
|
288
296
|
|
|
@@ -616,7 +616,7 @@ class SA(Search):
|
|
|
616
616
|
def copy_solution(self, sol):
|
|
617
617
|
# {
|
|
618
618
|
logging.info('normal copy')
|
|
619
|
-
copy_sol = sol
|
|
619
|
+
copy_sol = copy.deepcopy(sol) # Deep copy to prevent shared mutable state (lists, dicts inside solution)
|
|
620
620
|
return copy_sol
|
|
621
621
|
# }
|
|
622
622
|
|
|
@@ -998,15 +998,12 @@ class SA(Search):
|
|
|
998
998
|
print("Initialized best_sol with the first solution")
|
|
999
999
|
return
|
|
1000
1000
|
|
|
1001
|
-
self.no_impr = 0
|
|
1002
1001
|
if self.nb_crit == 1:
|
|
1003
1002
|
# {
|
|
1004
1003
|
if is_better(sol.obj(0), self.best_sol.obj(0), self.param.sign_crit(0)):
|
|
1005
1004
|
self.best_sol = self.copy_solution(sol)
|
|
1006
|
-
|
|
1007
|
-
print('new best')
|
|
1008
|
-
|
|
1009
1005
|
self.no_impr = 0
|
|
1006
|
+
print('new best')
|
|
1010
1007
|
# }
|
|
1011
1008
|
else:
|
|
1012
1009
|
self.archive = self.update_archive(sol)
|
|
@@ -1087,7 +1084,7 @@ class SA(Search):
|
|
|
1087
1084
|
# {
|
|
1088
1085
|
print("ARCHIVE STATIC. RESTORE A NON DOMINATED SOLUTION.")
|
|
1089
1086
|
choice = np.random.randint(len(self.archive))
|
|
1090
|
-
self.current_sol = self.archive[choice]
|
|
1087
|
+
self.current_sol = self.copy_solution(self.archive[choice]) # Deep copy to avoid aliasing archive entry
|
|
1091
1088
|
self.no_impr = 0
|
|
1092
1089
|
# }
|
|
1093
1090
|
# }
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.87
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
0.0.85
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|