bartz 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/jaxext.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/jaxext.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -26,11 +26,11 @@ import functools
26
26
  import math
27
27
  import warnings
28
28
 
29
- from scipy import special
30
29
  import jax
30
+ from jax import lax, tree_util
31
31
  from jax import numpy as jnp
32
- from jax import tree_util
33
- from jax import lax
32
+ from scipy import special
33
+
34
34
 
35
35
  def float_type(*args):
36
36
  """
@@ -39,62 +39,46 @@ def float_type(*args):
39
39
  t = jnp.result_type(*args)
40
40
  return jnp.sin(jnp.empty(0, t)).dtype
41
41
 
42
+
42
43
  def castto(func, type):
43
44
  @functools.wraps(func)
44
45
  def newfunc(*args, **kw):
45
46
  return func(*args, **kw).astype(type)
47
+
46
48
  return newfunc
47
49
 
48
- def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
49
- """ version of `jax.pure_callback` that deals correctly with ufuncs,
50
- see `<https://github.com/google/jax/issues/17187>`_ """
51
- if excluded is None:
52
- excluded = ()
53
- shape = jnp.broadcast_shapes(*(
54
- a.shape
55
- for i, a in enumerate(args)
56
- if i not in excluded
57
- ))
58
- ndim = len(shape)
59
- padded_args = [
60
- a if i in excluded
61
- else jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
62
- for i, a in enumerate(args)
63
- ]
64
- result = jax.ShapeDtypeStruct(shape, dtype)
65
- return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs)
66
-
67
- # TODO when jax solves this, check version and piggyback on original if new
68
50
 
69
51
  class scipy:
70
-
71
52
  class special:
72
-
73
53
  @functools.wraps(special.gammainccinv)
74
54
  def gammainccinv(a, y):
75
55
  a = jnp.asarray(a)
76
56
  y = jnp.asarray(y)
57
+ shape = jnp.broadcast_shapes(a.shape, y.shape)
77
58
  dtype = float_type(a.dtype, y.dtype)
59
+ dummy = jax.ShapeDtypeStruct(shape, dtype)
78
60
  ufunc = castto(special.gammainccinv, dtype)
79
- return pure_callback_ufunc(ufunc, dtype, a, y)
61
+ return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
80
62
 
81
63
  class stats:
82
-
83
64
  class invgamma:
84
-
85
65
  def ppf(q, a):
86
66
  return 1 / scipy.special.gammainccinv(a, q)
87
67
 
88
- @functools.wraps(jax.vmap)
68
+
89
69
  def vmap_nodoc(fun, *args, **kw):
90
70
  """
91
- Version of `jax.vmap` that preserves the docstring of the input function.
71
+ Wrapper of `jax.vmap` that preserves the docstring of the input function.
72
+
73
+ This is useful if the docstring already takes into account that the
74
+ arguments have additional axes due to vmap.
92
75
  """
93
76
  doc = fun.__doc__
94
77
  fun = jax.vmap(fun, *args, **kw)
95
78
  fun.__doc__ = doc
96
79
  return fun
97
80
 
81
+
98
82
  def huge_value(x):
99
83
  """
100
84
  Return the maximum value that can be stored in `x`.
@@ -114,19 +98,21 @@ def huge_value(x):
114
98
  else:
115
99
  return jnp.inf
116
100
 
101
+
117
102
  def minimal_unsigned_dtype(max_value):
118
103
  """
119
104
  Return the smallest unsigned integer dtype that can represent a given
120
105
  maximum value (inclusive).
121
106
  """
122
- if max_value < 2 ** 8:
107
+ if max_value < 2**8:
123
108
  return jnp.uint8
124
- if max_value < 2 ** 16:
109
+ if max_value < 2**16:
125
110
  return jnp.uint16
126
- if max_value < 2 ** 32:
111
+ if max_value < 2**32:
127
112
  return jnp.uint32
128
113
  return jnp.uint64
129
114
 
115
+
130
116
  def signed_to_unsigned(int_dtype):
131
117
  """
132
118
  Map a signed integer type to its unsigned counterpart. Unsigned types are
@@ -144,12 +130,14 @@ def signed_to_unsigned(int_dtype):
144
130
  if int_dtype == jnp.int64:
145
131
  return jnp.uint64
146
132
 
133
+
147
134
  def ensure_unsigned(x):
148
135
  """
149
136
  If x has signed integer type, cast it to the unsigned dtype of the same size.
150
137
  """
151
138
  return x.astype(signed_to_unsigned(x.dtype))
152
139
 
140
+
153
141
  @functools.partial(jax.jit, static_argnums=(1,))
154
142
  def unique(x, size, fill_value):
155
143
  """
@@ -177,15 +165,18 @@ def unique(x, size, fill_value):
177
165
  if size == 0:
178
166
  return jnp.empty(0, x.dtype), 0
179
167
  x = jnp.sort(x)
168
+
180
169
  def loop(carry, x):
181
170
  i_out, i_in, last, out = carry
182
171
  i_out = jnp.where(x == last, i_out, i_out + 1)
183
172
  out = out.at[i_out].set(x)
184
173
  return (i_out, i_in + 1, x, out), None
174
+
185
175
  carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
186
176
  (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
187
177
  return out, actual_length + 1
188
178
 
179
+
189
180
  def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
190
181
  """
191
182
  Batch a function such that each batch is smaller than a threshold.
@@ -222,6 +213,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
222
213
  def check_no_nones(axes, tree):
223
214
  def check_not_none(_, axis):
224
215
  assert axis is not None
216
+
225
217
  tree_util.tree_map(check_not_none, tree, axes)
226
218
 
227
219
  def extract_size(axes, tree):
@@ -230,6 +222,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
230
222
  return None
231
223
  else:
232
224
  return x.shape[axis]
225
+
233
226
  sizes = tree_util.tree_map(get_size, tree, axes)
234
227
  sizes, _ = tree_util.tree_flatten(sizes)
235
228
  assert all(s == sizes[0] for s in sizes)
@@ -238,6 +231,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
238
231
  def sum_nbytes(tree):
239
232
  def nbytes(x):
240
233
  return math.prod(x.shape) * x.dtype.itemsize
234
+
241
235
  return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
242
236
 
243
237
  def next_divisor_small(dividend, min_divisor):
@@ -266,6 +260,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
266
260
  return None
267
261
  else:
268
262
  return x
263
+
269
264
  return tree_util.tree_map(pull_nonbatched, tree, axes), tree
270
265
 
271
266
  def push_nonbatched(axes, tree, original_tree):
@@ -274,32 +269,38 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
274
269
  return original_x
275
270
  else:
276
271
  return x
272
+
277
273
  return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
278
274
 
279
275
  def move_axes_out(axes, tree):
280
276
  def move_axis_out(x, axis):
281
277
  return jnp.moveaxis(x, axis, 0)
278
+
282
279
  return tree_util.tree_map(move_axis_out, tree, axes)
283
280
 
284
281
  def move_axes_in(axes, tree):
285
282
  def move_axis_in(x, axis):
286
283
  return jnp.moveaxis(x, 0, axis)
284
+
287
285
  return tree_util.tree_map(move_axis_in, tree, axes)
288
286
 
289
287
  def batch(tree, nbatches):
290
288
  def batch(x):
291
289
  return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
290
+
292
291
  return tree_util.tree_map(batch, tree)
293
292
 
294
293
  def unbatch(tree):
295
294
  def unbatch(x):
296
295
  return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
296
+
297
297
  return tree_util.tree_map(unbatch, tree)
298
298
 
299
299
  def check_same(tree1, tree2):
300
300
  def check_same(x1, x2):
301
301
  assert x1.shape == x2.shape
302
302
  assert x1.dtype == x2.dtype
303
+
303
304
  tree_util.tree_map(check_same, tree1, tree2)
304
305
 
305
306
  initial_in_axes = in_axes
@@ -319,7 +320,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
319
320
  args, nonbatched_args = pull_nonbatched(in_axes, args)
320
321
 
321
322
  total_nbytes = sum_nbytes((args, example_result))
322
- min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
323
+ min_nbatches = total_nbytes // max_io_nbytes + bool(
324
+ total_nbytes % max_io_nbytes
325
+ )
323
326
  min_nbatches = max(1, min_nbatches)
324
327
  nbatches = next_divisor(size, min_nbatches)
325
328
  assert 1 <= nbatches <= max(1, size)
@@ -329,7 +332,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
329
332
  batch_nbytes = total_nbytes // nbatches
330
333
  if batch_nbytes > max_io_nbytes:
331
334
  assert size == nbatches
332
- warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
335
+ warnings.warn(
336
+ f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
337
+ )
333
338
 
334
339
  def loop(_, args):
335
340
  args = move_axes_in(in_axes, args)
@@ -352,10 +357,11 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
352
357
 
353
358
  return batched_func
354
359
 
360
+
355
361
  @tree_util.register_pytree_node_class
356
362
  class LeafDict(dict):
357
- """ dictionary that acts as a leaf in jax pytrees, to store compile-time
358
- values """
363
+ """dictionary that acts as a leaf in jax pytrees, to store compile-time
364
+ values"""
359
365
 
360
366
  def tree_flatten(self):
361
367
  return (), self
bartz/mcmcloop.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcloop.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -29,22 +29,21 @@ Functions that implement the full BART posterior MCMC loop.
29
29
  import functools
30
30
 
31
31
  import jax
32
- from jax import random
33
- from jax import debug
32
+ from jax import debug, lax, random, tree
34
33
  from jax import numpy as jnp
35
- from jax import lax
36
34
 
37
- from . import jaxext
38
- from . import grove
39
- from . import mcmcstep
35
+ from . import grove, jaxext, mcmcstep
40
36
 
41
- @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
42
- def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
37
+
38
+ @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))
39
+ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
43
40
  """
44
41
  Run the MCMC for the BART posterior.
45
42
 
46
43
  Parameters
47
44
  ----------
45
+ key : jax.dtypes.prng_key array
46
+ The key for random number generation.
48
47
  bart : dict
49
48
  The initial MCMC state, as created and updated by the functions in
50
49
  `bartz.mcmcstep`.
@@ -54,88 +53,123 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
54
53
  The number of iterations to save.
55
54
  n_skip : int
56
55
  The number of iterations to skip between each saved iteration, plus 1.
56
+ The effective burn-in is ``n_burn + n_skip - 1``.
57
57
  callback : callable
58
58
  An arbitrary function run at each iteration, called with the following
59
59
  arguments, passed by keyword:
60
60
 
61
61
  bart : dict
62
- The current MCMC state.
62
+ The MCMC state just after updating it.
63
63
  burnin : bool
64
64
  Whether the last iteration was in the burn-in phase.
65
65
  i_total : int
66
66
  The index of the last iteration (0-based).
67
67
  i_skip : int
68
- The index of the last iteration, starting from the last saved
69
- iteration.
68
+ The number of MCMC updates from the last saved state. The initial
69
+ state counts as saved, even if it's not copied into the trace.
70
70
  n_burn, n_save, n_skip : int
71
71
  The corresponding arguments as-is.
72
72
 
73
73
  Since this function is called under the jax jit, the values are not
74
74
  available at the time the Python code is executed. Use the utilities in
75
75
  `jax.debug` to access the values at actual runtime.
76
- key : jax.dtypes.prng_key array
77
- The key for random number generation.
78
76
 
79
77
  Returns
80
78
  -------
81
79
  bart : dict
82
80
  The final MCMC state.
83
- burnin_trace : dict
81
+ burnin_trace : dict of (n_burn, ...) arrays
84
82
  The trace of the burn-in phase, containing the following subset of
85
83
  fields from the `bart` dictionary, with an additional head index that
86
84
  runs over MCMC iterations: 'sigma2', 'grow_prop_count',
87
85
  'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
88
- main_trace : dict
86
+ main_trace : dict of (n_save, ...) arrays
89
87
  The trace of the main phase, containing the following subset of fields
90
88
  from the `bart` dictionary, with an additional head index that runs
91
89
  over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
92
90
  the fields in `burnin_trace`.
91
+
92
+ Notes
93
+ -----
94
+ The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
95
+ not include the initial state, and include the final state.
93
96
  """
94
97
 
95
- tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios'
98
+ tracevars_light = (
99
+ 'sigma2',
100
+ 'grow_prop_count',
101
+ 'grow_acc_count',
102
+ 'prune_prop_count',
103
+ 'prune_acc_count',
104
+ 'ratios',
105
+ )
106
+ tracevars_heavy = ('leaf_trees', 'var_trees', 'split_trees')
107
+
108
+ def empty_trace(length, bart, tracelist):
109
+ bart = {k: v for k, v in bart.items() if k in tracelist}
110
+ return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
96
111
 
97
- tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees')
112
+ trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
113
+ trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
98
114
 
99
115
  callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
100
116
 
101
- def inner_loop(carry, _, tracelist, burnin):
102
- bart, i_total, i_skip, key = carry
117
+ carry = (bart, 0, key, trace_light, trace_heavy)
118
+
119
+ def loop(carry, _):
120
+ bart, i_total, key, trace_light, trace_heavy = carry
121
+
103
122
  key, subkey = random.split(key)
104
- bart = mcmcstep.step(bart, subkey)
105
- callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
106
- output = {key: bart[key] for key in tracelist if key in bart}
107
- return (bart, i_total + 1, i_skip + 1, key), output
108
-
109
- def empty_trace(bart, tracelist):
110
- return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart)
111
-
112
- if n_burn > 0:
113
- carry = bart, 0, 0, key
114
- burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
115
- (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
116
- else:
117
- i_total = 0
118
- burnin_trace = empty_trace(bart, tracelist_burnin)
119
-
120
- def outer_loop(carry, _):
121
- bart, i_total, key = carry
122
- main_loop = functools.partial(inner_loop, tracelist=[], burnin=False)
123
- inner_carry = bart, i_total, 0, key
124
- (bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip)
125
- output = {key: bart[key] for key in tracelist_main if key in bart}
126
- return (bart, i_total, key), output
127
-
128
- if n_save > 0:
129
- carry = bart, i_total, key
130
- (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
131
- else:
132
- main_trace = empty_trace(bart, tracelist_main)
123
+ bart = mcmcstep.step(subkey, bart)
124
+
125
+ burnin = i_total < n_burn
126
+ i_skip = jnp.where(
127
+ burnin,
128
+ i_total + 1,
129
+ (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
130
+ )
131
+ callback(
132
+ bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw
133
+ )
134
+
135
+ i_heavy = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
136
+ i_light = jnp.where(burnin, i_total, n_burn + i_heavy)
137
+
138
+ def update_trace(index, trace, bart):
139
+ bart = {k: v for k, v in bart.items() if k in trace}
140
+
141
+ def assign_at_index(trace_array, state_array):
142
+ if trace_array.size:
143
+ return trace_array.at[index, ...].set(state_array)
144
+ else:
145
+ # this handles the case where a trace is empty (e.g.,
146
+ # no burn-in) because jax refuses to index into an array
147
+ # of length 0
148
+ return trace_array
149
+
150
+ return tree.map(assign_at_index, trace, bart)
151
+
152
+ trace_heavy = update_trace(i_heavy, trace_heavy, bart)
153
+ trace_light = update_trace(i_light, trace_light, bart)
154
+
155
+ i_total += 1
156
+ carry = (bart, i_total, key, trace_light, trace_heavy)
157
+ return carry, None
158
+
159
+ carry, _ = lax.scan(loop, carry, None, n_burn + n_skip * n_save)
160
+
161
+ bart, _, _, trace_light, trace_heavy = carry
162
+
163
+ burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
164
+ main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
165
+ main_trace.update(trace_heavy)
133
166
 
134
167
  return bart, burnin_trace, main_trace
135
168
 
169
+
136
170
  @functools.lru_cache
137
- # cache to make the callback function object unique, such that the jit
138
- # of run_mcmc recognizes it
171
+ # cache to make the callback function object unique, such that the jit
172
+ # of run_mcmc recognizes it
139
173
  def make_simple_print_callback(printevery):
140
174
  """
141
175
  Create a logging callback function for MCMC iterations.
@@ -150,6 +184,7 @@ def make_simple_print_callback(printevery):
150
184
  callback : callable
151
185
  A function in the format required by `run_mcmc`.
152
186
  """
187
+
153
188
  def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
154
189
  prop_total = len(bart['leaf_trees'])
155
190
  grow_prop = bart['grow_prop_count'] / prop_total
@@ -158,18 +193,40 @@ def make_simple_print_callback(printevery):
158
193
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
159
194
  n_total = n_burn + n_save * n_skip
160
195
  printcond = (i_total + 1) % printevery == 0
161
- debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
196
+ debug.callback(
197
+ _simple_print_callback,
198
+ burnin,
199
+ i_total,
200
+ n_total,
201
+ grow_prop,
202
+ grow_acc,
203
+ prune_prop,
204
+ prune_acc,
205
+ printcond,
206
+ )
207
+
162
208
  return callback
163
209
 
164
- def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
210
+
211
+ def _simple_print_callback(
212
+ burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond
213
+ ):
165
214
  if printcond:
166
215
  burnin_flag = ' (burnin)' if burnin else ''
167
216
  total_str = str(n_total)
168
217
  ndigits = len(total_str)
169
- i_str = str(i_total + 1).rjust(ndigits)
170
- print(f'Iteration {i_str}/{total_str} '
218
+ i_str = str(i_total.item() + 1).rjust(ndigits)
219
+ # I do i_total.item() + 1 instead of just i_total + 1 to solve a bug
220
+ # originating when jax is combined with some outdated dependencies. (I
221
+ # did not track down which dependencies exactly.) Doing .item() makes
222
+ # the + 1 operation be done by Python instead of by jax. The bug is that
223
+ # jax hangs completely, with a secondary thread blocked at this line.
224
+ print(
225
+ f'Iteration {i_str}/{total_str} '
171
226
  f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
172
- f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
227
+ f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
228
+ )
229
+
173
230
 
174
231
  @jax.jit
175
232
  def evaluate_trace(trace, X):
@@ -189,9 +246,13 @@ def evaluate_trace(trace, X):
189
246
  The predictions for each iteration of the MCMC.
190
247
  """
191
248
  evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
192
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0))
249
+ evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
250
+
193
251
  def loop(_, state):
194
- values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees'])
252
+ values = evaluate_trees(
253
+ X, state['leaf_trees'], state['var_trees'], state['split_trees']
254
+ )
195
255
  return None, jnp.sum(values, axis=0, dtype=jnp.float32)
256
+
196
257
  _, y = lax.scan(loop, None, trace)
197
258
  return y