bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/prepcovars.py CHANGED
@@ -24,40 +24,82 @@
24
24
 
25
25
  """Functions to preprocess data."""
26
26
 
27
- import functools
27
+ from functools import partial
28
28
 
29
- import jax
29
+ from jax import jit, vmap
30
30
  from jax import numpy as jnp
31
+ from jaxtyping import Array, Float, Integer, Real, UInt
31
32
 
32
- from . import jaxext
33
+ from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique
33
34
 
34
35
 
35
- @functools.partial(jax.jit, static_argnums=(1,))
36
- def quantilized_splits_from_matrix(X, max_bins):
36
+ def parse_xinfo(
37
+ xinfo: Float[Array, 'p m'],
38
+ ) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]:
39
+ """Parse pre-defined splits in the format of the R package BART.
40
+
41
+ Parameters
42
+ ----------
43
+ xinfo
44
+ A matrix with the cutpoins to use to bin each predictor. Each row shall
45
+ contain a sorted list of cutpoints for a predictor. If there are less
46
+ cutpoints than the number of columns in the matrix, fill the remaining
47
+ cells with NaN.
48
+
49
+ `xinfo` shall be a matrix even if `x_train` is a dataframe.
50
+
51
+ Returns
52
+ -------
53
+ splits : Float[Array, 'p m']
54
+ `xinfo` modified by replacing nan with a large value.
55
+ max_split : UInt[Array, 'p']
56
+ The number of non-nan elements in each row of `xinfo`.
57
+ """
58
+ is_not_nan = ~jnp.isnan(xinfo)
59
+ max_split = jnp.sum(is_not_nan, axis=1)
60
+ max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1]))
61
+ huge = _huge_value(xinfo)
62
+ splits = jnp.where(is_not_nan, xinfo, huge)
63
+ return splits, max_split
64
+
65
+
66
+ @partial(jit, static_argnums=(1,))
67
+ def quantilized_splits_from_matrix(
68
+ X: Real[Array, 'p n'], max_bins: int
69
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
37
70
  """
38
71
  Determine bins that make the distribution of each predictor uniform.
39
72
 
40
73
  Parameters
41
74
  ----------
42
- X : array (p, n)
75
+ X
43
76
  A matrix with `p` predictors and `n` observations.
44
- max_bins : int
77
+ max_bins
45
78
  The maximum number of bins to produce.
46
79
 
47
80
  Returns
48
81
  -------
49
- splits : array (p, m)
82
+ splits : Real[Array, 'p m']
50
83
  A matrix containing, for each predictor, the boundaries between bins.
51
84
  `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
52
85
  of splits. Each predictor may have a different number of splits; unused
53
86
  values at the end of each row are filled with the maximum value
54
87
  representable in the type of `X`.
55
- max_split : array (p,)
88
+ max_split : UInt[Array, ' p']
56
89
  The number of actually used values in each row of `splits`.
90
+
91
+ Raises
92
+ ------
93
+ ValueError
94
+ If `X` has no columns or if `max_bins` is less than 1.
57
95
  """
58
96
  out_length = min(max_bins, X.shape[1]) - 1
59
97
 
60
- @functools.partial(jaxext.autobatch, max_io_nbytes=2**29)
98
+ if out_length < 0:
99
+ msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.'
100
+ raise ValueError(msg)
101
+
102
+ @partial(autobatch, max_io_nbytes=2**29)
61
103
  def quantilize(X):
62
104
  # wrap this function because autobatch needs traceable args
63
105
  return _quantilized_splits_from_matrix(X, out_length)
@@ -65,23 +107,28 @@ def quantilized_splits_from_matrix(X, max_bins):
65
107
  return quantilize(X)
66
108
 
67
109
 
68
- @functools.partial(jax.vmap, in_axes=(0, None))
69
- def _quantilized_splits_from_matrix(x, out_length):
70
- huge = jaxext.huge_value(x)
71
- u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
72
- actual_length -= 1
110
+ @partial(vmap, in_axes=(0, None))
111
+ def _quantilized_splits_from_matrix(
112
+ x: Real[Array, 'p n'], out_length: int
113
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
114
+ # find the sorted unique values in x
115
+ huge = _huge_value(x)
116
+ u, actual_length = unique(x, size=x.size, fill_value=huge)
117
+
118
+ # compute the midpoints between each unique value
73
119
  if jnp.issubdtype(x.dtype, jnp.integer):
74
- midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
75
- indices = jnp.arange(
76
- midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1)
77
- )
78
- midpoints = jnp.where(indices < actual_length, midpoints, huge)
120
+ midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2
79
121
  else:
80
- midpoints = (u[1:] + u[:-1]) / 2
122
+ midpoints = u[:-1] + (u[1:] - u[:-1]) / 2
123
+ # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to
124
+ # avoid overflow
125
+ actual_length -= 1
126
+ if midpoints.size:
127
+ midpoints = midpoints.at[actual_length].set(huge)
128
+
129
+ # take a subset of the midpoints if there are more than the requested maximum
81
130
  indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
82
- indices = jnp.around(indices).astype(
83
- jaxext.minimal_unsigned_dtype(midpoints.size - 1)
84
- )
131
+ indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1))
85
132
  # indices calculation with float rather than int to avoid potential
86
133
  # overflow with int32, and to round to nearest instead of rounding down
87
134
  decimated_midpoints = midpoints[indices]
@@ -90,41 +137,92 @@ def _quantilized_splits_from_matrix(x, out_length):
90
137
  actual_length > out_length, decimated_midpoints, truncated_midpoints
91
138
  )
92
139
  max_split = jnp.minimum(actual_length, out_length)
93
- max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
140
+ max_split = max_split.astype(minimal_unsigned_dtype(out_length))
94
141
  return splits, max_split
95
142
 
96
143
 
97
- @functools.partial(jax.jit, static_argnums=(1,))
98
- def uniform_splits_from_matrix(X, num_bins):
144
+ def _huge_value(x: Array) -> int | float:
145
+ """
146
+ Return the maximum value that can be stored in `x`.
147
+
148
+ Parameters
149
+ ----------
150
+ x
151
+ A numerical numpy or jax array.
152
+
153
+ Returns
154
+ -------
155
+ The maximum value allowed by `x`'s type (finite for floats).
156
+ """
157
+ if jnp.issubdtype(x.dtype, jnp.integer):
158
+ return jnp.iinfo(x.dtype).max
159
+ else:
160
+ return float(jnp.finfo(x.dtype).max)
161
+
162
+
163
+ def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']:
164
+ """If x has signed integer type, cast it to the unsigned dtype of the same size."""
165
+ return x.astype(_signed_to_unsigned(x.dtype))
166
+
167
+
168
+ def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype:
169
+ """
170
+ Map a signed integer type to its unsigned counterpart.
171
+
172
+ Unsigned types are passed through.
173
+ """
174
+ assert jnp.issubdtype(int_dtype, jnp.integer)
175
+ if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
176
+ return int_dtype
177
+ match int_dtype:
178
+ case jnp.int8:
179
+ return jnp.uint8
180
+ case jnp.int16:
181
+ return jnp.uint16
182
+ case jnp.int32:
183
+ return jnp.uint32
184
+ case jnp.int64:
185
+ return jnp.uint64
186
+ case _:
187
+ msg = f'unexpected integer type {int_dtype}'
188
+ raise TypeError(msg)
189
+
190
+
191
+ @partial(jit, static_argnums=(1,))
192
+ def uniform_splits_from_matrix(
193
+ X: Real[Array, 'p n'], num_bins: int
194
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
99
195
  """
100
196
  Make an evenly spaced binning grid.
101
197
 
102
198
  Parameters
103
199
  ----------
104
- X : array (p, n)
200
+ X
105
201
  A matrix with `p` predictors and `n` observations.
106
- num_bins : int
202
+ num_bins
107
203
  The number of bins to produce.
108
204
 
109
205
  Returns
110
206
  -------
111
- splits : array (p, num_bins - 1)
207
+ splits : Real[Array, 'p m']
112
208
  A matrix containing, for each predictor, the boundaries between bins.
113
209
  The excluded endpoints are the minimum and maximum value in each row of
114
210
  `X`.
115
- max_split : array (p,)
211
+ max_split : UInt[Array, ' p']
116
212
  The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
117
213
  """
118
214
  low = jnp.min(X, axis=1)
119
215
  high = jnp.max(X, axis=1)
120
216
  splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1]
121
217
  assert splits.shape == (X.shape[0], num_bins - 1)
122
- max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1))
218
+ max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1))
123
219
  return splits, max_split
124
220
 
125
221
 
126
- @functools.partial(jax.jit, static_argnames=('method',))
127
- def bin_predictors(X, splits, **kw):
222
+ @partial(jit, static_argnames=('method',))
223
+ def bin_predictors(
224
+ X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw
225
+ ) -> UInt[Array, 'p n']:
128
226
  """
129
227
  Bin the predictors according to the given splits.
130
228
 
@@ -132,27 +230,25 @@ def bin_predictors(X, splits, **kw):
132
230
 
133
231
  Parameters
134
232
  ----------
135
- X : array (p, n)
233
+ X
136
234
  A matrix with `p` predictors and `n` observations.
137
- splits : array (p, m)
235
+ splits
138
236
  A matrix containing, for each predictor, the boundaries between bins.
139
237
  `m` is the maximum number of splits; each row may have shorter
140
238
  actual length, marked by padding unused locations at the end of the
141
239
  row with the maximum value allowed by the type.
142
- **kw : dict
240
+ **kw
143
241
  Additional arguments are passed to `jax.numpy.searchsorted`.
144
242
 
145
243
  Returns
146
244
  -------
147
- X_binned : int array (p, n)
148
- A matrix with `p` predictors and `n` observations, where each predictor
149
- has been replaced by the index of the bin it falls into.
245
+ `X` but with each value replaced by the index of the bin it falls into.
150
246
  """
151
247
 
152
- @functools.partial(jaxext.autobatch, max_io_nbytes=2**29)
153
- @jax.vmap
248
+ @partial(autobatch, max_io_nbytes=2**29)
249
+ @vmap
154
250
  def bin_predictors(x, splits):
155
- dtype = jaxext.minimal_unsigned_dtype(splits.size)
251
+ dtype = minimal_unsigned_dtype(splits.size)
156
252
  return jnp.searchsorted(splits, x, **kw).astype(dtype)
157
253
 
158
254
  return bin_predictors(X, splits)
@@ -1,13 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bartz
3
- Version: 0.6.0
3
+ Version: 0.7.0
4
4
  Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
5
5
  Author: Giacomo Petrillo
6
6
  Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
7
7
  License-Expression: MIT
8
8
  Requires-Dist: equinox>=0.12.2
9
- Requires-Dist: jax>=0.4.35
10
- Requires-Dist: jaxlib>=0.4.35
9
+ Requires-Dist: jax>=0.5.3
11
10
  Requires-Dist: jaxtyping>=0.3.2
12
11
  Requires-Dist: numpy>=1.25.2
13
12
  Requires-Dist: scipy>=1.11.4
@@ -0,0 +1,17 @@
1
+ bartz/.DS_Store,sha256=7191af46d7b8c0d4c03c502f94eb01353bc2e615d75c45b3af0e31ab238034b5,6148
2
+ bartz/BART.py,sha256=6b129e20a258d724e0cba5ffaf377b5c4d62e545c1f3a737602c2a0be5d84b96,29601
3
+ bartz/__init__.py,sha256=98c579136a8755390210ada33e713290749e4d7fca58550c791f77f192f4b4a1,1436
4
+ bartz/_version.py,sha256=d3d868979a2f2fa02b20b248259f8c8ac7273a329ae918da1139ad7602695b67,22
5
+ bartz/debug.py,sha256=5082c2dd07f6d3f8353c57491c857cca4aea913647e24502cc375f9aadc84975,43736
6
+ bartz/grove.py,sha256=f64505623feec7edcec96930909f9a9326f290976d7e9bdb7b9abcb32fe425fe,10559
7
+ bartz/jaxext/__init__.py,sha256=6cd2e7c23ccc4f0399fb3f7989312a500fd8bd3f7f07eb85ce537f2c8873f35a,6705
8
+ bartz/jaxext/_autobatch.py,sha256=b5dbaec52e39b4b32c824fde47eaaf33f496a1574cc336d3a79fa71f4c5e348a,7116
9
+ bartz/jaxext/scipy/__init__.py,sha256=a1f5990a75c1c73908565be4cd5fa1c07278ad3a02b78b21e2f1225b388ab6b5,1227
10
+ bartz/jaxext/scipy/special.py,sha256=f0e777c29a77d46d55ff2f0169b3d8581de24d8580b8d6f1c7c1473910fb62d9,8111
11
+ bartz/jaxext/scipy/stats.py,sha256=703beb9fcfe606a195fe9a3143eceed367c87a31bd75763ecdb6e509c2e87f53,1483
12
+ bartz/mcmcloop.py,sha256=e841c076bf5392b1ad8eab955c70a63d18ac321c249d255c7424b62e52c127c1,22134
13
+ bartz/mcmcstep.py,sha256=3b2a4bdad3bbf836efc7cc6f52adac3ba23dce00a1cca4a99e5a8d79d10c3e68,84092
14
+ bartz/prepcovars.py,sha256=50334621ca6ec7a6e35d21ce5ff8b96d4cbf0e696c9bcd10d86268e301a820a7,8728
15
+ bartz-0.7.0.dist-info/WHEEL,sha256=607c46fee47e440c91332c738096ff0f5e54ca3b0818ee85462dd5172a38e793,79
16
+ bartz-0.7.0.dist-info/METADATA,sha256=758e12296acf815c9bae50601ff502e6dee97c9ac0d817232e45e3dfb39665bc,2815
17
+ bartz-0.7.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.7.8
2
+ Generator: uv 0.7.19
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any