SearchLibrium 0.0.89__tar.gz → 0.0.91__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/PKG-INFO +1 -1
  2. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/pyproject.toml +1 -1
  3. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/MixedLogit.py +133 -27
  4. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/_choice_model.py +1 -1
  5. searchlibrium-0.0.91/src/SearchLibrium/latent_class.py +850 -0
  6. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mixed_nested.py +2 -2
  7. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/multinomial_logit.py +1 -1
  8. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/multinomial_nested.py +1 -1
  9. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/multinomial_probit.py +1 -1
  10. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/ordered_logit.py +1 -1
  11. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/sapbil.py +3 -0
  12. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/search.py +22 -10
  13. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/selection_models.py +2 -2
  14. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/siman.py +85 -46
  15. searchlibrium-0.0.91/src/SearchLibrium/test_lc_de.py +423 -0
  16. searchlibrium-0.0.91/src/SearchLibrium/test_mario_searches.py +346 -0
  17. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/test_sapbil_vs_banditsa.py +196 -33
  18. searchlibrium-0.0.91/src/SearchLibrium/version.txt +1 -0
  19. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
  20. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/SOURCES.txt +2 -0
  21. searchlibrium-0.0.89/src/SearchLibrium/latent_class.py +0 -353
  22. searchlibrium-0.0.89/src/SearchLibrium/version.txt +0 -1
  23. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/README.md +0 -0
  24. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/setup.cfg +0 -0
  25. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/Halton.py +0 -0
  26. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
  27. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/RandomP.py +0 -0
  28. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
  29. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/Two_Level_Nest.py +0 -0
  30. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/__init__.py +0 -0
  31. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/__main__.py +0 -0
  32. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/_device.py +0 -0
  33. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/banditsa.py +0 -0
  34. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/bhhh/minimize.py +0 -0
  35. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/boxcox_functions.py +0 -0
  36. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/call_meta.py +0 -0
  37. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/constraints_builder.py +0 -0
  38. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/harmony.py +0 -0
  39. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/main.py +0 -0
  40. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/main_debug.py +0 -0
  41. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mdcev.py +0 -0
  42. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/misc.py +0 -0
  43. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mixed_logit.py +0 -0
  44. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mixedrrm.py +0 -0
  45. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
  46. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/rrm.py +0 -0
  47. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/setup.py +0 -0
  48. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/threshold.py +0 -0
  49. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
  50. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
  51. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/requires.txt +0 -0
  52. {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SearchLibrium
3
- Version: 0.0.89
3
+ Version: 0.0.91
4
4
  Summary: A Python package for econometric models driven by search
5
5
  Author: Alexander Paz Prithvi Beeramole, Robert Burdett
6
6
  Author-email: Zeke Ahern <z.ahern@qut.edu.au>
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
59
59
  realpython = "SearchLibrium.__main__:main"
60
60
 
61
61
  [tool.bumpver]
62
- current_version = "0.0.89"
62
+ current_version = "0.0.91"
63
63
  version_pattern = "MAJOR.MINOR.PATCH"
64
64
  commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
65
65
  commit = true
@@ -2,7 +2,7 @@ from typing import Optional
2
2
  import itertools
3
3
  import numpy as np
4
4
  import scipy.stats as ss
5
- from scipy.optimize import minimize
5
+ from scipy.optimize import minimize, differential_evolution
6
6
  from typing import Callable, Tuple
7
7
  import inspect
8
8
 
@@ -34,8 +34,8 @@ max_comp_val, min_comp_val = 1e+20, 1e-200 # or use float('inf')
34
34
  infinity = float('inf')
35
35
 
36
36
  class MixedLogit(DiscreteChoiceModel):
37
- def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u']):
38
- super().__init__()
37
+ def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u'], _jax=True):
38
+ super().__init__(_jax)
39
39
  self.halton_opts = halton_opts
40
40
  self.draws_generator = Draws(k=len(distributions), halton_opts=halton_opts, rvdist=distributions)
41
41
  self.random_parameters = RandomParameters(distributions or []) # Initialize RandomParameters
@@ -66,7 +66,9 @@ class MixedLogit(DiscreteChoiceModel):
66
66
  n_draws=1000, halton=True, minimise_func=None,
67
67
  batch_size=None, halton_opts=None, ftol=1e-6,
68
68
  gtol=1e-6, return_hess=True, return_grad=True, method="bfgs",
69
- save_fitted_params=True, mnl_init=True):
69
+ save_fitted_params=True, mnl_init=True,
70
+ de_init=False, de_popsize=4, de_maxiter=3, de_tol=0.5,
71
+ de_polish=False):
70
72
  # {
71
73
  self.fit_intercept = fit_intercept
72
74
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -104,6 +106,11 @@ class MixedLogit(DiscreteChoiceModel):
104
106
  self.minimise_func = minimise_func
105
107
  self.save_fitted_params = save_fitted_params
106
108
  self.mnl_init = mnl_init
109
+ self.de_init = de_init
110
+ self.de_popsize = de_popsize
111
+ self.de_maxiter = de_maxiter
112
+ self.de_tol = de_tol
113
+ self.de_polish = de_polish
107
114
  self.total_fun_eval = 0
108
115
  self.method = method.lower() if hasattr(method, 'lower') else method
109
116
  self.jac = self.return_grad # scipy optimize parameter
@@ -312,6 +319,80 @@ class MixedLogit(DiscreteChoiceModel):
312
319
  if len(self.init_coeff) != n_coeff and not hasattr(self, 'class_params_spec'):
313
320
  raise ValueError("The size of init_coeff must be: " + str(n_coeff))
314
321
 
322
+ positive_bound = (0, infinity)
323
+ any_bound = (-infinity, infinity)
324
+ lmda_bound = (-5, 1)
325
+ bound_dict = {
326
+ "bf": (any_bound, self.Kf),
327
+ "br_b": (any_bound, self.Kr),
328
+ "chol": (any_bound, self.Kchol),
329
+ "br_w": (positive_bound, self.Kr - self.correlationLength),
330
+ "bf_trans": (any_bound, self.Kftrans),
331
+ "flmbda": (lmda_bound, self.Kftrans),
332
+ "br_trans_b": (any_bound, self.Krtrans),
333
+ "br_trans_w": (any_bound, self.Krtrans),
334
+ "rlmbda": (lmda_bound, self.Krtrans)
335
+ }
336
+ bnds = [[bound[1][0]] * bound[1][1] for bound in bound_dict.items() if bound[1][1] > 0]
337
+ bnds = list(itertools.chain.from_iterable(bnds))
338
+
339
+ print(f"[MXL] Starting beta seed length={n_coeff}, first_values={betas[:min(8, len(betas))]!r}")
340
+
341
+ if self.de_init:
342
+ print(f"[MXL] DE init enabled: popsize={self.de_popsize}, maxiter={self.de_maxiter}, tol={self.de_tol}, polish={self.de_polish}")
343
+ try:
344
+ before_de_obj = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
345
+ draws, drawstrans, self.weights,
346
+ self.avail, self.batch_size)[0]
347
+ print(f"[MXL] DE before: obj={before_de_obj:.6g}")
348
+
349
+ def _de_obj(x):
350
+ return self.get_loglik_gradient(x, self.X, self.y, self.panel_info,
351
+ draws, drawstrans, self.weights,
352
+ self.avail, self.batch_size)[0]
353
+
354
+ if len(bnds) == n_coeff:
355
+ de_bounds = []
356
+ for idx, bound in enumerate(bnds):
357
+ low, high = bound
358
+ if not np.isfinite(low) or not np.isfinite(high):
359
+ scale = max(1.0, 10.0 * abs(betas[idx]) if idx < len(betas) else 1.0)
360
+ de_bounds.append((-scale, scale))
361
+ else:
362
+ de_bounds.append((low, high))
363
+ else:
364
+ de_bounds = [(-10, 10)] * n_coeff
365
+
366
+ de_result = differential_evolution(
367
+ _de_obj,
368
+ de_bounds,
369
+ strategy='best1bin',
370
+ maxiter=self.de_maxiter,
371
+ popsize=self.de_popsize,
372
+ tol=self.de_tol,
373
+ polish=self.de_polish,
374
+ disp=False,
375
+ workers=1,
376
+ )
377
+ de_improved = False
378
+ print(f"[MXL] DE completed: success={de_result.success}, nit={de_result.nit}, message={de_result.message}")
379
+ if de_result.success:
380
+ de_obj = self.get_loglik_gradient(de_result.x, self.X, self.y, self.panel_info,
381
+ draws, drawstrans, self.weights,
382
+ self.avail, self.batch_size)[0]
383
+ print(f"[MXL] DE after: obj={de_obj:.6g}")
384
+ de_improved = de_obj < before_de_obj
385
+ if de_improved:
386
+ betas = de_result.x
387
+ print(f"[MXL] DE best seed first_values={betas[:min(8, len(betas))]!r}")
388
+ print("[MXL] Differential evolution initialization improved the objective; starting minimise from DE seed.")
389
+ else:
390
+ print(f"[MXL] Differential evolution did not beat the original start ({de_obj:.6g} >= {before_de_obj:.6g}); keeping original start.")
391
+ if not de_improved:
392
+ print(f"[MXL] Differential evolution initialization rejected because it did not improve the starting objective. Using original seed.")
393
+ except Exception as e:
394
+ print(f"[MXL] Differential evolution initialization failed: {e}")
395
+
315
396
  # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
316
397
  if dev.using_gpu: # {
317
398
  self.X, self.y = dev.convert_array_gpu(self.X), dev.convert_array_gpu(self.y)
@@ -323,6 +404,15 @@ class MixedLogit(DiscreteChoiceModel):
323
404
  # }
324
405
  # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
325
406
 
407
+ print(f"[MXL] Minimization start with betas first_values={betas[:min(8, len(betas))]!r}")
408
+ try:
409
+ before_fun = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
410
+ draws, drawstrans, self.weights,
411
+ self.avail, self.batch_size)[0]
412
+ print(f"[MXL] Minimization obj before={before_fun:.6g}")
413
+ except Exception as e:
414
+ print(f"[MXL] Could not evaluate initial objective before minimization: {e}")
415
+
326
416
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
327
417
  # Generate bound for L-BFGS-B method
328
418
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -378,6 +468,9 @@ class MixedLogit(DiscreteChoiceModel):
378
468
  options = {'gtol': self.gtol, 'maxiter': self.maxiter, 'disp': False}
379
469
  result = minimise_func(self.get_loglik_gradient, betas, jac=self.jac, method=self.method,
380
470
  args=args, tol=self.ftol, bounds=bounds, options=options)
471
+ print(f"[MXL] Minimization completed: success={result.get('success', None)}, fun={result.get('fun', float('nan')):.6g}, nit={result.get('nit', '?')}")
472
+ if 'x' in result:
473
+ print(f"[MXL] Minimization final betas first_values={np.asarray(result['x'])[:min(8, len(result['x']))]!r}")
381
474
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
382
475
 
383
476
  if hasattr(self, 'method') and self.method == "L-BFGS-B": # {
@@ -428,38 +521,37 @@ class MixedLogit(DiscreteChoiceModel):
428
521
  # ------------------------------------------------------------------
429
522
  @staticmethod
430
523
  def _jax_mxl_negloglik(betas, X_jax, y_jax, panel_info_jax, draws_jax,
431
- fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names):
524
+ fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names,
525
+ correlationLength):
432
526
  """JAX simulation-based log-likelihood for Mixed Logit (standard case).
433
527
 
434
528
  Handles fixed and random (normally/lognormally distributed) parameters.
435
- Cholesky correlation structure is included via the Kchol segment.
529
+ Cholesky correlation structure is included via the Kchol segment;
530
+ uncorrelated random vars use independent std devs in Br_w.
436
531
  """
437
532
  import jax.numpy as jnp
438
533
 
439
534
  # ---- split beta vector ----
440
535
  Bf = betas[:Kf]
441
536
  Br_b = betas[Kf:Kf + Kr]
442
- chol_v = betas[Kf + Kr:Kf + Kr + Kchol] # cholesky lower-triangle values
537
+ chol_v = betas[Kf + Kr:Kf + Kr + Kchol] # correlated cholesky elements
443
538
  Br_w = betas[Kf + Kr + Kchol:Kf + Kr + Kchol + Kbw] # independent std devs
444
539
 
445
540
  # ---- build cholesky matrix ----
446
- tril_r, tril_c = jnp.tril_indices(Kr)
541
+ # Rows 0..correlationLength-1: lower-triangle from chol_v
542
+ # Rows correlationLength..Kr-1: diagonal from Br_w (uncorrelated)
447
543
  chol_mat = jnp.zeros((Kr, Kr))
448
- # place chol values into lower triangle
449
- # (static indices ok for jit since Kr is compile-time constant)
450
544
  idx = 0
451
- for r in range(Kr):
545
+ for r in range(correlationLength):
452
546
  for c in range(r + 1):
453
547
  chol_mat = chol_mat.at[r, c].set(chol_v[idx])
454
548
  idx += 1
455
- # diagonal of the independent block
456
549
  for k in range(Kbw):
457
- chol_mat = chol_mat.at[Kr - Kbw + k, Kr - Kbw + k].set(jnp.abs(Br_w[k]))
550
+ diag_pos = correlationLength + k
551
+ chol_mat = chol_mat.at[diag_pos, diag_pos].set(jnp.abs(Br_w[k]))
458
552
 
459
553
  # ---- sample random coefficients: (N, Kr, R) ----
460
- N = X_jax.shape[0]
461
- R = draws_jax.shape[2]
462
- # draws_jax: (N, Kr, R)
554
+ N = X_jax.shape[0]
463
555
  Br = Br_b[:, None] + jnp.einsum('kl,nlr->nkr',
464
556
  chol_mat,
465
557
  draws_jax[:, :Kr, :]) # (N, Kr, R)
@@ -475,21 +567,24 @@ class MixedLogit(DiscreteChoiceModel):
475
567
  Br_b[k] + Br_w[k] * (draws_jax[:, k, :] - 0.5))
476
568
 
477
569
  # ---- utility ----
478
- # X_jax: (N, P, J, K)
479
570
  P = X_jax.shape[1]
480
- Xf = X_jax[:, :, :, fxidx] # (N, P, J, Kf)
571
+ J = X_jax.shape[2]
481
572
  Xr = X_jax[:, :, :, rvidx] # (N, P, J, Kr)
482
-
483
- UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
484
573
  UR = jnp.einsum('npjk,nkr->npjr', Xr, Br) # (N, P, J, R)
485
- U = UB[:, :, :, None] + UR # (N, P, J, R)
574
+
575
+ if Kf > 0:
576
+ Xf = X_jax[:, :, :, fxidx] # (N, P, J, Kf)
577
+ UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
578
+ U = UB[:, :, :, None] + UR
579
+ else:
580
+ U = UR # (N, P, J, R)
486
581
 
487
582
  U = U - jnp.max(U, axis=2, keepdims=True)
488
583
  eU = jnp.exp(U)
489
584
  p = eU / jnp.sum(eU, axis=2, keepdims=True) # (N, P, J, R)
490
585
 
491
- pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R) – chosen probs
492
- pch = jnp.prod(pch, axis=1) # (N, R) – product across panels
586
+ pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R)
587
+ pch = jnp.prod(pch, axis=1) # (N, R)
493
588
  pch = jnp.clip(pch, 1e-300, None)
494
589
 
495
590
  sim_p = jnp.mean(pch, axis=1) # (N,)
@@ -503,27 +598,38 @@ class MixedLogit(DiscreteChoiceModel):
503
598
  Falls back to standard numpy optimisation on failure.
504
599
  """
505
600
  try:
601
+ import os
602
+ os.environ.setdefault("JAX_ENABLE_X64", "True")
506
603
  import jax
604
+ jax.config.update("jax_enable_x64", True)
507
605
  import jax.numpy as jnp
508
606
  from scipy.optimize import minimize as sp_min
509
607
 
608
+ if int(self.Kr) == 0:
609
+ print("[JAX MXL optimizer] no random coefficients detected (Kr=0); falling back to scipy.")
610
+ return None
611
+
510
612
  # Convert static inputs once
511
613
  X_jax = jnp.array(self.X, dtype=jnp.float64)
512
- y_jax = jnp.array(self.y, dtype=jnp.float64)
513
- pi_jax = jnp.array(self.panel_info, dtype=jnp.float64)
514
- draws_jax = jnp.array(draws, dtype=jnp.float64)
614
+ _y_np = np.asarray(self.y)
615
+ if _y_np.ndim > 3: # stored as (N, P, J, 1) in some paths
616
+ _y_np = _y_np[..., 0]
617
+ y_jax = jnp.array(_y_np, dtype=jnp.float64)
618
+ pi_jax = jnp.array(self.panel_info, dtype=jnp.float64)
619
+ draws_jax = jnp.array(draws, dtype=jnp.float64)
515
620
 
516
621
  fxidx = jnp.array(self.fxidx, dtype=bool)
517
622
  rvidx = jnp.array(self.rvidx, dtype=bool)
518
623
  rvdist_names = [d for d in self.rvdist if d is not False]
519
624
 
520
625
  Kf, Kr, Kchol, Kbw = int(self.Kf), int(self.Kr), int(self.Kchol), int(self.Kbw)
626
+ correlationLength = int(self.correlationLength)
521
627
 
522
628
  @jax.jit
523
629
  def _neg_ll(b):
524
630
  return self._jax_mxl_negloglik(
525
631
  b, X_jax, y_jax, pi_jax, draws_jax,
526
- fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names)
632
+ fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names, correlationLength)
527
633
 
528
634
  _val_grad = jax.jit(jax.value_and_grad(_neg_ll))
529
635
 
@@ -141,7 +141,7 @@ class DiscreteChoiceModel(ABC):
141
141
  ''' ---------------------------------------------------------- '''
142
142
  ''' Function '''
143
143
  ''' ---------------------------------------------------------- '''
144
- def __init__(self, jax = False):
144
+ def __init__(self, jax = True):
145
145
  # {
146
146
 
147
147
  self.reset_attributes()