gensbi-examples 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,343 @@
1
+ import jax
2
+ from jax import numpy as jnp
3
+ import grain
4
+ import numpy as np
5
+
6
+ from datasets import load_dataset
7
+ from huggingface_hub import hf_hub_download
8
+ import json
9
+
10
+ # from .utils import download_artifacts
11
+ from .graph import faithfull_mask, min_faithfull_mask, moralize
12
+
13
+
14
+ def process_joint(batch):
15
+ cond = batch["xs"][..., None]
16
+ obs = batch["thetas"][..., None]
17
+ data = np.concatenate((obs, cond), axis=1)
18
+ return data
19
+
20
+
21
+ def process_conditional(batch):
22
+ cond = batch["xs"][..., None]
23
+ obs = batch["thetas"][..., None]
24
+ return obs, cond
25
+
26
+
27
+ class Task:
28
+ def __init__(self, task_name, kind="joint"):
29
+
30
+ self.repo_name = "aurelio-amerio/SBI-benchmarks"
31
+
32
+ self.task_name = task_name
33
+
34
+ fname = hf_hub_download(
35
+ repo_id=self.repo_name, filename="metadata.json", repo_type="dataset"
36
+ )
37
+ with open(fname, "r") as f:
38
+ metadata = json.load(f)
39
+
40
+ self.dataset = load_dataset(self.repo_name, task_name).with_format("numpy")
41
+ self.dataset_posterior = load_dataset(
42
+ self.repo_name, f"{task_name}_posterior"
43
+ ).with_format("numpy")
44
+
45
+ self.max_samples = self.dataset["train"].num_rows
46
+
47
+ self.observations = self.dataset_posterior["reference_posterior"][
48
+ "observations"
49
+ ]
50
+ self.reference_samples = self.dataset_posterior["reference_posterior"][
51
+ "reference_samples"
52
+ ]
53
+
54
+ self.true_parameters = self.dataset_posterior["reference_posterior"][
55
+ "true_parameters"
56
+ ]
57
+
58
+ self.dim_cond = metadata[task_name]["dim_cond"]
59
+ self.dim_obs = metadata[task_name]["dim_obs"]
60
+
61
+ self.dim_joint = self.dim_cond + self.dim_obs
62
+
63
+ self.num_observations = len(self.observations)
64
+ self.kind = kind
65
+
66
+ if kind == "joint":
67
+ self.process_fn = process_joint
68
+ elif kind == "conditional":
69
+ self.process_fn = process_conditional
70
+ else:
71
+ raise ValueError(f"Unknown kind: {kind}")
72
+
73
+ def get_train_dataset(self, batch_size, nsamples=1e5):
74
+ assert (
75
+ nsamples < self.max_samples
76
+ ), f"nsamples must be less than {self.max_samples}"
77
+
78
+ df = self.dataset["train"].select(range(int(nsamples))) # [:]
79
+
80
+ dataset_grain = (
81
+ grain.MapDataset.source(df).shuffle(42).repeat().to_iter_dataset()
82
+ )
83
+
84
+ performance_config = grain.experimental.pick_performance_config(
85
+ ds=dataset_grain,
86
+ ram_budget_mb=1024 * 4,
87
+ max_workers=None,
88
+ max_buffer_size=None,
89
+ )
90
+
91
+ dataset_batched = (
92
+ dataset_grain.batch(batch_size)
93
+ .map(self.process_fn)
94
+ .mp_prefetch(performance_config.multiprocessing_options)
95
+ )
96
+
97
+ return dataset_batched
98
+
99
+ def get_val_dataset(self, batch_size):
100
+ df = self.dataset["validation"] # [:]
101
+
102
+ val_dataset_grain = (
103
+ grain.MapDataset.source(df).shuffle(42).repeat().to_iter_dataset()
104
+ )
105
+ performance_config = grain.experimental.pick_performance_config(
106
+ ds=val_dataset_grain,
107
+ ram_budget_mb=1024 * 4,
108
+ max_workers=None,
109
+ max_buffer_size=None,
110
+ )
111
+ val_dataset_grain = (
112
+ val_dataset_grain.batch(batch_size)
113
+ .map(self.process_fn)
114
+ .mp_prefetch(performance_config.multiprocessing_options)
115
+ )
116
+
117
+ return val_dataset_grain
118
+
119
+ def get_test_dataset(self, batch_size):
120
+ df = self.dataset["test"] # [:]
121
+
122
+ val_dataset_grain = (
123
+ grain.MapDataset.source(df)
124
+ .shuffle(42)
125
+ .repeat()
126
+ .to_iter_dataset()
127
+ .batch(batch_size)
128
+ .map(self.process_fn)
129
+ )
130
+
131
+ return val_dataset_grain
132
+
133
+ def get_reference(self, num_observation=1):
134
+ """
135
+ Returns the reference posterior samples for a given number of observations.
136
+ """
137
+ if num_observation < 1 or num_observation > self.num_observations:
138
+ raise ValueError(
139
+ f"num_observation must be between 1 and {self.num_observations}"
140
+ )
141
+ obs = self.observations[num_observation - 1]
142
+ samples = self.reference_samples[num_observation - 1]
143
+ return obs, samples
144
+
145
+ def get_true_parameters(self, num_observation=1):
146
+ """
147
+ Returns the true parameters for a given number of observations.
148
+ """
149
+ if num_observation < 1 or num_observation > self.num_observations:
150
+ raise ValueError(
151
+ f"num_observation must be between 1 and {self.num_observations}"
152
+ )
153
+ return self.true_parameters[num_observation - 1]
154
+
155
+ def get_base_mask_fn(self):
156
+ raise NotImplementedError()
157
+
158
+ def get_edge_mask_fn(self, name="undirected"):
159
+ if name.lower() == "faithfull":
160
+ base_mask_fn = self.get_base_mask_fn()
161
+
162
+ def faithfull_edge_mask(node_id, condition_mask, meta_data=None):
163
+ base_mask = base_mask_fn(node_id, meta_data)
164
+ return faithfull_mask(base_mask, condition_mask)
165
+
166
+ return faithfull_edge_mask
167
+ elif name.lower() == "min_faithfull":
168
+ base_mask_fn = self.get_base_mask_fn()
169
+
170
+ def min_faithfull_edge_mask(node_id, condition_mask, meta_data=None):
171
+ base_mask = base_mask_fn(node_id, meta_data)
172
+
173
+ return min_faithfull_mask(base_mask, condition_mask)
174
+
175
+ return min_faithfull_edge_mask
176
+ elif name.lower() == "undirected":
177
+ base_mask_fn = self.get_base_mask_fn()
178
+
179
+ def undirected_edge_mask(node_id, condition_mask, meta_data=None):
180
+ base_mask = base_mask_fn(node_id, meta_data)
181
+ return moralize(base_mask)
182
+
183
+ return undirected_edge_mask
184
+
185
+ elif name.lower() == "directed":
186
+ base_mask_fn = self.get_base_mask_fn()
187
+
188
+ def directed_edge_mask(node_id, condition_mask, meta_data=None):
189
+ base_mask = base_mask_fn(node_id, meta_data)
190
+ return base_mask
191
+
192
+ return directed_edge_mask
193
+ elif name.lower() == "none":
194
+ return lambda node_id, condition_mask, *args, **kwargs: None
195
+ else:
196
+ raise NotImplementedError()
197
+
198
+
199
+ class TwoMoons(Task):
200
+ def __init__(self, kind="joint"):
201
+ task_name = "two_moons"
202
+ super().__init__(task_name, kind=kind)
203
+
204
+ def get_base_mask_fn(self):
205
+ theta_dim = self.dim_obs
206
+ x_dim = self.dim_cond
207
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
208
+ x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
209
+ base_mask = jnp.block(
210
+ [
211
+ [thetas_mask, jnp.zeros((theta_dim, x_dim))],
212
+ [jnp.ones((x_dim, theta_dim)), x_mask],
213
+ ]
214
+ )
215
+ base_mask = base_mask.astype(jnp.bool_)
216
+
217
+ def base_mask_fn(node_ids, node_meta_data):
218
+ return base_mask[node_ids, :][:, node_ids]
219
+
220
+ return base_mask_fn
221
+
222
+
223
+ class BernoulliGLM(Task):
224
+ def __init__(self, kind="joint"):
225
+ task_name = "bernoulli_glm"
226
+ super().__init__(task_name, kind=kind)
227
+
228
+ def get_base_mask_fn(self):
229
+ raise NotImplementedError()
230
+
231
+
232
+ class GaussianLinear(Task):
233
+ def __init__(self, kind="joint"):
234
+ task_name = "gaussian_linear"
235
+ super().__init__(task_name, kind=kind)
236
+
237
+ def get_base_mask_fn(self):
238
+ theta_dim = self.dim_obs
239
+ x_dim = self.dim_cond
240
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
241
+ x_i_mask = jnp.eye(x_dim, dtype=jnp.bool_)
242
+ base_mask = jnp.block(
243
+ [[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.eye((x_dim)), x_i_mask]]
244
+ )
245
+ base_mask = base_mask.astype(jnp.bool_)
246
+
247
+ def base_mask_fn(node_ids, node_meta_data):
248
+ return base_mask[node_ids, :][:, node_ids]
249
+
250
+ return base_mask_fn
251
+
252
+
253
+ class GaussianLinearUniform(Task):
254
+ def __init__(self, kind="joint"):
255
+ task_name = "gaussian_linear_uniform"
256
+ super().__init__(task_name, kind=kind)
257
+
258
+ def get_base_mask_fn(self):
259
+ theta_dim = self.dim_obs
260
+ x_dim = self.dim_cond
261
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
262
+ x_i_mask = jnp.eye(x_dim, dtype=jnp.bool_)
263
+ base_mask = jnp.block(
264
+ [[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.eye((x_dim)), x_i_mask]]
265
+ )
266
+ base_mask = base_mask.astype(jnp.bool_)
267
+
268
+ def base_mask_fn(node_ids, node_meta_data):
269
+ return base_mask[node_ids, :][:, node_ids]
270
+
271
+ return base_mask_fn
272
+
273
+
274
+ class GaussianMixture(Task):
275
+ def __init__(self, kind="joint"):
276
+ task_name = "gaussian_mixture"
277
+ super().__init__(task_name, kind=kind)
278
+
279
+ def get_base_mask_fn(self):
280
+ theta_dim = self.dim_obs
281
+ x_dim = self.dim_cond
282
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
283
+ x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
284
+ base_mask = jnp.block(
285
+ [
286
+ [thetas_mask, jnp.zeros((theta_dim, x_dim))],
287
+ [jnp.ones((x_dim, theta_dim)), x_mask],
288
+ ]
289
+ )
290
+ base_mask = base_mask.astype(jnp.bool_)
291
+
292
+ def base_mask_fn(node_ids, node_meta_data):
293
+ return base_mask[node_ids, :][:, node_ids]
294
+
295
+ return base_mask_fn
296
+
297
+
298
+ class SLCP(Task):
299
+ def __init__(self, kind="joint"):
300
+ task_name = "slcp"
301
+ super().__init__(task_name, kind=kind)
302
+
303
+ def get_base_mask_fn(self):
304
+ theta_dim = self.dim_obs
305
+ x_dim = self.dim_cond
306
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
307
+ x_i_dim = x_dim // 4
308
+ x_i_mask = jax.scipy.linalg.block_diag(
309
+ *tuple([jnp.tril(jnp.ones((x_i_dim, x_i_dim), dtype=jnp.bool_))] * 4)
310
+ )
311
+ base_mask = jnp.block(
312
+ [
313
+ [thetas_mask, jnp.zeros((theta_dim, x_dim))],
314
+ [jnp.ones((x_dim, theta_dim)), x_i_mask],
315
+ ]
316
+ )
317
+ base_mask = base_mask.astype(jnp.bool_)
318
+
319
+ def base_mask_fn(node_ids, node_meta_data):
320
+ return base_mask[node_ids, :][:, node_ids]
321
+
322
+ return base_mask_fn
323
+
324
+
325
+ def get_task(task_name, kind="joint"):
326
+ """
327
+ Returns a Task object based on the task name.
328
+ """
329
+ task_name = task_name.lower()
330
+ if task_name == "two_moons":
331
+ return TwoMoons(kind=kind)
332
+ elif task_name == "bernoulli_glm":
333
+ return BernoulliGLM(kind=kind)
334
+ elif task_name == "gaussian_linear":
335
+ return GaussianLinear(kind=kind)
336
+ elif task_name == "gaussian_linear_uniform":
337
+ return GaussianLinearUniform(kind=kind)
338
+ elif task_name == "gaussian_mixture":
339
+ return GaussianMixture(kind=kind)
340
+ elif task_name == "slcp":
341
+ return SLCP(kind=kind)
342
+ else:
343
+ raise ValueError(f"Unknown task: {task_name}")
@@ -0,0 +1,15 @@
1
+ # Utility function to get the checkpoint directory for a specific example
2
+ import os
3
+
4
+ # def get_example_checkpoint_dir(example_name):
5
+ # """
6
+ # Returns the absolute path to the checkpoint directory for a given example.
7
+ # Example:
8
+ # get_example_checkpoint_dir("my_first_model")
9
+ # # returns /path/to/GenSBI-examples/examples/getting_started/checkpoints
10
+ # """
11
+ # base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+ # if example_name == "my_first_model":
13
+ # return os.path.join(base_dir, "examples", "getting_started", "checkpoints")
14
+ # # Add more mappings as needed
15
+ # raise ValueError(f"Unknown example name: {example_name}")
@@ -0,0 +1,56 @@
1
+ # Utility function to get the checkpoint directory for a specific example
2
+ import os
3
+
4
+ def get_example_checkpoint_dir(example_name):
5
+ """
6
+ Returns the absolute path to the checkpoint directory for a given example.
7
+ Example:
8
+ get_example_checkpoint_dir("my_first_model")
9
+ # returns /path/to/GenSBI-examples/examples/getting_started/checkpoints
10
+ """
11
+ base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+ if example_name == "my_first_model":
13
+ return os.path.join(base_dir, "examples", "getting_started", "checkpoints")
14
+ # Add more mappings as needed
15
+ raise ValueError(f"Unknown example name: {example_name}")
16
+ import os
17
+ import os
18
+ from IPython import get_ipython
19
+
20
+ def download_artifacts(task=None, dir=None):
21
+ """
22
+ Downloads the artifacts from the GenSBI repository.
23
+ """
24
+ root = "https://github.com/aurelio-amerio/GenSBI-examples/releases/download"
25
+ tag = "data-v0.1"
26
+ if task is not None:
27
+ fnames = [f"data_{task}.npz"]
28
+ else:
29
+ fnames =[
30
+ "data_two_moons.npz",
31
+ "data_bernoulli_glm.npz",
32
+ "data_gaussian_linear.npz",
33
+ "data_gaussian_linear_uniform.npz",
34
+ "data_gaussian_mixture.npz",
35
+ "data_slcp.npz"]
36
+
37
+ fnames = [os.path.join(root, tag, fname) for fname in fnames]
38
+
39
+ if dir is None:
40
+ dir = os.path.join(os.getcwd(), "task_data")
41
+ else:
42
+ dir = os.path.join(dir, "task_data")
43
+ os.makedirs(dir, exist_ok=True)
44
+ for fname in fnames:
45
+ local_fname = os.path.join(dir, os.path.basename(fname))
46
+ if not os.path.exists(local_fname):
47
+ print(f"Downloading {fname} to {local_fname}")
48
+ os.system(f"wget -O {local_fname} {fname}")
49
+ else:
50
+ print(f"{local_fname} already exists, skipping download.")
51
+
52
+
53
+
54
+
55
+
56
+
@@ -0,0 +1,72 @@
1
+ Metadata-Version: 2.4
2
+ Name: gensbi-examples
3
+ Version: 0.0.2
4
+ Summary: Examples for the GenSBI library
5
+ Project-URL: Homepage, https://github.com/aurelio-amerio/GenSBI-examples
6
+ Project-URL: Issues, https://github.com/aurelio-amerio/GenSBI-examples/issues
7
+ Author-email: Aurelio Amerio <aure.amerio@gmail.com>
8
+ License: Copyright 2025 Amerio Aurelio
9
+
10
+ Licensed under the Apache License, Version 2.0 (the "License");
11
+ you may not use this file except in compliance with the License.
12
+ You may obtain a copy of the License at
13
+
14
+ http://www.apache.org/licenses/LICENSE-2.0
15
+
16
+ Unless required by applicable law or agreed to in writing, software
17
+ distributed under the License is distributed on an "AS IS" BASIS,
18
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ See the License for the specific language governing permissions and
20
+ limitations under the License.
21
+ License-File: LICENSE
22
+ Requires-Python: >=3.11
23
+ Requires-Dist: datasets
24
+ Requires-Dist: flax>=0.12.0
25
+ Requires-Dist: grain>=0.2.12
26
+ Requires-Dist: huggingface-hub
27
+ Requires-Dist: jax<=0.8.1,>=0.7.2
28
+ Requires-Dist: matplotlib>=3.10
29
+ Requires-Dist: numpy>=2.0
30
+ Requires-Dist: scikit-learn>=1.7.0
31
+ Description-Content-Type: text/markdown
32
+
33
+ # GenSBI Examples
34
+
35
+ This repository contains a collection of examples, tutorials, and recipes for **GenSBI**, a JAX-based library for Simulation-Based Inference using generative models.
36
+
37
+ These examples demonstrate how to use GenSBI for various tasks, including:
38
+
39
+ - Defining and running inference pipelines.
40
+ - Using different embedding networks (MLP, ResNet, etc.).
41
+ - Handling various data types (1D signals, 2D images).
42
+
43
+ ## Installation
44
+
45
+ ### Prerequisites
46
+
47
+ You need to have **GenSBI** installed.
48
+
49
+ **With CUDA 12 support (Recommended):**
50
+
51
+ ```bash
52
+ pip install gensbi[cuda12]
53
+ ```
54
+
55
+ **CPU-only:**
56
+
57
+ ```bash
58
+ pip install gensbi
59
+ ```
60
+
61
+ ### Install Examples Package
62
+
63
+ To run the examples and ensure all dependencies are met, install this package:
64
+
65
+ ```bash
66
+ pip install gensbi-examples
67
+ ```
68
+
69
+ ## Structure
70
+
71
+ - `examples/`: Contains standalone example scripts and notebooks.
72
+ - `src/gensbi_examples`: Helper utilities for the examples.
@@ -0,0 +1,13 @@
1
+ gensbi_examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ gensbi_examples/c2st.py,sha256=PU5fGq6QAItIyc_nscINk2TUrqvXYk-9PCDFVYbambg,2815
3
+ gensbi_examples/c2st_v2.py.bk,sha256=xqkYZkUvmOnylBn0BmAY8n5kzc29OYu73U8EilBFn48,4455
4
+ gensbi_examples/graph.py,sha256=OdJ8fiSP5_pgmzmHr2E9ZWYZzZm3ToJWLho2gRH0w7I,7570
5
+ gensbi_examples/mask.py,sha256=OO7QH_r7SRjCiGLPn0vSS-mtFD2FE4JZFWhbed0MkcI,2997
6
+ gensbi_examples/sbi_tasks.py.bk,sha256=AhJBXxBygXWSxmTzLzYGKrFI9hS9nuCHeLPZ5QgMWNk,14678
7
+ gensbi_examples/tasks.py,sha256=-G20P3tQ4xwmjhJah1f0aHwdJXtNJUSgBlz2sGlcmaQ,10893
8
+ gensbi_examples/utils.py,sha256=mF-wutjqqCX1EEmhHAtAlWvCSBcCiddoB9TBJjgz-xM,709
9
+ gensbi_examples/utils.py.bk,sha256=_Q2hgbNlwo2uNFZ2tHIs_uye8IbLmsJYy6cYrErthT8,1863
10
+ gensbi_examples-0.0.2.dist-info/METADATA,sha256=FsH7l_zy7WyFY5UCRm_aealQ6bvPW34yVpCIByu4Ctc,2233
11
+ gensbi_examples-0.0.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
+ gensbi_examples-0.0.2.dist-info/licenses/LICENSE,sha256=D8Mi2-fbemv3oPZgMB-COT0aw2DXajsiebPYWtOMSpg,582
13
+ gensbi_examples-0.0.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,13 @@
1
+ Copyright 2025 Amerio Aurelio
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.