effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,562 @@
1
+ try:
2
+ import numpyro.distributions as dist
3
+ except ImportError:
4
+ raise ImportError("Numpyro is required to use effectful.handlers.numpyro")
5
+
6
+
7
+ import functools
8
+ from collections.abc import Collection, Hashable, Mapping
9
+ from typing import Any, cast
10
+
11
+ import jax
12
+ import tree
13
+
14
+ import effectful.handlers.jax.numpy as jnp
15
+ from effectful.handlers.jax import bind_dims, jax_getitem, sizesof, unbind_dims
16
+ from effectful.handlers.jax._handlers import _register_jax_op, is_eager_array
17
+ from effectful.ops.semantics import apply, runner, typeof
18
+ from effectful.ops.syntax import defdata, defop, defterm
19
+ from effectful.ops.types import NotHandled, Operation, Term
20
+
21
+
22
+ class Naming(dict[Operation[[], jax.Array], int]):
23
+ """
24
+ A mapping from dimensions (indexed from the right) to names.
25
+ """
26
+
27
+ def __init__(self, name_to_dim: Mapping[Operation[[], jax.Array], int]):
28
+ assert all(v < 0 for v in name_to_dim.values())
29
+ super().__init__(name_to_dim)
30
+
31
+ @staticmethod
32
+ def from_shape(
33
+ names: Collection[Operation[[], jax.Array]], event_dims: int
34
+ ) -> "Naming":
35
+ """Create a naming from a set of indices and the number of event dimensions.
36
+
37
+ The resulting naming converts tensors of shape
38
+ ``| batch_shape | named | event_shape |``
39
+ to tensors of shape ``| batch_shape | event_shape |, | named |``.
40
+
41
+ """
42
+ assert event_dims >= 0
43
+ return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
44
+
45
+ def apply(self, value: jax.Array) -> jax.Array:
46
+ indexes: list[Any] = [slice(None)] * (len(value.shape))
47
+ for n, d in self.items():
48
+ indexes[len(value.shape) + d] = n()
49
+ return jax_getitem(value, tuple(indexes))
50
+
51
+ def __repr__(self):
52
+ return f"Naming({super().__repr__()})"
53
+
54
+
55
+ @unbind_dims.register # type: ignore
56
+ def _unbind_distribution(
57
+ d: dist.Distribution, *names: Operation[[], jax.Array]
58
+ ) -> dist.Distribution:
59
+ batch_shape = None
60
+
61
+ def _validate_batch_shape(t):
62
+ nonlocal batch_shape
63
+ if len(t.shape) < len(names):
64
+ raise ValueError(
65
+ "All tensors must have at least as many dimensions as names"
66
+ )
67
+
68
+ if batch_shape is None:
69
+ batch_shape = t.shape[: len(names)]
70
+
71
+ if (
72
+ len(t.shape) < len(batch_shape)
73
+ or t.shape[: len(batch_shape)] != batch_shape
74
+ ):
75
+ raise ValueError("All tensors must have the same batch shape.")
76
+
77
+ def _to_named(a):
78
+ nonlocal batch_shape
79
+ if isinstance(a, jax.Array):
80
+ _validate_batch_shape(a)
81
+ return unbind_dims(a, *names)
82
+ elif isinstance(a, dist.Distribution):
83
+ return unbind_dims(a, *names)
84
+ else:
85
+ return a
86
+
87
+ # Convert to a term in a context that does not evaluate distribution constructors.
88
+ def _apply(op, *args, **kwargs):
89
+ typ = op.__type_rule__(*args, **kwargs)
90
+ if issubclass(typ, dist.Distribution):
91
+ return defdata(op, *args, **kwargs)
92
+ return op.__default_rule__(*args, **kwargs)
93
+
94
+ with runner({apply: _apply}):
95
+ d = defterm(d)
96
+
97
+ if not (isinstance(d, Term) and typeof(d) is dist.Distribution):
98
+ raise NotHandled
99
+
100
+ # TODO: this is a hack to avoid mangling arguments that are array-valued, but not batched
101
+ aux_kwargs = set(["total_count"])
102
+
103
+ new_d = d.op(
104
+ *[_to_named(a) for a in d.args],
105
+ **{k: v if k in aux_kwargs else _to_named(v) for (k, v) in d.kwargs.items()},
106
+ )
107
+ return new_d
108
+
109
+
110
+ @bind_dims.register # type: ignore
111
+ def _bind_dims_distribution(
112
+ d: dist.Distribution, *names: Operation[[], jax.Array]
113
+ ) -> dist.Distribution:
114
+ def _to_positional(a, indices):
115
+ typ = typeof(a)
116
+ if issubclass(typ, jax.Array):
117
+ # broadcast to full indexed shape
118
+ existing_dims = set(sizesof(a).keys())
119
+ missing_dims = set(indices) - existing_dims
120
+
121
+ a_indexed = unbind_dims(
122
+ jnp.broadcast_to(
123
+ a, tuple(indices[dim] for dim in missing_dims) + a.shape
124
+ ),
125
+ *missing_dims,
126
+ )
127
+ return bind_dims(a_indexed, *indices)
128
+ elif issubclass(typ, dist.Distribution):
129
+ # We are really assuming that only one distriution appears in our arguments. This is sufficient for cases
130
+ # like Independent and TransformedDistribution
131
+ return bind_dims(a, *indices)
132
+ else:
133
+ return a
134
+
135
+ # Convert to a term in a context that does not evaluate distribution constructors.
136
+ def _apply(op, *args, **kwargs):
137
+ typ = op.__type_rule__(*args, **kwargs)
138
+ if issubclass(typ, dist.Distribution):
139
+ return defdata(op, *args, **kwargs)
140
+ return op.__default_rule__(*args, **kwargs)
141
+
142
+ with runner({apply: _apply}):
143
+ d = defterm(d)
144
+
145
+ if not (isinstance(d, Term) and typeof(d) is dist.Distribution):
146
+ raise NotHandled
147
+
148
+ sizes = sizesof(d)
149
+ indices = {k: sizes[k] for k in names}
150
+
151
+ pos_args = [_to_positional(a, indices) for a in d.args]
152
+ pos_kwargs = {k: _to_positional(v, indices) for (k, v) in d.kwargs.items()}
153
+ new_d = d.op(*pos_args, **pos_kwargs)
154
+
155
+ return new_d
156
+
157
+
158
+ @functools.cache
159
+ def _register_distribution_op(
160
+ dist_constr: type[dist.Distribution],
161
+ ) -> Operation[Any, dist.Distribution]:
162
+ # introduce a wrapper so that we can control type annotations
163
+ def wrapper(*args, **kwargs) -> dist.Distribution:
164
+ if any(isinstance(a, Term) for a in tree.flatten((args, kwargs))):
165
+ raise NotHandled
166
+ return dist_constr(*args, **kwargs)
167
+
168
+ return defop(wrapper, name=dist_constr.__name__)
169
+
170
+
171
+ @defdata.register(dist.Distribution)
172
+ def _(op, *args, **kwargs):
173
+ if all(
174
+ not isinstance(a, Term) or is_eager_array(a) or isinstance(a, dist.Distribution)
175
+ for a in tree.flatten((args, kwargs))
176
+ ):
177
+ return _DistributionTerm(op, *args, **kwargs)
178
+ else:
179
+ return defdata.dispatch(object)(op, *args, **kwargs)
180
+
181
+
182
+ def _broadcast_to_named(t, sizes):
183
+ missing_dims = set(sizes) - set(sizesof(t))
184
+ t_broadcast = jnp.broadcast_to(
185
+ t, tuple(sizes[dim] for dim in missing_dims) + t.shape
186
+ )
187
+ return jax_getitem(t_broadcast, tuple(dim() for dim in missing_dims))
188
+
189
+
190
+ def expand_to_batch_shape(tensor, batch_ndims, expanded_batch_shape):
191
+ """
192
+ Expands a tensor of shape batch_shape + remaining_shape to
193
+ expanded_batch_shape + remaining_shape.
194
+
195
+ Args:
196
+ tensor: JAX array with shape batch_shape + event_shape
197
+ expanded_batch_shape: tuple of the desired expanded batch dimensions
198
+ event_ndims: number of dimensions in the event_shape
199
+
200
+ Returns:
201
+ A JAX array with shape expanded_batch_shape + event_shape
202
+ """
203
+ # Split the shape into batch and event parts
204
+ assert len(tensor.shape) >= batch_ndims
205
+
206
+ batch_shape = tensor.shape[:batch_ndims] if batch_ndims > 0 else ()
207
+ remaining_shape = tensor.shape[batch_ndims:]
208
+
209
+ # Ensure the expanded batch shape is compatible with the current batch shape
210
+ if len(expanded_batch_shape) < batch_ndims:
211
+ raise ValueError(
212
+ "Expanded batch shape must have at least as many dimensions as current batch shape"
213
+ )
214
+ new_batch_shape = jnp.broadcast_shapes(batch_shape, expanded_batch_shape)
215
+
216
+ # Create the new shape
217
+ new_shape = new_batch_shape + remaining_shape
218
+
219
+ # Broadcast the tensor to the new shape
220
+ expanded_tensor = jnp.broadcast_to(tensor, new_shape)
221
+
222
+ return expanded_tensor
223
+
224
+
225
+ class _DistributionTerm(dist.Distribution):
226
+ """A distribution wrapper that satisfies the Term interface.
227
+
228
+ Represented as a term of the form D(*args, **kwargs) where D is the
229
+ distribution constructor.
230
+
231
+ Note: When we construct instances of this class, we put distribution
232
+ parameters that can be expanded in the args list and those that cannot in
233
+ the kwargs list.
234
+
235
+ """
236
+
237
+ _op: Operation[Any, dist.Distribution]
238
+ _args: tuple
239
+ _kwargs: dict
240
+
241
+ def __init__(self, op: Operation[Any, dist.Distribution], *args, **kwargs):
242
+ self._op = op
243
+ self._args = args
244
+ self._kwargs = kwargs
245
+ self.__indices = None
246
+ self.__pos_base_dist = None
247
+
248
+ @property
249
+ def _indices(self):
250
+ if self.__indices is None:
251
+ self.__indices = sizesof(self)
252
+ return self.__indices
253
+
254
+ @property
255
+ def _pos_base_dist(self):
256
+ if self.__pos_base_dist is None:
257
+ self.__pos_base_dist = bind_dims(self, *self._indices)
258
+ return self.__pos_base_dist
259
+
260
+ @property
261
+ def op(self):
262
+ return self._op
263
+
264
+ @property
265
+ def args(self):
266
+ return self._args
267
+
268
+ @property
269
+ def kwargs(self):
270
+ return self._kwargs
271
+
272
+ @property
273
+ def batch_shape(self):
274
+ return self._pos_base_dist.batch_shape[len(self._indices) :]
275
+
276
+ @property
277
+ def has_rsample(self) -> bool:
278
+ return self._pos_base_dist.has_rsample
279
+
280
+ @property
281
+ def event_shape(self):
282
+ return self._pos_base_dist.event_shape
283
+
284
+ def rsample(self, key, sample_shape=()):
285
+ return self._reindex_sample(
286
+ self._pos_base_dist.rsample(key, sample_shape), sample_shape
287
+ )
288
+
289
+ def sample(self, key, sample_shape=()):
290
+ return self._reindex_sample(
291
+ self._pos_base_dist.sample(key, sample_shape), sample_shape
292
+ )
293
+
294
+ def _reindex_sample(self, value, sample_shape):
295
+ index = (slice(None),) * len(sample_shape) + tuple(i() for i in self._indices)
296
+ ret = jax_getitem(value, index)
297
+ return ret
298
+
299
+ def log_prob(self, value):
300
+ # value has shape named_batch_shape + sample_shape + batch_shape + event_shape
301
+ n_batch_event = len(self.batch_shape) + len(self.event_shape)
302
+ sample_shape = (
303
+ value.shape if n_batch_event == 0 else value.shape[:-n_batch_event]
304
+ )
305
+ value = bind_dims(_broadcast_to_named(value, self._indices), *self._indices)
306
+ dims = list(range(len(value.shape)))
307
+ n_named_batch = len(self._indices)
308
+ perm = (
309
+ dims[n_named_batch : n_named_batch + len(sample_shape)]
310
+ + dims[:n_named_batch]
311
+ + dims[n_named_batch + len(sample_shape) :]
312
+ )
313
+ assert len(perm) == len(value.shape)
314
+
315
+ # perm_value has shape sample_shape + named_batch_shape + batch_shape + event_shape
316
+ perm_value = jnp.permute_dims(value, perm)
317
+ pos_log_prob = _register_jax_op(self._pos_base_dist.log_prob)(perm_value)
318
+ ind_log_prob = self._reindex_sample(pos_log_prob, sample_shape)
319
+ return ind_log_prob
320
+
321
+ @property
322
+ def mean(self):
323
+ return self._reindex_sample(self._pos_base_dist.mean, ())
324
+
325
+ @property
326
+ def variance(self):
327
+ return self._reindex_sample(self._pos_base_dist.variance, ())
328
+
329
+ def enumerate_support(self, expand=True):
330
+ return self._reindex_sample(self._pos_base_dist.enumerate_support(expand), ())
331
+
332
+ def entropy(self):
333
+ return self._pos_base_dist.entropy()
334
+
335
+ def to_event(self, reinterpreted_batch_ndims=None):
336
+ raise NotHandled
337
+
338
+ def expand(self, batch_shape):
339
+ def expand_arg(a, batch_shape):
340
+ if is_eager_array(a):
341
+ return expand_to_batch_shape(a, len(self.batch_shape), batch_shape)
342
+ return a
343
+
344
+ if self.batch_shape == batch_shape:
345
+ return self
346
+
347
+ expanded_args = [expand_arg(a, batch_shape) for a in self.args]
348
+ expanded_kwargs = {
349
+ k: expand_arg(v, batch_shape) for (k, v) in self.kwargs.items()
350
+ }
351
+ return self.op(*expanded_args, **expanded_kwargs)
352
+
353
+ def __repr__(self):
354
+ return Term.__repr__(self)
355
+
356
+ def __str__(self):
357
+ return Term.__str__(self)
358
+
359
+
360
+ Term.register(_DistributionTerm)
361
+
362
+
363
+ @defterm.register(dist.Distribution)
364
+ def _embed_distribution(dist: dist.Distribution) -> Term[dist.Distribution]:
365
+ raise ValueError(
366
+ f"No embedding provided for distribution of type {type(dist).__name__}."
367
+ )
368
+
369
+
370
+ @defterm.register(dist.Cauchy)
371
+ @defterm.register(dist.Gumbel)
372
+ @defterm.register(dist.Laplace)
373
+ @defterm.register(dist.LogNormal)
374
+ @defterm.register(dist.Logistic)
375
+ @defterm.register(dist.Normal)
376
+ @defterm.register(dist.StudentT)
377
+ def _embed_loc_scale(d: dist.Distribution) -> Term[dist.Distribution]:
378
+ return _register_distribution_op(cast(Hashable, type(d)))(d.loc, d.scale)
379
+
380
+
381
+ @defterm.register(dist.BernoulliProbs)
382
+ @defterm.register(dist.CategoricalProbs)
383
+ @defterm.register(dist.GeometricProbs)
384
+ def _embed_probs(d: dist.Distribution) -> Term[dist.Distribution]:
385
+ return _register_distribution_op(cast(Hashable, type(d)))(d.probs)
386
+
387
+
388
+ @defterm.register(dist.BernoulliLogits)
389
+ @defterm.register(dist.CategoricalLogits)
390
+ @defterm.register(dist.GeometricLogits)
391
+ def _embed_logits(d: dist.Distribution) -> Term[dist.Distribution]:
392
+ return _register_distribution_op(cast(Hashable, type(d)))(d.logits)
393
+
394
+
395
+ @defterm.register(dist.Beta)
396
+ @defterm.register(dist.Kumaraswamy)
397
+ def _embed_beta(d: dist.Distribution) -> Term[dist.Distribution]:
398
+ return _register_distribution_op(cast(Hashable, type(d)))(
399
+ d.concentration1, d.concentration0
400
+ )
401
+
402
+
403
+ @defterm.register(dist.BinomialProbs)
404
+ @defterm.register(dist.NegativeBinomialProbs)
405
+ @defterm.register(dist.MultinomialProbs)
406
+ def _embed_binomial_probs(d: dist.Distribution) -> Term[dist.Distribution]:
407
+ return _register_distribution_op(cast(Hashable, type(d)))(d.probs, d.total_count)
408
+
409
+
410
+ @defterm.register(dist.BinomialLogits)
411
+ @defterm.register(dist.NegativeBinomialLogits)
412
+ @defterm.register(dist.MultinomialLogits)
413
+ def _embed_binomial_logits(d: dist.Distribution) -> Term[dist.Distribution]:
414
+ return _register_distribution_op(cast(Hashable, type(d)))(d.logits, d.total_count)
415
+
416
+
417
+ @defterm.register
418
+ def _embed_chi2(d: dist.Chi2) -> Term[dist.Distribution]:
419
+ return _register_distribution_op(cast(Hashable, type(d)))(d.df)
420
+
421
+
422
+ @defterm.register
423
+ def _embed_dirichlet(d: dist.Dirichlet) -> Term[dist.Distribution]:
424
+ return _register_distribution_op(cast(Hashable, type(d)))(d.concentration)
425
+
426
+
427
+ @defterm.register
428
+ def _embed_dirichlet_multinomial(
429
+ d: dist.DirichletMultinomial,
430
+ ) -> Term[dist.Distribution]:
431
+ return _register_distribution_op(cast(Hashable, type(d)))(
432
+ d.concentration, total_count=d.total_count
433
+ )
434
+
435
+
436
+ @defterm.register(dist.Exponential)
437
+ @defterm.register(dist.Poisson)
438
+ def _embed_exponential(d: dist.Distribution) -> Term[dist.Distribution]:
439
+ return _register_distribution_op(cast(Hashable, type(d)))(d.rate)
440
+
441
+
442
+ @defterm.register
443
+ def _embed_gamma(d: dist.Gamma) -> Term[dist.Distribution]:
444
+ return _register_distribution_op(cast(Hashable, type(d)))(d.concentration, d.rate)
445
+
446
+
447
+ @defterm.register(dist.HalfCauchy)
448
+ @defterm.register(dist.HalfNormal)
449
+ def _embed_half_cauchy(d: dist.Distribution) -> Term[dist.Distribution]:
450
+ return _register_distribution_op(cast(Hashable, type(d)))(d.scale)
451
+
452
+
453
+ @defterm.register
454
+ def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[dist.Distribution]:
455
+ return _register_distribution_op(cast(Hashable, type(d)))(
456
+ d.dim, concentration=d.concentration
457
+ )
458
+
459
+
460
+ @defterm.register
461
+ def _embed_multivariate_normal(d: dist.MultivariateNormal) -> Term[dist.Distribution]:
462
+ return _register_distribution_op(cast(Hashable, type(d)))(
463
+ d.loc, scale_tril=d.scale_tril
464
+ )
465
+
466
+
467
+ @defterm.register
468
+ def _embed_pareto(d: dist.Pareto) -> Term[dist.Distribution]:
469
+ return _register_distribution_op(cast(Hashable, type(d)))(d.scale, d.alpha)
470
+
471
+
472
+ @defterm.register
473
+ def _embed_uniform(d: dist.Uniform) -> Term[dist.Distribution]:
474
+ return _register_distribution_op(cast(Hashable, type(d)))(d.low, d.high)
475
+
476
+
477
+ @defterm.register
478
+ def _embed_von_mises(d: dist.VonMises) -> Term[dist.Distribution]:
479
+ return _register_distribution_op(cast(Hashable, type(d)))(d.loc, d.concentration)
480
+
481
+
482
+ @defterm.register
483
+ def _embed_weibull(d: dist.Weibull) -> Term[dist.Distribution]:
484
+ return _register_distribution_op(dist.Weibull)(d.scale, d.concentration)
485
+
486
+
487
+ @defterm.register
488
+ def _embed_wishart(d: dist.Wishart) -> Term[dist.Distribution]:
489
+ return _register_distribution_op(dist.Wishart)(d.df, d.scale_tril)
490
+
491
+
492
+ @defterm.register
493
+ def _embed_delta(d: dist.Delta) -> Term[dist.Distribution]:
494
+ return _register_distribution_op(cast(Hashable, type(d)))(
495
+ d.v, log_density=d.log_density, event_dim=d.event_dim
496
+ )
497
+
498
+
499
+ @defterm.register
500
+ def _embed_low_rank_multivariate_normal(
501
+ d: dist.LowRankMultivariateNormal,
502
+ ) -> Term[dist.Distribution]:
503
+ return _register_distribution_op(cast(Hashable, type(d)))(
504
+ d.loc, d.cov_factor, d.cov_diag
505
+ )
506
+
507
+
508
+ @defterm.register
509
+ def _embed_relaxed_bernoulli_logits(
510
+ d: dist.RelaxedBernoulliLogits,
511
+ ) -> Term[dist.Distribution]:
512
+ return _register_distribution_op(cast(Hashable, type(d)))(d.temperature, d.logits)
513
+
514
+
515
+ @defterm.register
516
+ def _embed_independent(d: dist.Independent) -> Term[dist.Distribution]:
517
+ return _register_distribution_op(cast(Hashable, type(d)))(
518
+ d.base_dist, d.reinterpreted_batch_ndims
519
+ )
520
+
521
+
522
+ BernoulliLogits = _register_distribution_op(dist.BernoulliLogits)
523
+ BernoulliProbs = _register_distribution_op(dist.BernoulliProbs)
524
+ Beta = _register_distribution_op(dist.Beta)
525
+ BinomialProbs = _register_distribution_op(dist.BinomialProbs)
526
+ BinomialLogits = _register_distribution_op(dist.BinomialLogits)
527
+ CategoricalLogits = _register_distribution_op(dist.CategoricalLogits)
528
+ CategoricalProbs = _register_distribution_op(dist.CategoricalProbs)
529
+ Cauchy = _register_distribution_op(dist.Cauchy)
530
+ Chi2 = _register_distribution_op(dist.Chi2)
531
+ Delta = _register_distribution_op(dist.Delta)
532
+ Dirichlet = _register_distribution_op(dist.Dirichlet)
533
+ DirichletMultinomial = _register_distribution_op(dist.DirichletMultinomial)
534
+ Distribution = _register_distribution_op(dist.Distribution)
535
+ Exponential = _register_distribution_op(dist.Exponential)
536
+ Gamma = _register_distribution_op(dist.Gamma)
537
+ GeometricLogits = _register_distribution_op(dist.GeometricLogits)
538
+ GeometricProbs = _register_distribution_op(dist.GeometricProbs)
539
+ Gumbel = _register_distribution_op(dist.Gumbel)
540
+ HalfCauchy = _register_distribution_op(dist.HalfCauchy)
541
+ HalfNormal = _register_distribution_op(dist.HalfNormal)
542
+ Independent = _register_distribution_op(dist.Independent)
543
+ Kumaraswamy = _register_distribution_op(dist.Kumaraswamy)
544
+ LKJCholesky = _register_distribution_op(dist.LKJCholesky)
545
+ Laplace = _register_distribution_op(dist.Laplace)
546
+ LogNormal = _register_distribution_op(dist.LogNormal)
547
+ Logistic = _register_distribution_op(dist.Logistic)
548
+ LowRankMultivariateNormal = _register_distribution_op(dist.LowRankMultivariateNormal)
549
+ MultinomialProbs = _register_distribution_op(dist.MultinomialProbs)
550
+ MultinomialLogits = _register_distribution_op(dist.MultinomialLogits)
551
+ MultivariateNormal = _register_distribution_op(dist.MultivariateNormal)
552
+ NegativeBinomialProbs = _register_distribution_op(dist.NegativeBinomialProbs)
553
+ NegativeBinomialLogits = _register_distribution_op(dist.NegativeBinomialLogits)
554
+ Normal = _register_distribution_op(dist.Normal)
555
+ Pareto = _register_distribution_op(dist.Pareto)
556
+ Poisson = _register_distribution_op(dist.Poisson)
557
+ RelaxedBernoulliLogits = _register_distribution_op(dist.RelaxedBernoulliLogits)
558
+ StudentT = _register_distribution_op(dist.StudentT)
559
+ Uniform = _register_distribution_op(dist.Uniform)
560
+ VonMises = _register_distribution_op(dist.VonMises)
561
+ Weibull = _register_distribution_op(dist.Weibull)
562
+ Wishart = _register_distribution_op(dist.Wishart)