bartz 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/jaxext.py CHANGED
@@ -26,11 +26,11 @@ import functools
26
26
  import math
27
27
  import warnings
28
28
 
29
- from scipy import special
30
29
  import jax
30
+ from jax import lax, tree_util
31
31
  from jax import numpy as jnp
32
- from jax import tree_util
33
- from jax import lax
32
+ from scipy import special
33
+
34
34
 
35
35
  def float_type(*args):
36
36
  """
@@ -39,16 +39,17 @@ def float_type(*args):
39
39
  t = jnp.result_type(*args)
40
40
  return jnp.sin(jnp.empty(0, t)).dtype
41
41
 
42
+
42
43
  def castto(func, type):
43
44
  @functools.wraps(func)
44
45
  def newfunc(*args, **kw):
45
46
  return func(*args, **kw).astype(type)
47
+
46
48
  return newfunc
47
49
 
48
- class scipy:
49
50
 
51
+ class scipy:
50
52
  class special:
51
-
52
53
  @functools.wraps(special.gammainccinv)
53
54
  def gammainccinv(a, y):
54
55
  a = jnp.asarray(a)
@@ -60,22 +61,24 @@ class scipy:
60
61
  return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
61
62
 
62
63
  class stats:
63
-
64
64
  class invgamma:
65
-
66
65
  def ppf(q, a):
67
66
  return 1 / scipy.special.gammainccinv(a, q)
68
67
 
69
- @functools.wraps(jax.vmap)
68
+
70
69
  def vmap_nodoc(fun, *args, **kw):
71
70
  """
72
- Version of `jax.vmap` that preserves the docstring of the input function.
71
+ Wrapper of `jax.vmap` that preserves the docstring of the input function.
72
+
73
+ This is useful if the docstring already takes into account that the
74
+ arguments have additional axes due to vmap.
73
75
  """
74
76
  doc = fun.__doc__
75
77
  fun = jax.vmap(fun, *args, **kw)
76
78
  fun.__doc__ = doc
77
79
  return fun
78
80
 
81
+
79
82
  def huge_value(x):
80
83
  """
81
84
  Return the maximum value that can be stored in `x`.
@@ -95,19 +98,21 @@ def huge_value(x):
95
98
  else:
96
99
  return jnp.inf
97
100
 
101
+
98
102
  def minimal_unsigned_dtype(max_value):
99
103
  """
100
104
  Return the smallest unsigned integer dtype that can represent a given
101
105
  maximum value (inclusive).
102
106
  """
103
- if max_value < 2 ** 8:
107
+ if max_value < 2**8:
104
108
  return jnp.uint8
105
- if max_value < 2 ** 16:
109
+ if max_value < 2**16:
106
110
  return jnp.uint16
107
- if max_value < 2 ** 32:
111
+ if max_value < 2**32:
108
112
  return jnp.uint32
109
113
  return jnp.uint64
110
114
 
115
+
111
116
  def signed_to_unsigned(int_dtype):
112
117
  """
113
118
  Map a signed integer type to its unsigned counterpart. Unsigned types are
@@ -125,12 +130,14 @@ def signed_to_unsigned(int_dtype):
125
130
  if int_dtype == jnp.int64:
126
131
  return jnp.uint64
127
132
 
133
+
128
134
  def ensure_unsigned(x):
129
135
  """
130
136
  If x has signed integer type, cast it to the unsigned dtype of the same size.
131
137
  """
132
138
  return x.astype(signed_to_unsigned(x.dtype))
133
139
 
140
+
134
141
  @functools.partial(jax.jit, static_argnums=(1,))
135
142
  def unique(x, size, fill_value):
136
143
  """
@@ -158,15 +165,18 @@ def unique(x, size, fill_value):
158
165
  if size == 0:
159
166
  return jnp.empty(0, x.dtype), 0
160
167
  x = jnp.sort(x)
168
+
161
169
  def loop(carry, x):
162
170
  i_out, i_in, last, out = carry
163
171
  i_out = jnp.where(x == last, i_out, i_out + 1)
164
172
  out = out.at[i_out].set(x)
165
173
  return (i_out, i_in + 1, x, out), None
174
+
166
175
  carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
167
176
  (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
168
177
  return out, actual_length + 1
169
178
 
179
+
170
180
  def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
171
181
  """
172
182
  Batch a function such that each batch is smaller than a threshold.
@@ -203,6 +213,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
203
213
  def check_no_nones(axes, tree):
204
214
  def check_not_none(_, axis):
205
215
  assert axis is not None
216
+
206
217
  tree_util.tree_map(check_not_none, tree, axes)
207
218
 
208
219
  def extract_size(axes, tree):
@@ -211,6 +222,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
211
222
  return None
212
223
  else:
213
224
  return x.shape[axis]
225
+
214
226
  sizes = tree_util.tree_map(get_size, tree, axes)
215
227
  sizes, _ = tree_util.tree_flatten(sizes)
216
228
  assert all(s == sizes[0] for s in sizes)
@@ -219,6 +231,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
219
231
  def sum_nbytes(tree):
220
232
  def nbytes(x):
221
233
  return math.prod(x.shape) * x.dtype.itemsize
234
+
222
235
  return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
223
236
 
224
237
  def next_divisor_small(dividend, min_divisor):
@@ -247,6 +260,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
247
260
  return None
248
261
  else:
249
262
  return x
263
+
250
264
  return tree_util.tree_map(pull_nonbatched, tree, axes), tree
251
265
 
252
266
  def push_nonbatched(axes, tree, original_tree):
@@ -255,32 +269,38 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
255
269
  return original_x
256
270
  else:
257
271
  return x
272
+
258
273
  return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
259
274
 
260
275
  def move_axes_out(axes, tree):
261
276
  def move_axis_out(x, axis):
262
277
  return jnp.moveaxis(x, axis, 0)
278
+
263
279
  return tree_util.tree_map(move_axis_out, tree, axes)
264
280
 
265
281
  def move_axes_in(axes, tree):
266
282
  def move_axis_in(x, axis):
267
283
  return jnp.moveaxis(x, 0, axis)
284
+
268
285
  return tree_util.tree_map(move_axis_in, tree, axes)
269
286
 
270
287
  def batch(tree, nbatches):
271
288
  def batch(x):
272
289
  return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
290
+
273
291
  return tree_util.tree_map(batch, tree)
274
292
 
275
293
  def unbatch(tree):
276
294
  def unbatch(x):
277
295
  return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
296
+
278
297
  return tree_util.tree_map(unbatch, tree)
279
298
 
280
299
  def check_same(tree1, tree2):
281
300
  def check_same(x1, x2):
282
301
  assert x1.shape == x2.shape
283
302
  assert x1.dtype == x2.dtype
303
+
284
304
  tree_util.tree_map(check_same, tree1, tree2)
285
305
 
286
306
  initial_in_axes = in_axes
@@ -300,7 +320,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
300
320
  args, nonbatched_args = pull_nonbatched(in_axes, args)
301
321
 
302
322
  total_nbytes = sum_nbytes((args, example_result))
303
- min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
323
+ min_nbatches = total_nbytes // max_io_nbytes + bool(
324
+ total_nbytes % max_io_nbytes
325
+ )
304
326
  min_nbatches = max(1, min_nbatches)
305
327
  nbatches = next_divisor(size, min_nbatches)
306
328
  assert 1 <= nbatches <= max(1, size)
@@ -310,7 +332,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
310
332
  batch_nbytes = total_nbytes // nbatches
311
333
  if batch_nbytes > max_io_nbytes:
312
334
  assert size == nbatches
313
- warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
335
+ warnings.warn(
336
+ f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
337
+ )
314
338
 
315
339
  def loop(_, args):
316
340
  args = move_axes_in(in_axes, args)
@@ -333,10 +357,11 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
333
357
 
334
358
  return batched_func
335
359
 
360
+
336
361
  @tree_util.register_pytree_node_class
337
362
  class LeafDict(dict):
338
- """ dictionary that acts as a leaf in jax pytrees, to store compile-time
339
- values """
363
+ """dictionary that acts as a leaf in jax pytrees, to store compile-time
364
+ values"""
340
365
 
341
366
  def tree_flatten(self):
342
367
  return (), self
bartz/mcmcloop.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcloop.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -29,22 +29,21 @@ Functions that implement the full BART posterior MCMC loop.
29
29
  import functools
30
30
 
31
31
  import jax
32
- from jax import random
33
- from jax import debug
32
+ from jax import debug, lax, random, tree
34
33
  from jax import numpy as jnp
35
- from jax import lax
36
34
 
37
- from . import jaxext
38
- from . import grove
39
- from . import mcmcstep
35
+ from . import grove, jaxext, mcmcstep
40
36
 
41
- @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
42
- def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
37
+
38
+ @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))
39
+ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
43
40
  """
44
41
  Run the MCMC for the BART posterior.
45
42
 
46
43
  Parameters
47
44
  ----------
45
+ key : jax.dtypes.prng_key array
46
+ The key for random number generation.
48
47
  bart : dict
49
48
  The initial MCMC state, as created and updated by the functions in
50
49
  `bartz.mcmcstep`.
@@ -54,88 +53,123 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
54
53
  The number of iterations to save.
55
54
  n_skip : int
56
55
  The number of iterations to skip between each saved iteration, plus 1.
56
+ The effective burn-in is ``n_burn + n_skip - 1``.
57
57
  callback : callable
58
58
  An arbitrary function run at each iteration, called with the following
59
59
  arguments, passed by keyword:
60
60
 
61
61
  bart : dict
62
- The current MCMC state.
62
+ The MCMC state just after updating it.
63
63
  burnin : bool
64
64
  Whether the last iteration was in the burn-in phase.
65
65
  i_total : int
66
66
  The index of the last iteration (0-based).
67
67
  i_skip : int
68
- The index of the last iteration, starting from the last saved
69
- iteration.
68
+ The number of MCMC updates from the last saved state. The initial
69
+ state counts as saved, even if it's not copied into the trace.
70
70
  n_burn, n_save, n_skip : int
71
71
  The corresponding arguments as-is.
72
72
 
73
73
  Since this function is called under the jax jit, the values are not
74
74
  available at the time the Python code is executed. Use the utilities in
75
75
  `jax.debug` to access the values at actual runtime.
76
- key : jax.dtypes.prng_key array
77
- The key for random number generation.
78
76
 
79
77
  Returns
80
78
  -------
81
79
  bart : dict
82
80
  The final MCMC state.
83
- burnin_trace : dict
81
+ burnin_trace : dict of (n_burn, ...) arrays
84
82
  The trace of the burn-in phase, containing the following subset of
85
83
  fields from the `bart` dictionary, with an additional head index that
86
84
  runs over MCMC iterations: 'sigma2', 'grow_prop_count',
87
85
  'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
88
- main_trace : dict
86
+ main_trace : dict of (n_save, ...) arrays
89
87
  The trace of the main phase, containing the following subset of fields
90
88
  from the `bart` dictionary, with an additional head index that runs
91
89
  over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
92
90
  the fields in `burnin_trace`.
91
+
92
+ Notes
93
+ -----
94
+ The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
95
+ not include the initial state, and include the final state.
93
96
  """
94
97
 
95
- tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios'
98
+ tracevars_light = (
99
+ 'sigma2',
100
+ 'grow_prop_count',
101
+ 'grow_acc_count',
102
+ 'prune_prop_count',
103
+ 'prune_acc_count',
104
+ 'ratios',
105
+ )
106
+ tracevars_heavy = ('leaf_trees', 'var_trees', 'split_trees')
107
+
108
+ def empty_trace(length, bart, tracelist):
109
+ bart = {k: v for k, v in bart.items() if k in tracelist}
110
+ return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
96
111
 
97
- tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees')
112
+ trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
113
+ trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
98
114
 
99
115
  callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
100
116
 
101
- def inner_loop(carry, _, tracelist, burnin):
102
- bart, i_total, i_skip, key = carry
117
+ carry = (bart, 0, key, trace_light, trace_heavy)
118
+
119
+ def loop(carry, _):
120
+ bart, i_total, key, trace_light, trace_heavy = carry
121
+
103
122
  key, subkey = random.split(key)
104
- bart = mcmcstep.step(bart, subkey)
105
- callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
106
- output = {key: bart[key] for key in tracelist if key in bart}
107
- return (bart, i_total + 1, i_skip + 1, key), output
108
-
109
- def empty_trace(bart, tracelist):
110
- return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart)
111
-
112
- if n_burn > 0:
113
- carry = bart, 0, 0, key
114
- burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
115
- (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
116
- else:
117
- i_total = 0
118
- burnin_trace = empty_trace(bart, tracelist_burnin)
119
-
120
- def outer_loop(carry, _):
121
- bart, i_total, key = carry
122
- main_loop = functools.partial(inner_loop, tracelist=[], burnin=False)
123
- inner_carry = bart, i_total, 0, key
124
- (bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip)
125
- output = {key: bart[key] for key in tracelist_main if key in bart}
126
- return (bart, i_total, key), output
127
-
128
- if n_save > 0:
129
- carry = bart, i_total, key
130
- (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
131
- else:
132
- main_trace = empty_trace(bart, tracelist_main)
123
+ bart = mcmcstep.step(subkey, bart)
124
+
125
+ burnin = i_total < n_burn
126
+ i_skip = jnp.where(
127
+ burnin,
128
+ i_total + 1,
129
+ (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
130
+ )
131
+ callback(
132
+ bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw
133
+ )
134
+
135
+ i_heavy = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
136
+ i_light = jnp.where(burnin, i_total, n_burn + i_heavy)
137
+
138
+ def update_trace(index, trace, bart):
139
+ bart = {k: v for k, v in bart.items() if k in trace}
140
+
141
+ def assign_at_index(trace_array, state_array):
142
+ if trace_array.size:
143
+ return trace_array.at[index, ...].set(state_array)
144
+ else:
145
+ # this handles the case where a trace is empty (e.g.,
146
+ # no burn-in) because jax refuses to index into an array
147
+ # of length 0
148
+ return trace_array
149
+
150
+ return tree.map(assign_at_index, trace, bart)
151
+
152
+ trace_heavy = update_trace(i_heavy, trace_heavy, bart)
153
+ trace_light = update_trace(i_light, trace_light, bart)
154
+
155
+ i_total += 1
156
+ carry = (bart, i_total, key, trace_light, trace_heavy)
157
+ return carry, None
158
+
159
+ carry, _ = lax.scan(loop, carry, None, n_burn + n_skip * n_save)
160
+
161
+ bart, _, _, trace_light, trace_heavy = carry
162
+
163
+ burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
164
+ main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
165
+ main_trace.update(trace_heavy)
133
166
 
134
167
  return bart, burnin_trace, main_trace
135
168
 
169
+
136
170
  @functools.lru_cache
137
- # cache to make the callback function object unique, such that the jit
138
- # of run_mcmc recognizes it
171
+ # cache to make the callback function object unique, such that the jit
172
+ # of run_mcmc recognizes it
139
173
  def make_simple_print_callback(printevery):
140
174
  """
141
175
  Create a logging callback function for MCMC iterations.
@@ -150,6 +184,7 @@ def make_simple_print_callback(printevery):
150
184
  callback : callable
151
185
  A function in the format required by `run_mcmc`.
152
186
  """
187
+
153
188
  def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
154
189
  prop_total = len(bart['leaf_trees'])
155
190
  grow_prop = bart['grow_prop_count'] / prop_total
@@ -158,18 +193,40 @@ def make_simple_print_callback(printevery):
158
193
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
159
194
  n_total = n_burn + n_save * n_skip
160
195
  printcond = (i_total + 1) % printevery == 0
161
- debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
196
+ debug.callback(
197
+ _simple_print_callback,
198
+ burnin,
199
+ i_total,
200
+ n_total,
201
+ grow_prop,
202
+ grow_acc,
203
+ prune_prop,
204
+ prune_acc,
205
+ printcond,
206
+ )
207
+
162
208
  return callback
163
209
 
164
- def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
210
+
211
+ def _simple_print_callback(
212
+ burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond
213
+ ):
165
214
  if printcond:
166
215
  burnin_flag = ' (burnin)' if burnin else ''
167
216
  total_str = str(n_total)
168
217
  ndigits = len(total_str)
169
- i_str = str(i_total + 1).rjust(ndigits)
170
- print(f'Iteration {i_str}/{total_str} '
218
+ i_str = str(i_total.item() + 1).rjust(ndigits)
219
+ # I do i_total.item() + 1 instead of just i_total + 1 to solve a bug
220
+ # originating when jax is combined with some outdated dependencies. (I
221
+ # did not track down which dependencies exactly.) Doing .item() makes
222
+ # the + 1 operation be done by Python instead of by jax. The bug is that
223
+ # jax hangs completely, with a secondary thread blocked at this line.
224
+ print(
225
+ f'Iteration {i_str}/{total_str} '
171
226
  f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
172
- f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
227
+ f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
228
+ )
229
+
173
230
 
174
231
  @jax.jit
175
232
  def evaluate_trace(trace, X):
@@ -189,9 +246,13 @@ def evaluate_trace(trace, X):
189
246
  The predictions for each iteration of the MCMC.
190
247
  """
191
248
  evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
192
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0))
249
+ evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
250
+
193
251
  def loop(_, state):
194
- values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees'])
252
+ values = evaluate_trees(
253
+ X, state['leaf_trees'], state['var_trees'], state['split_trees']
254
+ )
195
255
  return None, jnp.sum(values, axis=0, dtype=jnp.float32)
256
+
196
257
  _, y = lax.scan(loop, None, trace)
197
258
  return y