aspire-inference 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/flows/base.py CHANGED
@@ -1,18 +1,28 @@
1
+ import logging
1
2
  from typing import Any
2
3
 
3
4
  from ..history import FlowHistory
4
- from ..transforms import BaseTransform
5
+ from ..transforms import BaseTransform, IdentityTransform
6
+
7
+ logger = logging.getLogger(__name__)
5
8
 
6
9
 
7
10
  class Flow:
11
+ xp = None # type: Any
12
+
8
13
  def __init__(
9
14
  self,
10
15
  dims: int,
11
16
  device: Any,
12
- data_transform: BaseTransform = None,
17
+ data_transform: BaseTransform | None = None,
13
18
  ):
14
19
  self.dims = dims
15
20
  self.device = device
21
+
22
+ if data_transform is None:
23
+ data_transform = IdentityTransform(self.xp)
24
+ logger.info("No data_transform provided, using IdentityTransform.")
25
+
16
26
  self.data_transform = data_transform
17
27
 
18
28
  def log_prob(self, x):
aspire/transforms.py CHANGED
@@ -4,7 +4,6 @@ from typing import Any
4
4
 
5
5
  from array_api_compat import device as get_device
6
6
  from array_api_compat import is_torch_namespace
7
- from scipy.special import erf, erfinv
8
7
 
9
8
  from .flows import get_flow_wrapper
10
9
  from .utils import (
@@ -50,13 +49,17 @@ class IdentityTransform(BaseTransform):
50
49
  """Identity transform that does nothing to the data."""
51
50
 
52
51
  def fit(self, x):
53
- return x
52
+ return copy_array(x, xp=self.xp)
54
53
 
55
54
  def forward(self, x):
56
- return x, self.xp.zeros(len(x), device=get_device(x))
55
+ return copy_array(x, xp=self.xp), self.xp.zeros(
56
+ len(x), device=get_device(x)
57
+ )
57
58
 
58
59
  def inverse(self, y):
59
- return y, self.xp.zeros(len(y), device=get_device(y))
60
+ return copy_array(y, xp=self.xp), self.xp.zeros(
61
+ len(y), device=get_device(y)
62
+ )
60
63
 
61
64
 
62
65
  class CompositeTransform(BaseTransform):
@@ -329,6 +332,8 @@ class ProbitTransform(BaseTransform):
329
332
  return self.forward(x)[0]
330
333
 
331
334
  def forward(self, x):
335
+ from scipy.special import erfinv
336
+
332
337
  y = (x - self.lower) / (self.upper - self.lower)
333
338
  y = self.xp.clip(y, self.eps, 1.0 - self.eps)
334
339
  y = erfinv(2 * y - 1) * math.sqrt(2)
@@ -339,6 +344,8 @@ class ProbitTransform(BaseTransform):
339
344
  return y, log_abs_det_jacobian
340
345
 
341
346
  def inverse(self, y):
347
+ from scipy.special import erf
348
+
342
349
  log_abs_det_jacobian = (
343
350
  -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
344
351
  - self._scale_log_abs_det_jacobian
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a4
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -14,6 +14,8 @@ Requires-Dist: numpy
14
14
  Requires-Dist: array-api-compat
15
15
  Requires-Dist: wrapt
16
16
  Requires-Dist: h5py
17
+ Provides-Extra: scipy
18
+ Requires-Dist: scipy; extra == "scipy"
17
19
  Provides-Extra: jax
18
20
  Requires-Dist: jax; extra == "jax"
19
21
  Requires-Dist: jaxlib; extra == "jax"
@@ -21,6 +23,7 @@ Requires-Dist: flowjax; extra == "jax"
21
23
  Provides-Extra: torch
22
24
  Requires-Dist: torch; extra == "torch"
23
25
  Requires-Dist: zuko; extra == "torch"
26
+ Requires-Dist: tqdm; extra == "torch"
24
27
  Provides-Extra: minipcn
25
28
  Requires-Dist: minipcn; extra == "minipcn"
26
29
  Provides-Extra: emcee
@@ -3,10 +3,10 @@ aspire/aspire.py,sha256=AEkFUuOCF4F_iXUqRNst_4mucxozYRK4fG4V2wGrT4Q,15762
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
5
  aspire/samples.py,sha256=hMlONOtSuYE3bU6r_wQCZ8Z1dcc3Ch15bNMLO8fGU8g,16263
6
- aspire/transforms.py,sha256=R_BNPlYxK8tvACkZMjgHayr9gUpHxUQiD8148jfTnmg,16407
6
+ aspire/transforms.py,sha256=fg2_UELJWjJ6gnqQi6X7s1CgKBhQ5hP7Ipil3tTjCeg,16566
7
7
  aspire/utils.py,sha256=fQeLMauCN3vAogKbVTVg9jfjW7nTEFi7V6Ot-BYNfxE,14301
8
8
  aspire/flows/__init__.py,sha256=3gGXF4HziMlZSmcEdJ_uHtrP-QEC6RXvylm4vtM-Xnk,1306
9
- aspire/flows/base.py,sha256=scBhYvtaoa1x_gcrWs0nLfOKhWYu2bqivVVqbH4zSI8,860
9
+ aspire/flows/base.py,sha256=oTw2ZkxCsA5RZhnMuIu9M-2FPHvQG2TGFIEJZVK4a2g,1140
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
11
11
  aspire/flows/jax/flows.py,sha256=jZ93fnc7U7ZhZLVixGUTwyeDb6Vz0UWpYkkVHwirNug,2896
12
12
  aspire/flows/jax/utils.py,sha256=UlvXOOqC5fNsmVUnU4LSksliq7pLRm9NhOu0ZvVHqgc,1455
@@ -21,8 +21,8 @@ aspire/samplers/smc/base.py,sha256=GePA6tm8Dno_AjCeNuRX3KOaKnoKSFHSRAb-QWx9wJE,1
21
21
  aspire/samplers/smc/blackjax.py,sha256=9w1ORzWTT1viwp99_ttLxnNMdgTO-VqAzsf-NhgG9vY,11722
22
22
  aspire/samplers/smc/emcee.py,sha256=ZXXyN2l1Bz5ZsCPEcswg-Kakiw41nNa2jEW1N8zGjuc,2498
23
23
  aspire/samplers/smc/minipcn.py,sha256=ZjeP4iHFR67G8WKEfMe0b1McrtPgQMNHyyy4vRx6WNE,2747
24
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a2.dist-info/METADATA,sha256=8s65XoHR6AJmpCDFAA1mqqCWZXd_m8skZgOotgNRO2U,1475
26
- aspire_inference-0.1.0a2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a2.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a2.dist-info/RECORD,,
24
+ aspire_inference-0.1.0a4.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a4.dist-info/METADATA,sha256=9RBI_xCkn1b0u7iOSjWZb7kCLoeZwYEwQrry--ieuQ8,1574
26
+ aspire_inference-0.1.0a4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a4.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a4.dist-info/RECORD,,