SearchLibrium 0.0.89__tar.gz → 0.0.90__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/PKG-INFO +1 -1
  2. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/pyproject.toml +1 -1
  3. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/MixedLogit.py +133 -27
  4. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/_choice_model.py +1 -1
  5. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/latent_class.py +102 -3
  6. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mixed_nested.py +2 -2
  7. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/multinomial_logit.py +1 -1
  8. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/multinomial_nested.py +1 -1
  9. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/multinomial_probit.py +1 -1
  10. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/ordered_logit.py +1 -1
  11. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/sapbil.py +3 -0
  12. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/search.py +22 -10
  13. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/selection_models.py +2 -2
  14. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/siman.py +85 -46
  15. searchlibrium-0.0.90/src/SearchLibrium/test_mario_searches.py +346 -0
  16. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/test_sapbil_vs_banditsa.py +196 -33
  17. searchlibrium-0.0.90/src/SearchLibrium/version.txt +1 -0
  18. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
  19. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/SOURCES.txt +1 -0
  20. searchlibrium-0.0.89/src/SearchLibrium/version.txt +0 -1
  21. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/README.md +0 -0
  22. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/setup.cfg +0 -0
  23. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/Halton.py +0 -0
  24. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
  25. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/RandomP.py +0 -0
  26. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
  27. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/Two_Level_Nest.py +0 -0
  28. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/__init__.py +0 -0
  29. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/__main__.py +0 -0
  30. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/_device.py +0 -0
  31. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/banditsa.py +0 -0
  32. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/bhhh/minimize.py +0 -0
  33. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/boxcox_functions.py +0 -0
  34. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/call_meta.py +0 -0
  35. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/constraints_builder.py +0 -0
  36. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/harmony.py +0 -0
  37. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/main.py +0 -0
  38. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/main_debug.py +0 -0
  39. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mdcev.py +0 -0
  40. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/misc.py +0 -0
  41. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mixed_logit.py +0 -0
  42. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mixedrrm.py +0 -0
  43. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
  44. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/rrm.py +0 -0
  45. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/setup.py +0 -0
  46. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/threshold.py +0 -0
  47. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
  48. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
  49. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/requires.txt +0 -0
  50. {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SearchLibrium
3
- Version: 0.0.89
3
+ Version: 0.0.90
4
4
  Summary: A Python package for econometric models driven by search
5
5
  Author: Alexander Paz Prithvi Beeramole, Robert Burdett
6
6
  Author-email: Zeke Ahern <z.ahern@qut.edu.au>
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
59
59
  realpython = "SearchLibrium.__main__:main"
60
60
 
61
61
  [tool.bumpver]
62
- current_version = "0.0.89"
62
+ current_version = "0.0.90"
63
63
  version_pattern = "MAJOR.MINOR.PATCH"
64
64
  commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
65
65
  commit = true
@@ -2,7 +2,7 @@ from typing import Optional
2
2
  import itertools
3
3
  import numpy as np
4
4
  import scipy.stats as ss
5
- from scipy.optimize import minimize
5
+ from scipy.optimize import minimize, differential_evolution
6
6
  from typing import Callable, Tuple
7
7
  import inspect
8
8
 
@@ -34,8 +34,8 @@ max_comp_val, min_comp_val = 1e+20, 1e-200 # or use float('inf')
34
34
  infinity = float('inf')
35
35
 
36
36
  class MixedLogit(DiscreteChoiceModel):
37
- def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u']):
38
- super().__init__()
37
+ def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u'], _jax=True):
38
+ super().__init__(_jax)
39
39
  self.halton_opts = halton_opts
40
40
  self.draws_generator = Draws(k=len(distributions), halton_opts=halton_opts, rvdist=distributions)
41
41
  self.random_parameters = RandomParameters(distributions or []) # Initialize RandomParameters
@@ -66,7 +66,9 @@ class MixedLogit(DiscreteChoiceModel):
66
66
  n_draws=1000, halton=True, minimise_func=None,
67
67
  batch_size=None, halton_opts=None, ftol=1e-6,
68
68
  gtol=1e-6, return_hess=True, return_grad=True, method="bfgs",
69
- save_fitted_params=True, mnl_init=True):
69
+ save_fitted_params=True, mnl_init=True,
70
+ de_init=False, de_popsize=4, de_maxiter=3, de_tol=0.5,
71
+ de_polish=False):
70
72
  # {
71
73
  self.fit_intercept = fit_intercept
72
74
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -104,6 +106,11 @@ class MixedLogit(DiscreteChoiceModel):
104
106
  self.minimise_func = minimise_func
105
107
  self.save_fitted_params = save_fitted_params
106
108
  self.mnl_init = mnl_init
109
+ self.de_init = de_init
110
+ self.de_popsize = de_popsize
111
+ self.de_maxiter = de_maxiter
112
+ self.de_tol = de_tol
113
+ self.de_polish = de_polish
107
114
  self.total_fun_eval = 0
108
115
  self.method = method.lower() if hasattr(method, 'lower') else method
109
116
  self.jac = self.return_grad # scipy optimize parameter
@@ -312,6 +319,80 @@ class MixedLogit(DiscreteChoiceModel):
312
319
  if len(self.init_coeff) != n_coeff and not hasattr(self, 'class_params_spec'):
313
320
  raise ValueError("The size of init_coeff must be: " + str(n_coeff))
314
321
 
322
+ positive_bound = (0, infinity)
323
+ any_bound = (-infinity, infinity)
324
+ lmda_bound = (-5, 1)
325
+ bound_dict = {
326
+ "bf": (any_bound, self.Kf),
327
+ "br_b": (any_bound, self.Kr),
328
+ "chol": (any_bound, self.Kchol),
329
+ "br_w": (positive_bound, self.Kr - self.correlationLength),
330
+ "bf_trans": (any_bound, self.Kftrans),
331
+ "flmbda": (lmda_bound, self.Kftrans),
332
+ "br_trans_b": (any_bound, self.Krtrans),
333
+ "br_trans_w": (any_bound, self.Krtrans),
334
+ "rlmbda": (lmda_bound, self.Krtrans)
335
+ }
336
+ bnds = [[bound[1][0]] * bound[1][1] for bound in bound_dict.items() if bound[1][1] > 0]
337
+ bnds = list(itertools.chain.from_iterable(bnds))
338
+
339
+ print(f"[MXL] Starting beta seed length={n_coeff}, first_values={betas[:min(8, len(betas))]!r}")
340
+
341
+ if self.de_init:
342
+ print(f"[MXL] DE init enabled: popsize={self.de_popsize}, maxiter={self.de_maxiter}, tol={self.de_tol}, polish={self.de_polish}")
343
+ try:
344
+ before_de_obj = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
345
+ draws, drawstrans, self.weights,
346
+ self.avail, self.batch_size)[0]
347
+ print(f"[MXL] DE before: obj={before_de_obj:.6g}")
348
+
349
+ def _de_obj(x):
350
+ return self.get_loglik_gradient(x, self.X, self.y, self.panel_info,
351
+ draws, drawstrans, self.weights,
352
+ self.avail, self.batch_size)[0]
353
+
354
+ if len(bnds) == n_coeff:
355
+ de_bounds = []
356
+ for idx, bound in enumerate(bnds):
357
+ low, high = bound
358
+ if not np.isfinite(low) or not np.isfinite(high):
359
+ scale = max(1.0, 10.0 * abs(betas[idx]) if idx < len(betas) else 1.0)
360
+ de_bounds.append((-scale, scale))
361
+ else:
362
+ de_bounds.append((low, high))
363
+ else:
364
+ de_bounds = [(-10, 10)] * n_coeff
365
+
366
+ de_result = differential_evolution(
367
+ _de_obj,
368
+ de_bounds,
369
+ strategy='best1bin',
370
+ maxiter=self.de_maxiter,
371
+ popsize=self.de_popsize,
372
+ tol=self.de_tol,
373
+ polish=self.de_polish,
374
+ disp=False,
375
+ workers=1,
376
+ )
377
+ de_improved = False
378
+ print(f"[MXL] DE completed: success={de_result.success}, nit={de_result.nit}, message={de_result.message}")
379
+ if de_result.success:
380
+ de_obj = self.get_loglik_gradient(de_result.x, self.X, self.y, self.panel_info,
381
+ draws, drawstrans, self.weights,
382
+ self.avail, self.batch_size)[0]
383
+ print(f"[MXL] DE after: obj={de_obj:.6g}")
384
+ de_improved = de_obj < before_de_obj
385
+ if de_improved:
386
+ betas = de_result.x
387
+ print(f"[MXL] DE best seed first_values={betas[:min(8, len(betas))]!r}")
388
+ print("[MXL] Differential evolution initialization improved the objective; starting minimise from DE seed.")
389
+ else:
390
+ print(f"[MXL] Differential evolution did not beat the original start ({de_obj:.6g} >= {before_de_obj:.6g}); keeping original start.")
391
+ if not de_improved:
392
+ print(f"[MXL] Differential evolution initialization rejected because it did not improve the starting objective. Using original seed.")
393
+ except Exception as e:
394
+ print(f"[MXL] Differential evolution initialization failed: {e}")
395
+
315
396
  # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
316
397
  if dev.using_gpu: # {
317
398
  self.X, self.y = dev.convert_array_gpu(self.X), dev.convert_array_gpu(self.y)
@@ -323,6 +404,15 @@ class MixedLogit(DiscreteChoiceModel):
323
404
  # }
324
405
  # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
325
406
 
407
+ print(f"[MXL] Minimization start with betas first_values={betas[:min(8, len(betas))]!r}")
408
+ try:
409
+ before_fun = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
410
+ draws, drawstrans, self.weights,
411
+ self.avail, self.batch_size)[0]
412
+ print(f"[MXL] Minimization obj before={before_fun:.6g}")
413
+ except Exception as e:
414
+ print(f"[MXL] Could not evaluate initial objective before minimization: {e}")
415
+
326
416
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
327
417
  # Generate bound for L-BFGS-B method
328
418
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -378,6 +468,9 @@ class MixedLogit(DiscreteChoiceModel):
378
468
  options = {'gtol': self.gtol, 'maxiter': self.maxiter, 'disp': False}
379
469
  result = minimise_func(self.get_loglik_gradient, betas, jac=self.jac, method=self.method,
380
470
  args=args, tol=self.ftol, bounds=bounds, options=options)
471
+ print(f"[MXL] Minimization completed: success={result.get('success', None)}, fun={result.get('fun', float('nan')):.6g}, nit={result.get('nit', '?')}")
472
+ if 'x' in result:
473
+ print(f"[MXL] Minimization final betas first_values={np.asarray(result['x'])[:min(8, len(result['x']))]!r}")
381
474
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
382
475
 
383
476
  if hasattr(self, 'method') and self.method == "L-BFGS-B": # {
@@ -428,38 +521,37 @@ class MixedLogit(DiscreteChoiceModel):
428
521
  # ------------------------------------------------------------------
429
522
  @staticmethod
430
523
  def _jax_mxl_negloglik(betas, X_jax, y_jax, panel_info_jax, draws_jax,
431
- fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names):
524
+ fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names,
525
+ correlationLength):
432
526
  """JAX simulation-based log-likelihood for Mixed Logit (standard case).
433
527
 
434
528
  Handles fixed and random (normally/lognormally distributed) parameters.
435
- Cholesky correlation structure is included via the Kchol segment.
529
+ Cholesky correlation structure is included via the Kchol segment;
530
+ uncorrelated random vars use independent std devs in Br_w.
436
531
  """
437
532
  import jax.numpy as jnp
438
533
 
439
534
  # ---- split beta vector ----
440
535
  Bf = betas[:Kf]
441
536
  Br_b = betas[Kf:Kf + Kr]
442
- chol_v = betas[Kf + Kr:Kf + Kr + Kchol] # cholesky lower-triangle values
537
+ chol_v = betas[Kf + Kr:Kf + Kr + Kchol] # correlated cholesky elements
443
538
  Br_w = betas[Kf + Kr + Kchol:Kf + Kr + Kchol + Kbw] # independent std devs
444
539
 
445
540
  # ---- build cholesky matrix ----
446
- tril_r, tril_c = jnp.tril_indices(Kr)
541
+ # Rows 0..correlationLength-1: lower-triangle from chol_v
542
+ # Rows correlationLength..Kr-1: diagonal from Br_w (uncorrelated)
447
543
  chol_mat = jnp.zeros((Kr, Kr))
448
- # place chol values into lower triangle
449
- # (static indices ok for jit since Kr is compile-time constant)
450
544
  idx = 0
451
- for r in range(Kr):
545
+ for r in range(correlationLength):
452
546
  for c in range(r + 1):
453
547
  chol_mat = chol_mat.at[r, c].set(chol_v[idx])
454
548
  idx += 1
455
- # diagonal of the independent block
456
549
  for k in range(Kbw):
457
- chol_mat = chol_mat.at[Kr - Kbw + k, Kr - Kbw + k].set(jnp.abs(Br_w[k]))
550
+ diag_pos = correlationLength + k
551
+ chol_mat = chol_mat.at[diag_pos, diag_pos].set(jnp.abs(Br_w[k]))
458
552
 
459
553
  # ---- sample random coefficients: (N, Kr, R) ----
460
- N = X_jax.shape[0]
461
- R = draws_jax.shape[2]
462
- # draws_jax: (N, Kr, R)
554
+ N = X_jax.shape[0]
463
555
  Br = Br_b[:, None] + jnp.einsum('kl,nlr->nkr',
464
556
  chol_mat,
465
557
  draws_jax[:, :Kr, :]) # (N, Kr, R)
@@ -475,21 +567,24 @@ class MixedLogit(DiscreteChoiceModel):
475
567
  Br_b[k] + Br_w[k] * (draws_jax[:, k, :] - 0.5))
476
568
 
477
569
  # ---- utility ----
478
- # X_jax: (N, P, J, K)
479
570
  P = X_jax.shape[1]
480
- Xf = X_jax[:, :, :, fxidx] # (N, P, J, Kf)
571
+ J = X_jax.shape[2]
481
572
  Xr = X_jax[:, :, :, rvidx] # (N, P, J, Kr)
482
-
483
- UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
484
573
  UR = jnp.einsum('npjk,nkr->npjr', Xr, Br) # (N, P, J, R)
485
- U = UB[:, :, :, None] + UR # (N, P, J, R)
574
+
575
+ if Kf > 0:
576
+ Xf = X_jax[:, :, :, fxidx] # (N, P, J, Kf)
577
+ UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
578
+ U = UB[:, :, :, None] + UR
579
+ else:
580
+ U = UR # (N, P, J, R)
486
581
 
487
582
  U = U - jnp.max(U, axis=2, keepdims=True)
488
583
  eU = jnp.exp(U)
489
584
  p = eU / jnp.sum(eU, axis=2, keepdims=True) # (N, P, J, R)
490
585
 
491
- pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R) – chosen probs
492
- pch = jnp.prod(pch, axis=1) # (N, R) – product across panels
586
+ pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R)
587
+ pch = jnp.prod(pch, axis=1) # (N, R)
493
588
  pch = jnp.clip(pch, 1e-300, None)
494
589
 
495
590
  sim_p = jnp.mean(pch, axis=1) # (N,)
@@ -503,27 +598,38 @@ class MixedLogit(DiscreteChoiceModel):
503
598
  Falls back to standard numpy optimisation on failure.
504
599
  """
505
600
  try:
601
+ import os
602
+ os.environ.setdefault("JAX_ENABLE_X64", "True")
506
603
  import jax
604
+ jax.config.update("jax_enable_x64", True)
507
605
  import jax.numpy as jnp
508
606
  from scipy.optimize import minimize as sp_min
509
607
 
608
+ if int(self.Kr) == 0:
609
+ print("[JAX MXL optimizer] no random coefficients detected (Kr=0); falling back to scipy.")
610
+ return None
611
+
510
612
  # Convert static inputs once
511
613
  X_jax = jnp.array(self.X, dtype=jnp.float64)
512
- y_jax = jnp.array(self.y, dtype=jnp.float64)
513
- pi_jax = jnp.array(self.panel_info, dtype=jnp.float64)
514
- draws_jax = jnp.array(draws, dtype=jnp.float64)
614
+ _y_np = np.asarray(self.y)
615
+ if _y_np.ndim > 3: # stored as (N, P, J, 1) in some paths
616
+ _y_np = _y_np[..., 0]
617
+ y_jax = jnp.array(_y_np, dtype=jnp.float64)
618
+ pi_jax = jnp.array(self.panel_info, dtype=jnp.float64)
619
+ draws_jax = jnp.array(draws, dtype=jnp.float64)
515
620
 
516
621
  fxidx = jnp.array(self.fxidx, dtype=bool)
517
622
  rvidx = jnp.array(self.rvidx, dtype=bool)
518
623
  rvdist_names = [d for d in self.rvdist if d is not False]
519
624
 
520
625
  Kf, Kr, Kchol, Kbw = int(self.Kf), int(self.Kr), int(self.Kchol), int(self.Kbw)
626
+ correlationLength = int(self.correlationLength)
521
627
 
522
628
  @jax.jit
523
629
  def _neg_ll(b):
524
630
  return self._jax_mxl_negloglik(
525
631
  b, X_jax, y_jax, pi_jax, draws_jax,
526
- fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names)
632
+ fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names, correlationLength)
527
633
 
528
634
  _val_grad = jax.jit(jax.value_and_grad(_neg_ll))
529
635
 
@@ -141,7 +141,7 @@ class DiscreteChoiceModel(ABC):
141
141
  ''' ---------------------------------------------------------- '''
142
142
  ''' Function '''
143
143
  ''' ---------------------------------------------------------- '''
144
- def __init__(self, jax = False):
144
+ def __init__(self, jax = True):
145
145
  # {
146
146
 
147
147
  self.reset_attributes()
@@ -1,5 +1,5 @@
1
1
  import numpy as np
2
- from scipy.optimize import minimize
2
+ from scipy.optimize import minimize, differential_evolution
3
3
  from scipy.special import logsumexp
4
4
 
5
5
 
@@ -13,7 +13,7 @@ class LatentClassMixedLogit:
13
13
  class_maxiter=100,
14
14
  tol=1e-6,
15
15
  random_state=0,
16
- _jax=False,
16
+ _jax=True,
17
17
  n_init=1,
18
18
  ):
19
19
  self.n_classes = int(n_classes)
@@ -190,6 +190,82 @@ class LatentClassMixedLogit:
190
190
  )
191
191
  return result.x
192
192
 
193
+ def _de_warm_start(
194
+ self,
195
+ popsize=6,
196
+ maxiter=20,
197
+ tol=0.01,
198
+ seed=None,
199
+ bounds_scale=5.0,
200
+ ):
201
+ """Differential Evolution warm-start for EM betas.
202
+
203
+ Minimises the negative marginal log-likelihood over the flattened
204
+ ``(n_classes, K)`` beta matrix. Uses JAX for the objective if
205
+ enabled, otherwise falls back to the numpy path.
206
+
207
+ Returns
208
+ -------
209
+ betas0 : ndarray, shape (n_classes, K)
210
+ """
211
+ n_params = self.n_classes * self.K
212
+ bounds = [(-bounds_scale, bounds_scale)] * n_params
213
+
214
+ if self._jax_enabled:
215
+ jnp = self.jnp
216
+ X_b = self.X_backend # (N, J, K)
217
+ y_b = self.y_backend # (N, J)
218
+ av_b = self.avail_backend # (N, J)
219
+ C = self.n_classes
220
+
221
+ @self.jit
222
+ def _jax_negll(betas_flat):
223
+ betas = betas_flat.reshape(C, self.K)
224
+ utilities = jnp.einsum("ck,njk->ncj", betas, X_b)
225
+ utilities = jnp.where(av_b[:, None, :] > 0, utilities, -1e10)
226
+ utilities = utilities - jnp.max(utilities, axis=2, keepdims=True)
227
+ exp_u = jnp.exp(utilities) * av_b[:, None, :]
228
+ denom = jnp.clip(exp_u.sum(axis=2, keepdims=True), 1e-300)
229
+ probs = exp_u / denom # (N, C, J)
230
+ chosen = jnp.clip((probs * y_b[:, None, :]).sum(axis=2), 1e-300) # (N, C)
231
+ log_chosen = jnp.log(chosen) # (N, C)
232
+ # equal class priors
233
+ log_prior = jnp.log(jnp.full(C, 1.0 / C))
234
+ log_joint = log_chosen + log_prior[None, :]
235
+ log_marg = self.jax_logsumexp(log_joint, axis=1)
236
+ return -jnp.sum(log_marg)
237
+
238
+ def _obj(betas_np):
239
+ return float(_jax_negll(jnp.array(betas_np, dtype=jnp.float64)))
240
+
241
+ else:
242
+ def _obj(betas_np):
243
+ betas = betas_np.reshape(self.n_classes, self.K)
244
+ log_choice, _ = self._log_choice_probs_np(betas) # (N, C)
245
+ log_prior = np.log(np.full(self.n_classes, 1.0 / self.n_classes))
246
+ log_joint = log_choice + log_prior[None, :]
247
+ log_marg = logsumexp(log_joint, axis=1)
248
+ return -float(log_marg.sum())
249
+
250
+ print(
251
+ f"[LC-DE] Running DE: classes={self.n_classes}, K={self.K}, "
252
+ f"popsize={popsize}, maxiter={maxiter}, jax={self._jax_enabled}"
253
+ )
254
+ result = differential_evolution(
255
+ _obj,
256
+ bounds,
257
+ popsize=popsize,
258
+ maxiter=maxiter,
259
+ tol=tol,
260
+ seed=seed,
261
+ polish=False,
262
+ )
263
+ print(
264
+ f"[LC-DE] DE done: success={result.success}, "
265
+ f"negll={result.fun:.4f}, nit={result.nit}"
266
+ )
267
+ return result.x.reshape(self.n_classes, self.K)
268
+
193
269
  def _make_initial_betas(self, rng, betas0=None):
194
270
  if betas0 is not None:
195
271
  betas0 = np.asarray(betas0, dtype=float)
@@ -237,7 +313,30 @@ class LatentClassMixedLogit:
237
313
  "iterations": iteration,
238
314
  }
239
315
 
240
- def fit(self, betas0=None, class_probs0=None):
316
+ def fit(self, betas0=None, class_probs0=None,
317
+ de_init=False, de_popsize=6, de_maxiter=20, de_tol=0.01, de_seed=None):
318
+ """Fit the latent class model via EM.
319
+
320
+ Parameters
321
+ ----------
322
+ betas0 : ndarray, optional
323
+ Initial class betas, shape (n_classes, K).
324
+ class_probs0 : ndarray, optional
325
+ Initial class shares, length n_classes.
326
+ de_init : bool
327
+ Use Differential Evolution to warm-start the EM betas (overrides
328
+ ``betas0`` when True).
329
+ de_popsize, de_maxiter, de_tol, de_seed
330
+ DE hyper-parameters forwarded to :meth:`_de_warm_start`.
331
+ """
332
+ if de_init:
333
+ betas0 = self._de_warm_start(
334
+ popsize=de_popsize,
335
+ maxiter=de_maxiter,
336
+ tol=de_tol,
337
+ seed=de_seed,
338
+ )
339
+
241
340
  best_result = None
242
341
 
243
342
  for init_idx in range(self.n_init):
@@ -13,8 +13,8 @@ class MixedNested(MixedLogit, NestedLogit):
13
13
  Mixed Nested Logit Model.
14
14
  """
15
15
 
16
- def __init__(self, _jax=False):
17
- MixedLogit.__init__(self)
16
+ def __init__(self, _jax=True):
17
+ MixedLogit.__init__(self, _jax=_jax)
18
18
  self._jax = _jax
19
19
  self._set_backend(_jax)
20
20
  self.descr = "Mixed Nested Logit"
@@ -181,7 +181,7 @@ class MultinomialLogit(DiscreteChoiceModel):
181
181
  ''' --------------------------------- '''
182
182
  ''' Function. Constructor '''
183
183
  ''' --------------------------------- '''
184
- def __init__(self, _jax = False): # {
184
+ def __init__(self, _jax = True): # {
185
185
  super(MultinomialLogit, self).__init__(_jax) # Base class initialisations
186
186
  self.descr = "MNL"
187
187
  # }
@@ -12,7 +12,7 @@ class NestedLogit(MultinomialLogit):
12
12
  Handles nested structure of alternatives.
13
13
  """
14
14
 
15
- def __init__(self, _jax = False):
15
+ def __init__(self, _jax = True):
16
16
  super(NestedLogit, self).__init__(_jax)
17
17
  self.descr = "Nested Logit"
18
18
  self.robust = False
@@ -25,7 +25,7 @@ class MultinomialProbit(MultinomialLogit):
25
25
  softmax with a probit-style probability construction based on normal CDFs.
26
26
  """
27
27
 
28
- def __init__(self, _jax=False):
28
+ def __init__(self, _jax=True):
29
29
  super(MultinomialProbit, self).__init__(_jax)
30
30
  self.descr = "Multinomial Probit (pairwise normal approximation)"
31
31
 
@@ -146,7 +146,7 @@ class OrderedLogit():
146
146
  ''' ---------------------------------------------------------- '''
147
147
  ''' Function '''
148
148
  ''' ---------------------------------------------------------- '''
149
- def __init__(self, _jax=False, **kwargs):
149
+ def __init__(self, _jax=True, **kwargs):
150
150
  # {
151
151
  self.descr = "ORL"
152
152
  self.delta_transform = kwargs.get('dt',True)
@@ -673,8 +673,11 @@ class SAPBIL(SA):
673
673
  def finalise(self):
674
674
  super().finalise()
675
675
 
676
+ total = self.accepted + self.not_accepted + self.not_converged
676
677
  header = (
677
678
  f"\nSA+PBIL — probability matrix after {self._pbil_updates} update(s):\n"
679
+ f" perturbations: accepted={self.accepted} rejected={self.not_accepted}"
680
+ f" not_converged={self.not_converged} total={total}\n"
678
681
  f" {'Variable':<22s} {'P(incl)':>8s} {'P(rand)':>8s} "
679
682
  f"{'P(corr)':>8s} {'P(bc)':>6s} P(distr)"
680
683
  )
@@ -643,7 +643,7 @@ class Parameters:
643
643
  print(f'inspect choices {choices}')
644
644
 
645
645
  raise ValueError('choice set must be defined and in list format')
646
- self.verbose = kwargs.get('verbose', True)
646
+ self.verbose = kwargs.get('verbose', False)
647
647
  if self.verbose:
648
648
  logging.info('verbose = TRUE, Will print all solutions. SET verbose = False in parameters')
649
649
  self.test_choices = test_choices
@@ -677,6 +677,11 @@ class Parameters:
677
677
  # reduction: equivalent to ~2x draws for normal-based distributions.
678
678
  # shuffled=True applies Owen scrambling to reduce inter-dimension correlation.
679
679
  self.halton_opts = kwargs.get('halton_opts', {'antithetic': True})
680
+ self.de_init = kwargs.get('de_init', False)
681
+ self.de_popsize = kwargs.get('de_popsize', 4)
682
+ self.de_maxiter = kwargs.get('de_maxiter', 3)
683
+ self.de_tol = kwargs.get('de_tol', 0.5)
684
+ self.de_polish = kwargs.get('de_polish', False)
680
685
 
681
686
  self.intercept_opts = intercept_opts
682
687
  self.base_alt = base_alt
@@ -774,7 +779,11 @@ class Parameters:
774
779
  pass
775
780
 
776
781
  # TODO I Think we could initialise it this way more effictively
777
- acceptable_keys = ['LCR', 'verbose', 'asc_ind', 'nests', 'lambdas', 'varnest', '_jax', 'all_sig']
782
+ acceptable_keys = [
783
+ 'LCR', 'verbose', 'asc_ind', 'nests', 'lambdas', 'varnest',
784
+ '_jax', 'all_sig', 'de_init', 'de_popsize', 'de_maxiter',
785
+ 'de_tol', 'de_polish', 'halton_opts'
786
+ ]
778
787
 
779
788
  # Assign all kwargs to self, but only if the key is in the acceptable_keys list
780
789
  for key, value in kwargs.items():
@@ -962,7 +971,7 @@ class Solution(UserDict):
962
971
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
963
972
 
964
973
  self.data.setdefault('insig', None) # Insignificant variables
965
- self.data.setdefault('obj', np.zeros(nb_crit) )
974
+ self.data.setdefault('obj', np.full(nb_crit, np.inf))
966
975
  self.data.setdefault('model', None)
967
976
  self.data.setdefault('class_num', None)
968
977
  self.data.setdefault('hash', None)
@@ -1512,11 +1521,9 @@ class Search():
1512
1521
  label = model_n or sol.get('model_n', '?')
1513
1522
  sep = '─' * 62
1514
1523
 
1515
- # In quiet mode (default during search) only print a one-liner.
1524
+ # In quiet mode (default during search) simply count the failure and
1525
+ # return — the totals are written to the results file at the end.
1516
1526
  if not getattr(self.param, 'verbose_convergence', False):
1517
- print(f" [no-converge] model={label} sol#={sol.get('sol_num','?')} "
1518
- f"vars={all_vars} "
1519
- f"(set verbose_convergence=True in Parameters for full diagnostic)")
1520
1527
  return
1521
1528
 
1522
1529
  print(f"\n{sep}")
@@ -3488,7 +3495,7 @@ class Search():
3488
3495
  fit_intercept, init_coeff, n_draws, weights, avail, base_alt, maxiter, ftol, gtol, save_fitted_params,
3489
3496
  halton_opts=None):
3490
3497
  # {
3491
- model = MixedLogit()
3498
+ model = MixedLogit(_jax=getattr(self.param, '_jax', True))
3492
3499
  #subvarnames = varnames delete itemes in randvaras
3493
3500
 
3494
3501
 
@@ -3498,7 +3505,12 @@ class Search():
3498
3505
  model.setup(X=X, y=y, varnames=varnames, isvars=isvars, alts=alts, transvars=transvars, ids=ids,
3499
3506
  randvars=randvars, panels=panels, fit_intercept=fit_intercept, correlated_vars=corvars, n_draws=n_draws,
3500
3507
  init_coeff=init_coeff, weights=weights, avail=avail, base_alt=base_alt, maxiter=maxiter,
3501
- ftol=ftol, gtol=gtol, save_fitted_params=save_fitted_params, halton_opts=halton_opts)
3508
+ ftol=ftol, gtol=gtol, save_fitted_params=save_fitted_params, halton_opts=halton_opts,
3509
+ de_init=getattr(self.param, 'de_init', False),
3510
+ de_popsize=getattr(self.param, 'de_popsize', 4),
3511
+ de_maxiter=getattr(self.param, 'de_maxiter', 3),
3512
+ de_tol=getattr(self.param, 'de_tol', 0.5),
3513
+ de_polish=getattr(self.param, 'de_polish', False))
3502
3514
  model.fit()
3503
3515
 
3504
3516
  return model
@@ -3700,7 +3712,7 @@ class Search():
3700
3712
  X_nest = self.param.df[nest_vars]
3701
3713
 
3702
3714
  # Fit the Nested Logit model
3703
- model = NestedLogit(_jax=self.param._jax)
3715
+ model = NestedLogit(_jax=getattr(self.param, '_jax', True))
3704
3716
  model.setup(X=X, X_nest = X_nest, y=y, varnames=all_vars, isvars=is_vars,
3705
3717
  alts=self.param.alt_var, ids=self.param.choice_id,
3706
3718
  nests=nests, lambdas=lambdas, fit_intercept=asc_ind, return_grad=self.param.grad, return_hess=self.param.hess)
@@ -24,7 +24,7 @@ except ImportError:
24
24
  class BinaryProbit(DiscreteChoiceModel):
25
25
  """Binary probit estimated with JAX autodiff and scipy L-BFGS-B."""
26
26
 
27
- def __init__(self, _jax=False):
27
+ def __init__(self, _jax=True):
28
28
  super(BinaryProbit, self).__init__(_jax)
29
29
  self.descr = "Binary Probit"
30
30
  self.result = None
@@ -133,7 +133,7 @@ class _OLSResult:
133
133
  class HeckmanTwoStep(DiscreteChoiceModel):
134
134
  """Heckman selection model using JAX probit + closed-form OLS second stage."""
135
135
 
136
- def __init__(self, _jax=False):
136
+ def __init__(self, _jax=True):
137
137
  super(HeckmanTwoStep, self).__init__(_jax)
138
138
  self.descr = "Heckman Two-Step"
139
139
  self.selection_result = None