bartz 0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/__init__.py +1 -1
- bartz/debug.py +4 -8
- bartz/interface.py +23 -17
- bartz/mcmcloop.py +5 -2
- {bartz-0.0.dist-info → bartz-0.0.1.dist-info}/METADATA +8 -7
- bartz-0.0.1.dist-info/RECORD +12 -0
- bartz-0.0.dist-info/RECORD +0 -12
- {bartz-0.0.dist-info → bartz-0.0.1.dist-info}/LICENSE +0 -0
- {bartz-0.0.dist-info → bartz-0.0.1.dist-info}/WHEEL +0 -0
bartz/__init__.py
CHANGED
bartz/debug.py
CHANGED
|
@@ -65,15 +65,11 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
65
65
|
else:
|
|
66
66
|
link = ' '
|
|
67
67
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
number = str(index).rjust(ndigits)
|
|
72
|
-
number = f' {number} '
|
|
73
|
-
else:
|
|
74
|
-
number = ''
|
|
68
|
+
max_number = len(leaf_tree) - 1
|
|
69
|
+
ndigits = len(str(max_number))
|
|
70
|
+
number = str(index).rjust(ndigits)
|
|
75
71
|
|
|
76
|
-
print(f'{number}{indent}{first_indent}{link}{node_str}')
|
|
72
|
+
print(f' {number} {indent}{first_indent}{link}{node_str}')
|
|
77
73
|
|
|
78
74
|
indent += next_indent
|
|
79
75
|
unused = unused or is_leaf
|
bartz/interface.py
CHANGED
|
@@ -102,14 +102,6 @@ class BART:
|
|
|
102
102
|
|
|
103
103
|
Attributes
|
|
104
104
|
----------
|
|
105
|
-
offset : float
|
|
106
|
-
The prior mean of the latent mean function.
|
|
107
|
-
scale : float
|
|
108
|
-
The prior standard deviation of the latent mean function.
|
|
109
|
-
lamda : float
|
|
110
|
-
The prior harmonic mean of the error variance.
|
|
111
|
-
ntree : int
|
|
112
|
-
The number of trees.
|
|
113
105
|
yhat_train : array (ndpost, n)
|
|
114
106
|
The conditional posterior mean at `x_train` for each MCMC iteration.
|
|
115
107
|
yhat_train_mean : array (n,)
|
|
@@ -122,6 +114,18 @@ class BART:
|
|
|
122
114
|
The standard deviation of the error.
|
|
123
115
|
first_sigma : array (nskip,)
|
|
124
116
|
The standard deviation of the error in the burn-in phase.
|
|
117
|
+
offset : float
|
|
118
|
+
The prior mean of the latent mean function.
|
|
119
|
+
scale : float
|
|
120
|
+
The prior standard deviation of the latent mean function.
|
|
121
|
+
lamda : float
|
|
122
|
+
The prior harmonic mean of the error variance.
|
|
123
|
+
sigest : float or None
|
|
124
|
+
The estimated standard deviation of the error used to set `lamda`.
|
|
125
|
+
ntree : int
|
|
126
|
+
The number of trees.
|
|
127
|
+
maxdepth : int
|
|
128
|
+
The maximum depth of the trees.
|
|
125
129
|
|
|
126
130
|
Methods
|
|
127
131
|
-------
|
|
@@ -166,17 +170,17 @@ class BART:
|
|
|
166
170
|
y_train, y_train_fmt = self._process_response_input(y_train)
|
|
167
171
|
self._check_same_length(x_train, y_train)
|
|
168
172
|
|
|
169
|
-
lamda = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda)
|
|
170
173
|
offset = self._process_offset_settings(y_train, offset)
|
|
171
174
|
scale = self._process_scale_settings(y_train, k)
|
|
175
|
+
lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
|
|
172
176
|
|
|
173
177
|
splits, max_split = self._determine_splits(x_train, numcut)
|
|
174
178
|
x_train = self._bin_predictors(x_train, splits)
|
|
175
179
|
|
|
176
180
|
y_train = self._transform_input(y_train, offset, scale)
|
|
177
|
-
|
|
181
|
+
lamda_scaled = lamda / (scale * scale)
|
|
178
182
|
|
|
179
|
-
mcmc_state = self._setup_mcmc(x_train, y_train, max_split,
|
|
183
|
+
mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree)
|
|
180
184
|
final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
|
|
181
185
|
|
|
182
186
|
sigma = self._extract_sigma(main_trace, scale)
|
|
@@ -184,8 +188,10 @@ class BART:
|
|
|
184
188
|
|
|
185
189
|
self.offset = offset
|
|
186
190
|
self.scale = scale
|
|
187
|
-
self.lamda = lamda
|
|
191
|
+
self.lamda = lamda
|
|
192
|
+
self.sigest = sigest
|
|
188
193
|
self.ntree = ntree
|
|
194
|
+
self.maxdepth = maxdepth
|
|
189
195
|
self.sigma = sigma
|
|
190
196
|
self.first_sigma = first_sigma
|
|
191
197
|
|
|
@@ -261,25 +267,25 @@ class BART:
|
|
|
261
267
|
assert get_length(x1) == get_length(x2)
|
|
262
268
|
|
|
263
269
|
@staticmethod
|
|
264
|
-
def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda):
|
|
270
|
+
def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset):
|
|
265
271
|
if lamda is not None:
|
|
266
|
-
return lamda
|
|
272
|
+
return lamda, None
|
|
267
273
|
else:
|
|
268
274
|
if sigest is not None:
|
|
269
275
|
sigest2 = sigest * sigest
|
|
270
276
|
elif y_train.size < 2:
|
|
271
277
|
sigest2 = 1
|
|
272
278
|
elif y_train.size <= x_train.shape[0]:
|
|
273
|
-
sigest2 = jnp.var(y_train)
|
|
279
|
+
sigest2 = jnp.var(y_train - offset)
|
|
274
280
|
else:
|
|
275
|
-
_, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train)
|
|
281
|
+
_, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train - offset)
|
|
276
282
|
chisq = chisq.squeeze(0)
|
|
277
283
|
dof = len(y_train) - rank
|
|
278
284
|
sigest2 = chisq / dof
|
|
279
285
|
alpha = sigdf / 2
|
|
280
286
|
invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
|
|
281
287
|
invchi2rid = invchi2 * sigdf
|
|
282
|
-
return sigest2 / invchi2rid
|
|
288
|
+
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
283
289
|
|
|
284
290
|
@staticmethod
|
|
285
291
|
def _process_offset_settings(y_train, offset):
|
bartz/mcmcloop.py
CHANGED
|
@@ -148,14 +148,17 @@ def make_simple_print_callback(printevery):
|
|
|
148
148
|
prune_prop = bart['prune_prop_count'] / prop_total
|
|
149
149
|
grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
|
|
150
150
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
151
|
-
n_total = n_burn + n_save
|
|
151
|
+
n_total = n_burn + n_save * n_skip
|
|
152
152
|
debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
|
|
153
153
|
return callback
|
|
154
154
|
|
|
155
155
|
def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
|
|
156
156
|
if (i_total + 1) % printevery == 0:
|
|
157
157
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
158
|
-
|
|
158
|
+
total_str = str(n_total)
|
|
159
|
+
ndigits = len(total_str)
|
|
160
|
+
i_str = str(i_total + 1).rjust(ndigits)
|
|
161
|
+
print(f'Iteration {i_str}/{total_str} '
|
|
159
162
|
f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
|
|
160
163
|
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
|
|
161
164
|
|
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.0.1
|
|
4
4
|
Summary: A JAX implementation of BART
|
|
5
5
|
Home-page: https://github.com/Gattocrucco/bartz
|
|
6
6
|
License: MIT
|
|
7
7
|
Author: Giacomo Petrillo
|
|
8
8
|
Author-email: info@giacomopetrillo.com
|
|
9
|
-
Requires-Python: >=3.
|
|
9
|
+
Requires-Python: >=3.10,<4.0
|
|
10
10
|
Classifier: License :: OSI Approved :: MIT License
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Requires-Dist:
|
|
14
|
-
Requires-Dist:
|
|
15
|
-
Requires-Dist:
|
|
16
|
-
Requires-Dist:
|
|
17
|
-
Requires-Dist: scipy (>=1.12.0,<2.0.0)
|
|
15
|
+
Requires-Dist: jax (>=0.4.23,<0.5.0)
|
|
16
|
+
Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
|
|
17
|
+
Requires-Dist: numpy (>=1.25.2,<2.0.0)
|
|
18
|
+
Requires-Dist: scipy (>=1.11.4,<2.0.0)
|
|
18
19
|
Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
19
20
|
Project-URL: Repository, https://github.com/Gattocrucco/bartz
|
|
20
21
|
Description-Content-Type: text/markdown
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bartz/__init__.py,sha256=PL-vhhEHoVMWOPLG_M45TIZVkbQia5riJbQboy-BNH8,1333
|
|
2
|
+
bartz/debug.py,sha256=FHnCalpK1uO1CN9QQ5DPj70JKR4Thltzp9o0BeYthIo,5741
|
|
3
|
+
bartz/grove.py,sha256=v2k10EBjgi2aLCsGvM01z0z--9Xv4ApBOxpke-6gIYM,10309
|
|
4
|
+
bartz/interface.py,sha256=GBwLwqEF_6EmeteFtsPw6ANdisnvMoWi_fKBJiQq-Vc,16129
|
|
5
|
+
bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
|
|
6
|
+
bartz/mcmcloop.py,sha256=N815-eJxsS_X85okXRO2kSOlikw8dPN05_krm0iT9Sg,7321
|
|
7
|
+
bartz/mcmcstep.py,sha256=acy_2rSIEXV5BzqLY96aQaqlsxtalxyO3Q4gPvUMRVU,35912
|
|
8
|
+
bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
|
|
9
|
+
bartz-0.0.1.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
10
|
+
bartz-0.0.1.dist-info/METADATA,sha256=zDW1dM58gV7c_8ZTjEtTt_tcXabbz5roZBf36EdLxls,933
|
|
11
|
+
bartz-0.0.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
12
|
+
bartz-0.0.1.dist-info/RECORD,,
|
bartz-0.0.dist-info/RECORD
DELETED
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bartz/__init__.py,sha256=KpeWEIKkXq7YyooRdLZK6atfLJimi0hIa8-xqBPnDWQ,1331
|
|
2
|
-
bartz/debug.py,sha256=SBdFJd9gXtBEw2b3R3pv14ARwKCDkHiycOuDm0iVEhA,5846
|
|
3
|
-
bartz/grove.py,sha256=v2k10EBjgi2aLCsGvM01z0z--9Xv4ApBOxpke-6gIYM,10309
|
|
4
|
-
bartz/interface.py,sha256=pPUOYpHCci_NADF58jrbee2tcnaKZYzwgK0FF5XuwJU,15823
|
|
5
|
-
bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
|
|
6
|
-
bartz/mcmcloop.py,sha256=ZxasSfWZzuYT_rOnSL8iSzIGQQS3ZtcBUhdXis2K-II,7207
|
|
7
|
-
bartz/mcmcstep.py,sha256=acy_2rSIEXV5BzqLY96aQaqlsxtalxyO3Q4gPvUMRVU,35912
|
|
8
|
-
bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
|
|
9
|
-
bartz-0.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
10
|
-
bartz-0.0.dist-info/METADATA,sha256=kKCGPdD_bOxJ3cLePc1R19kS3AzwGtFK_aR9ljBX6HY,869
|
|
11
|
-
bartz-0.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
12
|
-
bartz-0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|