ximinf 0.0.8__tar.gz → 0.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ximinf-0.0.8/src/ximinf.egg-info → ximinf-0.0.16}/PKG-INFO +1 -1
- {ximinf-0.0.8 → ximinf-0.0.16}/pyproject.toml +1 -1
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf/generate_sim.py +13 -16
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf/nn_inference.py +0 -1
- ximinf-0.0.16/src/ximinf/nn_test.py +246 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf/nn_train.py +161 -54
- {ximinf-0.0.8 → ximinf-0.0.16/src/ximinf.egg-info}/PKG-INFO +1 -1
- ximinf-0.0.8/src/ximinf/nn_test.py +0 -453
- {ximinf-0.0.8 → ximinf-0.0.16}/LICENSE +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/README.md +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/setup.cfg +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf/__init__.py +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf/selection_effects.py +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf.egg-info/SOURCES.txt +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf.egg-info/dependency_links.txt +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf.egg-info/requires.txt +0 -0
- {ximinf-0.0.8 → ximinf-0.0.16}/src/ximinf.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "ximinf"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.16"
|
|
8
8
|
description = "Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae. "
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -72,7 +72,7 @@ def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
|
|
|
72
72
|
"alpha": 0.0,
|
|
73
73
|
"beta": 0.0,
|
|
74
74
|
"mabs": -19.3,
|
|
75
|
-
"gamma": 0.0,
|
|
75
|
+
# "gamma": 0.0,
|
|
76
76
|
"sigma_int": 0.0, # default intrinsic scatter
|
|
77
77
|
}
|
|
78
78
|
|
|
@@ -83,36 +83,33 @@ def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
|
|
|
83
83
|
alpha_ = float(params["alpha"])
|
|
84
84
|
beta_ = float(params["beta"])
|
|
85
85
|
mabs_ = float(params["mabs"])
|
|
86
|
-
gamma_ = float(params["gamma"])
|
|
86
|
+
# gamma_ = float(params["gamma"])
|
|
87
87
|
sigma_int_ = float(params["sigma_int"])
|
|
88
88
|
|
|
89
|
-
brokenalpha_model = skysurvey_sniapop.brokenalpha_model
|
|
89
|
+
# brokenalpha_model = skysurvey_sniapop.brokenalpha_model
|
|
90
90
|
|
|
91
91
|
# Generate SNe sample
|
|
92
92
|
snia = skysurvey.SNeIa.from_draw(
|
|
93
93
|
size=M,
|
|
94
94
|
zmax=z_max,
|
|
95
|
-
model=brokenalpha_model,
|
|
95
|
+
# model=brokenalpha_model,
|
|
96
96
|
magabs={
|
|
97
|
-
"x1": "@x1",
|
|
98
|
-
"c": "@c",
|
|
99
97
|
"mabs": mabs_,
|
|
100
98
|
"sigmaint": sigma_int_,
|
|
101
|
-
"
|
|
102
|
-
"alpha_high": alpha_,
|
|
99
|
+
"alpha": alpha_,
|
|
103
100
|
"beta": beta_,
|
|
104
|
-
"gamma": gamma_
|
|
105
101
|
}
|
|
106
102
|
)
|
|
107
103
|
|
|
108
104
|
# Apply noise
|
|
109
|
-
errormodel = sim.noise_model
|
|
110
|
-
errormodel["localcolor"]["kwargs"]["a"] = 2
|
|
111
|
-
errormodel["localcolor"]["kwargs"]["loc"] = 0.005
|
|
112
|
-
errormodel["localcolor"]["kwargs"]["scale"] = 0.05
|
|
113
|
-
noisy_snia = snia.apply_gaussian_noise(errormodel)
|
|
114
|
-
|
|
115
|
-
df = noisy_snia.data
|
|
105
|
+
# errormodel = sim.noise_model
|
|
106
|
+
# errormodel["localcolor"]["kwargs"]["a"] = 2
|
|
107
|
+
# errormodel["localcolor"]["kwargs"]["loc"] = 0.005
|
|
108
|
+
# errormodel["localcolor"]["kwargs"]["scale"] = 0.05
|
|
109
|
+
# noisy_snia = snia.apply_gaussian_noise(errormodel)
|
|
110
|
+
|
|
111
|
+
# df = noisy_snia.data
|
|
112
|
+
df = snia.data
|
|
116
113
|
|
|
117
114
|
# Collect requested columns as lists
|
|
118
115
|
data_dict = {col: list(df[col]) for col in cols if col in df}
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
# Import libraries
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import blackjax
|
|
5
|
+
from functools import partial
|
|
6
|
+
from tqdm.notebook import tqdm
|
|
7
|
+
|
|
8
|
+
def distance(theta1, theta2):
|
|
9
|
+
"""
|
|
10
|
+
Compute the Euclidean distance between two points in NDIM space.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
theta1 : array-like
|
|
15
|
+
First point in NDIM-dimensional space.
|
|
16
|
+
theta2 : array-like
|
|
17
|
+
Second point in NDIM-dimensional space.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
float
|
|
22
|
+
The Euclidean distance between `theta1` and `theta2`.
|
|
23
|
+
"""
|
|
24
|
+
diff = theta1 - theta2
|
|
25
|
+
return jnp.linalg.norm(diff)
|
|
26
|
+
|
|
27
|
+
# def log_prior(theta, bounds):
|
|
28
|
+
# """
|
|
29
|
+
# Compute the log-prior probability for the parameter `theta`,
|
|
30
|
+
# assuming uniform prior within given bounds.
|
|
31
|
+
|
|
32
|
+
# Parameters
|
|
33
|
+
# ----------
|
|
34
|
+
# theta : array-like
|
|
35
|
+
# The parameter values for which the prior is to be calculated.
|
|
36
|
+
# bounds : jnp.ndarray, optional
|
|
37
|
+
# The bounds on each parameter (default is the global `BOUNDS`).
|
|
38
|
+
|
|
39
|
+
# Returns
|
|
40
|
+
# -------
|
|
41
|
+
# float
|
|
42
|
+
# The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
|
|
43
|
+
# """
|
|
44
|
+
|
|
45
|
+
# in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
|
|
46
|
+
# return jnp.where(in_bounds, 0.0, -jnp.inf)
|
|
47
|
+
|
|
48
|
+
def log_group_prior(theta, bounds, group_indices):
|
|
49
|
+
"""
|
|
50
|
+
Log prior for a single parameter group.
|
|
51
|
+
Uniform within bounds, -inf otherwise.
|
|
52
|
+
"""
|
|
53
|
+
theta_g = theta[group_indices]
|
|
54
|
+
bounds_g = bounds[group_indices]
|
|
55
|
+
|
|
56
|
+
in_bounds = jnp.all(
|
|
57
|
+
(theta_g >= bounds_g[:, 0]) &
|
|
58
|
+
(theta_g <= bounds_g[:, 1])
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return jnp.where(in_bounds, 0.0, -jnp.inf)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def sample_reference_point(rng_key, bounds):
|
|
66
|
+
"""
|
|
67
|
+
Sample a reference point within the given bounds uniformly.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
rng_key : jax.random.PRNGKey
|
|
72
|
+
The random key used for sampling.
|
|
73
|
+
bounds : jnp.ndarray, optional
|
|
74
|
+
The bounds for each parameter (default is the global `BOUNDS`).
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
tuple
|
|
79
|
+
A tuple containing the updated `rng_key` and the sampled reference point `theta`.
|
|
80
|
+
"""
|
|
81
|
+
ndim = bounds.shape[0]
|
|
82
|
+
rng_key, subkey = jax.random.split(rng_key)
|
|
83
|
+
u = jax.random.uniform(subkey, shape=(ndim,))
|
|
84
|
+
span = bounds[:, 1] - bounds[:, 0]
|
|
85
|
+
theta = bounds[:, 0] + u * span
|
|
86
|
+
return rng_key, theta
|
|
87
|
+
|
|
88
|
+
def inference_loop(rng_key, kernel, initial_state, num_samples):
|
|
89
|
+
"""
|
|
90
|
+
Perform an inference loop using a Markov Chain Monte Carlo (MCMC) kernel.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
rng_key : jax.random.PRNGKey
|
|
95
|
+
The random key used for sampling.
|
|
96
|
+
kernel : callable
|
|
97
|
+
The MCMC kernel (e.g., NUTS) used for updating the state.
|
|
98
|
+
initial_state : object
|
|
99
|
+
The initial state of the MCMC chain.
|
|
100
|
+
num_samples : int
|
|
101
|
+
The number of samples to generate in the chain.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
jax.numpy.ndarray
|
|
106
|
+
The sampled states from the inference loop.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def one_step(state, rng):
|
|
110
|
+
state, _ = kernel(rng, state)
|
|
111
|
+
return state, state
|
|
112
|
+
keys = jax.random.split(rng_key, num_samples)
|
|
113
|
+
_, states = jax.lax.scan(one_step, initial_state, keys)
|
|
114
|
+
return states
|
|
115
|
+
|
|
116
|
+
def log_prob_fn_groups(theta, models_per_group, data, bounds,
|
|
117
|
+
param_groups, global_param_names):
|
|
118
|
+
|
|
119
|
+
log_r_sum = 0.0
|
|
120
|
+
log_p_group_sum = 0.0
|
|
121
|
+
|
|
122
|
+
data = data.reshape(1, -1)
|
|
123
|
+
|
|
124
|
+
for g, group in enumerate(param_groups):
|
|
125
|
+
|
|
126
|
+
# --- parameter bookkeeping (unchanged) ---
|
|
127
|
+
prev_groups = [
|
|
128
|
+
p
|
|
129
|
+
for i in range(g)
|
|
130
|
+
for p in (param_groups[i] if isinstance(param_groups[i], list)
|
|
131
|
+
else [param_groups[i]])
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
group_list = [group] if isinstance(group, str) else group
|
|
135
|
+
visible_param_names = prev_groups + group_list
|
|
136
|
+
|
|
137
|
+
visible_idx = jnp.array(
|
|
138
|
+
[global_param_names.index(name) for name in visible_param_names]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
theta_visible = theta[visible_idx].reshape(1, -1)
|
|
142
|
+
input_g = jnp.concatenate([data, theta_visible], axis=-1)
|
|
143
|
+
|
|
144
|
+
# --- ratio estimator ---
|
|
145
|
+
logits = models_per_group[g](input_g)
|
|
146
|
+
p = jax.nn.sigmoid(logits)
|
|
147
|
+
log_r_sum += jnp.log(p) - jnp.log1p(-p)
|
|
148
|
+
|
|
149
|
+
# --- marginal prior for this group ---
|
|
150
|
+
group_idx = jnp.array(
|
|
151
|
+
[global_param_names.index(name) for name in group_list]
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
log_p_group_sum += log_group_prior(theta, bounds, group_idx)
|
|
155
|
+
|
|
156
|
+
return jnp.squeeze(log_r_sum + log_p_group_sum)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@partial(jax.jit, static_argnums=(0, 1, 2))
|
|
161
|
+
def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
|
|
162
|
+
warmup = blackjax.window_adaptation(blackjax.nuts, log_prob)
|
|
163
|
+
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
|
|
164
|
+
(warmup_state, params), _ = warmup.run(warmup_key, init_position, num_steps=n_warmup)
|
|
165
|
+
kernel = blackjax.nuts(log_prob, **params).step
|
|
166
|
+
rng_key, sample_key = jax.random.split(rng_key)
|
|
167
|
+
states = inference_loop(sample_key, kernel, warmup_state, n_samples)
|
|
168
|
+
return rng_key, states.position
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def one_sample_step_groups(rng_key, xi, theta_star, n_warmup, n_samples,
|
|
172
|
+
models_per_group, bounds, param_groups, param_names):
|
|
173
|
+
"""
|
|
174
|
+
Sample from posterior using sum of log-likelihoods over all groups.
|
|
175
|
+
"""
|
|
176
|
+
rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
|
|
177
|
+
|
|
178
|
+
def log_post(theta):
|
|
179
|
+
return log_prob_fn_groups(theta, models_per_group, xi, bounds, param_groups, param_names)
|
|
180
|
+
|
|
181
|
+
rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
|
|
182
|
+
d_star = distance(theta_star, theta_r0)
|
|
183
|
+
d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
|
|
184
|
+
f_val = jnp.mean(d_samples < d_star)
|
|
185
|
+
|
|
186
|
+
return rng_key, f_val, posterior
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def batched_one_sample_step_groups(rng_keys, x_batch, theta_star_batch,
|
|
190
|
+
n_warmup, n_samples, models_per_group, bounds, param_groups, param_names):
|
|
191
|
+
return jax.vmap(
|
|
192
|
+
lambda rng, x, theta: one_sample_step_groups(rng, x[None, :], theta, n_warmup, n_samples,
|
|
193
|
+
models_per_group, bounds, param_groups, param_names),
|
|
194
|
+
in_axes=(0, 0, 0)
|
|
195
|
+
)(rng_keys, x_batch, theta_star_batch)
|
|
196
|
+
|
|
197
|
+
def compute_ecp_tarp_jitted_groups(models_per_group, x_list, theta_star_list, alpha_list,
|
|
198
|
+
n_warmup, n_samples, rng_key, bounds,
|
|
199
|
+
param_groups, param_names):
|
|
200
|
+
"""
|
|
201
|
+
Batched ECP computation using multiple group models.
|
|
202
|
+
"""
|
|
203
|
+
N = x_list.shape[0]
|
|
204
|
+
rng_key, split_key = jax.random.split(rng_key)
|
|
205
|
+
rng_keys = jax.random.split(split_key, N)
|
|
206
|
+
|
|
207
|
+
# Batched MCMC and distance evaluation
|
|
208
|
+
_, f_vals, posterior_uns = batched_one_sample_step_groups(
|
|
209
|
+
rng_keys, x_list, theta_star_list, n_warmup, n_samples,
|
|
210
|
+
models_per_group, bounds, param_groups, param_names
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Compute ECP values for each alpha
|
|
214
|
+
ecp_vals = [jnp.mean(f_vals < (1 - alpha)) for alpha in alpha_list]
|
|
215
|
+
|
|
216
|
+
return ecp_vals, f_vals, posterior_uns, rng_key
|
|
217
|
+
|
|
218
|
+
def compute_ecp_tarp_jitted_with_progress_groups(models_per_group, x_list, theta_star_list, alpha_list,
|
|
219
|
+
n_warmup, n_samples, rng_key, bounds,
|
|
220
|
+
param_groups, param_names, batch_size=20):
|
|
221
|
+
N = x_list.shape[0]
|
|
222
|
+
|
|
223
|
+
posterior_list = []
|
|
224
|
+
f_vals_list = []
|
|
225
|
+
|
|
226
|
+
for start in tqdm(range(0, N, batch_size), desc="Computing ECP batches"):
|
|
227
|
+
end = min(start + batch_size, N)
|
|
228
|
+
x_batch = x_list[start:end]
|
|
229
|
+
theta_batch = theta_star_list[start:end]
|
|
230
|
+
|
|
231
|
+
# Compute ECP and posterior for batch
|
|
232
|
+
_, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted_groups(
|
|
233
|
+
models_per_group, x_batch, theta_batch, alpha_list,
|
|
234
|
+
n_warmup, n_samples, rng_key, bounds,
|
|
235
|
+
param_groups, param_names
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
posterior_list.append(posterior_batch)
|
|
239
|
+
f_vals_list.append(f_vals_batch)
|
|
240
|
+
|
|
241
|
+
posterior_uns = jnp.concatenate(posterior_list, axis=0)
|
|
242
|
+
f_vals_all = jnp.concatenate(f_vals_list, axis=0)
|
|
243
|
+
|
|
244
|
+
ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
|
|
245
|
+
|
|
246
|
+
return ecp_vals, posterior_uns, rng_key
|
|
@@ -182,6 +182,46 @@ def train_test_split_jax(X, y, test_size=0.3, shuffle=False, key=None):
|
|
|
182
182
|
|
|
183
183
|
return X[:N_train], X[N_train:], y[:N_train], y[N_train:]
|
|
184
184
|
|
|
185
|
+
def train_test_split_indices_jax(N, test_size=0.3, shuffle=False, key=None, fixed_test_idx=None):
|
|
186
|
+
"""
|
|
187
|
+
Generate train/test indices in JAX, optionally using a fixed test set.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
N : int
|
|
192
|
+
Total number of samples.
|
|
193
|
+
test_size : float
|
|
194
|
+
Fraction of the dataset to use as test data.
|
|
195
|
+
shuffle : bool
|
|
196
|
+
Whether to shuffle before splitting (ignored if fixed_test_idx is provided).
|
|
197
|
+
key : jax.random.PRNGKey
|
|
198
|
+
Random key used for shuffling (required if shuffle=True and fixed_test_idx is None).
|
|
199
|
+
fixed_test_idx : jax.numpy.ndarray, optional
|
|
200
|
+
Predefined indices to use as test set (persistent across rounds).
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
train_idx : jax.numpy.ndarray
|
|
205
|
+
Indices for the training set.
|
|
206
|
+
test_idx : jax.numpy.ndarray
|
|
207
|
+
Indices for the test set.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
N_test = int(jnp.floor(test_size * N))
|
|
211
|
+
|
|
212
|
+
if fixed_test_idx is None:
|
|
213
|
+
if shuffle:
|
|
214
|
+
perm = jax.random.permutation(key, N)
|
|
215
|
+
else:
|
|
216
|
+
perm = jnp.arange(N)
|
|
217
|
+
test_idx = perm[:N_test]
|
|
218
|
+
else:
|
|
219
|
+
test_idx = fixed_test_idx
|
|
220
|
+
|
|
221
|
+
train_idx = jnp.setdiff1d(jnp.arange(N), test_idx)
|
|
222
|
+
return train_idx, test_idx
|
|
223
|
+
|
|
224
|
+
|
|
185
225
|
@nnx.jit
|
|
186
226
|
def l2_loss(model, alpha):
|
|
187
227
|
"""
|
|
@@ -205,7 +245,7 @@ def l2_loss(model, alpha):
|
|
|
205
245
|
return alpha * sum((param ** 2).sum() for param in params)
|
|
206
246
|
|
|
207
247
|
@nnx.jit
|
|
208
|
-
def loss_fn(model, batch, l2_reg=1e-
|
|
248
|
+
def loss_fn(model, batch, l2_reg=1e-5):
|
|
209
249
|
"""
|
|
210
250
|
Compute the total loss, which is the sum of the data loss and L2 regularization.
|
|
211
251
|
|
|
@@ -304,6 +344,100 @@ def pred_step(model, x_batch):
|
|
|
304
344
|
logits = model(x_batch)
|
|
305
345
|
return logits
|
|
306
346
|
|
|
347
|
+
class Phi(nnx.Module):
|
|
348
|
+
"""
|
|
349
|
+
Neural network module for the Phi network in a Deep Set architecture.
|
|
350
|
+
"""
|
|
351
|
+
def __init__(self, Nsize, n_cols, *, rngs):
|
|
352
|
+
self.linear1 = nnx.Linear(n_cols, Nsize, rngs=rngs) #+n_params
|
|
353
|
+
self.linear2 = nnx.Linear(Nsize, Nsize, rngs=rngs)
|
|
354
|
+
self.linear3 = nnx.Linear(Nsize, Nsize, rngs=rngs)
|
|
355
|
+
|
|
356
|
+
def __call__(self, data):
|
|
357
|
+
h = data
|
|
358
|
+
|
|
359
|
+
h = nnx.relu(self.linear1(h))
|
|
360
|
+
h = nnx.relu(self.linear2(h))
|
|
361
|
+
h = nnx.relu(self.linear3(h))
|
|
362
|
+
return h
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class Rho(nnx.Module):
|
|
366
|
+
"""
|
|
367
|
+
Neural network module for the Rho network in a Deep Set architecture
|
|
368
|
+
with separate LayerNorm for pooled features and theta.
|
|
369
|
+
"""
|
|
370
|
+
def __init__(self, Nsize_p, Nsize_r, N_size_params, *, rngs):
|
|
371
|
+
self.linear1 = nnx.Linear(Nsize_p + N_size_params, Nsize_r, rngs=rngs) #
|
|
372
|
+
self.linear2 = nnx.Linear(Nsize_r, Nsize_r, rngs=rngs)
|
|
373
|
+
self.linear3 = nnx.Linear(Nsize_r, 1, rngs=rngs)
|
|
374
|
+
|
|
375
|
+
def __call__(self, dropout, pooled_features, params):
|
|
376
|
+
# Concatenate pooled features and embedding
|
|
377
|
+
x = jnp.concatenate([pooled_features, params], axis=-1)
|
|
378
|
+
|
|
379
|
+
x = nnx.relu(self.linear1(x))
|
|
380
|
+
x = dropout(x)
|
|
381
|
+
|
|
382
|
+
x = nnx.relu(self.linear2(x)) #leaky_relu
|
|
383
|
+
x = dropout(x)
|
|
384
|
+
|
|
385
|
+
return self.linear3(x)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class DeepSetClassifier(nnx.Module):
|
|
389
|
+
"""
|
|
390
|
+
Deep Set Classifier model combining Phi and Rho networks.
|
|
391
|
+
"""
|
|
392
|
+
def __init__(self, dropout_rate, Nsize_p, Nsize_r,
|
|
393
|
+
n_cols, n_params, *, rngs):
|
|
394
|
+
|
|
395
|
+
self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
|
|
396
|
+
self.n_cols = n_cols
|
|
397
|
+
self.n_params = n_params
|
|
398
|
+
|
|
399
|
+
self.phi = Phi(Nsize_p, n_cols, rngs=rngs)
|
|
400
|
+
self.rho = Rho(Nsize_p, Nsize_r, n_params, rngs=rngs)
|
|
401
|
+
|
|
402
|
+
def __call__(self, input_data):
|
|
403
|
+
# ----------------------------------------------------
|
|
404
|
+
# Accept both shape (N, D) and (D,) without failing
|
|
405
|
+
# ----------------------------------------------------
|
|
406
|
+
if input_data.ndim == 1:
|
|
407
|
+
input_data = input_data[None, :]
|
|
408
|
+
|
|
409
|
+
N = input_data.shape[0]
|
|
410
|
+
input_dim = input_data.shape[1]
|
|
411
|
+
|
|
412
|
+
# Compute M first from input size
|
|
413
|
+
# Total input columns = M*n_cols + n_params + M (mask)
|
|
414
|
+
M = (input_dim - self.n_params) // (self.n_cols + 1)
|
|
415
|
+
|
|
416
|
+
# Reshape data columns
|
|
417
|
+
data = input_data[:, :M*self.n_cols].reshape(N, M, self.n_cols)
|
|
418
|
+
|
|
419
|
+
# Slice mask (last M columns)
|
|
420
|
+
mask = input_data[:, -M-self.n_params:-self.n_params] # shape (N, M)
|
|
421
|
+
|
|
422
|
+
# Parameters
|
|
423
|
+
theta = input_data[:, -self.n_params:] # shape (N, n_params)
|
|
424
|
+
|
|
425
|
+
# Apply Phi
|
|
426
|
+
h = self.phi(data)
|
|
427
|
+
|
|
428
|
+
# Apply mask
|
|
429
|
+
h_masked = h * mask[..., None]
|
|
430
|
+
|
|
431
|
+
# Pool (masked average)
|
|
432
|
+
mask_sum = jnp.sum(mask, axis=1, keepdims=True)
|
|
433
|
+
mask_sum = jnp.where(mask_sum == 0, 1.0, mask_sum)
|
|
434
|
+
pooled = jnp.sum(h_masked, axis=1) / mask_sum
|
|
435
|
+
|
|
436
|
+
# pooled_N = jnp.concatenate([pooled, mask_sum], axis=-1)
|
|
437
|
+
|
|
438
|
+
# Apply Rho
|
|
439
|
+
return self.rho(self.dropout, pooled, theta)
|
|
440
|
+
|
|
307
441
|
def train_loop(model,
|
|
308
442
|
optimizer,
|
|
309
443
|
train_data,
|
|
@@ -317,6 +451,10 @@ def train_loop(model,
|
|
|
317
451
|
metrics_history,
|
|
318
452
|
M,
|
|
319
453
|
N,
|
|
454
|
+
cpu,
|
|
455
|
+
gpu,
|
|
456
|
+
group_id,
|
|
457
|
+
group_params,
|
|
320
458
|
plot_flag=False):
|
|
321
459
|
"""
|
|
322
460
|
Train loop with early stopping and optional plotting.
|
|
@@ -325,66 +463,63 @@ def train_loop(model,
|
|
|
325
463
|
# Initialise stopping criteria
|
|
326
464
|
best_train_loss = jnp.inf
|
|
327
465
|
best_test_loss = jnp.inf
|
|
466
|
+
best_train_accuracy = 0.0
|
|
467
|
+
best_test_accuracy = 0.0
|
|
328
468
|
strikes = 0
|
|
329
469
|
|
|
330
470
|
model.train()
|
|
331
471
|
|
|
332
472
|
for epoch in range(epochs):
|
|
333
|
-
# Shuffle the training data using JAX.
|
|
334
|
-
# key, subkey = jax.random.split(key)
|
|
335
|
-
# perm = jax.random.permutation(subkey, len(train_data))
|
|
336
|
-
# train_data = train_data[perm]
|
|
337
|
-
# train_labels = train_labels[perm]
|
|
338
|
-
# del perm
|
|
339
473
|
|
|
340
474
|
epoch_train_loss = 0
|
|
341
|
-
|
|
342
|
-
epoch_train_total = 0
|
|
475
|
+
epoch_train_accuracy = 0
|
|
343
476
|
|
|
344
477
|
for i in range(0, len(train_data), batch_size):
|
|
345
478
|
# Get the current batch of data and labels
|
|
346
|
-
batch_data = train_data[i:i+batch_size]
|
|
347
|
-
batch_labels = train_labels[i:i+batch_size]
|
|
479
|
+
batch_data = jax.device_put(train_data[i:i+batch_size], gpu)
|
|
480
|
+
batch_labels = jax.device_put(train_labels[i:i+batch_size], gpu)
|
|
348
481
|
|
|
349
482
|
# Perform a training step
|
|
350
483
|
loss, _ = loss_fn(model, (batch_data, batch_labels))
|
|
351
484
|
accuracy = accuracy_fn(model, (batch_data, batch_labels))
|
|
352
485
|
epoch_train_loss += loss
|
|
353
486
|
# Multiply batch accuracy by batch size to get number of correct predictions
|
|
354
|
-
|
|
355
|
-
epoch_train_total += len(batch_data)
|
|
487
|
+
epoch_train_accuracy += accuracy * len(batch_data)
|
|
356
488
|
train_step(model, optimizer, (batch_data, batch_labels))
|
|
357
489
|
|
|
358
490
|
# Log the training metrics.
|
|
359
491
|
current_train_loss = epoch_train_loss / (len(train_data) / batch_size)
|
|
492
|
+
current_train_accuracy = epoch_train_accuracy / len(train_data)
|
|
360
493
|
metrics_history['train_loss'].append(current_train_loss)
|
|
361
494
|
# Compute overall epoch accuracy
|
|
362
|
-
metrics_history['train_accuracy'].append(
|
|
495
|
+
metrics_history['train_accuracy'].append(current_train_accuracy)
|
|
363
496
|
|
|
364
497
|
epoch_test_loss = 0
|
|
365
|
-
|
|
366
|
-
epoch_test_total = 0
|
|
498
|
+
epoch_test_accuracy = 0
|
|
367
499
|
|
|
368
500
|
# Compute the metrics on the test set using the same batching as training
|
|
369
501
|
for i in range(0, len(test_data), batch_size):
|
|
370
|
-
batch_data = test_data[i:i+batch_size]
|
|
371
|
-
batch_labels = test_labels[i:i+batch_size]
|
|
502
|
+
batch_data = jax.device_put(test_data[i:i+batch_size], gpu)
|
|
503
|
+
batch_labels = jax.device_put(test_labels[i:i+batch_size], gpu)
|
|
372
504
|
|
|
373
505
|
loss, _ = loss_fn(model, (batch_data, batch_labels))
|
|
374
506
|
accuracy = accuracy_fn(model, (batch_data, batch_labels))
|
|
375
507
|
epoch_test_loss += loss
|
|
376
|
-
|
|
377
|
-
epoch_test_total += len(batch_data)
|
|
508
|
+
epoch_test_accuracy += accuracy * len(batch_data)
|
|
378
509
|
|
|
379
510
|
# Log the test metrics.
|
|
380
511
|
current_test_loss = epoch_test_loss / (len(test_data) / batch_size)
|
|
512
|
+
current_test_accuracy = epoch_test_accuracy / len(test_data)
|
|
381
513
|
metrics_history['test_loss'].append(current_test_loss)
|
|
382
|
-
metrics_history['test_accuracy'].append(
|
|
514
|
+
metrics_history['test_accuracy'].append(current_test_accuracy)
|
|
383
515
|
|
|
384
516
|
# Early Stopping Check
|
|
385
517
|
if current_test_loss < best_test_loss:
|
|
386
518
|
best_test_loss = current_test_loss # Update best test loss
|
|
387
519
|
strikes = 0
|
|
520
|
+
# elif current_test_accuracy > best_test_accuracy:
|
|
521
|
+
# best_test_accuracy = current_test_accuracy # Update best test accuracy
|
|
522
|
+
# strikes = 0
|
|
388
523
|
elif current_train_loss >= best_train_loss:
|
|
389
524
|
strikes = 0
|
|
390
525
|
elif current_test_loss > best_test_loss and current_train_loss < best_train_loss:
|
|
@@ -400,6 +535,8 @@ def train_loop(model,
|
|
|
400
535
|
if plot_flag and epoch % 1 == 0:
|
|
401
536
|
clear_output(wait=True)
|
|
402
537
|
|
|
538
|
+
print(f"=== Training model for group {group_id}: {group_params} ===")
|
|
539
|
+
|
|
403
540
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
|
404
541
|
|
|
405
542
|
# Loss subplot
|
|
@@ -417,40 +554,10 @@ def train_loop(model,
|
|
|
417
554
|
|
|
418
555
|
plt.show()
|
|
419
556
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
# def save_nn(model, path, model_config):
|
|
424
|
-
# """
|
|
425
|
-
# Save a neural network model to a checkpoint.
|
|
557
|
+
if epoch == epochs-1:
|
|
558
|
+
print(f"\n Reached maximum epochs: {epochs} \n")
|
|
426
559
|
|
|
427
|
-
|
|
428
|
-
# ----------
|
|
429
|
-
# model : nnx.Module
|
|
430
|
-
# The model to save.
|
|
431
|
-
# path : str
|
|
432
|
-
# Path to the checkpoint directory.
|
|
433
|
-
# model_config : dict
|
|
434
|
-
# Configuration dictionary for the model.
|
|
435
|
-
# """
|
|
436
|
-
# ckpt_dir = os.path.abspath(path)
|
|
437
|
-
# ckpt_dir = ocp.test_utils.erase_and_create_empty(ckpt_dir)
|
|
438
|
-
|
|
439
|
-
# # Split the model into GraphDef (structure) and State (parameters + buffers)
|
|
440
|
-
# _, _, _, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)
|
|
441
|
-
|
|
442
|
-
# # Display for debugging (optional)
|
|
443
|
-
# # nnx.display(state)
|
|
444
|
-
|
|
445
|
-
# # Initialize the checkpointer
|
|
446
|
-
# checkpointer = ocp.StandardCheckpointer()
|
|
447
|
-
|
|
448
|
-
# # Save State (parameters & non-trainable variables)
|
|
449
|
-
# checkpointer.save(ckpt_dir / 'state', state)
|
|
450
|
-
|
|
451
|
-
# # Save model configuration for later loading
|
|
452
|
-
# with open(ckpt_dir / 'config.json', 'w') as f:
|
|
453
|
-
# json.dump(model_config, f)
|
|
560
|
+
return model, metrics_history, key
|
|
454
561
|
|
|
455
562
|
def save_autoregressive_nn(models_per_group, path, model_config):
|
|
456
563
|
"""
|
|
@@ -1,453 +0,0 @@
|
|
|
1
|
-
# Import libraries
|
|
2
|
-
import jax
|
|
3
|
-
import jax.numpy as jnp
|
|
4
|
-
import blackjax
|
|
5
|
-
from functools import partial
|
|
6
|
-
from tqdm.notebook import tqdm
|
|
7
|
-
|
|
8
|
-
def distance(theta1, theta2):
|
|
9
|
-
"""
|
|
10
|
-
Compute the Euclidean distance between two points in NDIM space.
|
|
11
|
-
|
|
12
|
-
Parameters
|
|
13
|
-
----------
|
|
14
|
-
theta1 : array-like
|
|
15
|
-
First point in NDIM-dimensional space.
|
|
16
|
-
theta2 : array-like
|
|
17
|
-
Second point in NDIM-dimensional space.
|
|
18
|
-
|
|
19
|
-
Returns
|
|
20
|
-
-------
|
|
21
|
-
float
|
|
22
|
-
The Euclidean distance between `theta1` and `theta2`.
|
|
23
|
-
"""
|
|
24
|
-
diff = theta1 - theta2
|
|
25
|
-
return jnp.linalg.norm(diff)
|
|
26
|
-
|
|
27
|
-
def log_prior(theta, bounds):
|
|
28
|
-
"""
|
|
29
|
-
Compute the log-prior probability for the parameter `theta`,
|
|
30
|
-
assuming uniform prior within given bounds.
|
|
31
|
-
|
|
32
|
-
Parameters
|
|
33
|
-
----------
|
|
34
|
-
theta : array-like
|
|
35
|
-
The parameter values for which the prior is to be calculated.
|
|
36
|
-
bounds : jnp.ndarray, optional
|
|
37
|
-
The bounds on each parameter (default is the global `BOUNDS`).
|
|
38
|
-
|
|
39
|
-
Returns
|
|
40
|
-
-------
|
|
41
|
-
float
|
|
42
|
-
The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
|
|
46
|
-
return jnp.where(in_bounds, 0.0, -jnp.inf)
|
|
47
|
-
|
|
48
|
-
# def log_prob_fn(theta, model, xy_noise, bounds):
|
|
49
|
-
# """
|
|
50
|
-
# Compute the log-probability for the parameter `theta` using a
|
|
51
|
-
# log-prior and the log-likelihood from the neural likelihood ratio approximation.
|
|
52
|
-
|
|
53
|
-
# Parameters
|
|
54
|
-
# ----------
|
|
55
|
-
# theta : array-like
|
|
56
|
-
# The parameter values for which the log-probability is computed.
|
|
57
|
-
# model : callable
|
|
58
|
-
# A function that takes `theta` and produces model logits for computing the likelihood.
|
|
59
|
-
# xy_noise : array-like
|
|
60
|
-
# Input data with added noise for evaluating the likelihood.
|
|
61
|
-
|
|
62
|
-
# Returns
|
|
63
|
-
# -------
|
|
64
|
-
# float
|
|
65
|
-
# The log-probability, which is the sum of the log-prior and the log-likelihood.
|
|
66
|
-
# """
|
|
67
|
-
|
|
68
|
-
# lp = log_prior(theta, bounds)
|
|
69
|
-
# lp = jnp.where(jnp.isfinite(lp), lp, -jnp.inf)
|
|
70
|
-
# xy_flat = xy_noise.squeeze()
|
|
71
|
-
# inp = jnp.concatenate([xy_flat, theta])[None, :]
|
|
72
|
-
# logits = model(inp)
|
|
73
|
-
# p = jax.nn.sigmoid(logits).squeeze()
|
|
74
|
-
# p = jnp.clip(p, 1e-6, 1 - 1e-6)
|
|
75
|
-
# log_like = jnp.log(p) - jnp.log1p(-p)
|
|
76
|
-
# return lp + log_like
|
|
77
|
-
|
|
78
|
-
def sample_reference_point(rng_key, bounds):
|
|
79
|
-
"""
|
|
80
|
-
Sample a reference point within the given bounds uniformly.
|
|
81
|
-
|
|
82
|
-
Parameters
|
|
83
|
-
----------
|
|
84
|
-
rng_key : jax.random.PRNGKey
|
|
85
|
-
The random key used for sampling.
|
|
86
|
-
bounds : jnp.ndarray, optional
|
|
87
|
-
The bounds for each parameter (default is the global `BOUNDS`).
|
|
88
|
-
|
|
89
|
-
Returns
|
|
90
|
-
-------
|
|
91
|
-
tuple
|
|
92
|
-
A tuple containing the updated `rng_key` and the sampled reference point `theta`.
|
|
93
|
-
"""
|
|
94
|
-
ndim = bounds.shape[0]
|
|
95
|
-
rng_key, subkey = jax.random.split(rng_key)
|
|
96
|
-
u = jax.random.uniform(subkey, shape=(ndim,))
|
|
97
|
-
span = bounds[:, 1] - bounds[:, 0]
|
|
98
|
-
theta = bounds[:, 0] + u * span
|
|
99
|
-
return rng_key, theta
|
|
100
|
-
|
|
101
|
-
def inference_loop(rng_key, kernel, initial_state, num_samples):
|
|
102
|
-
"""
|
|
103
|
-
Perform an inference loop using a Markov Chain Monte Carlo (MCMC) kernel.
|
|
104
|
-
|
|
105
|
-
Parameters
|
|
106
|
-
----------
|
|
107
|
-
rng_key : jax.random.PRNGKey
|
|
108
|
-
The random key used for sampling.
|
|
109
|
-
kernel : callable
|
|
110
|
-
The MCMC kernel (e.g., NUTS) used for updating the state.
|
|
111
|
-
initial_state : object
|
|
112
|
-
The initial state of the MCMC chain.
|
|
113
|
-
num_samples : int
|
|
114
|
-
The number of samples to generate in the chain.
|
|
115
|
-
|
|
116
|
-
Returns
|
|
117
|
-
-------
|
|
118
|
-
jax.numpy.ndarray
|
|
119
|
-
The sampled states from the inference loop.
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
def one_step(state, rng):
|
|
123
|
-
state, _ = kernel(rng, state)
|
|
124
|
-
return state, state
|
|
125
|
-
keys = jax.random.split(rng_key, num_samples)
|
|
126
|
-
_, states = jax.lax.scan(one_step, initial_state, keys)
|
|
127
|
-
return states
|
|
128
|
-
|
|
129
|
-
def log_prob_fn_groups(theta, models_per_group, x, bounds, param_groups, param_names):
|
|
130
|
-
"""
|
|
131
|
-
Compute the sum of log-likelihoods for all groups given full theta.
|
|
132
|
-
|
|
133
|
-
Parameters
|
|
134
|
-
----------
|
|
135
|
-
theta : jnp.ndarray, shape (n_params,)
|
|
136
|
-
Full parameter vector.
|
|
137
|
-
models_per_group : list
|
|
138
|
-
List of DeepSetClassifier models, one per group.
|
|
139
|
-
x : jnp.ndarray
|
|
140
|
-
Input data sample (shape: (data_features + ... + n_params))
|
|
141
|
-
bounds : jnp.ndarray
|
|
142
|
-
Parameter bounds.
|
|
143
|
-
param_groups : list
|
|
144
|
-
List of parameter groups.
|
|
145
|
-
param_names : list
|
|
146
|
-
List of all parameter names in order.
|
|
147
|
-
|
|
148
|
-
Returns
|
|
149
|
-
-------
|
|
150
|
-
float
|
|
151
|
-
Sum of log-likelihoods over all groups.
|
|
152
|
-
"""
|
|
153
|
-
log_lik_sum = 0.0
|
|
154
|
-
|
|
155
|
-
n_params = len(param_names)
|
|
156
|
-
|
|
157
|
-
# Use everything except the last n_params entries as data
|
|
158
|
-
data_part = x[:-n_params].reshape(1, -1) # 2D
|
|
159
|
-
# If mask is required, you can extract it similarly here
|
|
160
|
-
|
|
161
|
-
for g, group in enumerate(param_groups):
|
|
162
|
-
# Determine visible parameters for this group
|
|
163
|
-
prev_groups = [
|
|
164
|
-
p
|
|
165
|
-
for i in range(g)
|
|
166
|
-
for p in (param_groups[i] if isinstance(param_groups[i], list) else [param_groups[i]])
|
|
167
|
-
]
|
|
168
|
-
group_list = [group] if isinstance(group, str) else group
|
|
169
|
-
visible_param_names = prev_groups + group_list
|
|
170
|
-
|
|
171
|
-
# Get visible theta values
|
|
172
|
-
visible_idx = jnp.array([param_names.index(name) for name in visible_param_names])
|
|
173
|
-
theta_visible = theta[visible_idx].reshape(1, -1) # make 2D
|
|
174
|
-
|
|
175
|
-
# Concatenate data with visible parameters
|
|
176
|
-
input_g = jnp.concatenate([data_part, theta_visible], axis=-1)
|
|
177
|
-
|
|
178
|
-
# Forward pass through the model
|
|
179
|
-
logits = models_per_group[g](input_g)
|
|
180
|
-
p = jax.nn.sigmoid(logits)
|
|
181
|
-
|
|
182
|
-
log_lik_sum += jnp.log(p) + jnp.log(1 - p)
|
|
183
|
-
|
|
184
|
-
return jnp.squeeze(log_lik_sum) + log_prior(theta, bounds)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@partial(jax.jit, static_argnums=(0, 1, 2))
|
|
190
|
-
def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
|
|
191
|
-
warmup = blackjax.window_adaptation(blackjax.nuts, log_prob)
|
|
192
|
-
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
|
|
193
|
-
(warmup_state, params), _ = warmup.run(warmup_key, init_position, num_steps=n_warmup)
|
|
194
|
-
kernel = blackjax.nuts(log_prob, **params).step
|
|
195
|
-
rng_key, sample_key = jax.random.split(rng_key)
|
|
196
|
-
states = inference_loop(sample_key, kernel, warmup_state, n_samples)
|
|
197
|
-
return rng_key, states.position
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
# # ========== JIT‐compiled per‐sample step ==========
|
|
201
|
-
# @partial(jax.jit, static_argnums=(3, 4, 5))
|
|
202
|
-
# def one_sample_step(rng_key, xi, theta_star, n_warmup, n_samples, model, bounds):
|
|
203
|
-
# """
|
|
204
|
-
# Sample from the posterior distribution using Hamiltonian Monte Carlo (HMC)
|
|
205
|
-
# with NUTS (No-U-Turn Sampler) for a given `log_prob`.
|
|
206
|
-
|
|
207
|
-
# Parameters
|
|
208
|
-
# ----------
|
|
209
|
-
# log_prob : callable
|
|
210
|
-
# The log-probability function for the model and parameters.
|
|
211
|
-
# n_warmup : int
|
|
212
|
-
# The number of warmup steps to adapt the sampler.
|
|
213
|
-
# n_samples : int
|
|
214
|
-
# The number of samples to generate after warmup.
|
|
215
|
-
# init_position : array-like
|
|
216
|
-
# The initial position for the chain (parameter values).
|
|
217
|
-
# rng_key : jax.random.PRNGKey
|
|
218
|
-
# The random key used for sampling.
|
|
219
|
-
|
|
220
|
-
# Returns
|
|
221
|
-
# -------
|
|
222
|
-
# jax.numpy.ndarray
|
|
223
|
-
# The sampled positions (parameters) from the posterior distribution.
|
|
224
|
-
# """
|
|
225
|
-
|
|
226
|
-
# # Draw a random reference
|
|
227
|
-
# rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
|
|
228
|
-
|
|
229
|
-
# def log_post(theta):
|
|
230
|
-
# return log_prob_fn(theta, model, xi, bounds)
|
|
231
|
-
|
|
232
|
-
# # Run MCMC
|
|
233
|
-
# rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
|
|
234
|
-
|
|
235
|
-
# # Compute e-c-p distances
|
|
236
|
-
# d_star = distance(theta_star, theta_r0)
|
|
237
|
-
# d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
|
|
238
|
-
# f_val = jnp.mean(d_samples < d_star)
|
|
239
|
-
|
|
240
|
-
# return rng_key, f_val, posterior
|
|
241
|
-
|
|
242
|
-
def one_sample_step_groups(rng_key, xi, theta_star, n_warmup, n_samples,
|
|
243
|
-
models_per_group, bounds, param_groups, param_names):
|
|
244
|
-
"""
|
|
245
|
-
Sample from posterior using sum of log-likelihoods over all groups.
|
|
246
|
-
"""
|
|
247
|
-
rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
|
|
248
|
-
|
|
249
|
-
def log_post(theta):
|
|
250
|
-
return log_prob_fn_groups(theta, models_per_group, xi, bounds, param_groups, param_names)
|
|
251
|
-
|
|
252
|
-
rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
|
|
253
|
-
d_star = distance(theta_star, theta_r0)
|
|
254
|
-
d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
|
|
255
|
-
f_val = jnp.mean(d_samples < d_star)
|
|
256
|
-
|
|
257
|
-
return rng_key, f_val, posterior
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
# def batched_one_sample_step(rng_keys, x_batch, theta_star_batch, n_warmup, n_samples, model, bounds):
|
|
261
|
-
# """
|
|
262
|
-
# Vectorized wrapper over `one_sample_step` using jax.vmap.
|
|
263
|
-
|
|
264
|
-
# Parameters
|
|
265
|
-
# ----------
|
|
266
|
-
# rng_keys : jax.random.PRNGKey
|
|
267
|
-
# Batch of random keys.
|
|
268
|
-
# x_batch : array-like
|
|
269
|
-
# Batch of input data.
|
|
270
|
-
# theta_star_batch : array-like
|
|
271
|
-
# Batch of true parameter values.
|
|
272
|
-
# n_warmup : int
|
|
273
|
-
# Number of warmup steps.
|
|
274
|
-
# n_samples : int
|
|
275
|
-
# Number of samples.
|
|
276
|
-
# model : callable
|
|
277
|
-
# The model function.
|
|
278
|
-
# bounds : array-like
|
|
279
|
-
# Parameter bounds.
|
|
280
|
-
|
|
281
|
-
# Returns
|
|
282
|
-
# -------
|
|
283
|
-
# tuple
|
|
284
|
-
# (rng_keys, f_vals, posterior_samples)
|
|
285
|
-
# """
|
|
286
|
-
# return jax.vmap(
|
|
287
|
-
# lambda rng, x, theta: one_sample_step(rng, x[None, :], theta, n_warmup, n_samples, model, bounds),
|
|
288
|
-
# in_axes=(0, 0, 0)
|
|
289
|
-
# )(rng_keys, x_batch, theta_star_batch)
|
|
290
|
-
|
|
291
|
-
def batched_one_sample_step_groups(rng_keys, x_batch, theta_star_batch,
|
|
292
|
-
n_warmup, n_samples, models_per_group, bounds, param_groups, param_names):
|
|
293
|
-
return jax.vmap(
|
|
294
|
-
lambda rng, x, theta: one_sample_step_groups(rng, x[None, :], theta, n_warmup, n_samples,
|
|
295
|
-
models_per_group, bounds, param_groups, param_names),
|
|
296
|
-
in_axes=(0, 0, 0)
|
|
297
|
-
)(rng_keys, x_batch, theta_star_batch)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
# def compute_ecp_tarp_jitted(model, x_list, theta_star_list, alpha_list, n_warmup, n_samples, rng_key, bounds):
|
|
302
|
-
# """
|
|
303
|
-
# Compute expected coverage probabilities (ECP) using vectorized sampling.
|
|
304
|
-
|
|
305
|
-
# Parameters
|
|
306
|
-
# ----------
|
|
307
|
-
# model : callable
|
|
308
|
-
# The model function.
|
|
309
|
-
# x_list : array-like
|
|
310
|
-
# List of input data.
|
|
311
|
-
# theta_star_list : array-like
|
|
312
|
-
# List of true parameter values.
|
|
313
|
-
# alpha_list : list of float
|
|
314
|
-
# List of alpha values for ECP computation.
|
|
315
|
-
# n_warmup : int
|
|
316
|
-
# Number of warmup steps.
|
|
317
|
-
# n_samples : int
|
|
318
|
-
# Number of samples.
|
|
319
|
-
# rng_key : jax.random.PRNGKey
|
|
320
|
-
# Random key.
|
|
321
|
-
# bounds : array-like
|
|
322
|
-
# Parameter bounds.
|
|
323
|
-
|
|
324
|
-
# Returns
|
|
325
|
-
# -------
|
|
326
|
-
# tuple
|
|
327
|
-
# (ecp_vals, f_vals, posterior_uns, rng_key)
|
|
328
|
-
# """
|
|
329
|
-
# N = x_list.shape[0]
|
|
330
|
-
# rng_key, split_key = jax.random.split(rng_key)
|
|
331
|
-
# rng_keys = jax.random.split(split_key, N)
|
|
332
|
-
|
|
333
|
-
# # Batched MCMC and distance evaluation
|
|
334
|
-
# _, f_vals, posterior_uns = batched_one_sample_step_groups(
|
|
335
|
-
# rng_keys, x_list, theta_star_list, n_warmup, n_samples, model, bounds
|
|
336
|
-
# )
|
|
337
|
-
|
|
338
|
-
# # Compute ECP values for each alpha
|
|
339
|
-
# ecp_vals = [jnp.mean(f_vals < (1 - alpha)) for alpha in alpha_list]
|
|
340
|
-
|
|
341
|
-
# return ecp_vals, f_vals, posterior_uns, rng_key
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
# def compute_ecp_tarp_jitted_with_progress(model, x_list, theta_star_list, alpha_list,
|
|
345
|
-
# n_warmup, n_samples, rng_key, bounds,
|
|
346
|
-
# batch_size=20):
|
|
347
|
-
# """
|
|
348
|
-
# Compute ECP using JITed MCMC in batches with progress reporting via tqdm.
|
|
349
|
-
|
|
350
|
-
# Parameters
|
|
351
|
-
# ----------
|
|
352
|
-
# model : callable
|
|
353
|
-
# The model function.
|
|
354
|
-
# x_list : array-like
|
|
355
|
-
# List of input data.
|
|
356
|
-
# theta_star_list : array-like
|
|
357
|
-
# List of true parameter values.
|
|
358
|
-
# alpha_list : list of float
|
|
359
|
-
# List of alpha values for ECP computation.
|
|
360
|
-
# n_warmup : int
|
|
361
|
-
# Number of warmup steps.
|
|
362
|
-
# n_samples : int
|
|
363
|
-
# Number of samples.
|
|
364
|
-
# rng_key : jax.random.PRNGKey
|
|
365
|
-
# Random key.
|
|
366
|
-
# bounds : array-like
|
|
367
|
-
# Parameter bounds.
|
|
368
|
-
# batch_size : int, optional
|
|
369
|
-
# Batch size for processing (default is 20).
|
|
370
|
-
|
|
371
|
-
# Returns
|
|
372
|
-
# -------
|
|
373
|
-
# tuple
|
|
374
|
-
# (ecp_vals, posterior_uns, rng_key)
|
|
375
|
-
# """
|
|
376
|
-
# N = x_list.shape[0]
|
|
377
|
-
|
|
378
|
-
# posterior_list = []
|
|
379
|
-
# f_vals_list = []
|
|
380
|
-
|
|
381
|
-
# for start in tqdm(range(0, N, batch_size), desc="Computing ECP batches"):
|
|
382
|
-
# end = min(start + batch_size, N)
|
|
383
|
-
# x_batch = x_list[start:end]
|
|
384
|
-
# theta_batch = theta_star_list[start:end]
|
|
385
|
-
|
|
386
|
-
# # Compute ECP and posterior for batch
|
|
387
|
-
# _, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted(
|
|
388
|
-
# model, x_batch, theta_batch, alpha_list,
|
|
389
|
-
# n_warmup, n_samples, rng_key, bounds
|
|
390
|
-
# )
|
|
391
|
-
|
|
392
|
-
# posterior_list.append(posterior_batch)
|
|
393
|
-
# f_vals_list.append(f_vals_batch)
|
|
394
|
-
|
|
395
|
-
# # Concatenate across batches
|
|
396
|
-
# posterior_uns = jnp.concatenate(posterior_list, axis=0)
|
|
397
|
-
# f_vals_all = jnp.concatenate(f_vals_list, axis=0)
|
|
398
|
-
|
|
399
|
-
# # Compute final ECP for each alpha
|
|
400
|
-
# ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
|
|
401
|
-
|
|
402
|
-
# return ecp_vals, posterior_uns, rng_key
|
|
403
|
-
|
|
404
|
-
def compute_ecp_tarp_jitted_groups(models_per_group, x_list, theta_star_list, alpha_list,
|
|
405
|
-
n_warmup, n_samples, rng_key, bounds,
|
|
406
|
-
param_groups, param_names):
|
|
407
|
-
"""
|
|
408
|
-
Batched ECP computation using multiple group models.
|
|
409
|
-
"""
|
|
410
|
-
N = x_list.shape[0]
|
|
411
|
-
rng_key, split_key = jax.random.split(rng_key)
|
|
412
|
-
rng_keys = jax.random.split(split_key, N)
|
|
413
|
-
|
|
414
|
-
# Batched MCMC and distance evaluation
|
|
415
|
-
_, f_vals, posterior_uns = batched_one_sample_step_groups(
|
|
416
|
-
rng_keys, x_list, theta_star_list, n_warmup, n_samples,
|
|
417
|
-
models_per_group, bounds, param_groups, param_names
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
# Compute ECP values for each alpha
|
|
421
|
-
ecp_vals = [jnp.mean(f_vals < (1 - alpha)) for alpha in alpha_list]
|
|
422
|
-
|
|
423
|
-
return ecp_vals, f_vals, posterior_uns, rng_key
|
|
424
|
-
|
|
425
|
-
def compute_ecp_tarp_jitted_with_progress_groups(models_per_group, x_list, theta_star_list, alpha_list,
|
|
426
|
-
n_warmup, n_samples, rng_key, bounds,
|
|
427
|
-
param_groups, param_names, batch_size=20):
|
|
428
|
-
N = x_list.shape[0]
|
|
429
|
-
|
|
430
|
-
posterior_list = []
|
|
431
|
-
f_vals_list = []
|
|
432
|
-
|
|
433
|
-
for start in tqdm(range(0, N, batch_size), desc="Computing ECP batches"):
|
|
434
|
-
end = min(start + batch_size, N)
|
|
435
|
-
x_batch = x_list[start:end]
|
|
436
|
-
theta_batch = theta_star_list[start:end]
|
|
437
|
-
|
|
438
|
-
# Compute ECP and posterior for batch
|
|
439
|
-
_, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted_groups(
|
|
440
|
-
models_per_group, x_batch, theta_batch, alpha_list,
|
|
441
|
-
n_warmup, n_samples, rng_key, bounds,
|
|
442
|
-
param_groups, param_names
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
posterior_list.append(posterior_batch)
|
|
446
|
-
f_vals_list.append(f_vals_batch)
|
|
447
|
-
|
|
448
|
-
posterior_uns = jnp.concatenate(posterior_list, axis=0)
|
|
449
|
-
f_vals_all = jnp.concatenate(f_vals_list, axis=0)
|
|
450
|
-
|
|
451
|
-
ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
|
|
452
|
-
|
|
453
|
-
return ecp_vals, posterior_uns, rng_key
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|