ximinf 0.0.2__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ximinf
3
- Version: 0.0.2
3
+ Version: 0.0.16
4
4
  Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
5
5
  Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
6
6
  License: GPL-3.0-or-later
7
7
  Project-URL: Homepage, https://github.com/a-trigui/ximinf
8
+ Project-URL: Documentation, https://ximinf.readthedocs.io
8
9
  Keywords: cosmology,supernovae,simulation based inference
9
10
  Classifier: Programming Language :: Python :: 3
10
11
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ximinf"
7
- version = "0.0.2"
7
+ version = "0.0.16"
8
8
  description = "Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae. "
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -35,7 +35,7 @@ classifiers = [
35
35
 
36
36
  [project.urls]
37
37
  Homepage = "https://github.com/a-trigui/ximinf"
38
- # Documentation = "https://my_package.readthedocs.io"
38
+ Documentation = "https://ximinf.readthedocs.io"
39
39
 
40
40
  [project.optional-dependencies]
41
41
  notebooks = ["jupyter", "matplotlib"]
@@ -0,0 +1,7 @@
1
+ # src/ximinf/__init__.py
2
+
3
+ # from .generate_sim import *
4
+ # from .nn_inference import *
5
+ # from .nn_train import *
6
+ # from .nn_test import *
7
+ # from .selection_effects import *
@@ -0,0 +1,117 @@
1
+ # Simulation libraries
2
+ import skysurvey
3
+ import numpy as np
4
+ from pyDOE import lhs # LHS sampler
5
+
6
+ def scan_params(ranges, N, dtype=np.float32):
7
+ """
8
+ Generate sampled parameter sets using Latin Hypercube Sampling (LHS).
9
+
10
+ Parameters
11
+ ----------
12
+ ranges : dict
13
+ Mapping parameter names to (min, max) tuples.
14
+ N : int
15
+ Number of samples.
16
+ dtype : data-type, optional
17
+ Numeric type for the sampled arrays (default is np.float32).
18
+
19
+ Returns
20
+ -------
21
+ params_dict : dict
22
+ Dictionary of parameter arrays of shape (N,).
23
+ """
24
+ param_names = list(ranges.keys())
25
+ n_params = len(param_names)
26
+
27
+ # LHS unit samples in [0,1]
28
+ unit_samples = lhs(n_params, samples=N)
29
+
30
+ # Scale unit samples to parameter ranges
31
+ params_dict = {}
32
+ for i, p in enumerate(param_names):
33
+ low, high = ranges[p]
34
+ params_dict[p] = (unit_samples[:, i] * (high - low) + low).astype(dtype)
35
+
36
+ return params_dict
37
+
38
+ def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
39
+ """
40
+ Simulate a single dataset of SNe Ia.
41
+
42
+ Parameters
43
+ ----------
44
+ params_dict : dict
45
+ Dictionary of model parameters (alpha, beta, mabs, gamma, sigma_int, etc.).
46
+ z_max : float
47
+ Maximum redshift.
48
+ M : int
49
+ Number of SNe to simulate.
50
+ cols : list of str
51
+ List of columns to include in the output.
52
+ N : int, optional
53
+ Total number of simulations (for progress printing).
54
+ i : int, optional
55
+ Current simulation index (for progress printing).
56
+
57
+ Returns
58
+ -------
59
+ data_dict : dict
60
+ Dictionary of lists (one per column) containing the simulated data.
61
+ """
62
+ import ztfidr.simulation as sim
63
+ import skysurvey_sniapop
64
+
65
+ # Print progress
66
+ if N is not None and i is not None:
67
+ if (i+1) % max(1, N//10) == 0 or i == N-1:
68
+ print(f"Simulation {i+1}/{N}", end="\r", flush=True)
69
+
70
+ # Define default parameters including sigma_int
71
+ default_params = {
72
+ "alpha": 0.0,
73
+ "beta": 0.0,
74
+ "mabs": -19.3,
75
+ # "gamma": 0.0,
76
+ "sigma_int": 0.0, # default intrinsic scatter
77
+ }
78
+
79
+ # Merge defaults with provided params (params_dict takes priority)
80
+ params = {**default_params, **params_dict}
81
+
82
+ # Ensure all are floats
83
+ alpha_ = float(params["alpha"])
84
+ beta_ = float(params["beta"])
85
+ mabs_ = float(params["mabs"])
86
+ # gamma_ = float(params["gamma"])
87
+ sigma_int_ = float(params["sigma_int"])
88
+
89
+ # brokenalpha_model = skysurvey_sniapop.brokenalpha_model
90
+
91
+ # Generate SNe sample
92
+ snia = skysurvey.SNeIa.from_draw(
93
+ size=M,
94
+ zmax=z_max,
95
+ # model=brokenalpha_model,
96
+ magabs={
97
+ "mabs": mabs_,
98
+ "sigmaint": sigma_int_,
99
+ "alpha": alpha_,
100
+ "beta": beta_,
101
+ }
102
+ )
103
+
104
+ # Apply noise
105
+ # errormodel = sim.noise_model
106
+ # errormodel["localcolor"]["kwargs"]["a"] = 2
107
+ # errormodel["localcolor"]["kwargs"]["loc"] = 0.005
108
+ # errormodel["localcolor"]["kwargs"]["scale"] = 0.05
109
+ # noisy_snia = snia.apply_gaussian_noise(errormodel)
110
+
111
+ # df = noisy_snia.data
112
+ df = snia.data
113
+
114
+ # Collect requested columns as lists
115
+ data_dict = {col: list(df[col]) for col in cols if col in df}
116
+
117
+ return data_dict
@@ -0,0 +1,135 @@
1
+ # Standard
2
+ import os
3
+ import json
4
+
5
+ # Jax
6
+ from flax import nnx
7
+
8
+ # Checkpointing
9
+ import orbax.checkpoint as ocp # Checkpointing library
10
+ ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
11
+ import pathlib # File path handling library
12
+
13
+ # Modules
14
+ import ximinf.nn_train as nntr
15
+
16
+ # def load_nn(path):
17
+ # """
18
+ # Load a neural network model from a checkpoint.
19
+
20
+ # Parameters
21
+ # ----------
22
+ # path : str
23
+ # Path to the checkpoint directory.
24
+
25
+ # Returns
26
+ # -------
27
+ # model : nnx.Module
28
+ # The loaded neural network model.
29
+
30
+ # Raises
31
+ # ------
32
+ # ValueError
33
+ # If the checkpoint directory or config file does not exist.
34
+ # """
35
+ # # Define the checkpoint directory
36
+ # ckpt_dir = os.path.abspath(path)
37
+ # ckpt_dir = pathlib.Path(ckpt_dir).resolve()
38
+
39
+ # # Ensure the folder is removed before saving
40
+ # if ckpt_dir.exists()==False:
41
+ # # Make an error
42
+ # raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
43
+
44
+ # # Load model configuration
45
+ # config_path = ckpt_dir / 'config.json'
46
+ # if not config_path.exists():
47
+ # raise ValueError("Model config file not found in checkpoint directory.")
48
+
49
+ # with open(config_path, 'r') as f:
50
+ # model_config = json.load(f)
51
+
52
+ # Nsize_p = model_config['Nsize_p']
53
+ # Nsize_r = model_config['Nsize_r']
54
+ # n_cols = model_config['n_cols']
55
+ # n_params = model_config['n_params']
56
+ # N_size_embed = model_config['N_size_embed']
57
+
58
+ # # 1. Re-create the checkpointer
59
+ # checkpointer = ocp.StandardCheckpointer()
60
+
61
+ # # Split the model into GraphDef (structure) and State (parameters + buffers)
62
+ # abstract_model = nnx.eval_shape(lambda: nntr.DeepSetClassifier(0.0, Nsize_p, Nsize_r, N_size_embed, n_cols, n_params, rngs=nnx.Rngs(0)))
63
+ # abs_graphdef, abs_rngkey, abs_rngcount, _ = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)
64
+
65
+ # # 3. Restore
66
+ # state_restored = checkpointer.restore(ckpt_dir / 'state')
67
+ # print('NNX State restored: ')
68
+
69
+ # model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)
70
+
71
+ # nnx.display(model)
72
+
73
+ # return model
74
+
75
+ def load_autoregressive_nn(path):
76
+ """
77
+ Load an autoregressive stack of NNX models.
78
+
79
+ Parameters
80
+ ----------
81
+ path : str
82
+ Checkpoint directory.
83
+
84
+ Returns
85
+ -------
86
+ models_per_group : list[nnx.Module]
87
+ Reconstructed models, one per group.
88
+ model_config : dict
89
+ Loaded configuration dictionary.
90
+ """
91
+ ckpt_dir = pathlib.Path(path).resolve()
92
+ if not ckpt_dir.exists():
93
+ raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist.")
94
+
95
+ config_path = ckpt_dir / "config.json"
96
+ if not config_path.exists():
97
+ raise ValueError("Model config file not found.")
98
+
99
+ with open(config_path, "r") as f:
100
+ model_config = json.load(f)
101
+
102
+ shared = model_config["shared"]
103
+ group_configs = model_config["groups"]
104
+
105
+ checkpointer = ocp.StandardCheckpointer()
106
+ models_per_group = []
107
+
108
+ for gconf in group_configs:
109
+ n_params_visible = gconf["n_params_visible"]
110
+
111
+ # Recreate abstract model (shape-only)
112
+ abstract_model = nnx.eval_shape(
113
+ lambda: nntr.DeepSetClassifier( # It should not work, there is no class DeepSetClassifier defined in nntr, check how this should be properly done
114
+ dropout_rate=0.0,
115
+ Nsize_p=shared["Nsize_p"],
116
+ Nsize_r=shared["Nsize_r"],
117
+ n_cols=shared["n_cols"],
118
+ n_params=n_params_visible,
119
+ rngs=nnx.Rngs(0),
120
+ )
121
+ )
122
+
123
+ graphdef, rngkey, rngcount, _ = nnx.split(
124
+ abstract_model, nnx.RngKey, nnx.RngCount, ...
125
+ )
126
+
127
+ # Restore parameters
128
+ state = checkpointer.restore(
129
+ ckpt_dir / f"state_group_{gconf['group_id']}"
130
+ )
131
+
132
+ model = nnx.merge(graphdef, rngkey, rngcount, state)
133
+ models_per_group.append(model)
134
+
135
+ return models_per_group, model_config
@@ -24,56 +24,43 @@ def distance(theta1, theta2):
24
24
  diff = theta1 - theta2
25
25
  return jnp.linalg.norm(diff)
26
26
 
27
- def log_prior(theta, bounds):
28
- """
29
- Compute the log-prior probability for the parameter `theta`,
30
- assuming uniform prior within given bounds.
31
-
32
- Parameters
33
- ----------
34
- theta : array-like
35
- The parameter values for which the prior is to be calculated.
36
- bounds : jnp.ndarray, optional
37
- The bounds on each parameter (default is the global `BOUNDS`).
38
-
39
- Returns
40
- -------
41
- float
42
- The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
43
- """
27
+ # def log_prior(theta, bounds):
28
+ # """
29
+ # Compute the log-prior probability for the parameter `theta`,
30
+ # assuming uniform prior within given bounds.
31
+
32
+ # Parameters
33
+ # ----------
34
+ # theta : array-like
35
+ # The parameter values for which the prior is to be calculated.
36
+ # bounds : jnp.ndarray, optional
37
+ # The bounds on each parameter (default is the global `BOUNDS`).
38
+
39
+ # Returns
40
+ # -------
41
+ # float
42
+ # The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
43
+ # """
44
+
45
+ # in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
46
+ # return jnp.where(in_bounds, 0.0, -jnp.inf)
47
+
48
+ def log_group_prior(theta, bounds, group_indices):
49
+ """
50
+ Log prior for a single parameter group.
51
+ Uniform within bounds, -inf otherwise.
52
+ """
53
+ theta_g = theta[group_indices]
54
+ bounds_g = bounds[group_indices]
55
+
56
+ in_bounds = jnp.all(
57
+ (theta_g >= bounds_g[:, 0]) &
58
+ (theta_g <= bounds_g[:, 1])
59
+ )
44
60
 
45
- in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
46
61
  return jnp.where(in_bounds, 0.0, -jnp.inf)
47
62
 
48
- def log_prob_fn(theta, model, xy_noise, bounds):
49
- """
50
- Compute the log-probability for the parameter `theta` using a
51
- log-prior and the log-likelihood from the neural likelihood ratio approximation.
52
63
 
53
- Parameters
54
- ----------
55
- theta : array-like
56
- The parameter values for which the log-probability is computed.
57
- model : callable
58
- A function that takes `theta` and produces model logits for computing the likelihood.
59
- xy_noise : array-like
60
- Input data with added noise for evaluating the likelihood.
61
-
62
- Returns
63
- -------
64
- float
65
- The log-probability, which is the sum of the log-prior and the log-likelihood.
66
- """
67
-
68
- lp = log_prior(theta, bounds)
69
- lp = jnp.where(jnp.isfinite(lp), lp, -jnp.inf)
70
- xy_flat = xy_noise.squeeze()
71
- inp = jnp.concatenate([xy_flat, theta])[None, :]
72
- logits = model(inp)
73
- p = jax.nn.sigmoid(logits).squeeze()
74
- p = jnp.clip(p, 1e-6, 1 - 1e-6)
75
- log_like = jnp.log(p) - jnp.log1p(-p)
76
- return lp + log_like
77
64
 
78
65
  def sample_reference_point(rng_key, bounds):
79
66
  """
@@ -126,6 +113,50 @@ def inference_loop(rng_key, kernel, initial_state, num_samples):
126
113
  _, states = jax.lax.scan(one_step, initial_state, keys)
127
114
  return states
128
115
 
116
+ def log_prob_fn_groups(theta, models_per_group, data, bounds,
117
+ param_groups, global_param_names):
118
+
119
+ log_r_sum = 0.0
120
+ log_p_group_sum = 0.0
121
+
122
+ data = data.reshape(1, -1)
123
+
124
+ for g, group in enumerate(param_groups):
125
+
126
+ # --- parameter bookkeeping (unchanged) ---
127
+ prev_groups = [
128
+ p
129
+ for i in range(g)
130
+ for p in (param_groups[i] if isinstance(param_groups[i], list)
131
+ else [param_groups[i]])
132
+ ]
133
+
134
+ group_list = [group] if isinstance(group, str) else group
135
+ visible_param_names = prev_groups + group_list
136
+
137
+ visible_idx = jnp.array(
138
+ [global_param_names.index(name) for name in visible_param_names]
139
+ )
140
+
141
+ theta_visible = theta[visible_idx].reshape(1, -1)
142
+ input_g = jnp.concatenate([data, theta_visible], axis=-1)
143
+
144
+ # --- ratio estimator ---
145
+ logits = models_per_group[g](input_g)
146
+ p = jax.nn.sigmoid(logits)
147
+ log_r_sum += jnp.log(p) - jnp.log1p(-p)
148
+
149
+ # --- marginal prior for this group ---
150
+ group_idx = jnp.array(
151
+ [global_param_names.index(name) for name in group_list]
152
+ )
153
+
154
+ log_p_group_sum += log_group_prior(theta, bounds, group_idx)
155
+
156
+ return jnp.squeeze(log_r_sum + log_p_group_sum)
157
+
158
+
159
+
129
160
  @partial(jax.jit, static_argnums=(0, 1, 2))
130
161
  def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
131
162
  warmup = blackjax.window_adaptation(blackjax.nuts, log_prob)
@@ -137,42 +168,17 @@ def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
137
168
  return rng_key, states.position
138
169
 
139
170
 
140
- # ========== JIT‐compiled per‐sample step ==========
141
- @partial(jax.jit, static_argnums=(3, 4, 5))
142
- def one_sample_step(rng_key, xi, theta_star, n_warmup, n_samples, model, bounds):
171
+ def one_sample_step_groups(rng_key, xi, theta_star, n_warmup, n_samples,
172
+ models_per_group, bounds, param_groups, param_names):
143
173
  """
144
- Sample from the posterior distribution using Hamiltonian Monte Carlo (HMC)
145
- with NUTS (No-U-Turn Sampler) for a given `log_prob`.
146
-
147
- Parameters
148
- ----------
149
- log_prob : callable
150
- The log-probability function for the model and parameters.
151
- n_warmup : int
152
- The number of warmup steps to adapt the sampler.
153
- n_samples : int
154
- The number of samples to generate after warmup.
155
- init_position : array-like
156
- The initial position for the chain (parameter values).
157
- rng_key : jax.random.PRNGKey
158
- The random key used for sampling.
159
-
160
- Returns
161
- -------
162
- jax.numpy.ndarray
163
- The sampled positions (parameters) from the posterior distribution.
174
+ Sample from posterior using sum of log-likelihoods over all groups.
164
175
  """
165
-
166
- # Draw a random reference
167
176
  rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
168
177
 
169
178
  def log_post(theta):
170
- return log_prob_fn(theta, model, xi, bounds)
179
+ return log_prob_fn_groups(theta, models_per_group, xi, bounds, param_groups, param_names)
171
180
 
172
- # Run MCMC
173
181
  rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
174
-
175
- # Compute e-c-p distances
176
182
  d_star = distance(theta_star, theta_r0)
177
183
  d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
178
184
  f_val = jnp.mean(d_samples < d_star)
@@ -180,29 +186,28 @@ def one_sample_step(rng_key, xi, theta_star, n_warmup, n_samples, model, bounds)
180
186
  return rng_key, f_val, posterior
181
187
 
182
188
 
183
- def batched_one_sample_step(rng_keys, x_batch, theta_star_batch, n_warmup, n_samples, model, bounds):
184
- """
185
- Vectorized wrapper over `one_sample_step` using jax.vmap.
186
- Returns proper f_vals for ECP computation.
187
- """
189
+ def batched_one_sample_step_groups(rng_keys, x_batch, theta_star_batch,
190
+ n_warmup, n_samples, models_per_group, bounds, param_groups, param_names):
188
191
  return jax.vmap(
189
- lambda rng, x, theta: one_sample_step(rng, x[None, :], theta, n_warmup, n_samples, model, bounds),
192
+ lambda rng, x, theta: one_sample_step_groups(rng, x[None, :], theta, n_warmup, n_samples,
193
+ models_per_group, bounds, param_groups, param_names),
190
194
  in_axes=(0, 0, 0)
191
195
  )(rng_keys, x_batch, theta_star_batch)
192
196
 
193
-
194
- def compute_ecp_tarp_jitted(model, x_list, theta_star_list, alpha_list, n_warmup, n_samples, rng_key, bounds):
197
+ def compute_ecp_tarp_jitted_groups(models_per_group, x_list, theta_star_list, alpha_list,
198
+ n_warmup, n_samples, rng_key, bounds,
199
+ param_groups, param_names):
195
200
  """
196
- Compute expected coverage probabilities (ECP) using vectorized sampling.
197
- Returns proper f_vals for ECP computation.
201
+ Batched ECP computation using multiple group models.
198
202
  """
199
203
  N = x_list.shape[0]
200
204
  rng_key, split_key = jax.random.split(rng_key)
201
205
  rng_keys = jax.random.split(split_key, N)
202
206
 
203
207
  # Batched MCMC and distance evaluation
204
- _, f_vals, posterior_uns = batched_one_sample_step(
205
- rng_keys, x_list, theta_star_list, n_warmup, n_samples, model, bounds
208
+ _, f_vals, posterior_uns = batched_one_sample_step_groups(
209
+ rng_keys, x_list, theta_star_list, n_warmup, n_samples,
210
+ models_per_group, bounds, param_groups, param_names
206
211
  )
207
212
 
208
213
  # Compute ECP values for each alpha
@@ -210,14 +215,9 @@ def compute_ecp_tarp_jitted(model, x_list, theta_star_list, alpha_list, n_warmup
210
215
 
211
216
  return ecp_vals, f_vals, posterior_uns, rng_key
212
217
 
213
-
214
- def compute_ecp_tarp_jitted_with_progress(model, x_list, theta_star_list, alpha_list,
215
- n_warmup, n_samples, rng_key, bounds,
216
- batch_size=20):
217
- """
218
- Compute ECP using JITed MCMC in batches with progress reporting via tqdm.
219
- Returns correct f_vals for all simulations.
220
- """
218
+ def compute_ecp_tarp_jitted_with_progress_groups(models_per_group, x_list, theta_star_list, alpha_list,
219
+ n_warmup, n_samples, rng_key, bounds,
220
+ param_groups, param_names, batch_size=20):
221
221
  N = x_list.shape[0]
222
222
 
223
223
  posterior_list = []
@@ -229,19 +229,18 @@ def compute_ecp_tarp_jitted_with_progress(model, x_list, theta_star_list, alpha_
229
229
  theta_batch = theta_star_list[start:end]
230
230
 
231
231
  # Compute ECP and posterior for batch
232
- _, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted(
233
- model, x_batch, theta_batch, alpha_list,
234
- n_warmup, n_samples, rng_key, bounds
232
+ _, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted_groups(
233
+ models_per_group, x_batch, theta_batch, alpha_list,
234
+ n_warmup, n_samples, rng_key, bounds,
235
+ param_groups, param_names
235
236
  )
236
237
 
237
238
  posterior_list.append(posterior_batch)
238
239
  f_vals_list.append(f_vals_batch)
239
240
 
240
- # Concatenate across batches
241
241
  posterior_uns = jnp.concatenate(posterior_list, axis=0)
242
242
  f_vals_all = jnp.concatenate(f_vals_list, axis=0)
243
243
 
244
- # Compute final ECP for each alpha
245
244
  ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
246
245
 
247
- return ecp_vals, posterior_uns, rng_key
246
+ return ecp_vals, posterior_uns, rng_key
@@ -3,6 +3,8 @@ import os
3
3
  import json
4
4
  import numpy as np # Numerical Python
5
5
  import scipy as sp
6
+ import matplotlib.pyplot as plt
7
+ from IPython.display import clear_output
6
8
 
7
9
  # JAX and Flax (new NNX API)
8
10
  import jax # Automatic differentiation library
@@ -19,7 +21,7 @@ ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
19
21
  # Cosmology
20
22
  from astropy.cosmology import Planck18
21
23
 
22
- def rm_cosmo(z, magobs, magabs, ref_mag=19.3, z_max=0.1, n_grid=100_000):
24
+ def rm_cosmo(z, magobs, ref_mag=19.3, z_max=0.1, n_grid=100_000):
23
25
  """
24
26
  Interpolate Planck18 distance modulus and compute residuals to the cosmology
25
27
 
@@ -59,14 +61,27 @@ def rm_cosmo(z, magobs, magabs, ref_mag=19.3, z_max=0.1, n_grid=100_000):
59
61
  print('... done')
60
62
 
61
63
  magobs_corr = magobs - mu_planck18 + ref_mag
62
- magabs_corr = magabs + ref_mag
63
64
 
64
- return mu_planck18, magobs_corr, magabs_corr
65
+ return mu_planck18, magobs_corr
65
66
 
66
67
 
67
68
  def gaussian(x, mu, sigma):
68
69
  """
69
70
  Compute the normalized Gaussian function.
71
+
72
+ Parameters
73
+ ----------
74
+ x : array-like
75
+ Input values.
76
+ mu : float
77
+ Mean of the Gaussian.
78
+ sigma : float
79
+ Standard deviation of the Gaussian.
80
+
81
+ Returns
82
+ -------
83
+ array-like
84
+ The values of the Gaussian function evaluated at x.
70
85
  """
71
86
  prefactor = 1 / (np.sqrt(2 * np.pi * sigma**2))
72
87
  exponent = np.exp(-((x - mu)**2) / (2 * sigma**2))
@@ -167,6 +182,46 @@ def train_test_split_jax(X, y, test_size=0.3, shuffle=False, key=None):
167
182
 
168
183
  return X[:N_train], X[N_train:], y[:N_train], y[N_train:]
169
184
 
185
+ def train_test_split_indices_jax(N, test_size=0.3, shuffle=False, key=None, fixed_test_idx=None):
186
+ """
187
+ Generate train/test indices in JAX, optionally using a fixed test set.
188
+
189
+ Parameters
190
+ ----------
191
+ N : int
192
+ Total number of samples.
193
+ test_size : float
194
+ Fraction of the dataset to use as test data.
195
+ shuffle : bool
196
+ Whether to shuffle before splitting (ignored if fixed_test_idx is provided).
197
+ key : jax.random.PRNGKey
198
+ Random key used for shuffling (required if shuffle=True and fixed_test_idx is None).
199
+ fixed_test_idx : jax.numpy.ndarray, optional
200
+ Predefined indices to use as test set (persistent across rounds).
201
+
202
+ Returns
203
+ -------
204
+ train_idx : jax.numpy.ndarray
205
+ Indices for the training set.
206
+ test_idx : jax.numpy.ndarray
207
+ Indices for the test set.
208
+ """
209
+
210
+ N_test = int(jnp.floor(test_size * N))
211
+
212
+ if fixed_test_idx is None:
213
+ if shuffle:
214
+ perm = jax.random.permutation(key, N)
215
+ else:
216
+ perm = jnp.arange(N)
217
+ test_idx = perm[:N_test]
218
+ else:
219
+ test_idx = fixed_test_idx
220
+
221
+ train_idx = jnp.setdiff1d(jnp.arange(N), test_idx)
222
+ return train_idx, test_idx
223
+
224
+
170
225
  @nnx.jit
171
226
  def l2_loss(model, alpha):
172
227
  """
@@ -290,30 +345,50 @@ def pred_step(model, x_batch):
290
345
  return logits
291
346
 
292
347
  class Phi(nnx.Module):
348
+ """
349
+ Neural network module for the Phi network in a Deep Set architecture.
350
+ """
293
351
  def __init__(self, Nsize, n_cols, *, rngs):
294
- self.linear1 = nnx.Linear(n_cols, Nsize, rngs=rngs)
352
+ self.linear1 = nnx.Linear(n_cols, Nsize, rngs=rngs) #+n_params
295
353
  self.linear2 = nnx.Linear(Nsize, Nsize, rngs=rngs)
354
+ self.linear3 = nnx.Linear(Nsize, Nsize, rngs=rngs)
296
355
 
297
- def __call__(self, x):
298
- h = nnx.relu(self.linear1(x))
356
+ def __call__(self, data):
357
+ h = data
358
+
359
+ h = nnx.relu(self.linear1(h))
299
360
  h = nnx.relu(self.linear2(h))
361
+ h = nnx.relu(self.linear3(h))
300
362
  return h
301
363
 
302
364
 
303
365
  class Rho(nnx.Module):
304
- def __init__(self, Nsize_p, Nsize_r, n_params, *, rngs):
305
- self.linear1 = nnx.Linear(Nsize_p + n_params, Nsize_r, rngs=rngs)
306
- self.linear2 = nnx.Linear(Nsize_r, 1, rngs=rngs)
366
+ """
367
+ Neural network module for the Rho network in a Deep Set architecture
368
+ with separate LayerNorm for pooled features and theta.
369
+ """
370
+ def __init__(self, Nsize_p, Nsize_r, N_size_params, *, rngs):
371
+ self.linear1 = nnx.Linear(Nsize_p + N_size_params, Nsize_r, rngs=rngs) #
372
+ self.linear2 = nnx.Linear(Nsize_r, Nsize_r, rngs=rngs)
373
+ self.linear3 = nnx.Linear(Nsize_r, 1, rngs=rngs)
374
+
375
+ def __call__(self, dropout, pooled_features, params):
376
+ # Concatenate pooled features and embedding
377
+ x = jnp.concatenate([pooled_features, params], axis=-1)
307
378
 
308
- def __call__(self, dropout, pooled_features, theta):
309
- x = jnp.concatenate([pooled_features, theta], axis=-1)
310
379
  x = nnx.relu(self.linear1(x))
311
380
  x = dropout(x)
312
- return self.linear2(x)
313
381
 
382
+ x = nnx.relu(self.linear2(x)) #leaky_relu
383
+ x = dropout(x)
384
+
385
+ return self.linear3(x)
314
386
 
315
387
 
316
388
  class DeepSetClassifier(nnx.Module):
389
+ """
390
+ Deep Set Classifier model combining Phi and Rho networks.
391
+ """
317
392
  def __init__(self, dropout_rate, Nsize_p, Nsize_r,
318
393
  n_cols, n_params, *, rngs):
319
394
 
@@ -325,7 +400,14 @@ class DeepSetClassifier(nnx.Module):
325
400
  self.rho = Rho(Nsize_p, Nsize_r, n_params, rngs=rngs)
326
401
 
327
402
  def __call__(self, input_data):
328
- N, input_dim = input_data.shape
403
+ # ----------------------------------------------------
404
+ # Accept both shape (N, D) and (D,) without failing
405
+ # ----------------------------------------------------
406
+ if input_data.ndim == 1:
407
+ input_data = input_data[None, :]
408
+
409
+ N = input_data.shape[0]
410
+ input_dim = input_data.shape[1]
329
411
 
330
412
  # Compute M first from input size
331
413
  # Total input columns = M*n_cols + n_params + M (mask)
@@ -340,40 +422,166 @@ class DeepSetClassifier(nnx.Module):
340
422
  # Parameters
341
423
  theta = input_data[:, -self.n_params:] # shape (N, n_params)
342
424
 
343
- # print(theta)
344
-
345
425
  # Apply Phi
346
- h = self.phi(self.dropout, data)
426
+ h = self.phi(data)
347
427
 
348
428
  # Apply mask
349
429
  h_masked = h * mask[..., None]
350
430
 
351
- # Pool (masked average)mask_sum = jnp.sum(mask, axis=1, keepdims=True)
431
+ # Pool (masked average)
352
432
  mask_sum = jnp.sum(mask, axis=1, keepdims=True)
353
433
  mask_sum = jnp.where(mask_sum == 0, 1.0, mask_sum)
354
- pooled = jnp.sum(h_masked, axis=1) / mask_sum # Try jnp.sqrt(mask_sum) ?
434
+ pooled = jnp.sum(h_masked, axis=1) / mask_sum
435
+
436
+ # pooled_N = jnp.concatenate([pooled, mask_sum], axis=-1)
355
437
 
356
438
  # Apply Rho
357
439
  return self.rho(self.dropout, pooled, theta)
358
440
 
441
+ def train_loop(model,
442
+ optimizer,
443
+ train_data,
444
+ train_labels,
445
+ test_data,
446
+ test_labels,
447
+ key,
448
+ epochs,
449
+ batch_size,
450
+ patience,
451
+ metrics_history,
452
+ M,
453
+ N,
454
+ cpu,
455
+ gpu,
456
+ group_id,
457
+ group_params,
458
+ plot_flag=False):
459
+ """
460
+ Train loop with early stopping and optional plotting.
461
+ """
359
462
 
463
+ # Initialise stopping criteria
464
+ best_train_loss = jnp.inf
465
+ best_test_loss = jnp.inf
466
+ best_train_accuracy = 0.0
467
+ best_test_accuracy = 0.0
468
+ strikes = 0
469
+
470
+ model.train()
471
+
472
+ for epoch in range(epochs):
473
+
474
+ epoch_train_loss = 0
475
+ epoch_train_accuracy = 0
476
+
477
+ for i in range(0, len(train_data), batch_size):
478
+ # Get the current batch of data and labels
479
+ batch_data = jax.device_put(train_data[i:i+batch_size], gpu)
480
+ batch_labels = jax.device_put(train_labels[i:i+batch_size], gpu)
481
+
482
+ # Perform a training step
483
+ loss, _ = loss_fn(model, (batch_data, batch_labels))
484
+ accuracy = accuracy_fn(model, (batch_data, batch_labels))
485
+ epoch_train_loss += loss
486
+ # Multiply batch accuracy by batch size to get number of correct predictions
487
+ epoch_train_accuracy += accuracy * len(batch_data)
488
+ train_step(model, optimizer, (batch_data, batch_labels))
489
+
490
+ # Log the training metrics.
491
+ current_train_loss = epoch_train_loss / (len(train_data) / batch_size)
492
+ current_train_accuracy = epoch_train_accuracy / len(train_data)
493
+ metrics_history['train_loss'].append(current_train_loss)
494
+ # Compute overall epoch accuracy
495
+ metrics_history['train_accuracy'].append(current_train_accuracy)
496
+
497
+ epoch_test_loss = 0
498
+ epoch_test_accuracy = 0
499
+
500
+ # Compute the metrics on the test set using the same batching as training
501
+ for i in range(0, len(test_data), batch_size):
502
+ batch_data = jax.device_put(test_data[i:i+batch_size], gpu)
503
+ batch_labels = jax.device_put(test_labels[i:i+batch_size], gpu)
504
+
505
+ loss, _ = loss_fn(model, (batch_data, batch_labels))
506
+ accuracy = accuracy_fn(model, (batch_data, batch_labels))
507
+ epoch_test_loss += loss
508
+ epoch_test_accuracy += accuracy * len(batch_data)
509
+
510
+ # Log the test metrics.
511
+ current_test_loss = epoch_test_loss / (len(test_data) / batch_size)
512
+ current_test_accuracy = epoch_test_accuracy / len(test_data)
513
+ metrics_history['test_loss'].append(current_test_loss)
514
+ metrics_history['test_accuracy'].append(current_test_accuracy)
515
+
516
+ # Early Stopping Check
517
+ if current_test_loss < best_test_loss:
518
+ best_test_loss = current_test_loss # Update best test loss
519
+ strikes = 0
520
+ # elif current_test_accuracy > best_test_accuracy:
521
+ # best_test_accuracy = current_test_accuracy # Update best test accuracy
522
+ # strikes = 0
523
+ elif current_train_loss >= best_train_loss:
524
+ strikes = 0
525
+ elif current_test_loss > best_test_loss and current_train_loss < best_train_loss:
526
+ strikes += 1
527
+ elif current_train_loss < best_train_loss:
528
+ best_train_loss = current_train_loss # Update best train loss
529
+
530
+ if strikes >= patience:
531
+ print(f"\n Early stopping at epoch {epoch+1} due to {patience} consecutive increases in loss gap \n")
532
+ break
533
+
534
+ # Plotting (optional)
535
+ if plot_flag and epoch % 1 == 0:
536
+ clear_output(wait=True)
537
+
538
+ print(f"=== Training model for group {group_id}: {group_params} ===")
539
+
540
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
541
+
542
+ # Loss subplot
543
+ ax1.set_title(f'Loss for M:{M} and N:{N}')
544
+ for dataset in ('train', 'test'):
545
+ ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
546
+ ax1.legend()
547
+ ax1.set_yscale("log")
548
+
549
+ # Accuracy subplot
550
+ ax2.set_title('Accuracy')
551
+ for dataset in ('train', 'test'):
552
+ ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
553
+ ax2.legend()
554
+
555
+ plt.show()
556
+
557
+ if epoch == epochs-1:
558
+ print(f"\n Reached maximum epochs: {epochs} \n")
559
+
560
+ return model, metrics_history, key
561
+
562
+ def save_autoregressive_nn(models_per_group, path, model_config):
563
+ """
564
+ Save an autoregressive stack of NNX models.
360
565
 
361
- def save_nn(model, path, model_config):
566
+ Parameters
567
+ ----------
568
+ models_per_group : list[nnx.Module]
569
+ One model per autoregressive group.
570
+ path : str
571
+ Checkpoint directory.
572
+ model_config : dict
573
+ Full model configuration (shared + per-group).
574
+ """
362
575
  ckpt_dir = os.path.abspath(path)
363
576
  ckpt_dir = ocp.test_utils.erase_and_create_empty(ckpt_dir)
364
577
 
365
- # Split the model into GraphDef (structure) and State (parameters + buffers)
366
- _, _, _, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)
367
-
368
- # Display for debugging (optional)
369
- # nnx.display(state)
370
-
371
- # Initialize the checkpointer
372
578
  checkpointer = ocp.StandardCheckpointer()
373
579
 
374
- # Save State (parameters & non-trainable variables)
375
- checkpointer.save(ckpt_dir / 'state', state)
580
+ for g, model in enumerate(models_per_group):
581
+ # Split model into graph-independent state
582
+ _, _, _, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)
583
+ checkpointer.save(ckpt_dir / f"state_group_{g}", state)
376
584
 
377
- # Save model configuration for later loading
378
- with open(ckpt_dir / 'config.json', 'w') as f:
379
- json.dump(model_config, f)
585
+ # Save configuration
586
+ with open(ckpt_dir / "config.json", "w") as f:
587
+ json.dump(model_config, f, indent=2)
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ximinf
3
- Version: 0.0.2
3
+ Version: 0.0.16
4
4
  Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
5
5
  Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
6
6
  License: GPL-3.0-or-later
7
7
  Project-URL: Homepage, https://github.com/a-trigui/ximinf
8
+ Project-URL: Documentation, https://ximinf.readthedocs.io
8
9
  Keywords: cosmology,supernovae,simulation based inference
9
10
  Classifier: Programming Language :: Python :: 3
10
11
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
@@ -1 +0,0 @@
1
- # src/ximinf/__init__.py
@@ -1,132 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
-
4
- # Simulation libraries
5
- import skysurvey
6
- import skysurvey_sniapop
7
- import ztfidr.simulation as sim
8
-
9
-
10
- # def flatten_df(df: pd.DataFrame, columns: list, params: list = None) -> np.ndarray:
11
- # """
12
- # Flatten selected columns from a DataFrame into a single 1D numpy array.
13
-
14
- # Parameters
15
- # ----------
16
- # df : pd.DataFrame
17
- # Input dataframe containing the data.
18
- # columns : list of str
19
- # Column names to extract and flatten.
20
- # prepend_params : list or None
21
- # Optional list of parameters to prepend to the flattened array.
22
-
23
- # Returns
24
- # -------
25
- # np.ndarray
26
- # 1D array containing [prepend_params..., col1..., col2..., ...]
27
- # """
28
- # arrays = [df[col].to_numpy(dtype=np.float32) for col in columns]
29
- # flat = np.concatenate(arrays)
30
-
31
- # if params is not None:
32
- # flat = np.concatenate([np.array(params, dtype=np.float32), flat])
33
-
34
- # return flat
35
-
36
- # def unflatten_array(flat_array: np.ndarray, columns: list, n_points: int = 0):
37
- # """
38
- # Convert a flattened array back into its original columns and optional prepended parameters.
39
-
40
- # Parameters
41
- # ----------
42
- # flat_array : np.ndarray
43
- # 1D array containing the prepended parameters (optional) and column data.
44
- # columns : list of str
45
- # Original column names in the same order as they were flattened.
46
- # n_points : int
47
- # Number of rows (SNe) in the data. If > 0, the function will deduce
48
- # the number of prepended parameters automatically.
49
-
50
- # Returns
51
- # -------
52
- # tuple
53
- # If prepended_params exist: (prepended_params, df)
54
- # Else: df
55
- # """
56
- # flat_array = flat_array.astype(np.float32)
57
-
58
- # if n_points > 0:
59
- # # Deduce number of prepended parameters
60
- # n_params = flat_array.size - n_points * len(columns)
61
- # if n_params < 0:
62
- # raise ValueError("Number of points incompatible with flat array size")
63
- # prepended_params = flat_array[:n_params] if n_params > 0 else None
64
- # data_array = flat_array[n_params:]
65
- # else:
66
- # prepended_params = None
67
- # data_array = flat_array
68
-
69
- # n_rows = data_array.size // len(columns)
70
- # if n_rows * len(columns) != data_array.size:
71
- # raise ValueError("Flat array size is not compatible with number of columns")
72
-
73
- # # Split array into columns
74
- # split_arrays = np.split(data_array, len(columns))
75
- # df = pd.DataFrame({col: arr for col, arr in zip(columns, split_arrays)})
76
-
77
- # if prepended_params is not None:
78
- # return prepended_params, df
79
- # else:
80
- # return df
81
-
82
- def simulate_one(params_dict, sigma_int, z_max, M, cols, N=None, i=None):
83
- """
84
- params_dict: dict of model parameters (alpha, beta, mabs, gamma, etc.)
85
- cols: list of columns to include in the output
86
- Returns a dict with:
87
- 'data': dict of lists (one per column)
88
- 'params': dict of parameter values
89
- """
90
- # Print progress
91
- if N is not None and i is not None:
92
- if (i+1) % max(1, N//10) == 0 or i == N-1:
93
- print(f"Simulation {i+1}/{N}", end="\r", flush=True)
94
-
95
- # Unpack parameters
96
- alpha_ = float(params_dict.get("alpha", 0))
97
- beta_ = float(params_dict.get("beta", 0))
98
- mabs_ = float(params_dict.get("mabs", 0))
99
- gamma_ = float(params_dict.get("gamma", 0))
100
-
101
- brokenalpha_model = skysurvey_sniapop.brokenalpha_model
102
-
103
- # Generate SNe sample
104
- snia = skysurvey.SNeIa.from_draw(
105
- size=M,
106
- zmax=z_max,
107
- model=brokenalpha_model,
108
- magabs={
109
- "x1": "@x1",
110
- "c": "@c",
111
- "mabs": mabs_,
112
- "sigmaint": sigma_int,
113
- "alpha_low": alpha_,
114
- "alpha_high": alpha_,
115
- "beta": beta_,
116
- "gamma": gamma_
117
- }
118
- )
119
-
120
- # Apply noise
121
- errormodel = sim.noise_model
122
- errormodel["localcolor"]["kwargs"]["a"] = 2
123
- errormodel["localcolor"]["kwargs"]["loc"] = 0.005
124
- errormodel["localcolor"]["kwargs"]["scale"] = 0.05
125
- noisy_snia = snia.apply_gaussian_noise(errormodel)
126
-
127
- df = noisy_snia.data
128
-
129
- # Collect requested columns as lists
130
- data_dict = {col: list(df[col]) for col in cols if col in df}
131
-
132
- return data_dict
@@ -1,56 +0,0 @@
1
- # Standard
2
- import os
3
- import json
4
-
5
- # Jax
6
- from flax import nnx
7
-
8
- # Checkpointing
9
- import orbax.checkpoint as ocp # Checkpointing library
10
- ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
11
- import pathlib # File path handling library
12
-
13
- # Modules
14
- import ximinf.nn_train as nntr
15
-
16
- def load_nn(path):
17
- # Define the checkpoint directory
18
- ckpt_dir = os.path.abspath(path)
19
- ckpt_dir = pathlib.Path(ckpt_dir).resolve()
20
-
21
- # Ensure the folder is removed before saving
22
- if ckpt_dir.exists()==False:
23
- # Make an error
24
- raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
25
-
26
- # Load model configuration
27
- config_path = ckpt_dir / 'config.json'
28
- if not config_path.exists():
29
- raise ValueError("Model config file not found in checkpoint directory.")
30
-
31
- with open(config_path, 'r') as f:
32
- model_config = json.load(f)
33
-
34
- Nsize_p = model_config['Nsize_p']
35
- Nsize_r = model_config['Nsize_r']
36
- n_cols = model_config['n_cols']
37
- n_params = model_config['n_params']
38
-
39
- # 1. Re-create the checkpointer
40
- checkpointer = ocp.StandardCheckpointer()
41
-
42
- # Split the model into GraphDef (structure) and State (parameters + buffers)
43
- abstract_model = nnx.eval_shape(lambda: nntr.DeepSetClassifier(0.05, Nsize_p, Nsize_r, n_cols, n_params, rngs=nnx.Rngs(0)))
44
- abs_graphdef, abs_rngkey, abs_rngcount, _ = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)
45
-
46
- # 3. Restore
47
- state_restored = checkpointer.restore(ckpt_dir / 'state')
48
- #jax.tree.map(np.testing.assert_array_equal, abstract_state, state_restored)
49
- print('NNX State restored: ')
50
- # nnx.display(state_restored)
51
-
52
- model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)
53
-
54
- nnx.display(model)
55
-
56
- return model
File without changes
File without changes
File without changes