SearchLibrium 0.0.89__tar.gz → 0.0.91__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/PKG-INFO +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/pyproject.toml +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/MixedLogit.py +133 -27
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/_choice_model.py +1 -1
- searchlibrium-0.0.91/src/SearchLibrium/latent_class.py +850 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mixed_nested.py +2 -2
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/multinomial_logit.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/multinomial_nested.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/multinomial_probit.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/ordered_logit.py +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/sapbil.py +3 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/search.py +22 -10
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/selection_models.py +2 -2
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/siman.py +85 -46
- searchlibrium-0.0.91/src/SearchLibrium/test_lc_de.py +423 -0
- searchlibrium-0.0.91/src/SearchLibrium/test_mario_searches.py +346 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/test_sapbil_vs_banditsa.py +196 -33
- searchlibrium-0.0.91/src/SearchLibrium/version.txt +1 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/SOURCES.txt +2 -0
- searchlibrium-0.0.89/src/SearchLibrium/latent_class.py +0 -353
- searchlibrium-0.0.89/src/SearchLibrium/version.txt +0 -1
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/README.md +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/setup.cfg +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/Halton.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/RandomP.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/Two_Level_Nest.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/__init__.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/__main__.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/_device.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/banditsa.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/bhhh/minimize.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/boxcox_functions.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/call_meta.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/constraints_builder.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/harmony.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/main.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/main_debug.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mdcev.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/misc.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mixed_logit.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/mixedrrm.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/rrm.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/setup.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium/threshold.py +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/requires.txt +0 -0
- {searchlibrium-0.0.89 → searchlibrium-0.0.91}/src/SearchLibrium.egg-info/top_level.txt +0 -0
|
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
|
|
|
59
59
|
realpython = "SearchLibrium.__main__:main"
|
|
60
60
|
|
|
61
61
|
[tool.bumpver]
|
|
62
|
-
current_version = "0.0.
|
|
62
|
+
current_version = "0.0.91"
|
|
63
63
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
64
64
|
commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
|
|
65
65
|
commit = true
|
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
|
2
2
|
import itertools
|
|
3
3
|
import numpy as np
|
|
4
4
|
import scipy.stats as ss
|
|
5
|
-
from scipy.optimize import minimize
|
|
5
|
+
from scipy.optimize import minimize, differential_evolution
|
|
6
6
|
from typing import Callable, Tuple
|
|
7
7
|
import inspect
|
|
8
8
|
|
|
@@ -34,8 +34,8 @@ max_comp_val, min_comp_val = 1e+20, 1e-200 # or use float('inf')
|
|
|
34
34
|
infinity = float('inf')
|
|
35
35
|
|
|
36
36
|
class MixedLogit(DiscreteChoiceModel):
|
|
37
|
-
def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u']):
|
|
38
|
-
super().__init__()
|
|
37
|
+
def __init__(self, halton_opts=None, distributions=['n', 'ln', 't', 'tn', 'u'], _jax=True):
|
|
38
|
+
super().__init__(_jax)
|
|
39
39
|
self.halton_opts = halton_opts
|
|
40
40
|
self.draws_generator = Draws(k=len(distributions), halton_opts=halton_opts, rvdist=distributions)
|
|
41
41
|
self.random_parameters = RandomParameters(distributions or []) # Initialize RandomParameters
|
|
@@ -66,7 +66,9 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
66
66
|
n_draws=1000, halton=True, minimise_func=None,
|
|
67
67
|
batch_size=None, halton_opts=None, ftol=1e-6,
|
|
68
68
|
gtol=1e-6, return_hess=True, return_grad=True, method="bfgs",
|
|
69
|
-
save_fitted_params=True, mnl_init=True
|
|
69
|
+
save_fitted_params=True, mnl_init=True,
|
|
70
|
+
de_init=False, de_popsize=4, de_maxiter=3, de_tol=0.5,
|
|
71
|
+
de_polish=False):
|
|
70
72
|
# {
|
|
71
73
|
self.fit_intercept = fit_intercept
|
|
72
74
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
@@ -104,6 +106,11 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
104
106
|
self.minimise_func = minimise_func
|
|
105
107
|
self.save_fitted_params = save_fitted_params
|
|
106
108
|
self.mnl_init = mnl_init
|
|
109
|
+
self.de_init = de_init
|
|
110
|
+
self.de_popsize = de_popsize
|
|
111
|
+
self.de_maxiter = de_maxiter
|
|
112
|
+
self.de_tol = de_tol
|
|
113
|
+
self.de_polish = de_polish
|
|
107
114
|
self.total_fun_eval = 0
|
|
108
115
|
self.method = method.lower() if hasattr(method, 'lower') else method
|
|
109
116
|
self.jac = self.return_grad # scipy optimize parameter
|
|
@@ -312,6 +319,80 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
312
319
|
if len(self.init_coeff) != n_coeff and not hasattr(self, 'class_params_spec'):
|
|
313
320
|
raise ValueError("The size of init_coeff must be: " + str(n_coeff))
|
|
314
321
|
|
|
322
|
+
positive_bound = (0, infinity)
|
|
323
|
+
any_bound = (-infinity, infinity)
|
|
324
|
+
lmda_bound = (-5, 1)
|
|
325
|
+
bound_dict = {
|
|
326
|
+
"bf": (any_bound, self.Kf),
|
|
327
|
+
"br_b": (any_bound, self.Kr),
|
|
328
|
+
"chol": (any_bound, self.Kchol),
|
|
329
|
+
"br_w": (positive_bound, self.Kr - self.correlationLength),
|
|
330
|
+
"bf_trans": (any_bound, self.Kftrans),
|
|
331
|
+
"flmbda": (lmda_bound, self.Kftrans),
|
|
332
|
+
"br_trans_b": (any_bound, self.Krtrans),
|
|
333
|
+
"br_trans_w": (any_bound, self.Krtrans),
|
|
334
|
+
"rlmbda": (lmda_bound, self.Krtrans)
|
|
335
|
+
}
|
|
336
|
+
bnds = [[bound[1][0]] * bound[1][1] for bound in bound_dict.items() if bound[1][1] > 0]
|
|
337
|
+
bnds = list(itertools.chain.from_iterable(bnds))
|
|
338
|
+
|
|
339
|
+
print(f"[MXL] Starting beta seed length={n_coeff}, first_values={betas[:min(8, len(betas))]!r}")
|
|
340
|
+
|
|
341
|
+
if self.de_init:
|
|
342
|
+
print(f"[MXL] DE init enabled: popsize={self.de_popsize}, maxiter={self.de_maxiter}, tol={self.de_tol}, polish={self.de_polish}")
|
|
343
|
+
try:
|
|
344
|
+
before_de_obj = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
|
|
345
|
+
draws, drawstrans, self.weights,
|
|
346
|
+
self.avail, self.batch_size)[0]
|
|
347
|
+
print(f"[MXL] DE before: obj={before_de_obj:.6g}")
|
|
348
|
+
|
|
349
|
+
def _de_obj(x):
|
|
350
|
+
return self.get_loglik_gradient(x, self.X, self.y, self.panel_info,
|
|
351
|
+
draws, drawstrans, self.weights,
|
|
352
|
+
self.avail, self.batch_size)[0]
|
|
353
|
+
|
|
354
|
+
if len(bnds) == n_coeff:
|
|
355
|
+
de_bounds = []
|
|
356
|
+
for idx, bound in enumerate(bnds):
|
|
357
|
+
low, high = bound
|
|
358
|
+
if not np.isfinite(low) or not np.isfinite(high):
|
|
359
|
+
scale = max(1.0, 10.0 * abs(betas[idx]) if idx < len(betas) else 1.0)
|
|
360
|
+
de_bounds.append((-scale, scale))
|
|
361
|
+
else:
|
|
362
|
+
de_bounds.append((low, high))
|
|
363
|
+
else:
|
|
364
|
+
de_bounds = [(-10, 10)] * n_coeff
|
|
365
|
+
|
|
366
|
+
de_result = differential_evolution(
|
|
367
|
+
_de_obj,
|
|
368
|
+
de_bounds,
|
|
369
|
+
strategy='best1bin',
|
|
370
|
+
maxiter=self.de_maxiter,
|
|
371
|
+
popsize=self.de_popsize,
|
|
372
|
+
tol=self.de_tol,
|
|
373
|
+
polish=self.de_polish,
|
|
374
|
+
disp=False,
|
|
375
|
+
workers=1,
|
|
376
|
+
)
|
|
377
|
+
de_improved = False
|
|
378
|
+
print(f"[MXL] DE completed: success={de_result.success}, nit={de_result.nit}, message={de_result.message}")
|
|
379
|
+
if de_result.success:
|
|
380
|
+
de_obj = self.get_loglik_gradient(de_result.x, self.X, self.y, self.panel_info,
|
|
381
|
+
draws, drawstrans, self.weights,
|
|
382
|
+
self.avail, self.batch_size)[0]
|
|
383
|
+
print(f"[MXL] DE after: obj={de_obj:.6g}")
|
|
384
|
+
de_improved = de_obj < before_de_obj
|
|
385
|
+
if de_improved:
|
|
386
|
+
betas = de_result.x
|
|
387
|
+
print(f"[MXL] DE best seed first_values={betas[:min(8, len(betas))]!r}")
|
|
388
|
+
print("[MXL] Differential evolution initialization improved the objective; starting minimise from DE seed.")
|
|
389
|
+
else:
|
|
390
|
+
print(f"[MXL] Differential evolution did not beat the original start ({de_obj:.6g} >= {before_de_obj:.6g}); keeping original start.")
|
|
391
|
+
if not de_improved:
|
|
392
|
+
print(f"[MXL] Differential evolution initialization rejected because it did not improve the starting objective. Using original seed.")
|
|
393
|
+
except Exception as e:
|
|
394
|
+
print(f"[MXL] Differential evolution initialization failed: {e}")
|
|
395
|
+
|
|
315
396
|
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
|
|
316
397
|
if dev.using_gpu: # {
|
|
317
398
|
self.X, self.y = dev.convert_array_gpu(self.X), dev.convert_array_gpu(self.y)
|
|
@@ -323,6 +404,15 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
323
404
|
# }
|
|
324
405
|
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
|
|
325
406
|
|
|
407
|
+
print(f"[MXL] Minimization start with betas first_values={betas[:min(8, len(betas))]!r}")
|
|
408
|
+
try:
|
|
409
|
+
before_fun = self.get_loglik_gradient(betas, self.X, self.y, self.panel_info,
|
|
410
|
+
draws, drawstrans, self.weights,
|
|
411
|
+
self.avail, self.batch_size)[0]
|
|
412
|
+
print(f"[MXL] Minimization obj before={before_fun:.6g}")
|
|
413
|
+
except Exception as e:
|
|
414
|
+
print(f"[MXL] Could not evaluate initial objective before minimization: {e}")
|
|
415
|
+
|
|
326
416
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
327
417
|
# Generate bound for L-BFGS-B method
|
|
328
418
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
@@ -378,6 +468,9 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
378
468
|
options = {'gtol': self.gtol, 'maxiter': self.maxiter, 'disp': False}
|
|
379
469
|
result = minimise_func(self.get_loglik_gradient, betas, jac=self.jac, method=self.method,
|
|
380
470
|
args=args, tol=self.ftol, bounds=bounds, options=options)
|
|
471
|
+
print(f"[MXL] Minimization completed: success={result.get('success', None)}, fun={result.get('fun', float('nan')):.6g}, nit={result.get('nit', '?')}")
|
|
472
|
+
if 'x' in result:
|
|
473
|
+
print(f"[MXL] Minimization final betas first_values={np.asarray(result['x'])[:min(8, len(result['x']))]!r}")
|
|
381
474
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
382
475
|
|
|
383
476
|
if hasattr(self, 'method') and self.method == "L-BFGS-B": # {
|
|
@@ -428,38 +521,37 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
428
521
|
# ------------------------------------------------------------------
|
|
429
522
|
@staticmethod
|
|
430
523
|
def _jax_mxl_negloglik(betas, X_jax, y_jax, panel_info_jax, draws_jax,
|
|
431
|
-
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names
|
|
524
|
+
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names,
|
|
525
|
+
correlationLength):
|
|
432
526
|
"""JAX simulation-based log-likelihood for Mixed Logit (standard case).
|
|
433
527
|
|
|
434
528
|
Handles fixed and random (normally/lognormally distributed) parameters.
|
|
435
|
-
Cholesky correlation structure is included via the Kchol segment
|
|
529
|
+
Cholesky correlation structure is included via the Kchol segment;
|
|
530
|
+
uncorrelated random vars use independent std devs in Br_w.
|
|
436
531
|
"""
|
|
437
532
|
import jax.numpy as jnp
|
|
438
533
|
|
|
439
534
|
# ---- split beta vector ----
|
|
440
535
|
Bf = betas[:Kf]
|
|
441
536
|
Br_b = betas[Kf:Kf + Kr]
|
|
442
|
-
chol_v = betas[Kf + Kr:Kf + Kr + Kchol]
|
|
537
|
+
chol_v = betas[Kf + Kr:Kf + Kr + Kchol] # correlated cholesky elements
|
|
443
538
|
Br_w = betas[Kf + Kr + Kchol:Kf + Kr + Kchol + Kbw] # independent std devs
|
|
444
539
|
|
|
445
540
|
# ---- build cholesky matrix ----
|
|
446
|
-
|
|
541
|
+
# Rows 0..correlationLength-1: lower-triangle from chol_v
|
|
542
|
+
# Rows correlationLength..Kr-1: diagonal from Br_w (uncorrelated)
|
|
447
543
|
chol_mat = jnp.zeros((Kr, Kr))
|
|
448
|
-
# place chol values into lower triangle
|
|
449
|
-
# (static indices ok for jit since Kr is compile-time constant)
|
|
450
544
|
idx = 0
|
|
451
|
-
for r in range(
|
|
545
|
+
for r in range(correlationLength):
|
|
452
546
|
for c in range(r + 1):
|
|
453
547
|
chol_mat = chol_mat.at[r, c].set(chol_v[idx])
|
|
454
548
|
idx += 1
|
|
455
|
-
# diagonal of the independent block
|
|
456
549
|
for k in range(Kbw):
|
|
457
|
-
|
|
550
|
+
diag_pos = correlationLength + k
|
|
551
|
+
chol_mat = chol_mat.at[diag_pos, diag_pos].set(jnp.abs(Br_w[k]))
|
|
458
552
|
|
|
459
553
|
# ---- sample random coefficients: (N, Kr, R) ----
|
|
460
|
-
N
|
|
461
|
-
R = draws_jax.shape[2]
|
|
462
|
-
# draws_jax: (N, Kr, R)
|
|
554
|
+
N = X_jax.shape[0]
|
|
463
555
|
Br = Br_b[:, None] + jnp.einsum('kl,nlr->nkr',
|
|
464
556
|
chol_mat,
|
|
465
557
|
draws_jax[:, :Kr, :]) # (N, Kr, R)
|
|
@@ -475,21 +567,24 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
475
567
|
Br_b[k] + Br_w[k] * (draws_jax[:, k, :] - 0.5))
|
|
476
568
|
|
|
477
569
|
# ---- utility ----
|
|
478
|
-
# X_jax: (N, P, J, K)
|
|
479
570
|
P = X_jax.shape[1]
|
|
480
|
-
|
|
571
|
+
J = X_jax.shape[2]
|
|
481
572
|
Xr = X_jax[:, :, :, rvidx] # (N, P, J, Kr)
|
|
482
|
-
|
|
483
|
-
UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
|
|
484
573
|
UR = jnp.einsum('npjk,nkr->npjr', Xr, Br) # (N, P, J, R)
|
|
485
|
-
|
|
574
|
+
|
|
575
|
+
if Kf > 0:
|
|
576
|
+
Xf = X_jax[:, :, :, fxidx] # (N, P, J, Kf)
|
|
577
|
+
UB = jnp.einsum('npjk,k->npj', Xf, Bf) # (N, P, J)
|
|
578
|
+
U = UB[:, :, :, None] + UR
|
|
579
|
+
else:
|
|
580
|
+
U = UR # (N, P, J, R)
|
|
486
581
|
|
|
487
582
|
U = U - jnp.max(U, axis=2, keepdims=True)
|
|
488
583
|
eU = jnp.exp(U)
|
|
489
584
|
p = eU / jnp.sum(eU, axis=2, keepdims=True) # (N, P, J, R)
|
|
490
585
|
|
|
491
|
-
pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R)
|
|
492
|
-
pch = jnp.prod(pch, axis=1) # (N, R)
|
|
586
|
+
pch = jnp.sum(y_jax[:, :, :, None] * p, axis=2) # (N, P, R)
|
|
587
|
+
pch = jnp.prod(pch, axis=1) # (N, R)
|
|
493
588
|
pch = jnp.clip(pch, 1e-300, None)
|
|
494
589
|
|
|
495
590
|
sim_p = jnp.mean(pch, axis=1) # (N,)
|
|
@@ -503,27 +598,38 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
503
598
|
Falls back to standard numpy optimisation on failure.
|
|
504
599
|
"""
|
|
505
600
|
try:
|
|
601
|
+
import os
|
|
602
|
+
os.environ.setdefault("JAX_ENABLE_X64", "True")
|
|
506
603
|
import jax
|
|
604
|
+
jax.config.update("jax_enable_x64", True)
|
|
507
605
|
import jax.numpy as jnp
|
|
508
606
|
from scipy.optimize import minimize as sp_min
|
|
509
607
|
|
|
608
|
+
if int(self.Kr) == 0:
|
|
609
|
+
print("[JAX MXL optimizer] no random coefficients detected (Kr=0); falling back to scipy.")
|
|
610
|
+
return None
|
|
611
|
+
|
|
510
612
|
# Convert static inputs once
|
|
511
613
|
X_jax = jnp.array(self.X, dtype=jnp.float64)
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
614
|
+
_y_np = np.asarray(self.y)
|
|
615
|
+
if _y_np.ndim > 3: # stored as (N, P, J, 1) in some paths
|
|
616
|
+
_y_np = _y_np[..., 0]
|
|
617
|
+
y_jax = jnp.array(_y_np, dtype=jnp.float64)
|
|
618
|
+
pi_jax = jnp.array(self.panel_info, dtype=jnp.float64)
|
|
619
|
+
draws_jax = jnp.array(draws, dtype=jnp.float64)
|
|
515
620
|
|
|
516
621
|
fxidx = jnp.array(self.fxidx, dtype=bool)
|
|
517
622
|
rvidx = jnp.array(self.rvidx, dtype=bool)
|
|
518
623
|
rvdist_names = [d for d in self.rvdist if d is not False]
|
|
519
624
|
|
|
520
625
|
Kf, Kr, Kchol, Kbw = int(self.Kf), int(self.Kr), int(self.Kchol), int(self.Kbw)
|
|
626
|
+
correlationLength = int(self.correlationLength)
|
|
521
627
|
|
|
522
628
|
@jax.jit
|
|
523
629
|
def _neg_ll(b):
|
|
524
630
|
return self._jax_mxl_negloglik(
|
|
525
631
|
b, X_jax, y_jax, pi_jax, draws_jax,
|
|
526
|
-
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names)
|
|
632
|
+
fxidx, rvidx, Kf, Kr, Kchol, Kbw, rvdist_names, correlationLength)
|
|
527
633
|
|
|
528
634
|
_val_grad = jax.jit(jax.value_and_grad(_neg_ll))
|
|
529
635
|
|
|
@@ -141,7 +141,7 @@ class DiscreteChoiceModel(ABC):
|
|
|
141
141
|
''' ---------------------------------------------------------- '''
|
|
142
142
|
''' Function '''
|
|
143
143
|
''' ---------------------------------------------------------- '''
|
|
144
|
-
def __init__(self, jax =
|
|
144
|
+
def __init__(self, jax = True):
|
|
145
145
|
# {
|
|
146
146
|
|
|
147
147
|
self.reset_attributes()
|