bartz 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bartz-0.2.0 → bartz-0.2.1}/PKG-INFO +1 -1
- {bartz-0.2.0 → bartz-0.2.1}/pyproject.toml +2 -2
- bartz-0.2.1/src/bartz/_version.py +1 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/mcmcstep.py +10 -2
- bartz-0.2.0/src/bartz/_version.py +0 -1
- {bartz-0.2.0 → bartz-0.2.1}/LICENSE +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/README.md +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/BART.py +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/__init__.py +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/debug.py +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/grove.py +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/jaxext.py +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/mcmcloop.py +0 -0
- {bartz-0.2.0 → bartz-0.2.1}/src/bartz/prepcovars.py +0 -0
|
@@ -28,7 +28,7 @@ build-backend = "poetry.core.masonry.api"
|
|
|
28
28
|
|
|
29
29
|
[tool.poetry]
|
|
30
30
|
name = "bartz"
|
|
31
|
-
version = "0.2.
|
|
31
|
+
version = "0.2.1"
|
|
32
32
|
description = "A JAX implementation of BART"
|
|
33
33
|
authors = ["Giacomo Petrillo <info@giacomopetrillo.com>"]
|
|
34
34
|
license = "MIT"
|
|
@@ -61,7 +61,7 @@ pytest = "^8.1.1"
|
|
|
61
61
|
|
|
62
62
|
[tool.poetry.group.docs.dependencies]
|
|
63
63
|
Sphinx = "^7.2.6"
|
|
64
|
-
numpydoc = "^1.6.0
|
|
64
|
+
numpydoc = "^1.6.0"
|
|
65
65
|
myst-parser = "^2.0.0"
|
|
66
66
|
|
|
67
67
|
[tool.pytest.ini_options]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.2.1'
|
|
@@ -84,7 +84,8 @@ def init(*,
|
|
|
84
84
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
85
|
suffstat_batch_size : int, None, str, default 'auto'
|
|
86
86
|
The batch size for computing sufficient statistics. `None` for no
|
|
87
|
-
batching. If 'auto', pick a value based on the device of `y
|
|
87
|
+
batching. If 'auto', pick a value based on the device of `y`, or the
|
|
88
|
+
default device.
|
|
88
89
|
|
|
89
90
|
Returns
|
|
90
91
|
-------
|
|
@@ -188,7 +189,12 @@ def init(*,
|
|
|
188
189
|
|
|
189
190
|
def _choose_suffstat_batch_size(size, y):
|
|
190
191
|
if size == 'auto':
|
|
191
|
-
|
|
192
|
+
try:
|
|
193
|
+
device = y.devices().pop()
|
|
194
|
+
except jax.errors.ConcretizationTypeError:
|
|
195
|
+
device = jax.devices()[0]
|
|
196
|
+
platform = device.platform
|
|
197
|
+
|
|
192
198
|
if platform == 'cpu':
|
|
193
199
|
return None
|
|
194
200
|
# maybe I should batch residuals (not counts) for numerical
|
|
@@ -198,8 +204,10 @@ def _choose_suffstat_batch_size(size, y):
|
|
|
198
204
|
# 512 is good on T4, and V100 at low n
|
|
199
205
|
else:
|
|
200
206
|
raise KeyError(f'Unknown platform: {platform}')
|
|
207
|
+
|
|
201
208
|
elif size is not None:
|
|
202
209
|
return int(size)
|
|
210
|
+
|
|
203
211
|
return size
|
|
204
212
|
|
|
205
213
|
def step(bart, key):
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = '0.2.0'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|