gensbi-examples 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gensbi_examples/__init__.py +0 -0
- gensbi_examples/c2st.py +111 -0
- gensbi_examples/c2st_v2.py.bk +147 -0
- gensbi_examples/graph.py +211 -0
- gensbi_examples/mask.py +80 -0
- gensbi_examples/sbi_tasks.py.bk +417 -0
- gensbi_examples/tasks.py +343 -0
- gensbi_examples/utils.py +15 -0
- gensbi_examples/utils.py.bk +56 -0
- gensbi_examples-0.0.2.dist-info/METADATA +72 -0
- gensbi_examples-0.0.2.dist-info/RECORD +13 -0
- gensbi_examples-0.0.2.dist-info/WHEEL +4 -0
- gensbi_examples-0.0.2.dist-info/licenses/LICENSE +13 -0
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
|
|
9
|
+
from .graph import faithfull_mask, min_faithfull_mask, moralize
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from sbibm import get_task as _get_torch_task
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Task(ABC):
|
|
20
|
+
|
|
21
|
+
def __init__(self, name: str, backend: str = "torch") -> None:
|
|
22
|
+
self.name = name
|
|
23
|
+
self.backend = backend
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def theta_dim(self):
|
|
27
|
+
return self.get_theta_dim()
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def x_dim(self):
|
|
31
|
+
return self.get_x_dim()
|
|
32
|
+
|
|
33
|
+
def get_theta_dim(self):
|
|
34
|
+
raise NotImplementedError()
|
|
35
|
+
|
|
36
|
+
def get_x_dim(self):
|
|
37
|
+
raise NotImplementedError()
|
|
38
|
+
|
|
39
|
+
def get_data(self, num_samples: int, key=None):
|
|
40
|
+
raise NotImplementedError()
|
|
41
|
+
|
|
42
|
+
def get_node_id(self):
|
|
43
|
+
raise NotImplementedError()
|
|
44
|
+
|
|
45
|
+
def get_base_mask_fn(self):
|
|
46
|
+
raise NotImplementedError()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class InferenceTask(Task):
|
|
53
|
+
|
|
54
|
+
observations = range(1, 11)
|
|
55
|
+
|
|
56
|
+
def __init__(self, name: str, backend: str = "jax") -> None:
|
|
57
|
+
super().__init__(name, backend)
|
|
58
|
+
|
|
59
|
+
def get_prior(self):
|
|
60
|
+
raise NotImplementedError()
|
|
61
|
+
|
|
62
|
+
def get_simulator(self):
|
|
63
|
+
raise NotImplementedError()
|
|
64
|
+
|
|
65
|
+
def get_data(self, num_samples: int, key=None):
|
|
66
|
+
raise NotImplementedError()
|
|
67
|
+
|
|
68
|
+
def get_observation(self, index: int):
|
|
69
|
+
raise NotImplementedError()
|
|
70
|
+
|
|
71
|
+
def get_reference_posterior_samples(self, index: int):
|
|
72
|
+
raise NotImplementedError()
|
|
73
|
+
|
|
74
|
+
def get_true_parameters(self, index: int):
|
|
75
|
+
raise NotImplementedError()
|
|
76
|
+
|
|
77
|
+
def get_edge_mask_fn(self, name="undirected"):
|
|
78
|
+
task = self.task
|
|
79
|
+
if name.lower() == "faithfull":
|
|
80
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
81
|
+
def faithfull_edge_mask(node_id, condition_mask, meta_data=None):
|
|
82
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
83
|
+
return faithfull_mask(base_mask, condition_mask)
|
|
84
|
+
|
|
85
|
+
return faithfull_edge_mask
|
|
86
|
+
elif name.lower() == "min_faithfull":
|
|
87
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
88
|
+
def min_faithfull_edge_mask(node_id, condition_mask,meta_data=None):
|
|
89
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
90
|
+
|
|
91
|
+
return min_faithfull_mask(base_mask, condition_mask)
|
|
92
|
+
|
|
93
|
+
return min_faithfull_edge_mask
|
|
94
|
+
elif name.lower() == "undirected":
|
|
95
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
96
|
+
def undirected_edge_mask(node_id, condition_mask, meta_data=None):
|
|
97
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
98
|
+
return moralize(base_mask)
|
|
99
|
+
|
|
100
|
+
return undirected_edge_mask
|
|
101
|
+
|
|
102
|
+
elif name.lower() == "directed":
|
|
103
|
+
base_mask_fn = self.get_base_mask_fn()
|
|
104
|
+
def directed_edge_mask(node_id, condition_mask, meta_data=None):
|
|
105
|
+
base_mask = base_mask_fn(node_id, meta_data)
|
|
106
|
+
return base_mask
|
|
107
|
+
|
|
108
|
+
return directed_edge_mask
|
|
109
|
+
elif name.lower() == "none":
|
|
110
|
+
return lambda node_id, condition_mask, *args, **kwargs: None
|
|
111
|
+
else:
|
|
112
|
+
raise NotImplementedError()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class SBIBMTask(InferenceTask):
|
|
116
|
+
observations = range(1, 11)
|
|
117
|
+
|
|
118
|
+
def __init__(self, name: str, backend: str = "jax") -> None:
|
|
119
|
+
super().__init__(name, backend)
|
|
120
|
+
self.task = _get_torch_task(self.name)
|
|
121
|
+
|
|
122
|
+
def get_theta_dim(self):
|
|
123
|
+
return self.task.dim_parameters
|
|
124
|
+
|
|
125
|
+
def get_x_dim(self):
|
|
126
|
+
return self.task.dim_cond
|
|
127
|
+
|
|
128
|
+
def get_prior(self):
|
|
129
|
+
if self.backend == "torch":
|
|
130
|
+
return self.task.get_prior_dist()
|
|
131
|
+
else:
|
|
132
|
+
raise NotImplementedError()
|
|
133
|
+
|
|
134
|
+
def get_simulator(self):
|
|
135
|
+
if self.backend == "torch":
|
|
136
|
+
return self.task.get_simulator()
|
|
137
|
+
else:
|
|
138
|
+
raise NotImplementedError()
|
|
139
|
+
|
|
140
|
+
def get_node_id(self):
|
|
141
|
+
dim = self.get_theta_dim() + self.get_x_dim()
|
|
142
|
+
if self.backend == "torch":
|
|
143
|
+
return torch.arange(dim)
|
|
144
|
+
else:
|
|
145
|
+
return jnp.arange(dim)
|
|
146
|
+
|
|
147
|
+
def get_data(self, num_samples: int, **kwargs):
|
|
148
|
+
try:
|
|
149
|
+
prior = self.get_prior()
|
|
150
|
+
simulator = self.get_simulator()
|
|
151
|
+
thetas = prior.sample((num_samples,))
|
|
152
|
+
xs = simulator(thetas)
|
|
153
|
+
return {"theta":thetas, "x":xs}
|
|
154
|
+
except:
|
|
155
|
+
# If not implemented in JAX, use PyTorch
|
|
156
|
+
old_backed = self.backend
|
|
157
|
+
self.backend = "torch"
|
|
158
|
+
prior = self.get_prior()
|
|
159
|
+
simulator = self.get_simulator()
|
|
160
|
+
thetas = prior.sample((num_samples,))
|
|
161
|
+
xs = simulator(thetas)
|
|
162
|
+
self.backend = old_backed
|
|
163
|
+
if self.backend == "numpy":
|
|
164
|
+
thetas = thetas.numpy()
|
|
165
|
+
xs = xs.numpy()
|
|
166
|
+
elif self.backend == "jax":
|
|
167
|
+
thetas = jnp.array(thetas)
|
|
168
|
+
xs = jnp.array(xs)
|
|
169
|
+
return {"theta":thetas, "x":xs}
|
|
170
|
+
|
|
171
|
+
def get_observation(self, index: int):
|
|
172
|
+
if self.backend == "torch":
|
|
173
|
+
return self.task.get_observation(index)
|
|
174
|
+
else:
|
|
175
|
+
out = self.task.get_observation(index)
|
|
176
|
+
if self.backend == "numpy":
|
|
177
|
+
return out.numpy()
|
|
178
|
+
elif self.backend == "jax":
|
|
179
|
+
return jnp.array(out)
|
|
180
|
+
|
|
181
|
+
def get_reference_posterior_samples(self, index: int):
|
|
182
|
+
if self.backend == "torch":
|
|
183
|
+
return self.task.get_reference_posterior_samples(index)
|
|
184
|
+
else:
|
|
185
|
+
out = self.task.get_reference_posterior_samples(index)
|
|
186
|
+
if self.backend == "numpy":
|
|
187
|
+
return out.numpy()
|
|
188
|
+
elif self.backend == "jax":
|
|
189
|
+
return jnp.array(out)
|
|
190
|
+
|
|
191
|
+
def get_true_parameters(self, index: int):
|
|
192
|
+
if self.backend == "torch":
|
|
193
|
+
return self.task.get_true_parameters(index)
|
|
194
|
+
else:
|
|
195
|
+
out = self.task.get_true_parameters(index)
|
|
196
|
+
if self.backend == "numpy":
|
|
197
|
+
return out.numpy()
|
|
198
|
+
elif self.backend == "jax":
|
|
199
|
+
return jnp.array(out)
|
|
200
|
+
|
|
201
|
+
class LinearGaussian(SBIBMTask):
|
|
202
|
+
def __init__(self, backend: str = "torch") -> None:
|
|
203
|
+
super().__init__(name="gaussian_linear", backend=backend)
|
|
204
|
+
|
|
205
|
+
def get_base_mask_fn(self):
|
|
206
|
+
task = _get_torch_task(self.name)
|
|
207
|
+
theta_dim = task.dim_parameters
|
|
208
|
+
x_dim = task.dim_cond
|
|
209
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
210
|
+
x_i_mask = jnp.eye(x_dim, dtype=jnp.bool_)
|
|
211
|
+
base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.eye((x_dim)), x_i_mask]])
|
|
212
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
213
|
+
|
|
214
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
215
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
216
|
+
|
|
217
|
+
return base_mask_fn
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class BernoulliGLM(SBIBMTask):
|
|
221
|
+
def __init__(self, backend: str = "torch") -> None:
|
|
222
|
+
super().__init__(name="bernoulli_glm", backend=backend)
|
|
223
|
+
|
|
224
|
+
def get_base_mask_fn(self):
|
|
225
|
+
raise NotImplementedError()
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class BernoulliGLMRaw(SBIBMTask):
|
|
229
|
+
def __init__(self, backend: str = "torch") -> None:
|
|
230
|
+
super().__init__(name="bernoulli_glm_raw", backend=backend)
|
|
231
|
+
|
|
232
|
+
def get_base_mask_fn(self):
|
|
233
|
+
raise NotImplementedError()
|
|
234
|
+
|
|
235
|
+
class MixtureGaussian(SBIBMTask):
|
|
236
|
+
def __init__(self, backend: str = "torch") -> None:
|
|
237
|
+
super().__init__(name="gaussian_mixture", backend=backend)
|
|
238
|
+
|
|
239
|
+
def get_base_mask_fn(self):
|
|
240
|
+
task = _get_torch_task(self.name)
|
|
241
|
+
theta_dim = task.dim_parameters
|
|
242
|
+
x_dim = task.dim_cond
|
|
243
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
244
|
+
x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
|
|
245
|
+
base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.ones((x_dim, theta_dim)), x_mask]])
|
|
246
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
247
|
+
|
|
248
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
249
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
250
|
+
|
|
251
|
+
return base_mask_fn
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class TwoMoons(SBIBMTask):
|
|
256
|
+
def __init__(self, backend: str = "torch") -> None:
|
|
257
|
+
super().__init__(name="two_moons", backend=backend)
|
|
258
|
+
|
|
259
|
+
def get_base_mask_fn(self):
|
|
260
|
+
task = self.task
|
|
261
|
+
theta_dim = task.dim_parameters
|
|
262
|
+
x_dim = task.dim_cond
|
|
263
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
264
|
+
x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
|
|
265
|
+
base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.ones((x_dim, theta_dim)), x_mask]])
|
|
266
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
267
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
268
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
269
|
+
|
|
270
|
+
return base_mask_fn
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class SLCP(SBIBMTask):
|
|
274
|
+
def __init__(self, backend: str = "torch") -> None:
|
|
275
|
+
super().__init__(name="slcp", backend=backend)
|
|
276
|
+
|
|
277
|
+
def get_base_mask_fn(self):
|
|
278
|
+
task = _get_torch_task(self.name)
|
|
279
|
+
theta_dim = task.dim_parameters
|
|
280
|
+
x_dim = task.dim_cond
|
|
281
|
+
thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
282
|
+
# TODO This could be triangular -> DAG
|
|
283
|
+
x_i_dim = x_dim // 4
|
|
284
|
+
x_i_mask = jax.scipy.linalg.block_diag(*tuple([jnp.tril(jnp.ones((x_i_dim,x_i_dim), dtype=jnp.bool_))]*4))
|
|
285
|
+
base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim,x_dim))], [jnp.ones((x_dim, theta_dim)), x_i_mask]])
|
|
286
|
+
base_mask = base_mask.astype(jnp.bool_)
|
|
287
|
+
def base_mask_fn(node_ids, node_meta_data):
|
|
288
|
+
# If node_ids are permuted, we need to permute the base_mask
|
|
289
|
+
return base_mask[node_ids, :][:, node_ids]
|
|
290
|
+
|
|
291
|
+
return base_mask_fn
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# class AllConditionalBMTask(ABC):
|
|
295
|
+
|
|
296
|
+
# def __init__(self, task_name: str) -> None:
|
|
297
|
+
# self.task_name = task_name
|
|
298
|
+
# self.task = sbibm.get_task(task_name)
|
|
299
|
+
# self.base_mask_fn = self.get_base_mask_fn()
|
|
300
|
+
# self.simulator = self.task.get_simulator()
|
|
301
|
+
# self.prior = self.task.get_prior()
|
|
302
|
+
|
|
303
|
+
# @abstractmethod
|
|
304
|
+
# def get_base_mask_fn(self):
|
|
305
|
+
# """
|
|
306
|
+
# Returns a function that takes in node_ids and node_meta_data and returns the base mask
|
|
307
|
+
# for the given node_ids.
|
|
308
|
+
# """
|
|
309
|
+
# pass
|
|
310
|
+
|
|
311
|
+
# class TwoMoonsAllConditionalTask(AllConditionalBMTask):
|
|
312
|
+
# def __init__(self) -> None:
|
|
313
|
+
# super().__init__(task_name="two_moons")
|
|
314
|
+
# return
|
|
315
|
+
|
|
316
|
+
# def get_base_mask_fn(self):
|
|
317
|
+
# thetas_mask = jnp.eye(2, dtype=jnp.bool_)
|
|
318
|
+
# x_mask = jnp.tril(jnp.ones((2, 2), dtype=jnp.bool_))
|
|
319
|
+
# base_mask = jnp.block(
|
|
320
|
+
# [[thetas_mask, jnp.zeros((2, 2))], [jnp.ones((2, 2)), x_mask]]
|
|
321
|
+
# )
|
|
322
|
+
# base_mask = base_mask.astype(jnp.bool_)
|
|
323
|
+
|
|
324
|
+
# def base_mask_fn(node_ids, node_meta_data):
|
|
325
|
+
# return base_mask[node_ids, :][:, node_ids]
|
|
326
|
+
|
|
327
|
+
# return base_mask_fn
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
# class SLCPAllConditionalTask(AllConditionalBMTask):
|
|
331
|
+
# def __init__(self) -> None:
|
|
332
|
+
# super().__init__(task_name="slcp")
|
|
333
|
+
|
|
334
|
+
# # def ravel_condition_mask(condition_mask):
|
|
335
|
+
# # thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond = jnp.split(condition_mask, [5,7,9,11], axis=-1)
|
|
336
|
+
# # x1_cond = jnp.any(x1_cond, axis=-1)[None]
|
|
337
|
+
# # x2_cond = jnp.any(x2_cond, axis=-1)[None]
|
|
338
|
+
# # x3_cond = jnp.any(x3_cond, axis=-1)[None]
|
|
339
|
+
# # x4_cond = jnp.any(x4_cond, axis=-1)[None]
|
|
340
|
+
# # return jnp.hstack([thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond])
|
|
341
|
+
# # def unravel_condition_mask(condition_mask):
|
|
342
|
+
# # thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond = jnp.split(condition_mask, [5,6,7,8], axis=-1)
|
|
343
|
+
# # x1_cond = jnp.repeat(x1_cond, 2, axis=-1)
|
|
344
|
+
# # x2_cond = jnp.repeat(x2_cond, 2, axis=-1)
|
|
345
|
+
# # x3_cond = jnp.repeat(x3_cond, 2, axis=-1)
|
|
346
|
+
# # x4_cond = jnp.repeat(x4_cond, 2, axis=-1)
|
|
347
|
+
# # return jnp.hstack([thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond])
|
|
348
|
+
|
|
349
|
+
# # self.ravel_condition_mask = ravel_condition_mask
|
|
350
|
+
# # self.unravel_condition_mask = unravel_condition_mask
|
|
351
|
+
# return
|
|
352
|
+
|
|
353
|
+
# def get_base_mask_fn(self):
|
|
354
|
+
# theta_dim = 5
|
|
355
|
+
# x_dim = 8
|
|
356
|
+
# thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
|
|
357
|
+
# # TODO This could be triangular -> DAG
|
|
358
|
+
# x_i_dim = x_dim // 4
|
|
359
|
+
# x_i_mask = jax.scipy.linalg.block_diag(
|
|
360
|
+
# *tuple([jnp.tril(jnp.ones((x_i_dim, x_i_dim), dtype=jnp.bool_))] * 4)
|
|
361
|
+
# )
|
|
362
|
+
# base_mask = jnp.block(
|
|
363
|
+
# [
|
|
364
|
+
# [thetas_mask, jnp.zeros((theta_dim, x_dim))],
|
|
365
|
+
# [jnp.ones((x_dim, theta_dim)), x_i_mask],
|
|
366
|
+
# ]
|
|
367
|
+
# )
|
|
368
|
+
# base_mask = base_mask.astype(jnp.bool_)
|
|
369
|
+
|
|
370
|
+
# def base_mask_fn(node_ids, node_meta_data):
|
|
371
|
+
# # If node_ids are permuted, we need to permute the base_mask
|
|
372
|
+
# return base_mask[node_ids, :][:, node_ids]
|
|
373
|
+
|
|
374
|
+
# return base_mask_fn
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# class NonlinearGaussianTreeAllConditionalTask(AllConditionalBMTask):
|
|
378
|
+
# def __init__(self) -> None:
|
|
379
|
+
# super().__init__(task_name="two_moons")
|
|
380
|
+
# return
|
|
381
|
+
|
|
382
|
+
# def get_base_mask_fn(self):
|
|
383
|
+
# base_mask = jnp.array(
|
|
384
|
+
# [
|
|
385
|
+
# [True, False, False, False, False, False, False],
|
|
386
|
+
# [True, True, False, False, False, False, False],
|
|
387
|
+
# [True, False, True, False, False, False, False],
|
|
388
|
+
# [False, True, False, True, False, False, False],
|
|
389
|
+
# [False, True, False, False, True, False, False],
|
|
390
|
+
# [False, False, True, False, False, True, False],
|
|
391
|
+
# [False, False, True, False, False, False, True],
|
|
392
|
+
# ]
|
|
393
|
+
# )
|
|
394
|
+
|
|
395
|
+
# def base_mask_fn(node_ids, node_meta_data):
|
|
396
|
+
# return base_mask[node_ids, :][:, node_ids]
|
|
397
|
+
|
|
398
|
+
# return base_mask_fn
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
# class NonlinearMarcovChainAllConditionalTask(AllConditionalBMTask):
|
|
402
|
+
# def __init__(self) -> None:
|
|
403
|
+
# super().__init__(task_name="two_moons")
|
|
404
|
+
# return
|
|
405
|
+
|
|
406
|
+
# def get_base_mask_fn(self):
|
|
407
|
+
# # Marcovian structure
|
|
408
|
+
# theta_mask = jnp.eye(10, dtype=jnp.bool_) | jnp.eye(10, k=-1, dtype=jnp.bool_)
|
|
409
|
+
# xs_mask = jnp.eye(10, dtype=jnp.bool_)
|
|
410
|
+
# theta_xs_mask = jnp.eye(10, dtype=jnp.bool_)
|
|
411
|
+
# fill_mask = jnp.zeros((10, 10), dtype=jnp.bool_)
|
|
412
|
+
# base_mask = jnp.block([[theta_mask, fill_mask], [theta_xs_mask, xs_mask]])
|
|
413
|
+
|
|
414
|
+
# def base_mask_fn(node_ids, node_meta_data):
|
|
415
|
+
# return base_mask[node_ids, :][:, node_ids]
|
|
416
|
+
|
|
417
|
+
# return base_mask_fn
|