bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/grove.py CHANGED
@@ -22,93 +22,113 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """
25
+ """Functions to create and manipulate binary decision trees."""
26
26
 
27
- Functions to create and manipulate binary trees.
28
-
29
- A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
30
-
31
- A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
32
-
33
- The 'leaf' array contains the values in the leaves.
27
+ import math
28
+ from functools import partial
29
+ from typing import Protocol
34
30
 
35
- The 'var' array contains the axes along which the decision nodes operate.
31
+ import jax
32
+ from jax import jit, lax
33
+ from jax import numpy as jnp
34
+ from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Real, Shaped, UInt
36
35
 
37
- The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
36
+ from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
38
37
 
39
- Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
40
38
 
41
- """
39
+ class TreeHeaps(Protocol):
40
+ """A protocol for dataclasses that represent trees.
42
41
 
43
- import functools
44
- import math
42
+ A tree is represented with arrays as a heap. The root node is at index 1.
43
+ The children nodes of a node at index :math:`i` are at indices :math:`2i`
44
+ (left child) and :math:`2i + 1` (right child). The array element at index 0
45
+ is unused.
45
46
 
46
- from jax import lax
47
- from jax import numpy as jnp
47
+ Parameters
48
+ ----------
49
+ leaf_tree
50
+ The values in the leaves of the trees. This array can be dirty, i.e.,
51
+ unused nodes can have whatever value.
52
+ var_tree
53
+ The axes along which the decision nodes operate. This array can be
54
+ dirty but for the always unused node at index 0 which must be set to 0.
55
+ split_tree
56
+ The decision boundaries of the trees. The boundaries are open on the
57
+ right, i.e., a point belongs to the left child iff x < split. Whether a
58
+ node is a leaf is indicated by the corresponding 'split' element being
59
+ 0. Unused nodes also have split set to 0. This array can't be dirty.
60
+
61
+ Notes
62
+ -----
63
+ Since the nodes at the bottom can only be leaves and not decision nodes,
64
+ `var_tree` and `split_tree` are half as long as `leaf_tree`.
65
+ """
48
66
 
49
- from . import jaxext
67
+ leaf_tree: Float32[Array, '* 2**d']
68
+ var_tree: UInt[Array, '* 2**(d-1)']
69
+ split_tree: UInt[Array, '* 2**(d-1)']
50
70
 
51
71
 
52
- def make_tree(depth, dtype):
72
+ def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']:
53
73
  """
54
74
  Make an array to represent a binary tree.
55
75
 
56
76
  Parameters
57
77
  ----------
58
- depth : int
78
+ depth
59
79
  The maximum depth of the tree. Depth 1 means that there is only a root
60
80
  node.
61
- dtype : dtype
81
+ dtype
62
82
  The dtype of the array.
63
83
 
64
84
  Returns
65
85
  -------
66
- tree : array
67
- An array of zeroes with shape (2 ** depth,).
86
+ An array of zeroes with the appropriate shape.
68
87
  """
69
88
  return jnp.zeros(2**depth, dtype)
70
89
 
71
90
 
72
- def tree_depth(tree):
91
+ def tree_depth(tree: Shaped[Array, '* 2**d']) -> int:
73
92
  """
74
93
  Return the maximum depth of a tree.
75
94
 
76
95
  Parameters
77
96
  ----------
78
- tree : array
97
+ tree
79
98
  A tree created by `make_tree`. If the array is ND, the tree structure is
80
99
  assumed to be along the last axis.
81
100
 
82
101
  Returns
83
102
  -------
84
- depth : int
85
- The maximum depth of the tree.
103
+ The maximum depth of the tree.
86
104
  """
87
- return int(round(math.log2(tree.shape[-1])))
105
+ return round(math.log2(tree.shape[-1]))
88
106
 
89
107
 
90
- def traverse_tree(x, var_tree, split_tree):
108
+ def traverse_tree(
109
+ x: Real[Array, ' p'],
110
+ var_tree: UInt[Array, ' 2**(d-1)'],
111
+ split_tree: UInt[Array, ' 2**(d-1)'],
112
+ ) -> Int32[Array, '']:
91
113
  """
92
114
  Find the leaf where a point falls into.
93
115
 
94
116
  Parameters
95
117
  ----------
96
- x : array (p,)
118
+ x
97
119
  The coordinates to evaluate the tree at.
98
- var_tree : array (2 ** (d - 1),)
120
+ var_tree
99
121
  The decision axes of the tree.
100
- split_tree : array (2 ** (d - 1),)
122
+ split_tree
101
123
  The decision boundaries of the tree.
102
124
 
103
125
  Returns
104
126
  -------
105
- index : int
106
- The index of the leaf.
127
+ The index of the leaf.
107
128
  """
108
-
109
129
  carry = (
110
130
  jnp.zeros((), bool),
111
- jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
131
+ jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
112
132
  )
113
133
 
114
134
  def loop(carry, _):
@@ -128,111 +148,107 @@ def traverse_tree(x, var_tree, split_tree):
128
148
  return index
129
149
 
130
150
 
131
- @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
132
- @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
133
- def traverse_forest(X, var_trees, split_trees):
151
+ @partial(vmap_nodoc, in_axes=(None, 0, 0))
152
+ @partial(vmap_nodoc, in_axes=(1, None, None))
153
+ def traverse_forest(
154
+ X: Real[Array, 'p n'],
155
+ var_trees: UInt[Array, 'm 2**(d-1)'],
156
+ split_trees: UInt[Array, 'm 2**(d-1)'],
157
+ ) -> Int32[Array, 'm n']:
134
158
  """
135
159
  Find the leaves where points fall into.
136
160
 
137
161
  Parameters
138
162
  ----------
139
- X : array (p, n)
163
+ X
140
164
  The coordinates to evaluate the trees at.
141
- var_trees : array (m, 2 ** (d - 1))
165
+ var_trees
142
166
  The decision axes of the trees.
143
- split_trees : array (m, 2 ** (d - 1))
167
+ split_trees
144
168
  The decision boundaries of the trees.
145
169
 
146
170
  Returns
147
171
  -------
148
- indices : array (m, n)
149
- The indices of the leaves.
172
+ The indices of the leaves.
150
173
  """
151
174
  return traverse_tree(X, var_trees, split_trees)
152
175
 
153
176
 
154
- def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
177
+ def evaluate_forest(
178
+ X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True
179
+ ) -> Float32[Array, ' n'] | Float32[Array, 'm n']:
155
180
  """
156
181
  Evaluate a ensemble of trees at an array of points.
157
182
 
158
183
  Parameters
159
184
  ----------
160
- X : array (p, n)
185
+ X
161
186
  The coordinates to evaluate the trees at.
162
- leaf_trees : array (m, 2 ** d)
163
- The leaf values of the tree or forest. If the input is a forest, the
164
- first axis is the tree index, and the values are summed.
165
- var_trees : array (m, 2 ** (d - 1))
166
- The decision axes of the trees.
167
- split_trees : array (m, 2 ** (d - 1))
168
- The decision boundaries of the trees.
169
- dtype : dtype, optional
170
- The dtype of the output. Ignored if `sum_trees` is `False`.
171
- sum_trees : bool, default True
187
+ trees
188
+ The tree heaps, with batch shape (m,).
189
+ sum_trees
172
190
  Whether to sum the values across trees.
173
191
 
174
192
  Returns
175
193
  -------
176
- out : array (n,) or (m, n)
177
- The (sum of) the values of the trees at the points in `X`.
194
+ The (sum of) the values of the trees at the points in `X`.
178
195
  """
179
- indices = traverse_forest(X, var_trees, split_trees)
180
- ntree, _ = leaf_trees.shape
181
- tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))
182
- leaves = leaf_trees[tree_index[:, None], indices]
196
+ indices = traverse_forest(X, trees.var_tree, trees.split_tree)
197
+ ntree, _ = trees.leaf_tree.shape
198
+ tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))
199
+ leaves = trees.leaf_tree[tree_index[:, None], indices]
183
200
  if sum_trees:
184
- return jnp.sum(leaves, axis=0, dtype=dtype)
201
+ return jnp.sum(leaves, axis=0, dtype=jnp.float32)
185
202
  # this sum suggests to swap the vmaps, but I think it's better for X
186
203
  # copying to keep it that way
187
204
  else:
188
205
  return leaves
189
206
 
190
207
 
191
- def is_actual_leaf(split_tree, *, add_bottom_level=False):
208
+ def is_actual_leaf(
209
+ split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
210
+ ) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
192
211
  """
193
212
  Return a mask indicating the leaf nodes in a tree.
194
213
 
195
214
  Parameters
196
215
  ----------
197
- split_tree : int array (2 ** (d - 1),)
216
+ split_tree
198
217
  The splitting points of the tree.
199
- add_bottom_level : bool, default False
218
+ add_bottom_level
200
219
  If True, the bottom level of the tree is also considered.
201
220
 
202
221
  Returns
203
222
  -------
204
- is_actual_leaf : bool array (2 ** (d - 1) or 2 ** d,)
205
- The mask indicating the leaf nodes. The length is doubled if
206
- `add_bottom_level` is True.
223
+ The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
207
224
  """
208
225
  size = split_tree.size
209
226
  is_leaf = split_tree == 0
210
227
  if add_bottom_level:
211
228
  size *= 2
212
229
  is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
213
- index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
230
+ index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
214
231
  parent_index = index >> 1
215
232
  parent_nonleaf = split_tree[parent_index].astype(bool)
216
233
  parent_nonleaf = parent_nonleaf.at[1].set(True)
217
234
  return is_leaf & parent_nonleaf
218
235
 
219
236
 
220
- def is_leaves_parent(split_tree):
237
+ def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
221
238
  """
222
239
  Return a mask indicating the nodes with leaf (and only leaf) children.
223
240
 
224
241
  Parameters
225
242
  ----------
226
- split_tree : int array (2 ** (d - 1),)
243
+ split_tree
227
244
  The decision boundaries of the tree.
228
245
 
229
246
  Returns
230
247
  -------
231
- is_leaves_parent : bool array (2 ** (d - 1),)
232
- The mask indicating which nodes have leaf children.
248
+ The mask indicating which nodes have leaf children.
233
249
  """
234
250
  index = jnp.arange(
235
- split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)
251
+ split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
236
252
  )
237
253
  left_index = index << 1 # left child
238
254
  right_index = left_index + 1 # right child
@@ -243,21 +259,24 @@ def is_leaves_parent(split_tree):
243
259
  # the 0-th item has split == 0, so it's not counted
244
260
 
245
261
 
246
- def tree_depths(tree_length):
262
+ def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
247
263
  """
248
264
  Return the depth of each node in a binary tree.
249
265
 
250
266
  Parameters
251
267
  ----------
252
- tree_length : int
268
+ tree_length
253
269
  The length of the tree array, i.e., 2 ** d.
254
270
 
255
271
  Returns
256
272
  -------
257
- depth : array (tree_length,)
258
- The depth of each node. The root node (index 1) has depth 0. The depth
259
- is the position of the most significant non-zero bit in the index. The
260
- first element (the unused node) is marked as depth 0.
273
+ The depth of each node.
274
+
275
+ Notes
276
+ -----
277
+ The root node (index 1) has depth 0. The depth is the position of the most
278
+ significant non-zero bit in the index. The first element (the unused node)
279
+ is marked as depth 0.
261
280
  """
262
281
  depths = []
263
282
  depth = 0
@@ -266,4 +285,72 @@ def tree_depths(tree_length):
266
285
  depth += 1
267
286
  depths.append(depth - 1)
268
287
  depths[0] = 0
269
- return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
288
+ return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
289
+
290
+
291
+ def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']:
292
+ """
293
+ Return a mask indicating the used nodes in a tree.
294
+
295
+ Parameters
296
+ ----------
297
+ split_tree
298
+ The decision boundaries of the tree.
299
+
300
+ Returns
301
+ -------
302
+ A mask indicating which nodes are actually used.
303
+ """
304
+ internal_node = split_tree.astype(bool)
305
+ internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)])
306
+ actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
307
+ return internal_node | actual_leaf
308
+
309
+
310
+ @jit
311
+ def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array, '']:
312
+ """
313
+ Return the fraction of used nodes in a set of trees.
314
+
315
+ Parameters
316
+ ----------
317
+ split_tree
318
+ The decision boundaries of the trees.
319
+
320
+ Returns
321
+ -------
322
+ Number of tree nodes over the maximum number that could be stored.
323
+ """
324
+ num_trees, _ = split_tree.shape
325
+ used = jax.vmap(is_used)(split_tree)
326
+ count = jnp.count_nonzero(used)
327
+ return count / (used.size - num_trees)
328
+
329
+
330
+ def var_histogram(
331
+ p: int, var_tree: UInt[Array, '* 2**(d-1)'], split_tree: UInt[Array, '* 2**(d-1)']
332
+ ) -> Int32[Array, ' {p}']:
333
+ """
334
+ Count how many times each variable appears in a tree.
335
+
336
+ Parameters
337
+ ----------
338
+ p
339
+ The number of variables (the maximum value that can occur in
340
+ `var_tree` is ``p - 1``).
341
+ var_tree
342
+ The decision axes of the tree.
343
+ split_tree
344
+ The decision boundaries of the tree.
345
+
346
+ Returns
347
+ -------
348
+ The histogram of the variables used in the tree.
349
+
350
+ Notes
351
+ -----
352
+ If there are leading axes in the tree arrays (i.e., multiple trees), the
353
+ returned counts are cumulative over trees.
354
+ """
355
+ is_internal = split_tree.astype(bool)
356
+ return jnp.zeros(p, int).at[var_tree].add(is_internal)
@@ -0,0 +1,213 @@
1
+ # bartz/src/bartz/jaxext/__init__.py
2
+ #
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Additions to jax."""
26
+
27
+ import functools
28
+ import math
29
+ from collections.abc import Sequence
30
+
31
+ import jax
32
+ from jax import numpy as jnp
33
+ from jax import random
34
+ from jax.lax import scan
35
+ from jax.scipy.special import ndtr
36
+ from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
37
+
38
+ from bartz.jaxext._autobatch import autobatch # noqa: F401
39
+ from bartz.jaxext.scipy.special import ndtri
40
+
41
+
42
+ def vmap_nodoc(fun, *args, **kw):
43
+ """
44
+ Acts like `jax.vmap` but preserves the docstring of the function unchanged.
45
+
46
+ This is useful if the docstring already takes into account that the
47
+ arguments have additional axes due to vmap.
48
+ """
49
+ doc = fun.__doc__
50
+ fun = jax.vmap(fun, *args, **kw)
51
+ fun.__doc__ = doc
52
+ return fun
53
+
54
+
55
+ def minimal_unsigned_dtype(value):
56
+ """Return the smallest unsigned integer dtype that can represent `value`."""
57
+ if value < 2**8:
58
+ return jnp.uint8
59
+ if value < 2**16:
60
+ return jnp.uint16
61
+ if value < 2**32:
62
+ return jnp.uint32
63
+ return jnp.uint64
64
+
65
+
66
+ @functools.partial(jax.jit, static_argnums=(1,))
67
+ def unique(
68
+ x: Shaped[Array, ' _'], size: int, fill_value: Scalar
69
+ ) -> tuple[Shaped[Array, ' {size}'], int]:
70
+ """
71
+ Restricted version of `jax.numpy.unique` that uses less memory.
72
+
73
+ Parameters
74
+ ----------
75
+ x
76
+ The input array.
77
+ size
78
+ The length of the output.
79
+ fill_value
80
+ The value to fill the output with if `size` is greater than the number
81
+ of unique values in `x`.
82
+
83
+ Returns
84
+ -------
85
+ out : Shaped[Array, '{size}']
86
+ The unique values in `x`, sorted, and right-padded with `fill_value`.
87
+ actual_length : int
88
+ The number of used values in `out`.
89
+ """
90
+ if x.size == 0:
91
+ return jnp.full(size, fill_value, x.dtype), 0
92
+ if size == 0:
93
+ return jnp.empty(0, x.dtype), 0
94
+ x = jnp.sort(x)
95
+
96
+ def loop(carry, x):
97
+ i_out, last, out = carry
98
+ i_out = jnp.where(x == last, i_out, i_out + 1)
99
+ out = out.at[i_out].set(x)
100
+ return (i_out, x, out), None
101
+
102
+ carry = 0, x[0], jnp.full(size, fill_value, x.dtype)
103
+ (actual_length, _, out), _ = scan(loop, carry, x[:size])
104
+ return out, actual_length + 1
105
+
106
+
107
+ class split:
108
+ """
109
+ Split a key into `num` keys.
110
+
111
+ Parameters
112
+ ----------
113
+ key
114
+ The key to split.
115
+ num
116
+ The number of keys to split into.
117
+ """
118
+
119
+ def __init__(self, key: Key[Array, ''], num: int = 2):
120
+ self._keys = random.split(key, num)
121
+
122
+ def __len__(self):
123
+ return self._keys.size
124
+
125
+ def pop(self, shape: int | tuple[int, ...] | None = None) -> Key[Array, '*']:
126
+ """
127
+ Pop one or more keys from the list.
128
+
129
+ Parameters
130
+ ----------
131
+ shape
132
+ The shape of the keys to pop. If `None`, a single key is popped.
133
+ If an integer, that many keys are popped. If a tuple, the keys are
134
+ reshaped to that shape.
135
+
136
+ Returns
137
+ -------
138
+ The popped keys as a jax array with the requested shape.
139
+
140
+ Raises
141
+ ------
142
+ IndexError
143
+ If `shape` is larger than the number of keys left in the list.
144
+
145
+ Notes
146
+ -----
147
+ The keys are popped from the beginning of the list, so for example
148
+ ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
149
+ """
150
+ if shape is None:
151
+ shape = ()
152
+ elif not isinstance(shape, tuple):
153
+ shape = (shape,)
154
+ size_to_pop = math.prod(shape)
155
+ if size_to_pop > self._keys.size:
156
+ msg = f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
157
+ raise IndexError(msg)
158
+ popped_keys = self._keys[:size_to_pop]
159
+ self._keys = self._keys[size_to_pop:]
160
+ return popped_keys.reshape(shape)
161
+
162
+
163
+ def truncated_normal_onesided(
164
+ key: Key[Array, ''],
165
+ shape: Sequence[int],
166
+ upper: Bool[Array, '*'],
167
+ bound: Float32[Array, '*'],
168
+ ) -> Float32[Array, '*']:
169
+ """
170
+ Sample from a one-sided truncated standard normal distribution.
171
+
172
+ Parameters
173
+ ----------
174
+ key
175
+ JAX random key.
176
+ shape
177
+ Shape of output array, broadcasted with other inputs.
178
+ upper
179
+ True for (-∞, bound], False for [bound, ∞).
180
+ bound
181
+ The truncation boundary.
182
+
183
+ Returns
184
+ -------
185
+ Array of samples from the truncated normal distribution.
186
+ """
187
+ # Pseudocode:
188
+ # | if upper:
189
+ # | if bound < 0:
190
+ # | ndtri(uniform(0, ndtr(bound))) =
191
+ # | ndtri(ndtr(bound) * u)
192
+ # | if bound > 0:
193
+ # | -ndtri(uniform(ndtr(-bound), 1)) =
194
+ # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
195
+ # | if not upper:
196
+ # | if bound < 0:
197
+ # | ndtri(uniform(ndtr(bound), 1)) =
198
+ # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
199
+ # | if bound > 0:
200
+ # | -ndtri(uniform(0, ndtr(-bound))) =
201
+ # | -ndtri(ndtr(-bound) * u)
202
+ shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape)
203
+ bound_pos = bound > 0
204
+ ndtr_bound = ndtr(bound)
205
+ ndtr_neg_bound = ndtr(-bound)
206
+ scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound)
207
+ shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound)
208
+ u = random.uniform(key, shape)
209
+ left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
210
+ right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
211
+ truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
212
+ truncated_norm = ndtri(truncated_u)
213
+ return jnp.where(bound_pos, -truncated_norm, truncated_norm)