itlog 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
itlog/__init__.py ADDED
@@ -0,0 +1,56 @@
1
+ """itlog: JAX-backed discrete choice modeling."""
2
+
3
+ from itlog.benchmarks.runner import (
4
+ BenchmarkResult,
5
+ available_engines,
6
+ benchmark_fit,
7
+ benchmark_objective_gradient,
8
+ )
9
+ from itlog.data import ChoiceDataset, Field, TensorData, wide_to_long_pylogit
10
+ from itlog.expr import Parameter, Var
11
+ from itlog.latex import expr_to_latex, model_to_latex, result_to_latex
12
+ from itlog.models import (
13
+ BaseChoiceModel,
14
+ ClassMembership,
15
+ CrossNestedLogit,
16
+ LatentClass,
17
+ MixedLogit,
18
+ MultinomialLogit,
19
+ Nest,
20
+ NestedLogit,
21
+ OrderedLogit,
22
+ OrderedProbit,
23
+ )
24
+ from itlog.results import FitResult
25
+ from itlog.suite import ModelReport, ModelSuite
26
+
27
+ __all__ = [
28
+ "ChoiceDataset",
29
+ "Field",
30
+ "TensorData",
31
+ "Parameter",
32
+ "Var",
33
+ "MultinomialLogit",
34
+ "MixedLogit",
35
+ "NestedLogit",
36
+ "CrossNestedLogit",
37
+ "LatentClass",
38
+ "ClassMembership",
39
+ "Nest",
40
+ "OrderedLogit",
41
+ "OrderedProbit",
42
+ "BaseChoiceModel",
43
+ "FitResult",
44
+ "ModelSuite",
45
+ "ModelReport",
46
+ "expr_to_latex",
47
+ "model_to_latex",
48
+ "result_to_latex",
49
+ "wide_to_long_pylogit",
50
+ "BenchmarkResult",
51
+ "available_engines",
52
+ "benchmark_fit",
53
+ "benchmark_objective_gradient",
54
+ ]
55
+
56
+ __version__ = "0.1.0"
@@ -0,0 +1,5 @@
1
+ """NumPy and JAX estimation backends."""
2
+
3
+ from itlog.backends.jax_backend import build_feature_arrays, make_jax_objective
4
+
5
+ __all__ = ["make_jax_objective", "build_feature_arrays"]
@@ -0,0 +1,67 @@
1
+ """JAX JIT objective/gradient factory."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable, Dict, Tuple
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+
11
+ from itlog.core.objective import make_cached_val_grad
12
+ from itlog.data import TensorData
13
+
14
+
15
+ def make_jax_objective(
16
+ neg_ll_fn: Callable[[jnp.ndarray], jnp.ndarray],
17
+ ) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray], Dict[str, int]]:
18
+ val_grad = jax.jit(jax.value_and_grad(neg_ll_fn))
19
+
20
+ def _eval(theta_np: np.ndarray):
21
+ val, grad = val_grad(jnp.asarray(theta_np))
22
+ return float(val), np.asarray(grad, dtype=np.float64)
23
+
24
+ return make_cached_val_grad(_eval)
25
+
26
+
27
+ def build_feature_arrays(tensors: TensorData) -> Dict[str, jnp.ndarray]:
28
+ return {k: jnp.asarray(v) for k, v in tensors.features.items()}
29
+
30
+
31
+ def make_mnl_objective(compiled, tensors: TensorData):
32
+ """JAX MNL objective/gradient builder (used for NumPy<->JAX parity checks)."""
33
+ from itlog.backends.likelihoods import mnl_neg_loglik
34
+
35
+ features = build_feature_arrays(tensors)
36
+ choice_idx = jnp.asarray(tensors.choice_idx)
37
+ availability = jnp.asarray(tensors.availability)
38
+
39
+ def neg_ll(theta):
40
+ return mnl_neg_loglik(theta, compiled, features, choice_idx, availability)
41
+
42
+ return make_jax_objective(neg_ll)
43
+
44
+
45
+ def make_mxl_objective(compiled, tensors: TensorData, *, num_draws: int, seed: int):
46
+ """JAX MXL objective/gradient builder (used for NumPy<->JAX parity checks)."""
47
+ from itlog.models.mixed import _mxl_neg_loglik_jax
48
+
49
+ if tensors.panel_ids is None:
50
+ raise ValueError("Mixed logit requires panel_id")
51
+ features = build_feature_arrays(tensors)
52
+ choice_idx = jnp.asarray(tensors.choice_idx)
53
+ availability = jnp.asarray(tensors.availability)
54
+ unique_panels = np.unique(tensors.panel_ids)
55
+ n_panels = len(unique_panels)
56
+ n_random = len(compiled.random_param_names)
57
+ rng = np.random.default_rng(seed)
58
+ draws = jnp.asarray(rng.standard_normal((n_panels, num_draws, max(n_random, 1))))
59
+ panel_index_np = {pid: i for i, pid in enumerate(unique_panels)}
60
+ panel_index = jnp.asarray([panel_index_np[pid] for pid in tensors.panel_ids])
61
+
62
+ def neg_ll(theta):
63
+ return _mxl_neg_loglik_jax(
64
+ theta, compiled, features, choice_idx, availability, panel_index, draws, n_random
65
+ )
66
+
67
+ return make_jax_objective(neg_ll)
@@ -0,0 +1,220 @@
1
+ """JAX log-likelihood kernels (backend-neutral array API)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+
10
+ from itlog.expr.compile import CompiledModel, CompiledTerm
11
+
12
+ NEG_INF = -1e20
13
+
14
+
15
+ def compute_utilities(
16
+ theta: Any,
17
+ compiled: CompiledModel,
18
+ features: Dict[str, Any],
19
+ *,
20
+ beta_override: Optional[Any] = None,
21
+ ) -> Any:
22
+ n_obs = next(iter(features.values())).shape[0]
23
+ n_alt = len(compiled.alt_labels)
24
+ V = jnp.zeros((n_obs, n_alt))
25
+ alt_to_idx = {label: i for i, label in enumerate(compiled.alt_labels)}
26
+
27
+ for alt_label, util in compiled.utilities.items():
28
+ ai = alt_to_idx[alt_label]
29
+ v_alt = jnp.zeros(n_obs)
30
+ for term in util.terms:
31
+ v_alt = v_alt + _eval_term(theta, term, features, ai, beta_override)
32
+ V = V.at[:, ai].set(v_alt)
33
+ return V
34
+
35
+
36
+ def _eval_term(
37
+ theta: Any,
38
+ term: CompiledTerm,
39
+ features: Dict[str, Any],
40
+ alt_idx: int,
41
+ beta_override: Optional[Any],
42
+ ) -> Any:
43
+ n_obs = features[next(iter(features))].shape[0]
44
+
45
+ if term.feature_name is None:
46
+ if term.param_idx is not None:
47
+ if beta_override is not None:
48
+ return beta_override[:, term.param_idx]
49
+ return jnp.full(n_obs, theta[term.param_idx])
50
+ return jnp.full(n_obs, term.const_value)
51
+
52
+ x = features[term.feature_name][:, alt_idx]
53
+ if term.param_idx is not None:
54
+ if beta_override is not None:
55
+ beta = beta_override[:, term.param_idx]
56
+ else:
57
+ beta = jnp.full(n_obs, theta[term.param_idx])
58
+ return beta * x
59
+ return term.const_value * x
60
+
61
+
62
+ def mnl_neg_loglik(
63
+ theta: Any,
64
+ compiled: CompiledModel,
65
+ features: Dict[str, Any],
66
+ choice_idx: Any,
67
+ availability: Any,
68
+ ) -> Any:
69
+ V = compute_utilities(theta, compiled, features)
70
+ V = jnp.where(availability, V, NEG_INF)
71
+ chosen = V[jnp.arange(V.shape[0]), choice_idx]
72
+ log_denom = jax.scipy.special.logsumexp(V, axis=1)
73
+ return -jnp.sum(chosen - log_denom)
74
+
75
+
76
+ def nested_neg_loglik(
77
+ theta: Any,
78
+ compiled: CompiledModel,
79
+ features: Dict[str, Any],
80
+ choice_idx: Any,
81
+ availability: Any,
82
+ nest_for_alt: Any,
83
+ nest_scales: Any,
84
+ scale_idx: Any,
85
+ ) -> Any:
86
+ """Nested logit with nest scales in (0, 1] via sigmoid reparam in theta."""
87
+ V = compute_utilities(theta, compiled, features)
88
+ V = jnp.where(availability, V, NEG_INF)
89
+ n_obs, n_alt = V.shape
90
+ n_nests = nest_scales.shape[0]
91
+
92
+ mu = jnp.where(
93
+ scale_idx >= 0,
94
+ jax.nn.sigmoid(nest_scales) * 0.999 + 0.001,
95
+ 1.0,
96
+ )
97
+ log_iv = jnp.full((n_obs, n_nests), -jnp.inf)
98
+ for m in range(n_nests):
99
+ in_nest = nest_for_alt == m
100
+ Vm = jnp.where(in_nest[None, :], V / mu[m], NEG_INF)
101
+ log_iv = log_iv.at[:, m].set(jax.scipy.special.logsumexp(Vm, axis=1))
102
+ chosen_nest = nest_for_alt[choice_idx]
103
+ mu_c = mu[chosen_nest]
104
+ Vc = V[jnp.arange(n_obs), choice_idx]
105
+ log_iv_c = log_iv[jnp.arange(n_obs), chosen_nest]
106
+ log_p = (Vc / mu_c) + (mu_c - 1.0) * log_iv_c
107
+ log_denom = jax.scipy.special.logsumexp(mu[None, :] * log_iv, axis=1)
108
+ log_p = log_p - log_denom
109
+ return -jnp.sum(log_p)
110
+
111
+
112
+ def cnl_neg_loglik(
113
+ theta: Any,
114
+ compiled: CompiledModel,
115
+ features: Dict[str, Any],
116
+ choice_idx: Any,
117
+ availability: Any,
118
+ alpha: Any,
119
+ nest_scales: Any,
120
+ scale_idx: Any,
121
+ ) -> Any:
122
+ """Cross-nested logit: alpha[n_alt, n_nests], nest_scales raw -> sigmoid."""
123
+ V = compute_utilities(theta, compiled, features)
124
+ V = jnp.where(availability, V, NEG_INF)
125
+ n_obs, n_alt = V.shape
126
+ n_nests = nest_scales.shape[0]
127
+
128
+ alpha_norm = alpha / jnp.maximum(jnp.sum(alpha, axis=1, keepdims=True), 1e-300)
129
+ mu = jnp.where(
130
+ scale_idx >= 0,
131
+ jax.nn.sigmoid(nest_scales) * 0.999 + 0.001,
132
+ 1.0,
133
+ )
134
+
135
+ log_iv = jnp.full((n_obs, n_nests), -jnp.inf)
136
+ for m in range(n_nests):
137
+ Vm = V / mu[m] + jnp.log(jnp.maximum(alpha_norm[:, m], 1e-300))[None, :]
138
+ log_iv = log_iv.at[:, m].set(jax.scipy.special.logsumexp(Vm, axis=1))
139
+
140
+ c = choice_idx
141
+ Vc = V[jnp.arange(n_obs), c]
142
+ log_terms = jnp.full((n_obs, n_nests), -jnp.inf)
143
+ for m in range(n_nests):
144
+ log_terms = log_terms.at[:, m].set(
145
+ jnp.log(jnp.maximum(alpha_norm[c, m], 1e-300))
146
+ + Vc / mu[m]
147
+ + (mu[m] - 1.0) * log_iv[:, m]
148
+ )
149
+ log_p = jax.scipy.special.logsumexp(log_terms, axis=1) - jax.scipy.special.logsumexp(
150
+ mu[None, :] * log_iv, axis=1
151
+ )
152
+ return -jnp.sum(log_p)
153
+
154
+
155
+ def latent_class_neg_loglik(
156
+ theta: Any,
157
+ compiled_list: list[CompiledModel],
158
+ features: Dict[str, Any],
159
+ choice_idx: Any,
160
+ availability: Any,
161
+ membership_X: Any,
162
+ membership_param_idx: Any,
163
+ ) -> Any:
164
+ """Latent class MNL with softmax membership."""
165
+ n_classes = len(compiled_list)
166
+ n_obs = choice_idx.shape[0]
167
+
168
+ class_ll = []
169
+ for compiled in compiled_list:
170
+ V = compute_utilities(theta, compiled, features)
171
+ V = jnp.where(availability, V, NEG_INF)
172
+ chosen = V[jnp.arange(n_obs), choice_idx]
173
+ log_denom = jax.scipy.special.logsumexp(V, axis=1)
174
+ class_ll.append(chosen - log_denom)
175
+ class_ll = jnp.stack(class_ll, axis=1)
176
+
177
+ # membership logits: reference class 0
178
+ if membership_param_idx.shape[0] == 0:
179
+ log_pi = jnp.zeros((n_obs, n_classes))
180
+ log_pi = log_pi.at[:, 0].set(0.0)
181
+ else:
182
+ logits = jnp.zeros((n_obs, n_classes - 1))
183
+ for k, idxs in enumerate(membership_param_idx):
184
+ for j, pidx in enumerate(idxs):
185
+ logits = logits.at[:, k].add(theta[pidx] * membership_X[:, j])
186
+ log_pi = jnp.concatenate([jnp.zeros((n_obs, 1)), logits], axis=1)
187
+ log_pi = log_pi - jax.scipy.special.logsumexp(log_pi, axis=1, keepdims=True)
188
+
189
+ log_mix = jax.scipy.special.logsumexp(log_pi + class_ll, axis=1)
190
+ return -jnp.sum(log_mix)
191
+
192
+
193
+ def ordered_neg_loglik(
194
+ theta: Any,
195
+ X: Any,
196
+ y_idx: Any,
197
+ n_thresholds: int,
198
+ link: str = "logit",
199
+ ) -> Any:
200
+ """Ordered logit/probit with monotone thresholds via cumulative softplus."""
201
+ n_beta = X.shape[1]
202
+ beta = theta[:n_beta]
203
+ raw_tau = theta[n_beta:]
204
+ tau = jnp.cumsum(jax.nn.softplus(raw_tau))
205
+ eta = X @ beta
206
+ k = y_idx.astype(jnp.int32)
207
+
208
+ if link == "logit":
209
+ cdf = jax.scipy.special.expit
210
+ else:
211
+ cdf = jax.scipy.stats.norm.cdf
212
+
213
+ big = 1e20
214
+ bounds_low = jnp.concatenate([jnp.array([-big]), tau])
215
+ bounds_high = jnp.concatenate([tau, jnp.array([big])])
216
+ cdf_high = cdf(bounds_high[None, :] - eta[:, None])
217
+ cdf_low = cdf(bounds_low[None, :] - eta[:, None])
218
+ cat_probs = jnp.maximum(cdf_high - cdf_low, 1e-300)
219
+ log_p = jnp.log(jnp.take_along_axis(cat_probs, y_idx[:, None], axis=1)).squeeze(1)
220
+ return -jnp.sum(log_p)