bartz 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/prepcovars.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -55,10 +55,10 @@ def quantilized_splits_from_matrix(X, max_bins):
55
55
  """
56
56
  out_length = min(max_bins, X.shape[1]) - 1
57
57
  # return _quantilized_splits_from_matrix(X, out_length)
58
- @functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
59
- def func(X):
58
+ @functools.partial(jaxext.autobatch, max_io_nbytes=2 ** 29)
59
+ def quantilize(X):
60
60
  return _quantilized_splits_from_matrix(X, out_length)
61
- return func(X)
61
+ return quantilize(X)
62
62
 
63
63
  @functools.partial(jax.vmap, in_axes=(0, None))
64
64
  def _quantilized_splits_from_matrix(x, out_length):
@@ -82,8 +82,36 @@ def _quantilized_splits_from_matrix(x, out_length):
82
82
  max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
83
83
  return splits, max_split
84
84
 
85
- @jax.jit
86
- def bin_predictors(X, splits):
85
+ @functools.partial(jax.jit, static_argnums=(1,))
86
+ def uniform_splits_from_matrix(X, num_bins):
87
+ """
88
+ Make an evenly spaced binning grid.
89
+
90
+ Parameters
91
+ ----------
92
+ X : array (p, n)
93
+ A matrix with `p` predictors and `n` observations.
94
+ num_bins : int
95
+ The number of bins to produce.
96
+
97
+ Returns
98
+ -------
99
+ splits : array (p, num_bins - 1)
100
+ A matrix containing, for each predictor, the boundaries between bins.
101
+ The excluded endpoints are the minimum and maximum value in each row of
102
+ `X`.
103
+ max_split : array (p,)
104
+ The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
105
+ """
106
+ low = jnp.min(X, axis=1)
107
+ high = jnp.max(X, axis=1)
108
+ splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1]
109
+ assert splits.shape == (X.shape[0], num_bins - 1)
110
+ max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1))
111
+ return splits, max_split
112
+
113
+ @functools.partial(jax.jit, static_argnames=('method',))
114
+ def bin_predictors(X, splits, **kw):
87
115
  """
88
116
  Bin the predictors according to the given splits.
89
117
 
@@ -98,6 +126,8 @@ def bin_predictors(X, splits):
98
126
  `m` is the maximum number of splits; each row may have shorter
99
127
  actual length, marked by padding unused locations at the end of the
100
128
  row with the maximum value allowed by the type.
129
+ **kw : dict
130
+ Additional arguments are passed to `jax.numpy.searchsorted`.
101
131
 
102
132
  Returns
103
133
  -------
@@ -105,9 +135,9 @@ def bin_predictors(X, splits):
105
135
  A matrix with `p` predictors and `n` observations, where each predictor
106
136
  has been replaced by the index of the bin it falls into.
107
137
  """
108
- return _bin_predictors(X, splits)
109
-
110
- @jax.vmap
111
- def _bin_predictors(x, splits):
112
- dtype = jaxext.minimal_unsigned_dtype(splits.size)
113
- return jnp.searchsorted(splits, x).astype(dtype)
138
+ @functools.partial(jaxext.autobatch, max_io_nbytes=2 ** 29)
139
+ @jax.vmap
140
+ def bin_predictors(x, splits):
141
+ dtype = jaxext.minimal_unsigned_dtype(splits.size)
142
+ return jnp.searchsorted(splits, x, **kw).astype(dtype)
143
+ return bin_predictors(X, splits)
@@ -0,0 +1,77 @@
1
+ Metadata-Version: 2.1
2
+ Name: bartz
3
+ Version: 0.3.0
4
+ Summary: A JAX implementation of BART
5
+ Home-page: https://github.com/Gattocrucco/bartz
6
+ License: MIT
7
+ Author: Giacomo Petrillo
8
+ Author-email: info@giacomopetrillo.com
9
+ Requires-Python: >=3.10,<4.0
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: jax (>=0.4.23,<0.5.0)
16
+ Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
17
+ Requires-Dist: numpy (>=1.25.2,<2.0.0)
18
+ Requires-Dist: scipy (>=1.11.4,<2.0.0)
19
+ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
+ Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
+ Description-Content-Type: text/markdown
22
+
23
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
+
25
+ # BART vectoriZed
26
+
27
+ A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
28
+
29
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
+
31
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
+
33
+ On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of), but using half the memory. On GPU, the speed premium depends on sample size; with 50000 datapoints and 5000 trees, on an Nvidia Tesla V100 GPU it's 12 times faster than an Apple M1 CPU, and this factor is linearly proportional to the number of datapoints.
34
+
35
+ ## Links
36
+
37
+ - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
38
+ - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
39
+ - [Repository](https://github.com/Gattocrucco/bartz)
40
+ - [Code coverage](https://gattocrucco.github.io/bartz/coverage)
41
+
42
+ ## Other BART packages
43
+
44
+ - [stochtree](https://github.com/StochasticTree) C++ library with R and Python bindings taylored to researchers who want to make their own BART variants
45
+ - [bnptools](https://github.com/rsparapa/bnptools) Feature-rich R packages for BART and some variants
46
+ - [dbarts](https://github.com/vdorie/dbarts) Fast R package
47
+ - [bartMachine](https://github.com/kapelner/bartMachine) Fast R package, supports missing predictors imputation
48
+ - [SoftBART](https://github.com/theodds/SoftBART) R package with a smooth version of BART
49
+ - [bcf](https://github.com/jaredsmurray/bcf) R package for a version of BART for causal inference
50
+ - [flexBART](https://github.com/skdeshpande91/flexBART) Fast R package, supports categorical predictors
51
+ - [flexBCF](https://github.com/skdeshpande91/flexBCF) R package, version of bcf optimized for large datasets
52
+ - [XBART](https://github.com/JingyuHe/XBART) R/Python package, XBART is a faster variant of BART
53
+ - [BART](https://github.com/JingyuHe/BART) R package, BART warm-started with XBART
54
+ - [XBCF](https://github.com/socket778/XBCF)
55
+ - [BayesTree](https://cran.r-project.org/package=BayesTree) R package, original BART implementation
56
+ - [bartCause](https://github.com/vdorie/bartCause) R package, pre-made BART-based workflows for causal inference
57
+ - [stan4bart](https://github.com/vdorie/stan4bart)
58
+ - [VCBART](https://github.com/skdeshpande91/VCBART)
59
+ - [monbart](https://github.com/jaredsmurray/monbart)
60
+ - [mBART](https://github.com/remcc/mBART_shlib)
61
+ - [SequentialBART](https://github.com/mjdaniels/SequentialBART)
62
+ - [sparseBART](https://github.com/cspanbauer/sparseBART)
63
+ - [pymc-bart](https://github.com/pymc-devs/pymc-bart)
64
+ - [semibart](https://github.com/zeldow/semibart)
65
+ - [CSP-BART](https://github.com/ebprado/CSP-BART)
66
+ - [AMBARTI](https://github.com/ebprado/AMBARTI)
67
+ - [MOTR-BART](https://github.com/ebprado/MOTR-BART)
68
+ - [bcfbma](https://github.com/EoghanONeill/bcfbma)
69
+ - [bartBMAnew](https://github.com/EoghanONeill/bartBMAnew)
70
+ - [BART-BMA](https://github.com/BelindaHernandez/BART-BMA) (superseded by bartBMAnew)
71
+ - [gpbart](https://github.com/MateusMaiaDS/gpbart)
72
+ - [GPBART](https://github.com/nchenderson/GPBART)
73
+ - [bartpy](https://github.com/JakeColtman/bartpy)
74
+ - [BayesTreePrior](https://github.com/AlexiaJM/BayesTreePrior)
75
+ - [BayesTree.jl](https://github.com/mathcg/BayesTree.jl)
76
+ - [longbet](https://github.com/google/longbet)
77
+
@@ -0,0 +1,13 @@
1
+ bartz/BART.py,sha256=CbGzFWtYw5u38Z9-Hy3CbDXpKOOvPFAAkSqu2HZl8no,16862
2
+ bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
+ bartz/_version.py,sha256=3wVEs2QD_7OcTlD97cZdCeizd2hUbJJ0GeIO8wQIjrk,22
4
+ bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
+ bartz/grove.py,sha256=x_6NK_l-hrXfy1PhssYNJkX41-w_WqjDziww0E7YRS8,8500
6
+ bartz/jaxext.py,sha256=RcVWTCGS8lXF7GBsNbKrpuA4MTcokItq0CpWm3s7CGk,12033
7
+ bartz/mcmcloop.py,sha256=lKDszvniNXka99X3e9RCrTgvEAZHA7ZbVXEgxUYvKMY,7634
8
+ bartz/mcmcstep.py,sha256=HPcxfl5f-OESZul-iurn0JmOnUJBe6IYTVaATeR6YBA,54221
9
+ bartz/prepcovars.py,sha256=mMgfL-LGJ_8QpOL6iy7yfkL8A7FrT7Zfn5M3voyNwSQ,5818
10
+ bartz-0.3.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
+ bartz-0.3.0.dist-info/METADATA,sha256=ymZNoowDdqQFyAJdeKKj6t7h8_eBXQr2cVPglyoYLDQ,4500
12
+ bartz-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ bartz-0.3.0.dist-info/RECORD,,
@@ -1,32 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: bartz
3
- Version: 0.2.0
4
- Summary: A JAX implementation of BART
5
- Home-page: https://github.com/Gattocrucco/bartz
6
- License: MIT
7
- Author: Giacomo Petrillo
8
- Author-email: info@giacomopetrillo.com
9
- Requires-Python: >=3.10,<4.0
10
- Classifier: License :: OSI Approved :: MIT License
11
- Classifier: Programming Language :: Python :: 3
12
- Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Requires-Dist: jax (>=0.4.23,<0.5.0)
16
- Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
17
- Requires-Dist: numpy (>=1.25.2,<2.0.0)
18
- Requires-Dist: scipy (>=1.11.4,<2.0.0)
19
- Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
- Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
- Description-Content-Type: text/markdown
22
-
23
- [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
-
25
- # BART vectoriZed
26
-
27
- A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
28
-
29
- BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
-
31
- This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
-
@@ -1,13 +0,0 @@
1
- bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
2
- bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
- bartz/_version.py,sha256=FVHPBGkfhbQDi_z3v0PiKJrXXqXOx0vGW_1VaqNJi7U,22
4
- bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
- bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
6
- bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
7
- bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
8
- bartz/mcmcstep.py,sha256=3ba94hXBW4UAZ11SFshnwJAgn6bpIqSZdRy_wQjEkrk,39278
9
- bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
10
- bartz-0.2.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
- bartz-0.2.0.dist-info/METADATA,sha256=LiYjTAzgoxUM2MAuaKtf0VW-_zciTKBkTX5B7HNvUbI,1490
12
- bartz-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- bartz-0.2.0.dist-info/RECORD,,
File without changes
File without changes