bartz 0.0__tar.gz → 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,21 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.0
3
+ Version: 0.0.1
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
7
7
  Author: Giacomo Petrillo
8
8
  Author-email: info@giacomopetrillo.com
9
- Requires-Python: >=3.12,<4.0
9
+ Requires-Python: >=3.10,<4.0
10
10
  Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
12
14
  Classifier: Programming Language :: Python :: 3.12
13
- Requires-Dist: appnope (>=0.1.4,<0.2.0)
14
- Requires-Dist: jax (>=0.4.25,<0.5.0)
15
- Requires-Dist: jaxlib (>=0.4.25,<0.5.0)
16
- Requires-Dist: numpy (>=1.26.4,<2.0.0)
17
- Requires-Dist: scipy (>=1.12.0,<2.0.0)
15
+ Requires-Dist: jax (>=0.4.23,<0.5.0)
16
+ Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
17
+ Requires-Dist: numpy (>=1.25.2,<2.0.0)
18
+ Requires-Dist: scipy (>=1.11.4,<2.0.0)
18
19
  Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
19
20
  Project-URL: Repository, https://github.com/Gattocrucco/bartz
20
21
  Description-Content-Type: text/markdown
@@ -28,7 +28,7 @@ build-backend = "poetry.core.masonry.api"
28
28
 
29
29
  [tool.poetry]
30
30
  name = "bartz"
31
- version = "0.0"
31
+ version = "0.0.1"
32
32
  description = "A JAX implementation of BART"
33
33
  authors = ["Giacomo Petrillo <info@giacomopetrillo.com>"]
34
34
  license = "MIT"
@@ -42,16 +42,16 @@ packages = [
42
42
  "Bug Tracker" = "https://github.com/Gattocrucco/bartz/issues"
43
43
 
44
44
  [tool.poetry.dependencies]
45
- python = "^3.12"
46
- jax = "^0.4.25"
47
- jaxlib = "^0.4.25"
48
- numpy = "^1.26.4"
49
- scipy = "^1.12.0"
50
- appnope = "^0.1.4"
45
+ python = "^3.10"
46
+ jax = "^0.4.23"
47
+ jaxlib = "^0.4.23"
48
+ numpy = "^1.25.2"
49
+ scipy = "^1.11.4"
51
50
 
52
51
  [tool.poetry.group.dev.dependencies]
53
52
  ipython = "^8.22.2"
54
53
  matplotlib = "^3.8.3"
54
+ appnope = "^0.1.4"
55
55
 
56
56
  [tool.poetry.group.test.dependencies]
57
57
  coverage = "^7.4.3"
@@ -28,7 +28,7 @@ A jax implementation of BART
28
28
  See the manual at https://gattocrucco.github.io/bartz/docs
29
29
  """
30
30
 
31
- __version__ = '0.0'
31
+ __version__ = '0.0.1'
32
32
 
33
33
  from .interface import BART
34
34
 
@@ -65,15 +65,11 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
65
65
  else:
66
66
  link = ' '
67
67
 
68
- if print_all:
69
- max_number = len(leaf_tree) - 1
70
- ndigits = len(str(max_number))
71
- number = str(index).rjust(ndigits)
72
- number = f' {number} '
73
- else:
74
- number = ''
68
+ max_number = len(leaf_tree) - 1
69
+ ndigits = len(str(max_number))
70
+ number = str(index).rjust(ndigits)
75
71
 
76
- print(f'{number}{indent}{first_indent}{link}{node_str}')
72
+ print(f' {number} {indent}{first_indent}{link}{node_str}')
77
73
 
78
74
  indent += next_indent
79
75
  unused = unused or is_leaf
@@ -102,14 +102,6 @@ class BART:
102
102
 
103
103
  Attributes
104
104
  ----------
105
- offset : float
106
- The prior mean of the latent mean function.
107
- scale : float
108
- The prior standard deviation of the latent mean function.
109
- lamda : float
110
- The prior harmonic mean of the error variance.
111
- ntree : int
112
- The number of trees.
113
105
  yhat_train : array (ndpost, n)
114
106
  The conditional posterior mean at `x_train` for each MCMC iteration.
115
107
  yhat_train_mean : array (n,)
@@ -122,6 +114,18 @@ class BART:
122
114
  The standard deviation of the error.
123
115
  first_sigma : array (nskip,)
124
116
  The standard deviation of the error in the burn-in phase.
117
+ offset : float
118
+ The prior mean of the latent mean function.
119
+ scale : float
120
+ The prior standard deviation of the latent mean function.
121
+ lamda : float
122
+ The prior harmonic mean of the error variance.
123
+ sigest : float or None
124
+ The estimated standard deviation of the error used to set `lamda`.
125
+ ntree : int
126
+ The number of trees.
127
+ maxdepth : int
128
+ The maximum depth of the trees.
125
129
 
126
130
  Methods
127
131
  -------
@@ -166,17 +170,17 @@ class BART:
166
170
  y_train, y_train_fmt = self._process_response_input(y_train)
167
171
  self._check_same_length(x_train, y_train)
168
172
 
169
- lamda = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda)
170
173
  offset = self._process_offset_settings(y_train, offset)
171
174
  scale = self._process_scale_settings(y_train, k)
175
+ lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
172
176
 
173
177
  splits, max_split = self._determine_splits(x_train, numcut)
174
178
  x_train = self._bin_predictors(x_train, splits)
175
179
 
176
180
  y_train = self._transform_input(y_train, offset, scale)
177
- lamda = lamda / scale
181
+ lamda_scaled = lamda / (scale * scale)
178
182
 
179
- mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree)
183
+ mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree)
180
184
  final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
181
185
 
182
186
  sigma = self._extract_sigma(main_trace, scale)
@@ -184,8 +188,10 @@ class BART:
184
188
 
185
189
  self.offset = offset
186
190
  self.scale = scale
187
- self.lamda = lamda * scale
191
+ self.lamda = lamda
192
+ self.sigest = sigest
188
193
  self.ntree = ntree
194
+ self.maxdepth = maxdepth
189
195
  self.sigma = sigma
190
196
  self.first_sigma = first_sigma
191
197
 
@@ -261,25 +267,25 @@ class BART:
261
267
  assert get_length(x1) == get_length(x2)
262
268
 
263
269
  @staticmethod
264
- def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda):
270
+ def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset):
265
271
  if lamda is not None:
266
- return lamda
272
+ return lamda, None
267
273
  else:
268
274
  if sigest is not None:
269
275
  sigest2 = sigest * sigest
270
276
  elif y_train.size < 2:
271
277
  sigest2 = 1
272
278
  elif y_train.size <= x_train.shape[0]:
273
- sigest2 = jnp.var(y_train)
279
+ sigest2 = jnp.var(y_train - offset)
274
280
  else:
275
- _, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train)
281
+ _, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train - offset)
276
282
  chisq = chisq.squeeze(0)
277
283
  dof = len(y_train) - rank
278
284
  sigest2 = chisq / dof
279
285
  alpha = sigdf / 2
280
286
  invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
281
287
  invchi2rid = invchi2 * sigdf
282
- return sigest2 / invchi2rid
288
+ return sigest2 / invchi2rid, jnp.sqrt(sigest2)
283
289
 
284
290
  @staticmethod
285
291
  def _process_offset_settings(y_train, offset):
@@ -148,14 +148,17 @@ def make_simple_print_callback(printevery):
148
148
  prune_prop = bart['prune_prop_count'] / prop_total
149
149
  grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
150
150
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
151
- n_total = n_burn + n_save
151
+ n_total = n_burn + n_save * n_skip
152
152
  debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
153
153
  return callback
154
154
 
155
155
  def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
156
156
  if (i_total + 1) % printevery == 0:
157
157
  burnin_flag = ' (burnin)' if burnin else ''
158
- print(f'Iteration {i_total + 1:4d}/{n_total:d} '
158
+ total_str = str(n_total)
159
+ ndigits = len(total_str)
160
+ i_str = str(i_total + 1).rjust(ndigits)
161
+ print(f'Iteration {i_str}/{total_str} '
159
162
  f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
160
163
  f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
161
164
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes