warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +1904 -114
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +331 -101
- warp/builtins.py +1244 -160
- warp/codegen.py +317 -206
- warp/config.py +1 -1
- warp/context.py +1465 -789
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_kernel.py +2 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +25 -2
- warp/jax_experimental/ffi.py +22 -1
- warp/jax_experimental/xla_ffi.py +16 -7
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +86 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +40 -31
- warp/native/sort.h +2 -0
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +471 -82
- warp/native/vec.h +328 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +377 -216
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +99 -18
- warp/render/render_usd.py +1 -0
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/interop/test_jax.py +608 -28
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +58 -5
- warp/tests/test_codegen.py +4 -3
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +49 -6
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +15 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +61 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +245 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +571 -267
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -17,10 +17,12 @@ from __future__ import annotations
|
|
|
17
17
|
|
|
18
18
|
import builtins
|
|
19
19
|
import functools
|
|
20
|
+
import math
|
|
20
21
|
from typing import Any, Callable, Mapping, Sequence
|
|
21
22
|
|
|
22
23
|
import warp.build
|
|
23
24
|
import warp.context
|
|
25
|
+
import warp.utils
|
|
24
26
|
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
25
27
|
from warp.types import *
|
|
26
28
|
|
|
@@ -124,6 +126,7 @@ add_builtin(
|
|
|
124
126
|
value_func=sametypes_create_value_func(Scalar),
|
|
125
127
|
doc="Return -1 if ``x`` < 0, return 1 otherwise.",
|
|
126
128
|
group="Scalar Math",
|
|
129
|
+
missing_grad=True,
|
|
127
130
|
)
|
|
128
131
|
|
|
129
132
|
add_builtin(
|
|
@@ -132,6 +135,7 @@ add_builtin(
|
|
|
132
135
|
value_func=sametypes_create_value_func(Scalar),
|
|
133
136
|
doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
|
|
134
137
|
group="Scalar Math",
|
|
138
|
+
missing_grad=True,
|
|
135
139
|
)
|
|
136
140
|
add_builtin(
|
|
137
141
|
"nonzero",
|
|
@@ -139,6 +143,7 @@ add_builtin(
|
|
|
139
143
|
value_func=sametypes_create_value_func(Scalar),
|
|
140
144
|
doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
|
|
141
145
|
group="Scalar Math",
|
|
146
|
+
missing_grad=True,
|
|
142
147
|
)
|
|
143
148
|
|
|
144
149
|
add_builtin(
|
|
@@ -290,6 +295,7 @@ add_builtin(
|
|
|
290
295
|
|
|
291
296
|
This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
|
|
292
297
|
Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
|
|
298
|
+
missing_grad=True,
|
|
293
299
|
)
|
|
294
300
|
|
|
295
301
|
add_builtin(
|
|
@@ -300,6 +306,7 @@ add_builtin(
|
|
|
300
306
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
|
|
301
307
|
|
|
302
308
|
It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
|
|
309
|
+
missing_grad=True,
|
|
303
310
|
)
|
|
304
311
|
|
|
305
312
|
add_builtin(
|
|
@@ -312,6 +319,7 @@ add_builtin(
|
|
|
312
319
|
In other words, it discards the fractional part of ``x``.
|
|
313
320
|
It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
|
|
314
321
|
Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
|
|
322
|
+
missing_grad=True,
|
|
315
323
|
)
|
|
316
324
|
|
|
317
325
|
add_builtin(
|
|
@@ -320,6 +328,7 @@ add_builtin(
|
|
|
320
328
|
value_func=sametypes_create_value_func(Float),
|
|
321
329
|
group="Scalar Math",
|
|
322
330
|
doc="""Return the largest integer that is less than or equal to ``x``.""",
|
|
331
|
+
missing_grad=True,
|
|
323
332
|
)
|
|
324
333
|
|
|
325
334
|
add_builtin(
|
|
@@ -328,6 +337,7 @@ add_builtin(
|
|
|
328
337
|
value_func=sametypes_create_value_func(Float),
|
|
329
338
|
group="Scalar Math",
|
|
330
339
|
doc="""Return the smallest integer that is greater than or equal to ``x``.""",
|
|
340
|
+
missing_grad=True,
|
|
331
341
|
)
|
|
332
342
|
|
|
333
343
|
add_builtin(
|
|
@@ -338,6 +348,7 @@ add_builtin(
|
|
|
338
348
|
doc="""Retrieve the fractional part of ``x``.
|
|
339
349
|
|
|
340
350
|
In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
|
|
351
|
+
missing_grad=True,
|
|
341
352
|
)
|
|
342
353
|
|
|
343
354
|
add_builtin(
|
|
@@ -346,6 +357,7 @@ add_builtin(
|
|
|
346
357
|
value_type=builtins.bool,
|
|
347
358
|
group="Scalar Math",
|
|
348
359
|
doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
|
|
360
|
+
missing_grad=True,
|
|
349
361
|
)
|
|
350
362
|
add_builtin(
|
|
351
363
|
"isfinite",
|
|
@@ -353,6 +365,7 @@ add_builtin(
|
|
|
353
365
|
value_type=builtins.bool,
|
|
354
366
|
group="Vector Math",
|
|
355
367
|
doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
|
|
368
|
+
missing_grad=True,
|
|
356
369
|
)
|
|
357
370
|
add_builtin(
|
|
358
371
|
"isfinite",
|
|
@@ -360,6 +373,7 @@ add_builtin(
|
|
|
360
373
|
value_type=builtins.bool,
|
|
361
374
|
group="Vector Math",
|
|
362
375
|
doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
|
|
376
|
+
missing_grad=True,
|
|
363
377
|
)
|
|
364
378
|
add_builtin(
|
|
365
379
|
"isfinite",
|
|
@@ -367,6 +381,7 @@ add_builtin(
|
|
|
367
381
|
value_type=builtins.bool,
|
|
368
382
|
group="Vector Math",
|
|
369
383
|
doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
|
|
384
|
+
missing_grad=True,
|
|
370
385
|
)
|
|
371
386
|
|
|
372
387
|
add_builtin(
|
|
@@ -375,6 +390,7 @@ add_builtin(
|
|
|
375
390
|
value_type=builtins.bool,
|
|
376
391
|
doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
|
|
377
392
|
group="Scalar Math",
|
|
393
|
+
missing_grad=True,
|
|
378
394
|
)
|
|
379
395
|
add_builtin(
|
|
380
396
|
"isnan",
|
|
@@ -382,6 +398,7 @@ add_builtin(
|
|
|
382
398
|
value_type=builtins.bool,
|
|
383
399
|
group="Vector Math",
|
|
384
400
|
doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
|
|
401
|
+
missing_grad=True,
|
|
385
402
|
)
|
|
386
403
|
add_builtin(
|
|
387
404
|
"isnan",
|
|
@@ -389,6 +406,7 @@ add_builtin(
|
|
|
389
406
|
value_type=builtins.bool,
|
|
390
407
|
group="Vector Math",
|
|
391
408
|
doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
|
|
409
|
+
missing_grad=True,
|
|
392
410
|
)
|
|
393
411
|
add_builtin(
|
|
394
412
|
"isnan",
|
|
@@ -396,6 +414,7 @@ add_builtin(
|
|
|
396
414
|
value_type=builtins.bool,
|
|
397
415
|
group="Vector Math",
|
|
398
416
|
doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
|
|
417
|
+
missing_grad=True,
|
|
399
418
|
)
|
|
400
419
|
|
|
401
420
|
add_builtin(
|
|
@@ -404,6 +423,7 @@ add_builtin(
|
|
|
404
423
|
value_type=builtins.bool,
|
|
405
424
|
group="Scalar Math",
|
|
406
425
|
doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
|
|
426
|
+
missing_grad=True,
|
|
407
427
|
)
|
|
408
428
|
add_builtin(
|
|
409
429
|
"isinf",
|
|
@@ -411,6 +431,7 @@ add_builtin(
|
|
|
411
431
|
value_type=builtins.bool,
|
|
412
432
|
group="Vector Math",
|
|
413
433
|
doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
434
|
+
missing_grad=True,
|
|
414
435
|
)
|
|
415
436
|
add_builtin(
|
|
416
437
|
"isinf",
|
|
@@ -418,6 +439,7 @@ add_builtin(
|
|
|
418
439
|
value_type=builtins.bool,
|
|
419
440
|
group="Vector Math",
|
|
420
441
|
doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
442
|
+
missing_grad=True,
|
|
421
443
|
)
|
|
422
444
|
add_builtin(
|
|
423
445
|
"isinf",
|
|
@@ -425,6 +447,7 @@ add_builtin(
|
|
|
425
447
|
value_type=builtins.bool,
|
|
426
448
|
group="Vector Math",
|
|
427
449
|
doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
450
|
+
missing_grad=True,
|
|
428
451
|
)
|
|
429
452
|
|
|
430
453
|
|
|
@@ -1180,6 +1203,7 @@ add_builtin(
|
|
|
1180
1203
|
doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
|
|
1181
1204
|
group="Vector Math",
|
|
1182
1205
|
export=False,
|
|
1206
|
+
missing_grad=True,
|
|
1183
1207
|
)
|
|
1184
1208
|
|
|
1185
1209
|
|
|
@@ -1544,6 +1568,7 @@ add_builtin(
|
|
|
1544
1568
|
group="Quaternion Math",
|
|
1545
1569
|
doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
|
|
1546
1570
|
export=True,
|
|
1571
|
+
missing_grad=True,
|
|
1547
1572
|
)
|
|
1548
1573
|
|
|
1549
1574
|
add_builtin(
|
|
@@ -1759,6 +1784,7 @@ add_builtin(
|
|
|
1759
1784
|
doc="Construct a spatial transform vector of given dtype.",
|
|
1760
1785
|
group="Spatial Math",
|
|
1761
1786
|
export=False,
|
|
1787
|
+
missing_grad=True,
|
|
1762
1788
|
)
|
|
1763
1789
|
|
|
1764
1790
|
|
|
@@ -1793,6 +1819,7 @@ add_builtin(
|
|
|
1793
1819
|
group="Transformations",
|
|
1794
1820
|
doc="Construct an identity transform with zero translation and identity rotation.",
|
|
1795
1821
|
export=True,
|
|
1822
|
+
missing_grad=True,
|
|
1796
1823
|
)
|
|
1797
1824
|
|
|
1798
1825
|
add_builtin(
|
|
@@ -2355,6 +2382,7 @@ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mappin
|
|
|
2355
2382
|
def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2356
2383
|
a = args["a"]
|
|
2357
2384
|
shape = extract_tuple(args["shape"], as_constant=True)
|
|
2385
|
+
bounds_check = args["bounds_check"]
|
|
2358
2386
|
|
|
2359
2387
|
if None in shape:
|
|
2360
2388
|
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
@@ -2365,17 +2393,23 @@ def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type:
|
|
|
2365
2393
|
offset = (0,) * a.type.ndim
|
|
2366
2394
|
|
|
2367
2395
|
func_args = (a, *offset)
|
|
2368
|
-
template_args = shape
|
|
2396
|
+
template_args = (return_type.dtype, bounds_check.constant, *shape)
|
|
2369
2397
|
|
|
2370
2398
|
return (func_args, template_args)
|
|
2371
2399
|
|
|
2372
2400
|
|
|
2373
2401
|
add_builtin(
|
|
2374
2402
|
"tile_load",
|
|
2375
|
-
input_types={
|
|
2403
|
+
input_types={
|
|
2404
|
+
"a": array(dtype=Any),
|
|
2405
|
+
"shape": Tuple[int, ...],
|
|
2406
|
+
"offset": Tuple[int, ...],
|
|
2407
|
+
"storage": str,
|
|
2408
|
+
"bounds_check": builtins.bool,
|
|
2409
|
+
},
|
|
2376
2410
|
value_func=tile_load_tuple_value_func,
|
|
2377
2411
|
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2378
|
-
defaults={"offset": None, "storage": "register"},
|
|
2412
|
+
defaults={"offset": None, "storage": "register", "bounds_check": True},
|
|
2379
2413
|
variadic=False,
|
|
2380
2414
|
doc="""Loads a tile from a global memory array.
|
|
2381
2415
|
|
|
@@ -2386,6 +2420,7 @@ add_builtin(
|
|
|
2386
2420
|
:param offset: Offset in the source array to begin reading from (optional)
|
|
2387
2421
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2388
2422
|
(default) or ``"shared"`` for shared memory.
|
|
2423
|
+
:param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster load times
|
|
2389
2424
|
:returns: A tile with shape as specified and data type the same as the source array""",
|
|
2390
2425
|
group="Tile Primitives",
|
|
2391
2426
|
export=False,
|
|
@@ -2394,16 +2429,160 @@ add_builtin(
|
|
|
2394
2429
|
# overload for scalar shape
|
|
2395
2430
|
add_builtin(
|
|
2396
2431
|
"tile_load",
|
|
2397
|
-
input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
|
|
2432
|
+
input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str, "bounds_check": builtins.bool},
|
|
2398
2433
|
value_func=tile_load_tuple_value_func,
|
|
2399
2434
|
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2400
|
-
defaults={"offset": None, "storage": "register"},
|
|
2435
|
+
defaults={"offset": None, "storage": "register", "bounds_check": True},
|
|
2401
2436
|
group="Tile Primitives",
|
|
2402
2437
|
hidden=True,
|
|
2403
2438
|
export=False,
|
|
2404
2439
|
)
|
|
2405
2440
|
|
|
2406
2441
|
|
|
2442
|
+
def tile_load_indexed_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2443
|
+
if arg_types is None:
|
|
2444
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2445
|
+
|
|
2446
|
+
a = arg_types["a"]
|
|
2447
|
+
|
|
2448
|
+
indices_tile = arg_types["indices"]
|
|
2449
|
+
indices_tile.storage = "shared" # force to shared
|
|
2450
|
+
|
|
2451
|
+
axis = arg_values["axis"]
|
|
2452
|
+
if axis >= a.ndim:
|
|
2453
|
+
raise ValueError(f"tile_load_indexed() axis argument must be valid axis of array {a}, got {axis}.")
|
|
2454
|
+
|
|
2455
|
+
indices_tile_dim = len(indices_tile.shape)
|
|
2456
|
+
if indices_tile_dim != 1:
|
|
2457
|
+
raise ValueError(
|
|
2458
|
+
f"tile_load_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
|
|
2459
|
+
)
|
|
2460
|
+
|
|
2461
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2462
|
+
|
|
2463
|
+
if None in shape:
|
|
2464
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2465
|
+
|
|
2466
|
+
num_indices = indices_tile.shape[0]
|
|
2467
|
+
if num_indices != shape[axis]:
|
|
2468
|
+
raise ValueError(
|
|
2469
|
+
"The number of elements in the 1D indices tile must match the output tile shape along the specified axis."
|
|
2470
|
+
)
|
|
2471
|
+
|
|
2472
|
+
if "offset" in arg_values:
|
|
2473
|
+
offset = extract_tuple(arg_values["offset"])
|
|
2474
|
+
else:
|
|
2475
|
+
offset = (0,) * a.ndim
|
|
2476
|
+
|
|
2477
|
+
if a.ndim != len(shape):
|
|
2478
|
+
raise ValueError(
|
|
2479
|
+
f"tile_load_indexed() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
|
|
2480
|
+
)
|
|
2481
|
+
|
|
2482
|
+
if a.ndim != len(offset):
|
|
2483
|
+
raise ValueError(
|
|
2484
|
+
f"tile_load_indexed() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
|
|
2485
|
+
)
|
|
2486
|
+
|
|
2487
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2488
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2489
|
+
|
|
2490
|
+
return tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
|
|
2491
|
+
|
|
2492
|
+
|
|
2493
|
+
def tile_load_indexed_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2494
|
+
a = args["a"]
|
|
2495
|
+
indices_tile = args["indices"]
|
|
2496
|
+
axis = args["axis"]
|
|
2497
|
+
|
|
2498
|
+
shape = extract_tuple(args["shape"], as_constant=True)
|
|
2499
|
+
|
|
2500
|
+
if None in shape:
|
|
2501
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2502
|
+
|
|
2503
|
+
if "offset" in args:
|
|
2504
|
+
offset = extract_tuple(args["offset"])
|
|
2505
|
+
else:
|
|
2506
|
+
offset = (0,) * a.type.ndim
|
|
2507
|
+
|
|
2508
|
+
func_args = (a, indices_tile, axis, *offset)
|
|
2509
|
+
template_args = shape
|
|
2510
|
+
|
|
2511
|
+
return (func_args, template_args)
|
|
2512
|
+
|
|
2513
|
+
|
|
2514
|
+
add_builtin(
|
|
2515
|
+
"tile_load_indexed",
|
|
2516
|
+
input_types={
|
|
2517
|
+
"a": array(dtype=Any),
|
|
2518
|
+
"indices": tile(dtype=int, shape=Tuple[int]),
|
|
2519
|
+
"shape": Tuple[int, ...],
|
|
2520
|
+
"offset": Tuple[int, ...],
|
|
2521
|
+
"axis": int,
|
|
2522
|
+
"storage": str,
|
|
2523
|
+
},
|
|
2524
|
+
value_func=tile_load_indexed_tuple_value_func,
|
|
2525
|
+
dispatch_func=tile_load_indexed_tuple_dispatch_func,
|
|
2526
|
+
defaults={"offset": None, "axis": 0, "storage": "register"},
|
|
2527
|
+
variadic=False,
|
|
2528
|
+
doc="""Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.
|
|
2529
|
+
|
|
2530
|
+
:param a: The source array in global memory
|
|
2531
|
+
:param indices: A 1D tile of integer indices mapping to elements in ``a``.
|
|
2532
|
+
:param shape: Shape of the tile to load, must have the same number of dimensions as ``a``, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
|
|
2533
|
+
:param offset: Offset in the source array to begin reading from (optional)
|
|
2534
|
+
:param axis: Axis of ``a`` that indices refer to
|
|
2535
|
+
:param storage: The storage location for the tile: ``"register"`` for registers (default) or ``"shared"`` for shared memory.
|
|
2536
|
+
:returns: A tile with shape as specified and data type the same as the source array
|
|
2537
|
+
|
|
2538
|
+
This example shows how to select and store the even indexed rows from a 2D array:
|
|
2539
|
+
|
|
2540
|
+
.. code-block:: python
|
|
2541
|
+
|
|
2542
|
+
TILE_M = wp.constant(2)
|
|
2543
|
+
TILE_N = wp.constant(2)
|
|
2544
|
+
HALF_M = wp.constant(TILE_M // 2)
|
|
2545
|
+
HALF_N = wp.constant(TILE_N // 2)
|
|
2546
|
+
|
|
2547
|
+
@wp.kernel
|
|
2548
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
2549
|
+
i, j = wp.tid()
|
|
2550
|
+
|
|
2551
|
+
evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2
|
|
2552
|
+
|
|
2553
|
+
t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register")
|
|
2554
|
+
wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N))
|
|
2555
|
+
|
|
2556
|
+
M = TILE_M * 2
|
|
2557
|
+
N = TILE_N * 2
|
|
2558
|
+
|
|
2559
|
+
arr = np.arange(M * N).reshape(M, N)
|
|
2560
|
+
|
|
2561
|
+
x = wp.array(arr, dtype=float)
|
|
2562
|
+
y = wp.zeros((M // 2, N), dtype=float)
|
|
2563
|
+
|
|
2564
|
+
wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
|
|
2565
|
+
|
|
2566
|
+
print(x.numpy())
|
|
2567
|
+
print(y.numpy())
|
|
2568
|
+
|
|
2569
|
+
Prints:
|
|
2570
|
+
|
|
2571
|
+
.. code-block:: text
|
|
2572
|
+
|
|
2573
|
+
[[ 0. 1. 2. 3.]
|
|
2574
|
+
[ 4. 5. 6. 7.]
|
|
2575
|
+
[ 8. 9. 10. 11.]
|
|
2576
|
+
[12. 13. 14. 15.]]
|
|
2577
|
+
|
|
2578
|
+
[[ 0. 1. 2. 3.]
|
|
2579
|
+
[ 8. 9. 10. 11.]]
|
|
2580
|
+
""",
|
|
2581
|
+
group="Tile Primitives",
|
|
2582
|
+
export=False,
|
|
2583
|
+
)
|
|
2584
|
+
|
|
2585
|
+
|
|
2407
2586
|
def tile_store_value_func(arg_types, arg_values):
|
|
2408
2587
|
# return generic type (for doc builds)
|
|
2409
2588
|
if arg_types is None:
|
|
@@ -2440,6 +2619,7 @@ def tile_store_value_func(arg_types, arg_values):
|
|
|
2440
2619
|
def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2441
2620
|
a = args["a"]
|
|
2442
2621
|
t = args["t"]
|
|
2622
|
+
bounds_check = args["bounds_check"]
|
|
2443
2623
|
|
|
2444
2624
|
if "offset" in args:
|
|
2445
2625
|
offset = extract_tuple(args["offset"])
|
|
@@ -2447,17 +2627,22 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2447
2627
|
offset = (0,) * a.type.ndim
|
|
2448
2628
|
|
|
2449
2629
|
func_args = (a, *offset, t)
|
|
2450
|
-
template_args =
|
|
2630
|
+
template_args = (a.type.dtype, bounds_check.constant)
|
|
2451
2631
|
|
|
2452
2632
|
return (func_args, template_args)
|
|
2453
2633
|
|
|
2454
2634
|
|
|
2455
2635
|
add_builtin(
|
|
2456
2636
|
"tile_store",
|
|
2457
|
-
input_types={
|
|
2637
|
+
input_types={
|
|
2638
|
+
"a": array(dtype=Any),
|
|
2639
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2640
|
+
"offset": Tuple[int, ...],
|
|
2641
|
+
"bounds_check": builtins.bool,
|
|
2642
|
+
},
|
|
2458
2643
|
value_func=tile_store_value_func,
|
|
2459
2644
|
dispatch_func=tile_store_dispatch_func,
|
|
2460
|
-
defaults={"offset": None},
|
|
2645
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2461
2646
|
variadic=False,
|
|
2462
2647
|
skip_replay=True,
|
|
2463
2648
|
doc="""Store a tile to a global memory array.
|
|
@@ -2466,7 +2651,9 @@ add_builtin(
|
|
|
2466
2651
|
|
|
2467
2652
|
:param a: The destination array in global memory
|
|
2468
2653
|
:param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
|
|
2469
|
-
:param offset: Offset in the destination array (optional)
|
|
2654
|
+
:param offset: Offset in the destination array (optional)
|
|
2655
|
+
:param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
|
|
2656
|
+
""",
|
|
2470
2657
|
group="Tile Primitives",
|
|
2471
2658
|
export=False,
|
|
2472
2659
|
)
|
|
@@ -2474,10 +2661,15 @@ add_builtin(
|
|
|
2474
2661
|
# overload for scalar offset
|
|
2475
2662
|
add_builtin(
|
|
2476
2663
|
"tile_store",
|
|
2477
|
-
input_types={
|
|
2664
|
+
input_types={
|
|
2665
|
+
"a": array(dtype=Any),
|
|
2666
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2667
|
+
"offset": int,
|
|
2668
|
+
"bounds_check": builtins.bool,
|
|
2669
|
+
},
|
|
2478
2670
|
value_func=tile_store_value_func,
|
|
2479
2671
|
dispatch_func=tile_store_dispatch_func,
|
|
2480
|
-
defaults={"offset": None},
|
|
2672
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2481
2673
|
variadic=False,
|
|
2482
2674
|
skip_replay=True,
|
|
2483
2675
|
group="Tile Primitives",
|
|
@@ -2486,6 +2678,151 @@ add_builtin(
|
|
|
2486
2678
|
)
|
|
2487
2679
|
|
|
2488
2680
|
|
|
2681
|
+
def tile_store_indexed_value_func(arg_types, arg_values):
|
|
2682
|
+
# return generic type (for doc builds)
|
|
2683
|
+
if arg_types is None:
|
|
2684
|
+
return None
|
|
2685
|
+
|
|
2686
|
+
a = arg_types["a"]
|
|
2687
|
+
t = arg_types["t"]
|
|
2688
|
+
indices_tile = arg_types["indices"]
|
|
2689
|
+
indices_tile.storage = "shared" # force to shared
|
|
2690
|
+
|
|
2691
|
+
axis = arg_values["axis"]
|
|
2692
|
+
if axis >= a.ndim:
|
|
2693
|
+
raise ValueError(f"tile_store_indexed() axis argument must be valid axis of array {a}, got {axis}.")
|
|
2694
|
+
|
|
2695
|
+
indices_tile_dim = len(indices_tile.shape)
|
|
2696
|
+
if indices_tile_dim != 1:
|
|
2697
|
+
raise ValueError(
|
|
2698
|
+
f"tile_store_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
|
|
2699
|
+
)
|
|
2700
|
+
|
|
2701
|
+
num_indices = indices_tile.shape[0]
|
|
2702
|
+
if num_indices != t.shape[axis]:
|
|
2703
|
+
raise ValueError(
|
|
2704
|
+
"The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
|
|
2705
|
+
)
|
|
2706
|
+
|
|
2707
|
+
if "offset" in arg_types:
|
|
2708
|
+
c = extract_tuple(arg_values["offset"])
|
|
2709
|
+
else:
|
|
2710
|
+
c = (0,) * a.ndim
|
|
2711
|
+
|
|
2712
|
+
if len(c) != a.ndim:
|
|
2713
|
+
raise ValueError(
|
|
2714
|
+
f"tile_store_indexed() 'a' argument must have {len(c)} dimensions, "
|
|
2715
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2716
|
+
)
|
|
2717
|
+
|
|
2718
|
+
if len(t.shape) != a.ndim:
|
|
2719
|
+
raise ValueError(
|
|
2720
|
+
f"tile_store_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2721
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2722
|
+
)
|
|
2723
|
+
|
|
2724
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2725
|
+
raise TypeError(
|
|
2726
|
+
f"tile_store_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2727
|
+
)
|
|
2728
|
+
|
|
2729
|
+
return None
|
|
2730
|
+
|
|
2731
|
+
|
|
2732
|
+
def tile_store_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2733
|
+
a = args["a"]
|
|
2734
|
+
indices_tile = args["indices"]
|
|
2735
|
+
axis = args["axis"]
|
|
2736
|
+
t = args["t"]
|
|
2737
|
+
|
|
2738
|
+
if "offset" in args:
|
|
2739
|
+
offset = extract_tuple(args["offset"])
|
|
2740
|
+
else:
|
|
2741
|
+
offset = (0,) * a.type.ndim
|
|
2742
|
+
|
|
2743
|
+
func_args = (a, indices_tile, axis, *offset, t)
|
|
2744
|
+
template_args = []
|
|
2745
|
+
|
|
2746
|
+
return (func_args, template_args)
|
|
2747
|
+
|
|
2748
|
+
|
|
2749
|
+
add_builtin(
|
|
2750
|
+
"tile_store_indexed",
|
|
2751
|
+
input_types={
|
|
2752
|
+
"a": array(dtype=Any),
|
|
2753
|
+
"indices": tile(dtype=int, shape=Tuple[int]),
|
|
2754
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2755
|
+
"offset": Tuple[int, ...],
|
|
2756
|
+
"axis": int,
|
|
2757
|
+
},
|
|
2758
|
+
value_func=tile_store_indexed_value_func,
|
|
2759
|
+
dispatch_func=tile_store_indexed_dispatch_func,
|
|
2760
|
+
defaults={"offset": None, "axis": 0},
|
|
2761
|
+
variadic=False,
|
|
2762
|
+
skip_replay=True,
|
|
2763
|
+
doc="""Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
|
|
2764
|
+
|
|
2765
|
+
:param a: The destination array in global memory
|
|
2766
|
+
:param indices: A 1D tile of integer indices mapping to elements in ``a``.
|
|
2767
|
+
:param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
|
|
2768
|
+
:param offset: Offset in the destination array (optional)
|
|
2769
|
+
:param axis: Axis of ``a`` that indices refer to
|
|
2770
|
+
|
|
2771
|
+
This example shows how to map tile rows to the even rows of a 2D array:
|
|
2772
|
+
|
|
2773
|
+
.. code-block:: python
|
|
2774
|
+
|
|
2775
|
+
TILE_M = wp.constant(2)
|
|
2776
|
+
TILE_N = wp.constant(2)
|
|
2777
|
+
TWO_M = wp.constant(TILE_M * 2)
|
|
2778
|
+
TWO_N = wp.constant(TILE_N * 2)
|
|
2779
|
+
|
|
2780
|
+
@wp.kernel
|
|
2781
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
2782
|
+
i, j = wp.tid()
|
|
2783
|
+
|
|
2784
|
+
t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
|
|
2785
|
+
|
|
2786
|
+
evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2
|
|
2787
|
+
|
|
2788
|
+
wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0)
|
|
2789
|
+
|
|
2790
|
+
M = TILE_M * 2
|
|
2791
|
+
N = TILE_N * 2
|
|
2792
|
+
|
|
2793
|
+
arr = np.arange(M * N, dtype=float).reshape(M, N)
|
|
2794
|
+
|
|
2795
|
+
x = wp.array(arr, dtype=float, requires_grad=True, device=device)
|
|
2796
|
+
y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)
|
|
2797
|
+
|
|
2798
|
+
wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
|
|
2799
|
+
|
|
2800
|
+
print(x.numpy())
|
|
2801
|
+
print(y.numpy())
|
|
2802
|
+
|
|
2803
|
+
Prints:
|
|
2804
|
+
|
|
2805
|
+
.. code-block:: text
|
|
2806
|
+
|
|
2807
|
+
[[ 0. 1. 2. 3.]
|
|
2808
|
+
[ 4. 5. 6. 7.]
|
|
2809
|
+
[ 8. 9. 10. 11.]
|
|
2810
|
+
[12. 13. 14. 15.]]
|
|
2811
|
+
|
|
2812
|
+
[[ 0. 1. 2. 3.]
|
|
2813
|
+
[ 0. 0. 0. 0.]
|
|
2814
|
+
[ 4. 5. 6. 7.]
|
|
2815
|
+
[ 0. 0. 0. 0.]
|
|
2816
|
+
[ 8. 9. 10. 11.]
|
|
2817
|
+
[ 0. 0. 0. 0.]
|
|
2818
|
+
[12. 13. 14. 15.]
|
|
2819
|
+
[ 0. 0. 0. 0.]]
|
|
2820
|
+
""",
|
|
2821
|
+
group="Tile Primitives",
|
|
2822
|
+
export=False,
|
|
2823
|
+
)
|
|
2824
|
+
|
|
2825
|
+
|
|
2489
2826
|
def tile_atomic_add_value_func(arg_types, arg_values):
|
|
2490
2827
|
# return generic type (for doc builds)
|
|
2491
2828
|
if arg_types is None:
|
|
@@ -2526,6 +2863,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
|
|
|
2526
2863
|
def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2527
2864
|
a = args["a"]
|
|
2528
2865
|
t = args["t"]
|
|
2866
|
+
bounds_check = args["bounds_check"]
|
|
2529
2867
|
|
|
2530
2868
|
if "offset" in args:
|
|
2531
2869
|
offset = extract_tuple(args["offset"])
|
|
@@ -2533,17 +2871,22 @@ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type:
|
|
|
2533
2871
|
offset = (0,) * a.type.ndim
|
|
2534
2872
|
|
|
2535
2873
|
func_args = (a, *offset, t)
|
|
2536
|
-
template_args =
|
|
2874
|
+
template_args = (a.type.dtype, bounds_check.constant)
|
|
2537
2875
|
|
|
2538
2876
|
return (func_args, template_args)
|
|
2539
2877
|
|
|
2540
2878
|
|
|
2541
2879
|
add_builtin(
|
|
2542
2880
|
"tile_atomic_add",
|
|
2543
|
-
input_types={
|
|
2881
|
+
input_types={
|
|
2882
|
+
"a": array(dtype=Any),
|
|
2883
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2884
|
+
"offset": Tuple[int, ...],
|
|
2885
|
+
"bounds_check": builtins.bool,
|
|
2886
|
+
},
|
|
2544
2887
|
value_func=tile_atomic_add_value_func,
|
|
2545
2888
|
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2546
|
-
defaults={"offset": None},
|
|
2889
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2547
2890
|
variadic=False,
|
|
2548
2891
|
skip_replay=True,
|
|
2549
2892
|
doc="""Atomically add a tile onto the array `a`, each element will be updated atomically.
|
|
@@ -2551,6 +2894,7 @@ add_builtin(
|
|
|
2551
2894
|
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
2552
2895
|
:param t: Source tile to add to the destination array
|
|
2553
2896
|
:param offset: Offset in the destination array (optional)
|
|
2897
|
+
:param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
|
|
2554
2898
|
:returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
|
|
2555
2899
|
group="Tile Primitives",
|
|
2556
2900
|
export=False,
|
|
@@ -2559,10 +2903,15 @@ add_builtin(
|
|
|
2559
2903
|
# overload for scalar offset
|
|
2560
2904
|
add_builtin(
|
|
2561
2905
|
"tile_atomic_add",
|
|
2562
|
-
input_types={
|
|
2906
|
+
input_types={
|
|
2907
|
+
"a": array(dtype=Any),
|
|
2908
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2909
|
+
"offset": int,
|
|
2910
|
+
"bounds_check": builtins.bool,
|
|
2911
|
+
},
|
|
2563
2912
|
value_func=tile_atomic_add_value_func,
|
|
2564
2913
|
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2565
|
-
defaults={"offset": None},
|
|
2914
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2566
2915
|
variadic=False,
|
|
2567
2916
|
skip_replay=True,
|
|
2568
2917
|
group="Tile Primitives",
|
|
@@ -2571,6 +2920,143 @@ add_builtin(
|
|
|
2571
2920
|
)
|
|
2572
2921
|
|
|
2573
2922
|
|
|
2923
|
+
def tile_atomic_add_indexed_value_func(arg_types, arg_values):
|
|
2924
|
+
# return generic type (for doc builds)
|
|
2925
|
+
if arg_types is None:
|
|
2926
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2927
|
+
|
|
2928
|
+
a = arg_types["a"]
|
|
2929
|
+
t = arg_types["t"]
|
|
2930
|
+
indices_tile = arg_types["indices"]
|
|
2931
|
+
indices_tile.storage = "shared" # force to shared
|
|
2932
|
+
|
|
2933
|
+
axis = arg_values["axis"]
|
|
2934
|
+
if axis >= a.ndim:
|
|
2935
|
+
raise ValueError(f"tile_atomic_add_indexed() axis argument must be valid axis of array {a}, got {axis}.")
|
|
2936
|
+
|
|
2937
|
+
indices_tile_dim = len(indices_tile.shape)
|
|
2938
|
+
if indices_tile_dim != 1:
|
|
2939
|
+
raise ValueError(
|
|
2940
|
+
f"tile_atomic_add_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
|
|
2941
|
+
)
|
|
2942
|
+
|
|
2943
|
+
num_indices = indices_tile.shape[0]
|
|
2944
|
+
if num_indices != t.shape[axis]:
|
|
2945
|
+
raise ValueError(
|
|
2946
|
+
"The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
|
|
2947
|
+
)
|
|
2948
|
+
|
|
2949
|
+
if "offset" in arg_types:
|
|
2950
|
+
c = extract_tuple(arg_values["offset"])
|
|
2951
|
+
else:
|
|
2952
|
+
c = (0,) * a.ndim
|
|
2953
|
+
|
|
2954
|
+
if len(c) != a.ndim:
|
|
2955
|
+
raise ValueError(
|
|
2956
|
+
f"tile_atomic_add_indexed() 'a' argument must have {len(c)} dimensions, "
|
|
2957
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2958
|
+
)
|
|
2959
|
+
|
|
2960
|
+
if len(t.shape) != a.ndim:
|
|
2961
|
+
raise ValueError(
|
|
2962
|
+
f"tile_atomic_add_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2963
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2964
|
+
)
|
|
2965
|
+
|
|
2966
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2967
|
+
raise TypeError(
|
|
2968
|
+
f"tile_atomic_add_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2969
|
+
)
|
|
2970
|
+
|
|
2971
|
+
return tile(dtype=t.dtype, shape=t.shape, storage=t.storage)
|
|
2972
|
+
|
|
2973
|
+
|
|
2974
|
+
def tile_atomic_add_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2975
|
+
a = args["a"]
|
|
2976
|
+
indices_tile = args["indices"]
|
|
2977
|
+
axis = args["axis"]
|
|
2978
|
+
t = args["t"]
|
|
2979
|
+
|
|
2980
|
+
if "offset" in args:
|
|
2981
|
+
offset = extract_tuple(args["offset"])
|
|
2982
|
+
else:
|
|
2983
|
+
offset = (0,) * a.type.ndim
|
|
2984
|
+
|
|
2985
|
+
func_args = (a, indices_tile, axis, *offset, t)
|
|
2986
|
+
template_args = []
|
|
2987
|
+
|
|
2988
|
+
return (func_args, template_args)
|
|
2989
|
+
|
|
2990
|
+
|
|
2991
|
+
add_builtin(
|
|
2992
|
+
"tile_atomic_add_indexed",
|
|
2993
|
+
input_types={
|
|
2994
|
+
"a": array(dtype=Any),
|
|
2995
|
+
"indices": tile(dtype=int, shape=Tuple[int]),
|
|
2996
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2997
|
+
"offset": Tuple[int, ...],
|
|
2998
|
+
"axis": int,
|
|
2999
|
+
},
|
|
3000
|
+
value_func=tile_atomic_add_indexed_value_func,
|
|
3001
|
+
dispatch_func=tile_atomic_add_indexed_dispatch_func,
|
|
3002
|
+
defaults={"offset": None, "axis": 0},
|
|
3003
|
+
variadic=False,
|
|
3004
|
+
skip_replay=True,
|
|
3005
|
+
doc="""Atomically add a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
|
|
3006
|
+
|
|
3007
|
+
:param a: The destination array in global memory
|
|
3008
|
+
:param indices: A 1D tile of integer indices mapping to elements in ``a``.
|
|
3009
|
+
:param t: The source tile to extract data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
|
|
3010
|
+
:param offset: Offset in the destination array (optional)
|
|
3011
|
+
:param axis: Axis of ``a`` that indices refer to
|
|
3012
|
+
|
|
3013
|
+
This example shows how to compute a blocked, row-wise reduction:
|
|
3014
|
+
|
|
3015
|
+
.. code-block:: python
|
|
3016
|
+
|
|
3017
|
+
TILE_M = wp.constant(2)
|
|
3018
|
+
TILE_N = wp.constant(2)
|
|
3019
|
+
|
|
3020
|
+
@wp.kernel
|
|
3021
|
+
def tile_atomic_add_indexed(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
3022
|
+
i, j = wp.tid()
|
|
3023
|
+
|
|
3024
|
+
t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
|
|
3025
|
+
|
|
3026
|
+
zeros = wp.tile_zeros(TILE_M, dtype=int, storage="shared")
|
|
3027
|
+
|
|
3028
|
+
wp.tile_atomic_add_indexed(y, indices=zeros, t=t, offset=(i, j*TILE_N), axis=0)
|
|
3029
|
+
|
|
3030
|
+
M = TILE_M * 2
|
|
3031
|
+
N = TILE_N * 2
|
|
3032
|
+
|
|
3033
|
+
arr = np.arange(M * N, dtype=float).reshape(M, N)
|
|
3034
|
+
|
|
3035
|
+
x = wp.array(arr, dtype=float, requires_grad=True, device=device)
|
|
3036
|
+
y = wp.zeros((2, N), dtype=float, requires_grad=True, device=device)
|
|
3037
|
+
|
|
3038
|
+
wp.launch_tiled(tile_atomic_add_indexed, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
|
|
3039
|
+
|
|
3040
|
+
print(x.numpy())
|
|
3041
|
+
print(y.numpy())
|
|
3042
|
+
|
|
3043
|
+
Prints:
|
|
3044
|
+
|
|
3045
|
+
.. code-block:: text
|
|
3046
|
+
|
|
3047
|
+
[[ 0. 1. 2. 3.]
|
|
3048
|
+
[ 4. 5. 6. 7.]
|
|
3049
|
+
[ 8. 9. 10. 11.]
|
|
3050
|
+
[12. 13. 14. 15.]]
|
|
3051
|
+
|
|
3052
|
+
[[ 4. 6. 8. 10.]
|
|
3053
|
+
[20. 22. 24. 26.]]
|
|
3054
|
+
""",
|
|
3055
|
+
group="Tile Primitives",
|
|
3056
|
+
export=False,
|
|
3057
|
+
)
|
|
3058
|
+
|
|
3059
|
+
|
|
2574
3060
|
def tile_view_value_func(arg_types, arg_values):
|
|
2575
3061
|
# return generic type (for doc builds)
|
|
2576
3062
|
if arg_types is None:
|
|
@@ -3525,6 +4011,7 @@ add_builtin(
|
|
|
3525
4011
|
""",
|
|
3526
4012
|
group="Tile Primitives",
|
|
3527
4013
|
export=False,
|
|
4014
|
+
missing_grad=True,
|
|
3528
4015
|
)
|
|
3529
4016
|
|
|
3530
4017
|
|
|
@@ -3578,6 +4065,7 @@ add_builtin(
|
|
|
3578
4065
|
""",
|
|
3579
4066
|
group="Tile Primitives",
|
|
3580
4067
|
export=False,
|
|
4068
|
+
missing_grad=True,
|
|
3581
4069
|
)
|
|
3582
4070
|
|
|
3583
4071
|
|
|
@@ -3631,6 +4119,7 @@ add_builtin(
|
|
|
3631
4119
|
""",
|
|
3632
4120
|
group="Tile Primitives",
|
|
3633
4121
|
export=False,
|
|
4122
|
+
missing_grad=True,
|
|
3634
4123
|
)
|
|
3635
4124
|
|
|
3636
4125
|
|
|
@@ -3683,6 +4172,7 @@ add_builtin(
|
|
|
3683
4172
|
""",
|
|
3684
4173
|
group="Tile Primitives",
|
|
3685
4174
|
export=False,
|
|
4175
|
+
missing_grad=True,
|
|
3686
4176
|
)
|
|
3687
4177
|
|
|
3688
4178
|
|
|
@@ -3735,6 +4225,7 @@ add_builtin(
|
|
|
3735
4225
|
""",
|
|
3736
4226
|
group="Tile Primitives",
|
|
3737
4227
|
export=False,
|
|
4228
|
+
missing_grad=True,
|
|
3738
4229
|
)
|
|
3739
4230
|
|
|
3740
4231
|
|
|
@@ -3792,6 +4283,7 @@ add_builtin(
|
|
|
3792
4283
|
""",
|
|
3793
4284
|
group="Tile Primitives",
|
|
3794
4285
|
export=False,
|
|
4286
|
+
missing_grad=True,
|
|
3795
4287
|
)
|
|
3796
4288
|
|
|
3797
4289
|
|
|
@@ -3855,6 +4347,7 @@ add_builtin(
|
|
|
3855
4347
|
""",
|
|
3856
4348
|
group="Tile Primitives",
|
|
3857
4349
|
export=False,
|
|
4350
|
+
missing_grad=True,
|
|
3858
4351
|
)
|
|
3859
4352
|
|
|
3860
4353
|
|
|
@@ -3918,6 +4411,7 @@ add_builtin(
|
|
|
3918
4411
|
""",
|
|
3919
4412
|
group="Tile Primitives",
|
|
3920
4413
|
export=False,
|
|
4414
|
+
missing_grad=True,
|
|
3921
4415
|
)
|
|
3922
4416
|
|
|
3923
4417
|
|
|
@@ -3934,14 +4428,45 @@ def tile_unary_map_value_func(arg_types, arg_values):
|
|
|
3934
4428
|
if not is_tile(a):
|
|
3935
4429
|
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
3936
4430
|
|
|
3937
|
-
|
|
4431
|
+
if "op" in arg_values:
|
|
4432
|
+
op = arg_values["op"]
|
|
4433
|
+
try:
|
|
4434
|
+
overload = op.get_overload([a.dtype], {})
|
|
4435
|
+
except KeyError as exc:
|
|
4436
|
+
raise RuntimeError(f"No overload of {op} found for tile element type {type_repr(a.dtype)}") from exc
|
|
4437
|
+
|
|
4438
|
+
# build the right overload on demand
|
|
4439
|
+
if overload.value_func is None:
|
|
4440
|
+
overload.build(None)
|
|
4441
|
+
|
|
4442
|
+
value_type = overload.value_func(None, None)
|
|
4443
|
+
|
|
4444
|
+
if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
|
|
4445
|
+
raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
|
|
4446
|
+
|
|
4447
|
+
return tile(dtype=value_type, shape=a.shape)
|
|
4448
|
+
|
|
4449
|
+
else:
|
|
4450
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
4451
|
+
|
|
4452
|
+
|
|
4453
|
+
def tile_unary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4454
|
+
op = arg_values["op"]
|
|
4455
|
+
tile_a = arg_values["a"]
|
|
4456
|
+
|
|
4457
|
+
overload = op.get_overload([tile_a.type.dtype], {})
|
|
4458
|
+
|
|
4459
|
+
# necessary, in case return type is different from input tile types
|
|
4460
|
+
tile_r = Var(label=None, type=return_type)
|
|
4461
|
+
|
|
4462
|
+
return ((overload, tile_a, tile_r), ())
|
|
3938
4463
|
|
|
3939
4464
|
|
|
3940
4465
|
add_builtin(
|
|
3941
4466
|
"tile_map",
|
|
3942
4467
|
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3943
4468
|
value_func=tile_unary_map_value_func,
|
|
3944
|
-
|
|
4469
|
+
dispatch_func=tile_unary_map_dispatch_func,
|
|
3945
4470
|
# variadic=True,
|
|
3946
4471
|
native_func="tile_unary_map",
|
|
3947
4472
|
doc="""Apply a unary function onto the tile.
|
|
@@ -3950,7 +4475,7 @@ add_builtin(
|
|
|
3950
4475
|
|
|
3951
4476
|
:param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
|
|
3952
4477
|
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
|
|
3953
|
-
:returns: A tile with the same dimensions
|
|
4478
|
+
:returns: A tile with the same dimensions as the input tile. Its datatype is specified by the return type of op
|
|
3954
4479
|
|
|
3955
4480
|
Example:
|
|
3956
4481
|
|
|
@@ -3991,10 +4516,6 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
3991
4516
|
if not is_tile(b):
|
|
3992
4517
|
raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
|
|
3993
4518
|
|
|
3994
|
-
# ensure types equal
|
|
3995
|
-
if not types_equal(a.dtype, b.dtype):
|
|
3996
|
-
raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
|
|
3997
|
-
|
|
3998
4519
|
if len(a.shape) != len(b.shape):
|
|
3999
4520
|
raise ValueError(
|
|
4000
4521
|
f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
|
|
@@ -4004,7 +4525,47 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
4004
4525
|
if a.shape[i] != b.shape[i]:
|
|
4005
4526
|
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
4006
4527
|
|
|
4007
|
-
|
|
4528
|
+
if "op" in arg_values:
|
|
4529
|
+
op = arg_values["op"]
|
|
4530
|
+
try:
|
|
4531
|
+
overload = op.get_overload([a.dtype, b.dtype], {})
|
|
4532
|
+
except KeyError as exc:
|
|
4533
|
+
raise RuntimeError(
|
|
4534
|
+
f"No overload of {op} found for tile element types {type_repr(a.dtype)}, {type_repr(b.dtype)}"
|
|
4535
|
+
) from exc
|
|
4536
|
+
|
|
4537
|
+
# build the right overload on demand
|
|
4538
|
+
if overload.value_func is None:
|
|
4539
|
+
overload.build(None)
|
|
4540
|
+
|
|
4541
|
+
value_type = overload.value_func(None, None)
|
|
4542
|
+
|
|
4543
|
+
if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
|
|
4544
|
+
raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
|
|
4545
|
+
|
|
4546
|
+
return tile(dtype=value_type, shape=a.shape)
|
|
4547
|
+
|
|
4548
|
+
else:
|
|
4549
|
+
# ensure types equal
|
|
4550
|
+
if not types_equal(a.dtype, b.dtype):
|
|
4551
|
+
raise TypeError(
|
|
4552
|
+
f"tile_map() arguments must have the same dtype for this operation, got {a.dtype} and {b.dtype}"
|
|
4553
|
+
)
|
|
4554
|
+
|
|
4555
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
4556
|
+
|
|
4557
|
+
|
|
4558
|
+
def tile_binary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4559
|
+
op = arg_values["op"]
|
|
4560
|
+
tile_a = arg_values["a"]
|
|
4561
|
+
tile_b = arg_values["b"]
|
|
4562
|
+
|
|
4563
|
+
overload = op.get_overload([tile_a.type.dtype, tile_b.type.dtype], {})
|
|
4564
|
+
|
|
4565
|
+
# necessary, in case return type is different from input tile types
|
|
4566
|
+
tile_r = Var(label=None, type=return_type)
|
|
4567
|
+
|
|
4568
|
+
return ((overload, tile_a, tile_b, tile_r), ())
|
|
4008
4569
|
|
|
4009
4570
|
|
|
4010
4571
|
add_builtin(
|
|
@@ -4015,18 +4576,18 @@ add_builtin(
|
|
|
4015
4576
|
"b": tile(dtype=Scalar, shape=Tuple[int, ...]),
|
|
4016
4577
|
},
|
|
4017
4578
|
value_func=tile_binary_map_value_func,
|
|
4018
|
-
|
|
4579
|
+
dispatch_func=tile_binary_map_dispatch_func,
|
|
4019
4580
|
# variadic=True,
|
|
4020
4581
|
native_func="tile_binary_map",
|
|
4021
4582
|
doc="""Apply a binary function onto the tile.
|
|
4022
4583
|
|
|
4023
4584
|
This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
|
|
4024
|
-
Both input tiles must have the same dimensions and
|
|
4585
|
+
Both input tiles must have the same dimensions, and if using a builtin op, the same datatypes.
|
|
4025
4586
|
|
|
4026
4587
|
:param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
|
|
4027
4588
|
:param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
4028
4589
|
:param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
4029
|
-
:returns: A tile with the same dimensions
|
|
4590
|
+
:returns: A tile with the same dimensions as the input tiles. Its datatype is specified by the return type of op
|
|
4030
4591
|
|
|
4031
4592
|
Example:
|
|
4032
4593
|
|
|
@@ -4104,6 +4665,7 @@ add_builtin(
|
|
|
4104
4665
|
doc="WIP",
|
|
4105
4666
|
group="Utility",
|
|
4106
4667
|
hidden=True,
|
|
4668
|
+
missing_grad=True,
|
|
4107
4669
|
)
|
|
4108
4670
|
|
|
4109
4671
|
add_builtin(
|
|
@@ -4119,6 +4681,7 @@ add_builtin(
|
|
|
4119
4681
|
doc="WIP",
|
|
4120
4682
|
group="Utility",
|
|
4121
4683
|
hidden=True,
|
|
4684
|
+
missing_grad=True,
|
|
4122
4685
|
)
|
|
4123
4686
|
|
|
4124
4687
|
add_builtin(
|
|
@@ -4128,6 +4691,7 @@ add_builtin(
|
|
|
4128
4691
|
doc="WIP",
|
|
4129
4692
|
group="Utility",
|
|
4130
4693
|
hidden=True,
|
|
4694
|
+
missing_grad=True,
|
|
4131
4695
|
)
|
|
4132
4696
|
|
|
4133
4697
|
add_builtin(
|
|
@@ -4179,6 +4743,7 @@ add_builtin(
|
|
|
4179
4743
|
:param low: The lower bound of the bounding box in BVH space
|
|
4180
4744
|
:param high: The upper bound of the bounding box in BVH space""",
|
|
4181
4745
|
export=False,
|
|
4746
|
+
missing_grad=True,
|
|
4182
4747
|
)
|
|
4183
4748
|
|
|
4184
4749
|
add_builtin(
|
|
@@ -4194,6 +4759,7 @@ add_builtin(
|
|
|
4194
4759
|
:param start: The start of the ray in BVH space
|
|
4195
4760
|
:param dir: The direction of the ray in BVH space""",
|
|
4196
4761
|
export=False,
|
|
4762
|
+
missing_grad=True,
|
|
4197
4763
|
)
|
|
4198
4764
|
|
|
4199
4765
|
add_builtin(
|
|
@@ -4204,6 +4770,7 @@ add_builtin(
|
|
|
4204
4770
|
doc="""Move to the next bound returned by the query.
|
|
4205
4771
|
The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
|
|
4206
4772
|
export=False,
|
|
4773
|
+
missing_grad=True,
|
|
4207
4774
|
)
|
|
4208
4775
|
|
|
4209
4776
|
add_builtin(
|
|
@@ -4538,12 +5105,13 @@ add_builtin(
|
|
|
4538
5105
|
group="Geometry",
|
|
4539
5106
|
doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
|
|
4540
5107
|
|
|
4541
|
-
This query can be used to iterate over all triangles inside a volume.
|
|
5108
|
+
This query can be used to iterate over all bounding boxes of the triangles inside a volume.
|
|
4542
5109
|
|
|
4543
5110
|
:param id: The mesh identifier
|
|
4544
5111
|
:param low: The lower bound of the bounding box in mesh space
|
|
4545
5112
|
:param high: The upper bound of the bounding box in mesh space""",
|
|
4546
5113
|
export=False,
|
|
5114
|
+
missing_grad=True,
|
|
4547
5115
|
)
|
|
4548
5116
|
|
|
4549
5117
|
add_builtin(
|
|
@@ -4551,10 +5119,11 @@ add_builtin(
|
|
|
4551
5119
|
input_types={"query": MeshQueryAABB, "index": int},
|
|
4552
5120
|
value_type=builtins.bool,
|
|
4553
5121
|
group="Geometry",
|
|
4554
|
-
doc="""Move to the next triangle
|
|
5122
|
+
doc="""Move to the next triangle whose bounding box overlaps the query bounding box.
|
|
4555
5123
|
|
|
4556
5124
|
The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
|
|
4557
5125
|
export=False,
|
|
5126
|
+
missing_grad=True,
|
|
4558
5127
|
)
|
|
4559
5128
|
|
|
4560
5129
|
add_builtin(
|
|
@@ -4584,6 +5153,7 @@ add_builtin(
|
|
|
4584
5153
|
|
|
4585
5154
|
This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
|
|
4586
5155
|
export=False,
|
|
5156
|
+
missing_grad=True,
|
|
4587
5157
|
)
|
|
4588
5158
|
|
|
4589
5159
|
add_builtin(
|
|
@@ -4595,6 +5165,7 @@ add_builtin(
|
|
|
4595
5165
|
|
|
4596
5166
|
The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
|
|
4597
5167
|
export=False,
|
|
5168
|
+
missing_grad=True,
|
|
4598
5169
|
)
|
|
4599
5170
|
|
|
4600
5171
|
add_builtin(
|
|
@@ -4608,6 +5179,7 @@ add_builtin(
|
|
|
4608
5179
|
|
|
4609
5180
|
Returns -1 if the :class:`HashGrid` has not been reserved.""",
|
|
4610
5181
|
export=False,
|
|
5182
|
+
missing_grad=True,
|
|
4611
5183
|
)
|
|
4612
5184
|
|
|
4613
5185
|
add_builtin(
|
|
@@ -4619,6 +5191,7 @@ add_builtin(
|
|
|
4619
5191
|
|
|
4620
5192
|
Returns > 0 if triangles intersect.""",
|
|
4621
5193
|
export=False,
|
|
5194
|
+
missing_grad=True,
|
|
4622
5195
|
)
|
|
4623
5196
|
|
|
4624
5197
|
add_builtin(
|
|
@@ -4638,6 +5211,7 @@ add_builtin(
|
|
|
4638
5211
|
group="Geometry",
|
|
4639
5212
|
doc="""Evaluates the face normal the mesh given a face index.""",
|
|
4640
5213
|
export=False,
|
|
5214
|
+
missing_grad=True,
|
|
4641
5215
|
)
|
|
4642
5216
|
|
|
4643
5217
|
add_builtin(
|
|
@@ -4647,6 +5221,7 @@ add_builtin(
|
|
|
4647
5221
|
group="Geometry",
|
|
4648
5222
|
doc="""Returns the point of the mesh given a index.""",
|
|
4649
5223
|
export=False,
|
|
5224
|
+
missing_grad=True,
|
|
4650
5225
|
)
|
|
4651
5226
|
|
|
4652
5227
|
add_builtin(
|
|
@@ -4656,6 +5231,7 @@ add_builtin(
|
|
|
4656
5231
|
group="Geometry",
|
|
4657
5232
|
doc="""Returns the velocity of the mesh given a index.""",
|
|
4658
5233
|
export=False,
|
|
5234
|
+
missing_grad=True,
|
|
4659
5235
|
)
|
|
4660
5236
|
|
|
4661
5237
|
add_builtin(
|
|
@@ -4665,6 +5241,7 @@ add_builtin(
|
|
|
4665
5241
|
group="Geometry",
|
|
4666
5242
|
doc="""Returns the point-index of the mesh given a face-vertex index.""",
|
|
4667
5243
|
export=False,
|
|
5244
|
+
missing_grad=True,
|
|
4668
5245
|
)
|
|
4669
5246
|
|
|
4670
5247
|
|
|
@@ -4705,12 +5282,32 @@ add_builtin(
|
|
|
4705
5282
|
# ---------------------------------
|
|
4706
5283
|
# Iterators
|
|
4707
5284
|
|
|
4708
|
-
add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
|
|
4709
5285
|
add_builtin(
|
|
4710
|
-
"iter_next",
|
|
5286
|
+
"iter_next",
|
|
5287
|
+
input_types={"range": range_t},
|
|
5288
|
+
value_type=int,
|
|
5289
|
+
group="Utility",
|
|
5290
|
+
export=False,
|
|
5291
|
+
hidden=True,
|
|
5292
|
+
missing_grad=True,
|
|
5293
|
+
)
|
|
5294
|
+
add_builtin(
|
|
5295
|
+
"iter_next",
|
|
5296
|
+
input_types={"query": HashGridQuery},
|
|
5297
|
+
value_type=int,
|
|
5298
|
+
group="Utility",
|
|
5299
|
+
export=False,
|
|
5300
|
+
hidden=True,
|
|
5301
|
+
missing_grad=True,
|
|
4711
5302
|
)
|
|
4712
5303
|
add_builtin(
|
|
4713
|
-
"iter_next",
|
|
5304
|
+
"iter_next",
|
|
5305
|
+
input_types={"query": MeshQueryAABB},
|
|
5306
|
+
value_type=int,
|
|
5307
|
+
group="Utility",
|
|
5308
|
+
export=False,
|
|
5309
|
+
hidden=True,
|
|
5310
|
+
missing_grad=True,
|
|
4714
5311
|
)
|
|
4715
5312
|
|
|
4716
5313
|
add_builtin(
|
|
@@ -4721,6 +5318,7 @@ add_builtin(
|
|
|
4721
5318
|
group="Utility",
|
|
4722
5319
|
doc="""Returns the range in reversed order.""",
|
|
4723
5320
|
export=False,
|
|
5321
|
+
missing_grad=True,
|
|
4724
5322
|
)
|
|
4725
5323
|
|
|
4726
5324
|
# ---------------------------------
|
|
@@ -4869,6 +5467,7 @@ add_builtin(
|
|
|
4869
5467
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
|
|
4870
5468
|
|
|
4871
5469
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5470
|
+
missing_grad=True,
|
|
4872
5471
|
)
|
|
4873
5472
|
|
|
4874
5473
|
|
|
@@ -4889,6 +5488,7 @@ add_builtin(
|
|
|
4889
5488
|
export=False,
|
|
4890
5489
|
group="Volumes",
|
|
4891
5490
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5491
|
+
missing_grad=True,
|
|
4892
5492
|
)
|
|
4893
5493
|
|
|
4894
5494
|
add_builtin(
|
|
@@ -4919,6 +5519,7 @@ add_builtin(
|
|
|
4919
5519
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
4920
5520
|
|
|
4921
5521
|
If the voxel at this index does not exist, this function returns the background value""",
|
|
5522
|
+
missing_grad=True,
|
|
4922
5523
|
)
|
|
4923
5524
|
|
|
4924
5525
|
add_builtin(
|
|
@@ -4927,6 +5528,7 @@ add_builtin(
|
|
|
4927
5528
|
group="Volumes",
|
|
4928
5529
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
4929
5530
|
export=False,
|
|
5531
|
+
missing_grad=True,
|
|
4930
5532
|
)
|
|
4931
5533
|
|
|
4932
5534
|
add_builtin(
|
|
@@ -4947,6 +5549,7 @@ add_builtin(
|
|
|
4947
5549
|
doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
4948
5550
|
|
|
4949
5551
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5552
|
+
missing_grad=True,
|
|
4950
5553
|
)
|
|
4951
5554
|
|
|
4952
5555
|
add_builtin(
|
|
@@ -4955,6 +5558,7 @@ add_builtin(
|
|
|
4955
5558
|
group="Volumes",
|
|
4956
5559
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
4957
5560
|
export=False,
|
|
5561
|
+
missing_grad=True,
|
|
4958
5562
|
)
|
|
4959
5563
|
|
|
4960
5564
|
add_builtin(
|
|
@@ -4973,6 +5577,7 @@ add_builtin(
|
|
|
4973
5577
|
doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
4974
5578
|
|
|
4975
5579
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5580
|
+
missing_grad=True,
|
|
4976
5581
|
)
|
|
4977
5582
|
|
|
4978
5583
|
add_builtin(
|
|
@@ -4981,6 +5586,7 @@ add_builtin(
|
|
|
4981
5586
|
group="Volumes",
|
|
4982
5587
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
4983
5588
|
export=False,
|
|
5589
|
+
missing_grad=True,
|
|
4984
5590
|
)
|
|
4985
5591
|
|
|
4986
5592
|
|
|
@@ -5062,6 +5668,7 @@ add_builtin(
|
|
|
5062
5668
|
If the voxel at this index does not exist, this function returns -1.
|
|
5063
5669
|
This function is available for both index grids and classical volumes.
|
|
5064
5670
|
""",
|
|
5671
|
+
missing_grad=True,
|
|
5065
5672
|
)
|
|
5066
5673
|
|
|
5067
5674
|
add_builtin(
|
|
@@ -5103,6 +5710,7 @@ add_builtin(
|
|
|
5103
5710
|
value_type=uint32,
|
|
5104
5711
|
group="Random",
|
|
5105
5712
|
doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
|
|
5713
|
+
missing_grad=True,
|
|
5106
5714
|
)
|
|
5107
5715
|
|
|
5108
5716
|
add_builtin(
|
|
@@ -5114,6 +5722,7 @@ add_builtin(
|
|
|
5114
5722
|
|
|
5115
5723
|
This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
|
|
5116
5724
|
but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
|
|
5725
|
+
missing_grad=True,
|
|
5117
5726
|
)
|
|
5118
5727
|
|
|
5119
5728
|
add_builtin(
|
|
@@ -5122,6 +5731,7 @@ add_builtin(
|
|
|
5122
5731
|
value_type=int,
|
|
5123
5732
|
group="Random",
|
|
5124
5733
|
doc="Return a random integer in the range [-2^31, 2^31).",
|
|
5734
|
+
missing_grad=True,
|
|
5125
5735
|
)
|
|
5126
5736
|
add_builtin(
|
|
5127
5737
|
"randi",
|
|
@@ -5129,6 +5739,7 @@ add_builtin(
|
|
|
5129
5739
|
value_type=int,
|
|
5130
5740
|
group="Random",
|
|
5131
5741
|
doc="Return a random integer between [low, high).",
|
|
5742
|
+
missing_grad=True,
|
|
5132
5743
|
)
|
|
5133
5744
|
add_builtin(
|
|
5134
5745
|
"randu",
|
|
@@ -5136,6 +5747,7 @@ add_builtin(
|
|
|
5136
5747
|
value_type=uint32,
|
|
5137
5748
|
group="Random",
|
|
5138
5749
|
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
5750
|
+
missing_grad=True,
|
|
5139
5751
|
)
|
|
5140
5752
|
add_builtin(
|
|
5141
5753
|
"randu",
|
|
@@ -5143,6 +5755,7 @@ add_builtin(
|
|
|
5143
5755
|
value_type=uint32,
|
|
5144
5756
|
group="Random",
|
|
5145
5757
|
doc="Return a random unsigned integer between [low, high).",
|
|
5758
|
+
missing_grad=True,
|
|
5146
5759
|
)
|
|
5147
5760
|
add_builtin(
|
|
5148
5761
|
"randf",
|
|
@@ -5150,6 +5763,7 @@ add_builtin(
|
|
|
5150
5763
|
value_type=float,
|
|
5151
5764
|
group="Random",
|
|
5152
5765
|
doc="Return a random float between [0.0, 1.0).",
|
|
5766
|
+
missing_grad=True,
|
|
5153
5767
|
)
|
|
5154
5768
|
add_builtin(
|
|
5155
5769
|
"randf",
|
|
@@ -5157,6 +5771,7 @@ add_builtin(
|
|
|
5157
5771
|
value_type=float,
|
|
5158
5772
|
group="Random",
|
|
5159
5773
|
doc="Return a random float between [low, high).",
|
|
5774
|
+
missing_grad=True,
|
|
5160
5775
|
)
|
|
5161
5776
|
add_builtin(
|
|
5162
5777
|
"randn",
|
|
@@ -5164,6 +5779,7 @@ add_builtin(
|
|
|
5164
5779
|
value_type=float,
|
|
5165
5780
|
group="Random",
|
|
5166
5781
|
doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
|
|
5782
|
+
missing_grad=True,
|
|
5167
5783
|
)
|
|
5168
5784
|
|
|
5169
5785
|
add_builtin(
|
|
@@ -5172,6 +5788,7 @@ add_builtin(
|
|
|
5172
5788
|
value_type=int,
|
|
5173
5789
|
group="Random",
|
|
5174
5790
|
doc="Inverse-transform sample a cumulative distribution function.",
|
|
5791
|
+
missing_grad=True,
|
|
5175
5792
|
)
|
|
5176
5793
|
add_builtin(
|
|
5177
5794
|
"sample_triangle",
|
|
@@ -5179,6 +5796,7 @@ add_builtin(
|
|
|
5179
5796
|
value_type=vec2,
|
|
5180
5797
|
group="Random",
|
|
5181
5798
|
doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
|
|
5799
|
+
missing_grad=True,
|
|
5182
5800
|
)
|
|
5183
5801
|
add_builtin(
|
|
5184
5802
|
"sample_unit_ring",
|
|
@@ -5186,6 +5804,7 @@ add_builtin(
|
|
|
5186
5804
|
value_type=vec2,
|
|
5187
5805
|
group="Random",
|
|
5188
5806
|
doc="Uniformly sample a ring in the xy plane.",
|
|
5807
|
+
missing_grad=True,
|
|
5189
5808
|
)
|
|
5190
5809
|
add_builtin(
|
|
5191
5810
|
"sample_unit_disk",
|
|
@@ -5193,6 +5812,7 @@ add_builtin(
|
|
|
5193
5812
|
value_type=vec2,
|
|
5194
5813
|
group="Random",
|
|
5195
5814
|
doc="Uniformly sample a disk in the xy plane.",
|
|
5815
|
+
missing_grad=True,
|
|
5196
5816
|
)
|
|
5197
5817
|
add_builtin(
|
|
5198
5818
|
"sample_unit_sphere_surface",
|
|
@@ -5200,6 +5820,7 @@ add_builtin(
|
|
|
5200
5820
|
value_type=vec3,
|
|
5201
5821
|
group="Random",
|
|
5202
5822
|
doc="Uniformly sample a unit sphere surface.",
|
|
5823
|
+
missing_grad=True,
|
|
5203
5824
|
)
|
|
5204
5825
|
add_builtin(
|
|
5205
5826
|
"sample_unit_sphere",
|
|
@@ -5207,6 +5828,7 @@ add_builtin(
|
|
|
5207
5828
|
value_type=vec3,
|
|
5208
5829
|
group="Random",
|
|
5209
5830
|
doc="Uniformly sample a unit sphere.",
|
|
5831
|
+
missing_grad=True,
|
|
5210
5832
|
)
|
|
5211
5833
|
add_builtin(
|
|
5212
5834
|
"sample_unit_hemisphere_surface",
|
|
@@ -5214,6 +5836,7 @@ add_builtin(
|
|
|
5214
5836
|
value_type=vec3,
|
|
5215
5837
|
group="Random",
|
|
5216
5838
|
doc="Uniformly sample a unit hemisphere surface.",
|
|
5839
|
+
missing_grad=True,
|
|
5217
5840
|
)
|
|
5218
5841
|
add_builtin(
|
|
5219
5842
|
"sample_unit_hemisphere",
|
|
@@ -5221,6 +5844,7 @@ add_builtin(
|
|
|
5221
5844
|
value_type=vec3,
|
|
5222
5845
|
group="Random",
|
|
5223
5846
|
doc="Uniformly sample a unit hemisphere.",
|
|
5847
|
+
missing_grad=True,
|
|
5224
5848
|
)
|
|
5225
5849
|
add_builtin(
|
|
5226
5850
|
"sample_unit_square",
|
|
@@ -5228,6 +5852,7 @@ add_builtin(
|
|
|
5228
5852
|
value_type=vec2,
|
|
5229
5853
|
group="Random",
|
|
5230
5854
|
doc="Uniformly sample a unit square.",
|
|
5855
|
+
missing_grad=True,
|
|
5231
5856
|
)
|
|
5232
5857
|
add_builtin(
|
|
5233
5858
|
"sample_unit_cube",
|
|
@@ -5235,6 +5860,7 @@ add_builtin(
|
|
|
5235
5860
|
value_type=vec3,
|
|
5236
5861
|
group="Random",
|
|
5237
5862
|
doc="Uniformly sample a unit cube.",
|
|
5863
|
+
missing_grad=True,
|
|
5238
5864
|
)
|
|
5239
5865
|
|
|
5240
5866
|
add_builtin(
|
|
@@ -5246,6 +5872,7 @@ add_builtin(
|
|
|
5246
5872
|
|
|
5247
5873
|
:param state: RNG state
|
|
5248
5874
|
:param lam: The expected value of the distribution""",
|
|
5875
|
+
missing_grad=True,
|
|
5249
5876
|
)
|
|
5250
5877
|
|
|
5251
5878
|
add_builtin(
|
|
@@ -5363,9 +5990,16 @@ add_builtin(
|
|
|
5363
5990
|
dispatch_func=printf_dispatch_func,
|
|
5364
5991
|
group="Utility",
|
|
5365
5992
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
5993
|
+
missing_grad=True,
|
|
5366
5994
|
)
|
|
5367
5995
|
|
|
5368
|
-
add_builtin(
|
|
5996
|
+
add_builtin(
|
|
5997
|
+
"print",
|
|
5998
|
+
input_types={"value": Any},
|
|
5999
|
+
doc="Print variable to stdout",
|
|
6000
|
+
export=False,
|
|
6001
|
+
group="Utility",
|
|
6002
|
+
)
|
|
5369
6003
|
|
|
5370
6004
|
add_builtin(
|
|
5371
6005
|
"breakpoint",
|
|
@@ -5375,6 +6009,7 @@ add_builtin(
|
|
|
5375
6009
|
group="Utility",
|
|
5376
6010
|
namespace="",
|
|
5377
6011
|
native_func="__debugbreak",
|
|
6012
|
+
missing_grad=True,
|
|
5378
6013
|
)
|
|
5379
6014
|
|
|
5380
6015
|
# helpers
|
|
@@ -5392,6 +6027,7 @@ add_builtin(
|
|
|
5392
6027
|
This function may not be called from user-defined Warp functions.""",
|
|
5393
6028
|
namespace="",
|
|
5394
6029
|
native_func="builtin_tid1d",
|
|
6030
|
+
missing_grad=True,
|
|
5395
6031
|
)
|
|
5396
6032
|
|
|
5397
6033
|
add_builtin(
|
|
@@ -5402,6 +6038,7 @@ add_builtin(
|
|
|
5402
6038
|
doc="Returns the number of threads in the current block.",
|
|
5403
6039
|
namespace="",
|
|
5404
6040
|
native_func="builtin_block_dim",
|
|
6041
|
+
missing_grad=True,
|
|
5405
6042
|
)
|
|
5406
6043
|
|
|
5407
6044
|
add_builtin(
|
|
@@ -5416,6 +6053,7 @@ add_builtin(
|
|
|
5416
6053
|
This function may not be called from user-defined Warp functions.""",
|
|
5417
6054
|
namespace="",
|
|
5418
6055
|
native_func="builtin_tid2d",
|
|
6056
|
+
missing_grad=True,
|
|
5419
6057
|
)
|
|
5420
6058
|
|
|
5421
6059
|
add_builtin(
|
|
@@ -5430,6 +6068,7 @@ add_builtin(
|
|
|
5430
6068
|
This function may not be called from user-defined Warp functions.""",
|
|
5431
6069
|
namespace="",
|
|
5432
6070
|
native_func="builtin_tid3d",
|
|
6071
|
+
missing_grad=True,
|
|
5433
6072
|
)
|
|
5434
6073
|
|
|
5435
6074
|
add_builtin(
|
|
@@ -5444,17 +6083,37 @@ add_builtin(
|
|
|
5444
6083
|
This function may not be called from user-defined Warp functions.""",
|
|
5445
6084
|
namespace="",
|
|
5446
6085
|
native_func="builtin_tid4d",
|
|
6086
|
+
missing_grad=True,
|
|
5447
6087
|
)
|
|
5448
6088
|
|
|
5449
6089
|
|
|
6090
|
+
def copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6091
|
+
a = arg_types["a"]
|
|
6092
|
+
|
|
6093
|
+
# if the input is a shared tile, we force a copy
|
|
6094
|
+
if is_tile(a) and a.storage == "shared":
|
|
6095
|
+
return tile(
|
|
6096
|
+
dtype=a.dtype,
|
|
6097
|
+
shape=a.shape,
|
|
6098
|
+
storage=a.storage,
|
|
6099
|
+
strides=a.strides,
|
|
6100
|
+
layout=a.layout,
|
|
6101
|
+
owner=True,
|
|
6102
|
+
)
|
|
6103
|
+
|
|
6104
|
+
return a
|
|
6105
|
+
|
|
6106
|
+
|
|
5450
6107
|
add_builtin(
|
|
5451
6108
|
"copy",
|
|
5452
6109
|
input_types={"a": Any},
|
|
5453
|
-
value_func=
|
|
6110
|
+
value_func=copy_value_func,
|
|
5454
6111
|
hidden=True,
|
|
5455
6112
|
export=False,
|
|
5456
6113
|
group="Utility",
|
|
5457
6114
|
)
|
|
6115
|
+
|
|
6116
|
+
|
|
5458
6117
|
add_builtin(
|
|
5459
6118
|
"assign",
|
|
5460
6119
|
input_types={"dest": Any, "src": Any},
|
|
@@ -5464,6 +6123,37 @@ add_builtin(
|
|
|
5464
6123
|
)
|
|
5465
6124
|
|
|
5466
6125
|
|
|
6126
|
+
def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6127
|
+
if arg_types is None:
|
|
6128
|
+
return Any
|
|
6129
|
+
|
|
6130
|
+
v_true = arg_types["value_if_true"]
|
|
6131
|
+
v_false = arg_types["value_if_false"]
|
|
6132
|
+
|
|
6133
|
+
if not types_equal(v_true, v_false):
|
|
6134
|
+
raise RuntimeError(
|
|
6135
|
+
f"select() true value type ({v_true}) must be of the same type as the false type ({v_false})"
|
|
6136
|
+
)
|
|
6137
|
+
|
|
6138
|
+
if is_tile(v_false):
|
|
6139
|
+
if v_true.storage == "register":
|
|
6140
|
+
return v_true
|
|
6141
|
+
if v_false.storage == "register":
|
|
6142
|
+
return v_false
|
|
6143
|
+
|
|
6144
|
+
# both v_true and v_false are shared
|
|
6145
|
+
return tile(
|
|
6146
|
+
dtype=v_true.dtype,
|
|
6147
|
+
shape=v_true.shape,
|
|
6148
|
+
storage=v_true.storage,
|
|
6149
|
+
strides=v_true.strides,
|
|
6150
|
+
layout=v_true.layout,
|
|
6151
|
+
owner=True,
|
|
6152
|
+
)
|
|
6153
|
+
|
|
6154
|
+
return v_true
|
|
6155
|
+
|
|
6156
|
+
|
|
5467
6157
|
def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
5468
6158
|
warp.utils.warn(
|
|
5469
6159
|
"wp.select() is deprecated and will be removed in a future\n"
|
|
@@ -5480,7 +6170,7 @@ def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args
|
|
|
5480
6170
|
add_builtin(
|
|
5481
6171
|
"select",
|
|
5482
6172
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
5483
|
-
value_func=
|
|
6173
|
+
value_func=select_value_func,
|
|
5484
6174
|
dispatch_func=select_dispatch_func,
|
|
5485
6175
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
5486
6176
|
|
|
@@ -5493,7 +6183,7 @@ for t in int_types:
|
|
|
5493
6183
|
add_builtin(
|
|
5494
6184
|
"select",
|
|
5495
6185
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
5496
|
-
value_func=
|
|
6186
|
+
value_func=select_value_func,
|
|
5497
6187
|
dispatch_func=select_dispatch_func,
|
|
5498
6188
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
5499
6189
|
|
|
@@ -5505,7 +6195,7 @@ for t in int_types:
|
|
|
5505
6195
|
add_builtin(
|
|
5506
6196
|
"select",
|
|
5507
6197
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
5508
|
-
value_func=
|
|
6198
|
+
value_func=select_value_func,
|
|
5509
6199
|
dispatch_func=select_dispatch_func,
|
|
5510
6200
|
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
5511
6201
|
|
|
@@ -5515,10 +6205,40 @@ add_builtin(
|
|
|
5515
6205
|
group="Utility",
|
|
5516
6206
|
)
|
|
5517
6207
|
|
|
6208
|
+
|
|
6209
|
+
def where_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6210
|
+
if arg_types is None:
|
|
6211
|
+
return Any
|
|
6212
|
+
|
|
6213
|
+
v_true = arg_types["value_if_true"]
|
|
6214
|
+
v_false = arg_types["value_if_false"]
|
|
6215
|
+
|
|
6216
|
+
if not types_equal(v_true, v_false):
|
|
6217
|
+
raise RuntimeError(f"where() true value type ({v_true}) must be of the same type as the false type ({v_false})")
|
|
6218
|
+
|
|
6219
|
+
if is_tile(v_false):
|
|
6220
|
+
if v_true.storage == "register":
|
|
6221
|
+
return v_true
|
|
6222
|
+
if v_false.storage == "register":
|
|
6223
|
+
return v_false
|
|
6224
|
+
|
|
6225
|
+
# both v_true and v_false are shared
|
|
6226
|
+
return tile(
|
|
6227
|
+
dtype=v_true.dtype,
|
|
6228
|
+
shape=v_true.shape,
|
|
6229
|
+
storage=v_true.storage,
|
|
6230
|
+
strides=v_true.strides,
|
|
6231
|
+
layout=v_true.layout,
|
|
6232
|
+
owner=True,
|
|
6233
|
+
)
|
|
6234
|
+
|
|
6235
|
+
return v_true
|
|
6236
|
+
|
|
6237
|
+
|
|
5518
6238
|
add_builtin(
|
|
5519
6239
|
"where",
|
|
5520
6240
|
input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
|
|
5521
|
-
value_func=
|
|
6241
|
+
value_func=where_value_func,
|
|
5522
6242
|
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
5523
6243
|
group="Utility",
|
|
5524
6244
|
)
|
|
@@ -5526,14 +6246,14 @@ for t in int_types:
|
|
|
5526
6246
|
add_builtin(
|
|
5527
6247
|
"where",
|
|
5528
6248
|
input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
|
|
5529
|
-
value_func=
|
|
6249
|
+
value_func=where_value_func,
|
|
5530
6250
|
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
5531
6251
|
group="Utility",
|
|
5532
6252
|
)
|
|
5533
6253
|
add_builtin(
|
|
5534
6254
|
"where",
|
|
5535
6255
|
input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
|
|
5536
|
-
value_func=
|
|
6256
|
+
value_func=where_value_func,
|
|
5537
6257
|
doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
5538
6258
|
group="Utility",
|
|
5539
6259
|
)
|
|
@@ -5544,7 +6264,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
|
|
|
5544
6264
|
return array(dtype=Scalar)
|
|
5545
6265
|
|
|
5546
6266
|
dtype = arg_values["dtype"]
|
|
5547
|
-
shape = extract_tuple(arg_values["shape"], as_constant=
|
|
6267
|
+
shape = extract_tuple(arg_values["shape"], as_constant=False)
|
|
5548
6268
|
return array(dtype=dtype, ndim=len(shape))
|
|
5549
6269
|
|
|
5550
6270
|
|
|
@@ -5554,7 +6274,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
|
|
|
5554
6274
|
# to the underlying C++ function's runtime and template params.
|
|
5555
6275
|
|
|
5556
6276
|
dtype = return_type.dtype
|
|
5557
|
-
shape = extract_tuple(args["shape"], as_constant=
|
|
6277
|
+
shape = extract_tuple(args["shape"], as_constant=False)
|
|
5558
6278
|
|
|
5559
6279
|
func_args = (args["ptr"], *shape)
|
|
5560
6280
|
template_args = (dtype,)
|
|
@@ -5563,7 +6283,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
|
|
|
5563
6283
|
|
|
5564
6284
|
add_builtin(
|
|
5565
6285
|
"array",
|
|
5566
|
-
input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype":
|
|
6286
|
+
input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Any},
|
|
5567
6287
|
value_func=array_value_func,
|
|
5568
6288
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
5569
6289
|
dispatch_func=array_dispatch_func,
|
|
@@ -5575,6 +6295,48 @@ add_builtin(
|
|
|
5575
6295
|
)
|
|
5576
6296
|
|
|
5577
6297
|
|
|
6298
|
+
def zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6299
|
+
if arg_types is None:
|
|
6300
|
+
return fixedarray(dtype=Scalar)
|
|
6301
|
+
|
|
6302
|
+
dtype = arg_values["dtype"]
|
|
6303
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
6304
|
+
|
|
6305
|
+
if None in shape:
|
|
6306
|
+
raise RuntimeError("the `shape` argument must be specified as a constant when zero-initializing an array")
|
|
6307
|
+
|
|
6308
|
+
return fixedarray(dtype=dtype, shape=shape)
|
|
6309
|
+
|
|
6310
|
+
|
|
6311
|
+
def zeros_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6312
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
6313
|
+
# Further validate the given argument values if needed and map them
|
|
6314
|
+
# to the underlying C++ function's runtime and template params.
|
|
6315
|
+
|
|
6316
|
+
dtype = return_type.dtype
|
|
6317
|
+
shape = extract_tuple(args["shape"], as_constant=True)
|
|
6318
|
+
|
|
6319
|
+
size = math.prod(shape)
|
|
6320
|
+
|
|
6321
|
+
func_args = shape
|
|
6322
|
+
template_args = (size, dtype)
|
|
6323
|
+
return (func_args, template_args)
|
|
6324
|
+
|
|
6325
|
+
|
|
6326
|
+
add_builtin(
|
|
6327
|
+
"zeros",
|
|
6328
|
+
input_types={"shape": Tuple[int, ...], "dtype": Any},
|
|
6329
|
+
value_func=zeros_value_func,
|
|
6330
|
+
export_func=lambda input_types: {},
|
|
6331
|
+
dispatch_func=zeros_dispatch_func,
|
|
6332
|
+
native_func="fixedarray_t",
|
|
6333
|
+
group="Utility",
|
|
6334
|
+
export=False,
|
|
6335
|
+
missing_grad=True,
|
|
6336
|
+
hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
|
|
6337
|
+
)
|
|
6338
|
+
|
|
6339
|
+
|
|
5578
6340
|
# does argument checking and type propagation for address()
|
|
5579
6341
|
def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5580
6342
|
arr_type = arg_types["arr"]
|
|
@@ -5751,6 +6513,7 @@ add_builtin(
|
|
|
5751
6513
|
hidden=True,
|
|
5752
6514
|
skip_replay=True,
|
|
5753
6515
|
group="Utility",
|
|
6516
|
+
missing_grad=True,
|
|
5754
6517
|
)
|
|
5755
6518
|
|
|
5756
6519
|
|
|
@@ -5767,6 +6530,7 @@ add_builtin(
|
|
|
5767
6530
|
dispatch_func=load_dispatch_func,
|
|
5768
6531
|
hidden=True,
|
|
5769
6532
|
group="Utility",
|
|
6533
|
+
missing_grad=True,
|
|
5770
6534
|
)
|
|
5771
6535
|
|
|
5772
6536
|
|
|
@@ -5864,8 +6628,8 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
|
|
|
5864
6628
|
|
|
5865
6629
|
|
|
5866
6630
|
for array_type in array_types:
|
|
5867
|
-
# don't list indexed array operations explicitly in docs
|
|
5868
|
-
hidden = array_type
|
|
6631
|
+
# don't list fixed or indexed array operations explicitly in docs
|
|
6632
|
+
hidden = array_type in (indexedarray, fixedarray)
|
|
5869
6633
|
|
|
5870
6634
|
add_builtin(
|
|
5871
6635
|
"atomic_add",
|
|
@@ -6083,6 +6847,7 @@ for array_type in array_types:
|
|
|
6083
6847
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6084
6848
|
group="Utility",
|
|
6085
6849
|
skip_replay=True,
|
|
6850
|
+
missing_grad=True,
|
|
6086
6851
|
)
|
|
6087
6852
|
add_builtin(
|
|
6088
6853
|
"atomic_cas",
|
|
@@ -6096,6 +6861,7 @@ for array_type in array_types:
|
|
|
6096
6861
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6097
6862
|
group="Utility",
|
|
6098
6863
|
skip_replay=True,
|
|
6864
|
+
missing_grad=True,
|
|
6099
6865
|
)
|
|
6100
6866
|
add_builtin(
|
|
6101
6867
|
"atomic_cas",
|
|
@@ -6109,6 +6875,7 @@ for array_type in array_types:
|
|
|
6109
6875
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6110
6876
|
group="Utility",
|
|
6111
6877
|
skip_replay=True,
|
|
6878
|
+
missing_grad=True,
|
|
6112
6879
|
)
|
|
6113
6880
|
add_builtin(
|
|
6114
6881
|
"atomic_cas",
|
|
@@ -6130,6 +6897,7 @@ for array_type in array_types:
|
|
|
6130
6897
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6131
6898
|
group="Utility",
|
|
6132
6899
|
skip_replay=True,
|
|
6900
|
+
missing_grad=True,
|
|
6133
6901
|
)
|
|
6134
6902
|
|
|
6135
6903
|
add_builtin(
|
|
@@ -6144,6 +6912,7 @@ for array_type in array_types:
|
|
|
6144
6912
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6145
6913
|
group="Utility",
|
|
6146
6914
|
skip_replay=True,
|
|
6915
|
+
missing_grad=True,
|
|
6147
6916
|
)
|
|
6148
6917
|
add_builtin(
|
|
6149
6918
|
"atomic_exch",
|
|
@@ -6157,6 +6926,7 @@ for array_type in array_types:
|
|
|
6157
6926
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6158
6927
|
group="Utility",
|
|
6159
6928
|
skip_replay=True,
|
|
6929
|
+
missing_grad=True,
|
|
6160
6930
|
)
|
|
6161
6931
|
add_builtin(
|
|
6162
6932
|
"atomic_exch",
|
|
@@ -6170,6 +6940,7 @@ for array_type in array_types:
|
|
|
6170
6940
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6171
6941
|
group="Utility",
|
|
6172
6942
|
skip_replay=True,
|
|
6943
|
+
missing_grad=True,
|
|
6173
6944
|
)
|
|
6174
6945
|
add_builtin(
|
|
6175
6946
|
"atomic_exch",
|
|
@@ -6187,46 +6958,110 @@ for array_type in array_types:
|
|
|
6187
6958
|
|
|
6188
6959
|
|
|
6189
6960
|
# used to index into builtin types, i.e.: y = vec3[1]
|
|
6190
|
-
def
|
|
6191
|
-
|
|
6961
|
+
def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6962
|
+
vec_type = arg_types["a"]
|
|
6963
|
+
idx_type = arg_types["i"]
|
|
6964
|
+
|
|
6965
|
+
if isinstance(idx_type, slice_t):
|
|
6966
|
+
length = idx_type.get_length(vec_type._length_)
|
|
6967
|
+
return vector(length=length, dtype=vec_type._wp_scalar_type_)
|
|
6968
|
+
|
|
6969
|
+
return vec_type._wp_scalar_type_
|
|
6970
|
+
|
|
6971
|
+
|
|
6972
|
+
def vector_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6973
|
+
func_args = tuple(args.values())
|
|
6974
|
+
template_args = getattr(return_type, "_shape_", ())
|
|
6975
|
+
return (func_args, template_args)
|
|
6192
6976
|
|
|
6193
6977
|
|
|
6194
6978
|
add_builtin(
|
|
6195
6979
|
"extract",
|
|
6196
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
6197
|
-
value_func=
|
|
6980
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any},
|
|
6981
|
+
value_func=vector_extract_value_func,
|
|
6982
|
+
dispatch_func=vector_extract_dispatch_func,
|
|
6983
|
+
export=False,
|
|
6198
6984
|
hidden=True,
|
|
6199
6985
|
group="Utility",
|
|
6200
6986
|
)
|
|
6201
6987
|
add_builtin(
|
|
6202
6988
|
"extract",
|
|
6203
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
6204
|
-
value_func=
|
|
6989
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any},
|
|
6990
|
+
value_func=vector_extract_value_func,
|
|
6991
|
+
dispatch_func=vector_extract_dispatch_func,
|
|
6992
|
+
export=False,
|
|
6205
6993
|
hidden=True,
|
|
6206
6994
|
group="Utility",
|
|
6207
6995
|
)
|
|
6208
|
-
|
|
6209
6996
|
add_builtin(
|
|
6210
6997
|
"extract",
|
|
6211
|
-
input_types={"a":
|
|
6212
|
-
value_func=
|
|
6213
|
-
|
|
6214
|
-
|
|
6998
|
+
input_types={"a": transformation(dtype=Scalar), "i": Any},
|
|
6999
|
+
value_func=vector_extract_value_func,
|
|
7000
|
+
dispatch_func=vector_extract_dispatch_func,
|
|
7001
|
+
export=False,
|
|
6215
7002
|
hidden=True,
|
|
6216
7003
|
group="Utility",
|
|
6217
7004
|
)
|
|
7005
|
+
|
|
7006
|
+
|
|
7007
|
+
def matrix_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
7008
|
+
mat_type = arg_types["a"]
|
|
7009
|
+
idx_types = tuple(arg_types[x] for x in "ij" if arg_types.get(x, None) is not None)
|
|
7010
|
+
|
|
7011
|
+
# Compute the resulting shape from the slicing, with -1 being simple indexing.
|
|
7012
|
+
shape = tuple(
|
|
7013
|
+
idx.get_length(mat_type._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idx_types)
|
|
7014
|
+
)
|
|
7015
|
+
|
|
7016
|
+
# Append any non indexed slice.
|
|
7017
|
+
for i in range(len(idx_types), len(mat_type._shape_)):
|
|
7018
|
+
shape += (mat_type._shape_[i],)
|
|
7019
|
+
|
|
7020
|
+
# Count how many dimensions the output value will have.
|
|
7021
|
+
ndim = sum(1 for x in shape if x >= 0)
|
|
7022
|
+
|
|
7023
|
+
if ndim == 0:
|
|
7024
|
+
return mat_type._wp_scalar_type_
|
|
7025
|
+
|
|
7026
|
+
assert shape[0] != -1 or shape[1] != -1
|
|
7027
|
+
|
|
7028
|
+
if ndim == 1:
|
|
7029
|
+
length = shape[0] if shape[0] != -1 else shape[1]
|
|
7030
|
+
return vector(length=length, dtype=mat_type._wp_scalar_type_)
|
|
7031
|
+
|
|
7032
|
+
assert ndim == 2
|
|
7033
|
+
|
|
7034
|
+
# When a matrix dimension is 0, all other dimensions are also expected to be 0.
|
|
7035
|
+
if any(x == 0 for x in shape):
|
|
7036
|
+
shape = (0,) * len(shape)
|
|
7037
|
+
|
|
7038
|
+
return matrix(shape=shape, dtype=mat_type._wp_scalar_type_)
|
|
7039
|
+
|
|
7040
|
+
|
|
7041
|
+
def matrix_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
7042
|
+
idx_types = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
|
|
7043
|
+
has_slice = any(isinstance(x, slice_t) for x in idx_types)
|
|
7044
|
+
|
|
7045
|
+
func_args = tuple(args.values())
|
|
7046
|
+
template_args = getattr(return_type, "_shape_", ()) if has_slice else ()
|
|
7047
|
+
return (func_args, template_args)
|
|
7048
|
+
|
|
7049
|
+
|
|
6218
7050
|
add_builtin(
|
|
6219
7051
|
"extract",
|
|
6220
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6221
|
-
value_func=
|
|
7052
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any},
|
|
7053
|
+
value_func=matrix_extract_value_func,
|
|
7054
|
+
dispatch_func=matrix_extract_dispatch_func,
|
|
7055
|
+
export=False,
|
|
6222
7056
|
hidden=True,
|
|
6223
7057
|
group="Utility",
|
|
6224
7058
|
)
|
|
6225
|
-
|
|
6226
7059
|
add_builtin(
|
|
6227
7060
|
"extract",
|
|
6228
|
-
input_types={"a":
|
|
6229
|
-
value_func=
|
|
7061
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any},
|
|
7062
|
+
value_func=matrix_extract_value_func,
|
|
7063
|
+
dispatch_func=matrix_extract_dispatch_func,
|
|
7064
|
+
export=False,
|
|
6230
7065
|
hidden=True,
|
|
6231
7066
|
group="Utility",
|
|
6232
7067
|
)
|
|
@@ -6247,6 +7082,19 @@ def vector_index_dispatch_func(input_types: Mapping[str, type], return_type: Any
|
|
|
6247
7082
|
return (func_args, template_args)
|
|
6248
7083
|
|
|
6249
7084
|
|
|
7085
|
+
def matrix_ij_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
7086
|
+
mat_type = arg_types["a"]
|
|
7087
|
+
value_type = mat_type._wp_scalar_type_
|
|
7088
|
+
|
|
7089
|
+
return Reference(value_type)
|
|
7090
|
+
|
|
7091
|
+
|
|
7092
|
+
def matrix_ij_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
7093
|
+
func_args = (Reference(args["a"]), args["i"], args["j"])
|
|
7094
|
+
template_args = ()
|
|
7095
|
+
return (func_args, template_args)
|
|
7096
|
+
|
|
7097
|
+
|
|
6250
7098
|
# implements &vector[index]
|
|
6251
7099
|
add_builtin(
|
|
6252
7100
|
"index",
|
|
@@ -6256,6 +7104,7 @@ add_builtin(
|
|
|
6256
7104
|
hidden=True,
|
|
6257
7105
|
group="Utility",
|
|
6258
7106
|
skip_replay=True,
|
|
7107
|
+
missing_grad=True,
|
|
6259
7108
|
)
|
|
6260
7109
|
# implements &quaternion[index]
|
|
6261
7110
|
add_builtin(
|
|
@@ -6266,6 +7115,7 @@ add_builtin(
|
|
|
6266
7115
|
hidden=True,
|
|
6267
7116
|
group="Utility",
|
|
6268
7117
|
skip_replay=True,
|
|
7118
|
+
missing_grad=True,
|
|
6269
7119
|
)
|
|
6270
7120
|
# implements &transformation[index]
|
|
6271
7121
|
add_builtin(
|
|
@@ -6276,6 +7126,7 @@ add_builtin(
|
|
|
6276
7126
|
hidden=True,
|
|
6277
7127
|
group="Utility",
|
|
6278
7128
|
skip_replay=True,
|
|
7129
|
+
missing_grad=True,
|
|
6279
7130
|
)
|
|
6280
7131
|
# implements &(*vector)[index]
|
|
6281
7132
|
add_builtin(
|
|
@@ -6286,6 +7137,18 @@ add_builtin(
|
|
|
6286
7137
|
hidden=True,
|
|
6287
7138
|
group="Utility",
|
|
6288
7139
|
skip_replay=True,
|
|
7140
|
+
missing_grad=True,
|
|
7141
|
+
)
|
|
7142
|
+
# implements &(*matrix)[i, j]
|
|
7143
|
+
add_builtin(
|
|
7144
|
+
"indexref",
|
|
7145
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
|
|
7146
|
+
value_func=matrix_ij_value_func,
|
|
7147
|
+
dispatch_func=matrix_ij_dispatch_func,
|
|
7148
|
+
hidden=True,
|
|
7149
|
+
group="Utility",
|
|
7150
|
+
skip_replay=True,
|
|
7151
|
+
missing_grad=True,
|
|
6289
7152
|
)
|
|
6290
7153
|
# implements &(*quaternion)[index]
|
|
6291
7154
|
add_builtin(
|
|
@@ -6296,6 +7159,7 @@ add_builtin(
|
|
|
6296
7159
|
hidden=True,
|
|
6297
7160
|
group="Utility",
|
|
6298
7161
|
skip_replay=True,
|
|
7162
|
+
missing_grad=True,
|
|
6299
7163
|
)
|
|
6300
7164
|
# implements &(*transformation)[index]
|
|
6301
7165
|
add_builtin(
|
|
@@ -6306,14 +7170,50 @@ add_builtin(
|
|
|
6306
7170
|
hidden=True,
|
|
6307
7171
|
group="Utility",
|
|
6308
7172
|
skip_replay=True,
|
|
7173
|
+
missing_grad=True,
|
|
6309
7174
|
)
|
|
6310
7175
|
|
|
6311
7176
|
|
|
7177
|
+
def vector_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
7178
|
+
vec = args["a"].type
|
|
7179
|
+
idx = args["i"].type
|
|
7180
|
+
value_type = strip_reference(args["value"].type)
|
|
7181
|
+
|
|
7182
|
+
if isinstance(idx, slice_t):
|
|
7183
|
+
length = idx.get_length(vec._length_)
|
|
7184
|
+
|
|
7185
|
+
if type_is_vector(value_type):
|
|
7186
|
+
if not types_equal(value_type._wp_scalar_type_, vec._wp_scalar_type_):
|
|
7187
|
+
raise ValueError(
|
|
7188
|
+
f"The provided vector is expected to be of length {length} with dtype {type_repr(vec._wp_scalar_type_)}."
|
|
7189
|
+
)
|
|
7190
|
+
if value_type._length_ != length:
|
|
7191
|
+
raise ValueError(
|
|
7192
|
+
f"The length of the provided vector ({args['value'].type._length_}) isn't compatible with the given slice (expected {length})."
|
|
7193
|
+
)
|
|
7194
|
+
template_args = (length,)
|
|
7195
|
+
else:
|
|
7196
|
+
# Disallow broadcasting.
|
|
7197
|
+
raise ValueError(
|
|
7198
|
+
f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(vec._wp_scalar_type_)}."
|
|
7199
|
+
)
|
|
7200
|
+
else:
|
|
7201
|
+
if not types_equal(value_type, vec._wp_scalar_type_):
|
|
7202
|
+
raise ValueError(
|
|
7203
|
+
f"The provided value is expected to be a scalar of type {type_repr(vec._wp_scalar_type_)}."
|
|
7204
|
+
)
|
|
7205
|
+
template_args = ()
|
|
7206
|
+
|
|
7207
|
+
func_args = tuple(args.values())
|
|
7208
|
+
return (func_args, template_args)
|
|
7209
|
+
|
|
7210
|
+
|
|
6312
7211
|
# implements vector[index] = value
|
|
6313
7212
|
add_builtin(
|
|
6314
7213
|
"assign_inplace",
|
|
6315
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
7214
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
6316
7215
|
value_type=None,
|
|
7216
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6317
7217
|
hidden=True,
|
|
6318
7218
|
export=False,
|
|
6319
7219
|
group="Utility",
|
|
@@ -6322,8 +7222,9 @@ add_builtin(
|
|
|
6322
7222
|
# implements quaternion[index] = value
|
|
6323
7223
|
add_builtin(
|
|
6324
7224
|
"assign_inplace",
|
|
6325
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
7225
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
6326
7226
|
value_type=None,
|
|
7227
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6327
7228
|
hidden=True,
|
|
6328
7229
|
export=False,
|
|
6329
7230
|
group="Utility",
|
|
@@ -6331,15 +7232,16 @@ add_builtin(
|
|
|
6331
7232
|
# implements transformation[index] = value
|
|
6332
7233
|
add_builtin(
|
|
6333
7234
|
"assign_inplace",
|
|
6334
|
-
input_types={"a": transformation(dtype=Scalar), "i":
|
|
7235
|
+
input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
|
|
6335
7236
|
value_type=None,
|
|
7237
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6336
7238
|
hidden=True,
|
|
6337
7239
|
export=False,
|
|
6338
7240
|
group="Utility",
|
|
6339
7241
|
)
|
|
6340
7242
|
|
|
6341
7243
|
|
|
6342
|
-
def
|
|
7244
|
+
def vector_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6343
7245
|
vec_type = arg_types["a"]
|
|
6344
7246
|
return vec_type
|
|
6345
7247
|
|
|
@@ -6347,8 +7249,9 @@ def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
|
|
|
6347
7249
|
# implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6348
7250
|
add_builtin(
|
|
6349
7251
|
"assign_copy",
|
|
6350
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
6351
|
-
value_func=
|
|
7252
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
7253
|
+
value_func=vector_assign_copy_value_func,
|
|
7254
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6352
7255
|
hidden=True,
|
|
6353
7256
|
export=False,
|
|
6354
7257
|
group="Utility",
|
|
@@ -6357,8 +7260,9 @@ add_builtin(
|
|
|
6357
7260
|
# implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6358
7261
|
add_builtin(
|
|
6359
7262
|
"assign_copy",
|
|
6360
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
6361
|
-
value_func=
|
|
7263
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
7264
|
+
value_func=vector_assign_copy_value_func,
|
|
7265
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6362
7266
|
hidden=True,
|
|
6363
7267
|
export=False,
|
|
6364
7268
|
group="Utility",
|
|
@@ -6367,8 +7271,9 @@ add_builtin(
|
|
|
6367
7271
|
# implements transformation[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6368
7272
|
add_builtin(
|
|
6369
7273
|
"assign_copy",
|
|
6370
|
-
input_types={"a": transformation(dtype=Scalar), "i":
|
|
6371
|
-
value_func=
|
|
7274
|
+
input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
|
|
7275
|
+
value_func=vector_assign_copy_value_func,
|
|
7276
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6372
7277
|
hidden=True,
|
|
6373
7278
|
export=False,
|
|
6374
7279
|
group="Utility",
|
|
@@ -6377,8 +7282,9 @@ add_builtin(
|
|
|
6377
7282
|
# implements vector[idx] += scalar
|
|
6378
7283
|
add_builtin(
|
|
6379
7284
|
"add_inplace",
|
|
6380
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
7285
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
6381
7286
|
value_type=None,
|
|
7287
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6382
7288
|
hidden=True,
|
|
6383
7289
|
export=False,
|
|
6384
7290
|
group="Utility",
|
|
@@ -6387,8 +7293,9 @@ add_builtin(
|
|
|
6387
7293
|
# implements quaternion[idx] += scalar
|
|
6388
7294
|
add_builtin(
|
|
6389
7295
|
"add_inplace",
|
|
6390
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
7296
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
6391
7297
|
value_type=None,
|
|
7298
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6392
7299
|
hidden=True,
|
|
6393
7300
|
export=False,
|
|
6394
7301
|
group="Utility",
|
|
@@ -6397,8 +7304,9 @@ add_builtin(
|
|
|
6397
7304
|
# implements transformation[idx] += scalar
|
|
6398
7305
|
add_builtin(
|
|
6399
7306
|
"add_inplace",
|
|
6400
|
-
input_types={"a": transformation(dtype=Float), "i":
|
|
7307
|
+
input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
|
|
6401
7308
|
value_type=None,
|
|
7309
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6402
7310
|
hidden=True,
|
|
6403
7311
|
export=False,
|
|
6404
7312
|
group="Utility",
|
|
@@ -6417,8 +7325,9 @@ add_builtin(
|
|
|
6417
7325
|
# implements vector[idx] -= scalar
|
|
6418
7326
|
add_builtin(
|
|
6419
7327
|
"sub_inplace",
|
|
6420
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
7328
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
6421
7329
|
value_type=None,
|
|
7330
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6422
7331
|
hidden=True,
|
|
6423
7332
|
export=False,
|
|
6424
7333
|
group="Utility",
|
|
@@ -6427,8 +7336,9 @@ add_builtin(
|
|
|
6427
7336
|
# implements quaternion[idx] -= scalar
|
|
6428
7337
|
add_builtin(
|
|
6429
7338
|
"sub_inplace",
|
|
6430
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
7339
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
6431
7340
|
value_type=None,
|
|
7341
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6432
7342
|
hidden=True,
|
|
6433
7343
|
export=False,
|
|
6434
7344
|
group="Utility",
|
|
@@ -6437,8 +7347,9 @@ add_builtin(
|
|
|
6437
7347
|
# implements transformation[idx] -= scalar
|
|
6438
7348
|
add_builtin(
|
|
6439
7349
|
"sub_inplace",
|
|
6440
|
-
input_types={"a": transformation(dtype=
|
|
7350
|
+
input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
|
|
6441
7351
|
value_type=None,
|
|
7352
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6442
7353
|
hidden=True,
|
|
6443
7354
|
export=False,
|
|
6444
7355
|
group="Utility",
|
|
@@ -6470,6 +7381,7 @@ add_builtin(
|
|
|
6470
7381
|
hidden=True,
|
|
6471
7382
|
group="Utility",
|
|
6472
7383
|
skip_replay=True,
|
|
7384
|
+
missing_grad=True,
|
|
6473
7385
|
)
|
|
6474
7386
|
|
|
6475
7387
|
|
|
@@ -6488,6 +7400,7 @@ add_builtin(
|
|
|
6488
7400
|
hidden=True,
|
|
6489
7401
|
group="Utility",
|
|
6490
7402
|
skip_replay=True,
|
|
7403
|
+
missing_grad=True,
|
|
6491
7404
|
)
|
|
6492
7405
|
|
|
6493
7406
|
|
|
@@ -6499,61 +7412,154 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
6499
7412
|
return mat_size == vec_size and mat_type == vec_type
|
|
6500
7413
|
|
|
6501
7414
|
|
|
6502
|
-
|
|
7415
|
+
def matrix_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
7416
|
+
mat = args["a"].type
|
|
7417
|
+
value_type = strip_reference(args["value"].type)
|
|
7418
|
+
|
|
7419
|
+
idxs = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
|
|
7420
|
+
has_slice = any(isinstance(x, slice_t) for x in idxs)
|
|
7421
|
+
|
|
7422
|
+
if has_slice:
|
|
7423
|
+
# Compute the resulting shape from the slicing, with -1 being simple indexing.
|
|
7424
|
+
shape = tuple(idx.get_length(mat._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idxs))
|
|
7425
|
+
|
|
7426
|
+
# Append any non indexed slice.
|
|
7427
|
+
for i in range(len(idxs), len(mat._shape_)):
|
|
7428
|
+
shape += (mat._shape_[i],)
|
|
7429
|
+
|
|
7430
|
+
# Count how many dimensions the output value will have.
|
|
7431
|
+
ndim = sum(1 for x in shape if x >= 0)
|
|
7432
|
+
assert ndim > 0
|
|
7433
|
+
|
|
7434
|
+
if ndim == 1:
|
|
7435
|
+
length = shape[0] if shape[0] != -1 else shape[1]
|
|
7436
|
+
|
|
7437
|
+
if type_is_vector(value_type):
|
|
7438
|
+
if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
|
|
7439
|
+
raise ValueError(
|
|
7440
|
+
f"The provided vector is expected to be of length {length} with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7441
|
+
)
|
|
7442
|
+
|
|
7443
|
+
if value_type._length_ != length:
|
|
7444
|
+
raise ValueError(
|
|
7445
|
+
f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {length})."
|
|
7446
|
+
)
|
|
7447
|
+
|
|
7448
|
+
template_args = (length,)
|
|
7449
|
+
else:
|
|
7450
|
+
# Disallow broadcasting.
|
|
7451
|
+
raise ValueError(
|
|
7452
|
+
f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7453
|
+
)
|
|
7454
|
+
else:
|
|
7455
|
+
assert ndim == 2
|
|
7456
|
+
|
|
7457
|
+
# When a matrix dimension is 0, all other dimensions are also expected to be 0.
|
|
7458
|
+
if any(x == 0 for x in shape):
|
|
7459
|
+
shape = (0,) * len(shape)
|
|
7460
|
+
|
|
7461
|
+
if type_is_matrix(value_type):
|
|
7462
|
+
if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
|
|
7463
|
+
raise ValueError(
|
|
7464
|
+
f"The provided matrix is expected to be of shape {shape} with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7465
|
+
)
|
|
7466
|
+
|
|
7467
|
+
if value_type._shape_ != shape:
|
|
7468
|
+
raise ValueError(
|
|
7469
|
+
f"The shape of the provided matrix ({value_type._shape_}) isn't compatible with the given slice (expected {shape})."
|
|
7470
|
+
)
|
|
7471
|
+
|
|
7472
|
+
template_args = shape
|
|
7473
|
+
else:
|
|
7474
|
+
# Disallow broadcasting.
|
|
7475
|
+
raise ValueError(
|
|
7476
|
+
f"The provided value is expected to be a matrix of shape {shape}, with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7477
|
+
)
|
|
7478
|
+
elif len(idxs) == 1:
|
|
7479
|
+
if not type_is_vector(value_type) or not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
|
|
7480
|
+
raise ValueError(
|
|
7481
|
+
f"The provided value is expected to be a vector of length {mat._shape_[1]}, with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7482
|
+
)
|
|
7483
|
+
|
|
7484
|
+
if value_type._length_ != mat._shape_[1]:
|
|
7485
|
+
raise ValueError(
|
|
7486
|
+
f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {mat._shape_[1]})."
|
|
7487
|
+
)
|
|
7488
|
+
|
|
7489
|
+
template_args = ()
|
|
7490
|
+
elif len(idxs) == 2:
|
|
7491
|
+
if not types_equal(value_type, mat._wp_scalar_type_):
|
|
7492
|
+
raise ValueError(
|
|
7493
|
+
f"The provided value is expected to be a scalar of type {type_repr(mat._wp_scalar_type_)}."
|
|
7494
|
+
)
|
|
7495
|
+
|
|
7496
|
+
template_args = ()
|
|
7497
|
+
else:
|
|
7498
|
+
raise AssertionError
|
|
7499
|
+
|
|
7500
|
+
func_args = tuple(args.values())
|
|
7501
|
+
return (func_args, template_args)
|
|
7502
|
+
|
|
7503
|
+
|
|
7504
|
+
# implements matrix[i] = value
|
|
6503
7505
|
add_builtin(
|
|
6504
7506
|
"assign_inplace",
|
|
6505
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7507
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
7508
|
+
constraint=matrix_vector_sametype,
|
|
6506
7509
|
value_type=None,
|
|
7510
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6507
7511
|
hidden=True,
|
|
6508
7512
|
export=False,
|
|
6509
7513
|
group="Utility",
|
|
6510
7514
|
)
|
|
6511
7515
|
|
|
6512
7516
|
|
|
6513
|
-
# implements matrix[i] =
|
|
7517
|
+
# implements matrix[i,j] = value
|
|
6514
7518
|
add_builtin(
|
|
6515
7519
|
"assign_inplace",
|
|
6516
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6517
|
-
constraint=matrix_vector_sametype,
|
|
7520
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
6518
7521
|
value_type=None,
|
|
7522
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6519
7523
|
hidden=True,
|
|
6520
7524
|
export=False,
|
|
6521
7525
|
group="Utility",
|
|
6522
7526
|
)
|
|
6523
7527
|
|
|
6524
7528
|
|
|
6525
|
-
def
|
|
7529
|
+
def matrix_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6526
7530
|
mat_type = arg_types["a"]
|
|
6527
7531
|
return mat_type
|
|
6528
7532
|
|
|
6529
7533
|
|
|
6530
|
-
# implements matrix[i
|
|
7534
|
+
# implements matrix[i] = value
|
|
6531
7535
|
add_builtin(
|
|
6532
7536
|
"assign_copy",
|
|
6533
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6534
|
-
value_func=
|
|
7537
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
7538
|
+
value_func=matrix_assign_copy_value_func,
|
|
7539
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6535
7540
|
hidden=True,
|
|
6536
7541
|
export=False,
|
|
6537
7542
|
group="Utility",
|
|
6538
7543
|
)
|
|
6539
7544
|
|
|
6540
7545
|
|
|
6541
|
-
# implements matrix[i] =
|
|
7546
|
+
# implements matrix[i,j] = value
|
|
6542
7547
|
add_builtin(
|
|
6543
7548
|
"assign_copy",
|
|
6544
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6545
|
-
|
|
6546
|
-
|
|
7549
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
7550
|
+
value_func=matrix_assign_copy_value_func,
|
|
7551
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6547
7552
|
hidden=True,
|
|
6548
7553
|
export=False,
|
|
6549
7554
|
group="Utility",
|
|
6550
7555
|
)
|
|
6551
7556
|
|
|
6552
7557
|
|
|
6553
|
-
# implements matrix[i
|
|
7558
|
+
# implements matrix[i] += value
|
|
6554
7559
|
add_builtin(
|
|
6555
7560
|
"add_inplace",
|
|
6556
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7561
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
7562
|
+
constraint=matrix_vector_sametype,
|
|
6557
7563
|
value_type=None,
|
|
6558
7564
|
hidden=True,
|
|
6559
7565
|
export=False,
|
|
@@ -6561,11 +7567,10 @@ add_builtin(
|
|
|
6561
7567
|
)
|
|
6562
7568
|
|
|
6563
7569
|
|
|
6564
|
-
# implements matrix[i] +=
|
|
7570
|
+
# implements matrix[i,j] += value
|
|
6565
7571
|
add_builtin(
|
|
6566
7572
|
"add_inplace",
|
|
6567
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6568
|
-
constraint=matrix_vector_sametype,
|
|
7573
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
6569
7574
|
value_type=None,
|
|
6570
7575
|
hidden=True,
|
|
6571
7576
|
export=False,
|
|
@@ -6573,10 +7578,10 @@ add_builtin(
|
|
|
6573
7578
|
)
|
|
6574
7579
|
|
|
6575
7580
|
|
|
6576
|
-
# implements matrix[i
|
|
7581
|
+
# implements matrix[i] -= value
|
|
6577
7582
|
add_builtin(
|
|
6578
7583
|
"sub_inplace",
|
|
6579
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7584
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
6580
7585
|
value_type=None,
|
|
6581
7586
|
hidden=True,
|
|
6582
7587
|
export=False,
|
|
@@ -6584,10 +7589,10 @@ add_builtin(
|
|
|
6584
7589
|
)
|
|
6585
7590
|
|
|
6586
7591
|
|
|
6587
|
-
# implements matrix[i] -=
|
|
7592
|
+
# implements matrix[i,j] -= value
|
|
6588
7593
|
add_builtin(
|
|
6589
7594
|
"sub_inplace",
|
|
6590
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7595
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
6591
7596
|
value_type=None,
|
|
6592
7597
|
hidden=True,
|
|
6593
7598
|
export=False,
|
|
@@ -6606,6 +7611,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
6606
7611
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
6607
7612
|
group="Utility",
|
|
6608
7613
|
hidden=True,
|
|
7614
|
+
missing_grad=True,
|
|
6609
7615
|
)
|
|
6610
7616
|
|
|
6611
7617
|
add_builtin(
|
|
@@ -6616,6 +7622,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
6616
7622
|
group="Utility",
|
|
6617
7623
|
hidden=True,
|
|
6618
7624
|
export=False,
|
|
7625
|
+
missing_grad=True,
|
|
6619
7626
|
)
|
|
6620
7627
|
|
|
6621
7628
|
|
|
@@ -6634,6 +7641,7 @@ add_builtin(
|
|
|
6634
7641
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
6635
7642
|
group="Utility",
|
|
6636
7643
|
hidden=True,
|
|
7644
|
+
missing_grad=True,
|
|
6637
7645
|
)
|
|
6638
7646
|
add_builtin(
|
|
6639
7647
|
"expect_neq",
|
|
@@ -6644,6 +7652,7 @@ add_builtin(
|
|
|
6644
7652
|
group="Utility",
|
|
6645
7653
|
hidden=True,
|
|
6646
7654
|
export=False,
|
|
7655
|
+
missing_grad=True,
|
|
6647
7656
|
)
|
|
6648
7657
|
|
|
6649
7658
|
add_builtin(
|
|
@@ -6654,6 +7663,7 @@ add_builtin(
|
|
|
6654
7663
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
6655
7664
|
group="Utility",
|
|
6656
7665
|
hidden=True,
|
|
7666
|
+
missing_grad=True,
|
|
6657
7667
|
)
|
|
6658
7668
|
add_builtin(
|
|
6659
7669
|
"expect_neq",
|
|
@@ -6664,6 +7674,7 @@ add_builtin(
|
|
|
6664
7674
|
group="Utility",
|
|
6665
7675
|
hidden=True,
|
|
6666
7676
|
export=False,
|
|
7677
|
+
missing_grad=True,
|
|
6667
7678
|
)
|
|
6668
7679
|
|
|
6669
7680
|
add_builtin(
|
|
@@ -6754,6 +7765,7 @@ add_builtin(
|
|
|
6754
7765
|
value_type=None,
|
|
6755
7766
|
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
6756
7767
|
group="Utility",
|
|
7768
|
+
missing_grad=True,
|
|
6757
7769
|
)
|
|
6758
7770
|
add_builtin(
|
|
6759
7771
|
"expect_near",
|
|
@@ -6763,6 +7775,7 @@ add_builtin(
|
|
|
6763
7775
|
value_type=None,
|
|
6764
7776
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
6765
7777
|
group="Utility",
|
|
7778
|
+
missing_grad=True,
|
|
6766
7779
|
)
|
|
6767
7780
|
add_builtin(
|
|
6768
7781
|
"expect_near",
|
|
@@ -6772,6 +7785,7 @@ add_builtin(
|
|
|
6772
7785
|
value_type=None,
|
|
6773
7786
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
6774
7787
|
group="Utility",
|
|
7788
|
+
missing_grad=True,
|
|
6775
7789
|
)
|
|
6776
7790
|
add_builtin(
|
|
6777
7791
|
"expect_near",
|
|
@@ -6785,6 +7799,7 @@ add_builtin(
|
|
|
6785
7799
|
value_type=None,
|
|
6786
7800
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
6787
7801
|
group="Utility",
|
|
7802
|
+
missing_grad=True,
|
|
6788
7803
|
)
|
|
6789
7804
|
|
|
6790
7805
|
# ---------------------------------
|
|
@@ -6795,6 +7810,7 @@ add_builtin(
|
|
|
6795
7810
|
input_types={"arr": array(dtype=Scalar), "value": Scalar},
|
|
6796
7811
|
value_type=int,
|
|
6797
7812
|
doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
|
|
7813
|
+
missing_grad=True,
|
|
6798
7814
|
)
|
|
6799
7815
|
|
|
6800
7816
|
add_builtin(
|
|
@@ -6802,11 +7818,13 @@ add_builtin(
|
|
|
6802
7818
|
input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
|
|
6803
7819
|
value_type=int,
|
|
6804
7820
|
doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
|
|
7821
|
+
missing_grad=True,
|
|
6805
7822
|
)
|
|
6806
7823
|
|
|
6807
7824
|
# ---------------------------------
|
|
6808
7825
|
# Operators
|
|
6809
7826
|
|
|
7827
|
+
|
|
6810
7828
|
add_builtin(
|
|
6811
7829
|
"add", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
6812
7830
|
)
|
|
@@ -6876,13 +7894,36 @@ add_builtin(
|
|
|
6876
7894
|
)
|
|
6877
7895
|
|
|
6878
7896
|
# bitwise operators
|
|
6879
|
-
add_builtin(
|
|
6880
|
-
|
|
6881
|
-
|
|
6882
|
-
|
|
6883
|
-
|
|
6884
|
-
|
|
6885
|
-
|
|
7897
|
+
add_builtin(
|
|
7898
|
+
"bit_and",
|
|
7899
|
+
input_types={"a": Int, "b": Int},
|
|
7900
|
+
value_func=sametypes_create_value_func(Int),
|
|
7901
|
+
group="Operators",
|
|
7902
|
+
missing_grad=True,
|
|
7903
|
+
)
|
|
7904
|
+
add_builtin(
|
|
7905
|
+
"bit_or",
|
|
7906
|
+
input_types={"a": Int, "b": Int},
|
|
7907
|
+
value_func=sametypes_create_value_func(Int),
|
|
7908
|
+
group="Operators",
|
|
7909
|
+
missing_grad=True,
|
|
7910
|
+
)
|
|
7911
|
+
add_builtin(
|
|
7912
|
+
"bit_xor",
|
|
7913
|
+
input_types={"a": Int, "b": Int},
|
|
7914
|
+
value_func=sametypes_create_value_func(Int),
|
|
7915
|
+
group="Operators",
|
|
7916
|
+
missing_grad=True,
|
|
7917
|
+
)
|
|
7918
|
+
add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
|
|
7919
|
+
add_builtin(
|
|
7920
|
+
"rshift",
|
|
7921
|
+
input_types={"a": Int, "b": Int},
|
|
7922
|
+
value_func=sametypes_create_value_func(Int),
|
|
7923
|
+
group="Operators",
|
|
7924
|
+
missing_grad=True,
|
|
7925
|
+
)
|
|
7926
|
+
add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
|
|
6886
7927
|
|
|
6887
7928
|
add_builtin(
|
|
6888
7929
|
"mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
@@ -7079,9 +8120,10 @@ add_builtin(
|
|
|
7079
8120
|
"mod",
|
|
7080
8121
|
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
7081
8122
|
constraint=sametypes,
|
|
7082
|
-
value_func=sametypes_create_value_func(Scalar),
|
|
8123
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
7083
8124
|
doc="Modulo operation using truncated division.",
|
|
7084
8125
|
group="Operators",
|
|
8126
|
+
missing_grad=True,
|
|
7085
8127
|
)
|
|
7086
8128
|
|
|
7087
8129
|
add_builtin(
|
|
@@ -7141,6 +8183,7 @@ add_builtin(
|
|
|
7141
8183
|
value_func=sametypes_create_value_func(Scalar),
|
|
7142
8184
|
doc="",
|
|
7143
8185
|
group="Operators",
|
|
8186
|
+
missing_grad=True,
|
|
7144
8187
|
)
|
|
7145
8188
|
|
|
7146
8189
|
add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
@@ -7188,12 +8231,16 @@ add_builtin(
|
|
|
7188
8231
|
group="Operators",
|
|
7189
8232
|
)
|
|
7190
8233
|
|
|
7191
|
-
add_builtin(
|
|
8234
|
+
add_builtin(
|
|
8235
|
+
"unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
|
|
8236
|
+
)
|
|
7192
8237
|
for t in int_types:
|
|
7193
|
-
add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators")
|
|
8238
|
+
add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True)
|
|
7194
8239
|
|
|
7195
8240
|
|
|
7196
|
-
add_builtin(
|
|
8241
|
+
add_builtin(
|
|
8242
|
+
"unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
|
|
8243
|
+
)
|
|
7197
8244
|
|
|
7198
8245
|
|
|
7199
8246
|
# Tile operators
|
|
@@ -7387,6 +8434,7 @@ add_builtin(
|
|
|
7387
8434
|
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
7388
8435
|
group="Tile Primitives",
|
|
7389
8436
|
export=False,
|
|
8437
|
+
missing_grad=True,
|
|
7390
8438
|
)
|
|
7391
8439
|
|
|
7392
8440
|
|
|
@@ -7481,7 +8529,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
7481
8529
|
num_threads = options["block_dim"]
|
|
7482
8530
|
arch = options["output_arch"]
|
|
7483
8531
|
|
|
7484
|
-
if arch is None or not warp.context.runtime.core.
|
|
8532
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7485
8533
|
# CPU/no-MathDx dispatch
|
|
7486
8534
|
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
7487
8535
|
else:
|
|
@@ -7671,7 +8719,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
7671
8719
|
arch = options["output_arch"]
|
|
7672
8720
|
ept = size // num_threads
|
|
7673
8721
|
|
|
7674
|
-
if arch is None or not warp.context.runtime.core.
|
|
8722
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7675
8723
|
# CPU/no-MathDx dispatch
|
|
7676
8724
|
return ([], [], [], 0)
|
|
7677
8725
|
else:
|
|
@@ -7714,6 +8762,7 @@ add_builtin(
|
|
|
7714
8762
|
group="Tile Primitives",
|
|
7715
8763
|
export=False,
|
|
7716
8764
|
namespace="",
|
|
8765
|
+
missing_grad=True,
|
|
7717
8766
|
)
|
|
7718
8767
|
|
|
7719
8768
|
add_builtin(
|
|
@@ -7735,6 +8784,7 @@ add_builtin(
|
|
|
7735
8784
|
group="Tile Primitives",
|
|
7736
8785
|
export=False,
|
|
7737
8786
|
namespace="",
|
|
8787
|
+
missing_grad=True,
|
|
7738
8788
|
)
|
|
7739
8789
|
|
|
7740
8790
|
|
|
@@ -7792,28 +8842,27 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
7792
8842
|
raise TypeError("tile_cholesky() returns one output")
|
|
7793
8843
|
out = return_values[0]
|
|
7794
8844
|
|
|
7795
|
-
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
7796
|
-
|
|
7797
8845
|
# We already ensured a is square in tile_cholesky_generic_value_func()
|
|
7798
8846
|
M, N = a.type.shape
|
|
7799
8847
|
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
7800
8848
|
raise ValueError("tile_cholesky() output tile must be square")
|
|
7801
8849
|
|
|
7802
|
-
solver = "potrf"
|
|
7803
|
-
solver_enum = cusolver_function_map[solver]
|
|
7804
|
-
|
|
7805
|
-
side_enum = cusolver_side_map["-"]
|
|
7806
|
-
diag_enum = cusolver_diag_map["-"]
|
|
7807
|
-
fill_mode = cusolver_fill_mode_map["lower"]
|
|
7808
|
-
|
|
7809
8850
|
arch = options["output_arch"]
|
|
7810
|
-
num_threads = options["block_dim"]
|
|
7811
|
-
parameter_list = f"({dtype}*, int*)"
|
|
7812
8851
|
|
|
7813
|
-
if arch is None or not warp.context.runtime.core.
|
|
8852
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7814
8853
|
# CPU/no-MathDx dispatch
|
|
7815
8854
|
return ((0, a, out), [], [], 0)
|
|
7816
8855
|
else:
|
|
8856
|
+
solver = "potrf"
|
|
8857
|
+
solver_enum = cusolver_function_map[solver]
|
|
8858
|
+
side_enum = cusolver_side_map["-"]
|
|
8859
|
+
diag_enum = cusolver_diag_map["-"]
|
|
8860
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8861
|
+
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
8862
|
+
num_threads = options["block_dim"]
|
|
8863
|
+
parameter_list = f"({dtype}*, int*)"
|
|
8864
|
+
req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
|
|
8865
|
+
|
|
7817
8866
|
# generate the LTO
|
|
7818
8867
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
7819
8868
|
M,
|
|
@@ -7831,6 +8880,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
7831
8880
|
num_threads,
|
|
7832
8881
|
parameter_list,
|
|
7833
8882
|
builder,
|
|
8883
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
7834
8884
|
)
|
|
7835
8885
|
|
|
7836
8886
|
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
@@ -7859,6 +8909,7 @@ add_builtin(
|
|
|
7859
8909
|
group="Tile Primitives",
|
|
7860
8910
|
export=False,
|
|
7861
8911
|
namespace="",
|
|
8912
|
+
missing_grad=True,
|
|
7862
8913
|
)
|
|
7863
8914
|
|
|
7864
8915
|
|
|
@@ -7918,9 +8969,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
7918
8969
|
if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
|
|
7919
8970
|
raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
|
|
7920
8971
|
|
|
7921
|
-
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
7922
8972
|
M, N = L.type.shape
|
|
7923
|
-
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
7924
8973
|
|
|
7925
8974
|
if len(x.type.shape) > 2 or len(x.type.shape) < 1:
|
|
7926
8975
|
raise TypeError(f"tile_cholesky_solve() output vector must be 1D or 2D, got {len(x.type.shape)}-D")
|
|
@@ -7931,21 +8980,23 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
7931
8980
|
f"got {x.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
7932
8981
|
)
|
|
7933
8982
|
|
|
7934
|
-
solver = "potrs"
|
|
7935
|
-
solver_enum = cusolver_function_map[solver]
|
|
7936
|
-
|
|
7937
|
-
side_enum = cusolver_side_map["-"]
|
|
7938
|
-
diag_enum = cusolver_diag_map["-"]
|
|
7939
|
-
fill_mode = cusolver_fill_mode_map["lower"]
|
|
7940
|
-
|
|
7941
8983
|
arch = options["output_arch"]
|
|
7942
|
-
num_threads = options["block_dim"]
|
|
7943
|
-
parameter_list = f"({dtype}*, {dtype}*)"
|
|
7944
8984
|
|
|
7945
|
-
if arch is None or not warp.context.runtime.core.
|
|
8985
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7946
8986
|
# CPU/no-MathDx dispatch
|
|
7947
8987
|
return ((0, L, y, x), [], [], 0)
|
|
7948
8988
|
else:
|
|
8989
|
+
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
8990
|
+
solver = "potrs"
|
|
8991
|
+
solver_enum = cusolver_function_map[solver]
|
|
8992
|
+
side_enum = cusolver_side_map["-"]
|
|
8993
|
+
diag_enum = cusolver_diag_map["-"]
|
|
8994
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8995
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
8996
|
+
num_threads = options["block_dim"]
|
|
8997
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8998
|
+
req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8999
|
+
|
|
7949
9000
|
# generate the LTO
|
|
7950
9001
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
7951
9002
|
M,
|
|
@@ -7963,6 +9014,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
7963
9014
|
num_threads,
|
|
7964
9015
|
parameter_list,
|
|
7965
9016
|
builder,
|
|
9017
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
7966
9018
|
)
|
|
7967
9019
|
|
|
7968
9020
|
return ((Var(lto_symbol, str, False, True, False), L, y, x), [], [lto_code_data], 0)
|
|
@@ -7988,6 +9040,7 @@ add_builtin(
|
|
|
7988
9040
|
group="Tile Primitives",
|
|
7989
9041
|
export=False,
|
|
7990
9042
|
namespace="",
|
|
9043
|
+
missing_grad=True,
|
|
7991
9044
|
)
|
|
7992
9045
|
|
|
7993
9046
|
|
|
@@ -8013,9 +9066,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8013
9066
|
|
|
8014
9067
|
z = return_values[0]
|
|
8015
9068
|
|
|
8016
|
-
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
8017
9069
|
M, N = L.type.shape
|
|
8018
|
-
NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
|
|
8019
9070
|
|
|
8020
9071
|
if len(z.type.shape) > 2 or len(z.type.shape) < 1:
|
|
8021
9072
|
raise TypeError(f"tile_lower_solve() output vector must be 1D or 2D, got {len(z.type.shape)}-D")
|
|
@@ -8026,21 +9077,23 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8026
9077
|
f"got {z.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
8027
9078
|
)
|
|
8028
9079
|
|
|
8029
|
-
solver = "trsm"
|
|
8030
|
-
solver_enum = cusolver_function_map[solver]
|
|
8031
|
-
|
|
8032
|
-
side_enum = cusolver_side_map["left"]
|
|
8033
|
-
diag_enum = cusolver_diag_map["nounit"]
|
|
8034
|
-
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8035
|
-
|
|
8036
9080
|
arch = options["output_arch"]
|
|
8037
|
-
num_threads = options["block_dim"]
|
|
8038
|
-
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8039
9081
|
|
|
8040
|
-
if arch is None or not warp.context.runtime.core.
|
|
9082
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
8041
9083
|
# CPU/no-MathDx dispatch
|
|
8042
9084
|
return ((0, L, y, z), [], [], 0)
|
|
8043
9085
|
else:
|
|
9086
|
+
NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
|
|
9087
|
+
solver = "trsm"
|
|
9088
|
+
solver_enum = cusolver_function_map[solver]
|
|
9089
|
+
side_enum = cusolver_side_map["left"]
|
|
9090
|
+
diag_enum = cusolver_diag_map["nounit"]
|
|
9091
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
9092
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
9093
|
+
num_threads = options["block_dim"]
|
|
9094
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
9095
|
+
req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
9096
|
+
|
|
8044
9097
|
# generate the LTO
|
|
8045
9098
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
8046
9099
|
M,
|
|
@@ -8058,6 +9111,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8058
9111
|
num_threads,
|
|
8059
9112
|
parameter_list,
|
|
8060
9113
|
builder,
|
|
9114
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
8061
9115
|
)
|
|
8062
9116
|
|
|
8063
9117
|
return ((Var(lto_symbol, str, False, True, False), L, y, z), [], [lto_code_data], 0)
|
|
@@ -8119,6 +9173,7 @@ add_builtin(
|
|
|
8119
9173
|
group="Tile Primitives",
|
|
8120
9174
|
export=False,
|
|
8121
9175
|
namespace="",
|
|
9176
|
+
missing_grad=True,
|
|
8122
9177
|
)
|
|
8123
9178
|
|
|
8124
9179
|
|
|
@@ -8144,9 +9199,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8144
9199
|
|
|
8145
9200
|
x = return_values[0]
|
|
8146
9201
|
|
|
8147
|
-
dtype, precision_enum = cusolver_type_map[U.type.dtype]
|
|
8148
9202
|
M, N = U.type.shape
|
|
8149
|
-
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
8150
9203
|
|
|
8151
9204
|
if len(z.type.shape) > 2 or len(z.type.shape) < 1:
|
|
8152
9205
|
raise TypeError(f"tile_upper_solve() output tile must be 1D or 2D, got {len(z.type.shape)}-D")
|
|
@@ -8157,21 +9210,23 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8157
9210
|
f"got {z.type.shape[0]} elements in output and {M} rows in 'U'"
|
|
8158
9211
|
)
|
|
8159
9212
|
|
|
8160
|
-
solver = "trsm"
|
|
8161
|
-
solver_enum = cusolver_function_map[solver]
|
|
8162
|
-
|
|
8163
|
-
side_enum = cusolver_side_map["left"]
|
|
8164
|
-
diag_enum = cusolver_diag_map["nounit"]
|
|
8165
|
-
fill_mode = cusolver_fill_mode_map["upper"]
|
|
8166
|
-
|
|
8167
9213
|
arch = options["output_arch"]
|
|
8168
|
-
num_threads = options["block_dim"]
|
|
8169
|
-
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8170
9214
|
|
|
8171
|
-
if arch is None or not warp.context.runtime.core.
|
|
9215
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
8172
9216
|
# CPU/no-MathDx dispatch
|
|
8173
9217
|
return ((0, U, z, x), [], [], 0)
|
|
8174
9218
|
else:
|
|
9219
|
+
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
9220
|
+
solver = "trsm"
|
|
9221
|
+
solver_enum = cusolver_function_map[solver]
|
|
9222
|
+
side_enum = cusolver_side_map["left"]
|
|
9223
|
+
diag_enum = cusolver_diag_map["nounit"]
|
|
9224
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
9225
|
+
dtype, precision_enum = cusolver_type_map[U.type.dtype]
|
|
9226
|
+
num_threads = options["block_dim"]
|
|
9227
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
9228
|
+
req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
|
|
9229
|
+
|
|
8175
9230
|
# generate the LTO
|
|
8176
9231
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
8177
9232
|
M,
|
|
@@ -8189,6 +9244,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8189
9244
|
num_threads,
|
|
8190
9245
|
parameter_list,
|
|
8191
9246
|
builder,
|
|
9247
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
8192
9248
|
)
|
|
8193
9249
|
|
|
8194
9250
|
return ((Var(lto_symbol, str, False, True, False), U, z, x), [], [lto_code_data], 0)
|
|
@@ -8250,6 +9306,7 @@ add_builtin(
|
|
|
8250
9306
|
group="Tile Primitives",
|
|
8251
9307
|
export=False,
|
|
8252
9308
|
namespace="",
|
|
9309
|
+
missing_grad=True,
|
|
8253
9310
|
)
|
|
8254
9311
|
|
|
8255
9312
|
|
|
@@ -8269,6 +9326,7 @@ add_builtin(
|
|
|
8269
9326
|
The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
|
|
8270
9327
|
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
8271
9328
|
group="Code Generation",
|
|
9329
|
+
missing_grad=True,
|
|
8272
9330
|
)
|
|
8273
9331
|
|
|
8274
9332
|
|
|
@@ -8293,6 +9351,7 @@ add_builtin(
|
|
|
8293
9351
|
doc="Return the number of elements in a vector.",
|
|
8294
9352
|
group="Utility",
|
|
8295
9353
|
export=False,
|
|
9354
|
+
missing_grad=True,
|
|
8296
9355
|
)
|
|
8297
9356
|
|
|
8298
9357
|
add_builtin(
|
|
@@ -8302,6 +9361,7 @@ add_builtin(
|
|
|
8302
9361
|
doc="Return the number of elements in a quaternion.",
|
|
8303
9362
|
group="Utility",
|
|
8304
9363
|
export=False,
|
|
9364
|
+
missing_grad=True,
|
|
8305
9365
|
)
|
|
8306
9366
|
|
|
8307
9367
|
add_builtin(
|
|
@@ -8311,6 +9371,7 @@ add_builtin(
|
|
|
8311
9371
|
doc="Return the number of rows in a matrix.",
|
|
8312
9372
|
group="Utility",
|
|
8313
9373
|
export=False,
|
|
9374
|
+
missing_grad=True,
|
|
8314
9375
|
)
|
|
8315
9376
|
|
|
8316
9377
|
add_builtin(
|
|
@@ -8320,6 +9381,7 @@ add_builtin(
|
|
|
8320
9381
|
doc="Return the number of elements in a transformation.",
|
|
8321
9382
|
group="Utility",
|
|
8322
9383
|
export=False,
|
|
9384
|
+
missing_grad=True,
|
|
8323
9385
|
)
|
|
8324
9386
|
|
|
8325
9387
|
add_builtin(
|
|
@@ -8329,6 +9391,7 @@ add_builtin(
|
|
|
8329
9391
|
doc="Return the size of the first dimension in an array.",
|
|
8330
9392
|
group="Utility",
|
|
8331
9393
|
export=False,
|
|
9394
|
+
missing_grad=True,
|
|
8332
9395
|
)
|
|
8333
9396
|
|
|
8334
9397
|
add_builtin(
|
|
@@ -8338,6 +9401,7 @@ add_builtin(
|
|
|
8338
9401
|
doc="Return the number of rows in a tile.",
|
|
8339
9402
|
group="Utility",
|
|
8340
9403
|
export=False,
|
|
9404
|
+
missing_grad=True,
|
|
8341
9405
|
)
|
|
8342
9406
|
|
|
8343
9407
|
|
|
@@ -8412,4 +9476,24 @@ add_builtin(
|
|
|
8412
9476
|
doc="Return the number of elements in a tuple.",
|
|
8413
9477
|
group="Utility",
|
|
8414
9478
|
export=False,
|
|
9479
|
+
missing_grad=True,
|
|
9480
|
+
)
|
|
9481
|
+
|
|
9482
|
+
# ---------------------------------
|
|
9483
|
+
# Slicing
|
|
9484
|
+
|
|
9485
|
+
|
|
9486
|
+
def slice_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
9487
|
+
return slice_t(**arg_values)
|
|
9488
|
+
|
|
9489
|
+
|
|
9490
|
+
add_builtin(
|
|
9491
|
+
"slice",
|
|
9492
|
+
input_types={"start": int, "stop": int, "step": int},
|
|
9493
|
+
value_func=slice_value_func,
|
|
9494
|
+
native_func="slice_t",
|
|
9495
|
+
export=False,
|
|
9496
|
+
group="Utility",
|
|
9497
|
+
hidden=True,
|
|
9498
|
+
missing_grad=True,
|
|
8415
9499
|
)
|