itlog 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- itlog/__init__.py +56 -0
- itlog/backends/__init__.py +5 -0
- itlog/backends/jax_backend.py +67 -0
- itlog/backends/likelihoods.py +220 -0
- itlog/backends/numpy_kernels.py +488 -0
- itlog/benchmarks/__init__.py +0 -0
- itlog/benchmarks/electricity.py +31 -0
- itlog/benchmarks/runner.py +96 -0
- itlog/benchmarks/swissmetro.py +118 -0
- itlog/core/__init__.py +11 -0
- itlog/core/objective.py +36 -0
- itlog/core/optimizer.py +98 -0
- itlog/data.py +362 -0
- itlog/distributions.py +76 -0
- itlog/draws.py +60 -0
- itlog/expr/__init__.py +27 -0
- itlog/expr/compile.py +164 -0
- itlog/expr/nodes.py +122 -0
- itlog/latex.py +131 -0
- itlog/metrics.py +35 -0
- itlog/models/__init__.py +24 -0
- itlog/models/base.py +140 -0
- itlog/models/cross_nested.py +108 -0
- itlog/models/latent_class.py +98 -0
- itlog/models/membership.py +17 -0
- itlog/models/mixed.py +175 -0
- itlog/models/mnl.py +59 -0
- itlog/models/nest.py +18 -0
- itlog/models/nested.py +128 -0
- itlog/models/ordered.py +106 -0
- itlog/results.py +90 -0
- itlog/suite.py +187 -0
- itlog-0.1.0.dist-info/METADATA +441 -0
- itlog-0.1.0.dist-info/RECORD +37 -0
- itlog-0.1.0.dist-info/WHEEL +5 -0
- itlog-0.1.0.dist-info/licenses/LICENSE +21 -0
- itlog-0.1.0.dist-info/top_level.txt +1 -0
itlog/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""itlog: JAX-backed discrete choice modeling."""
|
|
2
|
+
|
|
3
|
+
from itlog.benchmarks.runner import (
|
|
4
|
+
BenchmarkResult,
|
|
5
|
+
available_engines,
|
|
6
|
+
benchmark_fit,
|
|
7
|
+
benchmark_objective_gradient,
|
|
8
|
+
)
|
|
9
|
+
from itlog.data import ChoiceDataset, Field, TensorData, wide_to_long_pylogit
|
|
10
|
+
from itlog.expr import Parameter, Var
|
|
11
|
+
from itlog.latex import expr_to_latex, model_to_latex, result_to_latex
|
|
12
|
+
from itlog.models import (
|
|
13
|
+
BaseChoiceModel,
|
|
14
|
+
ClassMembership,
|
|
15
|
+
CrossNestedLogit,
|
|
16
|
+
LatentClass,
|
|
17
|
+
MixedLogit,
|
|
18
|
+
MultinomialLogit,
|
|
19
|
+
Nest,
|
|
20
|
+
NestedLogit,
|
|
21
|
+
OrderedLogit,
|
|
22
|
+
OrderedProbit,
|
|
23
|
+
)
|
|
24
|
+
from itlog.results import FitResult
|
|
25
|
+
from itlog.suite import ModelReport, ModelSuite
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"ChoiceDataset",
|
|
29
|
+
"Field",
|
|
30
|
+
"TensorData",
|
|
31
|
+
"Parameter",
|
|
32
|
+
"Var",
|
|
33
|
+
"MultinomialLogit",
|
|
34
|
+
"MixedLogit",
|
|
35
|
+
"NestedLogit",
|
|
36
|
+
"CrossNestedLogit",
|
|
37
|
+
"LatentClass",
|
|
38
|
+
"ClassMembership",
|
|
39
|
+
"Nest",
|
|
40
|
+
"OrderedLogit",
|
|
41
|
+
"OrderedProbit",
|
|
42
|
+
"BaseChoiceModel",
|
|
43
|
+
"FitResult",
|
|
44
|
+
"ModelSuite",
|
|
45
|
+
"ModelReport",
|
|
46
|
+
"expr_to_latex",
|
|
47
|
+
"model_to_latex",
|
|
48
|
+
"result_to_latex",
|
|
49
|
+
"wide_to_long_pylogit",
|
|
50
|
+
"BenchmarkResult",
|
|
51
|
+
"available_engines",
|
|
52
|
+
"benchmark_fit",
|
|
53
|
+
"benchmark_objective_gradient",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""JAX JIT objective/gradient factory."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Dict, Tuple
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from itlog.core.objective import make_cached_val_grad
|
|
12
|
+
from itlog.data import TensorData
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def make_jax_objective(
|
|
16
|
+
neg_ll_fn: Callable[[jnp.ndarray], jnp.ndarray],
|
|
17
|
+
) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray], Dict[str, int]]:
|
|
18
|
+
val_grad = jax.jit(jax.value_and_grad(neg_ll_fn))
|
|
19
|
+
|
|
20
|
+
def _eval(theta_np: np.ndarray):
|
|
21
|
+
val, grad = val_grad(jnp.asarray(theta_np))
|
|
22
|
+
return float(val), np.asarray(grad, dtype=np.float64)
|
|
23
|
+
|
|
24
|
+
return make_cached_val_grad(_eval)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def build_feature_arrays(tensors: TensorData) -> Dict[str, jnp.ndarray]:
|
|
28
|
+
return {k: jnp.asarray(v) for k, v in tensors.features.items()}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def make_mnl_objective(compiled, tensors: TensorData):
|
|
32
|
+
"""JAX MNL objective/gradient builder (used for NumPy<->JAX parity checks)."""
|
|
33
|
+
from itlog.backends.likelihoods import mnl_neg_loglik
|
|
34
|
+
|
|
35
|
+
features = build_feature_arrays(tensors)
|
|
36
|
+
choice_idx = jnp.asarray(tensors.choice_idx)
|
|
37
|
+
availability = jnp.asarray(tensors.availability)
|
|
38
|
+
|
|
39
|
+
def neg_ll(theta):
|
|
40
|
+
return mnl_neg_loglik(theta, compiled, features, choice_idx, availability)
|
|
41
|
+
|
|
42
|
+
return make_jax_objective(neg_ll)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def make_mxl_objective(compiled, tensors: TensorData, *, num_draws: int, seed: int):
|
|
46
|
+
"""JAX MXL objective/gradient builder (used for NumPy<->JAX parity checks)."""
|
|
47
|
+
from itlog.models.mixed import _mxl_neg_loglik_jax
|
|
48
|
+
|
|
49
|
+
if tensors.panel_ids is None:
|
|
50
|
+
raise ValueError("Mixed logit requires panel_id")
|
|
51
|
+
features = build_feature_arrays(tensors)
|
|
52
|
+
choice_idx = jnp.asarray(tensors.choice_idx)
|
|
53
|
+
availability = jnp.asarray(tensors.availability)
|
|
54
|
+
unique_panels = np.unique(tensors.panel_ids)
|
|
55
|
+
n_panels = len(unique_panels)
|
|
56
|
+
n_random = len(compiled.random_param_names)
|
|
57
|
+
rng = np.random.default_rng(seed)
|
|
58
|
+
draws = jnp.asarray(rng.standard_normal((n_panels, num_draws, max(n_random, 1))))
|
|
59
|
+
panel_index_np = {pid: i for i, pid in enumerate(unique_panels)}
|
|
60
|
+
panel_index = jnp.asarray([panel_index_np[pid] for pid in tensors.panel_ids])
|
|
61
|
+
|
|
62
|
+
def neg_ll(theta):
|
|
63
|
+
return _mxl_neg_loglik_jax(
|
|
64
|
+
theta, compiled, features, choice_idx, availability, panel_index, draws, n_random
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return make_jax_objective(neg_ll)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""JAX log-likelihood kernels (backend-neutral array API)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
|
|
10
|
+
from itlog.expr.compile import CompiledModel, CompiledTerm
|
|
11
|
+
|
|
12
|
+
NEG_INF = -1e20
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_utilities(
|
|
16
|
+
theta: Any,
|
|
17
|
+
compiled: CompiledModel,
|
|
18
|
+
features: Dict[str, Any],
|
|
19
|
+
*,
|
|
20
|
+
beta_override: Optional[Any] = None,
|
|
21
|
+
) -> Any:
|
|
22
|
+
n_obs = next(iter(features.values())).shape[0]
|
|
23
|
+
n_alt = len(compiled.alt_labels)
|
|
24
|
+
V = jnp.zeros((n_obs, n_alt))
|
|
25
|
+
alt_to_idx = {label: i for i, label in enumerate(compiled.alt_labels)}
|
|
26
|
+
|
|
27
|
+
for alt_label, util in compiled.utilities.items():
|
|
28
|
+
ai = alt_to_idx[alt_label]
|
|
29
|
+
v_alt = jnp.zeros(n_obs)
|
|
30
|
+
for term in util.terms:
|
|
31
|
+
v_alt = v_alt + _eval_term(theta, term, features, ai, beta_override)
|
|
32
|
+
V = V.at[:, ai].set(v_alt)
|
|
33
|
+
return V
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _eval_term(
|
|
37
|
+
theta: Any,
|
|
38
|
+
term: CompiledTerm,
|
|
39
|
+
features: Dict[str, Any],
|
|
40
|
+
alt_idx: int,
|
|
41
|
+
beta_override: Optional[Any],
|
|
42
|
+
) -> Any:
|
|
43
|
+
n_obs = features[next(iter(features))].shape[0]
|
|
44
|
+
|
|
45
|
+
if term.feature_name is None:
|
|
46
|
+
if term.param_idx is not None:
|
|
47
|
+
if beta_override is not None:
|
|
48
|
+
return beta_override[:, term.param_idx]
|
|
49
|
+
return jnp.full(n_obs, theta[term.param_idx])
|
|
50
|
+
return jnp.full(n_obs, term.const_value)
|
|
51
|
+
|
|
52
|
+
x = features[term.feature_name][:, alt_idx]
|
|
53
|
+
if term.param_idx is not None:
|
|
54
|
+
if beta_override is not None:
|
|
55
|
+
beta = beta_override[:, term.param_idx]
|
|
56
|
+
else:
|
|
57
|
+
beta = jnp.full(n_obs, theta[term.param_idx])
|
|
58
|
+
return beta * x
|
|
59
|
+
return term.const_value * x
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def mnl_neg_loglik(
|
|
63
|
+
theta: Any,
|
|
64
|
+
compiled: CompiledModel,
|
|
65
|
+
features: Dict[str, Any],
|
|
66
|
+
choice_idx: Any,
|
|
67
|
+
availability: Any,
|
|
68
|
+
) -> Any:
|
|
69
|
+
V = compute_utilities(theta, compiled, features)
|
|
70
|
+
V = jnp.where(availability, V, NEG_INF)
|
|
71
|
+
chosen = V[jnp.arange(V.shape[0]), choice_idx]
|
|
72
|
+
log_denom = jax.scipy.special.logsumexp(V, axis=1)
|
|
73
|
+
return -jnp.sum(chosen - log_denom)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def nested_neg_loglik(
|
|
77
|
+
theta: Any,
|
|
78
|
+
compiled: CompiledModel,
|
|
79
|
+
features: Dict[str, Any],
|
|
80
|
+
choice_idx: Any,
|
|
81
|
+
availability: Any,
|
|
82
|
+
nest_for_alt: Any,
|
|
83
|
+
nest_scales: Any,
|
|
84
|
+
scale_idx: Any,
|
|
85
|
+
) -> Any:
|
|
86
|
+
"""Nested logit with nest scales in (0, 1] via sigmoid reparam in theta."""
|
|
87
|
+
V = compute_utilities(theta, compiled, features)
|
|
88
|
+
V = jnp.where(availability, V, NEG_INF)
|
|
89
|
+
n_obs, n_alt = V.shape
|
|
90
|
+
n_nests = nest_scales.shape[0]
|
|
91
|
+
|
|
92
|
+
mu = jnp.where(
|
|
93
|
+
scale_idx >= 0,
|
|
94
|
+
jax.nn.sigmoid(nest_scales) * 0.999 + 0.001,
|
|
95
|
+
1.0,
|
|
96
|
+
)
|
|
97
|
+
log_iv = jnp.full((n_obs, n_nests), -jnp.inf)
|
|
98
|
+
for m in range(n_nests):
|
|
99
|
+
in_nest = nest_for_alt == m
|
|
100
|
+
Vm = jnp.where(in_nest[None, :], V / mu[m], NEG_INF)
|
|
101
|
+
log_iv = log_iv.at[:, m].set(jax.scipy.special.logsumexp(Vm, axis=1))
|
|
102
|
+
chosen_nest = nest_for_alt[choice_idx]
|
|
103
|
+
mu_c = mu[chosen_nest]
|
|
104
|
+
Vc = V[jnp.arange(n_obs), choice_idx]
|
|
105
|
+
log_iv_c = log_iv[jnp.arange(n_obs), chosen_nest]
|
|
106
|
+
log_p = (Vc / mu_c) + (mu_c - 1.0) * log_iv_c
|
|
107
|
+
log_denom = jax.scipy.special.logsumexp(mu[None, :] * log_iv, axis=1)
|
|
108
|
+
log_p = log_p - log_denom
|
|
109
|
+
return -jnp.sum(log_p)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def cnl_neg_loglik(
|
|
113
|
+
theta: Any,
|
|
114
|
+
compiled: CompiledModel,
|
|
115
|
+
features: Dict[str, Any],
|
|
116
|
+
choice_idx: Any,
|
|
117
|
+
availability: Any,
|
|
118
|
+
alpha: Any,
|
|
119
|
+
nest_scales: Any,
|
|
120
|
+
scale_idx: Any,
|
|
121
|
+
) -> Any:
|
|
122
|
+
"""Cross-nested logit: alpha[n_alt, n_nests], nest_scales raw -> sigmoid."""
|
|
123
|
+
V = compute_utilities(theta, compiled, features)
|
|
124
|
+
V = jnp.where(availability, V, NEG_INF)
|
|
125
|
+
n_obs, n_alt = V.shape
|
|
126
|
+
n_nests = nest_scales.shape[0]
|
|
127
|
+
|
|
128
|
+
alpha_norm = alpha / jnp.maximum(jnp.sum(alpha, axis=1, keepdims=True), 1e-300)
|
|
129
|
+
mu = jnp.where(
|
|
130
|
+
scale_idx >= 0,
|
|
131
|
+
jax.nn.sigmoid(nest_scales) * 0.999 + 0.001,
|
|
132
|
+
1.0,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
log_iv = jnp.full((n_obs, n_nests), -jnp.inf)
|
|
136
|
+
for m in range(n_nests):
|
|
137
|
+
Vm = V / mu[m] + jnp.log(jnp.maximum(alpha_norm[:, m], 1e-300))[None, :]
|
|
138
|
+
log_iv = log_iv.at[:, m].set(jax.scipy.special.logsumexp(Vm, axis=1))
|
|
139
|
+
|
|
140
|
+
c = choice_idx
|
|
141
|
+
Vc = V[jnp.arange(n_obs), c]
|
|
142
|
+
log_terms = jnp.full((n_obs, n_nests), -jnp.inf)
|
|
143
|
+
for m in range(n_nests):
|
|
144
|
+
log_terms = log_terms.at[:, m].set(
|
|
145
|
+
jnp.log(jnp.maximum(alpha_norm[c, m], 1e-300))
|
|
146
|
+
+ Vc / mu[m]
|
|
147
|
+
+ (mu[m] - 1.0) * log_iv[:, m]
|
|
148
|
+
)
|
|
149
|
+
log_p = jax.scipy.special.logsumexp(log_terms, axis=1) - jax.scipy.special.logsumexp(
|
|
150
|
+
mu[None, :] * log_iv, axis=1
|
|
151
|
+
)
|
|
152
|
+
return -jnp.sum(log_p)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def latent_class_neg_loglik(
|
|
156
|
+
theta: Any,
|
|
157
|
+
compiled_list: list[CompiledModel],
|
|
158
|
+
features: Dict[str, Any],
|
|
159
|
+
choice_idx: Any,
|
|
160
|
+
availability: Any,
|
|
161
|
+
membership_X: Any,
|
|
162
|
+
membership_param_idx: Any,
|
|
163
|
+
) -> Any:
|
|
164
|
+
"""Latent class MNL with softmax membership."""
|
|
165
|
+
n_classes = len(compiled_list)
|
|
166
|
+
n_obs = choice_idx.shape[0]
|
|
167
|
+
|
|
168
|
+
class_ll = []
|
|
169
|
+
for compiled in compiled_list:
|
|
170
|
+
V = compute_utilities(theta, compiled, features)
|
|
171
|
+
V = jnp.where(availability, V, NEG_INF)
|
|
172
|
+
chosen = V[jnp.arange(n_obs), choice_idx]
|
|
173
|
+
log_denom = jax.scipy.special.logsumexp(V, axis=1)
|
|
174
|
+
class_ll.append(chosen - log_denom)
|
|
175
|
+
class_ll = jnp.stack(class_ll, axis=1)
|
|
176
|
+
|
|
177
|
+
# membership logits: reference class 0
|
|
178
|
+
if membership_param_idx.shape[0] == 0:
|
|
179
|
+
log_pi = jnp.zeros((n_obs, n_classes))
|
|
180
|
+
log_pi = log_pi.at[:, 0].set(0.0)
|
|
181
|
+
else:
|
|
182
|
+
logits = jnp.zeros((n_obs, n_classes - 1))
|
|
183
|
+
for k, idxs in enumerate(membership_param_idx):
|
|
184
|
+
for j, pidx in enumerate(idxs):
|
|
185
|
+
logits = logits.at[:, k].add(theta[pidx] * membership_X[:, j])
|
|
186
|
+
log_pi = jnp.concatenate([jnp.zeros((n_obs, 1)), logits], axis=1)
|
|
187
|
+
log_pi = log_pi - jax.scipy.special.logsumexp(log_pi, axis=1, keepdims=True)
|
|
188
|
+
|
|
189
|
+
log_mix = jax.scipy.special.logsumexp(log_pi + class_ll, axis=1)
|
|
190
|
+
return -jnp.sum(log_mix)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def ordered_neg_loglik(
|
|
194
|
+
theta: Any,
|
|
195
|
+
X: Any,
|
|
196
|
+
y_idx: Any,
|
|
197
|
+
n_thresholds: int,
|
|
198
|
+
link: str = "logit",
|
|
199
|
+
) -> Any:
|
|
200
|
+
"""Ordered logit/probit with monotone thresholds via cumulative softplus."""
|
|
201
|
+
n_beta = X.shape[1]
|
|
202
|
+
beta = theta[:n_beta]
|
|
203
|
+
raw_tau = theta[n_beta:]
|
|
204
|
+
tau = jnp.cumsum(jax.nn.softplus(raw_tau))
|
|
205
|
+
eta = X @ beta
|
|
206
|
+
k = y_idx.astype(jnp.int32)
|
|
207
|
+
|
|
208
|
+
if link == "logit":
|
|
209
|
+
cdf = jax.scipy.special.expit
|
|
210
|
+
else:
|
|
211
|
+
cdf = jax.scipy.stats.norm.cdf
|
|
212
|
+
|
|
213
|
+
big = 1e20
|
|
214
|
+
bounds_low = jnp.concatenate([jnp.array([-big]), tau])
|
|
215
|
+
bounds_high = jnp.concatenate([tau, jnp.array([big])])
|
|
216
|
+
cdf_high = cdf(bounds_high[None, :] - eta[:, None])
|
|
217
|
+
cdf_low = cdf(bounds_low[None, :] - eta[:, None])
|
|
218
|
+
cat_probs = jnp.maximum(cdf_high - cdf_low, 1e-300)
|
|
219
|
+
log_p = jnp.log(jnp.take_along_axis(cat_probs, y_idx[:, None], axis=1)).squeeze(1)
|
|
220
|
+
return -jnp.sum(log_p)
|