bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/{interface.py → BART.py} +10 -18
- bartz/__init__.py +7 -2
- bartz/_version.py +1 -0
- bartz/debug.py +9 -22
- bartz/grove.py +73 -120
- bartz/jaxext.py +261 -5
- bartz/mcmcloop.py +27 -13
- bartz/mcmcstep.py +510 -439
- bartz/prepcovars.py +25 -30
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/METADATA +7 -1
- bartz-0.2.0.dist-info/RECORD +13 -0
- bartz-0.0.1.dist-info/RECORD +0 -12
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/LICENSE +0 -0
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/WHEEL +0 -0
bartz/prepcovars.py
CHANGED
|
@@ -27,8 +27,10 @@ import functools
|
|
|
27
27
|
import jax
|
|
28
28
|
from jax import numpy as jnp
|
|
29
29
|
|
|
30
|
+
from . import jaxext
|
|
30
31
|
from . import grove
|
|
31
32
|
|
|
33
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
32
34
|
def quantilized_splits_from_matrix(X, max_bins):
|
|
33
35
|
"""
|
|
34
36
|
Determine bins that make the distribution of each predictor uniform.
|
|
@@ -52,48 +54,41 @@ def quantilized_splits_from_matrix(X, max_bins):
|
|
|
52
54
|
The number of actually used values in each row of `splits`.
|
|
53
55
|
"""
|
|
54
56
|
out_length = min(max_bins, X.shape[1]) - 1
|
|
55
|
-
return
|
|
57
|
+
# return _quantilized_splits_from_matrix(X, out_length)
|
|
58
|
+
@functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
|
|
59
|
+
def func(X):
|
|
60
|
+
return _quantilized_splits_from_matrix(X, out_length)
|
|
61
|
+
return func(X)
|
|
56
62
|
|
|
57
63
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
58
|
-
def
|
|
59
|
-
huge = huge_value(x)
|
|
60
|
-
u =
|
|
61
|
-
actual_length
|
|
62
|
-
|
|
64
|
+
def _quantilized_splits_from_matrix(x, out_length):
|
|
65
|
+
huge = jaxext.huge_value(x)
|
|
66
|
+
u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
|
|
67
|
+
actual_length -= 1
|
|
68
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
69
|
+
midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
|
|
70
|
+
indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1))
|
|
71
|
+
midpoints = jnp.where(indices < actual_length, midpoints, huge)
|
|
72
|
+
else:
|
|
73
|
+
midpoints = (u[1:] + u[:-1]) / 2
|
|
63
74
|
indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
|
|
64
|
-
indices = jnp.around(indices).astype(
|
|
75
|
+
indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1))
|
|
65
76
|
# indices calculation with float rather than int to avoid potential
|
|
66
77
|
# overflow with int32, and to round to nearest instead of rounding down
|
|
67
78
|
decimated_midpoints = midpoints[indices]
|
|
68
79
|
truncated_midpoints = midpoints[:out_length]
|
|
69
80
|
splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints)
|
|
70
81
|
max_split = jnp.minimum(actual_length, out_length)
|
|
71
|
-
max_split = max_split.astype(
|
|
82
|
+
max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
|
|
72
83
|
return splits, max_split
|
|
73
84
|
|
|
74
|
-
|
|
75
|
-
"""
|
|
76
|
-
Return the maximum value that can be stored in `x`.
|
|
77
|
-
|
|
78
|
-
Parameters
|
|
79
|
-
----------
|
|
80
|
-
x : array
|
|
81
|
-
A numerical numpy or jax array.
|
|
82
|
-
|
|
83
|
-
Returns
|
|
84
|
-
-------
|
|
85
|
-
maxval : scalar
|
|
86
|
-
The maximum value allowed by `x`'s type (+inf for floats).
|
|
87
|
-
"""
|
|
88
|
-
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
89
|
-
return jnp.iinfo(x.dtype).max
|
|
90
|
-
else:
|
|
91
|
-
return jnp.inf
|
|
92
|
-
|
|
85
|
+
@jax.jit
|
|
93
86
|
def bin_predictors(X, splits):
|
|
94
87
|
"""
|
|
95
88
|
Bin the predictors according to the given splits.
|
|
96
89
|
|
|
90
|
+
A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
|
|
91
|
+
|
|
97
92
|
Parameters
|
|
98
93
|
----------
|
|
99
94
|
X : array (p, n)
|
|
@@ -110,9 +105,9 @@ def bin_predictors(X, splits):
|
|
|
110
105
|
A matrix with `p` predictors and `n` observations, where each predictor
|
|
111
106
|
has been replaced by the index of the bin it falls into.
|
|
112
107
|
"""
|
|
113
|
-
return
|
|
108
|
+
return _bin_predictors(X, splits)
|
|
114
109
|
|
|
115
110
|
@jax.vmap
|
|
116
|
-
def
|
|
117
|
-
dtype =
|
|
111
|
+
def _bin_predictors(x, splits):
|
|
112
|
+
dtype = jaxext.minimal_unsigned_dtype(splits.size)
|
|
118
113
|
return jnp.searchsorted(splits, x).astype(dtype)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A JAX implementation of BART
|
|
5
5
|
Home-page: https://github.com/Gattocrucco/bartz
|
|
6
6
|
License: MIT
|
|
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
|
20
20
|
Project-URL: Repository, https://github.com/Gattocrucco/bartz
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
|
|
23
|
+
[](https://pypi.org/project/bartz/)
|
|
24
|
+
|
|
23
25
|
# BART vectoriZed
|
|
24
26
|
|
|
25
27
|
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
26
28
|
|
|
29
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
30
|
+
|
|
31
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
32
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
|
|
2
|
+
bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
|
|
3
|
+
bartz/_version.py,sha256=FVHPBGkfhbQDi_z3v0PiKJrXXqXOx0vGW_1VaqNJi7U,22
|
|
4
|
+
bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
|
|
5
|
+
bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
|
|
6
|
+
bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
|
|
7
|
+
bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
|
|
8
|
+
bartz/mcmcstep.py,sha256=3ba94hXBW4UAZ11SFshnwJAgn6bpIqSZdRy_wQjEkrk,39278
|
|
9
|
+
bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
|
|
10
|
+
bartz-0.2.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
11
|
+
bartz-0.2.0.dist-info/METADATA,sha256=LiYjTAzgoxUM2MAuaKtf0VW-_zciTKBkTX5B7HNvUbI,1490
|
|
12
|
+
bartz-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
13
|
+
bartz-0.2.0.dist-info/RECORD,,
|
bartz-0.0.1.dist-info/RECORD
DELETED
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bartz/__init__.py,sha256=PL-vhhEHoVMWOPLG_M45TIZVkbQia5riJbQboy-BNH8,1333
|
|
2
|
-
bartz/debug.py,sha256=FHnCalpK1uO1CN9QQ5DPj70JKR4Thltzp9o0BeYthIo,5741
|
|
3
|
-
bartz/grove.py,sha256=v2k10EBjgi2aLCsGvM01z0z--9Xv4ApBOxpke-6gIYM,10309
|
|
4
|
-
bartz/interface.py,sha256=GBwLwqEF_6EmeteFtsPw6ANdisnvMoWi_fKBJiQq-Vc,16129
|
|
5
|
-
bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
|
|
6
|
-
bartz/mcmcloop.py,sha256=N815-eJxsS_X85okXRO2kSOlikw8dPN05_krm0iT9Sg,7321
|
|
7
|
-
bartz/mcmcstep.py,sha256=acy_2rSIEXV5BzqLY96aQaqlsxtalxyO3Q4gPvUMRVU,35912
|
|
8
|
-
bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
|
|
9
|
-
bartz-0.0.1.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
10
|
-
bartz-0.0.1.dist-info/METADATA,sha256=zDW1dM58gV7c_8ZTjEtTt_tcXabbz5roZBf36EdLxls,933
|
|
11
|
-
bartz-0.0.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
12
|
-
bartz-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|