bartz 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.2.0'
1
+ __version__ = '0.2.1'
bartz/mcmcstep.py CHANGED
@@ -84,7 +84,8 @@ def init(*,
84
84
  The minimum number of data points in a leaf node. 0 if not specified.
85
85
  suffstat_batch_size : int, None, str, default 'auto'
86
86
  The batch size for computing sufficient statistics. `None` for no
87
- batching. If 'auto', pick a value based on the device of `y`.
87
+ batching. If 'auto', pick a value based on the device of `y`, or the
88
+ default device.
88
89
 
89
90
  Returns
90
91
  -------
@@ -188,7 +189,12 @@ def init(*,
188
189
 
189
190
  def _choose_suffstat_batch_size(size, y):
190
191
  if size == 'auto':
191
- platform = y.devices().pop().platform
192
+ try:
193
+ device = y.devices().pop()
194
+ except jax.errors.ConcretizationTypeError:
195
+ device = jax.devices()[0]
196
+ platform = device.platform
197
+
192
198
  if platform == 'cpu':
193
199
  return None
194
200
  # maybe I should batch residuals (not counts) for numerical
@@ -198,8 +204,10 @@ def _choose_suffstat_batch_size(size, y):
198
204
  # 512 is good on T4, and V100 at low n
199
205
  else:
200
206
  raise KeyError(f'Unknown platform: {platform}')
207
+
201
208
  elif size is not None:
202
209
  return int(size)
210
+
203
211
  return size
204
212
 
205
213
  def step(bart, key):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
@@ -1,13 +1,13 @@
1
1
  bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
2
2
  bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
- bartz/_version.py,sha256=FVHPBGkfhbQDi_z3v0PiKJrXXqXOx0vGW_1VaqNJi7U,22
3
+ bartz/_version.py,sha256=PmcQ2PI2oP8irnLtJLJby2YfW6sBvLAmL-VpABzTqwc,22
4
4
  bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
5
  bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
6
6
  bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
7
7
  bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
8
- bartz/mcmcstep.py,sha256=3ba94hXBW4UAZ11SFshnwJAgn6bpIqSZdRy_wQjEkrk,39278
8
+ bartz/mcmcstep.py,sha256=6fzNMumXjMe6Fj6zoHLTf1D42JuAiQyGHfr6l1Bwrnk,39450
9
9
  bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
10
- bartz-0.2.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
- bartz-0.2.0.dist-info/METADATA,sha256=LiYjTAzgoxUM2MAuaKtf0VW-_zciTKBkTX5B7HNvUbI,1490
12
- bartz-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- bartz-0.2.0.dist-info/RECORD,,
10
+ bartz-0.2.1.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
+ bartz-0.2.1.dist-info/METADATA,sha256=eGxicC1iR-Bpjk1uKn50g6FxdFfq9S70nl7m5GmXO14,1490
12
+ bartz-0.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ bartz-0.2.1.dist-info/RECORD,,
File without changes
File without changes