gensbi-examples 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gensbi_examples/__init__.py +0 -0
- gensbi_examples/c2st.py +111 -0
- gensbi_examples/c2st_v2.py.bk +147 -0
- gensbi_examples/graph.py +211 -0
- gensbi_examples/mask.py +80 -0
- gensbi_examples/sbi_tasks.py.bk +417 -0
- gensbi_examples/tasks.py +343 -0
- gensbi_examples/utils.py +15 -0
- gensbi_examples/utils.py.bk +56 -0
- gensbi_examples-0.0.2.dist-info/METADATA +72 -0
- gensbi_examples-0.0.2.dist-info/RECORD +13 -0
- gensbi_examples-0.0.2.dist-info/WHEEL +4 -0
- gensbi_examples-0.0.2.dist-info/licenses/LICENSE +13 -0
gensbi_examples/tasks.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
from jax import numpy as jnp
|
|
3
|
+
import grain
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from datasets import load_dataset
|
|
7
|
+
from huggingface_hub import hf_hub_download
|
|
8
|
+
import json
|
|
9
|
+
|
|
10
|
+
# from .utils import download_artifacts
|
|
11
|
+
from .graph import faithfull_mask, min_faithfull_mask, moralize
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def process_joint(batch):
|
|
15
|
+
cond = batch["xs"][..., None]
|
|
16
|
+
obs = batch["thetas"][..., None]
|
|
17
|
+
data = np.concatenate((obs, cond), axis=1)
|
|
18
|
+
return data
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process_conditional(batch):
|
|
22
|
+
cond = batch["xs"][..., None]
|
|
23
|
+
obs = batch["thetas"][..., None]
|
|
24
|
+
return obs, cond
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Task:
|
|
28
|
+
def __init__(self, task_name, kind="joint"):
|
|
29
|
+
|
|
30
|
+
self.repo_name = "aurelio-amerio/SBI-benchmarks"
|
|
31
|
+
|
|
32
|
+
self.task_name = task_name
|
|
33
|
+
|
|
34
|
+
fname = hf_hub_download(
|
|
35
|
+
repo_id=self.repo_name, filename="metadata.json", repo_type="dataset"
|
|
36
|
+
)
|
|
37
|
+
with open(fname, "r") as f:
|
|
38
|
+
metadata = json.load(f)
|
|
39
|
+
|
|
40
|
+
self.dataset = load_dataset(self.repo_name, task_name).with_format("numpy")
|
|
41
|
+
self.dataset_posterior = load_dataset(
|
|
42
|
+
self.repo_name, f"{task_name}_posterior"
|
|
43
|
+
).with_format("numpy")
|
|
44
|
+
|
|
45
|
+
self.max_samples = self.dataset["train"].num_rows
|
|
46
|
+
|
|
47
|
+
self.observations = self.dataset_posterior["reference_posterior"][
|
|
48
|
+
"observations"
|
|
49
|
+
]
|
|
50
|
+
self.reference_samples = self.dataset_posterior["reference_posterior"][
|
|
51
|
+
"reference_samples"
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
self.true_parameters = self.dataset_posterior["reference_posterior"][
|
|
55
|
+
"true_parameters"
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
self.dim_cond = metadata[task_name]["dim_cond"]
|
|
59
|
+
self.dim_obs = metadata[task_name]["dim_obs"]
|
|
60
|
+
|
|
61
|
+
self.dim_joint = self.dim_cond + self.dim_obs
|
|
62
|
+
|
|
63
|
+
self.num_observations = len(self.observations)
|
|
64
|
+
self.kind = kind
|
|
65
|
+
|
|
66
|
+
if kind == "joint":
|
|
67
|
+
self.process_fn = process_joint
|
|
68
|
+
elif kind == "conditional":
|
|
69
|
+
self.process_fn = process_conditional
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Unknown kind: {kind}")
|
|
72
|
+
|
|
73
|
+
def get_train_dataset(self, batch_size, nsamples=1e5):
|
|
74
|
+
assert (
|
|
75
|
+
nsamples < self.max_samples
|
|
76
|
+
), f"nsamples must be less than {self.max_samples}"
|
|
77
|
+
|
|
78
|
+
df = self.dataset["train"].select(range(int(nsamples))) # [:]
|
|
79
|
+
|
|
80
|
+
dataset_grain = (
|
|
81
|
+
grain.MapDataset.source(df).shuffle(42).repeat().to_iter_dataset()
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
performance_config = grain.experimental.pick_performance_config(
|
|
85
|
+
ds=dataset_grain,
|
|
86
|
+
ram_budget_mb=1024 * 4,
|
|
87
|
+
max_workers=None,
|
|
88
|
+
max_buffer_size=None,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
dataset_batched = (
|
|
92
|
+
dataset_grain.batch(batch_size)
|
|
93
|
+
.map(self.process_fn)
|
|
94
|
+
.mp_prefetch(performance_config.multiprocessing_options)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return dataset_batched
|
|
98
|
+
|
|
99
|
+
def get_val_dataset(self, batch_size):
|
|
100
|
+
df = self.dataset["validation"] # [:]
|
|
101
|
+
|
|
102
|
+
val_dataset_grain = (
|
|
103
|
+
grain.MapDataset.source(df).shuffle(42).repeat().to_iter_dataset()
|
|
104
|
+
)
|
|
105
|
+
performance_config = grain.experimental.pick_performance_config(
|
|
106
|
+
ds=val_dataset_grain,
|
|
107
|
+
ram_budget_mb=1024 * 4,
|
|
108
|
+
max_workers=None,
|
|
109
|
+
max_buffer_size=None,
|
|
110
|
+
)
|
|
111
|
+
val_dataset_grain = (
|
|
112
|
+
val_dataset_grain.batch(batch_size)
|
|
113
|
+
.map(self.process_fn)
|
|
114
|
+
.mp_prefetch(performance_config.multiprocessing_options)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return val_dataset_grain
|
|
118
|
+
|
|
119
|
+
def get_test_dataset(self, batch_size):
|
|
120
|
+
df = self.dataset["test"] # [:]
|
|
121
|
+
|
|
122
|
+
val_dataset_grain = (
|
|
123
|
+
grain.MapDataset.source(df)
|
|
124
|
+
.shuffle(42)
|
|
125
|
+
.repeat()
|
|
126
|
+
.to_iter_dataset()
|
|
127
|
+
.batch(batch_size)
|
|
128
|
+
.map(self.process_fn)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return val_dataset_grain
|
|
132
|
+
|
|
133
|
+
def get_reference(self, num_observation=1):
|
|
134
|
+
"""
|
|
135
|
+
Returns the reference posterior samples for a given number of observations.
|
|
136
|
+
"""
|
|
137
|
+
if num_observation < 1 or num_observation > self.num_observations:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"num_observation must be between 1 and {self.num_observations}"
|
|
140
|
+
)
|
|
141
|
+
obs = self.observations[num_observation - 1]
|
|
142
|
+
samples = self.reference_samples[num_observation - 1]
|
|
143
|
+
return obs, samples
|
|
144
|
+
|
|
145
|
+
def get_true_parameters(self, num_observation=1):
|
|
146
|
+
"""
|
|
147
|
+
Returns the true parameters for a given number of observations.
|
|
148
|
+
"""
|
|
149
|
+
if num_observation < 1 or num_observation > self.num_observations:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"num_observation must be between 1 and {self.num_observations}"
|
|
152
|
+
)
|
|
153
|
+
return self.true_parameters[num_observation - 1]
|
|
154
|
+
|
|
155
|
+
def get_base_mask_fn(self):
|
|
156
|
+
raise NotImplementedError()
|
|
157
|
+
|
|
158
|
+
def get_edge_mask_fn(self, name="undirected"):
|
|
159
|
+
if name.lower() == "faithfull":
|
|
160
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
161
|
+
|
|
162
|
+
def faithfull_edge_mask(node_id, condition_mask, meta_data=None):
|
|
163
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
164
|
+
return faithfull_mask(base_mask, condition_mask)
|
|
165
|
+
|
|
166
|
+
return faithfull_edge_mask
|
|
167
|
+
elif name.lower() == "min_faithfull":
|
|
168
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
169
|
+
|
|
170
|
+
def min_faithfull_edge_mask(node_id, condition_mask, meta_data=None):
|
|
171
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
172
|
+
|
|
173
|
+
return min_faithfull_mask(base_mask, condition_mask)
|
|
174
|
+
|
|
175
|
+
return min_faithfull_edge_mask
|
|
176
|
+
elif name.lower() == "undirected":
|
|
177
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
178
|
+
|
|
179
|
+
def undirected_edge_mask(node_id, condition_mask, meta_data=None):
|
|
180
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
181
|
+
return moralize(base_mask)
|
|
182
|
+
|
|
183
|
+
return undirected_edge_mask
|
|
184
|
+
|
|
185
|
+
elif name.lower() == "directed":
|
|
186
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
187
|
+
|
|
188
|
+
def directed_edge_mask(node_id, condition_mask, meta_data=None):
|
|
189
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
190
|
+
return base_mask
|
|
191
|
+
|
|
192
|
+
return directed_edge_mask
|
|
193
|
+
elif name.lower() == "none":
|
|
194
|
+
return lambda node_id, condition_mask, *args, **kwargs: None
|
|
195
|
+
else:
|
|
196
|
+
raise NotImplementedError()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class TwoMoons(Task):
|
|
200
|
+
def __init__(self, kind="joint"):
|
|
201
|
+
task_name = "two_moons"
|
|
202
|
+
super().__init__(task_name, kind=kind)
|
|
203
|
+
|
|
204
|
+
def get_base_mask_fn(self):
|
|
205
|
+
theta_dim = self.dim_obs
|
|
206
|
+
x_dim = self.dim_cond
|
|
207
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
208
|
+
x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
|
|
209
|
+
base_mask = jnp.block(
|
|
210
|
+
[
|
|
211
|
+
[thetas_mask, jnp.zeros((theta_dim, x_dim))],
|
|
212
|
+
[jnp.ones((x_dim, theta_dim)), x_mask],
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
216
|
+
|
|
217
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
218
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
219
|
+
|
|
220
|
+
return base_mask_fn
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class BernoulliGLM(Task):
|
|
224
|
+
def __init__(self, kind="joint"):
|
|
225
|
+
task_name = "bernoulli_glm"
|
|
226
|
+
super().__init__(task_name, kind=kind)
|
|
227
|
+
|
|
228
|
+
def get_base_mask_fn(self):
|
|
229
|
+
raise NotImplementedError()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class GaussianLinear(Task):
|
|
233
|
+
def __init__(self, kind="joint"):
|
|
234
|
+
task_name = "gaussian_linear"
|
|
235
|
+
super().__init__(task_name, kind=kind)
|
|
236
|
+
|
|
237
|
+
def get_base_mask_fn(self):
|
|
238
|
+
theta_dim = self.dim_obs
|
|
239
|
+
x_dim = self.dim_cond
|
|
240
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
241
|
+
x_i_mask = jnp.eye(x_dim, dtype=jnp.bool_)
|
|
242
|
+
base_mask = jnp.block(
|
|
243
|
+
[[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.eye((x_dim)), x_i_mask]]
|
|
244
|
+
)
|
|
245
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
246
|
+
|
|
247
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
248
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
249
|
+
|
|
250
|
+
return base_mask_fn
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class GaussianLinearUniform(Task):
|
|
254
|
+
def __init__(self, kind="joint"):
|
|
255
|
+
task_name = "gaussian_linear_uniform"
|
|
256
|
+
super().__init__(task_name, kind=kind)
|
|
257
|
+
|
|
258
|
+
def get_base_mask_fn(self):
|
|
259
|
+
theta_dim = self.dim_obs
|
|
260
|
+
x_dim = self.dim_cond
|
|
261
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
262
|
+
x_i_mask = jnp.eye(x_dim, dtype=jnp.bool_)
|
|
263
|
+
base_mask = jnp.block(
|
|
264
|
+
[[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.eye((x_dim)), x_i_mask]]
|
|
265
|
+
)
|
|
266
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
267
|
+
|
|
268
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
269
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
270
|
+
|
|
271
|
+
return base_mask_fn
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class GaussianMixture(Task):
|
|
275
|
+
def __init__(self, kind="joint"):
|
|
276
|
+
task_name = "gaussian_mixture"
|
|
277
|
+
super().__init__(task_name, kind=kind)
|
|
278
|
+
|
|
279
|
+
def get_base_mask_fn(self):
|
|
280
|
+
theta_dim = self.dim_obs
|
|
281
|
+
x_dim = self.dim_cond
|
|
282
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
283
|
+
x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
|
|
284
|
+
base_mask = jnp.block(
|
|
285
|
+
[
|
|
286
|
+
[thetas_mask, jnp.zeros((theta_dim, x_dim))],
|
|
287
|
+
[jnp.ones((x_dim, theta_dim)), x_mask],
|
|
288
|
+
]
|
|
289
|
+
)
|
|
290
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
291
|
+
|
|
292
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
293
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
294
|
+
|
|
295
|
+
return base_mask_fn
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class SLCP(Task):
|
|
299
|
+
def __init__(self, kind="joint"):
|
|
300
|
+
task_name = "slcp"
|
|
301
|
+
super().__init__(task_name, kind=kind)
|
|
302
|
+
|
|
303
|
+
def get_base_mask_fn(self):
|
|
304
|
+
theta_dim = self.dim_obs
|
|
305
|
+
x_dim = self.dim_cond
|
|
306
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
307
|
+
x_i_dim = x_dim // 4
|
|
308
|
+
x_i_mask = jax.scipy.linalg.block_diag(
|
|
309
|
+
*tuple([jnp.tril(jnp.ones((x_i_dim, x_i_dim), dtype=jnp.bool_))] * 4)
|
|
310
|
+
)
|
|
311
|
+
base_mask = jnp.block(
|
|
312
|
+
[
|
|
313
|
+
[thetas_mask, jnp.zeros((theta_dim, x_dim))],
|
|
314
|
+
[jnp.ones((x_dim, theta_dim)), x_i_mask],
|
|
315
|
+
]
|
|
316
|
+
)
|
|
317
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
318
|
+
|
|
319
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
320
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
321
|
+
|
|
322
|
+
return base_mask_fn
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def get_task(task_name, kind="joint"):
|
|
326
|
+
"""
|
|
327
|
+
Returns a Task object based on the task name.
|
|
328
|
+
"""
|
|
329
|
+
task_name = task_name.lower()
|
|
330
|
+
if task_name == "two_moons":
|
|
331
|
+
return TwoMoons(kind=kind)
|
|
332
|
+
elif task_name == "bernoulli_glm":
|
|
333
|
+
return BernoulliGLM(kind=kind)
|
|
334
|
+
elif task_name == "gaussian_linear":
|
|
335
|
+
return GaussianLinear(kind=kind)
|
|
336
|
+
elif task_name == "gaussian_linear_uniform":
|
|
337
|
+
return GaussianLinearUniform(kind=kind)
|
|
338
|
+
elif task_name == "gaussian_mixture":
|
|
339
|
+
return GaussianMixture(kind=kind)
|
|
340
|
+
elif task_name == "slcp":
|
|
341
|
+
return SLCP(kind=kind)
|
|
342
|
+
else:
|
|
343
|
+
raise ValueError(f"Unknown task: {task_name}")
|
gensbi_examples/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Utility function to get the checkpoint directory for a specific example
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
# def get_example_checkpoint_dir(example_name):
|
|
5
|
+
# """
|
|
6
|
+
# Returns the absolute path to the checkpoint directory for a given example.
|
|
7
|
+
# Example:
|
|
8
|
+
# get_example_checkpoint_dir("my_first_model")
|
|
9
|
+
# # returns /path/to/GenSBI-examples/examples/getting_started/checkpoints
|
|
10
|
+
# """
|
|
11
|
+
# base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
12
|
+
# if example_name == "my_first_model":
|
|
13
|
+
# return os.path.join(base_dir, "examples", "getting_started", "checkpoints")
|
|
14
|
+
# # Add more mappings as needed
|
|
15
|
+
# raise ValueError(f"Unknown example name: {example_name}")
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Utility function to get the checkpoint directory for a specific example
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
def get_example_checkpoint_dir(example_name):
|
|
5
|
+
"""
|
|
6
|
+
Returns the absolute path to the checkpoint directory for a given example.
|
|
7
|
+
Example:
|
|
8
|
+
get_example_checkpoint_dir("my_first_model")
|
|
9
|
+
# returns /path/to/GenSBI-examples/examples/getting_started/checkpoints
|
|
10
|
+
"""
|
|
11
|
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
12
|
+
if example_name == "my_first_model":
|
|
13
|
+
return os.path.join(base_dir, "examples", "getting_started", "checkpoints")
|
|
14
|
+
# Add more mappings as needed
|
|
15
|
+
raise ValueError(f"Unknown example name: {example_name}")
|
|
16
|
+
import os
|
|
17
|
+
import os
|
|
18
|
+
from IPython import get_ipython
|
|
19
|
+
|
|
20
|
+
def download_artifacts(task=None, dir=None):
|
|
21
|
+
"""
|
|
22
|
+
Downloads the artifacts from the GenSBI repository.
|
|
23
|
+
"""
|
|
24
|
+
root = "https://github.com/aurelio-amerio/GenSBI-examples/releases/download"
|
|
25
|
+
tag = "data-v0.1"
|
|
26
|
+
if task is not None:
|
|
27
|
+
fnames = [f"data_{task}.npz"]
|
|
28
|
+
else:
|
|
29
|
+
fnames =[
|
|
30
|
+
"data_two_moons.npz",
|
|
31
|
+
"data_bernoulli_glm.npz",
|
|
32
|
+
"data_gaussian_linear.npz",
|
|
33
|
+
"data_gaussian_linear_uniform.npz",
|
|
34
|
+
"data_gaussian_mixture.npz",
|
|
35
|
+
"data_slcp.npz"]
|
|
36
|
+
|
|
37
|
+
fnames = [os.path.join(root, tag, fname) for fname in fnames]
|
|
38
|
+
|
|
39
|
+
if dir is None:
|
|
40
|
+
dir = os.path.join(os.getcwd(), "task_data")
|
|
41
|
+
else:
|
|
42
|
+
dir = os.path.join(dir, "task_data")
|
|
43
|
+
os.makedirs(dir, exist_ok=True)
|
|
44
|
+
for fname in fnames:
|
|
45
|
+
local_fname = os.path.join(dir, os.path.basename(fname))
|
|
46
|
+
if not os.path.exists(local_fname):
|
|
47
|
+
print(f"Downloading {fname} to {local_fname}")
|
|
48
|
+
os.system(f"wget -O {local_fname} {fname}")
|
|
49
|
+
else:
|
|
50
|
+
print(f"{local_fname} already exists, skipping download.")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gensbi-examples
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Examples for the GenSBI library
|
|
5
|
+
Project-URL: Homepage, https://github.com/aurelio-amerio/GenSBI-examples
|
|
6
|
+
Project-URL: Issues, https://github.com/aurelio-amerio/GenSBI-examples/issues
|
|
7
|
+
Author-email: Aurelio Amerio <aure.amerio@gmail.com>
|
|
8
|
+
License: Copyright 2025 Amerio Aurelio
|
|
9
|
+
|
|
10
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
11
|
+
you may not use this file except in compliance with the License.
|
|
12
|
+
You may obtain a copy of the License at
|
|
13
|
+
|
|
14
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
15
|
+
|
|
16
|
+
Unless required by applicable law or agreed to in writing, software
|
|
17
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
18
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
|
+
See the License for the specific language governing permissions and
|
|
20
|
+
limitations under the License.
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Python: >=3.11
|
|
23
|
+
Requires-Dist: datasets
|
|
24
|
+
Requires-Dist: flax>=0.12.0
|
|
25
|
+
Requires-Dist: grain>=0.2.12
|
|
26
|
+
Requires-Dist: huggingface-hub
|
|
27
|
+
Requires-Dist: jax<=0.8.1,>=0.7.2
|
|
28
|
+
Requires-Dist: matplotlib>=3.10
|
|
29
|
+
Requires-Dist: numpy>=2.0
|
|
30
|
+
Requires-Dist: scikit-learn>=1.7.0
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
# GenSBI Examples
|
|
34
|
+
|
|
35
|
+
This repository contains a collection of examples, tutorials, and recipes for **GenSBI**, a JAX-based library for Simulation-Based Inference using generative models.
|
|
36
|
+
|
|
37
|
+
These examples demonstrate how to use GenSBI for various tasks, including:
|
|
38
|
+
|
|
39
|
+
- Defining and running inference pipelines.
|
|
40
|
+
- Using different embedding networks (MLP, ResNet, etc.).
|
|
41
|
+
- Handling various data types (1D signals, 2D images).
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
### Prerequisites
|
|
46
|
+
|
|
47
|
+
You need to have **GenSBI** installed.
|
|
48
|
+
|
|
49
|
+
**With CUDA 12 support (Recommended):**
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install gensbi[cuda12]
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
**CPU-only:**
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
pip install gensbi
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
### Install Examples Package
|
|
62
|
+
|
|
63
|
+
To run the examples and ensure all dependencies are met, install this package:
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
pip install gensbi-examples
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Structure
|
|
70
|
+
|
|
71
|
+
- `examples/`: Contains standalone example scripts and notebooks.
|
|
72
|
+
- `src/gensbi_examples`: Helper utilities for the examples.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
gensbi_examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
gensbi_examples/c2st.py,sha256=PU5fGq6QAItIyc_nscINk2TUrqvXYk-9PCDFVYbambg,2815
|
|
3
|
+
gensbi_examples/c2st_v2.py.bk,sha256=xqkYZkUvmOnylBn0BmAY8n5kzc29OYu73U8EilBFn48,4455
|
|
4
|
+
gensbi_examples/graph.py,sha256=OdJ8fiSP5_pgmzmHr2E9ZWYZzZm3ToJWLho2gRH0w7I,7570
|
|
5
|
+
gensbi_examples/mask.py,sha256=OO7QH_r7SRjCiGLPn0vSS-mtFD2FE4JZFWhbed0MkcI,2997
|
|
6
|
+
gensbi_examples/sbi_tasks.py.bk,sha256=AhJBXxBygXWSxmTzLzYGKrFI9hS9nuCHeLPZ5QgMWNk,14678
|
|
7
|
+
gensbi_examples/tasks.py,sha256=-G20P3tQ4xwmjhJah1f0aHwdJXtNJUSgBlz2sGlcmaQ,10893
|
|
8
|
+
gensbi_examples/utils.py,sha256=mF-wutjqqCX1EEmhHAtAlWvCSBcCiddoB9TBJjgz-xM,709
|
|
9
|
+
gensbi_examples/utils.py.bk,sha256=_Q2hgbNlwo2uNFZ2tHIs_uye8IbLmsJYy6cYrErthT8,1863
|
|
10
|
+
gensbi_examples-0.0.2.dist-info/METADATA,sha256=FsH7l_zy7WyFY5UCRm_aealQ6bvPW34yVpCIByu4Ctc,2233
|
|
11
|
+
gensbi_examples-0.0.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
12
|
+
gensbi_examples-0.0.2.dist-info/licenses/LICENSE,sha256=D8Mi2-fbemv3oPZgMB-COT0aw2DXajsiebPYWtOMSpg,582
|
|
13
|
+
gensbi_examples-0.0.2.dist-info/RECORD,,
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Copyright 2025 Amerio Aurelio
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|