ximinf 0.0.2__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ximinf-0.0.2/src/ximinf.egg-info → ximinf-0.0.8}/PKG-INFO +2 -1
- {ximinf-0.0.2 → ximinf-0.0.8}/pyproject.toml +2 -2
- ximinf-0.0.8/src/ximinf/__init__.py +7 -0
- ximinf-0.0.8/src/ximinf/generate_sim.py +120 -0
- ximinf-0.0.8/src/ximinf/nn_inference.py +136 -0
- ximinf-0.0.8/src/ximinf/nn_test.py +453 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/src/ximinf/nn_train.py +185 -84
- {ximinf-0.0.2 → ximinf-0.0.8/src/ximinf.egg-info}/PKG-INFO +2 -1
- ximinf-0.0.2/src/ximinf/__init__.py +0 -1
- ximinf-0.0.2/src/ximinf/generate_sim.py +0 -132
- ximinf-0.0.2/src/ximinf/nn_inference.py +0 -56
- ximinf-0.0.2/src/ximinf/nn_test.py +0 -247
- {ximinf-0.0.2 → ximinf-0.0.8}/LICENSE +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/README.md +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/setup.cfg +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/src/ximinf/selection_effects.py +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/src/ximinf.egg-info/SOURCES.txt +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/src/ximinf.egg-info/dependency_links.txt +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/src/ximinf.egg-info/requires.txt +0 -0
- {ximinf-0.0.2 → ximinf-0.0.8}/src/ximinf.egg-info/top_level.txt +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ximinf
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
|
|
5
5
|
Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
|
|
6
6
|
License: GPL-3.0-or-later
|
|
7
7
|
Project-URL: Homepage, https://github.com/a-trigui/ximinf
|
|
8
|
+
Project-URL: Documentation, https://ximinf.readthedocs.io
|
|
8
9
|
Keywords: cosmology,supernovae,simulation based inference
|
|
9
10
|
Classifier: Programming Language :: Python :: 3
|
|
10
11
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "ximinf"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.8"
|
|
8
8
|
description = "Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae. "
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -35,7 +35,7 @@ classifiers = [
|
|
|
35
35
|
|
|
36
36
|
[project.urls]
|
|
37
37
|
Homepage = "https://github.com/a-trigui/ximinf"
|
|
38
|
-
|
|
38
|
+
Documentation = "https://ximinf.readthedocs.io"
|
|
39
39
|
|
|
40
40
|
[project.optional-dependencies]
|
|
41
41
|
notebooks = ["jupyter", "matplotlib"]
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Simulation libraries
|
|
2
|
+
import skysurvey
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pyDOE import lhs # LHS sampler
|
|
5
|
+
|
|
6
|
+
def scan_params(ranges, N, dtype=np.float32):
|
|
7
|
+
"""
|
|
8
|
+
Generate sampled parameter sets using Latin Hypercube Sampling (LHS).
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
ranges : dict
|
|
13
|
+
Mapping parameter names to (min, max) tuples.
|
|
14
|
+
N : int
|
|
15
|
+
Number of samples.
|
|
16
|
+
dtype : data-type, optional
|
|
17
|
+
Numeric type for the sampled arrays (default is np.float32).
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
params_dict : dict
|
|
22
|
+
Dictionary of parameter arrays of shape (N,).
|
|
23
|
+
"""
|
|
24
|
+
param_names = list(ranges.keys())
|
|
25
|
+
n_params = len(param_names)
|
|
26
|
+
|
|
27
|
+
# LHS unit samples in [0,1]
|
|
28
|
+
unit_samples = lhs(n_params, samples=N)
|
|
29
|
+
|
|
30
|
+
# Scale unit samples to parameter ranges
|
|
31
|
+
params_dict = {}
|
|
32
|
+
for i, p in enumerate(param_names):
|
|
33
|
+
low, high = ranges[p]
|
|
34
|
+
params_dict[p] = (unit_samples[:, i] * (high - low) + low).astype(dtype)
|
|
35
|
+
|
|
36
|
+
return params_dict
|
|
37
|
+
|
|
38
|
+
def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
|
|
39
|
+
"""
|
|
40
|
+
Simulate a single dataset of SNe Ia.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
params_dict : dict
|
|
45
|
+
Dictionary of model parameters (alpha, beta, mabs, gamma, sigma_int, etc.).
|
|
46
|
+
z_max : float
|
|
47
|
+
Maximum redshift.
|
|
48
|
+
M : int
|
|
49
|
+
Number of SNe to simulate.
|
|
50
|
+
cols : list of str
|
|
51
|
+
List of columns to include in the output.
|
|
52
|
+
N : int, optional
|
|
53
|
+
Total number of simulations (for progress printing).
|
|
54
|
+
i : int, optional
|
|
55
|
+
Current simulation index (for progress printing).
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
data_dict : dict
|
|
60
|
+
Dictionary of lists (one per column) containing the simulated data.
|
|
61
|
+
"""
|
|
62
|
+
import ztfidr.simulation as sim
|
|
63
|
+
import skysurvey_sniapop
|
|
64
|
+
|
|
65
|
+
# Print progress
|
|
66
|
+
if N is not None and i is not None:
|
|
67
|
+
if (i+1) % max(1, N//10) == 0 or i == N-1:
|
|
68
|
+
print(f"Simulation {i+1}/{N}", end="\r", flush=True)
|
|
69
|
+
|
|
70
|
+
# Define default parameters including sigma_int
|
|
71
|
+
default_params = {
|
|
72
|
+
"alpha": 0.0,
|
|
73
|
+
"beta": 0.0,
|
|
74
|
+
"mabs": -19.3,
|
|
75
|
+
"gamma": 0.0,
|
|
76
|
+
"sigma_int": 0.0, # default intrinsic scatter
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Merge defaults with provided params (params_dict takes priority)
|
|
80
|
+
params = {**default_params, **params_dict}
|
|
81
|
+
|
|
82
|
+
# Ensure all are floats
|
|
83
|
+
alpha_ = float(params["alpha"])
|
|
84
|
+
beta_ = float(params["beta"])
|
|
85
|
+
mabs_ = float(params["mabs"])
|
|
86
|
+
gamma_ = float(params["gamma"])
|
|
87
|
+
sigma_int_ = float(params["sigma_int"])
|
|
88
|
+
|
|
89
|
+
brokenalpha_model = skysurvey_sniapop.brokenalpha_model
|
|
90
|
+
|
|
91
|
+
# Generate SNe sample
|
|
92
|
+
snia = skysurvey.SNeIa.from_draw(
|
|
93
|
+
size=M,
|
|
94
|
+
zmax=z_max,
|
|
95
|
+
model=brokenalpha_model,
|
|
96
|
+
magabs={
|
|
97
|
+
"x1": "@x1",
|
|
98
|
+
"c": "@c",
|
|
99
|
+
"mabs": mabs_,
|
|
100
|
+
"sigmaint": sigma_int_,
|
|
101
|
+
"alpha_low": alpha_,
|
|
102
|
+
"alpha_high": alpha_,
|
|
103
|
+
"beta": beta_,
|
|
104
|
+
"gamma": gamma_
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Apply noise
|
|
109
|
+
errormodel = sim.noise_model
|
|
110
|
+
errormodel["localcolor"]["kwargs"]["a"] = 2
|
|
111
|
+
errormodel["localcolor"]["kwargs"]["loc"] = 0.005
|
|
112
|
+
errormodel["localcolor"]["kwargs"]["scale"] = 0.05
|
|
113
|
+
noisy_snia = snia.apply_gaussian_noise(errormodel)
|
|
114
|
+
|
|
115
|
+
df = noisy_snia.data
|
|
116
|
+
|
|
117
|
+
# Collect requested columns as lists
|
|
118
|
+
data_dict = {col: list(df[col]) for col in cols if col in df}
|
|
119
|
+
|
|
120
|
+
return data_dict
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Standard
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
# Jax
|
|
6
|
+
from flax import nnx
|
|
7
|
+
|
|
8
|
+
# Checkpointing
|
|
9
|
+
import orbax.checkpoint as ocp # Checkpointing library
|
|
10
|
+
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
|
|
11
|
+
import pathlib # File path handling library
|
|
12
|
+
|
|
13
|
+
# Modules
|
|
14
|
+
import ximinf.nn_train as nntr
|
|
15
|
+
|
|
16
|
+
# def load_nn(path):
|
|
17
|
+
# """
|
|
18
|
+
# Load a neural network model from a checkpoint.
|
|
19
|
+
|
|
20
|
+
# Parameters
|
|
21
|
+
# ----------
|
|
22
|
+
# path : str
|
|
23
|
+
# Path to the checkpoint directory.
|
|
24
|
+
|
|
25
|
+
# Returns
|
|
26
|
+
# -------
|
|
27
|
+
# model : nnx.Module
|
|
28
|
+
# The loaded neural network model.
|
|
29
|
+
|
|
30
|
+
# Raises
|
|
31
|
+
# ------
|
|
32
|
+
# ValueError
|
|
33
|
+
# If the checkpoint directory or config file does not exist.
|
|
34
|
+
# """
|
|
35
|
+
# # Define the checkpoint directory
|
|
36
|
+
# ckpt_dir = os.path.abspath(path)
|
|
37
|
+
# ckpt_dir = pathlib.Path(ckpt_dir).resolve()
|
|
38
|
+
|
|
39
|
+
# # Ensure the folder is removed before saving
|
|
40
|
+
# if ckpt_dir.exists()==False:
|
|
41
|
+
# # Make an error
|
|
42
|
+
# raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
|
|
43
|
+
|
|
44
|
+
# # Load model configuration
|
|
45
|
+
# config_path = ckpt_dir / 'config.json'
|
|
46
|
+
# if not config_path.exists():
|
|
47
|
+
# raise ValueError("Model config file not found in checkpoint directory.")
|
|
48
|
+
|
|
49
|
+
# with open(config_path, 'r') as f:
|
|
50
|
+
# model_config = json.load(f)
|
|
51
|
+
|
|
52
|
+
# Nsize_p = model_config['Nsize_p']
|
|
53
|
+
# Nsize_r = model_config['Nsize_r']
|
|
54
|
+
# n_cols = model_config['n_cols']
|
|
55
|
+
# n_params = model_config['n_params']
|
|
56
|
+
# N_size_embed = model_config['N_size_embed']
|
|
57
|
+
|
|
58
|
+
# # 1. Re-create the checkpointer
|
|
59
|
+
# checkpointer = ocp.StandardCheckpointer()
|
|
60
|
+
|
|
61
|
+
# # Split the model into GraphDef (structure) and State (parameters + buffers)
|
|
62
|
+
# abstract_model = nnx.eval_shape(lambda: nntr.DeepSetClassifier(0.0, Nsize_p, Nsize_r, N_size_embed, n_cols, n_params, rngs=nnx.Rngs(0)))
|
|
63
|
+
# abs_graphdef, abs_rngkey, abs_rngcount, _ = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)
|
|
64
|
+
|
|
65
|
+
# # 3. Restore
|
|
66
|
+
# state_restored = checkpointer.restore(ckpt_dir / 'state')
|
|
67
|
+
# print('NNX State restored: ')
|
|
68
|
+
|
|
69
|
+
# model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)
|
|
70
|
+
|
|
71
|
+
# nnx.display(model)
|
|
72
|
+
|
|
73
|
+
# return model
|
|
74
|
+
|
|
75
|
+
def load_autoregressive_nn(path):
|
|
76
|
+
"""
|
|
77
|
+
Load an autoregressive stack of NNX models.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
path : str
|
|
82
|
+
Checkpoint directory.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
models_per_group : list[nnx.Module]
|
|
87
|
+
Reconstructed models, one per group.
|
|
88
|
+
model_config : dict
|
|
89
|
+
Loaded configuration dictionary.
|
|
90
|
+
"""
|
|
91
|
+
ckpt_dir = pathlib.Path(path).resolve()
|
|
92
|
+
if not ckpt_dir.exists():
|
|
93
|
+
raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
|
94
|
+
|
|
95
|
+
config_path = ckpt_dir / "config.json"
|
|
96
|
+
if not config_path.exists():
|
|
97
|
+
raise ValueError("Model config file not found.")
|
|
98
|
+
|
|
99
|
+
with open(config_path, "r") as f:
|
|
100
|
+
model_config = json.load(f)
|
|
101
|
+
|
|
102
|
+
shared = model_config["shared"]
|
|
103
|
+
group_configs = model_config["groups"]
|
|
104
|
+
|
|
105
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
106
|
+
models_per_group = []
|
|
107
|
+
|
|
108
|
+
for gconf in group_configs:
|
|
109
|
+
n_params_visible = gconf["n_params_visible"]
|
|
110
|
+
|
|
111
|
+
# Recreate abstract model (shape-only)
|
|
112
|
+
abstract_model = nnx.eval_shape(
|
|
113
|
+
lambda: nntr.DeepSetClassifier( # It should not work, there is no class DeepSetClassifier defined in nntr, check how this should be properly done
|
|
114
|
+
dropout_rate=0.0,
|
|
115
|
+
Nsize_p=shared["Nsize_p"],
|
|
116
|
+
Nsize_r=shared["Nsize_r"],
|
|
117
|
+
N_size_embed=shared["N_size_embed"],
|
|
118
|
+
n_cols=shared["n_cols"],
|
|
119
|
+
n_params=n_params_visible,
|
|
120
|
+
rngs=nnx.Rngs(0),
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
graphdef, rngkey, rngcount, _ = nnx.split(
|
|
125
|
+
abstract_model, nnx.RngKey, nnx.RngCount, ...
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Restore parameters
|
|
129
|
+
state = checkpointer.restore(
|
|
130
|
+
ckpt_dir / f"state_group_{gconf['group_id']}"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
model = nnx.merge(graphdef, rngkey, rngcount, state)
|
|
134
|
+
models_per_group.append(model)
|
|
135
|
+
|
|
136
|
+
return models_per_group, model_config
|