SearchLibrium 0.0.85__tar.gz → 0.0.87__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/PKG-INFO +1 -1
  2. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/pyproject.toml +1 -1
  3. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/__init__.py +2 -0
  4. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mdcev.py +214 -74
  5. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/multinomial_logit.py +1 -1
  6. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/multinomial_nested.py +1 -1
  7. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/ordered_logit.py +2 -2
  8. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/rrm.py +9 -1
  9. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/siman.py +3 -6
  10. searchlibrium-0.0.87/src/SearchLibrium/version.txt +1 -0
  11. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
  12. searchlibrium-0.0.85/src/SearchLibrium/version.txt +0 -1
  13. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/README.md +0 -0
  14. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/setup.cfg +0 -0
  15. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/Halton.py +0 -0
  16. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/MixedLogit.py +0 -0
  17. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
  18. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/RandomP.py +0 -0
  19. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
  20. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/Two_Level_Nest.py +0 -0
  21. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/__main__.py +0 -0
  22. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/_choice_model.py +0 -0
  23. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/_device.py +0 -0
  24. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/bhhh/minimize.py +0 -0
  25. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/boxcox_functions.py +0 -0
  26. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/call_meta.py +0 -0
  27. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/constraints_builder.py +0 -0
  28. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/harmony.py +0 -0
  29. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/latent_class.py +0 -0
  30. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/main.py +0 -0
  31. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/main_debug.py +0 -0
  32. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/misc.py +0 -0
  33. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mixed_logit.py +0 -0
  34. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mixed_nested.py +0 -0
  35. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/mixedrrm.py +0 -0
  36. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/multinomial_probit.py +0 -0
  37. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
  38. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/search.py +0 -0
  39. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/selection_models.py +0 -0
  40. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/setup.py +0 -0
  41. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium/threshold.py +0 -0
  42. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/SOURCES.txt +0 -0
  43. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
  44. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
  45. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/requires.txt +0 -0
  46. {searchlibrium-0.0.85 → searchlibrium-0.0.87}/src/SearchLibrium.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SearchLibrium
3
- Version: 0.0.85
3
+ Version: 0.0.87
4
4
  Summary: A Python package for econometric models driven by search
5
5
  Author: Alexander Paz Prithvi Beeramole, Robert Burdett
6
6
  Author-email: Zeke Ahern <z.ahern@qut.edu.au>
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
59
59
  realpython = "SearchLibrium.__main__:main"
60
60
 
61
61
  [tool.bumpver]
62
- current_version = "0.0.85"
62
+ current_version = "0.0.87"
63
63
  version_pattern = "MAJOR.MINOR.PATCH"
64
64
  commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
65
65
  commit = true
@@ -88,6 +88,7 @@ except ImportError:
88
88
  try:
89
89
  from ._choice_model import DiscreteChoiceModel
90
90
  from .multinomial_logit import MultinomialLogit
91
+ from .MixedLogit import MixedLogit
91
92
  from .multinomial_nested import NestedLogit, MultiLayerNestedLogit
92
93
  from .Halton import Halton
93
94
  from .rrm import RandomRegret
@@ -107,6 +108,7 @@ try:
107
108
  except ImportError as e:
108
109
  from _choice_model import DiscreteChoiceModel
109
110
  from multinomial_logit import MultinomialLogit
111
+ from MixedLogit import MixedLogit
110
112
  from multinomial_nested import NestedLogit, MultiLayerNestedLogit
111
113
  from Halton import Halton
112
114
  from rrm import RandomRegret
@@ -9,11 +9,18 @@ conditions.
9
9
 
10
10
  The class is intended as a practical bridge between the current scalar budget
11
11
  models and a fuller MDCEV pipeline. It includes both a stable heuristic fit
12
- and a likelihood-based quasi-MLE refinement.
12
+ and a JAX-accelerated quasi-MLE refinement with exact automatic differentiation.
13
+
14
+ JAX is used for:
15
+ - Bisection solver: JIT-compiled + vmapped over observations (fast predict/simulate)
16
+ - fit_mle gradients: exact autodiff via jax.grad instead of scipy FD approximations
17
+
18
+ A pure-numpy fallback is provided when JAX is not installed.
13
19
  """
14
20
 
15
21
  from __future__ import annotations
16
22
 
23
+ import os as _os
17
24
  from dataclasses import dataclass
18
25
  from typing import Iterable, Optional
19
26
 
@@ -21,6 +28,17 @@ import numpy as np
21
28
  import pandas as pd
22
29
  from scipy.optimize import minimize
23
30
 
31
+ # ── JAX optional import ──────────────────────────────────────────────────────
32
+ _os.environ.setdefault("JAX_PLATFORMS", "cpu")
33
+ try:
34
+ import jax
35
+ import jax.numpy as jnp
36
+ from jax import jit as _jit, vmap as _vmap
37
+ jax.config.update("jax_enable_x64", True)
38
+ _JAX_AVAILABLE = True
39
+ except Exception:
40
+ _JAX_AVAILABLE = False
41
+
24
42
 
25
43
  def _as_2d_float(array_like) -> np.ndarray:
26
44
  arr = np.asarray(array_like, dtype=float)
@@ -42,12 +60,60 @@ class MDCEVFitResult:
42
60
  mean_budget: float
43
61
 
44
62
 
63
+ # ── JAX bisection kernel (module-level, JIT-compiled once) ───────────────────
64
+ if _JAX_AVAILABLE:
65
+ def _jax_bisect_single(budget, util_idx, alpha, gamma):
66
+ """Translated-utility KKT bisection for one observation.
67
+
68
+ Fully JAX-jittable and differentiable: uses ``jax.lax.scan`` for both
69
+ the initial bracket search and the bisection, so gradients flow through
70
+ the solver and ``fit_mle`` can use exact autodiff.
71
+ """
72
+ _tol = 1e-9
73
+ weights = jnp.exp(jnp.clip(util_idx, -40.0, 40.0))
74
+
75
+ def _alloc_sum(lam):
76
+ power = 1.0 / jnp.clip(1.0 - alpha, _tol, None)
77
+ raw = jnp.power(weights / jnp.maximum(lam, _tol), power) - gamma
78
+ return jnp.maximum(raw, 0.0).sum()
79
+
80
+ # Find upper bracket: keep doubling hi until alloc_sum(hi) <= budget.
81
+ # scan (max 60 steps) is differentiable; once condition is false hi stays fixed.
82
+ init_hi = jnp.maximum(jnp.max(weights), 1.0)
83
+
84
+ def _double_step(hi, _):
85
+ return jnp.where(_alloc_sum(hi) > budget, hi * 2.0, hi), None
86
+
87
+ hi, _ = jax.lax.scan(_double_step, init_hi, None, length=60)
88
+
89
+ # 80-step bisection
90
+ def _bisect_step(carry, _):
91
+ lo, hi = carry
92
+ mid = 0.5 * (lo + hi)
93
+ go_lo = _alloc_sum(mid) > budget
94
+ return (jnp.where(go_lo, mid, lo), jnp.where(go_lo, hi, mid)), None
95
+
96
+ (_, hi_f), _ = jax.lax.scan(_bisect_step, (_tol, hi), None, length=80)
97
+
98
+ power = 1.0 / jnp.clip(1.0 - alpha, _tol, None)
99
+ raw = jnp.power(weights / jnp.maximum(hi_f, _tol), power) - gamma
100
+ allocation = jnp.maximum(raw, 0.0)
101
+ total = allocation.sum()
102
+ allocation = jnp.where(total > _tol, allocation * (budget / total), allocation)
103
+ # zero-budget guard
104
+ return jnp.where(budget <= _tol, jnp.zeros_like(allocation), allocation)
105
+
106
+ # Batch version: vmap over (budget, util_idx); alpha and gamma are shared
107
+ _jax_bisect_batch = _jit(_vmap(_jax_bisect_single, in_axes=(0, 0, None, None)))
108
+
109
+
45
110
  class MDCEVModel:
46
111
  """Translated-utility MDCEV-style allocator.
47
112
 
48
113
  Parameters are learned from observed budget shares using stable moment-based
49
- heuristics, then predictions are produced by solving the translated-utility
50
- KKT system with a bisection search on the shadow price.
114
+ heuristics (``fit``), then optionally refined via quasi-MLE with JAX
115
+ automatic differentiation (``fit_mle``). Predictions use a JAX-jitted and
116
+ vmapped bisection solver when JAX is available, with a pure-numpy fallback.
51
117
  """
52
118
 
53
119
  def __init__(
@@ -152,87 +218,139 @@ class MDCEVModel:
152
218
  ):
153
219
  """Likelihood-based parameter refinement.
154
220
 
155
- The objective is a Gaussian log-likelihood on log allocations around
156
- translated-utility MDCEV deterministic predictions. This is a practical
157
- quasi-MLE refinement that preserves the MDCEV budget constraint while
158
- improving fit over pure moments.
221
+ Uses JAX automatic differentiation for exact analytic gradients when
222
+ JAX is available, replacing the slow scipy finite-difference
223
+ approximation. Falls back to scipy FD when JAX is not installed.
159
224
  """
160
225
  self.fit(allocations, labels=labels)
161
226
 
162
227
  y = _as_2d_float(allocations)
163
- budgets = y.sum(axis=1)
228
+ budgets_np = y.sum(axis=1)
164
229
  n_alt = y.shape[1]
165
-
166
- free_base_idx = [i for i in range(n_alt) if i != self.outside_good]
167
-
168
- def _pack(base, alpha, gamma, sigma):
169
- b = np.asarray(base, dtype=float)
170
- a = np.asarray(alpha, dtype=float)
171
- g = np.asarray(gamma, dtype=float)
172
-
230
+ alpha_floor = self.alpha_floor
231
+ alpha_cap = self.alpha_cap
232
+ gamma_floor = self.gamma_floor
233
+ tol = self.tol
234
+ og = self.outside_good
235
+ has_og = og is not None and 0 <= og < n_alt
236
+ free_base_idx = [i for i in range(n_alt) if not has_og or i != og]
237
+
238
+ def _pack_np(base, alpha, gamma, sigma):
173
239
  p = []
174
- p.extend(b[free_base_idx].tolist())
175
- p.extend(np.log(np.clip((a - self.alpha_floor) / np.clip(self.alpha_cap - a, self.tol, None), self.tol, None)).tolist())
176
- p.extend(np.log(np.clip(g, self.gamma_floor, None)).tolist())
240
+ p.extend(np.asarray(base)[free_base_idx].tolist())
241
+ p.extend(np.log(np.clip(
242
+ (np.asarray(alpha) - alpha_floor) / np.clip(alpha_cap - np.asarray(alpha), tol, None),
243
+ tol, None)).tolist())
244
+ p.extend(np.log(np.clip(np.asarray(gamma), gamma_floor, None)).tolist())
177
245
  p.append(np.log(max(float(sigma), 1e-3)))
178
246
  return np.asarray(p, dtype=float)
179
247
 
180
- def _unpack(theta):
181
- theta = np.asarray(theta, dtype=float)
182
- o = 0
183
-
184
- base = self.baseline_utility_.copy()
185
- for idx in free_base_idx:
186
- base[idx] = theta[o]
187
- o += 1
188
- if self.outside_good is not None and 0 <= self.outside_good < n_alt:
189
- base[self.outside_good] = 0.0
190
-
191
- alpha_raw = theta[o:o + n_alt]
192
- o += n_alt
193
- alpha_sig = 1.0 / (1.0 + np.exp(-alpha_raw))
194
- alpha = self.alpha_floor + (self.alpha_cap - self.alpha_floor) * alpha_sig
195
-
196
- gamma_raw = theta[o:o + n_alt]
197
- o += n_alt
198
- gamma = np.maximum(np.exp(gamma_raw), self.gamma_floor)
199
-
200
- sigma = max(np.exp(theta[o]), 1e-3)
201
- return base, alpha, gamma, sigma
202
-
203
- def _neg_loglike(theta):
204
- base, alpha, gamma, sigma = _unpack(theta)
205
-
206
- old_b, old_a, old_g = self.baseline_utility_, self.alpha_, self.gamma_
207
- self.baseline_utility_, self.alpha_, self.gamma_ = base, alpha, gamma
208
- try:
209
- mu = np.zeros_like(y)
210
- for i, b in enumerate(budgets):
211
- mu[i] = self._solve_budget(float(b), base)
212
- finally:
213
- self.baseline_utility_, self.alpha_, self.gamma_ = old_b, old_a, old_g
214
-
215
- log_y = np.log(np.clip(y, self.tol, None))
216
- log_mu = np.log(np.clip(mu, self.tol, None))
217
- resid = log_y - log_mu
218
- ll = -0.5 * resid.size * np.log(2.0 * np.pi * sigma * sigma)
219
- ll -= 0.5 * np.sum((resid / sigma) ** 2)
220
- ll -= l2_penalty * np.sum(theta * theta)
221
- return -float(ll)
222
-
223
- theta0 = _pack(self.baseline_utility_, self.alpha_, self.gamma_, sigma=0.5)
224
- res = minimize(
225
- _neg_loglike,
226
- theta0,
227
- method="L-BFGS-B",
228
- options={"maxiter": int(maxiter), "ftol": 1e-9},
229
- )
248
+ theta0 = _pack_np(self.baseline_utility_, self.alpha_, self.gamma_, sigma=0.5)
249
+
250
+ if _JAX_AVAILABLE:
251
+ # ── JAX path: exact gradients via autodiff ────────────────────
252
+ y_jax = jnp.array(y)
253
+ B_jax = jnp.array(budgets_np)
254
+ free_idx_jax = jnp.array(free_base_idx)
255
+
256
+ def _unpack_jax(theta):
257
+ o = 0
258
+ n_free = len(free_base_idx)
259
+ free_base = theta[o:o + n_free]
260
+ o += n_free
261
+ base = jnp.zeros(n_alt).at[free_idx_jax].set(free_base)
262
+
263
+ alpha_raw = theta[o:o + n_alt]
264
+ o += n_alt
265
+ alpha = alpha_floor + (alpha_cap - alpha_floor) * jax.nn.sigmoid(alpha_raw)
266
+
267
+ gamma_raw = theta[o:o + n_alt]
268
+ o += n_alt
269
+ gamma = jnp.maximum(jnp.exp(gamma_raw), gamma_floor)
270
+
271
+ sigma = jnp.maximum(jnp.exp(theta[o]), 1e-3)
272
+ return base, alpha, gamma, sigma
273
+
274
+ @_jit
275
+ def _neg_loglike_jax(theta):
276
+ base, alpha, gamma, sigma = _unpack_jax(theta)
277
+ util_matrix = jnp.broadcast_to(base[None, :], (len(B_jax), n_alt))
278
+ preds = _jax_bisect_batch(B_jax, util_matrix, alpha, gamma)
279
+ log_y = jnp.log(jnp.clip(y_jax, tol, None))
280
+ log_p = jnp.log(jnp.clip(preds, tol, None))
281
+ resid = log_y - log_p
282
+ ll = -0.5 * resid.size * jnp.log(2.0 * jnp.pi * sigma * sigma)
283
+ ll -= 0.5 * jnp.sum((resid / sigma) ** 2)
284
+ ll -= l2_penalty * jnp.sum(theta * theta)
285
+ return -ll
286
+
287
+ val_and_grad = _jit(jax.value_and_grad(_neg_loglike_jax))
288
+
289
+ def _scipy_obj(theta_np):
290
+ val, grad = val_and_grad(jnp.array(theta_np, dtype=jnp.float64))
291
+ return float(val), np.asarray(grad, dtype=np.float64)
292
+
293
+ res = minimize(
294
+ _scipy_obj,
295
+ theta0,
296
+ method="L-BFGS-B",
297
+ jac=True,
298
+ options={"maxiter": int(maxiter), "ftol": 1e-9},
299
+ )
300
+ base_j, alpha_j, gamma_j, sigma_j = _unpack_jax(jnp.array(res.x))
301
+ self.baseline_utility_ = np.asarray(base_j)
302
+ self.alpha_ = np.asarray(alpha_j)
303
+ self.gamma_ = np.asarray(gamma_j)
304
+ self.noise_sigma_ = float(sigma_j)
305
+
306
+ else:
307
+ # ── numpy / scipy FD fallback (original behaviour) ────────────
308
+ def _unpack_np(theta):
309
+ theta = np.asarray(theta, dtype=float)
310
+ o = 0
311
+ base = self.baseline_utility_.copy()
312
+ for idx in free_base_idx:
313
+ base[idx] = theta[o]
314
+ o += 1
315
+ if has_og:
316
+ base[og] = 0.0
317
+ alpha_raw = theta[o:o + n_alt]; o += n_alt
318
+ alpha = alpha_floor + (alpha_cap - alpha_floor) / (1.0 + np.exp(-alpha_raw))
319
+ gamma_raw = theta[o:o + n_alt]; o += n_alt
320
+ gamma = np.maximum(np.exp(gamma_raw), gamma_floor)
321
+ sigma = max(np.exp(theta[o]), 1e-3)
322
+ return base, alpha, gamma, sigma
323
+
324
+ def _neg_loglike_np(theta):
325
+ base, alpha, gamma, sigma = _unpack_np(theta)
326
+ old_b, old_a, old_g = self.baseline_utility_, self.alpha_, self.gamma_
327
+ self.baseline_utility_, self.alpha_, self.gamma_ = base, alpha, gamma
328
+ try:
329
+ mu = np.zeros_like(y)
330
+ for i, b in enumerate(budgets_np):
331
+ mu[i] = self._solve_budget(float(b), base)
332
+ finally:
333
+ self.baseline_utility_, self.alpha_, self.gamma_ = old_b, old_a, old_g
334
+ log_y = np.log(np.clip(y, tol, None))
335
+ log_mu = np.log(np.clip(mu, tol, None))
336
+ resid = log_y - log_mu
337
+ ll = -0.5 * resid.size * np.log(2.0 * np.pi * sigma * sigma)
338
+ ll -= 0.5 * np.sum((resid / sigma) ** 2)
339
+ ll -= l2_penalty * np.sum(theta * theta)
340
+ return -float(ll)
341
+
342
+ res = minimize(
343
+ _neg_loglike_np,
344
+ theta0,
345
+ method="L-BFGS-B",
346
+ options={"maxiter": int(maxiter), "ftol": 1e-9},
347
+ )
348
+ base, alpha, gamma, sigma = _unpack_np(res.x)
349
+ self.baseline_utility_ = base
350
+ self.alpha_ = alpha
351
+ self.gamma_ = gamma
352
+ self.noise_sigma_ = float(sigma)
230
353
 
231
- base, alpha, gamma, sigma = _unpack(res.x)
232
- self.baseline_utility_ = base
233
- self.alpha_ = alpha
234
- self.gamma_ = gamma
235
- self.noise_sigma_ = float(sigma)
236
354
  self.mle_success_ = bool(res.success)
237
355
  self.mle_message_ = str(res.message)
238
356
  return self
@@ -267,6 +385,16 @@ class MDCEVModel:
267
385
  budgets_arr = np.asarray(budgets, dtype=float).reshape(-1)
268
386
  shifts = self._prepare_utility_shift(utility_shift, len(budgets_arr))
269
387
 
388
+ if _JAX_AVAILABLE:
389
+ util_matrix = jnp.array(self.baseline_utility_[None, :] + shifts)
390
+ preds = _jax_bisect_batch(
391
+ jnp.array(budgets_arr),
392
+ util_matrix,
393
+ jnp.array(self.alpha_),
394
+ jnp.array(self.gamma_),
395
+ )
396
+ return np.asarray(preds)
397
+
270
398
  predictions = np.zeros((len(budgets_arr), len(self.baseline_utility_)), dtype=float)
271
399
  for row_idx, budget in enumerate(budgets_arr):
272
400
  predictions[row_idx] = self._solve_budget(budget, self.baseline_utility_ + shifts[row_idx])
@@ -280,6 +408,17 @@ class MDCEVModel:
280
408
  rng = np.random.default_rng(random_state)
281
409
 
282
410
  sims = np.zeros((n_draws, len(budgets_arr), len(self.baseline_utility_)), dtype=float)
411
+
412
+ if _JAX_AVAILABLE:
413
+ B_jax = jnp.array(budgets_arr)
414
+ alpha_jax = jnp.array(self.alpha_)
415
+ gamma_jax = jnp.array(self.gamma_)
416
+ for draw_idx in range(n_draws):
417
+ shocks = rng.gumbel(loc=0.0, scale=1.0, size=shifts.shape)
418
+ util_matrix = jnp.array(self.baseline_utility_[None, :] + shifts + shocks)
419
+ sims[draw_idx] = np.asarray(_jax_bisect_batch(B_jax, util_matrix, alpha_jax, gamma_jax))
420
+ return sims
421
+
283
422
  for draw_idx in range(n_draws):
284
423
  shocks = rng.gumbel(loc=0.0, scale=1.0, size=shifts.shape)
285
424
  for row_idx, budget in enumerate(budgets_arr):
@@ -304,6 +443,7 @@ class MDCEVModel:
304
443
  return shift_arr
305
444
 
306
445
  def _solve_budget(self, budget: float, utility_index: np.ndarray) -> np.ndarray:
446
+ """Pure-numpy bisection fallback — used when JAX is not available."""
307
447
  if budget <= self.tol:
308
448
  return np.zeros(len(self.baseline_utility_), dtype=float)
309
449
 
@@ -626,7 +626,7 @@ class MultinomialLogit(DiscreteChoiceModel):
626
626
  ''' ---------------------------------------------------------- '''
627
627
  ''' Function '''
628
628
  ''' ---------------------------------------------------------- '''
629
- def scipy_bfgs_optimization(self, betas, X, y, weights, avail, maxiter, ftol, gtol, jac, return_opg, sklearn=None):
629
+ def scipy_bfgs_optimization(self, betas, X, y, weights, avail, maxiter, ftol, gtol, jac, return_opg=False, sklearn=None):
630
630
  # {
631
631
 
632
632
 
@@ -906,7 +906,7 @@ class NestedLogit(MultinomialLogit):
906
906
  V_j = float(V[i, chosen_alt]) # V_{A1}
907
907
  sumPV_g = float(sumPV_per_nest[chosen_nest_name][i]) # Σ_k P(k|A) V_k
908
908
  logS_g = float(logS_per_nest[chosen_nest_name][i]) # ln S_A
909
- T_i = float(T[i])
909
+ T_i = float(T[i, 0]) # T has shape (N,1) due to keepdims=True
910
910
 
911
911
  # TERM 1: ( sumPV_g - V_j ) / lambda_g^2
912
912
  term1 = (sumPV_g - V_j) / (lambda_g ** 2)
@@ -296,7 +296,7 @@ class OrderedLogit():
296
296
  def find_category(self, values: np.ndarray, thresholds: np.ndarray)->np.ndarray:
297
297
  # {
298
298
  category = np.digitize(values, thresholds)
299
- return category - 1
299
+ return category
300
300
  # Note: Subtract - 1 to convert indexing starting from 1 to indexing starting from 0
301
301
  # }
302
302
 
@@ -699,7 +699,7 @@ class OrderedLogit():
699
699
  self.method = method
700
700
 
701
701
  if start is None:
702
- start = [0] * self.params
702
+ start = [0] * self.nparams
703
703
  value = [1] * (self.J - 2) # These are the deltas
704
704
  set_last_elements(start, self.J - 2, value)
705
705
 
@@ -206,7 +206,7 @@ class RandomRegret(DiscreteChoiceModel):
206
206
 
207
207
  X, y, varnames, alts, isvars, transvars, ids, weights, panels, avail = \
208
208
  self.set_asarray(X, y, varnames, alts, isvars, transvars, ids, weights, None, avail)
209
- self.nb_samples, self.nb_attr = self.X.shape
209
+ self.nb_samples, self.nb_attr = X.shape
210
210
  self.nb_alt = len(set(alts))
211
211
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212
212
  # CHECK FOR MISTAKES IN DATA
@@ -247,6 +247,13 @@ class RandomRegret(DiscreteChoiceModel):
247
247
  self.fixedtransvars = self.transvars
248
248
  self.X, self.Xnames = self.setup_design_matrix(self.X)
249
249
 
250
+ # Update nb_samples to reflect per-observation count (after 3D reshape)
251
+ self.nb_samples = self.N
252
+ self.nb_attr = self.X.shape[2] if self.X.ndim == 3 else self.X.shape[1]
253
+ # Convert binary choice indicators (N*J,) to per-obs chosen alt index (N,)
254
+ y_matrix = self.y.reshape(self.N, self.J)
255
+ self.y = np.argmax(y_matrix, axis=1).astype(int)
256
+
250
257
 
251
258
 
252
259
  if self.weights is not None:
@@ -283,6 +290,7 @@ class RandomRegret(DiscreteChoiceModel):
283
290
  '''
284
291
  self.y = np.zeros((self.nb_samples), dtype=int)
285
292
  self.attrs = kwargs.get("varnames", varnames)
293
+ self.normalize = kwargs.get('normalize', False)
286
294
  self.initialise()
287
295
 
288
296
 
@@ -616,7 +616,7 @@ class SA(Search):
616
616
  def copy_solution(self, sol):
617
617
  # {
618
618
  logging.info('normal copy')
619
- copy_sol = sol.copy()
619
+ copy_sol = copy.deepcopy(sol) # Deep copy to prevent shared mutable state (lists, dicts inside solution)
620
620
  return copy_sol
621
621
  # }
622
622
 
@@ -998,15 +998,12 @@ class SA(Search):
998
998
  print("Initialized best_sol with the first solution")
999
999
  return
1000
1000
 
1001
- self.no_impr = 0
1002
1001
  if self.nb_crit == 1:
1003
1002
  # {
1004
1003
  if is_better(sol.obj(0), self.best_sol.obj(0), self.param.sign_crit(0)):
1005
1004
  self.best_sol = self.copy_solution(sol)
1006
-
1007
- print('new best')
1008
-
1009
1005
  self.no_impr = 0
1006
+ print('new best')
1010
1007
  # }
1011
1008
  else:
1012
1009
  self.archive = self.update_archive(sol)
@@ -1087,7 +1084,7 @@ class SA(Search):
1087
1084
  # {
1088
1085
  print("ARCHIVE STATIC. RESTORE A NON DOMINATED SOLUTION.")
1089
1086
  choice = np.random.randint(len(self.archive))
1090
- self.current_sol = self.archive[choice]
1087
+ self.current_sol = self.copy_solution(self.archive[choice]) # Deep copy to avoid aliasing archive entry
1091
1088
  self.no_impr = 0
1092
1089
  # }
1093
1090
  # }
@@ -0,0 +1 @@
1
+ 0.0.87
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SearchLibrium
3
- Version: 0.0.85
3
+ Version: 0.0.87
4
4
  Summary: A Python package for econometric models driven by search
5
5
  Author: Alexander Paz Prithvi Beeramole, Robert Burdett
6
6
  Author-email: Zeke Ahern <z.ahern@qut.edu.au>
@@ -1 +0,0 @@
1
- 0.0.85
File without changes
File without changes