SearchLibrium 0.0.89__tar.gz → 0.0.90__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/PKG-INFO +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/pyproject.toml +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/MixedLogit.py +133 -27
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/_choice_model.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/latent_class.py +102 -3
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mixed_nested.py +2 -2
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/multinomial_logit.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/multinomial_nested.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/multinomial_probit.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/ordered_logit.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/sapbil.py +3 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/search.py +22 -10
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/selection_models.py +2 -2
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/siman.py +85 -46
- searchlibrium-0.0.90/src/SearchLibrium/test_mario_searches.py +346 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/test_sapbil_vs_banditsa.py +196 -33
- searchlibrium-0.0.90/src/SearchLibrium/version.txt +1 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/SOURCES.txt +1 -0
- searchlibrium-0.0.89/src/SearchLibrium/version.txt +0 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/README.md +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/setup.cfg +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/Halton.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/RandomP.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/Two_Level_Nest.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/__init__.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/__main__.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/_device.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/banditsa.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/bhhh/minimize.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/boxcox_functions.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/call_meta.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/constraints_builder.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/harmony.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/main.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/main_debug.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mdcev.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/misc.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mixed_logit.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/mixedrrm.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/rrm.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/setup.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium/threshold.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/requires.txt +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.90}/src/SearchLibrium.egg-info/top_level.txt +0 -0
|
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
|
|
|
59
59
|
realpython = "SearchLibrium.__main__:main"
|
|
60
60
|
|
|
61
61
|
[tool.bumpver]
|
|
62
|
-
current_version = "0.0.
|
|
62
|
+
current_version = "0.0.90"
|
|
63
63
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
64
64
|
commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
|
|
65
65
|
commit = true
|
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
|
2
2
|
import itertools
|
|
3
3
|
import numpy as np
|
|
4
4
|
import scipy.stats as ss
|
|
5
|
-
from scipy.optimize import minimize
|
|
5
|
+
from scipy.optimize import minimize, differential_evolution
|
|
6
6
|
from typing import Callable, Tuple
|
|
7
7
|
import inspect
|
|
8
8
|
|
|
@@ -34,8 +34,8 @@ max_comp_val, min_comp_val = 1e+20, 1e-200 # or use float('inf')
|
|
|
34
34
|
infinity = float('inf')
|
|
35
35
|
|
|
36
36
|
class MixedLogit(DiscreteChoiceModel):
|
|
37
|
-
def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u']):
|
|
38
|
-
super().__init__()
|
|
37
|
+
def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u'], _jax=True):
|
|
38
|
+
super().__init__(_jax)
|
|
39
39
|
self.halton_opts = halton_opts
|
|
40
40
|
self.draws_generator = Draws(k=len(distributions), halton_opts=halton_opts, rvdist=distributions)
|
|
41
41
|
self.random_parameters = RandomParameters(distributions or []) # Initialize RandomParameters
|
|
@@ -66,7 +66,9 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
66
66
|
n_draws=1000, halton=True, minimise_func=None,
|
|
67
67
|
batch_size=None, halton_opts=None, ftol=1e-6,
|
|
68
68
|
gtol=1e-6, return_hess=True, return_grad=True, method="bfgs",
|
|
69
|
-
save_fitted_params=True, mnl_init=True
|
|
69
|
+
save_fitted_params=True, mnl_init=True,
|
|
70
|
+
de_init=False, de_popsize=4, de_maxiter=3, de_tol=0.5,
|
|
71
|
+
de_polish=False):
|
|
70
72
|
# {
|
|
71
73
|
self.fit_intercept = fit_intercept
|
|
72
74
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
@@ -104,6 +106,11 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
104
106
|
self.minimise_func = minimise_func
|
|
105
107
|
self.save_fitted_params = save_fitted_params
|
|
106
108
|
self.mnl_init = mnl_init
|
|
109
|
+
self.de_init = de_init
|
|
110
|
+
self.de_popsize = de_popsize
|
|
111
|
+
self.de_maxiter = de_maxiter
|
|
112
|
+
self.de_tol = de_tol
|
|
113
|
+
self.de_polish = de_polish
|
|
107
114
|
self.total_fun_eval = 0
|
|
108
115
|
self.method = method.lower() if hasattr(method, 'lower') else method
|
|
109
116
|
self.jac = self.return_grad # scipy optimize parameter
|
|
@@ -312,6 +319,80 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
312
319
|
if len(self.init_coeff) != n_coeff and not hasattr(self, 'class_params_spec'):
|
|
313
320
|
raise ValueError("The size of init_coeff must be: " + str(n_coeff))
|
|
314
321
|
|
|
322
|
+
positive_bound = (0, infinity)
|
|
323
|
+
any_bound = (-infinity, infinity)
|
|
324
|
+
lmda_bound = (-5, 1)
|
|
325
|
+
bound_dict = {
|
|
326
|
+
"bf": (any_bound, self.Kf),
|
|
327
|
+
"br_b": (any_bound, self.Kr),
|
|
328
|
+
"chol": (any_bound, self.Kchol),
|
|
329
|
+
"br_w": (positive_bound, self.Kr - self.correlationLength),
|
|
330
|
+
"bf_trans": (any_bound, self.Kftrans),
|
|
331
|
+
"flmbda": (lmda_bound, self.Kftrans),
|
|
332
|
+
"br_trans_b": (any_bound, self.Krtrans),
|
|
333
|
+
"br_trans_w": (any_bound, self.Krtrans),
|
|
334
|
+
"rlmbda": (lmda_bound, self.Krtrans)
|
|
335
|
+
}
|
|
336
|
+
bnds = [[bound[1][0]] * bound[1][1] for bound in bound_dict.items() if bound[1][1] > 0]
|
|
337
|
+
bnds = list(itertools.chain.from_iterable(bnds))
|
|
338
|
+
|
|
339
|
+
print(f"[MXL] Starting beta seed length={n_coeff}, first_values={betas[:min(8, len(betas))]!r}")
|
|
340
|
+
|
|
341
|
+
if self.de_init:
|
|
342
|
+
print(f"[MXL] DE init enabled: popsize={self.de_popsize}, maxiter={self.de_maxiter}, tol={self.de_tol}, polish={self.de_polish}")
|
|
343
|
+
try:
|
|
344
|
+
before_de_obj = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
|
|
345
|
+
draws, drawstrans, self.weights,
|
|
346
|
+
self.avail, self.batch_size)[0]
|
|
347
|
+
print(f"[MXL] DE before: obj={before_de_obj:.6g}")
|
|
348
|
+
|
|
349
|
+
def _de_obj(x):
|
|
350
|
+
return self.get_loglik_gradient(x, self.X, self.y, self.panel_info,
|
|
351
|
+
draws, drawstrans, self.weights,
|
|
352
|
+
self.avail, self.batch_size)[0]
|
|
353
|
+
|
|
354
|
+
if len(bnds) == n_coeff:
|
|
355
|
+
de_bounds = []
|
|
356
|
+
for idx, bound in enumerate(bnds):
|
|
357
|
+
low, high = bound
|
|
358
|
+
if not np.isfinite(low) or not np.isfinite(high):
|
|
359
|
+
scale = max(1.0, 10.0 * abs(betas[idx]) if idx < len(betas) else 1.0)
|
|
360
|
+
de_bounds.append((-scale, scale))
|
|
361
|
+
else:
|
|
362
|
+
de_bounds.append((low, high))
|
|
363
|
+
else:
|
|
364
|
+
de_bounds = [(-10, 10)] * n_coeff
|
|
365
|
+
|
|
366
|
+
de_result = differential_evolution(
|
|
367
|
+
_de_obj,
|
|
368
|
+
de_bounds,
|
|
369
|
+
strategy='best1bin',
|
|
370
|
+
maxiter=self.de_maxiter,
|
|
371
|
+
popsize=self.de_popsize,
|
|
372
|
+
tol=self.de_tol,
|
|
373
|
+
polish=self.de_polish,
|
|
374
|
+
disp=False,
|
|
375
|
+
workers=1,
|
|
376
|
+
)
|
|
377
|
+
de_improved = False
|
|
378
|
+
print(f"[MXL] DE completed: success={de_result.success}, nit={de_result.nit}, message={de_result.message}")
|
|
379
|
+
if de_result.success:
|
|
380
|
+
de_obj = self.get_loglik_gradient(de_result.x, self.X, self.y, self.panel_info,
|
|
381
|
+
draws, drawstrans, self.weights,
|
|
382
|
+
self.avail, self.batch_size)[0]
|
|
383
|
+
print(f"[MXL] DE after: obj={de_obj:.6g}")
|
|
384
|
+
de_improved = de_obj < before_de_obj
|
|
385
|
+
if de_improved:
|
|
386
|
+
betas = de_result.x
|
|
387
|
+
print(f"[MXL] DE best seed first_values={betas[:min(8, len(betas))]!r}")
|
|
388
|
+
print("[MXL] Differential evolution initialization improved the objective; starting minimise from DE seed.")
|
|
389
|
+
else:
|
|
390
|
+
print(f"[MXL] Differential evolution did not beat the original start ({de_obj:.6g} >= {before_de_obj:.6g}); keeping original start.")
|
|
391
|
+
if not de_improved:
|
|
392
|
+
print(f"[MXL] Differential evolution initialization rejected because it did not improve the starting objective. Using original seed.")
|
|
393
|
+
except Exception as e:
|
|
394
|
+
print(f"[MXL] Differential evolution initialization failed: {e}")
|
|
395
|
+
|
|
315
396
|
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
|
|
316
397
|
if dev.using_gpu: # {
|
|
317
398
|
self.X, self.y = dev.convert_array_gpu(self.X), dev.convert_array_gpu(self.y)
|
|
@@ -323,6 +404,15 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
323
404
|
# }
|
|
324
405
|
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
|
|
325
406
|
|
|
407
|
+
print(f"[MXL] Minimization start with betas first_values={betas[:min(8, len(betas))]!r}")
|
|
408
|
+
try:
|
|
409
|
+
before_fun = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
|
|
410
|
+
draws, drawstrans, self.weights,
|
|
411
|
+
self.avail, self.batch_size)[0]
|
|
412
|
+
print(f"[MXL] Minimization obj before={before_fun:.6g}")
|
|
413
|
+
except Exception as e:
|
|
414
|
+
print(f"[MXL] Could not evaluate initial objective before minimization: {e}")
|
|
415
|
+
|
|
326
416
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
327
417
|
# Generate bound for L-BFGS-B method
|
|
328
418
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
@@ -378,6 +468,9 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
378
468
|
options = {'gtol': self.gtol, 'maxiter': self.maxiter, 'disp': False}
|
|
379
469
|
result = minimise_func(self.get_loglik_gradient, betas, jac=self.jac, method=self.method,
|
|
380
470
|
args=args, tol=self.ftol, bounds=bounds, options=options)
|
|
471
|
+
print(f"[MXL] Minimization completed: success={result.get('success', None)}, fun={result.get('fun', float('nan')):.6g}, nit={result.get('nit', '?')}")
|
|
472
|
+
if 'x' in result:
|
|
473
|
+
print(f"[MXL] Minimization final betas first_values={np.asarray(result['x'])[:min(8, len(result['x']))]!r}")
|
|
381
474
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
382
475
|
|
|
383
476
|
if hasattr(self, 'method') and self.method == "L-BFGS-B": # {
|
|
@@ -428,38 +521,37 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
428
521
|
# ------------------------------------------------------------------
|
|
429
522
|
@staticmethod
|
|
430
523
|
def _jax_mxl_negloglik(betas, X_jax, y_jax, panel_info_jax, draws_jax,
|
|
431
|
-
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names
|
|
524
|
+
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names,
|
|
525
|
+
correlationLength):
|
|
432
526
|
"""JAX simulation-based log-likelihood for Mixed Logit (standard case).
|
|
433
527
|
|
|
434
528
|
Handles fixed and random (normally/lognormally distributed) parameters.
|
|
435
|
-
Cholesky correlation structure is included via the Kchol segment
|
|
529
|
+
Cholesky correlation structure is included via the Kchol segment;
|
|
530
|
+
uncorrelated random vars use independent std devs in Br_w.
|
|
436
531
|
"""
|
|
437
532
|
import jax.numpy as jnp
|
|
438
533
|
|
|
439
534
|
# ---- split beta vector ----
|
|
440
535
|
Bf = betas[:Kf]
|
|
441
536
|
Br_b = betas[Kf:Kf + Kr]
|
|
442
|
-
chol_v = betas[Kf + Kr:Kf + Kr + Kchol]
|
|
537
|
+
chol_v = betas[Kf + Kr:Kf + Kr + Kchol] # correlated cholesky elements
|
|
443
538
|
Br_w = betas[Kf + Kr + Kchol:Kf + Kr + Kchol + Kbw] # independent std devs
|
|
444
539
|
|
|
445
540
|
# ---- build cholesky matrix ----
|
|
446
|
-
|
|
541
|
+
# Rows 0..correlationLength-1: lower-triangle from chol_v
|
|
542
|
+
# Rows correlationLength..Kr-1: diagonal from Br_w (uncorrelated)
|
|
447
543
|
chol_mat = jnp.zeros((Kr, Kr))
|
|
448
|
-
# place chol values into lower triangle
|
|
449
|
-
# (static indices ok for jit since Kr is compile-time constant)
|
|
450
544
|
idx = 0
|
|
451
|
-
for r in range(
|
|
545
|
+
for r in range(correlationLength):
|
|
452
546
|
for c in range(r + 1):
|
|
453
547
|
chol_mat = chol_mat.at[r, c].set(chol_v[idx])
|
|
454
548
|
idx += 1
|
|
455
|
-
# diagonal of the independent block
|
|
456
549
|
for k in range(Kbw):
|
|
457
|
-
|
|
550
|
+
diag_pos = correlationLength + k
|
|
551
|
+
chol_mat = chol_mat.at[diag_pos, diag_pos].set(jnp.abs(Br_w[k]))
|
|
458
552
|
|
|
459
553
|
# ---- sample random coefficients: (N, Kr, R) ----
|
|
460
|
-
N
|
|
461
|
-
R = draws_jax.shape[2]
|
|
462
|
-
# draws_jax: (N, Kr, R)
|
|
554
|
+
N = X_jax.shape[0]
|
|
463
555
|
Br = Br_b[:, None] + jnp.einsum('kl,nlr->nkr',
|
|
464
556
|
chol_mat,
|
|
465
557
|
draws_jax[:, :Kr, :]) # (N, Kr, R)
|
|
@@ -475,21 +567,24 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
475
567
|
Br_b[k] + Br_w[k] * (draws_jax[:, k, :] - 0.5))
|
|
476
568
|
|
|
477
569
|
# ---- utility ----
|
|
478
|
-
# X_jax: (N, P, J, K)
|
|
479
570
|
P = X_jax.shape[1]
|
|
480
|
-
|
|
571
|
+
J = X_jax.shape[2]
|
|
481
572
|
Xr = X_jax[:, :, :, rvidx] # (N, P, J, Kr)
|
|
482
|
-
|
|
483
|
-
UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
|
|
484
573
|
UR = jnp.einsum('npjk,nkr->npjr', Xr, Br) # (N, P, J, R)
|
|
485
|
-
|
|
574
|
+
|
|
575
|
+
if Kf > 0:
|
|
576
|
+
Xf = X_jax[:, :, :, fxidx] # (N, P, J, Kf)
|
|
577
|
+
UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
|
|
578
|
+
U = UB[:, :, :, None] + UR
|
|
579
|
+
else:
|
|
580
|
+
U = UR # (N, P, J, R)
|
|
486
581
|
|
|
487
582
|
U = U - jnp.max(U, axis=2, keepdims=True)
|
|
488
583
|
eU = jnp.exp(U)
|
|
489
584
|
p = eU / jnp.sum(eU, axis=2, keepdims=True) # (N, P, J, R)
|
|
490
585
|
|
|
491
|
-
pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R)
|
|
492
|
-
pch = jnp.prod(pch, axis=1) # (N, R)
|
|
586
|
+
pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R)
|
|
587
|
+
pch = jnp.prod(pch, axis=1) # (N, R)
|
|
493
588
|
pch = jnp.clip(pch, 1e-300, None)
|
|
494
589
|
|
|
495
590
|
sim_p = jnp.mean(pch, axis=1) # (N,)
|
|
@@ -503,27 +598,38 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
503
598
|
Falls back to standard numpy optimisation on failure.
|
|
504
599
|
"""
|
|
505
600
|
try:
|
|
601
|
+
import os
|
|
602
|
+
os.environ.setdefault("JAX_ENABLE_X64", "True")
|
|
506
603
|
import jax
|
|
604
|
+
jax.config.update("jax_enable_x64", True)
|
|
507
605
|
import jax.numpy as jnp
|
|
508
606
|
from scipy.optimize import minimize as sp_min
|
|
509
607
|
|
|
608
|
+
if int(self.Kr) == 0:
|
|
609
|
+
print("[JAX MXL optimizer] no random coefficients detected (Kr=0); falling back to scipy.")
|
|
610
|
+
return None
|
|
611
|
+
|
|
510
612
|
# Convert static inputs once
|
|
511
613
|
X_jax = jnp.array(self.X, dtype=jnp.float64)
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
614
|
+
_y_np = np.asarray(self.y)
|
|
615
|
+
if _y_np.ndim > 3: # stored as (N, P, J, 1) in some paths
|
|
616
|
+
_y_np = _y_np[..., 0]
|
|
617
|
+
y_jax = jnp.array(_y_np, dtype=jnp.float64)
|
|
618
|
+
pi_jax = jnp.array(self.panel_info, dtype=jnp.float64)
|
|
619
|
+
draws_jax = jnp.array(draws, dtype=jnp.float64)
|
|
515
620
|
|
|
516
621
|
fxidx = jnp.array(self.fxidx, dtype=bool)
|
|
517
622
|
rvidx = jnp.array(self.rvidx, dtype=bool)
|
|
518
623
|
rvdist_names = [d for d in self.rvdist if d is not False]
|
|
519
624
|
|
|
520
625
|
Kf, Kr, Kchol, Kbw = int(self.Kf), int(self.Kr), int(self.Kchol), int(self.Kbw)
|
|
626
|
+
correlationLength = int(self.correlationLength)
|
|
521
627
|
|
|
522
628
|
@jax.jit
|
|
523
629
|
def _neg_ll(b):
|
|
524
630
|
return self._jax_mxl_negloglik(
|
|
525
631
|
b, X_jax, y_jax, pi_jax, draws_jax,
|
|
526
|
-
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names)
|
|
632
|
+
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names, correlationLength)
|
|
527
633
|
|
|
528
634
|
_val_grad = jax.jit(jax.value_and_grad(_neg_ll))
|
|
529
635
|
|
|
@@ -141,7 +141,7 @@ class DiscreteChoiceModel(ABC):
|
|
|
141
141
|
''' ---------------------------------------------------------- '''
|
|
142
142
|
''' Function '''
|
|
143
143
|
''' ---------------------------------------------------------- '''
|
|
144
|
-
def __init__(self, jax =
|
|
144
|
+
def __init__(self, jax = True):
|
|
145
145
|
# {
|
|
146
146
|
|
|
147
147
|
self.reset_attributes()
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from scipy.optimize import minimize
|
|
2
|
+
from scipy.optimize import minimize, differential_evolution
|
|
3
3
|
from scipy.special import logsumexp
|
|
4
4
|
|
|
5
5
|
|
|
@@ -13,7 +13,7 @@ class LatentClassMixedLogit:
|
|
|
13
13
|
class_maxiter=100,
|
|
14
14
|
tol=1e-6,
|
|
15
15
|
random_state=0,
|
|
16
|
-
_jax=
|
|
16
|
+
_jax=True,
|
|
17
17
|
n_init=1,
|
|
18
18
|
):
|
|
19
19
|
self.n_classes = int(n_classes)
|
|
@@ -190,6 +190,82 @@ class LatentClassMixedLogit:
|
|
|
190
190
|
)
|
|
191
191
|
return result.x
|
|
192
192
|
|
|
193
|
+
def _de_warm_start(
|
|
194
|
+
self,
|
|
195
|
+
popsize=6,
|
|
196
|
+
maxiter=20,
|
|
197
|
+
tol=0.01,
|
|
198
|
+
seed=None,
|
|
199
|
+
bounds_scale=5.0,
|
|
200
|
+
):
|
|
201
|
+
"""Differential Evolution warm-start for EM betas.
|
|
202
|
+
|
|
203
|
+
Minimises the negative marginal log-likelihood over the flattened
|
|
204
|
+
``(n_classes, K)`` beta matrix. Uses JAX for the objective if
|
|
205
|
+
enabled, otherwise falls back to the numpy path.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
betas0 : ndarray, shape (n_classes, K)
|
|
210
|
+
"""
|
|
211
|
+
n_params = self.n_classes * self.K
|
|
212
|
+
bounds = [(-bounds_scale, bounds_scale)] * n_params
|
|
213
|
+
|
|
214
|
+
if self._jax_enabled:
|
|
215
|
+
jnp = self.jnp
|
|
216
|
+
X_b = self.X_backend # (N, J, K)
|
|
217
|
+
y_b = self.y_backend # (N, J)
|
|
218
|
+
av_b = self.avail_backend # (N, J)
|
|
219
|
+
C = self.n_classes
|
|
220
|
+
|
|
221
|
+
@self.jit
|
|
222
|
+
def _jax_negll(betas_flat):
|
|
223
|
+
betas = betas_flat.reshape(C, self.K)
|
|
224
|
+
utilities = jnp.einsum("ck,njk->ncj", betas, X_b)
|
|
225
|
+
utilities = jnp.where(av_b[:, None, :] > 0, utilities, -1e10)
|
|
226
|
+
utilities = utilities - jnp.max(utilities, axis=2, keepdims=True)
|
|
227
|
+
exp_u = jnp.exp(utilities) * av_b[:, None, :]
|
|
228
|
+
denom = jnp.clip(exp_u.sum(axis=2, keepdims=True), 1e-300)
|
|
229
|
+
probs = exp_u / denom # (N, C, J)
|
|
230
|
+
chosen = jnp.clip((probs * y_b[:, None, :]).sum(axis=2), 1e-300) # (N, C)
|
|
231
|
+
log_chosen = jnp.log(chosen) # (N, C)
|
|
232
|
+
# equal class priors
|
|
233
|
+
log_prior = jnp.log(jnp.full(C, 1.0 / C))
|
|
234
|
+
log_joint = log_chosen + log_prior[None, :]
|
|
235
|
+
log_marg = self.jax_logsumexp(log_joint, axis=1)
|
|
236
|
+
return -jnp.sum(log_marg)
|
|
237
|
+
|
|
238
|
+
def _obj(betas_np):
|
|
239
|
+
return float(_jax_negll(jnp.array(betas_np, dtype=jnp.float64)))
|
|
240
|
+
|
|
241
|
+
else:
|
|
242
|
+
def _obj(betas_np):
|
|
243
|
+
betas = betas_np.reshape(self.n_classes, self.K)
|
|
244
|
+
log_choice, _ = self._log_choice_probs_np(betas) # (N, C)
|
|
245
|
+
log_prior = np.log(np.full(self.n_classes, 1.0 / self.n_classes))
|
|
246
|
+
log_joint = log_choice + log_prior[None, :]
|
|
247
|
+
log_marg = logsumexp(log_joint, axis=1)
|
|
248
|
+
return -float(log_marg.sum())
|
|
249
|
+
|
|
250
|
+
print(
|
|
251
|
+
f"[LC-DE] Running DE: classes={self.n_classes}, K={self.K}, "
|
|
252
|
+
f"popsize={popsize}, maxiter={maxiter}, jax={self._jax_enabled}"
|
|
253
|
+
)
|
|
254
|
+
result = differential_evolution(
|
|
255
|
+
_obj,
|
|
256
|
+
bounds,
|
|
257
|
+
popsize=popsize,
|
|
258
|
+
maxiter=maxiter,
|
|
259
|
+
tol=tol,
|
|
260
|
+
seed=seed,
|
|
261
|
+
polish=False,
|
|
262
|
+
)
|
|
263
|
+
print(
|
|
264
|
+
f"[LC-DE] DE done: success={result.success}, "
|
|
265
|
+
f"negll={result.fun:.4f}, nit={result.nit}"
|
|
266
|
+
)
|
|
267
|
+
return result.x.reshape(self.n_classes, self.K)
|
|
268
|
+
|
|
193
269
|
def _make_initial_betas(self, rng, betas0=None):
|
|
194
270
|
if betas0 is not None:
|
|
195
271
|
betas0 = np.asarray(betas0, dtype=float)
|
|
@@ -237,7 +313,30 @@ class LatentClassMixedLogit:
|
|
|
237
313
|
"iterations": iteration,
|
|
238
314
|
}
|
|
239
315
|
|
|
240
|
-
def fit(self, betas0=None, class_probs0=None
|
|
316
|
+
def fit(self, betas0=None, class_probs0=None,
|
|
317
|
+
de_init=False, de_popsize=6, de_maxiter=20, de_tol=0.01, de_seed=None):
|
|
318
|
+
"""Fit the latent class model via EM.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
betas0 : ndarray, optional
|
|
323
|
+
Initial class betas, shape (n_classes, K).
|
|
324
|
+
class_probs0 : ndarray, optional
|
|
325
|
+
Initial class shares, length n_classes.
|
|
326
|
+
de_init : bool
|
|
327
|
+
Use Differential Evolution to warm-start the EM betas (overrides
|
|
328
|
+
``betas0`` when True).
|
|
329
|
+
de_popsize, de_maxiter, de_tol, de_seed
|
|
330
|
+
DE hyper-parameters forwarded to :meth:`_de_warm_start`.
|
|
331
|
+
"""
|
|
332
|
+
if de_init:
|
|
333
|
+
betas0 = self._de_warm_start(
|
|
334
|
+
popsize=de_popsize,
|
|
335
|
+
maxiter=de_maxiter,
|
|
336
|
+
tol=de_tol,
|
|
337
|
+
seed=de_seed,
|
|
338
|
+
)
|
|
339
|
+
|
|
241
340
|
best_result = None
|
|
242
341
|
|
|
243
342
|
for init_idx in range(self.n_init):
|
|
@@ -13,8 +13,8 @@ class MixedNested(MixedLogit, NestedLogit):
|
|
|
13
13
|
Mixed Nested Logit Model.
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
|
-
def __init__(self, _jax=
|
|
17
|
-
MixedLogit.__init__(self)
|
|
16
|
+
def __init__(self, _jax=True):
|
|
17
|
+
MixedLogit.__init__(self, _jax=_jax)
|
|
18
18
|
self._jax = _jax
|
|
19
19
|
self._set_backend(_jax)
|
|
20
20
|
self.descr = "Mixed Nested Logit"
|
|
@@ -181,7 +181,7 @@ class MultinomialLogit(DiscreteChoiceModel):
|
|
|
181
181
|
''' --------------------------------- '''
|
|
182
182
|
''' Function. Constructor '''
|
|
183
183
|
''' --------------------------------- '''
|
|
184
|
-
def __init__(self, _jax =
|
|
184
|
+
def __init__(self, _jax = True): # {
|
|
185
185
|
super(MultinomialLogit, self).__init__(_jax) # Base class initialisations
|
|
186
186
|
self.descr = "MNL"
|
|
187
187
|
# }
|
|
@@ -12,7 +12,7 @@ class NestedLogit(MultinomialLogit):
|
|
|
12
12
|
Handles nested structure of alternatives.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
def __init__(self, _jax =
|
|
15
|
+
def __init__(self, _jax = True):
|
|
16
16
|
super(NestedLogit, self).__init__(_jax)
|
|
17
17
|
self.descr = "Nested Logit"
|
|
18
18
|
self.robust = False
|
|
@@ -25,7 +25,7 @@ class MultinomialProbit(MultinomialLogit):
|
|
|
25
25
|
softmax with a probit-style probability construction based on normal CDFs.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
-
def __init__(self, _jax=
|
|
28
|
+
def __init__(self, _jax=True):
|
|
29
29
|
super(MultinomialProbit, self).__init__(_jax)
|
|
30
30
|
self.descr = "Multinomial Probit (pairwise normal approximation)"
|
|
31
31
|
|
|
@@ -146,7 +146,7 @@ class OrderedLogit():
|
|
|
146
146
|
''' ---------------------------------------------------------- '''
|
|
147
147
|
''' Function '''
|
|
148
148
|
''' ---------------------------------------------------------- '''
|
|
149
|
-
def __init__(self, _jax=
|
|
149
|
+
def __init__(self, _jax=True, **kwargs):
|
|
150
150
|
# {
|
|
151
151
|
self.descr = "ORL"
|
|
152
152
|
self.delta_transform = kwargs.get('dt',True)
|
|
@@ -673,8 +673,11 @@ class SAPBIL(SA):
|
|
|
673
673
|
def finalise(self):
|
|
674
674
|
super().finalise()
|
|
675
675
|
|
|
676
|
+
total = self.accepted + self.not_accepted + self.not_converged
|
|
676
677
|
header = (
|
|
677
678
|
f"\nSA+PBIL — probability matrix after {self._pbil_updates} update(s):\n"
|
|
679
|
+
f" perturbations: accepted={self.accepted} rejected={self.not_accepted}"
|
|
680
|
+
f" not_converged={self.not_converged} total={total}\n"
|
|
678
681
|
f" {'Variable':<22s} {'P(incl)':>8s} {'P(rand)':>8s} "
|
|
679
682
|
f"{'P(corr)':>8s} {'P(bc)':>6s} P(distr)"
|
|
680
683
|
)
|
|
@@ -643,7 +643,7 @@ class Parameters:
|
|
|
643
643
|
print(f'inspect choices {choices}')
|
|
644
644
|
|
|
645
645
|
raise ValueError('choice set must be defined and in list format')
|
|
646
|
-
self.verbose = kwargs.get('verbose',
|
|
646
|
+
self.verbose = kwargs.get('verbose', False)
|
|
647
647
|
if self.verbose:
|
|
648
648
|
logging.info('verbose = TRUE, Will print all solutions. SET verbose = False in parameters')
|
|
649
649
|
self.test_choices = test_choices
|
|
@@ -677,6 +677,11 @@ class Parameters:
|
|
|
677
677
|
# reduction: equivalent to ~2x draws for normal-based distributions.
|
|
678
678
|
# shuffled=True applies Owen scrambling to reduce inter-dimension correlation.
|
|
679
679
|
self.halton_opts = kwargs.get('halton_opts', {'antithetic': True})
|
|
680
|
+
self.de_init = kwargs.get('de_init', False)
|
|
681
|
+
self.de_popsize = kwargs.get('de_popsize', 4)
|
|
682
|
+
self.de_maxiter = kwargs.get('de_maxiter', 3)
|
|
683
|
+
self.de_tol = kwargs.get('de_tol', 0.5)
|
|
684
|
+
self.de_polish = kwargs.get('de_polish', False)
|
|
680
685
|
|
|
681
686
|
self.intercept_opts = intercept_opts
|
|
682
687
|
self.base_alt = base_alt
|
|
@@ -774,7 +779,11 @@ class Parameters:
|
|
|
774
779
|
pass
|
|
775
780
|
|
|
776
781
|
# TODO I Think we could initialise it this way more effictively
|
|
777
|
-
acceptable_keys = [
|
|
782
|
+
acceptable_keys = [
|
|
783
|
+
'LCR', 'verbose', 'asc_ind', 'nests', 'lambdas', 'varnest',
|
|
784
|
+
'_jax', 'all_sig', 'de_init', 'de_popsize', 'de_maxiter',
|
|
785
|
+
'de_tol', 'de_polish', 'halton_opts'
|
|
786
|
+
]
|
|
778
787
|
|
|
779
788
|
# Assign all kwargs to self, but only if the key is in the acceptable_keys list
|
|
780
789
|
for key, value in kwargs.items():
|
|
@@ -962,7 +971,7 @@ class Solution(UserDict):
|
|
|
962
971
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
963
972
|
|
|
964
973
|
self.data.setdefault('insig', None) # Insignificant variables
|
|
965
|
-
self.data.setdefault('obj', np.
|
|
974
|
+
self.data.setdefault('obj', np.full(nb_crit, np.inf))
|
|
966
975
|
self.data.setdefault('model', None)
|
|
967
976
|
self.data.setdefault('class_num', None)
|
|
968
977
|
self.data.setdefault('hash', None)
|
|
@@ -1512,11 +1521,9 @@ class Search():
|
|
|
1512
1521
|
label = model_n or sol.get('model_n', '?')
|
|
1513
1522
|
sep = '─' * 62
|
|
1514
1523
|
|
|
1515
|
-
# In quiet mode (default during search)
|
|
1524
|
+
# In quiet mode (default during search) simply count the failure and
|
|
1525
|
+
# return — the totals are written to the results file at the end.
|
|
1516
1526
|
if not getattr(self.param, 'verbose_convergence', False):
|
|
1517
|
-
print(f" [no-converge] model={label} sol#={sol.get('sol_num','?')} "
|
|
1518
|
-
f"vars={all_vars} "
|
|
1519
|
-
f"(set verbose_convergence=True in Parameters for full diagnostic)")
|
|
1520
1527
|
return
|
|
1521
1528
|
|
|
1522
1529
|
print(f"\n{sep}")
|
|
@@ -3488,7 +3495,7 @@ class Search():
|
|
|
3488
3495
|
fit_intercept, init_coeff, n_draws, weights, avail, base_alt, maxiter, ftol, gtol, save_fitted_params,
|
|
3489
3496
|
halton_opts=None):
|
|
3490
3497
|
# {
|
|
3491
|
-
model = MixedLogit()
|
|
3498
|
+
model = MixedLogit(_jax=getattr(self.param, '_jax', True))
|
|
3492
3499
|
#subvarnames = varnames delete itemes in randvaras
|
|
3493
3500
|
|
|
3494
3501
|
|
|
@@ -3498,7 +3505,12 @@ class Search():
|
|
|
3498
3505
|
model.setup(X=X, y=y, varnames=varnames, isvars=isvars, alts=alts, transvars=transvars, ids=ids,
|
|
3499
3506
|
randvars=randvars, panels=panels, fit_intercept=fit_intercept, correlated_vars=corvars, n_draws=n_draws,
|
|
3500
3507
|
init_coeff=init_coeff, weights=weights, avail=avail, base_alt=base_alt, maxiter=maxiter,
|
|
3501
|
-
ftol=ftol, gtol=gtol, save_fitted_params=save_fitted_params, halton_opts=halton_opts
|
|
3508
|
+
ftol=ftol, gtol=gtol, save_fitted_params=save_fitted_params, halton_opts=halton_opts,
|
|
3509
|
+
de_init=getattr(self.param, 'de_init', False),
|
|
3510
|
+
de_popsize=getattr(self.param, 'de_popsize', 4),
|
|
3511
|
+
de_maxiter=getattr(self.param, 'de_maxiter', 3),
|
|
3512
|
+
de_tol=getattr(self.param, 'de_tol', 0.5),
|
|
3513
|
+
de_polish=getattr(self.param, 'de_polish', False))
|
|
3502
3514
|
model.fit()
|
|
3503
3515
|
|
|
3504
3516
|
return model
|
|
@@ -3700,7 +3712,7 @@ class Search():
|
|
|
3700
3712
|
X_nest = self.param.df[nest_vars]
|
|
3701
3713
|
|
|
3702
3714
|
# Fit the Nested Logit model
|
|
3703
|
-
model = NestedLogit(_jax=self.param
|
|
3715
|
+
model = NestedLogit(_jax=getattr(self.param, '_jax', True))
|
|
3704
3716
|
model.setup(X=X, X_nest = X_nest, y=y, varnames=all_vars, isvars=is_vars,
|
|
3705
3717
|
alts=self.param.alt_var, ids=self.param.choice_id,
|
|
3706
3718
|
nests=nests, lambdas=lambdas, fit_intercept=asc_ind, return_grad=self.param.grad, return_hess=self.param.hess)
|
|
@@ -24,7 +24,7 @@ except ImportError:
|
|
|
24
24
|
class BinaryProbit(DiscreteChoiceModel):
|
|
25
25
|
"""Binary probit estimated with JAX autodiff and scipy L-BFGS-B."""
|
|
26
26
|
|
|
27
|
-
def __init__(self, _jax=
|
|
27
|
+
def __init__(self, _jax=True):
|
|
28
28
|
super(BinaryProbit, self).__init__(_jax)
|
|
29
29
|
self.descr = "Binary Probit"
|
|
30
30
|
self.result = None
|
|
@@ -133,7 +133,7 @@ class _OLSResult:
|
|
|
133
133
|
class HeckmanTwoStep(DiscreteChoiceModel):
|
|
134
134
|
"""Heckman selection model using JAX probit + closed-form OLS second stage."""
|
|
135
135
|
|
|
136
|
-
def __init__(self, _jax=
|
|
136
|
+
def __init__(self, _jax=True):
|
|
137
137
|
super(HeckmanTwoStep, self).__init__(_jax)
|
|
138
138
|
self.descr = "Heckman Two-Step"
|
|
139
139
|
self.selection_result = None
|