bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/prepcovars.py CHANGED
@@ -27,8 +27,10 @@ import functools
27
27
  import jax
28
28
  from jax import numpy as jnp
29
29
 
30
+ from . import jaxext
30
31
  from . import grove
31
32
 
33
+ @functools.partial(jax.jit, static_argnums=(1,))
32
34
  def quantilized_splits_from_matrix(X, max_bins):
33
35
  """
34
36
  Determine bins that make the distribution of each predictor uniform.
@@ -52,48 +54,41 @@ def quantilized_splits_from_matrix(X, max_bins):
52
54
  The number of actually used values in each row of `splits`.
53
55
  """
54
56
  out_length = min(max_bins, X.shape[1]) - 1
55
- return quantilized_splits_from_matrix_impl(X, out_length)
57
+ # return _quantilized_splits_from_matrix(X, out_length)
58
+ @functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
59
+ def func(X):
60
+ return _quantilized_splits_from_matrix(X, out_length)
61
+ return func(X)
56
62
 
57
63
  @functools.partial(jax.vmap, in_axes=(0, None))
58
- def quantilized_splits_from_matrix_impl(x, out_length):
59
- huge = huge_value(x)
60
- u = jnp.unique(x, size=x.size, fill_value=huge)
61
- actual_length = jnp.count_nonzero(u < huge) - 1
62
- midpoints = (u[1:] + u[:-1]) / 2
64
+ def _quantilized_splits_from_matrix(x, out_length):
65
+ huge = jaxext.huge_value(x)
66
+ u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
67
+ actual_length -= 1
68
+ if jnp.issubdtype(x.dtype, jnp.integer):
69
+ midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
70
+ indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1))
71
+ midpoints = jnp.where(indices < actual_length, midpoints, huge)
72
+ else:
73
+ midpoints = (u[1:] + u[:-1]) / 2
63
74
  indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
64
- indices = jnp.around(indices).astype(grove.minimal_unsigned_dtype(midpoints.size - 1))
75
+ indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1))
65
76
  # indices calculation with float rather than int to avoid potential
66
77
  # overflow with int32, and to round to nearest instead of rounding down
67
78
  decimated_midpoints = midpoints[indices]
68
79
  truncated_midpoints = midpoints[:out_length]
69
80
  splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints)
70
81
  max_split = jnp.minimum(actual_length, out_length)
71
- max_split = max_split.astype(grove.minimal_unsigned_dtype(out_length))
82
+ max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
72
83
  return splits, max_split
73
84
 
74
- def huge_value(x):
75
- """
76
- Return the maximum value that can be stored in `x`.
77
-
78
- Parameters
79
- ----------
80
- x : array
81
- A numerical numpy or jax array.
82
-
83
- Returns
84
- -------
85
- maxval : scalar
86
- The maximum value allowed by `x`'s type (+inf for floats).
87
- """
88
- if jnp.issubdtype(x.dtype, jnp.integer):
89
- return jnp.iinfo(x.dtype).max
90
- else:
91
- return jnp.inf
92
-
85
+ @jax.jit
93
86
  def bin_predictors(X, splits):
94
87
  """
95
88
  Bin the predictors according to the given splits.
96
89
 
90
+ A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
91
+
97
92
  Parameters
98
93
  ----------
99
94
  X : array (p, n)
@@ -110,9 +105,9 @@ def bin_predictors(X, splits):
110
105
  A matrix with `p` predictors and `n` observations, where each predictor
111
106
  has been replaced by the index of the bin it falls into.
112
107
  """
113
- return bin_predictors_impl(X, splits)
108
+ return _bin_predictors(X, splits)
114
109
 
115
110
  @jax.vmap
116
- def bin_predictors_impl(x, splits):
117
- dtype = grove.minimal_unsigned_dtype(splits.size)
111
+ def _bin_predictors(x, splits):
112
+ dtype = jaxext.minimal_unsigned_dtype(splits.size)
118
113
  return jnp.searchsorted(splits, x).astype(dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.0.1
3
+ Version: 0.2.0
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
20
  Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
21
  Description-Content-Type: text/markdown
22
22
 
23
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
+
23
25
  # BART vectoriZed
24
26
 
25
27
  A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
26
28
 
29
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
+
31
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
+
@@ -0,0 +1,13 @@
1
+ bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
2
+ bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
+ bartz/_version.py,sha256=FVHPBGkfhbQDi_z3v0PiKJrXXqXOx0vGW_1VaqNJi7U,22
4
+ bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
+ bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
6
+ bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
7
+ bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
8
+ bartz/mcmcstep.py,sha256=3ba94hXBW4UAZ11SFshnwJAgn6bpIqSZdRy_wQjEkrk,39278
9
+ bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
10
+ bartz-0.2.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
+ bartz-0.2.0.dist-info/METADATA,sha256=LiYjTAzgoxUM2MAuaKtf0VW-_zciTKBkTX5B7HNvUbI,1490
12
+ bartz-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ bartz-0.2.0.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- bartz/__init__.py,sha256=PL-vhhEHoVMWOPLG_M45TIZVkbQia5riJbQboy-BNH8,1333
2
- bartz/debug.py,sha256=FHnCalpK1uO1CN9QQ5DPj70JKR4Thltzp9o0BeYthIo,5741
3
- bartz/grove.py,sha256=v2k10EBjgi2aLCsGvM01z0z--9Xv4ApBOxpke-6gIYM,10309
4
- bartz/interface.py,sha256=GBwLwqEF_6EmeteFtsPw6ANdisnvMoWi_fKBJiQq-Vc,16129
5
- bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
6
- bartz/mcmcloop.py,sha256=N815-eJxsS_X85okXRO2kSOlikw8dPN05_krm0iT9Sg,7321
7
- bartz/mcmcstep.py,sha256=acy_2rSIEXV5BzqLY96aQaqlsxtalxyO3Q4gPvUMRVU,35912
8
- bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
9
- bartz-0.0.1.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
10
- bartz-0.0.1.dist-info/METADATA,sha256=zDW1dM58gV7c_8ZTjEtTt_tcXabbz5roZBf36EdLxls,933
11
- bartz-0.0.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
12
- bartz-0.0.1.dist-info/RECORD,,
File without changes
File without changes