ximinf 0.0.8__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ximinf
3
- Version: 0.0.8
3
+ Version: 0.0.16
4
4
  Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
5
5
  Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
6
6
  License: GPL-3.0-or-later
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ximinf"
7
- version = "0.0.8"
7
+ version = "0.0.16"
8
8
  description = "Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae. "
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -72,7 +72,7 @@ def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
72
72
  "alpha": 0.0,
73
73
  "beta": 0.0,
74
74
  "mabs": -19.3,
75
- "gamma": 0.0,
75
+ # "gamma": 0.0,
76
76
  "sigma_int": 0.0, # default intrinsic scatter
77
77
  }
78
78
 
@@ -83,36 +83,33 @@ def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
83
83
  alpha_ = float(params["alpha"])
84
84
  beta_ = float(params["beta"])
85
85
  mabs_ = float(params["mabs"])
86
- gamma_ = float(params["gamma"])
86
+ # gamma_ = float(params["gamma"])
87
87
  sigma_int_ = float(params["sigma_int"])
88
88
 
89
- brokenalpha_model = skysurvey_sniapop.brokenalpha_model
89
+ # brokenalpha_model = skysurvey_sniapop.brokenalpha_model
90
90
 
91
91
  # Generate SNe sample
92
92
  snia = skysurvey.SNeIa.from_draw(
93
93
  size=M,
94
94
  zmax=z_max,
95
- model=brokenalpha_model,
95
+ # model=brokenalpha_model,
96
96
  magabs={
97
- "x1": "@x1",
98
- "c": "@c",
99
97
  "mabs": mabs_,
100
98
  "sigmaint": sigma_int_,
101
- "alpha_low": alpha_,
102
- "alpha_high": alpha_,
99
+ "alpha": alpha_,
103
100
  "beta": beta_,
104
- "gamma": gamma_
105
101
  }
106
102
  )
107
103
 
108
104
  # Apply noise
109
- errormodel = sim.noise_model
110
- errormodel["localcolor"]["kwargs"]["a"] = 2
111
- errormodel["localcolor"]["kwargs"]["loc"] = 0.005
112
- errormodel["localcolor"]["kwargs"]["scale"] = 0.05
113
- noisy_snia = snia.apply_gaussian_noise(errormodel)
114
-
115
- df = noisy_snia.data
105
+ # errormodel = sim.noise_model
106
+ # errormodel["localcolor"]["kwargs"]["a"] = 2
107
+ # errormodel["localcolor"]["kwargs"]["loc"] = 0.005
108
+ # errormodel["localcolor"]["kwargs"]["scale"] = 0.05
109
+ # noisy_snia = snia.apply_gaussian_noise(errormodel)
110
+
111
+ # df = noisy_snia.data
112
+ df = snia.data
116
113
 
117
114
  # Collect requested columns as lists
118
115
  data_dict = {col: list(df[col]) for col in cols if col in df}
@@ -114,7 +114,6 @@ def load_autoregressive_nn(path):
114
114
  dropout_rate=0.0,
115
115
  Nsize_p=shared["Nsize_p"],
116
116
  Nsize_r=shared["Nsize_r"],
117
- N_size_embed=shared["N_size_embed"],
118
117
  n_cols=shared["n_cols"],
119
118
  n_params=n_params_visible,
120
119
  rngs=nnx.Rngs(0),
@@ -0,0 +1,246 @@
1
+ # Import libraries
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import blackjax
5
+ from functools import partial
6
+ from tqdm.notebook import tqdm
7
+
8
+ def distance(theta1, theta2):
9
+ """
10
+ Compute the Euclidean distance between two points in NDIM space.
11
+
12
+ Parameters
13
+ ----------
14
+ theta1 : array-like
15
+ First point in NDIM-dimensional space.
16
+ theta2 : array-like
17
+ Second point in NDIM-dimensional space.
18
+
19
+ Returns
20
+ -------
21
+ float
22
+ The Euclidean distance between `theta1` and `theta2`.
23
+ """
24
+ diff = theta1 - theta2
25
+ return jnp.linalg.norm(diff)
26
+
27
+ # def log_prior(theta, bounds):
28
+ # """
29
+ # Compute the log-prior probability for the parameter `theta`,
30
+ # assuming uniform prior within given bounds.
31
+
32
+ # Parameters
33
+ # ----------
34
+ # theta : array-like
35
+ # The parameter values for which the prior is to be calculated.
36
+ # bounds : jnp.ndarray, optional
37
+ # The bounds on each parameter (default is the global `BOUNDS`).
38
+
39
+ # Returns
40
+ # -------
41
+ # float
42
+ # The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
43
+ # """
44
+
45
+ # in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
46
+ # return jnp.where(in_bounds, 0.0, -jnp.inf)
47
+
48
+ def log_group_prior(theta, bounds, group_indices):
49
+ """
50
+ Log prior for a single parameter group.
51
+ Uniform within bounds, -inf otherwise.
52
+ """
53
+ theta_g = theta[group_indices]
54
+ bounds_g = bounds[group_indices]
55
+
56
+ in_bounds = jnp.all(
57
+ (theta_g >= bounds_g[:, 0]) &
58
+ (theta_g <= bounds_g[:, 1])
59
+ )
60
+
61
+ return jnp.where(in_bounds, 0.0, -jnp.inf)
62
+
63
+
64
+
65
+ def sample_reference_point(rng_key, bounds):
66
+ """
67
+ Sample a reference point within the given bounds uniformly.
68
+
69
+ Parameters
70
+ ----------
71
+ rng_key : jax.random.PRNGKey
72
+ The random key used for sampling.
73
+ bounds : jnp.ndarray, optional
74
+ The bounds for each parameter (default is the global `BOUNDS`).
75
+
76
+ Returns
77
+ -------
78
+ tuple
79
+ A tuple containing the updated `rng_key` and the sampled reference point `theta`.
80
+ """
81
+ ndim = bounds.shape[0]
82
+ rng_key, subkey = jax.random.split(rng_key)
83
+ u = jax.random.uniform(subkey, shape=(ndim,))
84
+ span = bounds[:, 1] - bounds[:, 0]
85
+ theta = bounds[:, 0] + u * span
86
+ return rng_key, theta
87
+
88
+ def inference_loop(rng_key, kernel, initial_state, num_samples):
89
+ """
90
+ Perform an inference loop using a Markov Chain Monte Carlo (MCMC) kernel.
91
+
92
+ Parameters
93
+ ----------
94
+ rng_key : jax.random.PRNGKey
95
+ The random key used for sampling.
96
+ kernel : callable
97
+ The MCMC kernel (e.g., NUTS) used for updating the state.
98
+ initial_state : object
99
+ The initial state of the MCMC chain.
100
+ num_samples : int
101
+ The number of samples to generate in the chain.
102
+
103
+ Returns
104
+ -------
105
+ jax.numpy.ndarray
106
+ The sampled states from the inference loop.
107
+ """
108
+
109
+ def one_step(state, rng):
110
+ state, _ = kernel(rng, state)
111
+ return state, state
112
+ keys = jax.random.split(rng_key, num_samples)
113
+ _, states = jax.lax.scan(one_step, initial_state, keys)
114
+ return states
115
+
116
+ def log_prob_fn_groups(theta, models_per_group, data, bounds,
117
+ param_groups, global_param_names):
118
+
119
+ log_r_sum = 0.0
120
+ log_p_group_sum = 0.0
121
+
122
+ data = data.reshape(1, -1)
123
+
124
+ for g, group in enumerate(param_groups):
125
+
126
+ # --- parameter bookkeeping (unchanged) ---
127
+ prev_groups = [
128
+ p
129
+ for i in range(g)
130
+ for p in (param_groups[i] if isinstance(param_groups[i], list)
131
+ else [param_groups[i]])
132
+ ]
133
+
134
+ group_list = [group] if isinstance(group, str) else group
135
+ visible_param_names = prev_groups + group_list
136
+
137
+ visible_idx = jnp.array(
138
+ [global_param_names.index(name) for name in visible_param_names]
139
+ )
140
+
141
+ theta_visible = theta[visible_idx].reshape(1, -1)
142
+ input_g = jnp.concatenate([data, theta_visible], axis=-1)
143
+
144
+ # --- ratio estimator ---
145
+ logits = models_per_group[g](input_g)
146
+ p = jax.nn.sigmoid(logits)
147
+ log_r_sum += jnp.log(p) - jnp.log1p(-p)
148
+
149
+ # --- marginal prior for this group ---
150
+ group_idx = jnp.array(
151
+ [global_param_names.index(name) for name in group_list]
152
+ )
153
+
154
+ log_p_group_sum += log_group_prior(theta, bounds, group_idx)
155
+
156
+ return jnp.squeeze(log_r_sum + log_p_group_sum)
157
+
158
+
159
+
160
+ @partial(jax.jit, static_argnums=(0, 1, 2))
161
+ def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
162
+ warmup = blackjax.window_adaptation(blackjax.nuts, log_prob)
163
+ rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
164
+ (warmup_state, params), _ = warmup.run(warmup_key, init_position, num_steps=n_warmup)
165
+ kernel = blackjax.nuts(log_prob, **params).step
166
+ rng_key, sample_key = jax.random.split(rng_key)
167
+ states = inference_loop(sample_key, kernel, warmup_state, n_samples)
168
+ return rng_key, states.position
169
+
170
+
171
+ def one_sample_step_groups(rng_key, xi, theta_star, n_warmup, n_samples,
172
+ models_per_group, bounds, param_groups, param_names):
173
+ """
174
+ Sample from posterior using sum of log-likelihoods over all groups.
175
+ """
176
+ rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
177
+
178
+ def log_post(theta):
179
+ return log_prob_fn_groups(theta, models_per_group, xi, bounds, param_groups, param_names)
180
+
181
+ rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
182
+ d_star = distance(theta_star, theta_r0)
183
+ d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
184
+ f_val = jnp.mean(d_samples < d_star)
185
+
186
+ return rng_key, f_val, posterior
187
+
188
+
189
+ def batched_one_sample_step_groups(rng_keys, x_batch, theta_star_batch,
190
+ n_warmup, n_samples, models_per_group, bounds, param_groups, param_names):
191
+ return jax.vmap(
192
+ lambda rng, x, theta: one_sample_step_groups(rng, x[None, :], theta, n_warmup, n_samples,
193
+ models_per_group, bounds, param_groups, param_names),
194
+ in_axes=(0, 0, 0)
195
+ )(rng_keys, x_batch, theta_star_batch)
196
+
197
+ def compute_ecp_tarp_jitted_groups(models_per_group, x_list, theta_star_list, alpha_list,
198
+ n_warmup, n_samples, rng_key, bounds,
199
+ param_groups, param_names):
200
+ """
201
+ Batched ECP computation using multiple group models.
202
+ """
203
+ N = x_list.shape[0]
204
+ rng_key, split_key = jax.random.split(rng_key)
205
+ rng_keys = jax.random.split(split_key, N)
206
+
207
+ # Batched MCMC and distance evaluation
208
+ _, f_vals, posterior_uns = batched_one_sample_step_groups(
209
+ rng_keys, x_list, theta_star_list, n_warmup, n_samples,
210
+ models_per_group, bounds, param_groups, param_names
211
+ )
212
+
213
+ # Compute ECP values for each alpha
214
+ ecp_vals = [jnp.mean(f_vals < (1 - alpha)) for alpha in alpha_list]
215
+
216
+ return ecp_vals, f_vals, posterior_uns, rng_key
217
+
218
+ def compute_ecp_tarp_jitted_with_progress_groups(models_per_group, x_list, theta_star_list, alpha_list,
219
+ n_warmup, n_samples, rng_key, bounds,
220
+ param_groups, param_names, batch_size=20):
221
+ N = x_list.shape[0]
222
+
223
+ posterior_list = []
224
+ f_vals_list = []
225
+
226
+ for start in tqdm(range(0, N, batch_size), desc="Computing ECP batches"):
227
+ end = min(start + batch_size, N)
228
+ x_batch = x_list[start:end]
229
+ theta_batch = theta_star_list[start:end]
230
+
231
+ # Compute ECP and posterior for batch
232
+ _, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted_groups(
233
+ models_per_group, x_batch, theta_batch, alpha_list,
234
+ n_warmup, n_samples, rng_key, bounds,
235
+ param_groups, param_names
236
+ )
237
+
238
+ posterior_list.append(posterior_batch)
239
+ f_vals_list.append(f_vals_batch)
240
+
241
+ posterior_uns = jnp.concatenate(posterior_list, axis=0)
242
+ f_vals_all = jnp.concatenate(f_vals_list, axis=0)
243
+
244
+ ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
245
+
246
+ return ecp_vals, posterior_uns, rng_key
@@ -182,6 +182,46 @@ def train_test_split_jax(X, y, test_size=0.3, shuffle=False, key=None):
182
182
 
183
183
  return X[:N_train], X[N_train:], y[:N_train], y[N_train:]
184
184
 
185
+ def train_test_split_indices_jax(N, test_size=0.3, shuffle=False, key=None, fixed_test_idx=None):
186
+ """
187
+ Generate train/test indices in JAX, optionally using a fixed test set.
188
+
189
+ Parameters
190
+ ----------
191
+ N : int
192
+ Total number of samples.
193
+ test_size : float
194
+ Fraction of the dataset to use as test data.
195
+ shuffle : bool
196
+ Whether to shuffle before splitting (ignored if fixed_test_idx is provided).
197
+ key : jax.random.PRNGKey
198
+ Random key used for shuffling (required if shuffle=True and fixed_test_idx is None).
199
+ fixed_test_idx : jax.numpy.ndarray, optional
200
+ Predefined indices to use as test set (persistent across rounds).
201
+
202
+ Returns
203
+ -------
204
+ train_idx : jax.numpy.ndarray
205
+ Indices for the training set.
206
+ test_idx : jax.numpy.ndarray
207
+ Indices for the test set.
208
+ """
209
+
210
+ N_test = int(jnp.floor(test_size * N))
211
+
212
+ if fixed_test_idx is None:
213
+ if shuffle:
214
+ perm = jax.random.permutation(key, N)
215
+ else:
216
+ perm = jnp.arange(N)
217
+ test_idx = perm[:N_test]
218
+ else:
219
+ test_idx = fixed_test_idx
220
+
221
+ train_idx = jnp.setdiff1d(jnp.arange(N), test_idx)
222
+ return train_idx, test_idx
223
+
224
+
185
225
  @nnx.jit
186
226
  def l2_loss(model, alpha):
187
227
  """
@@ -205,7 +245,7 @@ def l2_loss(model, alpha):
205
245
  return alpha * sum((param ** 2).sum() for param in params)
206
246
 
207
247
  @nnx.jit
208
- def loss_fn(model, batch, l2_reg=1e-7):
248
+ def loss_fn(model, batch, l2_reg=1e-5):
209
249
  """
210
250
  Compute the total loss, which is the sum of the data loss and L2 regularization.
211
251
 
@@ -304,6 +344,100 @@ def pred_step(model, x_batch):
304
344
  logits = model(x_batch)
305
345
  return logits
306
346
 
347
+ class Phi(nnx.Module):
348
+ """
349
+ Neural network module for the Phi network in a Deep Set architecture.
350
+ """
351
+ def __init__(self, Nsize, n_cols, *, rngs):
352
+ self.linear1 = nnx.Linear(n_cols, Nsize, rngs=rngs) #+n_params
353
+ self.linear2 = nnx.Linear(Nsize, Nsize, rngs=rngs)
354
+ self.linear3 = nnx.Linear(Nsize, Nsize, rngs=rngs)
355
+
356
+ def __call__(self, data):
357
+ h = data
358
+
359
+ h = nnx.relu(self.linear1(h))
360
+ h = nnx.relu(self.linear2(h))
361
+ h = nnx.relu(self.linear3(h))
362
+ return h
363
+
364
+
365
+ class Rho(nnx.Module):
366
+ """
367
+ Neural network module for the Rho network in a Deep Set architecture
368
+ with separate LayerNorm for pooled features and theta.
369
+ """
370
+ def __init__(self, Nsize_p, Nsize_r, N_size_params, *, rngs):
371
+ self.linear1 = nnx.Linear(Nsize_p + N_size_params, Nsize_r, rngs=rngs) #
372
+ self.linear2 = nnx.Linear(Nsize_r, Nsize_r, rngs=rngs)
373
+ self.linear3 = nnx.Linear(Nsize_r, 1, rngs=rngs)
374
+
375
+ def __call__(self, dropout, pooled_features, params):
376
+ # Concatenate pooled features and embedding
377
+ x = jnp.concatenate([pooled_features, params], axis=-1)
378
+
379
+ x = nnx.relu(self.linear1(x))
380
+ x = dropout(x)
381
+
382
+ x = nnx.relu(self.linear2(x)) #leaky_relu
383
+ x = dropout(x)
384
+
385
+ return self.linear3(x)
386
+
387
+
388
+ class DeepSetClassifier(nnx.Module):
389
+ """
390
+ Deep Set Classifier model combining Phi and Rho networks.
391
+ """
392
+ def __init__(self, dropout_rate, Nsize_p, Nsize_r,
393
+ n_cols, n_params, *, rngs):
394
+
395
+ self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
396
+ self.n_cols = n_cols
397
+ self.n_params = n_params
398
+
399
+ self.phi = Phi(Nsize_p, n_cols, rngs=rngs)
400
+ self.rho = Rho(Nsize_p, Nsize_r, n_params, rngs=rngs)
401
+
402
+ def __call__(self, input_data):
403
+ # ----------------------------------------------------
404
+ # Accept both shape (N, D) and (D,) without failing
405
+ # ----------------------------------------------------
406
+ if input_data.ndim == 1:
407
+ input_data = input_data[None, :]
408
+
409
+ N = input_data.shape[0]
410
+ input_dim = input_data.shape[1]
411
+
412
+ # Compute M first from input size
413
+ # Total input columns = M*n_cols + n_params + M (mask)
414
+ M = (input_dim - self.n_params) // (self.n_cols + 1)
415
+
416
+ # Reshape data columns
417
+ data = input_data[:, :M*self.n_cols].reshape(N, M, self.n_cols)
418
+
419
+ # Slice mask (last M columns)
420
+ mask = input_data[:, -M-self.n_params:-self.n_params] # shape (N, M)
421
+
422
+ # Parameters
423
+ theta = input_data[:, -self.n_params:] # shape (N, n_params)
424
+
425
+ # Apply Phi
426
+ h = self.phi(data)
427
+
428
+ # Apply mask
429
+ h_masked = h * mask[..., None]
430
+
431
+ # Pool (masked average)
432
+ mask_sum = jnp.sum(mask, axis=1, keepdims=True)
433
+ mask_sum = jnp.where(mask_sum == 0, 1.0, mask_sum)
434
+ pooled = jnp.sum(h_masked, axis=1) / mask_sum
435
+
436
+ # pooled_N = jnp.concatenate([pooled, mask_sum], axis=-1)
437
+
438
+ # Apply Rho
439
+ return self.rho(self.dropout, pooled, theta)
440
+
307
441
  def train_loop(model,
308
442
  optimizer,
309
443
  train_data,
@@ -317,6 +451,10 @@ def train_loop(model,
317
451
  metrics_history,
318
452
  M,
319
453
  N,
454
+ cpu,
455
+ gpu,
456
+ group_id,
457
+ group_params,
320
458
  plot_flag=False):
321
459
  """
322
460
  Train loop with early stopping and optional plotting.
@@ -325,66 +463,63 @@ def train_loop(model,
325
463
  # Initialise stopping criteria
326
464
  best_train_loss = jnp.inf
327
465
  best_test_loss = jnp.inf
466
+ best_train_accuracy = 0.0
467
+ best_test_accuracy = 0.0
328
468
  strikes = 0
329
469
 
330
470
  model.train()
331
471
 
332
472
  for epoch in range(epochs):
333
- # Shuffle the training data using JAX.
334
- # key, subkey = jax.random.split(key)
335
- # perm = jax.random.permutation(subkey, len(train_data))
336
- # train_data = train_data[perm]
337
- # train_labels = train_labels[perm]
338
- # del perm
339
473
 
340
474
  epoch_train_loss = 0
341
- epoch_train_correct = 0
342
- epoch_train_total = 0
475
+ epoch_train_accuracy = 0
343
476
 
344
477
  for i in range(0, len(train_data), batch_size):
345
478
  # Get the current batch of data and labels
346
- batch_data = train_data[i:i+batch_size]
347
- batch_labels = train_labels[i:i+batch_size]
479
+ batch_data = jax.device_put(train_data[i:i+batch_size], gpu)
480
+ batch_labels = jax.device_put(train_labels[i:i+batch_size], gpu)
348
481
 
349
482
  # Perform a training step
350
483
  loss, _ = loss_fn(model, (batch_data, batch_labels))
351
484
  accuracy = accuracy_fn(model, (batch_data, batch_labels))
352
485
  epoch_train_loss += loss
353
486
  # Multiply batch accuracy by batch size to get number of correct predictions
354
- epoch_train_correct += accuracy * len(batch_data)
355
- epoch_train_total += len(batch_data)
487
+ epoch_train_accuracy += accuracy * len(batch_data)
356
488
  train_step(model, optimizer, (batch_data, batch_labels))
357
489
 
358
490
  # Log the training metrics.
359
491
  current_train_loss = epoch_train_loss / (len(train_data) / batch_size)
492
+ current_train_accuracy = epoch_train_accuracy / len(train_data)
360
493
  metrics_history['train_loss'].append(current_train_loss)
361
494
  # Compute overall epoch accuracy
362
- metrics_history['train_accuracy'].append(epoch_train_correct / epoch_train_total)
495
+ metrics_history['train_accuracy'].append(current_train_accuracy)
363
496
 
364
497
  epoch_test_loss = 0
365
- epoch_test_correct = 0
366
- epoch_test_total = 0
498
+ epoch_test_accuracy = 0
367
499
 
368
500
  # Compute the metrics on the test set using the same batching as training
369
501
  for i in range(0, len(test_data), batch_size):
370
- batch_data = test_data[i:i+batch_size]
371
- batch_labels = test_labels[i:i+batch_size]
502
+ batch_data = jax.device_put(test_data[i:i+batch_size], gpu)
503
+ batch_labels = jax.device_put(test_labels[i:i+batch_size], gpu)
372
504
 
373
505
  loss, _ = loss_fn(model, (batch_data, batch_labels))
374
506
  accuracy = accuracy_fn(model, (batch_data, batch_labels))
375
507
  epoch_test_loss += loss
376
- epoch_test_correct += accuracy * len(batch_data)
377
- epoch_test_total += len(batch_data)
508
+ epoch_test_accuracy += accuracy * len(batch_data)
378
509
 
379
510
  # Log the test metrics.
380
511
  current_test_loss = epoch_test_loss / (len(test_data) / batch_size)
512
+ current_test_accuracy = epoch_test_accuracy / len(test_data)
381
513
  metrics_history['test_loss'].append(current_test_loss)
382
- metrics_history['test_accuracy'].append(epoch_test_correct / epoch_test_total)
514
+ metrics_history['test_accuracy'].append(current_test_accuracy)
383
515
 
384
516
  # Early Stopping Check
385
517
  if current_test_loss < best_test_loss:
386
518
  best_test_loss = current_test_loss # Update best test loss
387
519
  strikes = 0
520
+ # elif current_test_accuracy > best_test_accuracy:
521
+ # best_test_accuracy = current_test_accuracy # Update best test accuracy
522
+ # strikes = 0
388
523
  elif current_train_loss >= best_train_loss:
389
524
  strikes = 0
390
525
  elif current_test_loss > best_test_loss and current_train_loss < best_train_loss:
@@ -400,6 +535,8 @@ def train_loop(model,
400
535
  if plot_flag and epoch % 1 == 0:
401
536
  clear_output(wait=True)
402
537
 
538
+ print(f"=== Training model for group {group_id}: {group_params} ===")
539
+
403
540
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
404
541
 
405
542
  # Loss subplot
@@ -417,40 +554,10 @@ def train_loop(model,
417
554
 
418
555
  plt.show()
419
556
 
420
- return model, metrics_history, key
421
-
422
-
423
- # def save_nn(model, path, model_config):
424
- # """
425
- # Save a neural network model to a checkpoint.
557
+ if epoch == epochs-1:
558
+ print(f"\n Reached maximum epochs: {epochs} \n")
426
559
 
427
- # Parameters
428
- # ----------
429
- # model : nnx.Module
430
- # The model to save.
431
- # path : str
432
- # Path to the checkpoint directory.
433
- # model_config : dict
434
- # Configuration dictionary for the model.
435
- # """
436
- # ckpt_dir = os.path.abspath(path)
437
- # ckpt_dir = ocp.test_utils.erase_and_create_empty(ckpt_dir)
438
-
439
- # # Split the model into GraphDef (structure) and State (parameters + buffers)
440
- # _, _, _, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)
441
-
442
- # # Display for debugging (optional)
443
- # # nnx.display(state)
444
-
445
- # # Initialize the checkpointer
446
- # checkpointer = ocp.StandardCheckpointer()
447
-
448
- # # Save State (parameters & non-trainable variables)
449
- # checkpointer.save(ckpt_dir / 'state', state)
450
-
451
- # # Save model configuration for later loading
452
- # with open(ckpt_dir / 'config.json', 'w') as f:
453
- # json.dump(model_config, f)
560
+ return model, metrics_history, key
454
561
 
455
562
  def save_autoregressive_nn(models_per_group, path, model_config):
456
563
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ximinf
3
- Version: 0.0.8
3
+ Version: 0.0.16
4
4
  Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
5
5
  Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
6
6
  License: GPL-3.0-or-later
@@ -1,453 +0,0 @@
1
- # Import libraries
2
- import jax
3
- import jax.numpy as jnp
4
- import blackjax
5
- from functools import partial
6
- from tqdm.notebook import tqdm
7
-
8
- def distance(theta1, theta2):
9
- """
10
- Compute the Euclidean distance between two points in NDIM space.
11
-
12
- Parameters
13
- ----------
14
- theta1 : array-like
15
- First point in NDIM-dimensional space.
16
- theta2 : array-like
17
- Second point in NDIM-dimensional space.
18
-
19
- Returns
20
- -------
21
- float
22
- The Euclidean distance between `theta1` and `theta2`.
23
- """
24
- diff = theta1 - theta2
25
- return jnp.linalg.norm(diff)
26
-
27
- def log_prior(theta, bounds):
28
- """
29
- Compute the log-prior probability for the parameter `theta`,
30
- assuming uniform prior within given bounds.
31
-
32
- Parameters
33
- ----------
34
- theta : array-like
35
- The parameter values for which the prior is to be calculated.
36
- bounds : jnp.ndarray, optional
37
- The bounds on each parameter (default is the global `BOUNDS`).
38
-
39
- Returns
40
- -------
41
- float
42
- The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
43
- """
44
-
45
- in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
46
- return jnp.where(in_bounds, 0.0, -jnp.inf)
47
-
48
- # def log_prob_fn(theta, model, xy_noise, bounds):
49
- # """
50
- # Compute the log-probability for the parameter `theta` using a
51
- # log-prior and the log-likelihood from the neural likelihood ratio approximation.
52
-
53
- # Parameters
54
- # ----------
55
- # theta : array-like
56
- # The parameter values for which the log-probability is computed.
57
- # model : callable
58
- # A function that takes `theta` and produces model logits for computing the likelihood.
59
- # xy_noise : array-like
60
- # Input data with added noise for evaluating the likelihood.
61
-
62
- # Returns
63
- # -------
64
- # float
65
- # The log-probability, which is the sum of the log-prior and the log-likelihood.
66
- # """
67
-
68
- # lp = log_prior(theta, bounds)
69
- # lp = jnp.where(jnp.isfinite(lp), lp, -jnp.inf)
70
- # xy_flat = xy_noise.squeeze()
71
- # inp = jnp.concatenate([xy_flat, theta])[None, :]
72
- # logits = model(inp)
73
- # p = jax.nn.sigmoid(logits).squeeze()
74
- # p = jnp.clip(p, 1e-6, 1 - 1e-6)
75
- # log_like = jnp.log(p) - jnp.log1p(-p)
76
- # return lp + log_like
77
-
78
- def sample_reference_point(rng_key, bounds):
79
- """
80
- Sample a reference point within the given bounds uniformly.
81
-
82
- Parameters
83
- ----------
84
- rng_key : jax.random.PRNGKey
85
- The random key used for sampling.
86
- bounds : jnp.ndarray, optional
87
- The bounds for each parameter (default is the global `BOUNDS`).
88
-
89
- Returns
90
- -------
91
- tuple
92
- A tuple containing the updated `rng_key` and the sampled reference point `theta`.
93
- """
94
- ndim = bounds.shape[0]
95
- rng_key, subkey = jax.random.split(rng_key)
96
- u = jax.random.uniform(subkey, shape=(ndim,))
97
- span = bounds[:, 1] - bounds[:, 0]
98
- theta = bounds[:, 0] + u * span
99
- return rng_key, theta
100
-
101
- def inference_loop(rng_key, kernel, initial_state, num_samples):
102
- """
103
- Perform an inference loop using a Markov Chain Monte Carlo (MCMC) kernel.
104
-
105
- Parameters
106
- ----------
107
- rng_key : jax.random.PRNGKey
108
- The random key used for sampling.
109
- kernel : callable
110
- The MCMC kernel (e.g., NUTS) used for updating the state.
111
- initial_state : object
112
- The initial state of the MCMC chain.
113
- num_samples : int
114
- The number of samples to generate in the chain.
115
-
116
- Returns
117
- -------
118
- jax.numpy.ndarray
119
- The sampled states from the inference loop.
120
- """
121
-
122
- def one_step(state, rng):
123
- state, _ = kernel(rng, state)
124
- return state, state
125
- keys = jax.random.split(rng_key, num_samples)
126
- _, states = jax.lax.scan(one_step, initial_state, keys)
127
- return states
128
-
129
- def log_prob_fn_groups(theta, models_per_group, x, bounds, param_groups, param_names):
130
- """
131
- Compute the sum of log-likelihoods for all groups given full theta.
132
-
133
- Parameters
134
- ----------
135
- theta : jnp.ndarray, shape (n_params,)
136
- Full parameter vector.
137
- models_per_group : list
138
- List of DeepSetClassifier models, one per group.
139
- x : jnp.ndarray
140
- Input data sample (shape: (data_features + ... + n_params))
141
- bounds : jnp.ndarray
142
- Parameter bounds.
143
- param_groups : list
144
- List of parameter groups.
145
- param_names : list
146
- List of all parameter names in order.
147
-
148
- Returns
149
- -------
150
- float
151
- Sum of log-likelihoods over all groups.
152
- """
153
- log_lik_sum = 0.0
154
-
155
- n_params = len(param_names)
156
-
157
- # Use everything except the last n_params entries as data
158
- data_part = x[:-n_params].reshape(1, -1) # 2D
159
- # If mask is required, you can extract it similarly here
160
-
161
- for g, group in enumerate(param_groups):
162
- # Determine visible parameters for this group
163
- prev_groups = [
164
- p
165
- for i in range(g)
166
- for p in (param_groups[i] if isinstance(param_groups[i], list) else [param_groups[i]])
167
- ]
168
- group_list = [group] if isinstance(group, str) else group
169
- visible_param_names = prev_groups + group_list
170
-
171
- # Get visible theta values
172
- visible_idx = jnp.array([param_names.index(name) for name in visible_param_names])
173
- theta_visible = theta[visible_idx].reshape(1, -1) # make 2D
174
-
175
- # Concatenate data with visible parameters
176
- input_g = jnp.concatenate([data_part, theta_visible], axis=-1)
177
-
178
- # Forward pass through the model
179
- logits = models_per_group[g](input_g)
180
- p = jax.nn.sigmoid(logits)
181
-
182
- log_lik_sum += jnp.log(p) + jnp.log(1 - p)
183
-
184
- return jnp.squeeze(log_lik_sum) + log_prior(theta, bounds)
185
-
186
-
187
-
188
-
189
- @partial(jax.jit, static_argnums=(0, 1, 2))
190
- def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
191
- warmup = blackjax.window_adaptation(blackjax.nuts, log_prob)
192
- rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
193
- (warmup_state, params), _ = warmup.run(warmup_key, init_position, num_steps=n_warmup)
194
- kernel = blackjax.nuts(log_prob, **params).step
195
- rng_key, sample_key = jax.random.split(rng_key)
196
- states = inference_loop(sample_key, kernel, warmup_state, n_samples)
197
- return rng_key, states.position
198
-
199
-
200
- # # ========== JIT‐compiled per‐sample step ==========
201
- # @partial(jax.jit, static_argnums=(3, 4, 5))
202
- # def one_sample_step(rng_key, xi, theta_star, n_warmup, n_samples, model, bounds):
203
- # """
204
- # Sample from the posterior distribution using Hamiltonian Monte Carlo (HMC)
205
- # with NUTS (No-U-Turn Sampler) for a given `log_prob`.
206
-
207
- # Parameters
208
- # ----------
209
- # log_prob : callable
210
- # The log-probability function for the model and parameters.
211
- # n_warmup : int
212
- # The number of warmup steps to adapt the sampler.
213
- # n_samples : int
214
- # The number of samples to generate after warmup.
215
- # init_position : array-like
216
- # The initial position for the chain (parameter values).
217
- # rng_key : jax.random.PRNGKey
218
- # The random key used for sampling.
219
-
220
- # Returns
221
- # -------
222
- # jax.numpy.ndarray
223
- # The sampled positions (parameters) from the posterior distribution.
224
- # """
225
-
226
- # # Draw a random reference
227
- # rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
228
-
229
- # def log_post(theta):
230
- # return log_prob_fn(theta, model, xi, bounds)
231
-
232
- # # Run MCMC
233
- # rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
234
-
235
- # # Compute e-c-p distances
236
- # d_star = distance(theta_star, theta_r0)
237
- # d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
238
- # f_val = jnp.mean(d_samples < d_star)
239
-
240
- # return rng_key, f_val, posterior
241
-
242
- def one_sample_step_groups(rng_key, xi, theta_star, n_warmup, n_samples,
243
- models_per_group, bounds, param_groups, param_names):
244
- """
245
- Sample from posterior using sum of log-likelihoods over all groups.
246
- """
247
- rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
248
-
249
- def log_post(theta):
250
- return log_prob_fn_groups(theta, models_per_group, xi, bounds, param_groups, param_names)
251
-
252
- rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
253
- d_star = distance(theta_star, theta_r0)
254
- d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
255
- f_val = jnp.mean(d_samples < d_star)
256
-
257
- return rng_key, f_val, posterior
258
-
259
-
260
- # def batched_one_sample_step(rng_keys, x_batch, theta_star_batch, n_warmup, n_samples, model, bounds):
261
- # """
262
- # Vectorized wrapper over `one_sample_step` using jax.vmap.
263
-
264
- # Parameters
265
- # ----------
266
- # rng_keys : jax.random.PRNGKey
267
- # Batch of random keys.
268
- # x_batch : array-like
269
- # Batch of input data.
270
- # theta_star_batch : array-like
271
- # Batch of true parameter values.
272
- # n_warmup : int
273
- # Number of warmup steps.
274
- # n_samples : int
275
- # Number of samples.
276
- # model : callable
277
- # The model function.
278
- # bounds : array-like
279
- # Parameter bounds.
280
-
281
- # Returns
282
- # -------
283
- # tuple
284
- # (rng_keys, f_vals, posterior_samples)
285
- # """
286
- # return jax.vmap(
287
- # lambda rng, x, theta: one_sample_step(rng, x[None, :], theta, n_warmup, n_samples, model, bounds),
288
- # in_axes=(0, 0, 0)
289
- # )(rng_keys, x_batch, theta_star_batch)
290
-
291
- def batched_one_sample_step_groups(rng_keys, x_batch, theta_star_batch,
292
- n_warmup, n_samples, models_per_group, bounds, param_groups, param_names):
293
- return jax.vmap(
294
- lambda rng, x, theta: one_sample_step_groups(rng, x[None, :], theta, n_warmup, n_samples,
295
- models_per_group, bounds, param_groups, param_names),
296
- in_axes=(0, 0, 0)
297
- )(rng_keys, x_batch, theta_star_batch)
298
-
299
-
300
-
301
- # def compute_ecp_tarp_jitted(model, x_list, theta_star_list, alpha_list, n_warmup, n_samples, rng_key, bounds):
302
- # """
303
- # Compute expected coverage probabilities (ECP) using vectorized sampling.
304
-
305
- # Parameters
306
- # ----------
307
- # model : callable
308
- # The model function.
309
- # x_list : array-like
310
- # List of input data.
311
- # theta_star_list : array-like
312
- # List of true parameter values.
313
- # alpha_list : list of float
314
- # List of alpha values for ECP computation.
315
- # n_warmup : int
316
- # Number of warmup steps.
317
- # n_samples : int
318
- # Number of samples.
319
- # rng_key : jax.random.PRNGKey
320
- # Random key.
321
- # bounds : array-like
322
- # Parameter bounds.
323
-
324
- # Returns
325
- # -------
326
- # tuple
327
- # (ecp_vals, f_vals, posterior_uns, rng_key)
328
- # """
329
- # N = x_list.shape[0]
330
- # rng_key, split_key = jax.random.split(rng_key)
331
- # rng_keys = jax.random.split(split_key, N)
332
-
333
- # # Batched MCMC and distance evaluation
334
- # _, f_vals, posterior_uns = batched_one_sample_step_groups(
335
- # rng_keys, x_list, theta_star_list, n_warmup, n_samples, model, bounds
336
- # )
337
-
338
- # # Compute ECP values for each alpha
339
- # ecp_vals = [jnp.mean(f_vals < (1 - alpha)) for alpha in alpha_list]
340
-
341
- # return ecp_vals, f_vals, posterior_uns, rng_key
342
-
343
-
344
- # def compute_ecp_tarp_jitted_with_progress(model, x_list, theta_star_list, alpha_list,
345
- # n_warmup, n_samples, rng_key, bounds,
346
- # batch_size=20):
347
- # """
348
- # Compute ECP using JITed MCMC in batches with progress reporting via tqdm.
349
-
350
- # Parameters
351
- # ----------
352
- # model : callable
353
- # The model function.
354
- # x_list : array-like
355
- # List of input data.
356
- # theta_star_list : array-like
357
- # List of true parameter values.
358
- # alpha_list : list of float
359
- # List of alpha values for ECP computation.
360
- # n_warmup : int
361
- # Number of warmup steps.
362
- # n_samples : int
363
- # Number of samples.
364
- # rng_key : jax.random.PRNGKey
365
- # Random key.
366
- # bounds : array-like
367
- # Parameter bounds.
368
- # batch_size : int, optional
369
- # Batch size for processing (default is 20).
370
-
371
- # Returns
372
- # -------
373
- # tuple
374
- # (ecp_vals, posterior_uns, rng_key)
375
- # """
376
- # N = x_list.shape[0]
377
-
378
- # posterior_list = []
379
- # f_vals_list = []
380
-
381
- # for start in tqdm(range(0, N, batch_size), desc="Computing ECP batches"):
382
- # end = min(start + batch_size, N)
383
- # x_batch = x_list[start:end]
384
- # theta_batch = theta_star_list[start:end]
385
-
386
- # # Compute ECP and posterior for batch
387
- # _, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted(
388
- # model, x_batch, theta_batch, alpha_list,
389
- # n_warmup, n_samples, rng_key, bounds
390
- # )
391
-
392
- # posterior_list.append(posterior_batch)
393
- # f_vals_list.append(f_vals_batch)
394
-
395
- # # Concatenate across batches
396
- # posterior_uns = jnp.concatenate(posterior_list, axis=0)
397
- # f_vals_all = jnp.concatenate(f_vals_list, axis=0)
398
-
399
- # # Compute final ECP for each alpha
400
- # ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
401
-
402
- # return ecp_vals, posterior_uns, rng_key
403
-
404
- def compute_ecp_tarp_jitted_groups(models_per_group, x_list, theta_star_list, alpha_list,
405
- n_warmup, n_samples, rng_key, bounds,
406
- param_groups, param_names):
407
- """
408
- Batched ECP computation using multiple group models.
409
- """
410
- N = x_list.shape[0]
411
- rng_key, split_key = jax.random.split(rng_key)
412
- rng_keys = jax.random.split(split_key, N)
413
-
414
- # Batched MCMC and distance evaluation
415
- _, f_vals, posterior_uns = batched_one_sample_step_groups(
416
- rng_keys, x_list, theta_star_list, n_warmup, n_samples,
417
- models_per_group, bounds, param_groups, param_names
418
- )
419
-
420
- # Compute ECP values for each alpha
421
- ecp_vals = [jnp.mean(f_vals < (1 - alpha)) for alpha in alpha_list]
422
-
423
- return ecp_vals, f_vals, posterior_uns, rng_key
424
-
425
- def compute_ecp_tarp_jitted_with_progress_groups(models_per_group, x_list, theta_star_list, alpha_list,
426
- n_warmup, n_samples, rng_key, bounds,
427
- param_groups, param_names, batch_size=20):
428
- N = x_list.shape[0]
429
-
430
- posterior_list = []
431
- f_vals_list = []
432
-
433
- for start in tqdm(range(0, N, batch_size), desc="Computing ECP batches"):
434
- end = min(start + batch_size, N)
435
- x_batch = x_list[start:end]
436
- theta_batch = theta_star_list[start:end]
437
-
438
- # Compute ECP and posterior for batch
439
- _, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted_groups(
440
- models_per_group, x_batch, theta_batch, alpha_list,
441
- n_warmup, n_samples, rng_key, bounds,
442
- param_groups, param_names
443
- )
444
-
445
- posterior_list.append(posterior_batch)
446
- f_vals_list.append(f_vals_batch)
447
-
448
- posterior_uns = jnp.concatenate(posterior_list, axis=0)
449
- f_vals_all = jnp.concatenate(f_vals_list, axis=0)
450
-
451
- ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
452
-
453
- return ecp_vals, posterior_uns, rng_key
File without changes
File without changes
File without changes
File without changes