bartz 0.7.0__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {bartz-0.7.0 → bartz-0.8.0}/PKG-INFO +17 -11
  2. {bartz-0.7.0 → bartz-0.8.0}/README.md +9 -7
  3. {bartz-0.7.0 → bartz-0.8.0}/pyproject.toml +92 -92
  4. bartz-0.8.0/src/bartz/BART/__init__.py +27 -0
  5. bartz-0.8.0/src/bartz/BART/_gbart.py +522 -0
  6. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/__init__.py +4 -2
  7. bartz-0.7.0/src/bartz/BART.py → bartz-0.8.0/src/bartz/_interface.py +256 -132
  8. bartz-0.8.0/src/bartz/_profiler.py +318 -0
  9. bartz-0.8.0/src/bartz/_version.py +1 -0
  10. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/debug.py +269 -314
  11. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/grove.py +124 -68
  12. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/jaxext/__init__.py +101 -27
  13. bartz-0.8.0/src/bartz/jaxext/_autobatch.py +444 -0
  14. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/jaxext/scipy/__init__.py +1 -1
  15. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/jaxext/scipy/special.py +3 -4
  16. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/jaxext/scipy/stats.py +1 -1
  17. bartz-0.8.0/src/bartz/mcmcloop.py +859 -0
  18. bartz-0.8.0/src/bartz/mcmcstep/__init__.py +35 -0
  19. bartz-0.8.0/src/bartz/mcmcstep/_moves.py +904 -0
  20. bartz-0.8.0/src/bartz/mcmcstep/_state.py +1114 -0
  21. bartz-0.8.0/src/bartz/mcmcstep/_step.py +1603 -0
  22. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/prepcovars.py +1 -1
  23. bartz-0.8.0/src/bartz/testing/__init__.py +29 -0
  24. bartz-0.8.0/src/bartz/testing/_dgp.py +442 -0
  25. bartz-0.7.0/src/bartz/_version.py +0 -1
  26. bartz-0.7.0/src/bartz/jaxext/_autobatch.py +0 -238
  27. bartz-0.7.0/src/bartz/mcmcloop.py +0 -668
  28. bartz-0.7.0/src/bartz/mcmcstep.py +0 -2616
  29. {bartz-0.7.0 → bartz-0.8.0}/src/bartz/.DS_Store +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bartz
3
- Version: 0.7.0
3
+ Version: 0.8.0
4
4
  Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
5
5
  Author: Giacomo Petrillo
6
6
  Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
@@ -10,10 +10,14 @@ Requires-Dist: jax>=0.5.3
10
10
  Requires-Dist: jaxtyping>=0.3.2
11
11
  Requires-Dist: numpy>=1.25.2
12
12
  Requires-Dist: scipy>=1.11.4
13
+ Requires-Dist: jax[cuda12] ; extra == 'cuda12'
14
+ Requires-Dist: jax[cuda13] ; extra == 'cuda13'
13
15
  Requires-Python: >=3.10
14
- Project-URL: Documentation, https://gattocrucco.github.io/bartz/docs-dev
15
- Project-URL: Homepage, https://github.com/Gattocrucco/bartz
16
- Project-URL: Issues, https://github.com/Gattocrucco/bartz/issues
16
+ Project-URL: Documentation, https://bartz-org.github.io/bartz/docs-dev
17
+ Project-URL: Homepage, https://github.com/bartz-org/bartz
18
+ Project-URL: Issues, https://github.com/bartz-org/bartz/issues
19
+ Provides-Extra: cuda12
20
+ Provides-Extra: cuda13
17
21
  Description-Content-Type: text/markdown
18
22
 
19
23
  [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
@@ -31,16 +35,18 @@ This Python module provides an implementation of BART that runs on GPU, to proce
31
35
 
32
36
  On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of) if n > 20,000, but using 1/20 of the memory. On GPU, the speed premium depends on sample size; it is convenient over CPU only for n > 10,000. The maximum speedup is currently 200x, on an Nvidia A100 and with at least 2,000,000 observations.
33
37
 
34
- [This Colab notebook](https://colab.research.google.com/github/Gattocrucco/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 5 minutes.
38
+ [This Colab notebook](https://colab.research.google.com/github/bartz-org/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 10 minutes.
39
+
40
+ BART is a very flexible method with many variants. This implementation provides only a small subset of the possible features. If you need a feature from [another BART implementation](https://bartz-org.github.io/bartz/docs-dev/pkglist.html) or from the BART literature, please [open an issue on github](https://github.com/bartz-org/bartz/issues).
35
41
 
36
42
  ## Links
37
43
 
38
- - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
39
- - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
40
- - [Repository](https://github.com/Gattocrucco/bartz)
41
- - [Code coverage](https://gattocrucco.github.io/bartz/coverage)
42
- - [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
43
- - [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
44
+ - [Documentation (latest release)](https://bartz-org.github.io/bartz/docs)
45
+ - [Documentation (development version)](https://bartz-org.github.io/bartz/docs-dev)
46
+ - [Repository](https://github.com/bartz-org/bartz)
47
+ - [Code coverage](https://bartz-org.github.io/bartz/coverage)
48
+ - [Benchmarks](https://bartz-org.github.io/bartz/benchmarks)
49
+ - [List of BART packages](https://bartz-org.github.io/bartz/docs-dev/pkglist.html)
44
50
 
45
51
  ## Citing bartz
46
52
 
@@ -13,16 +13,18 @@ This Python module provides an implementation of BART that runs on GPU, to proce
13
13
 
14
14
  On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of) if n > 20,000, but using 1/20 of the memory. On GPU, the speed premium depends on sample size; it is convenient over CPU only for n > 10,000. The maximum speedup is currently 200x, on an Nvidia A100 and with at least 2,000,000 observations.
15
15
 
16
- [This Colab notebook](https://colab.research.google.com/github/Gattocrucco/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 5 minutes.
16
+ [This Colab notebook](https://colab.research.google.com/github/bartz-org/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 10 minutes.
17
+
18
+ BART is a very flexible method with many variants. This implementation provides only a small subset of the possible features. If you need a feature from [another BART implementation](https://bartz-org.github.io/bartz/docs-dev/pkglist.html) or from the BART literature, please [open an issue on github](https://github.com/bartz-org/bartz/issues).
17
19
 
18
20
  ## Links
19
21
 
20
- - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
21
- - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
22
- - [Repository](https://github.com/Gattocrucco/bartz)
23
- - [Code coverage](https://gattocrucco.github.io/bartz/coverage)
24
- - [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
25
- - [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
22
+ - [Documentation (latest release)](https://bartz-org.github.io/bartz/docs)
23
+ - [Documentation (development version)](https://bartz-org.github.io/bartz/docs-dev)
24
+ - [Repository](https://github.com/bartz-org/bartz)
25
+ - [Code coverage](https://bartz-org.github.io/bartz/coverage)
26
+ - [Benchmarks](https://bartz-org.github.io/bartz/benchmarks)
27
+ - [List of BART packages](https://bartz-org.github.io/bartz/docs-dev/pkglist.html)
26
28
 
27
29
  ## Citing bartz
28
30
 
@@ -1,6 +1,6 @@
1
1
  # bartz/pyproject.toml
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -23,16 +23,14 @@
23
23
  # SOFTWARE.
24
24
 
25
25
  [build-system]
26
- requires = ["uv_build>=0.7.3,<0.8.0"]
26
+ requires = ["uv_build>=0.9.5,<0.10.0"]
27
27
  build-backend = "uv_build"
28
28
 
29
29
  [project]
30
30
  name = "bartz"
31
- version = "0.7.0"
31
+ version = "0.8.0"
32
32
  description = "Super-fast BART (Bayesian Additive Regression Trees) in Python"
33
- authors = [
34
- {name = "Giacomo Petrillo", email = "info@giacomopetrillo.com"},
35
- ]
33
+ authors = [{ name = "Giacomo Petrillo", email = "info@giacomopetrillo.com" }]
36
34
  license = "MIT"
37
35
  readme = "README.md"
38
36
  requires-python = ">=3.10"
@@ -44,35 +42,43 @@ dependencies = [
44
42
  "scipy>=1.11.4",
45
43
  ]
46
44
 
45
+ [project.optional-dependencies]
46
+ cuda12 = ["jax[cuda12]"]
47
+ cuda13 = ["jax[cuda13]"]
48
+
47
49
  [project.urls]
48
- Homepage = "https://github.com/Gattocrucco/bartz"
49
- Documentation = "https://gattocrucco.github.io/bartz/docs-dev"
50
- Issues = "https://github.com/Gattocrucco/bartz/issues"
50
+ Homepage = "https://github.com/bartz-org/bartz"
51
+ Documentation = "https://bartz-org.github.io/bartz/docs-dev"
52
+ Issues = "https://github.com/bartz-org/bartz/issues"
51
53
 
52
54
  [dependency-groups]
53
- only-local = [
55
+ dev = [
54
56
  "appnope>=0.1.4",
55
57
  "ipython>=8.36.0",
56
58
  "matplotlib>=3.10.3",
57
59
  "matplotlib-label-lines>=0.8.1",
58
60
  "pre-commit>=4.2.0",
59
- "pydoclint>=0.6.6",
60
- "ruff>=0.11.9",
61
- "scikit-learn>=1.6.1",
62
- "tomli>=2.2.1",
61
+ "rich>=13.9.4",
62
+ "snakeviz>=2.2.2",
63
+ "tqdm>=4.67.1",
63
64
  "virtualenv>=20.31.2",
64
- "xgboost>=3.0.0",
65
- ]
66
- ci = [
67
65
  "asv>=0.6.4",
68
- "coverage>=7.8.0",
69
- "myst-parser>=4.0.1",
66
+ "beartype>=0.20.2",
67
+ "flaky>=3.8.1",
68
+ "gitpython>=3.1.43",
70
69
  "packaging>=25.0",
71
70
  "polars[pandas,pyarrow]>=1.29.0",
72
71
  "pytest>=8.3.5",
72
+ "pytest-cov>=6.1.1",
73
+ "pytest-subtests>=0.14.1",
73
74
  "pytest-timeout>=2.4.0",
75
+ "pytest-xdist>=3.6.1",
76
+ "rpy2>=3.5.17",
74
77
  "sphinx>=8.1.3",
75
78
  "sphinx-autodoc-typehints>=3.0.1",
79
+ "tomli>=2.2.1",
80
+ "myst-nb>=1.2.0",
81
+ "jupyterlab>=4.4.2",
76
82
  ]
77
83
 
78
84
  [tool.pytest.ini_options]
@@ -85,12 +91,15 @@ addopts = [
85
91
  "--verbose",
86
92
  "--import-mode=importlib",
87
93
  ]
88
- timeout = 256
89
- timeout_method = "thread" # when jax hangs, signals do not work
94
+ filterwarnings = [
95
+ "ignore:unclosed database:ResourceWarning",
96
+ ]
97
+ timeout = 512
98
+ timeout_method = "thread" # when jax hangs, signals do not work
90
99
 
91
100
  [tool.coverage.run]
92
101
  branch = true
93
- source = ["bartz"]
102
+ source_pkgs = ["bartz", "tests"]
94
103
 
95
104
  [tool.coverage.report]
96
105
  show_missing = true
@@ -125,7 +134,7 @@ local = [
125
134
  ]
126
135
 
127
136
  [tool.ruff]
128
- exclude = [".asv", "*.ipynb"]
137
+ exclude = [".asv", "*.ipynb", "benchmarks/latest_bartz"]
129
138
  cache-dir = "config/ruff_cache"
130
139
 
131
140
  [tool.ruff.format]
@@ -138,84 +147,84 @@ split-on-trailing-comma = false
138
147
  [tool.ruff.lint]
139
148
  select = [
140
149
  "ERA", # eradicate
141
- "S", # flake8-bandit
142
- "BLE", # flake8-blind-except
143
- "B", # bugbear
144
- "A", # flake8-builtins
145
- "C4", # flake8-comprehensions
146
- "CPY", # flake8-copyright
147
- "DTZ", # flake8-datetimez
148
- "T10", # flake8-debugger
149
- "EM", # flake8-errmsg
150
- "EXE", # flake8-executable
151
- "FIX", # flake8-fixme
152
- "ISC", # flake8-implicit-str-concat
153
- "INP", # flake8-no-pep420
154
- "PIE", # flake8-pie
155
- "T20", # flake8-print
156
- "PT", # flake8-pytest-style
157
- "RSE", # flake8-raise
158
- "RET", # flake8-return
159
- "SLF", # flake8-self
160
- "SIM", # flake8-simplify
161
- "TID", # flake8-tidy-imports
162
- "ARG", # flake8-unused-arguments
163
- "PTH", # flake8-use-pathlib
164
- "FLY", # flynt
165
- "I", # isort
166
- "C90", # mccabe
167
- "NPY", # NumPy-specific rules
150
+ "S", # flake8-bandit
151
+ "BLE", # flake8-blind-except
152
+ "B", # bugbear
153
+ "A", # flake8-builtins
154
+ "C4", # flake8-comprehensions
155
+ "CPY", # flake8-copyright
156
+ "DTZ", # flake8-datetimez
157
+ "T10", # flake8-debugger
158
+ "EM", # flake8-errmsg
159
+ "EXE", # flake8-executable
160
+ "FIX", # flake8-fixme
161
+ "ISC", # flake8-implicit-str-concat
162
+ "INP", # flake8-no-pep420
163
+ "PIE", # flake8-pie
164
+ "T20", # flake8-print
165
+ "PT", # flake8-pytest-style
166
+ "RSE", # flake8-raise
167
+ "RET", # flake8-return
168
+ "SLF", # flake8-self
169
+ "SIM", # flake8-simplify
170
+ "TID", # flake8-tidy-imports
171
+ "ARG", # flake8-unused-arguments
172
+ "PTH", # flake8-use-pathlib
173
+ "FLY", # flynt
174
+ "I", # isort
175
+ "C90", # mccabe
176
+ "NPY", # NumPy-specific rules
168
177
  "PERF", # Perflint
169
- "W", # pycodestyle Warning
170
- "F", # pyflakes
171
- "D", # pydocstyle
172
- "PGH", # pygrep-hooks
173
- "PLC", # Pylint Convention
174
- "PLE", # Pylint Error
175
- "PLR", # Pylint Refactor
176
- "PLW", # Pyling Warning
177
- "UP", # pyupgrade
178
+ "W", # pycodestyle Warning
179
+ "F", # pyflakes
180
+ "D", # pydocstyle
181
+ "PGH", # pygrep-hooks
182
+ "PLC", # Pylint Convention
183
+ "PLE", # Pylint Error
184
+ "PLR", # Pylint Refactor
185
+ "PLW", # Pyling Warning
186
+ "UP", # pyupgrade
178
187
  "FURB", # refurb
179
- "RUF", # Ruff-specific rules
180
- "TRY", # tryceratops
188
+ "RUF", # Ruff-specific rules
189
+ "TRY", # tryceratops
181
190
  ]
182
191
  ignore = [
183
- "B028", # warn with stacklevel = 2
184
- "C408", # Unnecessary `dict()` call (rewrite as a literal), it's too convenient for kwargs
185
- "D105", # Missing docstring in magic method
186
- "F722", # Syntax error in forward annotation. I ignore this because jaxtyping uses strings for shapes instead of for deferred annotations.
192
+ "B028", # warn with stacklevel = 2
193
+ "C408", # Unnecessary `dict()` call (rewrite as a literal), it's too convenient for kwargs
194
+ "D105", # Missing docstring in magic method
195
+ "F722", # Syntax error in forward annotation. I ignore this because jaxtyping uses strings for shapes instead of for deferred annotations.
187
196
  "PIE790", # Unnecessary ... or pass. Ignored because sometimes I use ... as sentinel to tell the rest of ruff and pyright that an implementation is a stub.
188
- "PLR0913", # Too many arguments in function definition. Maybe I should do something about this?
189
- "PLR2004", # Magic value used in comparison, consider replacing `*` with a constant variable
197
+ "PLR0912", # Too many branches; ignore bc C901 already handles this
198
+ "PLR0913", # Too many arguments in function definition. Maybe I should do something about this?
199
+ "PLR2004", # Magic value used in comparison, consider replacing `*` with a constant variable
190
200
  "RET505", # Unnecessary `{branch}` after `return` statement. I ignore this because I like to keep branches for readability.
191
201
  "RET506", # Unnecessary `else` after `raise` statement. I ignore this because I like to keep branches for readability.
192
- "S101", # Use of `assert` detected. Too annoying.
193
- "S603", # `subprocess` call: check for execution of untrusted input. Too many false positives.
202
+ "S101", # Use of `assert` detected. Too annoying.
194
203
  "SIM108", # SIM108 Use ternary operator `*` instead of `if`-`else`-block, I find blocks more readable
195
- "UP037", # Remove quotes from type annotation. Ignore because jaxtyping.
204
+ "UP037", # Remove quotes from type annotation. Ignore because jaxtyping.
196
205
  ]
197
206
 
198
207
  [tool.ruff.lint.per-file-ignores]
199
208
  "{config/*,docs/*}" = [
200
- "D100", # Missing docstring in public module
201
- "D101", # Missing docstring in public class
202
- "D102", # Missing docstring in public method
203
- "D103", # Missing docstring in public function
204
- "D104", # Missing docstring in public package
205
- "INP001", # File * is part of an implicit namespace package. Add an `__init__.py`.
209
+ "D100", # Missing docstring in public module
210
+ "D101", # Missing docstring in public class
211
+ "D102", # Missing docstring in public method
212
+ "D103", # Missing docstring in public function
213
+ "D104", # Missing docstring in public package
214
+ "INP001", # File * is part of an implicit namespace package. Add an `__init__.py`.
206
215
  ]
207
216
  "src/bartz/_version.py" = [
208
- "CPY001", # Missing copyright notice at top of file
217
+ "CPY001", # Missing copyright notice at top of file
209
218
  ]
210
- "{config/*,docs/*,tests/*}" = [
211
- "T201", # `print` found
219
+ "{config/*,docs/*,tests/*,benchmarks/_vendor_latest_bartz.py,scripts/*}" = [
220
+ "T201", # `print` found
212
221
  ]
213
222
  "{tests/*,benchmarks/*}" = [
214
- "SLF001", # Private member accessed: `*`
215
- "TID253", # `{module}` is banned at the module level
223
+ "SLF001", # Private member accessed: `*`
224
+ "TID253", # `{module}` is banned at the module level
216
225
  ]
217
226
  "docs/conf.py" = [
218
- "S607", # Starting a process with a partial executable path. Ignored because for a build script it makes more sense to use PATH.
227
+ "S607", # Starting a process with a partial executable path. Ignored because for a build script it makes more sense to use PATH.
219
228
  ]
220
229
 
221
230
  [tool.ruff.lint.pydocstyle]
@@ -229,11 +238,6 @@ banned-module-level-imports = ["bartz.debug"]
229
238
  ban-relative-imports = "all"
230
239
 
231
240
  [tool.pydoclint]
232
- baseline = "config/pydoclint-baseline.txt"
233
- auto-regenerate-baseline = false
234
- # auto-regenerate = false because in pre-commit pydoclint does not see the
235
- # unmodified files, thinks those errors are gone, and removes them from the
236
- # baseline.
237
241
  arg-type-hints-in-signature = true
238
242
  arg-type-hints-in-docstring = false
239
243
  check-return-types = false
@@ -243,8 +247,4 @@ check-style-mismatch = true
243
247
  show-filenames-in-every-violation-message = true
244
248
  check-class-attributes = false
245
249
  # do not check class attributes because in dataclasses I document them as
246
- # init parameters because they are duplicated otherwise.
247
-
248
- [tool.uv]
249
- python-downloads = "never"
250
- python-preference = "only-system"
250
+ # init parameters because they are duplicated in the html docs otherwise.
@@ -0,0 +1,27 @@
1
+ # bartz/src/bartz/BART/__init__.py
2
+ #
3
+ # Copyright (c) 2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
26
+
27
+ from bartz.BART._gbart import gbart, mc_gbart # noqa: F401