bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +4 -2
- bartz/{BART.py → _interface.py} +256 -132
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +269 -314
- bartz/grove.py +124 -68
- bartz/jaxext/__init__.py +101 -27
- bartz/jaxext/_autobatch.py +257 -51
- bartz/jaxext/scipy/__init__.py +1 -1
- bartz/jaxext/scipy/special.py +3 -4
- bartz/jaxext/scipy/stats.py +1 -1
- bartz/mcmcloop.py +399 -208
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +1 -1
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/METADATA +17 -11
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/mcmcstep.py +0 -2616
- bartz-0.7.0.dist-info/RECORD +0 -17
bartz/jaxext/_autobatch.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/jaxext/_autobatch.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2025,
|
|
3
|
+
# Copyright (c) 2025-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -26,16 +26,23 @@
|
|
|
26
26
|
|
|
27
27
|
import math
|
|
28
28
|
from collections.abc import Callable
|
|
29
|
-
from functools import wraps
|
|
29
|
+
from functools import partial, wraps
|
|
30
30
|
from warnings import warn
|
|
31
31
|
|
|
32
|
-
from jax import
|
|
32
|
+
from jax.typing import DTypeLike
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from numpy.lib.array_utils import normalize_axis_index # numpy 2
|
|
36
|
+
except ImportError:
|
|
37
|
+
from numpy.core.numeric import normalize_axis_index # numpy 1
|
|
38
|
+
|
|
39
|
+
from jax import ShapeDtypeStruct, eval_shape, jit
|
|
33
40
|
from jax import numpy as jnp
|
|
34
41
|
from jax.lax import scan
|
|
35
42
|
from jax.tree import flatten as tree_flatten
|
|
36
43
|
from jax.tree import map as tree_map
|
|
37
44
|
from jax.tree import reduce as tree_reduce
|
|
38
|
-
from jaxtyping import PyTree
|
|
45
|
+
from jaxtyping import Array, PyTree, Shaped
|
|
39
46
|
|
|
40
47
|
|
|
41
48
|
def expand_axes(axes, tree):
|
|
@@ -47,14 +54,43 @@ def expand_axes(axes, tree):
|
|
|
47
54
|
return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None)
|
|
48
55
|
|
|
49
56
|
|
|
57
|
+
def normalize_axes(
|
|
58
|
+
axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T']
|
|
59
|
+
) -> PyTree[int | None, ' T']:
|
|
60
|
+
"""Normalize axes to be non-negative and valid for the corresponding arrays in the tree."""
|
|
61
|
+
|
|
62
|
+
def normalize_axis(axis: int | None, x: Array) -> int | None:
|
|
63
|
+
if axis is None:
|
|
64
|
+
return None
|
|
65
|
+
else:
|
|
66
|
+
return normalize_axis_index(axis, len(x.shape))
|
|
67
|
+
|
|
68
|
+
return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None)
|
|
69
|
+
|
|
70
|
+
|
|
50
71
|
def check_no_nones(axes, tree):
|
|
51
72
|
def check_not_none(_, axis):
|
|
52
73
|
assert axis is not None
|
|
53
74
|
|
|
54
|
-
tree_map(check_not_none, tree, axes)
|
|
75
|
+
tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def remove_axis(
|
|
79
|
+
x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc
|
|
80
|
+
) -> PyTree[ShapeDtypeStruct, ' T']:
|
|
81
|
+
"""Remove an axis from dummy arrays and change the type to reduction type."""
|
|
82
|
+
|
|
83
|
+
def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct:
|
|
84
|
+
new_shape = x.shape[:axis] + x.shape[axis + 1 :]
|
|
85
|
+
new_dtype = reduction_dtype(ufunc, x.dtype)
|
|
86
|
+
return ShapeDtypeStruct(new_shape, new_dtype)
|
|
87
|
+
|
|
88
|
+
return tree_map(remove_axis, x, axis)
|
|
55
89
|
|
|
56
90
|
|
|
57
91
|
def extract_size(axes, tree):
|
|
92
|
+
"""Get the size of each array in tree at the axis in axes, check they are equal and return it."""
|
|
93
|
+
|
|
58
94
|
def get_size(x, axis):
|
|
59
95
|
if axis is None:
|
|
60
96
|
return None
|
|
@@ -90,6 +126,7 @@ def next_divisor_large(dividend, min_divisor):
|
|
|
90
126
|
|
|
91
127
|
|
|
92
128
|
def next_divisor(dividend, min_divisor):
|
|
129
|
+
"""Return divisor >= min_divisor such that divided % divisor == 0."""
|
|
93
130
|
if dividend == 0:
|
|
94
131
|
return min_divisor
|
|
95
132
|
if min_divisor * min_divisor <= dividend:
|
|
@@ -131,20 +168,73 @@ def move_axes_in(axes, tree):
|
|
|
131
168
|
return tree_map(move_axis_in, tree, axes)
|
|
132
169
|
|
|
133
170
|
|
|
134
|
-
def batch(tree, nbatches):
|
|
171
|
+
def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
|
|
172
|
+
"""Split the first axis into two axes, the first of size `nbatches`."""
|
|
173
|
+
|
|
135
174
|
def batch(x):
|
|
136
|
-
return x.reshape(
|
|
175
|
+
return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:])
|
|
137
176
|
|
|
138
177
|
return tree_map(batch, tree)
|
|
139
178
|
|
|
140
179
|
|
|
141
|
-
def unbatch(tree):
|
|
180
|
+
def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
|
|
181
|
+
"""Merge the first two axes into a single axis."""
|
|
182
|
+
|
|
142
183
|
def unbatch(x):
|
|
143
|
-
return x.reshape(
|
|
184
|
+
return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
|
|
144
185
|
|
|
145
186
|
return tree_map(unbatch, tree)
|
|
146
187
|
|
|
147
188
|
|
|
189
|
+
def reduce(
|
|
190
|
+
ufunc: jnp.ufunc,
|
|
191
|
+
x: PyTree[Array, ' T'],
|
|
192
|
+
axes: PyTree[int, ' T'],
|
|
193
|
+
initial: PyTree[Array, ' T'] | None,
|
|
194
|
+
) -> PyTree[Array, ' T']:
|
|
195
|
+
"""Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
|
|
196
|
+
if initial is None:
|
|
197
|
+
|
|
198
|
+
def reduce(x: Array, axis: int) -> Array:
|
|
199
|
+
return ufunc.reduce(x, axis=axis)
|
|
200
|
+
|
|
201
|
+
return tree_map(reduce, x, axes)
|
|
202
|
+
|
|
203
|
+
else:
|
|
204
|
+
|
|
205
|
+
def reduce(x: Array, initial: Array, axis: int) -> Array:
|
|
206
|
+
reduced = ufunc.reduce(x, axis=axis)
|
|
207
|
+
return ufunc(initial, reduced)
|
|
208
|
+
|
|
209
|
+
return tree_map(reduce, x, initial, axes)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def identity(
|
|
213
|
+
ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T']
|
|
214
|
+
) -> PyTree[Array, ' T']:
|
|
215
|
+
"""Get the identity element for `ufunc` and each array in `x`."""
|
|
216
|
+
|
|
217
|
+
def identity(x: ShapeDtypeStruct) -> Array:
|
|
218
|
+
identity = identity_for(ufunc, x.dtype)
|
|
219
|
+
return jnp.broadcast_to(identity, x.shape)
|
|
220
|
+
|
|
221
|
+
return tree_map(identity, x)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike:
|
|
225
|
+
"""Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
|
|
226
|
+
return ufunc.reduce(jnp.empty(1, input_dtype)).dtype
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
|
|
230
|
+
"""Return the identity for ufunc as an array scalar with the right dtype."""
|
|
231
|
+
# get output type from input type, e.g., int8 is accumulated to int32
|
|
232
|
+
dtype = reduction_dtype(ufunc, input_dtype)
|
|
233
|
+
|
|
234
|
+
# return as explicitly typed array
|
|
235
|
+
return jnp.array(ufunc.identity, dtype)
|
|
236
|
+
|
|
237
|
+
|
|
148
238
|
def check_same(tree1, tree2):
|
|
149
239
|
def check_same(x1, x2):
|
|
150
240
|
assert x1.shape == x2.shape
|
|
@@ -153,12 +243,20 @@ def check_same(tree1, tree2):
|
|
|
153
243
|
tree_map(check_same, tree1, tree2)
|
|
154
244
|
|
|
155
245
|
|
|
246
|
+
class NotDefined:
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
|
|
156
250
|
def autobatch(
|
|
157
251
|
func: Callable,
|
|
158
252
|
max_io_nbytes: int,
|
|
159
253
|
in_axes: PyTree[int | None] = 0,
|
|
160
254
|
out_axes: PyTree[int] = 0,
|
|
255
|
+
*,
|
|
161
256
|
return_nbatches: bool = False,
|
|
257
|
+
reduce_ufunc: jnp.ufunc | None = None,
|
|
258
|
+
warn_on_overflow: bool = True,
|
|
259
|
+
result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined,
|
|
162
260
|
) -> Callable:
|
|
163
261
|
"""
|
|
164
262
|
Batch a function such that each batch is smaller than a threshold.
|
|
@@ -179,60 +277,168 @@ def autobatch(
|
|
|
179
277
|
The same for outputs (but non-batching is not allowed).
|
|
180
278
|
return_nbatches
|
|
181
279
|
If True, the number of batches is returned as a second output.
|
|
280
|
+
reduce_ufunc
|
|
281
|
+
Function used to reduce the output along the batched axis (e.g.,
|
|
282
|
+
`jax.numpy.add`).
|
|
283
|
+
warn_on_overflow
|
|
284
|
+
If True, a warning is raised if the memory limit could not be
|
|
285
|
+
respected.
|
|
286
|
+
result_shape_dtype
|
|
287
|
+
A pytree of dummy arrays matching the expected output. If not provided,
|
|
288
|
+
the function is traced an additional time to determine the output
|
|
289
|
+
structure.
|
|
182
290
|
|
|
183
291
|
Returns
|
|
184
292
|
-------
|
|
185
293
|
A function with the same signature as `func`, save for the return value if `return_nbatches`.
|
|
294
|
+
|
|
295
|
+
Notes
|
|
296
|
+
-----
|
|
297
|
+
Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
|
|
298
|
+
arguments is idempotent. Furthermore, `autobatch` can be applied multiple
|
|
299
|
+
times over multiple axes with the same `max_io_nbytes` limit to work on
|
|
300
|
+
multiple axes; in this case it won't unnecessarily loop over additional axes
|
|
301
|
+
if one or more outer `autobatch` are already sufficient.
|
|
302
|
+
|
|
303
|
+
To handle memory used in intermediate values: assuming all intermediate
|
|
304
|
+
values have size that scales linearly with the axis batched over, say the
|
|
305
|
+
batched input/output total size is ``batched_size * core_io_size``, and the
|
|
306
|
+
intermediate values have size ``batched_size * core_int_size``, then to take
|
|
307
|
+
them into account divide `max_io_nbytes` by ``(1 + core_int_size /
|
|
308
|
+
core_io_size)``.
|
|
186
309
|
"""
|
|
187
|
-
initial_in_axes = in_axes
|
|
188
|
-
initial_out_axes = out_axes
|
|
189
310
|
|
|
190
311
|
@jit
|
|
191
312
|
@wraps(func)
|
|
192
|
-
def
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
313
|
+
def autobatch_wrapper(*args):
|
|
314
|
+
return batched_func(
|
|
315
|
+
func,
|
|
316
|
+
max_io_nbytes,
|
|
317
|
+
in_axes,
|
|
318
|
+
out_axes,
|
|
319
|
+
return_nbatches,
|
|
320
|
+
reduce_ufunc,
|
|
321
|
+
warn_on_overflow,
|
|
322
|
+
result_shape_dtype,
|
|
323
|
+
args,
|
|
324
|
+
)
|
|
200
325
|
|
|
201
|
-
|
|
326
|
+
return autobatch_wrapper
|
|
202
327
|
|
|
203
|
-
total_nbytes = sum_nbytes((args, example_result))
|
|
204
|
-
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
205
|
-
total_nbytes % max_io_nbytes
|
|
206
|
-
)
|
|
207
|
-
min_nbatches = max(1, min_nbatches)
|
|
208
|
-
nbatches = next_divisor(size, min_nbatches)
|
|
209
|
-
assert 1 <= nbatches <= max(1, size)
|
|
210
|
-
assert size % nbatches == 0
|
|
211
|
-
assert total_nbytes % nbatches == 0
|
|
212
|
-
|
|
213
|
-
batch_nbytes = total_nbytes // nbatches
|
|
214
|
-
if batch_nbytes > max_io_nbytes:
|
|
215
|
-
assert size == nbatches
|
|
216
|
-
msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
|
|
217
|
-
warn(msg)
|
|
218
|
-
|
|
219
|
-
def loop(_, args):
|
|
220
|
-
args = move_axes_in(in_axes, args)
|
|
221
|
-
args = push_nonbatched(in_axes, args, nonbatched_args)
|
|
222
|
-
result = func(*args)
|
|
223
|
-
result = move_axes_out(out_axes, result)
|
|
224
|
-
return None, result
|
|
225
328
|
|
|
329
|
+
def batched_func(
|
|
330
|
+
func: Callable,
|
|
331
|
+
max_io_nbytes: int,
|
|
332
|
+
in_axes: PyTree[int | None],
|
|
333
|
+
out_axes: PyTree[int],
|
|
334
|
+
return_nbatches: bool,
|
|
335
|
+
reduce_ufunc: jnp.ufunc | None,
|
|
336
|
+
warn_on_overflow: bool,
|
|
337
|
+
result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined,
|
|
338
|
+
args: tuple[PyTree[Array], ...],
|
|
339
|
+
) -> PyTree[Array]:
|
|
340
|
+
"""Implement the wrapper used in `autobatch`."""
|
|
341
|
+
# determine the output structure of the function
|
|
342
|
+
if result_shape_dtype is NotDefined:
|
|
343
|
+
example_result = eval_shape(func, *args)
|
|
344
|
+
else:
|
|
345
|
+
example_result = result_shape_dtype
|
|
346
|
+
|
|
347
|
+
# expand the axes pytrees if they are prefixes
|
|
348
|
+
in_axes = expand_axes(in_axes, args)
|
|
349
|
+
out_axes = expand_axes(out_axes, example_result)
|
|
350
|
+
check_no_nones(out_axes, example_result)
|
|
351
|
+
|
|
352
|
+
# check the axes are valid
|
|
353
|
+
in_axes = normalize_axes(in_axes, args)
|
|
354
|
+
out_axes = normalize_axes(out_axes, example_result)
|
|
355
|
+
|
|
356
|
+
# get the size of the batched axis
|
|
357
|
+
size = extract_size((in_axes, out_axes), (args, example_result))
|
|
358
|
+
|
|
359
|
+
# split arguments in batched and not batched
|
|
360
|
+
original_args = args
|
|
361
|
+
args, nonbatched_args = pull_nonbatched(in_axes, args)
|
|
362
|
+
|
|
363
|
+
# determine the number of batches to respect the memory limit
|
|
364
|
+
total_nbytes = sum_nbytes((args, example_result))
|
|
365
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
366
|
+
min_nbatches = max(1, min_nbatches)
|
|
367
|
+
nbatches = next_divisor(size, min_nbatches)
|
|
368
|
+
assert 1 <= nbatches <= max(1, size)
|
|
369
|
+
assert size % nbatches == 0
|
|
370
|
+
assert total_nbytes % nbatches == 0
|
|
371
|
+
|
|
372
|
+
# warn if the memory limit could not be respected
|
|
373
|
+
batch_nbytes = total_nbytes // nbatches
|
|
374
|
+
if batch_nbytes > max_io_nbytes and warn_on_overflow:
|
|
375
|
+
assert size == nbatches
|
|
376
|
+
msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}'
|
|
377
|
+
warn(msg)
|
|
378
|
+
|
|
379
|
+
# squeeze out the output dims that will be reduced
|
|
380
|
+
if reduce_ufunc is not None:
|
|
381
|
+
example_result = remove_axis(example_result, out_axes, reduce_ufunc)
|
|
382
|
+
|
|
383
|
+
if nbatches > 1:
|
|
384
|
+
# prepare arguments for looping
|
|
226
385
|
args = move_axes_out(in_axes, args)
|
|
227
386
|
args = batch(args, nbatches)
|
|
228
|
-
_, result = scan(loop, None, args)
|
|
229
|
-
result = unbatch(result)
|
|
230
|
-
result = move_axes_in(out_axes, result)
|
|
231
|
-
|
|
232
|
-
check_same(example_result, result)
|
|
233
387
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
388
|
+
# prepare carry for reduction
|
|
389
|
+
if reduce_ufunc is None:
|
|
390
|
+
initial = None
|
|
391
|
+
else:
|
|
392
|
+
initial = identity(reduce_ufunc, example_result)
|
|
393
|
+
|
|
394
|
+
# loop and invoke the function in batches
|
|
395
|
+
loop = partial(
|
|
396
|
+
batching_loop,
|
|
397
|
+
func=func,
|
|
398
|
+
nonbatched_args=nonbatched_args,
|
|
399
|
+
in_axes=in_axes,
|
|
400
|
+
out_axes=out_axes,
|
|
401
|
+
reduce_ufunc=reduce_ufunc,
|
|
402
|
+
)
|
|
403
|
+
reduced_result, result = scan(loop, initial, args)
|
|
237
404
|
|
|
238
|
-
|
|
405
|
+
# remove auxiliary batching axis and reverse transposition
|
|
406
|
+
if reduce_ufunc is None:
|
|
407
|
+
assert reduced_result is None
|
|
408
|
+
result = unbatch(result)
|
|
409
|
+
result = move_axes_in(out_axes, result)
|
|
410
|
+
else:
|
|
411
|
+
assert result is None
|
|
412
|
+
result = reduced_result
|
|
413
|
+
|
|
414
|
+
# trivial case: no batching needed
|
|
415
|
+
else:
|
|
416
|
+
result = func(*original_args)
|
|
417
|
+
if reduce_ufunc is not None:
|
|
418
|
+
result = reduce(reduce_ufunc, result, out_axes, None)
|
|
419
|
+
|
|
420
|
+
check_same(example_result, result)
|
|
421
|
+
|
|
422
|
+
if return_nbatches:
|
|
423
|
+
return result, nbatches
|
|
424
|
+
return result
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def batching_loop(
|
|
428
|
+
initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc
|
|
429
|
+
):
|
|
430
|
+
"""Implement the batching loop in `autobatch`."""
|
|
431
|
+
# evaluate the function
|
|
432
|
+
args = move_axes_in(in_axes, args)
|
|
433
|
+
args = push_nonbatched(in_axes, args, nonbatched_args)
|
|
434
|
+
result = func(*args)
|
|
435
|
+
|
|
436
|
+
# unreduced case: transpose for concatenation and return
|
|
437
|
+
if reduce_ufunc is None:
|
|
438
|
+
result = move_axes_out(out_axes, result)
|
|
439
|
+
return None, result
|
|
440
|
+
|
|
441
|
+
# reduced case: reduce starting from initial
|
|
442
|
+
else:
|
|
443
|
+
reduced_result = reduce(reduce_ufunc, result, out_axes, initial)
|
|
444
|
+
return reduced_result, None
|
bartz/jaxext/scipy/__init__.py
CHANGED
bartz/jaxext/scipy/special.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/jaxext/scipy/special.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2025,
|
|
3
|
+
# Copyright (c) 2025, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -26,7 +26,7 @@
|
|
|
26
26
|
|
|
27
27
|
from functools import wraps
|
|
28
28
|
|
|
29
|
-
from jax import ShapeDtypeStruct, pure_callback
|
|
29
|
+
from jax import ShapeDtypeStruct, jit, pure_callback
|
|
30
30
|
from jax import numpy as jnp
|
|
31
31
|
from scipy.special import gammainccinv as scipy_gammainccinv
|
|
32
32
|
|
|
@@ -45,10 +45,9 @@ def _castto(func, dtype):
|
|
|
45
45
|
return newfunc
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
@jit
|
|
48
49
|
def gammainccinv(a, y):
|
|
49
50
|
"""Survival function inverse of the Gamma(a, 1) distribution."""
|
|
50
|
-
a = jnp.asarray(a)
|
|
51
|
-
y = jnp.asarray(y)
|
|
52
51
|
shape = jnp.broadcast_shapes(a.shape, y.shape)
|
|
53
52
|
dtype = _float_type(a.dtype, y.dtype)
|
|
54
53
|
dummy = ShapeDtypeStruct(shape, dtype)
|