ximinf 0.0.2__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ximinf
3
- Version: 0.0.2
3
+ Version: 0.0.8
4
4
  Summary: Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae.
5
5
  Author-email: Adam Trigui <a.trigui@ip2i.in2p3.fr>
6
6
  License: GPL-3.0-or-later
7
7
  Project-URL: Homepage, https://github.com/a-trigui/ximinf
8
+ Project-URL: Documentation, https://ximinf.readthedocs.io
8
9
  Keywords: cosmology,supernovae,simulation based inference
9
10
  Classifier: Programming Language :: Python :: 3
10
11
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ximinf"
7
- version = "0.0.2"
7
+ version = "0.0.8"
8
8
  description = "Simulation Based Inference of Cosmological parameters in Jax using type Ia supernovae. "
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -35,7 +35,7 @@ classifiers = [
35
35
 
36
36
  [project.urls]
37
37
  Homepage = "https://github.com/a-trigui/ximinf"
38
- # Documentation = "https://my_package.readthedocs.io"
38
+ Documentation = "https://ximinf.readthedocs.io"
39
39
 
40
40
  [project.optional-dependencies]
41
41
  notebooks = ["jupyter", "matplotlib"]
@@ -0,0 +1,7 @@
1
+ # src/ximinf/__init__.py
2
+
3
+ # from .generate_sim import *
4
+ # from .nn_inference import *
5
+ # from .nn_train import *
6
+ # from .nn_test import *
7
+ # from .selection_effects import *
@@ -0,0 +1,120 @@
1
+ # Simulation libraries
2
+ import skysurvey
3
+ import numpy as np
4
+ from pyDOE import lhs # LHS sampler
5
+
6
+ def scan_params(ranges, N, dtype=np.float32):
7
+ """
8
+ Generate sampled parameter sets using Latin Hypercube Sampling (LHS).
9
+
10
+ Parameters
11
+ ----------
12
+ ranges : dict
13
+ Mapping parameter names to (min, max) tuples.
14
+ N : int
15
+ Number of samples.
16
+ dtype : data-type, optional
17
+ Numeric type for the sampled arrays (default is np.float32).
18
+
19
+ Returns
20
+ -------
21
+ params_dict : dict
22
+ Dictionary of parameter arrays of shape (N,).
23
+ """
24
+ param_names = list(ranges.keys())
25
+ n_params = len(param_names)
26
+
27
+ # LHS unit samples in [0,1]
28
+ unit_samples = lhs(n_params, samples=N)
29
+
30
+ # Scale unit samples to parameter ranges
31
+ params_dict = {}
32
+ for i, p in enumerate(param_names):
33
+ low, high = ranges[p]
34
+ params_dict[p] = (unit_samples[:, i] * (high - low) + low).astype(dtype)
35
+
36
+ return params_dict
37
+
38
+ def simulate_one(params_dict, z_max, M, cols, N=None, i=None):
39
+ """
40
+ Simulate a single dataset of SNe Ia.
41
+
42
+ Parameters
43
+ ----------
44
+ params_dict : dict
45
+ Dictionary of model parameters (alpha, beta, mabs, gamma, sigma_int, etc.).
46
+ z_max : float
47
+ Maximum redshift.
48
+ M : int
49
+ Number of SNe to simulate.
50
+ cols : list of str
51
+ List of columns to include in the output.
52
+ N : int, optional
53
+ Total number of simulations (for progress printing).
54
+ i : int, optional
55
+ Current simulation index (for progress printing).
56
+
57
+ Returns
58
+ -------
59
+ data_dict : dict
60
+ Dictionary of lists (one per column) containing the simulated data.
61
+ """
62
+ import ztfidr.simulation as sim
63
+ import skysurvey_sniapop
64
+
65
+ # Print progress
66
+ if N is not None and i is not None:
67
+ if (i+1) % max(1, N//10) == 0 or i == N-1:
68
+ print(f"Simulation {i+1}/{N}", end="\r", flush=True)
69
+
70
+ # Define default parameters including sigma_int
71
+ default_params = {
72
+ "alpha": 0.0,
73
+ "beta": 0.0,
74
+ "mabs": -19.3,
75
+ "gamma": 0.0,
76
+ "sigma_int": 0.0, # default intrinsic scatter
77
+ }
78
+
79
+ # Merge defaults with provided params (params_dict takes priority)
80
+ params = {**default_params, **params_dict}
81
+
82
+ # Ensure all are floats
83
+ alpha_ = float(params["alpha"])
84
+ beta_ = float(params["beta"])
85
+ mabs_ = float(params["mabs"])
86
+ gamma_ = float(params["gamma"])
87
+ sigma_int_ = float(params["sigma_int"])
88
+
89
+ brokenalpha_model = skysurvey_sniapop.brokenalpha_model
90
+
91
+ # Generate SNe sample
92
+ snia = skysurvey.SNeIa.from_draw(
93
+ size=M,
94
+ zmax=z_max,
95
+ model=brokenalpha_model,
96
+ magabs={
97
+ "x1": "@x1",
98
+ "c": "@c",
99
+ "mabs": mabs_,
100
+ "sigmaint": sigma_int_,
101
+ "alpha_low": alpha_,
102
+ "alpha_high": alpha_,
103
+ "beta": beta_,
104
+ "gamma": gamma_
105
+ }
106
+ )
107
+
108
+ # Apply noise
109
+ errormodel = sim.noise_model
110
+ errormodel["localcolor"]["kwargs"]["a"] = 2
111
+ errormodel["localcolor"]["kwargs"]["loc"] = 0.005
112
+ errormodel["localcolor"]["kwargs"]["scale"] = 0.05
113
+ noisy_snia = snia.apply_gaussian_noise(errormodel)
114
+
115
+ df = noisy_snia.data
116
+
117
+ # Collect requested columns as lists
118
+ data_dict = {col: list(df[col]) for col in cols if col in df}
119
+
120
+ return data_dict
@@ -0,0 +1,136 @@
1
+ # Standard
2
+ import os
3
+ import json
4
+
5
+ # Jax
6
+ from flax import nnx
7
+
8
+ # Checkpointing
9
+ import orbax.checkpoint as ocp # Checkpointing library
10
+ ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
11
+ import pathlib # File path handling library
12
+
13
+ # Modules
14
+ import ximinf.nn_train as nntr
15
+
16
+ # def load_nn(path):
17
+ # """
18
+ # Load a neural network model from a checkpoint.
19
+
20
+ # Parameters
21
+ # ----------
22
+ # path : str
23
+ # Path to the checkpoint directory.
24
+
25
+ # Returns
26
+ # -------
27
+ # model : nnx.Module
28
+ # The loaded neural network model.
29
+
30
+ # Raises
31
+ # ------
32
+ # ValueError
33
+ # If the checkpoint directory or config file does not exist.
34
+ # """
35
+ # # Define the checkpoint directory
36
+ # ckpt_dir = os.path.abspath(path)
37
+ # ckpt_dir = pathlib.Path(ckpt_dir).resolve()
38
+
39
+ # # Ensure the folder is removed before saving
40
+ # if ckpt_dir.exists()==False:
41
+ # # Make an error
42
+ # raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
43
+
44
+ # # Load model configuration
45
+ # config_path = ckpt_dir / 'config.json'
46
+ # if not config_path.exists():
47
+ # raise ValueError("Model config file not found in checkpoint directory.")
48
+
49
+ # with open(config_path, 'r') as f:
50
+ # model_config = json.load(f)
51
+
52
+ # Nsize_p = model_config['Nsize_p']
53
+ # Nsize_r = model_config['Nsize_r']
54
+ # n_cols = model_config['n_cols']
55
+ # n_params = model_config['n_params']
56
+ # N_size_embed = model_config['N_size_embed']
57
+
58
+ # # 1. Re-create the checkpointer
59
+ # checkpointer = ocp.StandardCheckpointer()
60
+
61
+ # # Split the model into GraphDef (structure) and State (parameters + buffers)
62
+ # abstract_model = nnx.eval_shape(lambda: nntr.DeepSetClassifier(0.0, Nsize_p, Nsize_r, N_size_embed, n_cols, n_params, rngs=nnx.Rngs(0)))
63
+ # abs_graphdef, abs_rngkey, abs_rngcount, _ = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)
64
+
65
+ # # 3. Restore
66
+ # state_restored = checkpointer.restore(ckpt_dir / 'state')
67
+ # print('NNX State restored: ')
68
+
69
+ # model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)
70
+
71
+ # nnx.display(model)
72
+
73
+ # return model
74
+
75
+ def load_autoregressive_nn(path):
76
+ """
77
+ Load an autoregressive stack of NNX models.
78
+
79
+ Parameters
80
+ ----------
81
+ path : str
82
+ Checkpoint directory.
83
+
84
+ Returns
85
+ -------
86
+ models_per_group : list[nnx.Module]
87
+ Reconstructed models, one per group.
88
+ model_config : dict
89
+ Loaded configuration dictionary.
90
+ """
91
+ ckpt_dir = pathlib.Path(path).resolve()
92
+ if not ckpt_dir.exists():
93
+ raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist.")
94
+
95
+ config_path = ckpt_dir / "config.json"
96
+ if not config_path.exists():
97
+ raise ValueError("Model config file not found.")
98
+
99
+ with open(config_path, "r") as f:
100
+ model_config = json.load(f)
101
+
102
+ shared = model_config["shared"]
103
+ group_configs = model_config["groups"]
104
+
105
+ checkpointer = ocp.StandardCheckpointer()
106
+ models_per_group = []
107
+
108
+ for gconf in group_configs:
109
+ n_params_visible = gconf["n_params_visible"]
110
+
111
+ # Recreate abstract model (shape-only)
112
+ abstract_model = nnx.eval_shape(
113
+ lambda: nntr.DeepSetClassifier( # It should not work, there is no class DeepSetClassifier defined in nntr, check how this should be properly done
114
+ dropout_rate=0.0,
115
+ Nsize_p=shared["Nsize_p"],
116
+ Nsize_r=shared["Nsize_r"],
117
+ N_size_embed=shared["N_size_embed"],
118
+ n_cols=shared["n_cols"],
119
+ n_params=n_params_visible,
120
+ rngs=nnx.Rngs(0),
121
+ )
122
+ )
123
+
124
+ graphdef, rngkey, rngcount, _ = nnx.split(
125
+ abstract_model, nnx.RngKey, nnx.RngCount, ...
126
+ )
127
+
128
+ # Restore parameters
129
+ state = checkpointer.restore(
130
+ ckpt_dir / f"state_group_{gconf['group_id']}"
131
+ )
132
+
133
+ model = nnx.merge(graphdef, rngkey, rngcount, state)
134
+ models_per_group.append(model)
135
+
136
+ return models_per_group, model_config