aspire-inference 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/flows/base.py +12 -2
- aspire/transforms.py +11 -4
- {aspire_inference-0.1.0a2.dist-info → aspire_inference-0.1.0a3.dist-info}/METADATA +3 -1
- {aspire_inference-0.1.0a2.dist-info → aspire_inference-0.1.0a3.dist-info}/RECORD +7 -7
- {aspire_inference-0.1.0a2.dist-info → aspire_inference-0.1.0a3.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a2.dist-info → aspire_inference-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a2.dist-info → aspire_inference-0.1.0a3.dist-info}/top_level.txt +0 -0
aspire/flows/base.py
CHANGED
|
@@ -1,18 +1,28 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
from ..history import FlowHistory
|
|
4
|
-
from ..transforms import BaseTransform
|
|
5
|
+
from ..transforms import BaseTransform, IdentityTransform
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
class Flow:
|
|
11
|
+
xp = None # type: Any
|
|
12
|
+
|
|
8
13
|
def __init__(
|
|
9
14
|
self,
|
|
10
15
|
dims: int,
|
|
11
16
|
device: Any,
|
|
12
|
-
data_transform: BaseTransform = None,
|
|
17
|
+
data_transform: BaseTransform | None = None,
|
|
13
18
|
):
|
|
14
19
|
self.dims = dims
|
|
15
20
|
self.device = device
|
|
21
|
+
|
|
22
|
+
if data_transform is None:
|
|
23
|
+
data_transform = IdentityTransform(self.xp)
|
|
24
|
+
logger.info("No data_transform provided, using IdentityTransform.")
|
|
25
|
+
|
|
16
26
|
self.data_transform = data_transform
|
|
17
27
|
|
|
18
28
|
def log_prob(self, x):
|
aspire/transforms.py
CHANGED
|
@@ -4,7 +4,6 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
from array_api_compat import device as get_device
|
|
6
6
|
from array_api_compat import is_torch_namespace
|
|
7
|
-
from scipy.special import erf, erfinv
|
|
8
7
|
|
|
9
8
|
from .flows import get_flow_wrapper
|
|
10
9
|
from .utils import (
|
|
@@ -50,13 +49,17 @@ class IdentityTransform(BaseTransform):
|
|
|
50
49
|
"""Identity transform that does nothing to the data."""
|
|
51
50
|
|
|
52
51
|
def fit(self, x):
|
|
53
|
-
return x
|
|
52
|
+
return copy_array(x, xp=self.xp)
|
|
54
53
|
|
|
55
54
|
def forward(self, x):
|
|
56
|
-
return x, self.xp
|
|
55
|
+
return copy_array(x, xp=self.xp), self.xp.zeros(
|
|
56
|
+
len(x), device=get_device(x)
|
|
57
|
+
)
|
|
57
58
|
|
|
58
59
|
def inverse(self, y):
|
|
59
|
-
return y, self.xp
|
|
60
|
+
return copy_array(y, xp=self.xp), self.xp.zeros(
|
|
61
|
+
len(y), device=get_device(y)
|
|
62
|
+
)
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
class CompositeTransform(BaseTransform):
|
|
@@ -329,6 +332,8 @@ class ProbitTransform(BaseTransform):
|
|
|
329
332
|
return self.forward(x)[0]
|
|
330
333
|
|
|
331
334
|
def forward(self, x):
|
|
335
|
+
from scipy.special import erfinv
|
|
336
|
+
|
|
332
337
|
y = (x - self.lower) / (self.upper - self.lower)
|
|
333
338
|
y = self.xp.clip(y, self.eps, 1.0 - self.eps)
|
|
334
339
|
y = erfinv(2 * y - 1) * math.sqrt(2)
|
|
@@ -339,6 +344,8 @@ class ProbitTransform(BaseTransform):
|
|
|
339
344
|
return y, log_abs_det_jacobian
|
|
340
345
|
|
|
341
346
|
def inverse(self, y):
|
|
347
|
+
from scipy.special import erf
|
|
348
|
+
|
|
342
349
|
log_abs_det_jacobian = (
|
|
343
350
|
-(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
|
|
344
351
|
- self._scale_log_abs_det_jacobian
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a3
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -14,6 +14,8 @@ Requires-Dist: numpy
|
|
|
14
14
|
Requires-Dist: array-api-compat
|
|
15
15
|
Requires-Dist: wrapt
|
|
16
16
|
Requires-Dist: h5py
|
|
17
|
+
Provides-Extra: scipy
|
|
18
|
+
Requires-Dist: scipy; extra == "scipy"
|
|
17
19
|
Provides-Extra: jax
|
|
18
20
|
Requires-Dist: jax; extra == "jax"
|
|
19
21
|
Requires-Dist: jaxlib; extra == "jax"
|
|
@@ -3,10 +3,10 @@ aspire/aspire.py,sha256=AEkFUuOCF4F_iXUqRNst_4mucxozYRK4fG4V2wGrT4Q,15762
|
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
5
|
aspire/samples.py,sha256=hMlONOtSuYE3bU6r_wQCZ8Z1dcc3Ch15bNMLO8fGU8g,16263
|
|
6
|
-
aspire/transforms.py,sha256=
|
|
6
|
+
aspire/transforms.py,sha256=fg2_UELJWjJ6gnqQi6X7s1CgKBhQ5hP7Ipil3tTjCeg,16566
|
|
7
7
|
aspire/utils.py,sha256=fQeLMauCN3vAogKbVTVg9jfjW7nTEFi7V6Ot-BYNfxE,14301
|
|
8
8
|
aspire/flows/__init__.py,sha256=3gGXF4HziMlZSmcEdJ_uHtrP-QEC6RXvylm4vtM-Xnk,1306
|
|
9
|
-
aspire/flows/base.py,sha256=
|
|
9
|
+
aspire/flows/base.py,sha256=oTw2ZkxCsA5RZhnMuIu9M-2FPHvQG2TGFIEJZVK4a2g,1140
|
|
10
10
|
aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
|
|
11
11
|
aspire/flows/jax/flows.py,sha256=jZ93fnc7U7ZhZLVixGUTwyeDb6Vz0UWpYkkVHwirNug,2896
|
|
12
12
|
aspire/flows/jax/utils.py,sha256=UlvXOOqC5fNsmVUnU4LSksliq7pLRm9NhOu0ZvVHqgc,1455
|
|
@@ -21,8 +21,8 @@ aspire/samplers/smc/base.py,sha256=GePA6tm8Dno_AjCeNuRX3KOaKnoKSFHSRAb-QWx9wJE,1
|
|
|
21
21
|
aspire/samplers/smc/blackjax.py,sha256=9w1ORzWTT1viwp99_ttLxnNMdgTO-VqAzsf-NhgG9vY,11722
|
|
22
22
|
aspire/samplers/smc/emcee.py,sha256=ZXXyN2l1Bz5ZsCPEcswg-Kakiw41nNa2jEW1N8zGjuc,2498
|
|
23
23
|
aspire/samplers/smc/minipcn.py,sha256=ZjeP4iHFR67G8WKEfMe0b1McrtPgQMNHyyy4vRx6WNE,2747
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
24
|
+
aspire_inference-0.1.0a3.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a3.dist-info/METADATA,sha256=jwpdVn7ns4Fa6hEfFr3sNB3XXJJ5ApdKEry_H-APddA,1536
|
|
26
|
+
aspire_inference-0.1.0a3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a3.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|