bayinx 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/core/model.py CHANGED
@@ -1,25 +1,25 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Dict, Tuple
2
+ from typing import Any, Dict, Generic, Tuple, TypeVar
3
3
 
4
4
  import equinox as eqx
5
5
  import jax.numpy as jnp
6
6
  import jax.tree as jt
7
- from jaxtyping import Array, PyTree, Scalar
7
+ from jaxtyping import PyTree, Scalar
8
8
 
9
9
  from bayinx.core.constraint import Constraint
10
10
  from bayinx.core.parameter import Parameter
11
11
 
12
-
13
- class Model(eqx.Module):
12
+ T = TypeVar('T', bound=PyTree)
13
+ class Model(eqx.Module, Generic[T]):
14
14
  """
15
15
  An abstract base class used to define probabilistic models.
16
16
 
17
17
  # Attributes
18
- - `params`: A dictionary of Arrays or PyTrees containing Arrays representing parameters of the model.
18
+ - `params`: A dictionary of parameters.
19
19
  - `constraints`: A dictionary of constraints.
20
20
  """
21
21
 
22
- params: Dict[str, Parameter]
22
+ params: Dict[str, Parameter[T]]
23
23
  constraints: Dict[str, Constraint]
24
24
 
25
25
  @abstractmethod
@@ -47,14 +47,14 @@ class Model(eqx.Module):
47
47
 
48
48
  # Add constrain method
49
49
  @eqx.filter_jit
50
- def constrain_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
50
+ def constrain_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
51
51
  """
52
52
  Constrain `params` to the appropriate domain.
53
53
 
54
54
  # Returns
55
55
  A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
56
56
  """
57
- t_params: Dict[str, Array | PyTree] = self.params
57
+ t_params: Dict[str, Parameter[T]] = self.params
58
58
  target: Scalar = jnp.array(0.0)
59
59
 
60
60
  for par, map in self.constraints.items():
@@ -67,7 +67,8 @@ class Model(eqx.Module):
67
67
  return t_params, target
68
68
 
69
69
  # Add default transform method
70
- def transform_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
70
+ @eqx.filter_jit
71
+ def transform_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
71
72
  """
72
73
  Apply a custom transformation to `params` if needed.
73
74
 
bayinx/core/parameter.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import Self
1
+ from typing import Generic, Self, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.tree as jt
5
- from jaxtyping import Array, PyTree
5
+ from jaxtyping import PyTree
6
6
 
7
-
8
- class Parameter(eqx.Module):
7
+ T = TypeVar('T', bound=PyTree)
8
+ class Parameter(eqx.Module, Generic[T]):
9
9
  """
10
10
  A container for a parameter of a `Model`.
11
11
 
@@ -14,10 +14,10 @@ class Parameter(eqx.Module):
14
14
  # Attributes
15
15
  - `vals`: The parameter's value(s).
16
16
  """
17
- vals: Array | PyTree
17
+ vals: T
18
18
 
19
19
 
20
- def __init__(self, values: Array | PyTree):
20
+ def __init__(self, values: T):
21
21
  # Insert parameter values
22
22
  self.vals = values
23
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -5,8 +5,8 @@ bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1
5
5
  bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
6
6
  bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
7
7
  bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
8
- bayinx/core/model.py,sha256=QnJUKaR6d5RCe_WIxD2oJtI8NJyFKWUWyCRVwOm0j3s,2276
9
- bayinx/core/parameter.py,sha256=fdyzun6TDnXxQT_KlarIJvWzn9n8bQgzfiVjWIIHk6k,998
8
+ bayinx/core/model.py,sha256=RKAtsLXc6xDnXWz5upmx6Vz6JOoorw4WTfxTA7B7Lmg,2294
9
+ bayinx/core/parameter.py,sha256=oxCCZcZ-DDBvfWzexfhQkSJPxNQnE1vYXtBhiEZG2bM,1025
10
10
  bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
11
11
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
@@ -25,6 +25,6 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
25
25
  bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
26
26
  bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
27
27
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
28
- bayinx-0.3.0.dist-info/METADATA,sha256=RLbnLgyMmnEh2BJmqex3MMWFFS3HgSU9NEeQEvkyfC0,3057
29
- bayinx-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
- bayinx-0.3.0.dist-info/RECORD,,
28
+ bayinx-0.3.1.dist-info/METADATA,sha256=Slp2nxR8HISCwCqIXY2El3GgqsO1v9_UbVeJq726w7k,3057
29
+ bayinx-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
+ bayinx-0.3.1.dist-info/RECORD,,
File without changes