bayinx 0.3.15__py3-none-any.whl → 0.3.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/_variational.py +25 -9
- {bayinx-0.3.15.dist-info → bayinx-0.3.17.dist-info}/METADATA +1 -1
- {bayinx-0.3.15.dist-info → bayinx-0.3.17.dist-info}/RECORD +5 -5
- {bayinx-0.3.15.dist-info → bayinx-0.3.17.dist-info}/WHEEL +0 -0
- {bayinx-0.3.15.dist-info → bayinx-0.3.17.dist-info}/licenses/LICENSE +0 -0
bayinx/core/_variational.py
CHANGED
@@ -175,17 +175,33 @@ class Variational(eqx.Module, Generic[M]):
|
|
175
175
|
data: Any = None,
|
176
176
|
key: Key = jr.PRNGKey(0),
|
177
177
|
) -> Array:
|
178
|
-
# Sample
|
179
|
-
|
178
|
+
# Sample a single draw to evaluate shape of output
|
179
|
+
draw: Array = self.sample(1, key)[0]
|
180
|
+
output: Array = func(self._unflatten(draw), data)
|
181
|
+
|
182
|
+
# Allocate space for results
|
183
|
+
results: Array = jnp.zeros((n,) + output.shape, dtype=output.dtype)
|
184
|
+
|
185
|
+
@eqx.filter_jit
|
186
|
+
def body_fun(i: int, state: Tuple[Key, Array]) -> Tuple[Key, Array]:
|
187
|
+
# Unpack state
|
188
|
+
key, results = state
|
189
|
+
|
190
|
+
# Update PRNG key
|
191
|
+
next, key = jr.split(key)
|
192
|
+
|
193
|
+
# Draw from variational
|
194
|
+
draw: Array = self.sample(1, key)
|
180
195
|
|
181
|
-
# Evaluate posterior predictive
|
182
|
-
@jax.jit
|
183
|
-
@jax.vmap
|
184
|
-
def evaluate(draw: Array):
|
185
196
|
# Reconstruct model
|
186
197
|
model: M = self._unflatten(draw)
|
187
198
|
|
188
|
-
#
|
189
|
-
|
199
|
+
# Update results with output
|
200
|
+
results = results.at[i].set(func(model, data))
|
201
|
+
|
202
|
+
return next, results
|
203
|
+
|
204
|
+
# Evaluate draws
|
205
|
+
results: Array = jax.lax.fori_loop(0, n, body_fun, (key, results))[1]
|
190
206
|
|
191
|
-
return
|
207
|
+
return results
|
@@ -8,7 +8,7 @@ bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
|
8
8
|
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
9
|
bayinx/core/_optimization.py,sha256=mmeVUqfFARz8F7q4LRl-uEwVWzekmzh-9o7PnuvsHZk,2651
|
10
10
|
bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
|
11
|
-
bayinx/core/_variational.py,sha256=
|
11
|
+
bayinx/core/_variational.py,sha256=mLumWRcpATuU-S073qFy-6oYL-OpI3wNrFe7bXEbyRE,6188
|
12
12
|
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
13
13
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
14
14
|
bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
|
@@ -31,7 +31,7 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
|
|
31
31
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
32
32
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
33
33
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
34
|
-
bayinx-0.3.
|
35
|
-
bayinx-0.3.
|
36
|
-
bayinx-0.3.
|
37
|
-
bayinx-0.3.
|
34
|
+
bayinx-0.3.17.dist-info/METADATA,sha256=8YZLse5eUFxxJYvZ8lBHMs-5FbnxxkLTe1hEN4pxuz4,3087
|
35
|
+
bayinx-0.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
36
|
+
bayinx-0.3.17.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
37
|
+
bayinx-0.3.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|