gensbi-examples 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,417 @@
1
+
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ from functools import partial
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ from .graph import faithfull_mask, min_faithfull_mask, moralize
10
+
11
+
12
+ from sbibm import get_task as _get_torch_task
13
+
14
+ import torch
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+
19
+ class Task(ABC):
20
+
21
+ def __init__(self, name: str, backend: str = "torch") -> None:
22
+ self.name = name
23
+ self.backend = backend
24
+
25
+ @property
26
+ def theta_dim(self):
27
+ return self.get_theta_dim()
28
+
29
+ @property
30
+ def x_dim(self):
31
+ return self.get_x_dim()
32
+
33
+ def get_theta_dim(self):
34
+ raise NotImplementedError()
35
+
36
+ def get_x_dim(self):
37
+ raise NotImplementedError()
38
+
39
+ def get_data(self, num_samples: int, key=None):
40
+ raise NotImplementedError()
41
+
42
+ def get_node_id(self):
43
+ raise NotImplementedError()
44
+
45
+ def get_base_mask_fn(self):
46
+ raise NotImplementedError()
47
+
48
+
49
+
50
+
51
+
52
+ class InferenceTask(Task):
53
+
54
+ observations = range(1, 11)
55
+
56
+ def __init__(self, name: str, backend: str = "jax") -> None:
57
+ super().__init__(name, backend)
58
+
59
+ def get_prior(self):
60
+ raise NotImplementedError()
61
+
62
+ def get_simulator(self):
63
+ raise NotImplementedError()
64
+
65
+ def get_data(self, num_samples: int, key=None):
66
+ raise NotImplementedError()
67
+
68
+ def get_observation(self, index: int):
69
+ raise NotImplementedError()
70
+
71
+ def get_reference_posterior_samples(self, index: int):
72
+ raise NotImplementedError()
73
+
74
+ def get_true_parameters(self, index: int):
75
+ raise NotImplementedError()
76
+
77
+ def get_edge_mask_fn(self, name="undirected"):
78
+ task = self.task
79
+ if name.lower() == "faithfull":
80
+ base_mask_fn = self.get_base_mask_fn()
81
+ def faithfull_edge_mask(node_id, condition_mask, meta_data=None):
82
+ base_mask = base_mask_fn(node_id, meta_data)
83
+ return faithfull_mask(base_mask, condition_mask)
84
+
85
+ return faithfull_edge_mask
86
+ elif name.lower() == "min_faithfull":
87
+ base_mask_fn = self.get_base_mask_fn()
88
+ def min_faithfull_edge_mask(node_id, condition_mask,meta_data=None):
89
+ base_mask = base_mask_fn(node_id, meta_data)
90
+
91
+ return min_faithfull_mask(base_mask, condition_mask)
92
+
93
+ return min_faithfull_edge_mask
94
+ elif name.lower() == "undirected":
95
+ base_mask_fn = self.get_base_mask_fn()
96
+ def undirected_edge_mask(node_id, condition_mask, meta_data=None):
97
+ base_mask = base_mask_fn(node_id, meta_data)
98
+ return moralize(base_mask)
99
+
100
+ return undirected_edge_mask
101
+
102
+ elif name.lower() == "directed":
103
+ base_mask_fn = self.get_base_mask_fn()
104
+ def directed_edge_mask(node_id, condition_mask, meta_data=None):
105
+ base_mask = base_mask_fn(node_id, meta_data)
106
+ return base_mask
107
+
108
+ return directed_edge_mask
109
+ elif name.lower() == "none":
110
+ return lambda node_id, condition_mask, *args, **kwargs: None
111
+ else:
112
+ raise NotImplementedError()
113
+
114
+
115
+ class SBIBMTask(InferenceTask):
116
+ observations = range(1, 11)
117
+
118
+ def __init__(self, name: str, backend: str = "jax") -> None:
119
+ super().__init__(name, backend)
120
+ self.task = _get_torch_task(self.name)
121
+
122
+ def get_theta_dim(self):
123
+ return self.task.dim_parameters
124
+
125
+ def get_x_dim(self):
126
+ return self.task.dim_cond
127
+
128
+ def get_prior(self):
129
+ if self.backend == "torch":
130
+ return self.task.get_prior_dist()
131
+ else:
132
+ raise NotImplementedError()
133
+
134
+ def get_simulator(self):
135
+ if self.backend == "torch":
136
+ return self.task.get_simulator()
137
+ else:
138
+ raise NotImplementedError()
139
+
140
+ def get_node_id(self):
141
+ dim = self.get_theta_dim() + self.get_x_dim()
142
+ if self.backend == "torch":
143
+ return torch.arange(dim)
144
+ else:
145
+ return jnp.arange(dim)
146
+
147
+ def get_data(self, num_samples: int, **kwargs):
148
+ try:
149
+ prior = self.get_prior()
150
+ simulator = self.get_simulator()
151
+ thetas = prior.sample((num_samples,))
152
+ xs = simulator(thetas)
153
+ return {"theta":thetas, "x":xs}
154
+ except:
155
+ # If not implemented in JAX, use PyTorch
156
+ old_backed = self.backend
157
+ self.backend = "torch"
158
+ prior = self.get_prior()
159
+ simulator = self.get_simulator()
160
+ thetas = prior.sample((num_samples,))
161
+ xs = simulator(thetas)
162
+ self.backend = old_backed
163
+ if self.backend == "numpy":
164
+ thetas = thetas.numpy()
165
+ xs = xs.numpy()
166
+ elif self.backend == "jax":
167
+ thetas = jnp.array(thetas)
168
+ xs = jnp.array(xs)
169
+ return {"theta":thetas, "x":xs}
170
+
171
+ def get_observation(self, index: int):
172
+ if self.backend == "torch":
173
+ return self.task.get_observation(index)
174
+ else:
175
+ out = self.task.get_observation(index)
176
+ if self.backend == "numpy":
177
+ return out.numpy()
178
+ elif self.backend == "jax":
179
+ return jnp.array(out)
180
+
181
+ def get_reference_posterior_samples(self, index: int):
182
+ if self.backend == "torch":
183
+ return self.task.get_reference_posterior_samples(index)
184
+ else:
185
+ out = self.task.get_reference_posterior_samples(index)
186
+ if self.backend == "numpy":
187
+ return out.numpy()
188
+ elif self.backend == "jax":
189
+ return jnp.array(out)
190
+
191
+ def get_true_parameters(self, index: int):
192
+ if self.backend == "torch":
193
+ return self.task.get_true_parameters(index)
194
+ else:
195
+ out = self.task.get_true_parameters(index)
196
+ if self.backend == "numpy":
197
+ return out.numpy()
198
+ elif self.backend == "jax":
199
+ return jnp.array(out)
200
+
201
+ class LinearGaussian(SBIBMTask):
202
+ def __init__(self, backend: str = "torch") -> None:
203
+ super().__init__(name="gaussian_linear", backend=backend)
204
+
205
+ def get_base_mask_fn(self):
206
+ task = _get_torch_task(self.name)
207
+ theta_dim = task.dim_parameters
208
+ x_dim = task.dim_cond
209
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
210
+ x_i_mask = jnp.eye(x_dim, dtype=jnp.bool_)
211
+ base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.eye((x_dim)), x_i_mask]])
212
+ base_mask = base_mask.astype(jnp.bool_)
213
+
214
+ def base_mask_fn(node_ids, node_meta_data):
215
+ return base_mask[node_ids, :][:, node_ids]
216
+
217
+ return base_mask_fn
218
+
219
+
220
+ class BernoulliGLM(SBIBMTask):
221
+ def __init__(self, backend: str = "torch") -> None:
222
+ super().__init__(name="bernoulli_glm", backend=backend)
223
+
224
+ def get_base_mask_fn(self):
225
+ raise NotImplementedError()
226
+
227
+
228
+ class BernoulliGLMRaw(SBIBMTask):
229
+ def __init__(self, backend: str = "torch") -> None:
230
+ super().__init__(name="bernoulli_glm_raw", backend=backend)
231
+
232
+ def get_base_mask_fn(self):
233
+ raise NotImplementedError()
234
+
235
+ class MixtureGaussian(SBIBMTask):
236
+ def __init__(self, backend: str = "torch") -> None:
237
+ super().__init__(name="gaussian_mixture", backend=backend)
238
+
239
+ def get_base_mask_fn(self):
240
+ task = _get_torch_task(self.name)
241
+ theta_dim = task.dim_parameters
242
+ x_dim = task.dim_cond
243
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
244
+ x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
245
+ base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.ones((x_dim, theta_dim)), x_mask]])
246
+ base_mask = base_mask.astype(jnp.bool_)
247
+
248
+ def base_mask_fn(node_ids, node_meta_data):
249
+ return base_mask[node_ids, :][:, node_ids]
250
+
251
+ return base_mask_fn
252
+
253
+
254
+
255
+ class TwoMoons(SBIBMTask):
256
+ def __init__(self, backend: str = "torch") -> None:
257
+ super().__init__(name="two_moons", backend=backend)
258
+
259
+ def get_base_mask_fn(self):
260
+ task = self.task
261
+ theta_dim = task.dim_parameters
262
+ x_dim = task.dim_cond
263
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
264
+ x_mask = jnp.tril(jnp.ones((theta_dim, x_dim), dtype=jnp.bool_))
265
+ base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim, x_dim))], [jnp.ones((x_dim, theta_dim)), x_mask]])
266
+ base_mask = base_mask.astype(jnp.bool_)
267
+ def base_mask_fn(node_ids, node_meta_data):
268
+ return base_mask[node_ids, :][:, node_ids]
269
+
270
+ return base_mask_fn
271
+
272
+
273
+ class SLCP(SBIBMTask):
274
+ def __init__(self, backend: str = "torch") -> None:
275
+ super().__init__(name="slcp", backend=backend)
276
+
277
+ def get_base_mask_fn(self):
278
+ task = _get_torch_task(self.name)
279
+ theta_dim = task.dim_parameters
280
+ x_dim = task.dim_cond
281
+ thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
282
+ # TODO This could be triangular -> DAG
283
+ x_i_dim = x_dim // 4
284
+ x_i_mask = jax.scipy.linalg.block_diag(*tuple([jnp.tril(jnp.ones((x_i_dim,x_i_dim), dtype=jnp.bool_))]*4))
285
+ base_mask = jnp.block([[thetas_mask, jnp.zeros((theta_dim,x_dim))], [jnp.ones((x_dim, theta_dim)), x_i_mask]])
286
+ base_mask = base_mask.astype(jnp.bool_)
287
+ def base_mask_fn(node_ids, node_meta_data):
288
+ # If node_ids are permuted, we need to permute the base_mask
289
+ return base_mask[node_ids, :][:, node_ids]
290
+
291
+ return base_mask_fn
292
+
293
+
294
+ # class AllConditionalBMTask(ABC):
295
+
296
+ # def __init__(self, task_name: str) -> None:
297
+ # self.task_name = task_name
298
+ # self.task = sbibm.get_task(task_name)
299
+ # self.base_mask_fn = self.get_base_mask_fn()
300
+ # self.simulator = self.task.get_simulator()
301
+ # self.prior = self.task.get_prior()
302
+
303
+ # @abstractmethod
304
+ # def get_base_mask_fn(self):
305
+ # """
306
+ # Returns a function that takes in node_ids and node_meta_data and returns the base mask
307
+ # for the given node_ids.
308
+ # """
309
+ # pass
310
+
311
+ # class TwoMoonsAllConditionalTask(AllConditionalBMTask):
312
+ # def __init__(self) -> None:
313
+ # super().__init__(task_name="two_moons")
314
+ # return
315
+
316
+ # def get_base_mask_fn(self):
317
+ # thetas_mask = jnp.eye(2, dtype=jnp.bool_)
318
+ # x_mask = jnp.tril(jnp.ones((2, 2), dtype=jnp.bool_))
319
+ # base_mask = jnp.block(
320
+ # [[thetas_mask, jnp.zeros((2, 2))], [jnp.ones((2, 2)), x_mask]]
321
+ # )
322
+ # base_mask = base_mask.astype(jnp.bool_)
323
+
324
+ # def base_mask_fn(node_ids, node_meta_data):
325
+ # return base_mask[node_ids, :][:, node_ids]
326
+
327
+ # return base_mask_fn
328
+
329
+
330
+ # class SLCPAllConditionalTask(AllConditionalBMTask):
331
+ # def __init__(self) -> None:
332
+ # super().__init__(task_name="slcp")
333
+
334
+ # # def ravel_condition_mask(condition_mask):
335
+ # # thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond = jnp.split(condition_mask, [5,7,9,11], axis=-1)
336
+ # # x1_cond = jnp.any(x1_cond, axis=-1)[None]
337
+ # # x2_cond = jnp.any(x2_cond, axis=-1)[None]
338
+ # # x3_cond = jnp.any(x3_cond, axis=-1)[None]
339
+ # # x4_cond = jnp.any(x4_cond, axis=-1)[None]
340
+ # # return jnp.hstack([thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond])
341
+ # # def unravel_condition_mask(condition_mask):
342
+ # # thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond = jnp.split(condition_mask, [5,6,7,8], axis=-1)
343
+ # # x1_cond = jnp.repeat(x1_cond, 2, axis=-1)
344
+ # # x2_cond = jnp.repeat(x2_cond, 2, axis=-1)
345
+ # # x3_cond = jnp.repeat(x3_cond, 2, axis=-1)
346
+ # # x4_cond = jnp.repeat(x4_cond, 2, axis=-1)
347
+ # # return jnp.hstack([thetas_cond, x1_cond, x2_cond, x3_cond, x4_cond])
348
+
349
+ # # self.ravel_condition_mask = ravel_condition_mask
350
+ # # self.unravel_condition_mask = unravel_condition_mask
351
+ # return
352
+
353
+ # def get_base_mask_fn(self):
354
+ # theta_dim = 5
355
+ # x_dim = 8
356
+ # thetas_mask = jnp.eye(theta_dim, dtype=jnp.bool_)
357
+ # # TODO This could be triangular -> DAG
358
+ # x_i_dim = x_dim // 4
359
+ # x_i_mask = jax.scipy.linalg.block_diag(
360
+ # *tuple([jnp.tril(jnp.ones((x_i_dim, x_i_dim), dtype=jnp.bool_))] * 4)
361
+ # )
362
+ # base_mask = jnp.block(
363
+ # [
364
+ # [thetas_mask, jnp.zeros((theta_dim, x_dim))],
365
+ # [jnp.ones((x_dim, theta_dim)), x_i_mask],
366
+ # ]
367
+ # )
368
+ # base_mask = base_mask.astype(jnp.bool_)
369
+
370
+ # def base_mask_fn(node_ids, node_meta_data):
371
+ # # If node_ids are permuted, we need to permute the base_mask
372
+ # return base_mask[node_ids, :][:, node_ids]
373
+
374
+ # return base_mask_fn
375
+
376
+
377
+ # class NonlinearGaussianTreeAllConditionalTask(AllConditionalBMTask):
378
+ # def __init__(self) -> None:
379
+ # super().__init__(task_name="two_moons")
380
+ # return
381
+
382
+ # def get_base_mask_fn(self):
383
+ # base_mask = jnp.array(
384
+ # [
385
+ # [True, False, False, False, False, False, False],
386
+ # [True, True, False, False, False, False, False],
387
+ # [True, False, True, False, False, False, False],
388
+ # [False, True, False, True, False, False, False],
389
+ # [False, True, False, False, True, False, False],
390
+ # [False, False, True, False, False, True, False],
391
+ # [False, False, True, False, False, False, True],
392
+ # ]
393
+ # )
394
+
395
+ # def base_mask_fn(node_ids, node_meta_data):
396
+ # return base_mask[node_ids, :][:, node_ids]
397
+
398
+ # return base_mask_fn
399
+
400
+
401
+ # class NonlinearMarcovChainAllConditionalTask(AllConditionalBMTask):
402
+ # def __init__(self) -> None:
403
+ # super().__init__(task_name="two_moons")
404
+ # return
405
+
406
+ # def get_base_mask_fn(self):
407
+ # # Marcovian structure
408
+ # theta_mask = jnp.eye(10, dtype=jnp.bool_) | jnp.eye(10, k=-1, dtype=jnp.bool_)
409
+ # xs_mask = jnp.eye(10, dtype=jnp.bool_)
410
+ # theta_xs_mask = jnp.eye(10, dtype=jnp.bool_)
411
+ # fill_mask = jnp.zeros((10, 10), dtype=jnp.bool_)
412
+ # base_mask = jnp.block([[theta_mask, fill_mask], [theta_xs_mask, xs_mask]])
413
+
414
+ # def base_mask_fn(node_ids, node_meta_data):
415
+ # return base_mask[node_ids, :][:, node_ids]
416
+
417
+ # return base_mask_fn