bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/mcmcstep/_state.py
ADDED
|
@@ -0,0 +1,1114 @@
|
|
|
1
|
+
# bartz/src/bartz/mcmcstep/_state.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Module defining the BART MCMC state and initialization."""
|
|
26
|
+
|
|
27
|
+
from collections.abc import Callable, Hashable
|
|
28
|
+
from dataclasses import fields
|
|
29
|
+
from functools import partial, wraps
|
|
30
|
+
from math import log2
|
|
31
|
+
from typing import Any, Literal, TypedDict, TypeVar
|
|
32
|
+
|
|
33
|
+
import numpy
|
|
34
|
+
from equinox import Module, error_if
|
|
35
|
+
from equinox import field as eqx_field
|
|
36
|
+
from jax import (
|
|
37
|
+
NamedSharding,
|
|
38
|
+
device_put,
|
|
39
|
+
eval_shape,
|
|
40
|
+
jit,
|
|
41
|
+
make_mesh,
|
|
42
|
+
random,
|
|
43
|
+
tree,
|
|
44
|
+
vmap,
|
|
45
|
+
)
|
|
46
|
+
from jax import numpy as jnp
|
|
47
|
+
from jax.scipy.linalg import solve_triangular
|
|
48
|
+
from jax.sharding import AxisType, Mesh, PartitionSpec
|
|
49
|
+
from jax.tree import flatten
|
|
50
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt
|
|
51
|
+
|
|
52
|
+
from bartz.grove import make_tree, tree_depths
|
|
53
|
+
from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def field(*, chains: bool = False, data: bool = False, **kwargs):
|
|
57
|
+
"""Extend `equinox.field` with two new parameters.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
chains
|
|
62
|
+
Whether the arrays in the field have an optional first axis that
|
|
63
|
+
represents independent Markov chains.
|
|
64
|
+
data
|
|
65
|
+
Whether the last axis of the arrays in the field represent units of
|
|
66
|
+
the data.
|
|
67
|
+
**kwargs
|
|
68
|
+
Other parameters passed to `equinox.field`.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
A dataclass field descriptor with the special attributes in the metadata, unset if False.
|
|
73
|
+
"""
|
|
74
|
+
metadata = dict(kwargs.pop('metadata', {}))
|
|
75
|
+
assert 'chains' not in metadata
|
|
76
|
+
assert 'data' not in metadata
|
|
77
|
+
if chains:
|
|
78
|
+
metadata['chains'] = True
|
|
79
|
+
if data:
|
|
80
|
+
metadata['data'] = True
|
|
81
|
+
return eqx_field(metadata=metadata, **kwargs)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
|
|
85
|
+
"""Determine vmapping axes for chains.
|
|
86
|
+
|
|
87
|
+
This function determines the argument to the `in_axes` or `out_axes`
|
|
88
|
+
parameter of `jax.vmap` to vmap over all and only the chain axes found in the
|
|
89
|
+
pytree `x`.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
x
|
|
94
|
+
A pytree. Subpytrees that are Module attributes marked with
|
|
95
|
+
``field(..., chains=True)`` are considered to have a leading chain axis.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
A pytree with the same structure as `x` with 0 or None in the leaves.
|
|
100
|
+
"""
|
|
101
|
+
return _find_metadata(x, 'chains', 0, None)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
|
|
105
|
+
"""Determine vmapping axes for data.
|
|
106
|
+
|
|
107
|
+
This is analogous to `chain_vmap_axes` but returns -1 for all fields
|
|
108
|
+
marked with ``field(..., data=True)``.
|
|
109
|
+
"""
|
|
110
|
+
return _find_metadata(x, 'data', -1, None)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
T = TypeVar('T')
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _find_metadata(
|
|
117
|
+
x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T
|
|
118
|
+
) -> PyTree[T, ' S']:
|
|
119
|
+
"""Replace all subtrees of x marked with a metadata key."""
|
|
120
|
+
if isinstance(x, Module):
|
|
121
|
+
args = []
|
|
122
|
+
for f in fields(x):
|
|
123
|
+
v = getattr(x, f.name)
|
|
124
|
+
if f.metadata.get('static', False):
|
|
125
|
+
args.append(v)
|
|
126
|
+
elif f.metadata.get(key, False):
|
|
127
|
+
subtree = tree.map(lambda _: if_true, v)
|
|
128
|
+
args.append(subtree)
|
|
129
|
+
else:
|
|
130
|
+
args.append(_find_metadata(v, key, if_true, if_false))
|
|
131
|
+
return x.__class__(*args)
|
|
132
|
+
|
|
133
|
+
def is_leaf(x) -> bool:
|
|
134
|
+
return isinstance(x, Module)
|
|
135
|
+
|
|
136
|
+
def get_axes(x: Module | Any) -> PyTree[T]:
|
|
137
|
+
if isinstance(x, Module):
|
|
138
|
+
return _find_metadata(x, key, if_true, if_false)
|
|
139
|
+
else:
|
|
140
|
+
return tree.map(lambda _: if_false, x)
|
|
141
|
+
|
|
142
|
+
return tree.map(get_axes, x, is_leaf=is_leaf)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class Forest(Module):
|
|
146
|
+
"""Represents the MCMC state of a sum of trees."""
|
|
147
|
+
|
|
148
|
+
leaf_tree: (
|
|
149
|
+
Float32[Array, '*chains num_trees 2**d']
|
|
150
|
+
| Float32[Array, '*chains num_trees k 2**d']
|
|
151
|
+
) = field(chains=True)
|
|
152
|
+
"""The leaf values."""
|
|
153
|
+
|
|
154
|
+
var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
|
|
155
|
+
"""The decision axes."""
|
|
156
|
+
|
|
157
|
+
split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
|
|
158
|
+
"""The decision boundaries."""
|
|
159
|
+
|
|
160
|
+
affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
|
|
161
|
+
"""Marks leaves that can be grown."""
|
|
162
|
+
|
|
163
|
+
max_split: UInt[Array, ' p']
|
|
164
|
+
"""The maximum split index for each predictor."""
|
|
165
|
+
|
|
166
|
+
blocked_vars: UInt[Array, ' q'] | None
|
|
167
|
+
"""Indices of variables that are not used. This shall include at least
|
|
168
|
+
the `i` such that ``max_split[i] == 0``, otherwise behavior is
|
|
169
|
+
undefined."""
|
|
170
|
+
|
|
171
|
+
p_nonterminal: Float32[Array, ' 2**d']
|
|
172
|
+
"""The prior probability of each node being nonterminal, conditional on
|
|
173
|
+
its ancestors. Includes the nodes at maximum depth which should be set
|
|
174
|
+
to 0."""
|
|
175
|
+
|
|
176
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)']
|
|
177
|
+
"""The unnormalized probability of picking a leaf for a grow proposal."""
|
|
178
|
+
|
|
179
|
+
leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True)
|
|
180
|
+
"""The index of the leaf each datapoints falls into, for each tree."""
|
|
181
|
+
|
|
182
|
+
min_points_per_decision_node: Int32[Array, ''] | None
|
|
183
|
+
"""The minimum number of data points in a decision node."""
|
|
184
|
+
|
|
185
|
+
min_points_per_leaf: Int32[Array, ''] | None
|
|
186
|
+
"""The minimum number of data points in a leaf node."""
|
|
187
|
+
|
|
188
|
+
log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True)
|
|
189
|
+
"""The log transition and prior Metropolis-Hastings ratio for the
|
|
190
|
+
proposed move on each tree."""
|
|
191
|
+
|
|
192
|
+
log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True)
|
|
193
|
+
"""The log likelihood ratio."""
|
|
194
|
+
|
|
195
|
+
grow_prop_count: Int32[Array, '*chains'] = field(chains=True)
|
|
196
|
+
"""The number of grow proposals made during one full MCMC cycle."""
|
|
197
|
+
|
|
198
|
+
prune_prop_count: Int32[Array, '*chains'] = field(chains=True)
|
|
199
|
+
"""The number of prune proposals made during one full MCMC cycle."""
|
|
200
|
+
|
|
201
|
+
grow_acc_count: Int32[Array, '*chains'] = field(chains=True)
|
|
202
|
+
"""The number of grow moves accepted during one full MCMC cycle."""
|
|
203
|
+
|
|
204
|
+
prune_acc_count: Int32[Array, '*chains'] = field(chains=True)
|
|
205
|
+
"""The number of prune moves accepted during one full MCMC cycle."""
|
|
206
|
+
|
|
207
|
+
leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None
|
|
208
|
+
"""The prior precision matrix of a leaf, conditional on the tree structure.
|
|
209
|
+
For the univariate case (k=1), this is a scalar (the inverse variance).
|
|
210
|
+
The prior covariance of the sum of trees is
|
|
211
|
+
``num_trees * leaf_prior_cov_inv^-1``."""
|
|
212
|
+
|
|
213
|
+
log_s: Float32[Array, '*chains p'] | None = field(chains=True)
|
|
214
|
+
"""The logarithm of the prior probability for choosing a variable to split
|
|
215
|
+
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
216
|
+
If `None`, use a uniform distribution."""
|
|
217
|
+
|
|
218
|
+
theta: Float32[Array, '*chains'] | None = field(chains=True)
|
|
219
|
+
"""The concentration parameter for the Dirichlet prior on the variable
|
|
220
|
+
distribution `s`. Required only to update `log_s`."""
|
|
221
|
+
|
|
222
|
+
a: Float32[Array, ''] | None
|
|
223
|
+
"""Parameter of the prior on `theta`. Required only to sample `theta`.
|
|
224
|
+
See `step_theta`."""
|
|
225
|
+
|
|
226
|
+
b: Float32[Array, ''] | None
|
|
227
|
+
"""Parameter of the prior on `theta`. Required only to sample `theta`.
|
|
228
|
+
See `step_theta`."""
|
|
229
|
+
|
|
230
|
+
rho: Float32[Array, ''] | None
|
|
231
|
+
"""Parameter of the prior on `theta`. Required only to sample `theta`.
|
|
232
|
+
See `step_theta`."""
|
|
233
|
+
|
|
234
|
+
def num_chains(self) -> int | None:
|
|
235
|
+
"""Return the number of chains, or `None` if not multichain."""
|
|
236
|
+
# maybe this should be replaced by chain_shape() -> () | (int,)
|
|
237
|
+
if self.var_tree.ndim == 2:
|
|
238
|
+
return None
|
|
239
|
+
else:
|
|
240
|
+
return self.var_tree.shape[0]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class StepConfig(Module):
|
|
244
|
+
"""Options for the MCMC step."""
|
|
245
|
+
|
|
246
|
+
steps_done: Int32[Array, '']
|
|
247
|
+
"""The number of MCMC steps completed so far."""
|
|
248
|
+
|
|
249
|
+
sparse_on_at: Int32[Array, ''] | None
|
|
250
|
+
"""After how many steps to turn on variable selection."""
|
|
251
|
+
|
|
252
|
+
resid_num_batches: int | None = field(static=True)
|
|
253
|
+
"""The number of batches for computing the sum of residuals. If
|
|
254
|
+
`None`, they are computed with no batching."""
|
|
255
|
+
|
|
256
|
+
count_num_batches: int | None = field(static=True)
|
|
257
|
+
"""The number of batches for computing counts. If
|
|
258
|
+
`None`, they are computed with no batching."""
|
|
259
|
+
|
|
260
|
+
prec_num_batches: int | None = field(static=True)
|
|
261
|
+
"""The number of batches for computing precision scales. If
|
|
262
|
+
`None`, they are computed with no batching."""
|
|
263
|
+
|
|
264
|
+
prec_count_num_trees: int | None = field(static=True)
|
|
265
|
+
"""Batch size for processing trees to compute count and prec trees."""
|
|
266
|
+
|
|
267
|
+
mesh: Mesh | None = field(static=True)
|
|
268
|
+
"""The mesh used to shard data and computation across multiple devices."""
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class State(Module):
|
|
272
|
+
"""Represents the MCMC state of BART."""
|
|
273
|
+
|
|
274
|
+
X: UInt[Array, 'p n'] = field(data=True)
|
|
275
|
+
"""The predictors."""
|
|
276
|
+
|
|
277
|
+
y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'] = field(
|
|
278
|
+
data=True
|
|
279
|
+
)
|
|
280
|
+
"""The response. If the data type is `bool`, the model is binary regression."""
|
|
281
|
+
|
|
282
|
+
z: None | Float32[Array, '*chains n'] = field(chains=True, data=True)
|
|
283
|
+
"""The latent variable for binary regression. `None` in continuous
|
|
284
|
+
regression."""
|
|
285
|
+
|
|
286
|
+
offset: Float32[Array, ''] | Float32[Array, ' k']
|
|
287
|
+
"""Constant shift added to the sum of trees."""
|
|
288
|
+
|
|
289
|
+
resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field(
|
|
290
|
+
chains=True, data=True
|
|
291
|
+
)
|
|
292
|
+
"""The residuals (`y` or `z` minus sum of trees)."""
|
|
293
|
+
|
|
294
|
+
error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None = (
|
|
295
|
+
field(chains=True)
|
|
296
|
+
)
|
|
297
|
+
"""The inverse error covariance (scalar for univariate, matrix for multivariate).
|
|
298
|
+
`None` in binary regression."""
|
|
299
|
+
|
|
300
|
+
prec_scale: Float32[Array, ' n'] | None = field(data=True)
|
|
301
|
+
"""The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
302
|
+
`None` in binary regression."""
|
|
303
|
+
|
|
304
|
+
error_cov_df: Float32[Array, ''] | None
|
|
305
|
+
"""The df parameter of the inverse Wishart prior on the noise
|
|
306
|
+
covariance. For the univariate case, the relationship to the inverse
|
|
307
|
+
gamma prior parameters is ``alpha = df / 2``.
|
|
308
|
+
`None` in binary regression."""
|
|
309
|
+
|
|
310
|
+
error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None
|
|
311
|
+
"""The scale parameter of the inverse Wishart prior on the noise
|
|
312
|
+
covariance. For the univariate case, the relationship to the inverse
|
|
313
|
+
gamma prior parameters is ``beta = scale / 2``.
|
|
314
|
+
`None` in binary regression."""
|
|
315
|
+
|
|
316
|
+
forest: Forest
|
|
317
|
+
"""The sum of trees model."""
|
|
318
|
+
|
|
319
|
+
config: StepConfig
|
|
320
|
+
"""Metadata and configurations for the MCMC step."""
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _init_shape_shifting_parameters(
|
|
324
|
+
y: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
|
|
325
|
+
offset: Float32[Array, ''] | Float32[Array, ' k'],
|
|
326
|
+
error_scale: Float32[Any, ' n'] | None,
|
|
327
|
+
error_cov_df: float | Float32[Any, ''] | None,
|
|
328
|
+
error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None,
|
|
329
|
+
leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
|
|
330
|
+
) -> tuple[
|
|
331
|
+
bool,
|
|
332
|
+
tuple[()] | tuple[int],
|
|
333
|
+
None | Float32[Array, ''],
|
|
334
|
+
None | Float32[Array, ''],
|
|
335
|
+
None | Float32[Array, ''],
|
|
336
|
+
]:
|
|
337
|
+
"""
|
|
338
|
+
Check and initialize parameters that change array type/shape based on outcome kind.
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
y
|
|
343
|
+
The response variable; the outcome type is deduced from `y` and then
|
|
344
|
+
all other parameters are checked against it.
|
|
345
|
+
offset
|
|
346
|
+
The offset to add to the predictions.
|
|
347
|
+
error_scale
|
|
348
|
+
Per-observation error scale (univariate only).
|
|
349
|
+
error_cov_df
|
|
350
|
+
The error covariance degrees of freedom.
|
|
351
|
+
error_cov_scale
|
|
352
|
+
The error covariance scale.
|
|
353
|
+
leaf_prior_cov_inv
|
|
354
|
+
The inverse of the leaf prior covariance.
|
|
355
|
+
|
|
356
|
+
Returns
|
|
357
|
+
-------
|
|
358
|
+
is_binary
|
|
359
|
+
Whether the outcome is binary.
|
|
360
|
+
kshape
|
|
361
|
+
The outcome shape, empty for univariate, (k,) for multivariate.
|
|
362
|
+
error_cov_inv
|
|
363
|
+
The initialized error covariance inverse.
|
|
364
|
+
error_cov_df
|
|
365
|
+
The error covariance degrees of freedom (as array).
|
|
366
|
+
error_cov_scale
|
|
367
|
+
The error covariance scale (as array).
|
|
368
|
+
|
|
369
|
+
Raises
|
|
370
|
+
------
|
|
371
|
+
ValueError
|
|
372
|
+
If `y` is binary and multivariate.
|
|
373
|
+
"""
|
|
374
|
+
# determine outcome kind, binary/continuous x univariate/multivariate
|
|
375
|
+
is_binary = y.dtype == bool
|
|
376
|
+
kshape = y.shape[:-1]
|
|
377
|
+
|
|
378
|
+
# Binary vs continuous
|
|
379
|
+
if is_binary:
|
|
380
|
+
if kshape:
|
|
381
|
+
msg = 'Binary multivariate regression not supported, open an issue at https://github.com/bartz-org/bartz/issues if you need it.'
|
|
382
|
+
raise ValueError(msg)
|
|
383
|
+
assert error_scale is None
|
|
384
|
+
assert error_cov_df is None
|
|
385
|
+
assert error_cov_scale is None
|
|
386
|
+
error_cov_inv = None
|
|
387
|
+
else:
|
|
388
|
+
error_cov_df = jnp.asarray(error_cov_df)
|
|
389
|
+
error_cov_scale = jnp.asarray(error_cov_scale)
|
|
390
|
+
assert error_cov_scale.shape == 2 * kshape
|
|
391
|
+
|
|
392
|
+
# Multivariate vs univariate
|
|
393
|
+
if kshape:
|
|
394
|
+
error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale)
|
|
395
|
+
else:
|
|
396
|
+
# inverse gamma prior: alpha = df / 2, beta = scale / 2
|
|
397
|
+
error_cov_inv = error_cov_df / error_cov_scale
|
|
398
|
+
|
|
399
|
+
assert leaf_prior_cov_inv.shape == 2 * kshape
|
|
400
|
+
assert offset.shape == kshape
|
|
401
|
+
|
|
402
|
+
return is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def _parse_p_nonterminal(
|
|
406
|
+
p_nonterminal: Float32[Any, ' d_minus_1'],
|
|
407
|
+
) -> Float32[Array, ' d_minus_1+1']:
|
|
408
|
+
"""Check it's in (0, 1) and pad with a 0 at the end."""
|
|
409
|
+
p_nonterminal = jnp.asarray(p_nonterminal)
|
|
410
|
+
ok = (p_nonterminal > 0) & (p_nonterminal < 1)
|
|
411
|
+
p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)')
|
|
412
|
+
return jnp.pad(p_nonterminal, (0, 1))
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def make_p_nonterminal(
|
|
416
|
+
d: int, alpha: float | Float32[Array, ''], beta: float | Float32[Array, '']
|
|
417
|
+
) -> Float32[Array, ' {d}-1']:
|
|
418
|
+
"""Prepare the `p_nonterminal` argument to `init`.
|
|
419
|
+
|
|
420
|
+
It is calculated according to the formula:
|
|
421
|
+
|
|
422
|
+
P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based
|
|
423
|
+
|
|
424
|
+
Parameters
|
|
425
|
+
----------
|
|
426
|
+
d
|
|
427
|
+
The maximum depth of the trees (d=1 means tree with only root node)
|
|
428
|
+
alpha
|
|
429
|
+
The a priori probability of the root node having children, conditional
|
|
430
|
+
on it being possible
|
|
431
|
+
beta
|
|
432
|
+
The exponent of the power decay of the probability of having children
|
|
433
|
+
with depth.
|
|
434
|
+
|
|
435
|
+
Returns
|
|
436
|
+
-------
|
|
437
|
+
An array of probabilities, one per tree level but the last.
|
|
438
|
+
"""
|
|
439
|
+
assert d >= 1
|
|
440
|
+
depth = jnp.arange(d - 1)
|
|
441
|
+
return alpha / (1 + depth).astype(float) ** beta
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def init(
|
|
445
|
+
*,
|
|
446
|
+
X: UInt[Any, 'p n'],
|
|
447
|
+
y: Float32[Any, ' n'] | Float32[Any, ' k n'] | Bool[Any, ' n'],
|
|
448
|
+
offset: float | Float32[Any, ''] | Float32[Any, ' k'],
|
|
449
|
+
max_split: UInt[Any, ' p'],
|
|
450
|
+
num_trees: int,
|
|
451
|
+
p_nonterminal: Float32[Any, ' d_minus_1'],
|
|
452
|
+
leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'],
|
|
453
|
+
error_cov_df: float | Float32[Any, ''] | None = None,
|
|
454
|
+
error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None,
|
|
455
|
+
error_scale: Float32[Any, ' n'] | None = None,
|
|
456
|
+
min_points_per_decision_node: int | Integer[Any, ''] | None = None,
|
|
457
|
+
resid_num_batches: int | None | Literal['auto'] = 'auto',
|
|
458
|
+
count_num_batches: int | None | Literal['auto'] = 'auto',
|
|
459
|
+
prec_num_batches: int | None | Literal['auto'] = 'auto',
|
|
460
|
+
prec_count_num_trees: int | None | Literal['auto'] = 'auto',
|
|
461
|
+
save_ratios: bool = False,
|
|
462
|
+
filter_splitless_vars: int = 0,
|
|
463
|
+
min_points_per_leaf: int | Integer[Any, ''] | None = None,
|
|
464
|
+
log_s: Float32[Any, ' p'] | None = None,
|
|
465
|
+
theta: float | Float32[Any, ''] | None = None,
|
|
466
|
+
a: float | Float32[Any, ''] | None = None,
|
|
467
|
+
b: float | Float32[Any, ''] | None = None,
|
|
468
|
+
rho: float | Float32[Any, ''] | None = None,
|
|
469
|
+
sparse_on_at: int | Integer[Any, ''] | None = None,
|
|
470
|
+
num_chains: int | None = None,
|
|
471
|
+
mesh: Mesh | dict[str, int] | None = None,
|
|
472
|
+
target_platform: Literal['cpu', 'gpu'] | None = None,
|
|
473
|
+
) -> State:
|
|
474
|
+
"""
|
|
475
|
+
Make a BART posterior sampling MCMC initial state.
|
|
476
|
+
|
|
477
|
+
Parameters
|
|
478
|
+
----------
|
|
479
|
+
X
|
|
480
|
+
The predictors. Note this is trasposed compared to the usual convention.
|
|
481
|
+
y
|
|
482
|
+
The response. If the data type is `bool`, the regression model is binary
|
|
483
|
+
regression with probit. If two-dimensional, the outcome is multivariate
|
|
484
|
+
with the first axis indicating the component.
|
|
485
|
+
offset
|
|
486
|
+
Constant shift added to the sum of trees. 0 if not specified.
|
|
487
|
+
max_split
|
|
488
|
+
The maximum split index for each variable. All split ranges start at 1.
|
|
489
|
+
num_trees
|
|
490
|
+
The number of trees in the forest.
|
|
491
|
+
p_nonterminal
|
|
492
|
+
The probability of a nonterminal node at each depth. The maximum depth
|
|
493
|
+
of trees is fixed by the length of this array. Use `make_p_nonterminal`
|
|
494
|
+
to set it with the conventional formula.
|
|
495
|
+
leaf_prior_cov_inv
|
|
496
|
+
The prior precision matrix of a leaf, conditional on the tree structure.
|
|
497
|
+
For the univariate case (k=1), this is a scalar (the inverse variance).
|
|
498
|
+
The prior covariance of the sum of trees is
|
|
499
|
+
``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is
|
|
500
|
+
always zero.
|
|
501
|
+
error_cov_df
|
|
502
|
+
error_cov_scale
|
|
503
|
+
The df and scale parameters of the inverse Wishart prior on the error
|
|
504
|
+
covariance. For the univariate case, the relationship to the inverse
|
|
505
|
+
gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``.
|
|
506
|
+
Leave unspecified for binary regression.
|
|
507
|
+
error_scale
|
|
508
|
+
Each error is scaled by the corresponding factor in `error_scale`, so
|
|
509
|
+
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
510
|
+
Not supported for binary regression. If not specified, defaults to 1 for
|
|
511
|
+
all points, but potentially skipping calculations.
|
|
512
|
+
min_points_per_decision_node
|
|
513
|
+
The minimum number of data points in a decision node. 0 if not
|
|
514
|
+
specified.
|
|
515
|
+
resid_num_batches
|
|
516
|
+
count_num_batches
|
|
517
|
+
prec_num_batches
|
|
518
|
+
The number of batches, along datapoints, for summing the residuals,
|
|
519
|
+
counting the number of datapoints in each leaf, and computing the
|
|
520
|
+
likelihood precision in each leaf, respectively. `None` for no batching.
|
|
521
|
+
If 'auto', it's chosen automatically based on the target platform; see
|
|
522
|
+
the description of `target_platform` below for how it is determined.
|
|
523
|
+
prec_count_num_trees
|
|
524
|
+
The number of trees to process at a time when counting datapoints or
|
|
525
|
+
computing the likelihood precision. If `None`, do all trees at once,
|
|
526
|
+
which may use too much memory. If 'auto' (default), it's chosen
|
|
527
|
+
automatically.
|
|
528
|
+
save_ratios
|
|
529
|
+
Whether to save the Metropolis-Hastings ratios.
|
|
530
|
+
filter_splitless_vars
|
|
531
|
+
The maximum number of variables without splits that can be ignored. If
|
|
532
|
+
there are more, `init` raises an exception.
|
|
533
|
+
min_points_per_leaf
|
|
534
|
+
The minimum number of datapoints in a leaf node. 0 if not specified.
|
|
535
|
+
Unlike `min_points_per_decision_node`, this constraint is not taken into
|
|
536
|
+
account in the Metropolis-Hastings ratio because it would be expensive
|
|
537
|
+
to compute. Grow moves that would violate this constraint are vetoed.
|
|
538
|
+
This parameter is independent of `min_points_per_decision_node` and
|
|
539
|
+
there is no check that they are coherent. It makes sense to set
|
|
540
|
+
``min_points_per_decision_node >= 2 * min_points_per_leaf``.
|
|
541
|
+
log_s
|
|
542
|
+
The logarithm of the prior probability for choosing a variable to split
|
|
543
|
+
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
544
|
+
If not specified, use a uniform distribution. If not specified and
|
|
545
|
+
`theta` or `rho`, `a`, `b` are, it's initialized automatically.
|
|
546
|
+
theta
|
|
547
|
+
The concentration parameter for the Dirichlet prior on `s`. Required
|
|
548
|
+
only to update `log_s`. If not specified, and `rho`, `a`, `b` are
|
|
549
|
+
specified, it's initialized automatically.
|
|
550
|
+
a
|
|
551
|
+
b
|
|
552
|
+
rho
|
|
553
|
+
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
554
|
+
sparse_on_at
|
|
555
|
+
After how many MCMC steps to turn on variable selection.
|
|
556
|
+
num_chains
|
|
557
|
+
The number of independent MCMC chains to represent in the state. Single
|
|
558
|
+
chain with scalar values if not specified.
|
|
559
|
+
mesh
|
|
560
|
+
A jax mesh used to shard data and computation across multiple devices.
|
|
561
|
+
If it has a 'chains' axis, that axis is used to shard the chains. If it
|
|
562
|
+
has a 'data' axis, that axis is used to shard the datapoints.
|
|
563
|
+
|
|
564
|
+
As a shorthand, if a dictionary mapping axis names to axis size is
|
|
565
|
+
passed, the corresponding mesh is created, e.g., ``dict(chains=4,
|
|
566
|
+
data=2)`` will let jax pick 8 devices to split chains (which must be a
|
|
567
|
+
multiple of 4) across 4 pairs of devices, where in each pair the data is
|
|
568
|
+
split in two.
|
|
569
|
+
|
|
570
|
+
Note: if a mesh is passed, the arrays are always sharded according to
|
|
571
|
+
it. In particular even if the mesh has no 'chains' or 'data' axis, the
|
|
572
|
+
arrays will be replicated on all devices in the mesh.
|
|
573
|
+
target_platform
|
|
574
|
+
Platform ('cpu' or 'gpu') used to determine the number of batches
|
|
575
|
+
automatically. If `mesh` is specified, the platform is inferred from the
|
|
576
|
+
devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init`
|
|
577
|
+
is not invoked in a `jax.jit` context), the platform is set to the
|
|
578
|
+
platform of `y`. Otherwise, use `target_platform`.
|
|
579
|
+
|
|
580
|
+
To avoid confusion, in all cases where the `target_platform` argument
|
|
581
|
+
would be ignored, `init` raises an exception if `target_platform` is
|
|
582
|
+
set.
|
|
583
|
+
|
|
584
|
+
Returns
|
|
585
|
+
-------
|
|
586
|
+
An initialized BART MCMC state.
|
|
587
|
+
|
|
588
|
+
Raises
|
|
589
|
+
------
|
|
590
|
+
ValueError
|
|
591
|
+
If `y` is boolean and arguments unused in binary regression are set.
|
|
592
|
+
|
|
593
|
+
Notes
|
|
594
|
+
-----
|
|
595
|
+
In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
|
|
596
|
+
of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
|
|
597
|
+
child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
|
|
598
|
+
integers in the range ``[0, 1, ..., max_split[i]]``.
|
|
599
|
+
"""
|
|
600
|
+
# convert to array all array-like arguments that are used in other
|
|
601
|
+
# configurations but don't need further processing themselves
|
|
602
|
+
X = jnp.asarray(X)
|
|
603
|
+
y = jnp.asarray(y)
|
|
604
|
+
offset = jnp.asarray(offset)
|
|
605
|
+
leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv)
|
|
606
|
+
max_split = jnp.asarray(max_split)
|
|
607
|
+
|
|
608
|
+
# check p_nonterminal and pad it with a 0 at the end (still not final shape)
|
|
609
|
+
p_nonterminal = _parse_p_nonterminal(p_nonterminal)
|
|
610
|
+
|
|
611
|
+
# process arguments that change depending on outcome type
|
|
612
|
+
is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale = (
|
|
613
|
+
_init_shape_shifting_parameters(
|
|
614
|
+
y, offset, error_scale, error_cov_df, error_cov_scale, leaf_prior_cov_inv
|
|
615
|
+
)
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# extract array sizes from arguments
|
|
619
|
+
(max_depth,) = p_nonterminal.shape
|
|
620
|
+
p, n = X.shape
|
|
621
|
+
|
|
622
|
+
# check and initialize sparsity parameters
|
|
623
|
+
if not _all_none_or_not_none(rho, a, b):
|
|
624
|
+
msg = 'rho, a, b are not either all `None` or all set'
|
|
625
|
+
raise ValueError(msg)
|
|
626
|
+
if theta is None and rho is not None:
|
|
627
|
+
theta = rho
|
|
628
|
+
if log_s is None and theta is not None:
|
|
629
|
+
log_s = jnp.zeros(max_split.size)
|
|
630
|
+
if not _all_none_or_not_none(theta, sparse_on_at):
|
|
631
|
+
msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set'
|
|
632
|
+
raise ValueError(msg)
|
|
633
|
+
|
|
634
|
+
# process multichain settings
|
|
635
|
+
chain_shape = () if num_chains is None else (num_chains,)
|
|
636
|
+
resid_shape = chain_shape + y.shape
|
|
637
|
+
tree_shape = (*chain_shape, num_trees)
|
|
638
|
+
add_chains = partial(_add_chains, chain_shape=chain_shape)
|
|
639
|
+
|
|
640
|
+
# determine batch sizes for reductions
|
|
641
|
+
mesh = _parse_mesh(num_chains, mesh)
|
|
642
|
+
target_platform = _parse_target_platform(
|
|
643
|
+
y, mesh, target_platform, resid_num_batches, count_num_batches, prec_num_batches
|
|
644
|
+
)
|
|
645
|
+
red_cfg = _parse_reduction_configs(
|
|
646
|
+
resid_num_batches,
|
|
647
|
+
count_num_batches,
|
|
648
|
+
prec_num_batches,
|
|
649
|
+
prec_count_num_trees,
|
|
650
|
+
y,
|
|
651
|
+
num_trees,
|
|
652
|
+
mesh,
|
|
653
|
+
target_platform,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# check there aren't too many deactivated predictors
|
|
657
|
+
msg = (
|
|
658
|
+
f'there are more than {filter_splitless_vars=} predictors with no splits, '
|
|
659
|
+
'please increase `filter_splitless_vars` or investigate the missing splits'
|
|
660
|
+
)
|
|
661
|
+
offset = error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg)
|
|
662
|
+
|
|
663
|
+
# initialize all remaining stuff and put it in an unsharded state
|
|
664
|
+
state = State(
|
|
665
|
+
X=X,
|
|
666
|
+
y=y,
|
|
667
|
+
z=jnp.full(resid_shape, offset) if is_binary else None,
|
|
668
|
+
offset=offset,
|
|
669
|
+
resid=jnp.zeros(resid_shape)
|
|
670
|
+
if is_binary
|
|
671
|
+
else jnp.broadcast_to(y - offset[..., None], resid_shape),
|
|
672
|
+
error_cov_inv=add_chains(error_cov_inv),
|
|
673
|
+
prec_scale=_get_prec_scale(error_scale),
|
|
674
|
+
error_cov_df=error_cov_df,
|
|
675
|
+
error_cov_scale=error_cov_scale,
|
|
676
|
+
forest=Forest(
|
|
677
|
+
leaf_tree=make_tree(max_depth, jnp.float32, tree_shape + kshape),
|
|
678
|
+
var_tree=make_tree(
|
|
679
|
+
max_depth - 1, minimal_unsigned_dtype(p - 1), tree_shape
|
|
680
|
+
),
|
|
681
|
+
split_tree=make_tree(max_depth - 1, max_split.dtype, tree_shape),
|
|
682
|
+
affluence_tree=(
|
|
683
|
+
make_tree(max_depth - 1, bool, tree_shape)
|
|
684
|
+
.at[..., 1]
|
|
685
|
+
.set(
|
|
686
|
+
True
|
|
687
|
+
if min_points_per_decision_node is None
|
|
688
|
+
else n >= min_points_per_decision_node
|
|
689
|
+
)
|
|
690
|
+
),
|
|
691
|
+
blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split),
|
|
692
|
+
max_split=max_split,
|
|
693
|
+
grow_prop_count=jnp.zeros(chain_shape, int),
|
|
694
|
+
grow_acc_count=jnp.zeros(chain_shape, int),
|
|
695
|
+
prune_prop_count=jnp.zeros(chain_shape, int),
|
|
696
|
+
prune_acc_count=jnp.zeros(chain_shape, int),
|
|
697
|
+
p_nonterminal=p_nonterminal[tree_depths(2**max_depth)],
|
|
698
|
+
p_propose_grow=p_nonterminal[tree_depths(2 ** (max_depth - 1))],
|
|
699
|
+
leaf_indices=jnp.ones(
|
|
700
|
+
(*tree_shape, n), minimal_unsigned_dtype(2**max_depth - 1)
|
|
701
|
+
),
|
|
702
|
+
min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
|
|
703
|
+
min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
|
|
704
|
+
log_trans_prior=jnp.zeros((*chain_shape, num_trees))
|
|
705
|
+
if save_ratios
|
|
706
|
+
else None,
|
|
707
|
+
log_likelihood=jnp.zeros((*chain_shape, num_trees))
|
|
708
|
+
if save_ratios
|
|
709
|
+
else None,
|
|
710
|
+
leaf_prior_cov_inv=leaf_prior_cov_inv,
|
|
711
|
+
log_s=add_chains(_asarray_or_none(log_s)),
|
|
712
|
+
theta=add_chains(_asarray_or_none(theta)),
|
|
713
|
+
rho=_asarray_or_none(rho),
|
|
714
|
+
a=_asarray_or_none(a),
|
|
715
|
+
b=_asarray_or_none(b),
|
|
716
|
+
),
|
|
717
|
+
config=StepConfig(
|
|
718
|
+
steps_done=jnp.int32(0),
|
|
719
|
+
sparse_on_at=_asarray_or_none(sparse_on_at),
|
|
720
|
+
mesh=mesh,
|
|
721
|
+
**red_cfg,
|
|
722
|
+
),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
# move all arrays to the appropriate device
|
|
726
|
+
return _shard_state(state)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
@partial(jit, donate_argnums=(0,))
|
|
730
|
+
def _get_prec_scale(
|
|
731
|
+
error_scale: Float32[Array, ' n'] | None,
|
|
732
|
+
) -> Float32[Array, ' n'] | None:
|
|
733
|
+
"""Compute 1 / error_scale**2.
|
|
734
|
+
|
|
735
|
+
This is a separate function to use donate_argnums to avoid intermediate
|
|
736
|
+
copies.
|
|
737
|
+
"""
|
|
738
|
+
if error_scale is None:
|
|
739
|
+
return None
|
|
740
|
+
else:
|
|
741
|
+
return jnp.reciprocal(jnp.square(jnp.asarray(error_scale)))
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def _get_blocked_vars(
|
|
745
|
+
filter_splitless_vars: int, max_split: UInt[Array, ' p']
|
|
746
|
+
) -> None | UInt[Array, ' q']:
|
|
747
|
+
"""Initialize the `blocked_vars` field."""
|
|
748
|
+
if filter_splitless_vars:
|
|
749
|
+
(p,) = max_split.shape
|
|
750
|
+
(blocked_vars,) = jnp.nonzero(
|
|
751
|
+
max_split == 0, size=filter_splitless_vars, fill_value=p
|
|
752
|
+
)
|
|
753
|
+
return blocked_vars.astype(minimal_unsigned_dtype(p))
|
|
754
|
+
# see `fully_used_variables` for the type cast
|
|
755
|
+
else:
|
|
756
|
+
return None
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def _add_chains(
|
|
760
|
+
x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...]
|
|
761
|
+
) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None:
|
|
762
|
+
"""Broadcast `x` to all chains."""
|
|
763
|
+
if x is None:
|
|
764
|
+
return None
|
|
765
|
+
else:
|
|
766
|
+
return jnp.broadcast_to(x, chain_shape + x.shape)
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def _parse_mesh(
|
|
770
|
+
num_chains: int | None, mesh: Mesh | dict[str, int] | None
|
|
771
|
+
) -> Mesh | None:
|
|
772
|
+
"""Parse the `mesh` argument."""
|
|
773
|
+
if mesh is None:
|
|
774
|
+
return None
|
|
775
|
+
|
|
776
|
+
# convert dict format to actual mesh
|
|
777
|
+
if isinstance(mesh, dict):
|
|
778
|
+
assert set(mesh).issubset({'chains', 'data'})
|
|
779
|
+
mesh = make_mesh(
|
|
780
|
+
tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh)
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# check there's no chain mesh axis if there are no chains
|
|
784
|
+
if num_chains is None:
|
|
785
|
+
assert 'chains' not in mesh.axis_names
|
|
786
|
+
|
|
787
|
+
# check the axes we use are in auto mode
|
|
788
|
+
assert 'chains' not in mesh.axis_names or 'chains' in _auto_axes(mesh)
|
|
789
|
+
assert 'data' not in mesh.axis_names or 'data' in _auto_axes(mesh)
|
|
790
|
+
|
|
791
|
+
return mesh
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def _parse_target_platform(
|
|
795
|
+
y: Array,
|
|
796
|
+
mesh: Mesh | None,
|
|
797
|
+
target_platform: Literal['cpu', 'gpu'] | None,
|
|
798
|
+
resid_num_batches: int | None | Literal['auto'],
|
|
799
|
+
count_num_batches: int | None | Literal['auto'],
|
|
800
|
+
prec_num_batches: int | None | Literal['auto'],
|
|
801
|
+
) -> Literal['cpu', 'gpu'] | None:
|
|
802
|
+
if mesh is not None:
|
|
803
|
+
assert target_platform is None, 'mesh provided, do not set target_platform'
|
|
804
|
+
return mesh.devices.flat[0].platform
|
|
805
|
+
elif hasattr(y, 'platform'):
|
|
806
|
+
assert target_platform is None, 'device inferred from y, unset target_platform'
|
|
807
|
+
return y.platform()
|
|
808
|
+
elif (
|
|
809
|
+
resid_num_batches == 'auto'
|
|
810
|
+
or count_num_batches == 'auto'
|
|
811
|
+
or prec_num_batches == 'auto'
|
|
812
|
+
):
|
|
813
|
+
assert target_platform in ('cpu', 'gpu')
|
|
814
|
+
return target_platform
|
|
815
|
+
else:
|
|
816
|
+
assert target_platform is None, 'target_platform not used, unset it'
|
|
817
|
+
return target_platform
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def _auto_axes(mesh: Mesh) -> list[str]:
|
|
821
|
+
"""Re-implement `Mesh.auto_axes` because that's missing in jax v0.5."""
|
|
822
|
+
# Mesh.auto_axes added in jax v0.6.0
|
|
823
|
+
return [
|
|
824
|
+
n
|
|
825
|
+
for n, t in zip(mesh.axis_names, mesh.axis_types, strict=True)
|
|
826
|
+
if t == AxisType.Auto
|
|
827
|
+
]
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def _shard_state(state: State) -> State:
|
|
831
|
+
"""Place all fields in the state on the appropriate devices."""
|
|
832
|
+
mesh = state.config.mesh
|
|
833
|
+
if mesh is None:
|
|
834
|
+
return state
|
|
835
|
+
|
|
836
|
+
def shard_leaf(
|
|
837
|
+
x: Array | None, chain_axis: int | None, data_axis: int | None
|
|
838
|
+
) -> Array | None:
|
|
839
|
+
if x is None:
|
|
840
|
+
return None
|
|
841
|
+
|
|
842
|
+
spec = [None] * x.ndim
|
|
843
|
+
if chain_axis is not None and 'chains' in mesh.axis_names:
|
|
844
|
+
spec[chain_axis] = 'chains'
|
|
845
|
+
if data_axis is not None and 'data' in mesh.axis_names:
|
|
846
|
+
spec[data_axis] = 'data'
|
|
847
|
+
|
|
848
|
+
# remove trailing Nones to be consistent with jax's output, it's useful
|
|
849
|
+
# for comparing shardings during debugging
|
|
850
|
+
while spec and spec[-1] is None:
|
|
851
|
+
spec.pop()
|
|
852
|
+
|
|
853
|
+
spec = PartitionSpec(*spec)
|
|
854
|
+
return device_put(x, NamedSharding(mesh, spec), donate=True)
|
|
855
|
+
|
|
856
|
+
return tree.map(
|
|
857
|
+
shard_leaf,
|
|
858
|
+
state,
|
|
859
|
+
chain_vmap_axes(state),
|
|
860
|
+
data_vmap_axes(state),
|
|
861
|
+
is_leaf=lambda x: x is None,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def _all_none_or_not_none(*args):
|
|
866
|
+
is_none = [x is None for x in args]
|
|
867
|
+
return all(is_none) or not any(is_none)
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def _asarray_or_none(x):
|
|
871
|
+
if x is None:
|
|
872
|
+
return None
|
|
873
|
+
return jnp.asarray(x)
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
def _get_platform(mesh: Mesh | None) -> str:
|
|
877
|
+
if mesh is None:
|
|
878
|
+
return get_default_device().platform
|
|
879
|
+
else:
|
|
880
|
+
return mesh.devices.flat[0].platform
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
class _ReductionConfig(TypedDict):
|
|
884
|
+
"""Fields of `StepConfig` related to reductions."""
|
|
885
|
+
|
|
886
|
+
resid_num_batches: int | None
|
|
887
|
+
count_num_batches: int | None
|
|
888
|
+
prec_num_batches: int | None
|
|
889
|
+
prec_count_num_trees: int | None
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
def _parse_reduction_configs(
|
|
893
|
+
resid_num_batches: int | None | Literal['auto'],
|
|
894
|
+
count_num_batches: int | None | Literal['auto'],
|
|
895
|
+
prec_num_batches: int | None | Literal['auto'],
|
|
896
|
+
prec_count_num_trees: int | None | Literal['auto'],
|
|
897
|
+
y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'],
|
|
898
|
+
num_trees: int,
|
|
899
|
+
mesh: Mesh | None,
|
|
900
|
+
target_platform: Literal['cpu', 'gpu'] | None,
|
|
901
|
+
) -> _ReductionConfig:
|
|
902
|
+
"""Determine settings for indexed reduces."""
|
|
903
|
+
n = y.shape[-1]
|
|
904
|
+
n //= get_axis_size(mesh, 'data') # per-device datapoints
|
|
905
|
+
parse_num_batches = partial(_parse_num_batches, target_platform, n)
|
|
906
|
+
return dict(
|
|
907
|
+
resid_num_batches=parse_num_batches(resid_num_batches, 'resid'),
|
|
908
|
+
count_num_batches=parse_num_batches(count_num_batches, 'count'),
|
|
909
|
+
prec_num_batches=parse_num_batches(prec_num_batches, 'prec'),
|
|
910
|
+
prec_count_num_trees=_parse_prec_count_num_trees(
|
|
911
|
+
prec_count_num_trees, num_trees, n
|
|
912
|
+
),
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def _parse_num_batches(
|
|
917
|
+
target_platform: Literal['cpu', 'gpu'] | None,
|
|
918
|
+
n: int,
|
|
919
|
+
num_batches: int | None | Literal['auto'],
|
|
920
|
+
which: Literal['resid', 'count', 'prec'],
|
|
921
|
+
) -> int | None:
|
|
922
|
+
"""Return the number of batches or determine it automatically."""
|
|
923
|
+
final_round = partial(_final_round, n)
|
|
924
|
+
if num_batches != 'auto':
|
|
925
|
+
nb = num_batches
|
|
926
|
+
elif target_platform == 'cpu':
|
|
927
|
+
nb = final_round(16)
|
|
928
|
+
elif target_platform == 'gpu':
|
|
929
|
+
nb = dict(resid=1024, count=2048, prec=1024)[which] # on an A4000
|
|
930
|
+
nb = final_round(nb)
|
|
931
|
+
return nb
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def _final_round(n: int, num: float) -> int | None:
|
|
935
|
+
"""Bound batch size, round number of batches to a power of 2, and disable batching if there's only 1 batch."""
|
|
936
|
+
# at least some elements per batch
|
|
937
|
+
num = min(n // 32, num)
|
|
938
|
+
|
|
939
|
+
# round to the nearest power of 2 because I guess XLA and the hardware
|
|
940
|
+
# will like that (not sure about this, maybe just multiple of 32?)
|
|
941
|
+
num = 2 ** round(log2(num)) if num else 0
|
|
942
|
+
|
|
943
|
+
# disable batching if the batch is as large as the whole dataset
|
|
944
|
+
return num if num > 1 else None
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
def _parse_prec_count_num_trees(
|
|
948
|
+
prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int
|
|
949
|
+
) -> int | None:
|
|
950
|
+
"""Return the number of trees to process at a time or determine it automatically."""
|
|
951
|
+
if prec_count_num_trees != 'auto':
|
|
952
|
+
return prec_count_num_trees
|
|
953
|
+
max_n_by_ntree = 2**27 # about 100M
|
|
954
|
+
pcnt = max_n_by_ntree // max(1, n)
|
|
955
|
+
pcnt = min(num_trees, pcnt)
|
|
956
|
+
pcnt = max(1, pcnt)
|
|
957
|
+
pcnt = _search_divisor(
|
|
958
|
+
pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2))
|
|
959
|
+
)
|
|
960
|
+
if pcnt >= num_trees:
|
|
961
|
+
pcnt = None
|
|
962
|
+
return pcnt
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int:
|
|
966
|
+
"""Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already.
|
|
967
|
+
|
|
968
|
+
If there is none, give up and return `target_divisor`.
|
|
969
|
+
"""
|
|
970
|
+
assert target_divisor >= 1
|
|
971
|
+
assert 1 <= low <= up <= dividend
|
|
972
|
+
if dividend % target_divisor == 0:
|
|
973
|
+
return target_divisor
|
|
974
|
+
candidates = numpy.arange(low, up + 1)
|
|
975
|
+
divisors = candidates[dividend % candidates == 0]
|
|
976
|
+
if divisors.size == 0:
|
|
977
|
+
return target_divisor
|
|
978
|
+
penalty = numpy.abs(divisors - target_divisor)
|
|
979
|
+
closest = numpy.argmin(penalty)
|
|
980
|
+
return divisors[closest].item()
|
|
981
|
+
|
|
982
|
+
|
|
983
|
+
def get_axis_size(mesh: Mesh | None, axis_name: str) -> int:
|
|
984
|
+
if mesh is None or axis_name not in mesh.axis_names:
|
|
985
|
+
return 1
|
|
986
|
+
else:
|
|
987
|
+
i = mesh.axis_names.index(axis_name)
|
|
988
|
+
return mesh.axis_sizes[i]
|
|
989
|
+
|
|
990
|
+
|
|
991
|
+
def chol_with_gersh(
|
|
992
|
+
mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False
|
|
993
|
+
) -> Float32[Array, '*batch_shape k k']:
|
|
994
|
+
"""Cholesky with Gershgorin stabilization, supports batching."""
|
|
995
|
+
return _chol_with_gersh_impl(mat, absolute_eps)
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,))
|
|
999
|
+
def _chol_with_gersh_impl(
|
|
1000
|
+
mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool
|
|
1001
|
+
) -> Float32[Array, '*batch_shape k k']:
|
|
1002
|
+
rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0)
|
|
1003
|
+
eps = jnp.finfo(mat.dtype).eps
|
|
1004
|
+
u = mat.shape[0] * rho * eps
|
|
1005
|
+
if absolute_eps:
|
|
1006
|
+
u += eps
|
|
1007
|
+
mat = mat.at[jnp.diag_indices_from(mat)].add(u)
|
|
1008
|
+
return jnp.linalg.cholesky(mat)
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def _inv_via_chol_with_gersh(mat: Float32[Array, 'k k']) -> Float32[Array, 'k k']:
|
|
1012
|
+
"""Compute matrix inverse via Cholesky with Gershgorin stabilization.
|
|
1013
|
+
|
|
1014
|
+
DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO.
|
|
1015
|
+
"""
|
|
1016
|
+
L = chol_with_gersh(mat)
|
|
1017
|
+
I = jnp.eye(mat.shape[0], dtype=mat.dtype)
|
|
1018
|
+
L_inv = solve_triangular(L, I, lower=True)
|
|
1019
|
+
return L_inv.T @ L_inv
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
def get_num_chains(x: PyTree) -> int | None:
|
|
1023
|
+
"""Get the number of chains of a pytree.
|
|
1024
|
+
|
|
1025
|
+
Find all nodes in the structure that define 'num_chains()', stopping
|
|
1026
|
+
traversal at nodes that define it. Check all values obtained invoking
|
|
1027
|
+
`num_chains` are equal, then return it.
|
|
1028
|
+
"""
|
|
1029
|
+
leaves, _ = flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains'))
|
|
1030
|
+
num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')]
|
|
1031
|
+
ref = num_chains[0]
|
|
1032
|
+
assert all(c == ref for c in num_chains)
|
|
1033
|
+
return ref
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]:
|
|
1037
|
+
"""Return `chain_vmap_axes(x)` but also set to 0 for random keys."""
|
|
1038
|
+
axes = chain_vmap_axes(x)
|
|
1039
|
+
|
|
1040
|
+
def axis_if_key(x, axis):
|
|
1041
|
+
if is_key(x):
|
|
1042
|
+
return 0
|
|
1043
|
+
else:
|
|
1044
|
+
return axis
|
|
1045
|
+
|
|
1046
|
+
return tree.map(axis_if_key, x, axes)
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
def _get_mc_out_axes(
|
|
1050
|
+
fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None]
|
|
1051
|
+
) -> PyTree[int | None]:
|
|
1052
|
+
"""Decide chain vmap axes for outputs."""
|
|
1053
|
+
vmapped_fun = vmap(fun, in_axes=in_axes)
|
|
1054
|
+
out = eval_shape(vmapped_fun, *args)
|
|
1055
|
+
return chain_vmap_axes(out)
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def _find_mesh(x: PyTree) -> Mesh | None:
|
|
1059
|
+
"""Find the mesh used for chains."""
|
|
1060
|
+
|
|
1061
|
+
class MeshFound(Exception):
|
|
1062
|
+
pass
|
|
1063
|
+
|
|
1064
|
+
def find_mesh(x: State | Any):
|
|
1065
|
+
if isinstance(x, State):
|
|
1066
|
+
raise MeshFound(x.config.mesh)
|
|
1067
|
+
|
|
1068
|
+
try:
|
|
1069
|
+
tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State))
|
|
1070
|
+
except MeshFound as e:
|
|
1071
|
+
return e.args[0]
|
|
1072
|
+
else:
|
|
1073
|
+
raise ValueError
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def _split_all_keys(x: PyTree, num_chains: int) -> PyTree:
|
|
1077
|
+
"""Split all random keys in `num_chains` keys."""
|
|
1078
|
+
mesh = _find_mesh(x)
|
|
1079
|
+
|
|
1080
|
+
def split_key(x):
|
|
1081
|
+
if is_key(x):
|
|
1082
|
+
x = random.split(x, num_chains)
|
|
1083
|
+
if mesh is not None and 'chains' in mesh.axis_names:
|
|
1084
|
+
x = device_put(x, NamedSharding(mesh, PartitionSpec('chains')))
|
|
1085
|
+
return x
|
|
1086
|
+
|
|
1087
|
+
return tree.map(split_key, x)
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
def vmap_chains(
|
|
1091
|
+
fun: Callable[..., T], *, auto_split_keys: bool = False
|
|
1092
|
+
) -> Callable[..., T]:
|
|
1093
|
+
"""Apply vmap on chain axes automatically if the inputs are multichain."""
|
|
1094
|
+
|
|
1095
|
+
@wraps(fun)
|
|
1096
|
+
def auto_vmapped_fun(*args, **kwargs) -> T:
|
|
1097
|
+
all_args = args, kwargs
|
|
1098
|
+
num_chains = get_num_chains(all_args)
|
|
1099
|
+
if num_chains is not None:
|
|
1100
|
+
if auto_split_keys:
|
|
1101
|
+
all_args = _split_all_keys(all_args, num_chains)
|
|
1102
|
+
|
|
1103
|
+
def wrapped_fun(args, kwargs):
|
|
1104
|
+
return fun(*args, **kwargs)
|
|
1105
|
+
|
|
1106
|
+
mc_in_axes = _chain_axes_with_keys(all_args)
|
|
1107
|
+
mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes)
|
|
1108
|
+
vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes)
|
|
1109
|
+
return vmapped_fun(*all_args)
|
|
1110
|
+
|
|
1111
|
+
else:
|
|
1112
|
+
return fun(*args, **kwargs)
|
|
1113
|
+
|
|
1114
|
+
return auto_vmapped_fun
|