bartz 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/.DS_Store ADDED
Binary file
bartz/BART.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/BART.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -27,11 +27,8 @@ import functools
27
27
  import jax
28
28
  import jax.numpy as jnp
29
29
 
30
- from . import jaxext
31
- from . import grove
32
- from . import mcmcstep
33
- from . import mcmcloop
34
- from . import prepcovars
30
+ from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
31
+
35
32
 
36
33
  class gbart:
37
34
  """
@@ -53,10 +50,11 @@ class gbart:
53
50
  Whether to use predictors quantiles instead of a uniform grid to bin
54
51
  predictors.
55
52
  sigest : float, optional
56
- An estimate of the residual standard deviation on `y_train`, used to
57
- set `lamda`. If not specified, it is estimated by linear regression.
58
- If `y_train` has less than two elements, it is set to 1. If n <= p, it
59
- is set to the variance of `y_train`. Ignored if `lamda` is specified.
53
+ An estimate of the residual standard deviation on `y_train`, used to set
54
+ `lamda`. If not specified, it is estimated by linear regression (with
55
+ intercept, and without taking into account `w`). If `y_train` has less
56
+ than two elements, it is set to 1. If n <= p, it is set to the standard
57
+ deviation of `y_train`. Ignored if `lamda` is specified.
60
58
  sigdf : int, default 3
61
59
  The degrees of freedom of the scaled inverse-chisquared prior on the
62
60
  noise variance.
@@ -82,6 +80,12 @@ class gbart:
82
80
  offset : float, optional
83
81
  The prior mean of the latent mean function. If not specified, it is set
84
82
  to the mean of `y_train`. If `y_train` is empty, it is set to 0.
83
+ w : array (n,), optional
84
+ Coefficients that rescale the error standard deviation on each
85
+ datapoint. Not specifying `w` is equivalent to setting it to 1 for all
86
+ datapoints. Note: `w` is ignored in the automatic determination of
87
+ `sigest`, so either the weights should be O(1), or `sigest` should be
88
+ specified by the user.
85
89
  ntree : int, default 200
86
90
  The number of trees used to represent the latent mean function.
87
91
  numcut : int, default 255
@@ -108,6 +112,8 @@ class gbart:
108
112
  The number of iterations (including skipped ones) between each log.
109
113
  seed : int or jax random key, default 0
110
114
  The seed for the random number generator.
115
+ initkw : dict
116
+ Additional arguments passed to `mcmcstep.init`.
111
117
 
112
118
  Attributes
113
119
  ----------
@@ -135,8 +141,6 @@ class gbart:
135
141
  The number of trees.
136
142
  maxdepth : int
137
143
  The maximum depth of the trees.
138
- initkw : dict
139
- Additional arguments passed to `mcmcstep.init`.
140
144
 
141
145
  Methods
142
146
  -------
@@ -158,10 +162,13 @@ class gbart:
158
162
  - A lot of functionality is missing (variable selection, discrete response).
159
163
  - There are some additional attributes, and some missing.
160
164
 
161
- The linear regression used to set `sigest` adds an intercept.
162
165
  """
163
166
 
164
- def __init__(self, x_train, y_train, *,
167
+ def __init__(
168
+ self,
169
+ x_train,
170
+ y_train,
171
+ *,
165
172
  x_test=None,
166
173
  usequants=False,
167
174
  sigest=None,
@@ -173,6 +180,7 @@ class gbart:
173
180
  maxdepth=6,
174
181
  lamda=None,
175
182
  offset=None,
183
+ w=None,
176
184
  ntree=200,
177
185
  numcut=255,
178
186
  ndpost=1000,
@@ -180,26 +188,41 @@ class gbart:
180
188
  keepevery=1,
181
189
  printevery=100,
182
190
  seed=0,
183
- initkw={},
184
- ):
185
-
191
+ initkw=None,
192
+ ):
186
193
  x_train, x_train_fmt = self._process_predictor_input(x_train)
187
-
188
- y_train, y_train_fmt = self._process_response_input(y_train)
194
+ y_train, _ = self._process_response_input(y_train)
189
195
  self._check_same_length(x_train, y_train)
196
+ if w is not None:
197
+ w, _ = self._process_response_input(w)
198
+ self._check_same_length(x_train, w)
190
199
 
191
200
  offset = self._process_offset_settings(y_train, offset)
192
201
  scale = self._process_scale_settings(y_train, k)
193
- lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
202
+ lamda, sigest = self._process_noise_variance_settings(
203
+ x_train, y_train, sigest, sigdf, sigquant, lamda, offset
204
+ )
194
205
 
195
206
  splits, max_split = self._determine_splits(x_train, usequants, numcut)
196
207
  x_train = self._bin_predictors(x_train, splits)
197
-
198
- y_train = self._transform_input(y_train, offset, scale)
199
- lamda_scaled = lamda / (scale * scale)
200
-
201
- mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree, initkw)
202
- final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
208
+ y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
209
+
210
+ mcmc_state = self._setup_mcmc(
211
+ x_train,
212
+ y_train,
213
+ w,
214
+ max_split,
215
+ lamda_scaled,
216
+ sigdf,
217
+ power,
218
+ base,
219
+ maxdepth,
220
+ ntree,
221
+ initkw,
222
+ )
223
+ final_state, burnin_trace, main_trace = self._run_mcmc(
224
+ mcmc_state, ndpost, nskip, keepevery, printevery, seed
225
+ )
203
226
 
204
227
  sigma = self._extract_sigma(main_trace, scale)
205
228
  first_sigma = self._extract_sigma(burnin_trace, scale)
@@ -239,7 +262,7 @@ class gbart:
239
262
 
240
263
  Parameters
241
264
  ----------
242
- x_test : array (m, p) or DataFrame
265
+ x_test : array (p, m) or DataFrame
243
266
  The test predictors.
244
267
 
245
268
  Returns
@@ -285,7 +308,9 @@ class gbart:
285
308
  assert get_length(x1) == get_length(x2)
286
309
 
287
310
  @staticmethod
288
- def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset):
311
+ def _process_noise_variance_settings(
312
+ x_train, y_train, sigest, sigdf, sigquant, lamda, offset
313
+ ):
289
314
  if lamda is not None:
290
315
  return lamda, None
291
316
  else:
@@ -298,7 +323,7 @@ class gbart:
298
323
  else:
299
324
  x_centered = x_train.T - x_train.mean(axis=1)
300
325
  y_centered = y_train - y_train.mean()
301
- # centering is equivalent to adding an intercept column
326
+ # centering is equivalent to adding an intercept column
302
327
  _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
303
328
  chisq = chisq.squeeze(0)
304
329
  dof = len(y_train) - rank
@@ -336,11 +361,25 @@ class gbart:
336
361
  return prepcovars.bin_predictors(x, splits)
337
362
 
338
363
  @staticmethod
339
- def _transform_input(y, offset, scale):
340
- return (y - offset) / scale
364
+ def _transform_input(y, lamda, offset, scale):
365
+ y = (y - offset) / scale
366
+ lamda = lamda / (scale * scale)
367
+ return y, lamda
341
368
 
342
369
  @staticmethod
343
- def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree, initkw):
370
+ def _setup_mcmc(
371
+ x_train,
372
+ y_train,
373
+ w,
374
+ max_split,
375
+ lamda,
376
+ sigdf,
377
+ power,
378
+ base,
379
+ maxdepth,
380
+ ntree,
381
+ initkw,
382
+ ):
344
383
  depth = jnp.arange(maxdepth - 1)
345
384
  p_nonterminal = base / (1 + depth).astype(float) ** power
346
385
  sigma2_alpha = sigdf / 2
@@ -348,6 +387,7 @@ class gbart:
348
387
  kw = dict(
349
388
  X=x_train,
350
389
  y=y_train,
390
+ error_scale=w,
351
391
  max_split=max_split,
352
392
  num_trees=ntree,
353
393
  p_nonterminal=p_nonterminal,
@@ -355,17 +395,20 @@ class gbart:
355
395
  sigma2_beta=sigma2_beta,
356
396
  min_points_per_leaf=5,
357
397
  )
358
- kw.update(initkw)
398
+ if initkw is not None:
399
+ kw.update(initkw)
359
400
  return mcmcstep.init(**kw)
360
401
 
361
402
  @staticmethod
362
403
  def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
363
- if isinstance(seed, jax.Array) and jnp.issubdtype(seed.dtype, jax.dtypes.prng_key):
404
+ if isinstance(seed, jax.Array) and jnp.issubdtype(
405
+ seed.dtype, jax.dtypes.prng_key
406
+ ):
364
407
  key = seed
365
408
  else:
366
409
  key = jax.random.key(seed)
367
410
  callback = mcmcloop.make_simple_print_callback(printevery)
368
- return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback, key)
411
+ return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
369
412
 
370
413
  @staticmethod
371
414
  def _predict(trace, x):
@@ -379,9 +422,9 @@ class gbart:
379
422
  def _extract_sigma(trace, scale):
380
423
  return scale * jnp.sqrt(trace['sigma2'])
381
424
 
382
-
383
425
  def _show_tree(self, i_sample, i_tree, print_all=False):
384
426
  from . import debug
427
+
385
428
  trace = self._main_trace
386
429
  leaf_tree = trace['leaf_trees'][i_sample, i_tree]
387
430
  var_tree = trace['var_trees'][i_sample, i_tree]
@@ -396,7 +439,9 @@ class gbart:
396
439
  else:
397
440
  resid = bart['resid']
398
441
  alpha = bart['sigma2_alpha'] + resid.size / 2
399
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
442
+ norm2 = jnp.dot(
443
+ resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
444
+ )
400
445
  beta = bart['sigma2_beta'] + norm2 / 2
401
446
  sigma2 = beta / alpha
402
447
  return jnp.sqrt(sigma2) * self.scale
@@ -404,22 +449,32 @@ class gbart:
404
449
  def _compare_resid(self):
405
450
  bart = self._mcmc_state
406
451
  resid1 = bart['resid']
407
- yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
452
+ yhat = grove.evaluate_forest(
453
+ bart['X'],
454
+ bart['leaf_trees'],
455
+ bart['var_trees'],
456
+ bart['split_trees'],
457
+ jnp.float32,
458
+ )
408
459
  resid2 = bart['y'] - yhat
409
460
  return resid1, resid2
410
461
 
411
462
  def _avg_acc(self):
412
463
  trace = self._main_trace
464
+
413
465
  def acc(prefix):
414
466
  acc = trace[f'{prefix}_acc_count']
415
467
  prop = trace[f'{prefix}_prop_count']
416
468
  return acc.sum() / prop.sum()
469
+
417
470
  return acc('grow'), acc('prune')
418
471
 
419
472
  def _avg_prop(self):
420
473
  trace = self._main_trace
474
+
421
475
  def prop(prefix):
422
476
  return trace[f'{prefix}_prop_count'].sum()
477
+
423
478
  pgrow = prop('grow')
424
479
  pprune = prop('prune')
425
480
  total = pgrow + pprune
@@ -432,16 +487,21 @@ class gbart:
432
487
 
433
488
  def _depth_distr(self):
434
489
  from . import debug
490
+
435
491
  trace = self._main_trace
436
492
  split_trees = trace['split_trees']
437
493
  return debug.trace_depth_distr(split_trees)
438
494
 
439
495
  def _points_per_leaf_distr(self):
440
496
  from . import debug
441
- return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state['X'])
497
+
498
+ return debug.trace_points_per_leaf_distr(
499
+ self._main_trace, self._mcmc_state['X']
500
+ )
442
501
 
443
502
  def _check_trees(self):
444
503
  from . import debug
504
+
445
505
  return debug.check_trace(self._main_trace, self._mcmc_state)
446
506
 
447
507
  def _tree_goes_bad(self):
bartz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/__init__.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -28,13 +28,5 @@ Super-fast BART (Bayesian Additive Regression Trees) in Python
28
28
  See the manual at https://gattocrucco.github.io/bartz/docs
29
29
  """
30
30
 
31
- from ._version import __version__
32
-
33
- from . import BART
34
-
35
- from . import debug
36
- from . import grove
37
- from . import mcmcstep
38
- from . import mcmcloop
39
- from . import prepcovars
40
- from . import jaxext
31
+ from . import BART, debug, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
32
+ from ._version import __version__ # noqa: F401
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.4.1'
1
+ __version__ = '0.5.0'
bartz/debug.py CHANGED
@@ -1,21 +1,19 @@
1
1
  import functools
2
2
 
3
3
  import jax
4
- from jax import numpy as jnp
5
4
  from jax import lax
5
+ from jax import numpy as jnp
6
6
 
7
- from . import grove
8
- from . import mcmcstep
9
- from . import jaxext
7
+ from . import grove, jaxext
10
8
 
11
- def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
12
9
 
10
+ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
13
11
  tee = '├──'
14
12
  corner = '└──'
15
13
  join = '│ '
16
14
  space = ' '
17
15
  down = '┐'
18
- bottom = '╢' # '┨' #
16
+ bottom = '╢' # '┨' #
19
17
 
20
18
  def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
21
19
  if index >= len(leaf_tree):
@@ -58,7 +56,7 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
58
56
 
59
57
  indent += next_indent
60
58
  unused = unused or is_leaf
61
-
59
+
62
60
  if unused and not print_all:
63
61
  return
64
62
 
@@ -67,58 +65,80 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
67
65
 
68
66
  traverse_tree(1, 0, '', '', '', False)
69
67
 
68
+
70
69
  def tree_actual_depth(split_tree):
71
70
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
72
71
  depth = grove.tree_depths(is_leaf.size)
73
72
  depth = jnp.where(is_leaf, depth, 0)
74
73
  return jnp.max(depth)
75
74
 
75
+
76
76
  def forest_depth_distr(split_trees):
77
77
  depth = grove.tree_depth(split_trees) + 1
78
78
  depths = jax.vmap(tree_actual_depth)(split_trees)
79
79
  return jnp.bincount(depths, length=depth)
80
80
 
81
+
81
82
  def trace_depth_distr(split_trees_trace):
82
83
  return jax.vmap(forest_depth_distr)(split_trees_trace)
83
84
 
85
+
84
86
  def points_per_leaf_distr(var_tree, split_tree, X):
85
87
  traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
86
88
  indices = traverse_tree(X, var_tree, split_tree)
87
- count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
89
+ count_tree = jnp.zeros(
90
+ 2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
91
+ )
88
92
  count_tree = count_tree.at[indices].add(1)
89
93
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
90
94
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
91
95
 
96
+
92
97
  def forest_points_per_leaf_distr(bart, X):
93
98
  distr = jnp.zeros(X.shape[1] + 1, int)
94
99
  trees = bart['var_trees'], bart['split_trees']
100
+
95
101
  def loop(distr, tree):
96
102
  return distr + points_per_leaf_distr(*tree, X), None
103
+
97
104
  distr, _ = lax.scan(loop, distr, trees)
98
105
  return distr
99
106
 
107
+
100
108
  def trace_points_per_leaf_distr(bart, X):
101
109
  def loop(_, bart):
102
110
  return None, forest_points_per_leaf_distr(bart, X)
111
+
103
112
  _, distr = lax.scan(loop, None, bart)
104
113
  return distr
105
114
 
115
+
106
116
  def check_types(leaf_tree, var_tree, split_tree, max_split):
107
117
  expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
108
118
  expected_split_dtype = max_split.dtype
109
- return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
119
+ return (
120
+ var_tree.dtype == expected_var_dtype
121
+ and split_tree.dtype == expected_split_dtype
122
+ )
123
+
110
124
 
111
125
  def check_sizes(leaf_tree, var_tree, split_tree, max_split):
112
126
  return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
113
127
 
128
+
114
129
  def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
115
130
  return (var_tree[0] == 0) & (split_tree[0] == 0)
116
131
 
132
+
117
133
  def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
118
134
  return jnp.all(jnp.isfinite(leaf_tree))
119
135
 
136
+
120
137
  def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
121
- index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
138
+ index = jnp.arange(
139
+ 2 * split_tree.size,
140
+ dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
141
+ )
122
142
  parent_index = index >> 1
123
143
  is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
124
144
  parent_is_leaf = split_tree[parent_index] == 0
@@ -126,6 +146,7 @@ def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
126
146
  stray = stray.at[1].set(False)
127
147
  return ~jnp.any(stray)
128
148
 
149
+
129
150
  check_functions = [
130
151
  check_types,
131
152
  check_sizes,
@@ -134,6 +155,7 @@ check_functions = [
134
155
  check_stray_nodes,
135
156
  ]
136
157
 
158
+
137
159
  def check_tree(leaf_tree, var_tree, split_tree, max_split):
138
160
  error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
139
161
  error = error_type(0)
@@ -144,15 +166,19 @@ def check_tree(leaf_tree, var_tree, split_tree, max_split):
144
166
  error |= bit
145
167
  return error
146
168
 
169
+
147
170
  def describe_error(error):
148
- return [
149
- func.__name__
150
- for i, func in enumerate(check_functions)
151
- if error & (1 << i)
152
- ]
171
+ return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
172
+
153
173
 
154
174
  check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None))
155
175
 
176
+
156
177
  @functools.partial(jax.vmap, in_axes=(0, None))
157
178
  def check_trace(trace, state):
158
- return check_forest(trace['leaf_trees'], trace['var_trees'], trace['split_trees'], state['max_split'])
179
+ return check_forest(
180
+ trace['leaf_trees'],
181
+ trace['var_trees'],
182
+ trace['split_trees'],
183
+ state['max_split'],
184
+ )
bartz/grove.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/grove.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -43,12 +43,12 @@ Since the nodes at the bottom can only be leaves and not decision nodes, the 'va
43
43
  import functools
44
44
  import math
45
45
 
46
- import jax
47
- from jax import numpy as jnp
48
46
  from jax import lax
47
+ from jax import numpy as jnp
49
48
 
50
49
  from . import jaxext
51
50
 
51
+
52
52
  def make_tree(depth, dtype):
53
53
  """
54
54
  Make an array to represent a binary tree.
@@ -66,7 +66,8 @@ def make_tree(depth, dtype):
66
66
  tree : array
67
67
  An array of zeroes with shape (2 ** depth,).
68
68
  """
69
- return jnp.zeros(2 ** depth, dtype)
69
+ return jnp.zeros(2**depth, dtype)
70
+
70
71
 
71
72
  def tree_depth(tree):
72
73
  """
@@ -85,6 +86,7 @@ def tree_depth(tree):
85
86
  """
86
87
  return int(round(math.log2(tree.shape[-1])))
87
88
 
89
+
88
90
  def traverse_tree(x, var_tree, split_tree):
89
91
  """
90
92
  Find the leaf where a point falls into.
@@ -125,6 +127,7 @@ def traverse_tree(x, var_tree, split_tree):
125
127
  (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
126
128
  return index
127
129
 
130
+
128
131
  @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
129
132
  @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
130
133
  def traverse_forest(X, var_trees, split_trees):
@@ -147,6 +150,7 @@ def traverse_forest(X, var_trees, split_trees):
147
150
  """
148
151
  return traverse_tree(X, var_trees, split_trees)
149
152
 
153
+
150
154
  def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
151
155
  """
152
156
  Evaluate a ensemble of trees at an array of points.
@@ -178,11 +182,12 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees
178
182
  leaves = leaf_trees[tree_index[:, None], indices]
179
183
  if sum_trees:
180
184
  return jnp.sum(leaves, axis=0, dtype=dtype)
181
- # this sum suggests to swap the vmaps, but I think it's better for X
182
- # copying to keep it that way
185
+ # this sum suggests to swap the vmaps, but I think it's better for X
186
+ # copying to keep it that way
183
187
  else:
184
188
  return leaves
185
189
 
190
+
186
191
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
187
192
  """
188
193
  Return a mask indicating the leaf nodes in a tree.
@@ -211,6 +216,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
211
216
  parent_nonleaf = parent_nonleaf.at[1].set(True)
212
217
  return is_leaf & parent_nonleaf
213
218
 
219
+
214
220
  def is_leaves_parent(split_tree):
215
221
  """
216
222
  Return a mask indicating the nodes with leaf (and only leaf) children.
@@ -225,14 +231,17 @@ def is_leaves_parent(split_tree):
225
231
  is_leaves_parent : bool array (2 ** (d - 1),)
226
232
  The mask indicating which nodes have leaf children.
227
233
  """
228
- index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
229
- left_index = index << 1 # left child
230
- right_index = left_index + 1 # right child
234
+ index = jnp.arange(
235
+ split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)
236
+ )
237
+ left_index = index << 1 # left child
238
+ right_index = left_index + 1 # right child
231
239
  left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
232
240
  right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0
233
241
  is_not_leaf = split_tree.astype(bool)
234
242
  return is_not_leaf & left_leaf & right_leaf
235
- # the 0-th item has split == 0, so it's not counted
243
+ # the 0-th item has split == 0, so it's not counted
244
+
236
245
 
237
246
  def tree_depths(tree_length):
238
247
  """
@@ -253,7 +262,7 @@ def tree_depths(tree_length):
253
262
  depths = []
254
263
  depth = 0
255
264
  for i in range(tree_length):
256
- if i == 2 ** depth:
265
+ if i == 2**depth:
257
266
  depth += 1
258
267
  depths.append(depth - 1)
259
268
  depths[0] = 0