bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1114 @@
1
+ # bartz/src/bartz/mcmcstep/_state.py
2
+ #
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Module defining the BART MCMC state and initialization."""
26
+
27
+ from collections.abc import Callable, Hashable
28
+ from dataclasses import fields
29
+ from functools import partial, wraps
30
+ from math import log2
31
+ from typing import Any, Literal, TypedDict, TypeVar
32
+
33
+ import numpy
34
+ from equinox import Module, error_if
35
+ from equinox import field as eqx_field
36
+ from jax import (
37
+ NamedSharding,
38
+ device_put,
39
+ eval_shape,
40
+ jit,
41
+ make_mesh,
42
+ random,
43
+ tree,
44
+ vmap,
45
+ )
46
+ from jax import numpy as jnp
47
+ from jax.scipy.linalg import solve_triangular
48
+ from jax.sharding import AxisType, Mesh, PartitionSpec
49
+ from jax.tree import flatten
50
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt
51
+
52
+ from bartz.grove import make_tree, tree_depths
53
+ from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype
54
+
55
+
56
+ def field(*, chains: bool = False, data: bool = False, **kwargs):
57
+ """Extend `equinox.field` with two new parameters.
58
+
59
+ Parameters
60
+ ----------
61
+ chains
62
+ Whether the arrays in the field have an optional first axis that
63
+ represents independent Markov chains.
64
+ data
65
+ Whether the last axis of the arrays in the field represent units of
66
+ the data.
67
+ **kwargs
68
+ Other parameters passed to `equinox.field`.
69
+
70
+ Returns
71
+ -------
72
+ A dataclass field descriptor with the special attributes in the metadata, unset if False.
73
+ """
74
+ metadata = dict(kwargs.pop('metadata', {}))
75
+ assert 'chains' not in metadata
76
+ assert 'data' not in metadata
77
+ if chains:
78
+ metadata['chains'] = True
79
+ if data:
80
+ metadata['data'] = True
81
+ return eqx_field(metadata=metadata, **kwargs)
82
+
83
+
84
+ def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
85
+ """Determine vmapping axes for chains.
86
+
87
+ This function determines the argument to the `in_axes` or `out_axes`
88
+ parameter of `jax.vmap` to vmap over all and only the chain axes found in the
89
+ pytree `x`.
90
+
91
+ Parameters
92
+ ----------
93
+ x
94
+ A pytree. Subpytrees that are Module attributes marked with
95
+ ``field(..., chains=True)`` are considered to have a leading chain axis.
96
+
97
+ Returns
98
+ -------
99
+ A pytree with the same structure as `x` with 0 or None in the leaves.
100
+ """
101
+ return _find_metadata(x, 'chains', 0, None)
102
+
103
+
104
+ def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
105
+ """Determine vmapping axes for data.
106
+
107
+ This is analogous to `chain_vmap_axes` but returns -1 for all fields
108
+ marked with ``field(..., data=True)``.
109
+ """
110
+ return _find_metadata(x, 'data', -1, None)
111
+
112
+
113
+ T = TypeVar('T')
114
+
115
+
116
+ def _find_metadata(
117
+ x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T
118
+ ) -> PyTree[T, ' S']:
119
+ """Replace all subtrees of x marked with a metadata key."""
120
+ if isinstance(x, Module):
121
+ args = []
122
+ for f in fields(x):
123
+ v = getattr(x, f.name)
124
+ if f.metadata.get('static', False):
125
+ args.append(v)
126
+ elif f.metadata.get(key, False):
127
+ subtree = tree.map(lambda _: if_true, v)
128
+ args.append(subtree)
129
+ else:
130
+ args.append(_find_metadata(v, key, if_true, if_false))
131
+ return x.__class__(*args)
132
+
133
+ def is_leaf(x) -> bool:
134
+ return isinstance(x, Module)
135
+
136
+ def get_axes(x: Module | Any) -> PyTree[T]:
137
+ if isinstance(x, Module):
138
+ return _find_metadata(x, key, if_true, if_false)
139
+ else:
140
+ return tree.map(lambda _: if_false, x)
141
+
142
+ return tree.map(get_axes, x, is_leaf=is_leaf)
143
+
144
+
145
+ class Forest(Module):
146
+ """Represents the MCMC state of a sum of trees."""
147
+
148
+ leaf_tree: (
149
+ Float32[Array, '*chains num_trees 2**d']
150
+ | Float32[Array, '*chains num_trees k 2**d']
151
+ ) = field(chains=True)
152
+ """The leaf values."""
153
+
154
+ var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
155
+ """The decision axes."""
156
+
157
+ split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
158
+ """The decision boundaries."""
159
+
160
+ affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
161
+ """Marks leaves that can be grown."""
162
+
163
+ max_split: UInt[Array, ' p']
164
+ """The maximum split index for each predictor."""
165
+
166
+ blocked_vars: UInt[Array, ' q'] | None
167
+ """Indices of variables that are not used. This shall include at least
168
+ the `i` such that ``max_split[i] == 0``, otherwise behavior is
169
+ undefined."""
170
+
171
+ p_nonterminal: Float32[Array, ' 2**d']
172
+ """The prior probability of each node being nonterminal, conditional on
173
+ its ancestors. Includes the nodes at maximum depth which should be set
174
+ to 0."""
175
+
176
+ p_propose_grow: Float32[Array, ' 2**(d-1)']
177
+ """The unnormalized probability of picking a leaf for a grow proposal."""
178
+
179
+ leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True)
180
+ """The index of the leaf each datapoints falls into, for each tree."""
181
+
182
+ min_points_per_decision_node: Int32[Array, ''] | None
183
+ """The minimum number of data points in a decision node."""
184
+
185
+ min_points_per_leaf: Int32[Array, ''] | None
186
+ """The minimum number of data points in a leaf node."""
187
+
188
+ log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True)
189
+ """The log transition and prior Metropolis-Hastings ratio for the
190
+ proposed move on each tree."""
191
+
192
+ log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True)
193
+ """The log likelihood ratio."""
194
+
195
+ grow_prop_count: Int32[Array, '*chains'] = field(chains=True)
196
+ """The number of grow proposals made during one full MCMC cycle."""
197
+
198
+ prune_prop_count: Int32[Array, '*chains'] = field(chains=True)
199
+ """The number of prune proposals made during one full MCMC cycle."""
200
+
201
+ grow_acc_count: Int32[Array, '*chains'] = field(chains=True)
202
+ """The number of grow moves accepted during one full MCMC cycle."""
203
+
204
+ prune_acc_count: Int32[Array, '*chains'] = field(chains=True)
205
+ """The number of prune moves accepted during one full MCMC cycle."""
206
+
207
+ leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None
208
+ """The prior precision matrix of a leaf, conditional on the tree structure.
209
+ For the univariate case (k=1), this is a scalar (the inverse variance).
210
+ The prior covariance of the sum of trees is
211
+ ``num_trees * leaf_prior_cov_inv^-1``."""
212
+
213
+ log_s: Float32[Array, '*chains p'] | None = field(chains=True)
214
+ """The logarithm of the prior probability for choosing a variable to split
215
+ along in a decision rule, conditional on the ancestors. Not normalized.
216
+ If `None`, use a uniform distribution."""
217
+
218
+ theta: Float32[Array, '*chains'] | None = field(chains=True)
219
+ """The concentration parameter for the Dirichlet prior on the variable
220
+ distribution `s`. Required only to update `log_s`."""
221
+
222
+ a: Float32[Array, ''] | None
223
+ """Parameter of the prior on `theta`. Required only to sample `theta`.
224
+ See `step_theta`."""
225
+
226
+ b: Float32[Array, ''] | None
227
+ """Parameter of the prior on `theta`. Required only to sample `theta`.
228
+ See `step_theta`."""
229
+
230
+ rho: Float32[Array, ''] | None
231
+ """Parameter of the prior on `theta`. Required only to sample `theta`.
232
+ See `step_theta`."""
233
+
234
+ def num_chains(self) -> int | None:
235
+ """Return the number of chains, or `None` if not multichain."""
236
+ # maybe this should be replaced by chain_shape() -> () | (int,)
237
+ if self.var_tree.ndim == 2:
238
+ return None
239
+ else:
240
+ return self.var_tree.shape[0]
241
+
242
+
243
+ class StepConfig(Module):
244
+ """Options for the MCMC step."""
245
+
246
+ steps_done: Int32[Array, '']
247
+ """The number of MCMC steps completed so far."""
248
+
249
+ sparse_on_at: Int32[Array, ''] | None
250
+ """After how many steps to turn on variable selection."""
251
+
252
+ resid_num_batches: int | None = field(static=True)
253
+ """The number of batches for computing the sum of residuals. If
254
+ `None`, they are computed with no batching."""
255
+
256
+ count_num_batches: int | None = field(static=True)
257
+ """The number of batches for computing counts. If
258
+ `None`, they are computed with no batching."""
259
+
260
+ prec_num_batches: int | None = field(static=True)
261
+ """The number of batches for computing precision scales. If
262
+ `None`, they are computed with no batching."""
263
+
264
+ prec_count_num_trees: int | None = field(static=True)
265
+ """Batch size for processing trees to compute count and prec trees."""
266
+
267
+ mesh: Mesh | None = field(static=True)
268
+ """The mesh used to shard data and computation across multiple devices."""
269
+
270
+
271
+ class State(Module):
272
+ """Represents the MCMC state of BART."""
273
+
274
+ X: UInt[Array, 'p n'] = field(data=True)
275
+ """The predictors."""
276
+
277
+ y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'] = field(
278
+ data=True
279
+ )
280
+ """The response. If the data type is `bool`, the model is binary regression."""
281
+
282
+ z: None | Float32[Array, '*chains n'] = field(chains=True, data=True)
283
+ """The latent variable for binary regression. `None` in continuous
284
+ regression."""
285
+
286
+ offset: Float32[Array, ''] | Float32[Array, ' k']
287
+ """Constant shift added to the sum of trees."""
288
+
289
+ resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field(
290
+ chains=True, data=True
291
+ )
292
+ """The residuals (`y` or `z` minus sum of trees)."""
293
+
294
+ error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None = (
295
+ field(chains=True)
296
+ )
297
+ """The inverse error covariance (scalar for univariate, matrix for multivariate).
298
+ `None` in binary regression."""
299
+
300
+ prec_scale: Float32[Array, ' n'] | None = field(data=True)
301
+ """The scale on the error precision, i.e., ``1 / error_scale ** 2``.
302
+ `None` in binary regression."""
303
+
304
+ error_cov_df: Float32[Array, ''] | None
305
+ """The df parameter of the inverse Wishart prior on the noise
306
+ covariance. For the univariate case, the relationship to the inverse
307
+ gamma prior parameters is ``alpha = df / 2``.
308
+ `None` in binary regression."""
309
+
310
+ error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None
311
+ """The scale parameter of the inverse Wishart prior on the noise
312
+ covariance. For the univariate case, the relationship to the inverse
313
+ gamma prior parameters is ``beta = scale / 2``.
314
+ `None` in binary regression."""
315
+
316
+ forest: Forest
317
+ """The sum of trees model."""
318
+
319
+ config: StepConfig
320
+ """Metadata and configurations for the MCMC step."""
321
+
322
+
323
+ def _init_shape_shifting_parameters(
324
+ y: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
325
+ offset: Float32[Array, ''] | Float32[Array, ' k'],
326
+ error_scale: Float32[Any, ' n'] | None,
327
+ error_cov_df: float | Float32[Any, ''] | None,
328
+ error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None,
329
+ leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
330
+ ) -> tuple[
331
+ bool,
332
+ tuple[()] | tuple[int],
333
+ None | Float32[Array, ''],
334
+ None | Float32[Array, ''],
335
+ None | Float32[Array, ''],
336
+ ]:
337
+ """
338
+ Check and initialize parameters that change array type/shape based on outcome kind.
339
+
340
+ Parameters
341
+ ----------
342
+ y
343
+ The response variable; the outcome type is deduced from `y` and then
344
+ all other parameters are checked against it.
345
+ offset
346
+ The offset to add to the predictions.
347
+ error_scale
348
+ Per-observation error scale (univariate only).
349
+ error_cov_df
350
+ The error covariance degrees of freedom.
351
+ error_cov_scale
352
+ The error covariance scale.
353
+ leaf_prior_cov_inv
354
+ The inverse of the leaf prior covariance.
355
+
356
+ Returns
357
+ -------
358
+ is_binary
359
+ Whether the outcome is binary.
360
+ kshape
361
+ The outcome shape, empty for univariate, (k,) for multivariate.
362
+ error_cov_inv
363
+ The initialized error covariance inverse.
364
+ error_cov_df
365
+ The error covariance degrees of freedom (as array).
366
+ error_cov_scale
367
+ The error covariance scale (as array).
368
+
369
+ Raises
370
+ ------
371
+ ValueError
372
+ If `y` is binary and multivariate.
373
+ """
374
+ # determine outcome kind, binary/continuous x univariate/multivariate
375
+ is_binary = y.dtype == bool
376
+ kshape = y.shape[:-1]
377
+
378
+ # Binary vs continuous
379
+ if is_binary:
380
+ if kshape:
381
+ msg = 'Binary multivariate regression not supported, open an issue at https://github.com/bartz-org/bartz/issues if you need it.'
382
+ raise ValueError(msg)
383
+ assert error_scale is None
384
+ assert error_cov_df is None
385
+ assert error_cov_scale is None
386
+ error_cov_inv = None
387
+ else:
388
+ error_cov_df = jnp.asarray(error_cov_df)
389
+ error_cov_scale = jnp.asarray(error_cov_scale)
390
+ assert error_cov_scale.shape == 2 * kshape
391
+
392
+ # Multivariate vs univariate
393
+ if kshape:
394
+ error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale)
395
+ else:
396
+ # inverse gamma prior: alpha = df / 2, beta = scale / 2
397
+ error_cov_inv = error_cov_df / error_cov_scale
398
+
399
+ assert leaf_prior_cov_inv.shape == 2 * kshape
400
+ assert offset.shape == kshape
401
+
402
+ return is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale
403
+
404
+
405
+ def _parse_p_nonterminal(
406
+ p_nonterminal: Float32[Any, ' d_minus_1'],
407
+ ) -> Float32[Array, ' d_minus_1+1']:
408
+ """Check it's in (0, 1) and pad with a 0 at the end."""
409
+ p_nonterminal = jnp.asarray(p_nonterminal)
410
+ ok = (p_nonterminal > 0) & (p_nonterminal < 1)
411
+ p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)')
412
+ return jnp.pad(p_nonterminal, (0, 1))
413
+
414
+
415
+ def make_p_nonterminal(
416
+ d: int, alpha: float | Float32[Array, ''], beta: float | Float32[Array, '']
417
+ ) -> Float32[Array, ' {d}-1']:
418
+ """Prepare the `p_nonterminal` argument to `init`.
419
+
420
+ It is calculated according to the formula:
421
+
422
+ P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based
423
+
424
+ Parameters
425
+ ----------
426
+ d
427
+ The maximum depth of the trees (d=1 means tree with only root node)
428
+ alpha
429
+ The a priori probability of the root node having children, conditional
430
+ on it being possible
431
+ beta
432
+ The exponent of the power decay of the probability of having children
433
+ with depth.
434
+
435
+ Returns
436
+ -------
437
+ An array of probabilities, one per tree level but the last.
438
+ """
439
+ assert d >= 1
440
+ depth = jnp.arange(d - 1)
441
+ return alpha / (1 + depth).astype(float) ** beta
442
+
443
+
444
+ def init(
445
+ *,
446
+ X: UInt[Any, 'p n'],
447
+ y: Float32[Any, ' n'] | Float32[Any, ' k n'] | Bool[Any, ' n'],
448
+ offset: float | Float32[Any, ''] | Float32[Any, ' k'],
449
+ max_split: UInt[Any, ' p'],
450
+ num_trees: int,
451
+ p_nonterminal: Float32[Any, ' d_minus_1'],
452
+ leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'],
453
+ error_cov_df: float | Float32[Any, ''] | None = None,
454
+ error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None,
455
+ error_scale: Float32[Any, ' n'] | None = None,
456
+ min_points_per_decision_node: int | Integer[Any, ''] | None = None,
457
+ resid_num_batches: int | None | Literal['auto'] = 'auto',
458
+ count_num_batches: int | None | Literal['auto'] = 'auto',
459
+ prec_num_batches: int | None | Literal['auto'] = 'auto',
460
+ prec_count_num_trees: int | None | Literal['auto'] = 'auto',
461
+ save_ratios: bool = False,
462
+ filter_splitless_vars: int = 0,
463
+ min_points_per_leaf: int | Integer[Any, ''] | None = None,
464
+ log_s: Float32[Any, ' p'] | None = None,
465
+ theta: float | Float32[Any, ''] | None = None,
466
+ a: float | Float32[Any, ''] | None = None,
467
+ b: float | Float32[Any, ''] | None = None,
468
+ rho: float | Float32[Any, ''] | None = None,
469
+ sparse_on_at: int | Integer[Any, ''] | None = None,
470
+ num_chains: int | None = None,
471
+ mesh: Mesh | dict[str, int] | None = None,
472
+ target_platform: Literal['cpu', 'gpu'] | None = None,
473
+ ) -> State:
474
+ """
475
+ Make a BART posterior sampling MCMC initial state.
476
+
477
+ Parameters
478
+ ----------
479
+ X
480
+ The predictors. Note this is trasposed compared to the usual convention.
481
+ y
482
+ The response. If the data type is `bool`, the regression model is binary
483
+ regression with probit. If two-dimensional, the outcome is multivariate
484
+ with the first axis indicating the component.
485
+ offset
486
+ Constant shift added to the sum of trees. 0 if not specified.
487
+ max_split
488
+ The maximum split index for each variable. All split ranges start at 1.
489
+ num_trees
490
+ The number of trees in the forest.
491
+ p_nonterminal
492
+ The probability of a nonterminal node at each depth. The maximum depth
493
+ of trees is fixed by the length of this array. Use `make_p_nonterminal`
494
+ to set it with the conventional formula.
495
+ leaf_prior_cov_inv
496
+ The prior precision matrix of a leaf, conditional on the tree structure.
497
+ For the univariate case (k=1), this is a scalar (the inverse variance).
498
+ The prior covariance of the sum of trees is
499
+ ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is
500
+ always zero.
501
+ error_cov_df
502
+ error_cov_scale
503
+ The df and scale parameters of the inverse Wishart prior on the error
504
+ covariance. For the univariate case, the relationship to the inverse
505
+ gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``.
506
+ Leave unspecified for binary regression.
507
+ error_scale
508
+ Each error is scaled by the corresponding factor in `error_scale`, so
509
+ the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
510
+ Not supported for binary regression. If not specified, defaults to 1 for
511
+ all points, but potentially skipping calculations.
512
+ min_points_per_decision_node
513
+ The minimum number of data points in a decision node. 0 if not
514
+ specified.
515
+ resid_num_batches
516
+ count_num_batches
517
+ prec_num_batches
518
+ The number of batches, along datapoints, for summing the residuals,
519
+ counting the number of datapoints in each leaf, and computing the
520
+ likelihood precision in each leaf, respectively. `None` for no batching.
521
+ If 'auto', it's chosen automatically based on the target platform; see
522
+ the description of `target_platform` below for how it is determined.
523
+ prec_count_num_trees
524
+ The number of trees to process at a time when counting datapoints or
525
+ computing the likelihood precision. If `None`, do all trees at once,
526
+ which may use too much memory. If 'auto' (default), it's chosen
527
+ automatically.
528
+ save_ratios
529
+ Whether to save the Metropolis-Hastings ratios.
530
+ filter_splitless_vars
531
+ The maximum number of variables without splits that can be ignored. If
532
+ there are more, `init` raises an exception.
533
+ min_points_per_leaf
534
+ The minimum number of datapoints in a leaf node. 0 if not specified.
535
+ Unlike `min_points_per_decision_node`, this constraint is not taken into
536
+ account in the Metropolis-Hastings ratio because it would be expensive
537
+ to compute. Grow moves that would violate this constraint are vetoed.
538
+ This parameter is independent of `min_points_per_decision_node` and
539
+ there is no check that they are coherent. It makes sense to set
540
+ ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
541
+ log_s
542
+ The logarithm of the prior probability for choosing a variable to split
543
+ along in a decision rule, conditional on the ancestors. Not normalized.
544
+ If not specified, use a uniform distribution. If not specified and
545
+ `theta` or `rho`, `a`, `b` are, it's initialized automatically.
546
+ theta
547
+ The concentration parameter for the Dirichlet prior on `s`. Required
548
+ only to update `log_s`. If not specified, and `rho`, `a`, `b` are
549
+ specified, it's initialized automatically.
550
+ a
551
+ b
552
+ rho
553
+ Parameters of the prior on `theta`. Required only to sample `theta`.
554
+ sparse_on_at
555
+ After how many MCMC steps to turn on variable selection.
556
+ num_chains
557
+ The number of independent MCMC chains to represent in the state. Single
558
+ chain with scalar values if not specified.
559
+ mesh
560
+ A jax mesh used to shard data and computation across multiple devices.
561
+ If it has a 'chains' axis, that axis is used to shard the chains. If it
562
+ has a 'data' axis, that axis is used to shard the datapoints.
563
+
564
+ As a shorthand, if a dictionary mapping axis names to axis size is
565
+ passed, the corresponding mesh is created, e.g., ``dict(chains=4,
566
+ data=2)`` will let jax pick 8 devices to split chains (which must be a
567
+ multiple of 4) across 4 pairs of devices, where in each pair the data is
568
+ split in two.
569
+
570
+ Note: if a mesh is passed, the arrays are always sharded according to
571
+ it. In particular even if the mesh has no 'chains' or 'data' axis, the
572
+ arrays will be replicated on all devices in the mesh.
573
+ target_platform
574
+ Platform ('cpu' or 'gpu') used to determine the number of batches
575
+ automatically. If `mesh` is specified, the platform is inferred from the
576
+ devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init`
577
+ is not invoked in a `jax.jit` context), the platform is set to the
578
+ platform of `y`. Otherwise, use `target_platform`.
579
+
580
+ To avoid confusion, in all cases where the `target_platform` argument
581
+ would be ignored, `init` raises an exception if `target_platform` is
582
+ set.
583
+
584
+ Returns
585
+ -------
586
+ An initialized BART MCMC state.
587
+
588
+ Raises
589
+ ------
590
+ ValueError
591
+ If `y` is boolean and arguments unused in binary regression are set.
592
+
593
+ Notes
594
+ -----
595
+ In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
596
+ of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
597
+ child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
598
+ integers in the range ``[0, 1, ..., max_split[i]]``.
599
+ """
600
+ # convert to array all array-like arguments that are used in other
601
+ # configurations but don't need further processing themselves
602
+ X = jnp.asarray(X)
603
+ y = jnp.asarray(y)
604
+ offset = jnp.asarray(offset)
605
+ leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv)
606
+ max_split = jnp.asarray(max_split)
607
+
608
+ # check p_nonterminal and pad it with a 0 at the end (still not final shape)
609
+ p_nonterminal = _parse_p_nonterminal(p_nonterminal)
610
+
611
+ # process arguments that change depending on outcome type
612
+ is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale = (
613
+ _init_shape_shifting_parameters(
614
+ y, offset, error_scale, error_cov_df, error_cov_scale, leaf_prior_cov_inv
615
+ )
616
+ )
617
+
618
+ # extract array sizes from arguments
619
+ (max_depth,) = p_nonterminal.shape
620
+ p, n = X.shape
621
+
622
+ # check and initialize sparsity parameters
623
+ if not _all_none_or_not_none(rho, a, b):
624
+ msg = 'rho, a, b are not either all `None` or all set'
625
+ raise ValueError(msg)
626
+ if theta is None and rho is not None:
627
+ theta = rho
628
+ if log_s is None and theta is not None:
629
+ log_s = jnp.zeros(max_split.size)
630
+ if not _all_none_or_not_none(theta, sparse_on_at):
631
+ msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set'
632
+ raise ValueError(msg)
633
+
634
+ # process multichain settings
635
+ chain_shape = () if num_chains is None else (num_chains,)
636
+ resid_shape = chain_shape + y.shape
637
+ tree_shape = (*chain_shape, num_trees)
638
+ add_chains = partial(_add_chains, chain_shape=chain_shape)
639
+
640
+ # determine batch sizes for reductions
641
+ mesh = _parse_mesh(num_chains, mesh)
642
+ target_platform = _parse_target_platform(
643
+ y, mesh, target_platform, resid_num_batches, count_num_batches, prec_num_batches
644
+ )
645
+ red_cfg = _parse_reduction_configs(
646
+ resid_num_batches,
647
+ count_num_batches,
648
+ prec_num_batches,
649
+ prec_count_num_trees,
650
+ y,
651
+ num_trees,
652
+ mesh,
653
+ target_platform,
654
+ )
655
+
656
+ # check there aren't too many deactivated predictors
657
+ msg = (
658
+ f'there are more than {filter_splitless_vars=} predictors with no splits, '
659
+ 'please increase `filter_splitless_vars` or investigate the missing splits'
660
+ )
661
+ offset = error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg)
662
+
663
+ # initialize all remaining stuff and put it in an unsharded state
664
+ state = State(
665
+ X=X,
666
+ y=y,
667
+ z=jnp.full(resid_shape, offset) if is_binary else None,
668
+ offset=offset,
669
+ resid=jnp.zeros(resid_shape)
670
+ if is_binary
671
+ else jnp.broadcast_to(y - offset[..., None], resid_shape),
672
+ error_cov_inv=add_chains(error_cov_inv),
673
+ prec_scale=_get_prec_scale(error_scale),
674
+ error_cov_df=error_cov_df,
675
+ error_cov_scale=error_cov_scale,
676
+ forest=Forest(
677
+ leaf_tree=make_tree(max_depth, jnp.float32, tree_shape + kshape),
678
+ var_tree=make_tree(
679
+ max_depth - 1, minimal_unsigned_dtype(p - 1), tree_shape
680
+ ),
681
+ split_tree=make_tree(max_depth - 1, max_split.dtype, tree_shape),
682
+ affluence_tree=(
683
+ make_tree(max_depth - 1, bool, tree_shape)
684
+ .at[..., 1]
685
+ .set(
686
+ True
687
+ if min_points_per_decision_node is None
688
+ else n >= min_points_per_decision_node
689
+ )
690
+ ),
691
+ blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split),
692
+ max_split=max_split,
693
+ grow_prop_count=jnp.zeros(chain_shape, int),
694
+ grow_acc_count=jnp.zeros(chain_shape, int),
695
+ prune_prop_count=jnp.zeros(chain_shape, int),
696
+ prune_acc_count=jnp.zeros(chain_shape, int),
697
+ p_nonterminal=p_nonterminal[tree_depths(2**max_depth)],
698
+ p_propose_grow=p_nonterminal[tree_depths(2 ** (max_depth - 1))],
699
+ leaf_indices=jnp.ones(
700
+ (*tree_shape, n), minimal_unsigned_dtype(2**max_depth - 1)
701
+ ),
702
+ min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
703
+ min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
704
+ log_trans_prior=jnp.zeros((*chain_shape, num_trees))
705
+ if save_ratios
706
+ else None,
707
+ log_likelihood=jnp.zeros((*chain_shape, num_trees))
708
+ if save_ratios
709
+ else None,
710
+ leaf_prior_cov_inv=leaf_prior_cov_inv,
711
+ log_s=add_chains(_asarray_or_none(log_s)),
712
+ theta=add_chains(_asarray_or_none(theta)),
713
+ rho=_asarray_or_none(rho),
714
+ a=_asarray_or_none(a),
715
+ b=_asarray_or_none(b),
716
+ ),
717
+ config=StepConfig(
718
+ steps_done=jnp.int32(0),
719
+ sparse_on_at=_asarray_or_none(sparse_on_at),
720
+ mesh=mesh,
721
+ **red_cfg,
722
+ ),
723
+ )
724
+
725
+ # move all arrays to the appropriate device
726
+ return _shard_state(state)
727
+
728
+
729
+ @partial(jit, donate_argnums=(0,))
730
+ def _get_prec_scale(
731
+ error_scale: Float32[Array, ' n'] | None,
732
+ ) -> Float32[Array, ' n'] | None:
733
+ """Compute 1 / error_scale**2.
734
+
735
+ This is a separate function to use donate_argnums to avoid intermediate
736
+ copies.
737
+ """
738
+ if error_scale is None:
739
+ return None
740
+ else:
741
+ return jnp.reciprocal(jnp.square(jnp.asarray(error_scale)))
742
+
743
+
744
+ def _get_blocked_vars(
745
+ filter_splitless_vars: int, max_split: UInt[Array, ' p']
746
+ ) -> None | UInt[Array, ' q']:
747
+ """Initialize the `blocked_vars` field."""
748
+ if filter_splitless_vars:
749
+ (p,) = max_split.shape
750
+ (blocked_vars,) = jnp.nonzero(
751
+ max_split == 0, size=filter_splitless_vars, fill_value=p
752
+ )
753
+ return blocked_vars.astype(minimal_unsigned_dtype(p))
754
+ # see `fully_used_variables` for the type cast
755
+ else:
756
+ return None
757
+
758
+
759
+ def _add_chains(
760
+ x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...]
761
+ ) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None:
762
+ """Broadcast `x` to all chains."""
763
+ if x is None:
764
+ return None
765
+ else:
766
+ return jnp.broadcast_to(x, chain_shape + x.shape)
767
+
768
+
769
+ def _parse_mesh(
770
+ num_chains: int | None, mesh: Mesh | dict[str, int] | None
771
+ ) -> Mesh | None:
772
+ """Parse the `mesh` argument."""
773
+ if mesh is None:
774
+ return None
775
+
776
+ # convert dict format to actual mesh
777
+ if isinstance(mesh, dict):
778
+ assert set(mesh).issubset({'chains', 'data'})
779
+ mesh = make_mesh(
780
+ tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh)
781
+ )
782
+
783
+ # check there's no chain mesh axis if there are no chains
784
+ if num_chains is None:
785
+ assert 'chains' not in mesh.axis_names
786
+
787
+ # check the axes we use are in auto mode
788
+ assert 'chains' not in mesh.axis_names or 'chains' in _auto_axes(mesh)
789
+ assert 'data' not in mesh.axis_names or 'data' in _auto_axes(mesh)
790
+
791
+ return mesh
792
+
793
+
794
+ def _parse_target_platform(
795
+ y: Array,
796
+ mesh: Mesh | None,
797
+ target_platform: Literal['cpu', 'gpu'] | None,
798
+ resid_num_batches: int | None | Literal['auto'],
799
+ count_num_batches: int | None | Literal['auto'],
800
+ prec_num_batches: int | None | Literal['auto'],
801
+ ) -> Literal['cpu', 'gpu'] | None:
802
+ if mesh is not None:
803
+ assert target_platform is None, 'mesh provided, do not set target_platform'
804
+ return mesh.devices.flat[0].platform
805
+ elif hasattr(y, 'platform'):
806
+ assert target_platform is None, 'device inferred from y, unset target_platform'
807
+ return y.platform()
808
+ elif (
809
+ resid_num_batches == 'auto'
810
+ or count_num_batches == 'auto'
811
+ or prec_num_batches == 'auto'
812
+ ):
813
+ assert target_platform in ('cpu', 'gpu')
814
+ return target_platform
815
+ else:
816
+ assert target_platform is None, 'target_platform not used, unset it'
817
+ return target_platform
818
+
819
+
820
+ def _auto_axes(mesh: Mesh) -> list[str]:
821
+ """Re-implement `Mesh.auto_axes` because that's missing in jax v0.5."""
822
+ # Mesh.auto_axes added in jax v0.6.0
823
+ return [
824
+ n
825
+ for n, t in zip(mesh.axis_names, mesh.axis_types, strict=True)
826
+ if t == AxisType.Auto
827
+ ]
828
+
829
+
830
+ def _shard_state(state: State) -> State:
831
+ """Place all fields in the state on the appropriate devices."""
832
+ mesh = state.config.mesh
833
+ if mesh is None:
834
+ return state
835
+
836
+ def shard_leaf(
837
+ x: Array | None, chain_axis: int | None, data_axis: int | None
838
+ ) -> Array | None:
839
+ if x is None:
840
+ return None
841
+
842
+ spec = [None] * x.ndim
843
+ if chain_axis is not None and 'chains' in mesh.axis_names:
844
+ spec[chain_axis] = 'chains'
845
+ if data_axis is not None and 'data' in mesh.axis_names:
846
+ spec[data_axis] = 'data'
847
+
848
+ # remove trailing Nones to be consistent with jax's output, it's useful
849
+ # for comparing shardings during debugging
850
+ while spec and spec[-1] is None:
851
+ spec.pop()
852
+
853
+ spec = PartitionSpec(*spec)
854
+ return device_put(x, NamedSharding(mesh, spec), donate=True)
855
+
856
+ return tree.map(
857
+ shard_leaf,
858
+ state,
859
+ chain_vmap_axes(state),
860
+ data_vmap_axes(state),
861
+ is_leaf=lambda x: x is None,
862
+ )
863
+
864
+
865
+ def _all_none_or_not_none(*args):
866
+ is_none = [x is None for x in args]
867
+ return all(is_none) or not any(is_none)
868
+
869
+
870
+ def _asarray_or_none(x):
871
+ if x is None:
872
+ return None
873
+ return jnp.asarray(x)
874
+
875
+
876
+ def _get_platform(mesh: Mesh | None) -> str:
877
+ if mesh is None:
878
+ return get_default_device().platform
879
+ else:
880
+ return mesh.devices.flat[0].platform
881
+
882
+
883
+ class _ReductionConfig(TypedDict):
884
+ """Fields of `StepConfig` related to reductions."""
885
+
886
+ resid_num_batches: int | None
887
+ count_num_batches: int | None
888
+ prec_num_batches: int | None
889
+ prec_count_num_trees: int | None
890
+
891
+
892
+ def _parse_reduction_configs(
893
+ resid_num_batches: int | None | Literal['auto'],
894
+ count_num_batches: int | None | Literal['auto'],
895
+ prec_num_batches: int | None | Literal['auto'],
896
+ prec_count_num_trees: int | None | Literal['auto'],
897
+ y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'],
898
+ num_trees: int,
899
+ mesh: Mesh | None,
900
+ target_platform: Literal['cpu', 'gpu'] | None,
901
+ ) -> _ReductionConfig:
902
+ """Determine settings for indexed reduces."""
903
+ n = y.shape[-1]
904
+ n //= get_axis_size(mesh, 'data') # per-device datapoints
905
+ parse_num_batches = partial(_parse_num_batches, target_platform, n)
906
+ return dict(
907
+ resid_num_batches=parse_num_batches(resid_num_batches, 'resid'),
908
+ count_num_batches=parse_num_batches(count_num_batches, 'count'),
909
+ prec_num_batches=parse_num_batches(prec_num_batches, 'prec'),
910
+ prec_count_num_trees=_parse_prec_count_num_trees(
911
+ prec_count_num_trees, num_trees, n
912
+ ),
913
+ )
914
+
915
+
916
+ def _parse_num_batches(
917
+ target_platform: Literal['cpu', 'gpu'] | None,
918
+ n: int,
919
+ num_batches: int | None | Literal['auto'],
920
+ which: Literal['resid', 'count', 'prec'],
921
+ ) -> int | None:
922
+ """Return the number of batches or determine it automatically."""
923
+ final_round = partial(_final_round, n)
924
+ if num_batches != 'auto':
925
+ nb = num_batches
926
+ elif target_platform == 'cpu':
927
+ nb = final_round(16)
928
+ elif target_platform == 'gpu':
929
+ nb = dict(resid=1024, count=2048, prec=1024)[which] # on an A4000
930
+ nb = final_round(nb)
931
+ return nb
932
+
933
+
934
+ def _final_round(n: int, num: float) -> int | None:
935
+ """Bound batch size, round number of batches to a power of 2, and disable batching if there's only 1 batch."""
936
+ # at least some elements per batch
937
+ num = min(n // 32, num)
938
+
939
+ # round to the nearest power of 2 because I guess XLA and the hardware
940
+ # will like that (not sure about this, maybe just multiple of 32?)
941
+ num = 2 ** round(log2(num)) if num else 0
942
+
943
+ # disable batching if the batch is as large as the whole dataset
944
+ return num if num > 1 else None
945
+
946
+
947
+ def _parse_prec_count_num_trees(
948
+ prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int
949
+ ) -> int | None:
950
+ """Return the number of trees to process at a time or determine it automatically."""
951
+ if prec_count_num_trees != 'auto':
952
+ return prec_count_num_trees
953
+ max_n_by_ntree = 2**27 # about 100M
954
+ pcnt = max_n_by_ntree // max(1, n)
955
+ pcnt = min(num_trees, pcnt)
956
+ pcnt = max(1, pcnt)
957
+ pcnt = _search_divisor(
958
+ pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2))
959
+ )
960
+ if pcnt >= num_trees:
961
+ pcnt = None
962
+ return pcnt
963
+
964
+
965
+ def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int:
966
+ """Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already.
967
+
968
+ If there is none, give up and return `target_divisor`.
969
+ """
970
+ assert target_divisor >= 1
971
+ assert 1 <= low <= up <= dividend
972
+ if dividend % target_divisor == 0:
973
+ return target_divisor
974
+ candidates = numpy.arange(low, up + 1)
975
+ divisors = candidates[dividend % candidates == 0]
976
+ if divisors.size == 0:
977
+ return target_divisor
978
+ penalty = numpy.abs(divisors - target_divisor)
979
+ closest = numpy.argmin(penalty)
980
+ return divisors[closest].item()
981
+
982
+
983
+ def get_axis_size(mesh: Mesh | None, axis_name: str) -> int:
984
+ if mesh is None or axis_name not in mesh.axis_names:
985
+ return 1
986
+ else:
987
+ i = mesh.axis_names.index(axis_name)
988
+ return mesh.axis_sizes[i]
989
+
990
+
991
+ def chol_with_gersh(
992
+ mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False
993
+ ) -> Float32[Array, '*batch_shape k k']:
994
+ """Cholesky with Gershgorin stabilization, supports batching."""
995
+ return _chol_with_gersh_impl(mat, absolute_eps)
996
+
997
+
998
+ @partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,))
999
+ def _chol_with_gersh_impl(
1000
+ mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool
1001
+ ) -> Float32[Array, '*batch_shape k k']:
1002
+ rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0)
1003
+ eps = jnp.finfo(mat.dtype).eps
1004
+ u = mat.shape[0] * rho * eps
1005
+ if absolute_eps:
1006
+ u += eps
1007
+ mat = mat.at[jnp.diag_indices_from(mat)].add(u)
1008
+ return jnp.linalg.cholesky(mat)
1009
+
1010
+
1011
+ def _inv_via_chol_with_gersh(mat: Float32[Array, 'k k']) -> Float32[Array, 'k k']:
1012
+ """Compute matrix inverse via Cholesky with Gershgorin stabilization.
1013
+
1014
+ DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO.
1015
+ """
1016
+ L = chol_with_gersh(mat)
1017
+ I = jnp.eye(mat.shape[0], dtype=mat.dtype)
1018
+ L_inv = solve_triangular(L, I, lower=True)
1019
+ return L_inv.T @ L_inv
1020
+
1021
+
1022
+ def get_num_chains(x: PyTree) -> int | None:
1023
+ """Get the number of chains of a pytree.
1024
+
1025
+ Find all nodes in the structure that define 'num_chains()', stopping
1026
+ traversal at nodes that define it. Check all values obtained invoking
1027
+ `num_chains` are equal, then return it.
1028
+ """
1029
+ leaves, _ = flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains'))
1030
+ num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')]
1031
+ ref = num_chains[0]
1032
+ assert all(c == ref for c in num_chains)
1033
+ return ref
1034
+
1035
+
1036
+ def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]:
1037
+ """Return `chain_vmap_axes(x)` but also set to 0 for random keys."""
1038
+ axes = chain_vmap_axes(x)
1039
+
1040
+ def axis_if_key(x, axis):
1041
+ if is_key(x):
1042
+ return 0
1043
+ else:
1044
+ return axis
1045
+
1046
+ return tree.map(axis_if_key, x, axes)
1047
+
1048
+
1049
+ def _get_mc_out_axes(
1050
+ fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None]
1051
+ ) -> PyTree[int | None]:
1052
+ """Decide chain vmap axes for outputs."""
1053
+ vmapped_fun = vmap(fun, in_axes=in_axes)
1054
+ out = eval_shape(vmapped_fun, *args)
1055
+ return chain_vmap_axes(out)
1056
+
1057
+
1058
+ def _find_mesh(x: PyTree) -> Mesh | None:
1059
+ """Find the mesh used for chains."""
1060
+
1061
+ class MeshFound(Exception):
1062
+ pass
1063
+
1064
+ def find_mesh(x: State | Any):
1065
+ if isinstance(x, State):
1066
+ raise MeshFound(x.config.mesh)
1067
+
1068
+ try:
1069
+ tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State))
1070
+ except MeshFound as e:
1071
+ return e.args[0]
1072
+ else:
1073
+ raise ValueError
1074
+
1075
+
1076
+ def _split_all_keys(x: PyTree, num_chains: int) -> PyTree:
1077
+ """Split all random keys in `num_chains` keys."""
1078
+ mesh = _find_mesh(x)
1079
+
1080
+ def split_key(x):
1081
+ if is_key(x):
1082
+ x = random.split(x, num_chains)
1083
+ if mesh is not None and 'chains' in mesh.axis_names:
1084
+ x = device_put(x, NamedSharding(mesh, PartitionSpec('chains')))
1085
+ return x
1086
+
1087
+ return tree.map(split_key, x)
1088
+
1089
+
1090
+ def vmap_chains(
1091
+ fun: Callable[..., T], *, auto_split_keys: bool = False
1092
+ ) -> Callable[..., T]:
1093
+ """Apply vmap on chain axes automatically if the inputs are multichain."""
1094
+
1095
+ @wraps(fun)
1096
+ def auto_vmapped_fun(*args, **kwargs) -> T:
1097
+ all_args = args, kwargs
1098
+ num_chains = get_num_chains(all_args)
1099
+ if num_chains is not None:
1100
+ if auto_split_keys:
1101
+ all_args = _split_all_keys(all_args, num_chains)
1102
+
1103
+ def wrapped_fun(args, kwargs):
1104
+ return fun(*args, **kwargs)
1105
+
1106
+ mc_in_axes = _chain_axes_with_keys(all_args)
1107
+ mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes)
1108
+ vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes)
1109
+ return vmapped_fun(*all_args)
1110
+
1111
+ else:
1112
+ return fun(*args, **kwargs)
1113
+
1114
+ return auto_vmapped_fun