ximinf 0.0.2__tar.gz → 0.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ximinf-0.0.2/src/ximinf.egg-info → ximinf-0.0.16}/PKG-INFO +2 -1
- {ximinf-0.0.2 → ximinf-0.0.16}/pyproject.toml +2 -2
- ximinf-0.0.16/src/ximinf/__init__.py +7 -0
- ximinf-0.0.16/src/ximinf/generate_sim.py +117 -0
- ximinf-0.0.16/src/ximinf/nn_inference.py +135 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf/nn_test.py +100 -101
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf/nn_train.py +239 -31
- {ximinf-0.0.2 → ximinf-0.0.16/src/ximinf.egg-info}/PKG-INFO +2 -1
- ximinf-0.0.2/src/ximinf/__init__.py +0 -1
- ximinf-0.0.2/src/ximinf/generate_sim.py +0 -132
- ximinf-0.0.2/src/ximinf/nn_inference.py +0 -56
- {ximinf-0.0.2 → ximinf-0.0.16}/LICENSE +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/README.md +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/setup.cfg +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf/selection_effects.py +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf.egg-info/SOURCES.txt +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf.egg-info/dependency_links.txt +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf.egg-info/requires.txt +0 -0
- {ximinf-0.0.2 → ximinf-0.0.16}/src/ximinf.egg-info/top_level.txt +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ximinf
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.16
|
|
4
4
|
Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
|
|
5
5
|
Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
|
|
6
6
|
License: GPL-3.0-or-later
|
|
7
7
|
Project-URL: Homepage, https://github.com/a-trigui/ximinf
|
|
8
|
+
Project-URL: Documentation, https://ximinf.readthedocs.io
|
|
8
9
|
Keywords: cosmology,supernovae,simulation based inference
|
|
9
10
|
Classifier: Programming Language :: Python :: 3
|
|
10
11
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "ximinf"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.16"
|
|
8
8
|
description = "Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae. "
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -35,7 +35,7 @@ classifiers = [
|
|
|
35
35
|
|
|
36
36
|
[project.urls]
|
|
37
37
|
Homepage = "https://github.com/a-trigui/ximinf"
|
|
38
|
-
|
|
38
|
+
Documentation = "https://ximinf.readthedocs.io"
|
|
39
39
|
|
|
40
40
|
[project.optional-dependencies]
|
|
41
41
|
notebooks = ["jupyter", "matplotlib"]
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Simulation libraries
|
|
2
|
+
import skysurvey
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pyDOE import lhs # LHS sampler
|
|
5
|
+
|
|
6
|
+
def scan_params(ranges, N, dtype=np.float32):
|
|
7
|
+
"""
|
|
8
|
+
Generate sampled parameter sets using Latin Hypercube Sampling (LHS).
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
ranges : dict
|
|
13
|
+
Mapping parameter names to (min, max) tuples.
|
|
14
|
+
N : int
|
|
15
|
+
Number of samples.
|
|
16
|
+
dtype : data-type, optional
|
|
17
|
+
Numeric type for the sampled arrays (default is np.float32).
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
params_dict : dict
|
|
22
|
+
Dictionary of parameter arrays of shape (N,).
|
|
23
|
+
"""
|
|
24
|
+
param_names = list(ranges.keys())
|
|
25
|
+
n_params = len(param_names)
|
|
26
|
+
|
|
27
|
+
# LHS unit samples in [0,1]
|
|
28
|
+
unit_samples = lhs(n_params, samples=N)
|
|
29
|
+
|
|
30
|
+
# Scale unit samples to parameter ranges
|
|
31
|
+
params_dict = {}
|
|
32
|
+
for i, p in enumerate(param_names):
|
|
33
|
+
low, high = ranges[p]
|
|
34
|
+
params_dict[p] = (unit_samples[:, i] * (high - low) + low).astype(dtype)
|
|
35
|
+
|
|
36
|
+
return params_dict
|
|
37
|
+
|
|
38
|
+
def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
|
|
39
|
+
"""
|
|
40
|
+
Simulate a single dataset of SNe Ia.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
params_dict : dict
|
|
45
|
+
Dictionary of model parameters (alpha, beta, mabs, gamma, sigma_int, etc.).
|
|
46
|
+
z_max : float
|
|
47
|
+
Maximum redshift.
|
|
48
|
+
M : int
|
|
49
|
+
Number of SNe to simulate.
|
|
50
|
+
cols : list of str
|
|
51
|
+
List of columns to include in the output.
|
|
52
|
+
N : int, optional
|
|
53
|
+
Total number of simulations (for progress printing).
|
|
54
|
+
i : int, optional
|
|
55
|
+
Current simulation index (for progress printing).
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
data_dict : dict
|
|
60
|
+
Dictionary of lists (one per column) containing the simulated data.
|
|
61
|
+
"""
|
|
62
|
+
import ztfidr.simulation as sim
|
|
63
|
+
import skysurvey_sniapop
|
|
64
|
+
|
|
65
|
+
# Print progress
|
|
66
|
+
if N is not None and i is not None:
|
|
67
|
+
if (i+1) % max(1, N//10) == 0 or i == N-1:
|
|
68
|
+
print(f"Simulation {i+1}/{N}", end="\r", flush=True)
|
|
69
|
+
|
|
70
|
+
# Define default parameters including sigma_int
|
|
71
|
+
default_params = {
|
|
72
|
+
"alpha": 0.0,
|
|
73
|
+
"beta": 0.0,
|
|
74
|
+
"mabs": -19.3,
|
|
75
|
+
# "gamma": 0.0,
|
|
76
|
+
"sigma_int": 0.0, # default intrinsic scatter
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Merge defaults with provided params (params_dict takes priority)
|
|
80
|
+
params = {**default_params, **params_dict}
|
|
81
|
+
|
|
82
|
+
# Ensure all are floats
|
|
83
|
+
alpha_ = float(params["alpha"])
|
|
84
|
+
beta_ = float(params["beta"])
|
|
85
|
+
mabs_ = float(params["mabs"])
|
|
86
|
+
# gamma_ = float(params["gamma"])
|
|
87
|
+
sigma_int_ = float(params["sigma_int"])
|
|
88
|
+
|
|
89
|
+
# brokenalpha_model = skysurvey_sniapop.brokenalpha_model
|
|
90
|
+
|
|
91
|
+
# Generate SNe sample
|
|
92
|
+
snia = skysurvey.SNeIa.from_draw(
|
|
93
|
+
size=M,
|
|
94
|
+
zmax=z_max,
|
|
95
|
+
# model=brokenalpha_model,
|
|
96
|
+
magabs={
|
|
97
|
+
"mabs": mabs_,
|
|
98
|
+
"sigmaint": sigma_int_,
|
|
99
|
+
"alpha": alpha_,
|
|
100
|
+
"beta": beta_,
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Apply noise
|
|
105
|
+
# errormodel = sim.noise_model
|
|
106
|
+
# errormodel["localcolor"]["kwargs"]["a"] = 2
|
|
107
|
+
# errormodel["localcolor"]["kwargs"]["loc"] = 0.005
|
|
108
|
+
# errormodel["localcolor"]["kwargs"]["scale"] = 0.05
|
|
109
|
+
# noisy_snia = snia.apply_gaussian_noise(errormodel)
|
|
110
|
+
|
|
111
|
+
# df = noisy_snia.data
|
|
112
|
+
df = snia.data
|
|
113
|
+
|
|
114
|
+
# Collect requested columns as lists
|
|
115
|
+
data_dict = {col: list(df[col]) for col in cols if col in df}
|
|
116
|
+
|
|
117
|
+
return data_dict
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Standard
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
# Jax
|
|
6
|
+
from flax import nnx
|
|
7
|
+
|
|
8
|
+
# Checkpointing
|
|
9
|
+
import orbax.checkpoint as ocp # Checkpointing library
|
|
10
|
+
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
|
|
11
|
+
import pathlib # File path handling library
|
|
12
|
+
|
|
13
|
+
# Modules
|
|
14
|
+
import ximinf.nn_train as nntr
|
|
15
|
+
|
|
16
|
+
# def load_nn(path):
|
|
17
|
+
# """
|
|
18
|
+
# Load a neural network model from a checkpoint.
|
|
19
|
+
|
|
20
|
+
# Parameters
|
|
21
|
+
# ----------
|
|
22
|
+
# path : str
|
|
23
|
+
# Path to the checkpoint directory.
|
|
24
|
+
|
|
25
|
+
# Returns
|
|
26
|
+
# -------
|
|
27
|
+
# model : nnx.Module
|
|
28
|
+
# The loaded neural network model.
|
|
29
|
+
|
|
30
|
+
# Raises
|
|
31
|
+
# ------
|
|
32
|
+
# ValueError
|
|
33
|
+
# If the checkpoint directory or config file does not exist.
|
|
34
|
+
# """
|
|
35
|
+
# # Define the checkpoint directory
|
|
36
|
+
# ckpt_dir = os.path.abspath(path)
|
|
37
|
+
# ckpt_dir = pathlib.Path(ckpt_dir).resolve()
|
|
38
|
+
|
|
39
|
+
# # Ensure the folder is removed before saving
|
|
40
|
+
# if ckpt_dir.exists()==False:
|
|
41
|
+
# # Make an error
|
|
42
|
+
# raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
|
|
43
|
+
|
|
44
|
+
# # Load model configuration
|
|
45
|
+
# config_path = ckpt_dir / 'config.json'
|
|
46
|
+
# if not config_path.exists():
|
|
47
|
+
# raise ValueError("Model config file not found in checkpoint directory.")
|
|
48
|
+
|
|
49
|
+
# with open(config_path, 'r') as f:
|
|
50
|
+
# model_config = json.load(f)
|
|
51
|
+
|
|
52
|
+
# Nsize_p = model_config['Nsize_p']
|
|
53
|
+
# Nsize_r = model_config['Nsize_r']
|
|
54
|
+
# n_cols = model_config['n_cols']
|
|
55
|
+
# n_params = model_config['n_params']
|
|
56
|
+
# N_size_embed = model_config['N_size_embed']
|
|
57
|
+
|
|
58
|
+
# # 1. Re-create the checkpointer
|
|
59
|
+
# checkpointer = ocp.StandardCheckpointer()
|
|
60
|
+
|
|
61
|
+
# # Split the model into GraphDef (structure) and State (parameters + buffers)
|
|
62
|
+
# abstract_model = nnx.eval_shape(lambda: nntr.DeepSetClassifier(0.0, Nsize_p, Nsize_r, N_size_embed, n_cols, n_params, rngs=nnx.Rngs(0)))
|
|
63
|
+
# abs_graphdef, abs_rngkey, abs_rngcount, _ = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)
|
|
64
|
+
|
|
65
|
+
# # 3. Restore
|
|
66
|
+
# state_restored = checkpointer.restore(ckpt_dir / 'state')
|
|
67
|
+
# print('NNX State restored: ')
|
|
68
|
+
|
|
69
|
+
# model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)
|
|
70
|
+
|
|
71
|
+
# nnx.display(model)
|
|
72
|
+
|
|
73
|
+
# return model
|
|
74
|
+
|
|
75
|
+
def load_autoregressive_nn(path):
|
|
76
|
+
"""
|
|
77
|
+
Load an autoregressive stack of NNX models.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
path : str
|
|
82
|
+
Checkpoint directory.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
models_per_group : list[nnx.Module]
|
|
87
|
+
Reconstructed models, one per group.
|
|
88
|
+
model_config : dict
|
|
89
|
+
Loaded configuration dictionary.
|
|
90
|
+
"""
|
|
91
|
+
ckpt_dir = pathlib.Path(path).resolve()
|
|
92
|
+
if not ckpt_dir.exists():
|
|
93
|
+
raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
|
94
|
+
|
|
95
|
+
config_path = ckpt_dir / "config.json"
|
|
96
|
+
if not config_path.exists():
|
|
97
|
+
raise ValueError("Model config file not found.")
|
|
98
|
+
|
|
99
|
+
with open(config_path, "r") as f:
|
|
100
|
+
model_config = json.load(f)
|
|
101
|
+
|
|
102
|
+
shared = model_config["shared"]
|
|
103
|
+
group_configs = model_config["groups"]
|
|
104
|
+
|
|
105
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
106
|
+
models_per_group = []
|
|
107
|
+
|
|
108
|
+
for gconf in group_configs:
|
|
109
|
+
n_params_visible = gconf["n_params_visible"]
|
|
110
|
+
|
|
111
|
+
# Recreate abstract model (shape-only)
|
|
112
|
+
abstract_model = nnx.eval_shape(
|
|
113
|
+
lambda: nntr.DeepSetClassifier( # It should not work, there is no class DeepSetClassifier defined in nntr, check how this should be properly done
|
|
114
|
+
dropout_rate=0.0,
|
|
115
|
+
Nsize_p=shared["Nsize_p"],
|
|
116
|
+
Nsize_r=shared["Nsize_r"],
|
|
117
|
+
n_cols=shared["n_cols"],
|
|
118
|
+
n_params=n_params_visible,
|
|
119
|
+
rngs=nnx.Rngs(0),
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
graphdef, rngkey, rngcount, _ = nnx.split(
|
|
124
|
+
abstract_model, nnx.RngKey, nnx.RngCount, ...
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Restore parameters
|
|
128
|
+
state = checkpointer.restore(
|
|
129
|
+
ckpt_dir / f"state_group_{gconf['group_id']}"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
model = nnx.merge(graphdef, rngkey, rngcount, state)
|
|
133
|
+
models_per_group.append(model)
|
|
134
|
+
|
|
135
|
+
return models_per_group, model_config
|
|
@@ -24,56 +24,43 @@ def distance(theta1, theta2):
|
|
|
24
24
|
diff = theta1 - theta2
|
|
25
25
|
return jnp.linalg.norm(diff)
|
|
26
26
|
|
|
27
|
-
def log_prior(theta, bounds):
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
27
|
+
# def log_prior(theta, bounds):
|
|
28
|
+
# """
|
|
29
|
+
# Compute the log-prior probability for the parameter `theta`,
|
|
30
|
+
# assuming uniform prior within given bounds.
|
|
31
|
+
|
|
32
|
+
# Parameters
|
|
33
|
+
# ----------
|
|
34
|
+
# theta : array-like
|
|
35
|
+
# The parameter values for which the prior is to be calculated.
|
|
36
|
+
# bounds : jnp.ndarray, optional
|
|
37
|
+
# The bounds on each parameter (default is the global `BOUNDS`).
|
|
38
|
+
|
|
39
|
+
# Returns
|
|
40
|
+
# -------
|
|
41
|
+
# float
|
|
42
|
+
# The log-prior of `theta`, or negative infinity if `theta` is out of bounds.
|
|
43
|
+
# """
|
|
44
|
+
|
|
45
|
+
# in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
|
|
46
|
+
# return jnp.where(in_bounds, 0.0, -jnp.inf)
|
|
47
|
+
|
|
48
|
+
def log_group_prior(theta, bounds, group_indices):
|
|
49
|
+
"""
|
|
50
|
+
Log prior for a single parameter group.
|
|
51
|
+
Uniform within bounds, -inf otherwise.
|
|
52
|
+
"""
|
|
53
|
+
theta_g = theta[group_indices]
|
|
54
|
+
bounds_g = bounds[group_indices]
|
|
55
|
+
|
|
56
|
+
in_bounds = jnp.all(
|
|
57
|
+
(theta_g >= bounds_g[:, 0]) &
|
|
58
|
+
(theta_g <= bounds_g[:, 1])
|
|
59
|
+
)
|
|
44
60
|
|
|
45
|
-
in_bounds = jnp.all((theta >= bounds[:, 0]) & (theta <= bounds[:, 1]))
|
|
46
61
|
return jnp.where(in_bounds, 0.0, -jnp.inf)
|
|
47
62
|
|
|
48
|
-
def log_prob_fn(theta, model, xy_noise, bounds):
|
|
49
|
-
"""
|
|
50
|
-
Compute the log-probability for the parameter `theta` using a
|
|
51
|
-
log-prior and the log-likelihood from the neural likelihood ratio approximation.
|
|
52
63
|
|
|
53
|
-
Parameters
|
|
54
|
-
----------
|
|
55
|
-
theta : array-like
|
|
56
|
-
The parameter values for which the log-probability is computed.
|
|
57
|
-
model : callable
|
|
58
|
-
A function that takes `theta` and produces model logits for computing the likelihood.
|
|
59
|
-
xy_noise : array-like
|
|
60
|
-
Input data with added noise for evaluating the likelihood.
|
|
61
|
-
|
|
62
|
-
Returns
|
|
63
|
-
-------
|
|
64
|
-
float
|
|
65
|
-
The log-probability, which is the sum of the log-prior and the log-likelihood.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
lp = log_prior(theta, bounds)
|
|
69
|
-
lp = jnp.where(jnp.isfinite(lp), lp, -jnp.inf)
|
|
70
|
-
xy_flat = xy_noise.squeeze()
|
|
71
|
-
inp = jnp.concatenate([xy_flat, theta])[None, :]
|
|
72
|
-
logits = model(inp)
|
|
73
|
-
p = jax.nn.sigmoid(logits).squeeze()
|
|
74
|
-
p = jnp.clip(p, 1e-6, 1 - 1e-6)
|
|
75
|
-
log_like = jnp.log(p) - jnp.log1p(-p)
|
|
76
|
-
return lp + log_like
|
|
77
64
|
|
|
78
65
|
def sample_reference_point(rng_key, bounds):
|
|
79
66
|
"""
|
|
@@ -126,6 +113,50 @@ def inference_loop(rng_key, kernel, initial_state, num_samples):
|
|
|
126
113
|
_, states = jax.lax.scan(one_step, initial_state, keys)
|
|
127
114
|
return states
|
|
128
115
|
|
|
116
|
+
def log_prob_fn_groups(theta, models_per_group, data, bounds,
|
|
117
|
+
param_groups, global_param_names):
|
|
118
|
+
|
|
119
|
+
log_r_sum = 0.0
|
|
120
|
+
log_p_group_sum = 0.0
|
|
121
|
+
|
|
122
|
+
data = data.reshape(1, -1)
|
|
123
|
+
|
|
124
|
+
for g, group in enumerate(param_groups):
|
|
125
|
+
|
|
126
|
+
# --- parameter bookkeeping (unchanged) ---
|
|
127
|
+
prev_groups = [
|
|
128
|
+
p
|
|
129
|
+
for i in range(g)
|
|
130
|
+
for p in (param_groups[i] if isinstance(param_groups[i], list)
|
|
131
|
+
else [param_groups[i]])
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
group_list = [group] if isinstance(group, str) else group
|
|
135
|
+
visible_param_names = prev_groups + group_list
|
|
136
|
+
|
|
137
|
+
visible_idx = jnp.array(
|
|
138
|
+
[global_param_names.index(name) for name in visible_param_names]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
theta_visible = theta[visible_idx].reshape(1, -1)
|
|
142
|
+
input_g = jnp.concatenate([data, theta_visible], axis=-1)
|
|
143
|
+
|
|
144
|
+
# --- ratio estimator ---
|
|
145
|
+
logits = models_per_group[g](input_g)
|
|
146
|
+
p = jax.nn.sigmoid(logits)
|
|
147
|
+
log_r_sum += jnp.log(p) - jnp.log1p(-p)
|
|
148
|
+
|
|
149
|
+
# --- marginal prior for this group ---
|
|
150
|
+
group_idx = jnp.array(
|
|
151
|
+
[global_param_names.index(name) for name in group_list]
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
log_p_group_sum += log_group_prior(theta, bounds, group_idx)
|
|
155
|
+
|
|
156
|
+
return jnp.squeeze(log_r_sum + log_p_group_sum)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
129
160
|
@partial(jax.jit, static_argnums=(0, 1, 2))
|
|
130
161
|
def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
|
|
131
162
|
warmup = blackjax.window_adaptation(blackjax.nuts, log_prob)
|
|
@@ -137,42 +168,17 @@ def sample_posterior(log_prob, n_warmup, n_samples, init_position, rng_key):
|
|
|
137
168
|
return rng_key, states.position
|
|
138
169
|
|
|
139
170
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def one_sample_step(rng_key, xi, theta_star, n_warmup, n_samples, model, bounds):
|
|
171
|
+
def one_sample_step_groups(rng_key, xi, theta_star, n_warmup, n_samples,
|
|
172
|
+
models_per_group, bounds, param_groups, param_names):
|
|
143
173
|
"""
|
|
144
|
-
Sample from
|
|
145
|
-
with NUTS (No-U-Turn Sampler) for a given `log_prob`.
|
|
146
|
-
|
|
147
|
-
Parameters
|
|
148
|
-
----------
|
|
149
|
-
log_prob : callable
|
|
150
|
-
The log-probability function for the model and parameters.
|
|
151
|
-
n_warmup : int
|
|
152
|
-
The number of warmup steps to adapt the sampler.
|
|
153
|
-
n_samples : int
|
|
154
|
-
The number of samples to generate after warmup.
|
|
155
|
-
init_position : array-like
|
|
156
|
-
The initial position for the chain (parameter values).
|
|
157
|
-
rng_key : jax.random.PRNGKey
|
|
158
|
-
The random key used for sampling.
|
|
159
|
-
|
|
160
|
-
Returns
|
|
161
|
-
-------
|
|
162
|
-
jax.numpy.ndarray
|
|
163
|
-
The sampled positions (parameters) from the posterior distribution.
|
|
174
|
+
Sample from posterior using sum of log-likelihoods over all groups.
|
|
164
175
|
"""
|
|
165
|
-
|
|
166
|
-
# Draw a random reference
|
|
167
176
|
rng_key, theta_r0 = sample_reference_point(rng_key, bounds)
|
|
168
177
|
|
|
169
178
|
def log_post(theta):
|
|
170
|
-
return
|
|
179
|
+
return log_prob_fn_groups(theta, models_per_group, xi, bounds, param_groups, param_names)
|
|
171
180
|
|
|
172
|
-
# Run MCMC
|
|
173
181
|
rng_key, posterior = sample_posterior(log_post, n_warmup, n_samples, theta_star, rng_key)
|
|
174
|
-
|
|
175
|
-
# Compute e-c-p distances
|
|
176
182
|
d_star = distance(theta_star, theta_r0)
|
|
177
183
|
d_samples = jnp.linalg.norm(posterior - theta_r0, axis=1)
|
|
178
184
|
f_val = jnp.mean(d_samples < d_star)
|
|
@@ -180,29 +186,28 @@ def one_sample_step(rng_key, xi, theta_star, n_warmup, n_samples, model, bounds)
|
|
|
180
186
|
return rng_key, f_val, posterior
|
|
181
187
|
|
|
182
188
|
|
|
183
|
-
def
|
|
184
|
-
|
|
185
|
-
Vectorized wrapper over `one_sample_step` using jax.vmap.
|
|
186
|
-
Returns proper f_vals for ECP computation.
|
|
187
|
-
"""
|
|
189
|
+
def batched_one_sample_step_groups(rng_keys, x_batch, theta_star_batch,
|
|
190
|
+
n_warmup, n_samples, models_per_group, bounds, param_groups, param_names):
|
|
188
191
|
return jax.vmap(
|
|
189
|
-
lambda rng, x, theta:
|
|
192
|
+
lambda rng, x, theta: one_sample_step_groups(rng, x[None, :], theta, n_warmup, n_samples,
|
|
193
|
+
models_per_group, bounds, param_groups, param_names),
|
|
190
194
|
in_axes=(0, 0, 0)
|
|
191
195
|
)(rng_keys, x_batch, theta_star_batch)
|
|
192
196
|
|
|
193
|
-
|
|
194
|
-
|
|
197
|
+
def compute_ecp_tarp_jitted_groups(models_per_group, x_list, theta_star_list, alpha_list,
|
|
198
|
+
n_warmup, n_samples, rng_key, bounds,
|
|
199
|
+
param_groups, param_names):
|
|
195
200
|
"""
|
|
196
|
-
|
|
197
|
-
Returns proper f_vals for ECP computation.
|
|
201
|
+
Batched ECP computation using multiple group models.
|
|
198
202
|
"""
|
|
199
203
|
N = x_list.shape[0]
|
|
200
204
|
rng_key, split_key = jax.random.split(rng_key)
|
|
201
205
|
rng_keys = jax.random.split(split_key, N)
|
|
202
206
|
|
|
203
207
|
# Batched MCMC and distance evaluation
|
|
204
|
-
_, f_vals, posterior_uns =
|
|
205
|
-
rng_keys, x_list, theta_star_list, n_warmup, n_samples,
|
|
208
|
+
_, f_vals, posterior_uns = batched_one_sample_step_groups(
|
|
209
|
+
rng_keys, x_list, theta_star_list, n_warmup, n_samples,
|
|
210
|
+
models_per_group, bounds, param_groups, param_names
|
|
206
211
|
)
|
|
207
212
|
|
|
208
213
|
# Compute ECP values for each alpha
|
|
@@ -210,14 +215,9 @@ def compute_ecp_tarp_jitted(model, x_list, theta_star_list, alpha_list, n_warmup
|
|
|
210
215
|
|
|
211
216
|
return ecp_vals, f_vals, posterior_uns, rng_key
|
|
212
217
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
batch_size=20):
|
|
217
|
-
"""
|
|
218
|
-
Compute ECP using JITed MCMC in batches with progress reporting via tqdm.
|
|
219
|
-
Returns correct f_vals for all simulations.
|
|
220
|
-
"""
|
|
218
|
+
def compute_ecp_tarp_jitted_with_progress_groups(models_per_group, x_list, theta_star_list, alpha_list,
|
|
219
|
+
n_warmup, n_samples, rng_key, bounds,
|
|
220
|
+
param_groups, param_names, batch_size=20):
|
|
221
221
|
N = x_list.shape[0]
|
|
222
222
|
|
|
223
223
|
posterior_list = []
|
|
@@ -229,19 +229,18 @@ def compute_ecp_tarp_jitted_with_progress(model, x_list, theta_star_list, alpha_
|
|
|
229
229
|
theta_batch = theta_star_list[start:end]
|
|
230
230
|
|
|
231
231
|
# Compute ECP and posterior for batch
|
|
232
|
-
_, f_vals_batch, posterior_batch, rng_key =
|
|
233
|
-
|
|
234
|
-
n_warmup, n_samples, rng_key, bounds
|
|
232
|
+
_, f_vals_batch, posterior_batch, rng_key = compute_ecp_tarp_jitted_groups(
|
|
233
|
+
models_per_group, x_batch, theta_batch, alpha_list,
|
|
234
|
+
n_warmup, n_samples, rng_key, bounds,
|
|
235
|
+
param_groups, param_names
|
|
235
236
|
)
|
|
236
237
|
|
|
237
238
|
posterior_list.append(posterior_batch)
|
|
238
239
|
f_vals_list.append(f_vals_batch)
|
|
239
240
|
|
|
240
|
-
# Concatenate across batches
|
|
241
241
|
posterior_uns = jnp.concatenate(posterior_list, axis=0)
|
|
242
242
|
f_vals_all = jnp.concatenate(f_vals_list, axis=0)
|
|
243
243
|
|
|
244
|
-
# Compute final ECP for each alpha
|
|
245
244
|
ecp_vals = [jnp.mean(f_vals_all < (1 - alpha)) for alpha in alpha_list]
|
|
246
245
|
|
|
247
|
-
return ecp_vals, posterior_uns, rng_key
|
|
246
|
+
return ecp_vals, posterior_uns, rng_key
|
|
@@ -3,6 +3,8 @@ import os
|
|
|
3
3
|
import json
|
|
4
4
|
import numpy as np # Numerical Python
|
|
5
5
|
import scipy as sp
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
from IPython.display import clear_output
|
|
6
8
|
|
|
7
9
|
# JAX and Flax (new NNX API)
|
|
8
10
|
import jax # Automatic differentiation library
|
|
@@ -19,7 +21,7 @@ ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
|
|
|
19
21
|
# Cosmology
|
|
20
22
|
from astropy.cosmology import Planck18
|
|
21
23
|
|
|
22
|
-
def rm_cosmo(z, magobs,
|
|
24
|
+
def rm_cosmo(z, magobs, ref_mag=19.3, z_max=0.1, n_grid=100_000):
|
|
23
25
|
"""
|
|
24
26
|
Interpolate Planck18 distance modulus and compute residuals to the cosmology
|
|
25
27
|
|
|
@@ -59,14 +61,27 @@ def rm_cosmo(z, magobs, magabs, ref_mag=19.3, z_max=0.1, n_grid=100_000):
|
|
|
59
61
|
print('... done')
|
|
60
62
|
|
|
61
63
|
magobs_corr = magobs - mu_planck18 + ref_mag
|
|
62
|
-
magabs_corr = magabs + ref_mag
|
|
63
64
|
|
|
64
|
-
return mu_planck18, magobs_corr
|
|
65
|
+
return mu_planck18, magobs_corr
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
def gaussian(x, mu, sigma):
|
|
68
69
|
"""
|
|
69
70
|
Compute the normalized Gaussian function.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
x : array-like
|
|
75
|
+
Input values.
|
|
76
|
+
mu : float
|
|
77
|
+
Mean of the Gaussian.
|
|
78
|
+
sigma : float
|
|
79
|
+
Standard deviation of the Gaussian.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
array-like
|
|
84
|
+
The values of the Gaussian function evaluated at x.
|
|
70
85
|
"""
|
|
71
86
|
prefactor = 1 / (np.sqrt(2 * np.pi * sigma**2))
|
|
72
87
|
exponent = np.exp(-((x - mu)**2) / (2 * sigma**2))
|
|
@@ -167,6 +182,46 @@ def train_test_split_jax(X, y, test_size=0.3, shuffle=False, key=None):
|
|
|
167
182
|
|
|
168
183
|
return X[:N_train], X[N_train:], y[:N_train], y[N_train:]
|
|
169
184
|
|
|
185
|
+
def train_test_split_indices_jax(N, test_size=0.3, shuffle=False, key=None, fixed_test_idx=None):
|
|
186
|
+
"""
|
|
187
|
+
Generate train/test indices in JAX, optionally using a fixed test set.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
N : int
|
|
192
|
+
Total number of samples.
|
|
193
|
+
test_size : float
|
|
194
|
+
Fraction of the dataset to use as test data.
|
|
195
|
+
shuffle : bool
|
|
196
|
+
Whether to shuffle before splitting (ignored if fixed_test_idx is provided).
|
|
197
|
+
key : jax.random.PRNGKey
|
|
198
|
+
Random key used for shuffling (required if shuffle=True and fixed_test_idx is None).
|
|
199
|
+
fixed_test_idx : jax.numpy.ndarray, optional
|
|
200
|
+
Predefined indices to use as test set (persistent across rounds).
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
train_idx : jax.numpy.ndarray
|
|
205
|
+
Indices for the training set.
|
|
206
|
+
test_idx : jax.numpy.ndarray
|
|
207
|
+
Indices for the test set.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
N_test = int(jnp.floor(test_size * N))
|
|
211
|
+
|
|
212
|
+
if fixed_test_idx is None:
|
|
213
|
+
if shuffle:
|
|
214
|
+
perm = jax.random.permutation(key, N)
|
|
215
|
+
else:
|
|
216
|
+
perm = jnp.arange(N)
|
|
217
|
+
test_idx = perm[:N_test]
|
|
218
|
+
else:
|
|
219
|
+
test_idx = fixed_test_idx
|
|
220
|
+
|
|
221
|
+
train_idx = jnp.setdiff1d(jnp.arange(N), test_idx)
|
|
222
|
+
return train_idx, test_idx
|
|
223
|
+
|
|
224
|
+
|
|
170
225
|
@nnx.jit
|
|
171
226
|
def l2_loss(model, alpha):
|
|
172
227
|
"""
|
|
@@ -290,30 +345,50 @@ def pred_step(model, x_batch):
|
|
|
290
345
|
return logits
|
|
291
346
|
|
|
292
347
|
class Phi(nnx.Module):
|
|
348
|
+
"""
|
|
349
|
+
Neural network module for the Phi network in a Deep Set architecture.
|
|
350
|
+
"""
|
|
293
351
|
def __init__(self, Nsize, n_cols, *, rngs):
|
|
294
|
-
self.linear1 = nnx.Linear(n_cols, Nsize, rngs=rngs)
|
|
352
|
+
self.linear1 = nnx.Linear(n_cols, Nsize, rngs=rngs) #+n_params
|
|
295
353
|
self.linear2 = nnx.Linear(Nsize, Nsize, rngs=rngs)
|
|
354
|
+
self.linear3 = nnx.Linear(Nsize, Nsize, rngs=rngs)
|
|
296
355
|
|
|
297
|
-
def __call__(self,
|
|
298
|
-
h =
|
|
356
|
+
def __call__(self, data):
|
|
357
|
+
h = data
|
|
358
|
+
|
|
359
|
+
h = nnx.relu(self.linear1(h))
|
|
299
360
|
h = nnx.relu(self.linear2(h))
|
|
361
|
+
h = nnx.relu(self.linear3(h))
|
|
300
362
|
return h
|
|
301
363
|
|
|
302
364
|
|
|
303
365
|
class Rho(nnx.Module):
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
366
|
+
"""
|
|
367
|
+
Neural network module for the Rho network in a Deep Set architecture
|
|
368
|
+
with separate LayerNorm for pooled features and theta.
|
|
369
|
+
"""
|
|
370
|
+
def __init__(self, Nsize_p, Nsize_r, N_size_params, *, rngs):
|
|
371
|
+
self.linear1 = nnx.Linear(Nsize_p + N_size_params, Nsize_r, rngs=rngs) #
|
|
372
|
+
self.linear2 = nnx.Linear(Nsize_r, Nsize_r, rngs=rngs)
|
|
373
|
+
self.linear3 = nnx.Linear(Nsize_r, 1, rngs=rngs)
|
|
374
|
+
|
|
375
|
+
def __call__(self, dropout, pooled_features, params):
|
|
376
|
+
# Concatenate pooled features and embedding
|
|
377
|
+
x = jnp.concatenate([pooled_features, params], axis=-1)
|
|
307
378
|
|
|
308
|
-
def __call__(self, dropout, pooled_features, theta):
|
|
309
|
-
x = jnp.concatenate([pooled_features, theta], axis=-1)
|
|
310
379
|
x = nnx.relu(self.linear1(x))
|
|
311
380
|
x = dropout(x)
|
|
312
|
-
return self.linear2(x)
|
|
313
381
|
|
|
382
|
+
x = nnx.relu(self.linear2(x)) #leaky_relu
|
|
383
|
+
x = dropout(x)
|
|
384
|
+
|
|
385
|
+
return self.linear3(x)
|
|
314
386
|
|
|
315
387
|
|
|
316
388
|
class DeepSetClassifier(nnx.Module):
|
|
389
|
+
"""
|
|
390
|
+
Deep Set Classifier model combining Phi and Rho networks.
|
|
391
|
+
"""
|
|
317
392
|
def __init__(self, dropout_rate, Nsize_p, Nsize_r,
|
|
318
393
|
n_cols, n_params, *, rngs):
|
|
319
394
|
|
|
@@ -325,7 +400,14 @@ class DeepSetClassifier(nnx.Module):
|
|
|
325
400
|
self.rho = Rho(Nsize_p, Nsize_r, n_params, rngs=rngs)
|
|
326
401
|
|
|
327
402
|
def __call__(self, input_data):
|
|
328
|
-
|
|
403
|
+
# ----------------------------------------------------
|
|
404
|
+
# Accept both shape (N, D) and (D,) without failing
|
|
405
|
+
# ----------------------------------------------------
|
|
406
|
+
if input_data.ndim == 1:
|
|
407
|
+
input_data = input_data[None, :]
|
|
408
|
+
|
|
409
|
+
N = input_data.shape[0]
|
|
410
|
+
input_dim = input_data.shape[1]
|
|
329
411
|
|
|
330
412
|
# Compute M first from input size
|
|
331
413
|
# Total input columns = M*n_cols + n_params + M (mask)
|
|
@@ -340,40 +422,166 @@ class DeepSetClassifier(nnx.Module):
|
|
|
340
422
|
# Parameters
|
|
341
423
|
theta = input_data[:, -self.n_params:] # shape (N, n_params)
|
|
342
424
|
|
|
343
|
-
# print(theta)
|
|
344
|
-
|
|
345
425
|
# Apply Phi
|
|
346
|
-
h = self.phi(
|
|
426
|
+
h = self.phi(data)
|
|
347
427
|
|
|
348
428
|
# Apply mask
|
|
349
429
|
h_masked = h * mask[..., None]
|
|
350
430
|
|
|
351
|
-
# Pool (masked average)
|
|
431
|
+
# Pool (masked average)
|
|
352
432
|
mask_sum = jnp.sum(mask, axis=1, keepdims=True)
|
|
353
433
|
mask_sum = jnp.where(mask_sum == 0, 1.0, mask_sum)
|
|
354
|
-
pooled = jnp.sum(h_masked, axis=1) / mask_sum
|
|
434
|
+
pooled = jnp.sum(h_masked, axis=1) / mask_sum
|
|
435
|
+
|
|
436
|
+
# pooled_N = jnp.concatenate([pooled, mask_sum], axis=-1)
|
|
355
437
|
|
|
356
438
|
# Apply Rho
|
|
357
439
|
return self.rho(self.dropout, pooled, theta)
|
|
358
440
|
|
|
441
|
+
def train_loop(model,
|
|
442
|
+
optimizer,
|
|
443
|
+
train_data,
|
|
444
|
+
train_labels,
|
|
445
|
+
test_data,
|
|
446
|
+
test_labels,
|
|
447
|
+
key,
|
|
448
|
+
epochs,
|
|
449
|
+
batch_size,
|
|
450
|
+
patience,
|
|
451
|
+
metrics_history,
|
|
452
|
+
M,
|
|
453
|
+
N,
|
|
454
|
+
cpu,
|
|
455
|
+
gpu,
|
|
456
|
+
group_id,
|
|
457
|
+
group_params,
|
|
458
|
+
plot_flag=False):
|
|
459
|
+
"""
|
|
460
|
+
Train loop with early stopping and optional plotting.
|
|
461
|
+
"""
|
|
359
462
|
|
|
463
|
+
# Initialise stopping criteria
|
|
464
|
+
best_train_loss = jnp.inf
|
|
465
|
+
best_test_loss = jnp.inf
|
|
466
|
+
best_train_accuracy = 0.0
|
|
467
|
+
best_test_accuracy = 0.0
|
|
468
|
+
strikes = 0
|
|
469
|
+
|
|
470
|
+
model.train()
|
|
471
|
+
|
|
472
|
+
for epoch in range(epochs):
|
|
473
|
+
|
|
474
|
+
epoch_train_loss = 0
|
|
475
|
+
epoch_train_accuracy = 0
|
|
476
|
+
|
|
477
|
+
for i in range(0, len(train_data), batch_size):
|
|
478
|
+
# Get the current batch of data and labels
|
|
479
|
+
batch_data = jax.device_put(train_data[i:i+batch_size], gpu)
|
|
480
|
+
batch_labels = jax.device_put(train_labels[i:i+batch_size], gpu)
|
|
481
|
+
|
|
482
|
+
# Perform a training step
|
|
483
|
+
loss, _ = loss_fn(model, (batch_data, batch_labels))
|
|
484
|
+
accuracy = accuracy_fn(model, (batch_data, batch_labels))
|
|
485
|
+
epoch_train_loss += loss
|
|
486
|
+
# Multiply batch accuracy by batch size to get number of correct predictions
|
|
487
|
+
epoch_train_accuracy += accuracy * len(batch_data)
|
|
488
|
+
train_step(model, optimizer, (batch_data, batch_labels))
|
|
489
|
+
|
|
490
|
+
# Log the training metrics.
|
|
491
|
+
current_train_loss = epoch_train_loss / (len(train_data) / batch_size)
|
|
492
|
+
current_train_accuracy = epoch_train_accuracy / len(train_data)
|
|
493
|
+
metrics_history['train_loss'].append(current_train_loss)
|
|
494
|
+
# Compute overall epoch accuracy
|
|
495
|
+
metrics_history['train_accuracy'].append(current_train_accuracy)
|
|
496
|
+
|
|
497
|
+
epoch_test_loss = 0
|
|
498
|
+
epoch_test_accuracy = 0
|
|
499
|
+
|
|
500
|
+
# Compute the metrics on the test set using the same batching as training
|
|
501
|
+
for i in range(0, len(test_data), batch_size):
|
|
502
|
+
batch_data = jax.device_put(test_data[i:i+batch_size], gpu)
|
|
503
|
+
batch_labels = jax.device_put(test_labels[i:i+batch_size], gpu)
|
|
504
|
+
|
|
505
|
+
loss, _ = loss_fn(model, (batch_data, batch_labels))
|
|
506
|
+
accuracy = accuracy_fn(model, (batch_data, batch_labels))
|
|
507
|
+
epoch_test_loss += loss
|
|
508
|
+
epoch_test_accuracy += accuracy * len(batch_data)
|
|
509
|
+
|
|
510
|
+
# Log the test metrics.
|
|
511
|
+
current_test_loss = epoch_test_loss / (len(test_data) / batch_size)
|
|
512
|
+
current_test_accuracy = epoch_test_accuracy / len(test_data)
|
|
513
|
+
metrics_history['test_loss'].append(current_test_loss)
|
|
514
|
+
metrics_history['test_accuracy'].append(current_test_accuracy)
|
|
515
|
+
|
|
516
|
+
# Early Stopping Check
|
|
517
|
+
if current_test_loss < best_test_loss:
|
|
518
|
+
best_test_loss = current_test_loss # Update best test loss
|
|
519
|
+
strikes = 0
|
|
520
|
+
# elif current_test_accuracy > best_test_accuracy:
|
|
521
|
+
# best_test_accuracy = current_test_accuracy # Update best test accuracy
|
|
522
|
+
# strikes = 0
|
|
523
|
+
elif current_train_loss >= best_train_loss:
|
|
524
|
+
strikes = 0
|
|
525
|
+
elif current_test_loss > best_test_loss and current_train_loss < best_train_loss:
|
|
526
|
+
strikes += 1
|
|
527
|
+
elif current_train_loss < best_train_loss:
|
|
528
|
+
best_train_loss = current_train_loss # Update best train loss
|
|
529
|
+
|
|
530
|
+
if strikes >= patience:
|
|
531
|
+
print(f"\n Early stopping at epoch {epoch+1} due to {patience} consecutive increases in loss gap \n")
|
|
532
|
+
break
|
|
533
|
+
|
|
534
|
+
# Plotting (optional)
|
|
535
|
+
if plot_flag and epoch % 1 == 0:
|
|
536
|
+
clear_output(wait=True)
|
|
537
|
+
|
|
538
|
+
print(f"=== Training model for group {group_id}: {group_params} ===")
|
|
539
|
+
|
|
540
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
|
541
|
+
|
|
542
|
+
# Loss subplot
|
|
543
|
+
ax1.set_title(f'Loss for M:{M} and N:{N}')
|
|
544
|
+
for dataset in ('train', 'test'):
|
|
545
|
+
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
|
|
546
|
+
ax1.legend()
|
|
547
|
+
ax1.set_yscale("log")
|
|
548
|
+
|
|
549
|
+
# Accuracy subplot
|
|
550
|
+
ax2.set_title('Accuracy')
|
|
551
|
+
for dataset in ('train', 'test'):
|
|
552
|
+
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
|
|
553
|
+
ax2.legend()
|
|
554
|
+
|
|
555
|
+
plt.show()
|
|
556
|
+
|
|
557
|
+
if epoch == epochs-1:
|
|
558
|
+
print(f"\n Reached maximum epochs: {epochs} \n")
|
|
559
|
+
|
|
560
|
+
return model, metrics_history, key
|
|
561
|
+
|
|
562
|
+
def save_autoregressive_nn(models_per_group, path, model_config):
|
|
563
|
+
"""
|
|
564
|
+
Save an autoregressive stack of NNX models.
|
|
360
565
|
|
|
361
|
-
|
|
566
|
+
Parameters
|
|
567
|
+
----------
|
|
568
|
+
models_per_group : list[nnx.Module]
|
|
569
|
+
One model per autoregressive group.
|
|
570
|
+
path : str
|
|
571
|
+
Checkpoint directory.
|
|
572
|
+
model_config : dict
|
|
573
|
+
Full model configuration (shared + per-group).
|
|
574
|
+
"""
|
|
362
575
|
ckpt_dir = os.path.abspath(path)
|
|
363
576
|
ckpt_dir = ocp.test_utils.erase_and_create_empty(ckpt_dir)
|
|
364
577
|
|
|
365
|
-
# Split the model into GraphDef (structure) and State (parameters + buffers)
|
|
366
|
-
_, _, _, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)
|
|
367
|
-
|
|
368
|
-
# Display for debugging (optional)
|
|
369
|
-
# nnx.display(state)
|
|
370
|
-
|
|
371
|
-
# Initialize the checkpointer
|
|
372
578
|
checkpointer = ocp.StandardCheckpointer()
|
|
373
579
|
|
|
374
|
-
|
|
375
|
-
|
|
580
|
+
for g, model in enumerate(models_per_group):
|
|
581
|
+
# Split model into graph-independent state
|
|
582
|
+
_, _, _, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)
|
|
583
|
+
checkpointer.save(ckpt_dir / f"state_group_{g}", state)
|
|
376
584
|
|
|
377
|
-
# Save
|
|
378
|
-
with open(ckpt_dir /
|
|
379
|
-
json.dump(model_config, f)
|
|
585
|
+
# Save configuration
|
|
586
|
+
with open(ckpt_dir / "config.json", "w") as f:
|
|
587
|
+
json.dump(model_config, f, indent=2)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ximinf
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.16
|
|
4
4
|
Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
|
|
5
5
|
Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
|
|
6
6
|
License: GPL-3.0-or-later
|
|
7
7
|
Project-URL: Homepage, https://github.com/a-trigui/ximinf
|
|
8
|
+
Project-URL: Documentation, https://ximinf.readthedocs.io
|
|
8
9
|
Keywords: cosmology,supernovae,simulation based inference
|
|
9
10
|
Classifier: Programming Language :: Python :: 3
|
|
10
11
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
# src/ximinf/__init__.py
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
|
|
4
|
-
# Simulation libraries
|
|
5
|
-
import skysurvey
|
|
6
|
-
import skysurvey_sniapop
|
|
7
|
-
import ztfidr.simulation as sim
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
# def flatten_df(df: pd.DataFrame, columns: list, params: list = None) -> np.ndarray:
|
|
11
|
-
# """
|
|
12
|
-
# Flatten selected columns from a DataFrame into a single 1D numpy array.
|
|
13
|
-
|
|
14
|
-
# Parameters
|
|
15
|
-
# ----------
|
|
16
|
-
# df : pd.DataFrame
|
|
17
|
-
# Input dataframe containing the data.
|
|
18
|
-
# columns : list of str
|
|
19
|
-
# Column names to extract and flatten.
|
|
20
|
-
# prepend_params : list or None
|
|
21
|
-
# Optional list of parameters to prepend to the flattened array.
|
|
22
|
-
|
|
23
|
-
# Returns
|
|
24
|
-
# -------
|
|
25
|
-
# np.ndarray
|
|
26
|
-
# 1D array containing [prepend_params..., col1..., col2..., ...]
|
|
27
|
-
# """
|
|
28
|
-
# arrays = [df[col].to_numpy(dtype=np.float32) for col in columns]
|
|
29
|
-
# flat = np.concatenate(arrays)
|
|
30
|
-
|
|
31
|
-
# if params is not None:
|
|
32
|
-
# flat = np.concatenate([np.array(params, dtype=np.float32), flat])
|
|
33
|
-
|
|
34
|
-
# return flat
|
|
35
|
-
|
|
36
|
-
# def unflatten_array(flat_array: np.ndarray, columns: list, n_points: int = 0):
|
|
37
|
-
# """
|
|
38
|
-
# Convert a flattened array back into its original columns and optional prepended parameters.
|
|
39
|
-
|
|
40
|
-
# Parameters
|
|
41
|
-
# ----------
|
|
42
|
-
# flat_array : np.ndarray
|
|
43
|
-
# 1D array containing the prepended parameters (optional) and column data.
|
|
44
|
-
# columns : list of str
|
|
45
|
-
# Original column names in the same order as they were flattened.
|
|
46
|
-
# n_points : int
|
|
47
|
-
# Number of rows (SNe) in the data. If > 0, the function will deduce
|
|
48
|
-
# the number of prepended parameters automatically.
|
|
49
|
-
|
|
50
|
-
# Returns
|
|
51
|
-
# -------
|
|
52
|
-
# tuple
|
|
53
|
-
# If prepended_params exist: (prepended_params, df)
|
|
54
|
-
# Else: df
|
|
55
|
-
# """
|
|
56
|
-
# flat_array = flat_array.astype(np.float32)
|
|
57
|
-
|
|
58
|
-
# if n_points > 0:
|
|
59
|
-
# # Deduce number of prepended parameters
|
|
60
|
-
# n_params = flat_array.size - n_points * len(columns)
|
|
61
|
-
# if n_params < 0:
|
|
62
|
-
# raise ValueError("Number of points incompatible with flat array size")
|
|
63
|
-
# prepended_params = flat_array[:n_params] if n_params > 0 else None
|
|
64
|
-
# data_array = flat_array[n_params:]
|
|
65
|
-
# else:
|
|
66
|
-
# prepended_params = None
|
|
67
|
-
# data_array = flat_array
|
|
68
|
-
|
|
69
|
-
# n_rows = data_array.size // len(columns)
|
|
70
|
-
# if n_rows * len(columns) != data_array.size:
|
|
71
|
-
# raise ValueError("Flat array size is not compatible with number of columns")
|
|
72
|
-
|
|
73
|
-
# # Split array into columns
|
|
74
|
-
# split_arrays = np.split(data_array, len(columns))
|
|
75
|
-
# df = pd.DataFrame({col: arr for col, arr in zip(columns, split_arrays)})
|
|
76
|
-
|
|
77
|
-
# if prepended_params is not None:
|
|
78
|
-
# return prepended_params, df
|
|
79
|
-
# else:
|
|
80
|
-
# return df
|
|
81
|
-
|
|
82
|
-
def simulate_one(params_dict, sigma_int, z_max, M, cols, N=None, i=None):
|
|
83
|
-
"""
|
|
84
|
-
params_dict: dict of model parameters (alpha, beta, mabs, gamma, etc.)
|
|
85
|
-
cols: list of columns to include in the output
|
|
86
|
-
Returns a dict with:
|
|
87
|
-
'data': dict of lists (one per column)
|
|
88
|
-
'params': dict of parameter values
|
|
89
|
-
"""
|
|
90
|
-
# Print progress
|
|
91
|
-
if N is not None and i is not None:
|
|
92
|
-
if (i+1) % max(1, N//10) == 0 or i == N-1:
|
|
93
|
-
print(f"Simulation {i+1}/{N}", end="\r", flush=True)
|
|
94
|
-
|
|
95
|
-
# Unpack parameters
|
|
96
|
-
alpha_ = float(params_dict.get("alpha", 0))
|
|
97
|
-
beta_ = float(params_dict.get("beta", 0))
|
|
98
|
-
mabs_ = float(params_dict.get("mabs", 0))
|
|
99
|
-
gamma_ = float(params_dict.get("gamma", 0))
|
|
100
|
-
|
|
101
|
-
brokenalpha_model = skysurvey_sniapop.brokenalpha_model
|
|
102
|
-
|
|
103
|
-
# Generate SNe sample
|
|
104
|
-
snia = skysurvey.SNeIa.from_draw(
|
|
105
|
-
size=M,
|
|
106
|
-
zmax=z_max,
|
|
107
|
-
model=brokenalpha_model,
|
|
108
|
-
magabs={
|
|
109
|
-
"x1": "@x1",
|
|
110
|
-
"c": "@c",
|
|
111
|
-
"mabs": mabs_,
|
|
112
|
-
"sigmaint": sigma_int,
|
|
113
|
-
"alpha_low": alpha_,
|
|
114
|
-
"alpha_high": alpha_,
|
|
115
|
-
"beta": beta_,
|
|
116
|
-
"gamma": gamma_
|
|
117
|
-
}
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
# Apply noise
|
|
121
|
-
errormodel = sim.noise_model
|
|
122
|
-
errormodel["localcolor"]["kwargs"]["a"] = 2
|
|
123
|
-
errormodel["localcolor"]["kwargs"]["loc"] = 0.005
|
|
124
|
-
errormodel["localcolor"]["kwargs"]["scale"] = 0.05
|
|
125
|
-
noisy_snia = snia.apply_gaussian_noise(errormodel)
|
|
126
|
-
|
|
127
|
-
df = noisy_snia.data
|
|
128
|
-
|
|
129
|
-
# Collect requested columns as lists
|
|
130
|
-
data_dict = {col: list(df[col]) for col in cols if col in df}
|
|
131
|
-
|
|
132
|
-
return data_dict
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
# Standard
|
|
2
|
-
import os
|
|
3
|
-
import json
|
|
4
|
-
|
|
5
|
-
# Jax
|
|
6
|
-
from flax import nnx
|
|
7
|
-
|
|
8
|
-
# Checkpointing
|
|
9
|
-
import orbax.checkpoint as ocp # Checkpointing library
|
|
10
|
-
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
|
|
11
|
-
import pathlib # File path handling library
|
|
12
|
-
|
|
13
|
-
# Modules
|
|
14
|
-
import ximinf.nn_train as nntr
|
|
15
|
-
|
|
16
|
-
def load_nn(path):
|
|
17
|
-
# Define the checkpoint directory
|
|
18
|
-
ckpt_dir = os.path.abspath(path)
|
|
19
|
-
ckpt_dir = pathlib.Path(ckpt_dir).resolve()
|
|
20
|
-
|
|
21
|
-
# Ensure the folder is removed before saving
|
|
22
|
-
if ckpt_dir.exists()==False:
|
|
23
|
-
# Make an error
|
|
24
|
-
raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
|
|
25
|
-
|
|
26
|
-
# Load model configuration
|
|
27
|
-
config_path = ckpt_dir / 'config.json'
|
|
28
|
-
if not config_path.exists():
|
|
29
|
-
raise ValueError("Model config file not found in checkpoint directory.")
|
|
30
|
-
|
|
31
|
-
with open(config_path, 'r') as f:
|
|
32
|
-
model_config = json.load(f)
|
|
33
|
-
|
|
34
|
-
Nsize_p = model_config['Nsize_p']
|
|
35
|
-
Nsize_r = model_config['Nsize_r']
|
|
36
|
-
n_cols = model_config['n_cols']
|
|
37
|
-
n_params = model_config['n_params']
|
|
38
|
-
|
|
39
|
-
# 1. Re-create the checkpointer
|
|
40
|
-
checkpointer = ocp.StandardCheckpointer()
|
|
41
|
-
|
|
42
|
-
# Split the model into GraphDef (structure) and State (parameters + buffers)
|
|
43
|
-
abstract_model = nnx.eval_shape(lambda: nntr.DeepSetClassifier(0.05, Nsize_p, Nsize_r, n_cols, n_params, rngs=nnx.Rngs(0)))
|
|
44
|
-
abs_graphdef, abs_rngkey, abs_rngcount, _ = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)
|
|
45
|
-
|
|
46
|
-
# 3. Restore
|
|
47
|
-
state_restored = checkpointer.restore(ckpt_dir / 'state')
|
|
48
|
-
#jax.tree.map(np.testing.assert_array_equal, abstract_state, state_restored)
|
|
49
|
-
print('NNX State restored: ')
|
|
50
|
-
# nnx.display(state_restored)
|
|
51
|
-
|
|
52
|
-
model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)
|
|
53
|
-
|
|
54
|
-
nnx.display(model)
|
|
55
|
-
|
|
56
|
-
return model
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|