warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +47 -67
- warp/builtins.py +955 -137
- warp/codegen.py +312 -206
- warp/config.py +1 -1
- warp/context.py +1249 -784
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +2 -1
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +82 -5
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +283 -69
- warp/native/vec.h +381 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +323 -192
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +85 -6
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +56 -5
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +184 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -17,10 +17,12 @@ from __future__ import annotations
|
|
|
17
17
|
|
|
18
18
|
import builtins
|
|
19
19
|
import functools
|
|
20
|
+
import math
|
|
20
21
|
from typing import Any, Callable, Mapping, Sequence
|
|
21
22
|
|
|
22
23
|
import warp.build
|
|
23
24
|
import warp.context
|
|
25
|
+
import warp.utils
|
|
24
26
|
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
25
27
|
from warp.types import *
|
|
26
28
|
|
|
@@ -2355,6 +2357,7 @@ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mappin
|
|
|
2355
2357
|
def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2356
2358
|
a = args["a"]
|
|
2357
2359
|
shape = extract_tuple(args["shape"], as_constant=True)
|
|
2360
|
+
bounds_check = args["bounds_check"]
|
|
2358
2361
|
|
|
2359
2362
|
if None in shape:
|
|
2360
2363
|
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
@@ -2365,17 +2368,23 @@ def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type:
|
|
|
2365
2368
|
offset = (0,) * a.type.ndim
|
|
2366
2369
|
|
|
2367
2370
|
func_args = (a, *offset)
|
|
2368
|
-
template_args = shape
|
|
2371
|
+
template_args = (return_type.dtype, bounds_check.constant, *shape)
|
|
2369
2372
|
|
|
2370
2373
|
return (func_args, template_args)
|
|
2371
2374
|
|
|
2372
2375
|
|
|
2373
2376
|
add_builtin(
|
|
2374
2377
|
"tile_load",
|
|
2375
|
-
input_types={
|
|
2378
|
+
input_types={
|
|
2379
|
+
"a": array(dtype=Any),
|
|
2380
|
+
"shape": Tuple[int, ...],
|
|
2381
|
+
"offset": Tuple[int, ...],
|
|
2382
|
+
"storage": str,
|
|
2383
|
+
"bounds_check": builtins.bool,
|
|
2384
|
+
},
|
|
2376
2385
|
value_func=tile_load_tuple_value_func,
|
|
2377
2386
|
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2378
|
-
defaults={"offset": None, "storage": "register"},
|
|
2387
|
+
defaults={"offset": None, "storage": "register", "bounds_check": True},
|
|
2379
2388
|
variadic=False,
|
|
2380
2389
|
doc="""Loads a tile from a global memory array.
|
|
2381
2390
|
|
|
@@ -2386,6 +2395,7 @@ add_builtin(
|
|
|
2386
2395
|
:param offset: Offset in the source array to begin reading from (optional)
|
|
2387
2396
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2388
2397
|
(default) or ``"shared"`` for shared memory.
|
|
2398
|
+
:param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster load times
|
|
2389
2399
|
:returns: A tile with shape as specified and data type the same as the source array""",
|
|
2390
2400
|
group="Tile Primitives",
|
|
2391
2401
|
export=False,
|
|
@@ -2394,16 +2404,160 @@ add_builtin(
|
|
|
2394
2404
|
# overload for scalar shape
|
|
2395
2405
|
add_builtin(
|
|
2396
2406
|
"tile_load",
|
|
2397
|
-
input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
|
|
2407
|
+
input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str, "bounds_check": builtins.bool},
|
|
2398
2408
|
value_func=tile_load_tuple_value_func,
|
|
2399
2409
|
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2400
|
-
defaults={"offset": None, "storage": "register"},
|
|
2410
|
+
defaults={"offset": None, "storage": "register", "bounds_check": True},
|
|
2401
2411
|
group="Tile Primitives",
|
|
2402
2412
|
hidden=True,
|
|
2403
2413
|
export=False,
|
|
2404
2414
|
)
|
|
2405
2415
|
|
|
2406
2416
|
|
|
2417
|
+
def tile_load_indexed_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2418
|
+
if arg_types is None:
|
|
2419
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2420
|
+
|
|
2421
|
+
a = arg_types["a"]
|
|
2422
|
+
|
|
2423
|
+
indices_tile = arg_types["indices"]
|
|
2424
|
+
indices_tile.storage = "shared" # force to shared
|
|
2425
|
+
|
|
2426
|
+
axis = arg_values["axis"]
|
|
2427
|
+
if axis >= a.ndim:
|
|
2428
|
+
raise ValueError(f"tile_load_indexed() axis argument must be valid axis of array {a}, got {axis}.")
|
|
2429
|
+
|
|
2430
|
+
indices_tile_dim = len(indices_tile.shape)
|
|
2431
|
+
if indices_tile_dim != 1:
|
|
2432
|
+
raise ValueError(
|
|
2433
|
+
f"tile_load_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
|
|
2434
|
+
)
|
|
2435
|
+
|
|
2436
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2437
|
+
|
|
2438
|
+
if None in shape:
|
|
2439
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2440
|
+
|
|
2441
|
+
num_indices = indices_tile.shape[0]
|
|
2442
|
+
if num_indices != shape[axis]:
|
|
2443
|
+
raise ValueError(
|
|
2444
|
+
"The number of elements in the 1D indices tile must match the output tile shape along the specified axis."
|
|
2445
|
+
)
|
|
2446
|
+
|
|
2447
|
+
if "offset" in arg_values:
|
|
2448
|
+
offset = extract_tuple(arg_values["offset"])
|
|
2449
|
+
else:
|
|
2450
|
+
offset = (0,) * a.ndim
|
|
2451
|
+
|
|
2452
|
+
if a.ndim != len(shape):
|
|
2453
|
+
raise ValueError(
|
|
2454
|
+
f"tile_load_indexed() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
|
|
2455
|
+
)
|
|
2456
|
+
|
|
2457
|
+
if a.ndim != len(offset):
|
|
2458
|
+
raise ValueError(
|
|
2459
|
+
f"tile_load_indexed() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
|
|
2460
|
+
)
|
|
2461
|
+
|
|
2462
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2463
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2464
|
+
|
|
2465
|
+
return tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
|
|
2466
|
+
|
|
2467
|
+
|
|
2468
|
+
def tile_load_indexed_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2469
|
+
a = args["a"]
|
|
2470
|
+
indices_tile = args["indices"]
|
|
2471
|
+
axis = args["axis"]
|
|
2472
|
+
|
|
2473
|
+
shape = extract_tuple(args["shape"], as_constant=True)
|
|
2474
|
+
|
|
2475
|
+
if None in shape:
|
|
2476
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2477
|
+
|
|
2478
|
+
if "offset" in args:
|
|
2479
|
+
offset = extract_tuple(args["offset"])
|
|
2480
|
+
else:
|
|
2481
|
+
offset = (0,) * a.type.ndim
|
|
2482
|
+
|
|
2483
|
+
func_args = (a, indices_tile, axis, *offset)
|
|
2484
|
+
template_args = shape
|
|
2485
|
+
|
|
2486
|
+
return (func_args, template_args)
|
|
2487
|
+
|
|
2488
|
+
|
|
2489
|
+
add_builtin(
|
|
2490
|
+
"tile_load_indexed",
|
|
2491
|
+
input_types={
|
|
2492
|
+
"a": array(dtype=Any),
|
|
2493
|
+
"indices": tile(dtype=int, shape=Tuple[int]),
|
|
2494
|
+
"shape": Tuple[int, ...],
|
|
2495
|
+
"offset": Tuple[int, ...],
|
|
2496
|
+
"axis": int,
|
|
2497
|
+
"storage": str,
|
|
2498
|
+
},
|
|
2499
|
+
value_func=tile_load_indexed_tuple_value_func,
|
|
2500
|
+
dispatch_func=tile_load_indexed_tuple_dispatch_func,
|
|
2501
|
+
defaults={"offset": None, "axis": 0, "storage": "register"},
|
|
2502
|
+
variadic=False,
|
|
2503
|
+
doc="""Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.
|
|
2504
|
+
|
|
2505
|
+
:param a: The source array in global memory
|
|
2506
|
+
:param indices: A 1D tile of integer indices mapping to elements in ``a``.
|
|
2507
|
+
:param shape: Shape of the tile to load, must have the same number of dimensions as ``a``, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
|
|
2508
|
+
:param offset: Offset in the source array to begin reading from (optional)
|
|
2509
|
+
:param axis: Axis of ``a`` that indices refer to
|
|
2510
|
+
:param storage: The storage location for the tile: ``"register"`` for registers (default) or ``"shared"`` for shared memory.
|
|
2511
|
+
:returns: A tile with shape as specified and data type the same as the source array
|
|
2512
|
+
|
|
2513
|
+
This example shows how to select and store the even indexed rows from a 2D array:
|
|
2514
|
+
|
|
2515
|
+
.. code-block:: python
|
|
2516
|
+
|
|
2517
|
+
TILE_M = wp.constant(2)
|
|
2518
|
+
TILE_N = wp.constant(2)
|
|
2519
|
+
HALF_M = wp.constant(TILE_M // 2)
|
|
2520
|
+
HALF_N = wp.constant(TILE_N // 2)
|
|
2521
|
+
|
|
2522
|
+
@wp.kernel
|
|
2523
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
2524
|
+
i, j = wp.tid()
|
|
2525
|
+
|
|
2526
|
+
evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2
|
|
2527
|
+
|
|
2528
|
+
t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register")
|
|
2529
|
+
wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N))
|
|
2530
|
+
|
|
2531
|
+
M = TILE_M * 2
|
|
2532
|
+
N = TILE_N * 2
|
|
2533
|
+
|
|
2534
|
+
arr = np.arange(M * N).reshape(M, N)
|
|
2535
|
+
|
|
2536
|
+
x = wp.array(arr, dtype=float)
|
|
2537
|
+
y = wp.zeros((M // 2, N), dtype=float)
|
|
2538
|
+
|
|
2539
|
+
wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
|
|
2540
|
+
|
|
2541
|
+
print(x.numpy())
|
|
2542
|
+
print(y.numpy())
|
|
2543
|
+
|
|
2544
|
+
Prints:
|
|
2545
|
+
|
|
2546
|
+
.. code-block:: text
|
|
2547
|
+
|
|
2548
|
+
[[ 0. 1. 2. 3.]
|
|
2549
|
+
[ 4. 5. 6. 7.]
|
|
2550
|
+
[ 8. 9. 10. 11.]
|
|
2551
|
+
[12. 13. 14. 15.]]
|
|
2552
|
+
|
|
2553
|
+
[[ 0. 1. 2. 3.]
|
|
2554
|
+
[ 8. 9. 10. 11.]]
|
|
2555
|
+
""",
|
|
2556
|
+
group="Tile Primitives",
|
|
2557
|
+
export=False,
|
|
2558
|
+
)
|
|
2559
|
+
|
|
2560
|
+
|
|
2407
2561
|
def tile_store_value_func(arg_types, arg_values):
|
|
2408
2562
|
# return generic type (for doc builds)
|
|
2409
2563
|
if arg_types is None:
|
|
@@ -2440,6 +2594,7 @@ def tile_store_value_func(arg_types, arg_values):
|
|
|
2440
2594
|
def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2441
2595
|
a = args["a"]
|
|
2442
2596
|
t = args["t"]
|
|
2597
|
+
bounds_check = args["bounds_check"]
|
|
2443
2598
|
|
|
2444
2599
|
if "offset" in args:
|
|
2445
2600
|
offset = extract_tuple(args["offset"])
|
|
@@ -2447,17 +2602,22 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2447
2602
|
offset = (0,) * a.type.ndim
|
|
2448
2603
|
|
|
2449
2604
|
func_args = (a, *offset, t)
|
|
2450
|
-
template_args =
|
|
2605
|
+
template_args = (a.type.dtype, bounds_check.constant)
|
|
2451
2606
|
|
|
2452
2607
|
return (func_args, template_args)
|
|
2453
2608
|
|
|
2454
2609
|
|
|
2455
2610
|
add_builtin(
|
|
2456
2611
|
"tile_store",
|
|
2457
|
-
input_types={
|
|
2612
|
+
input_types={
|
|
2613
|
+
"a": array(dtype=Any),
|
|
2614
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2615
|
+
"offset": Tuple[int, ...],
|
|
2616
|
+
"bounds_check": builtins.bool,
|
|
2617
|
+
},
|
|
2458
2618
|
value_func=tile_store_value_func,
|
|
2459
2619
|
dispatch_func=tile_store_dispatch_func,
|
|
2460
|
-
defaults={"offset": None},
|
|
2620
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2461
2621
|
variadic=False,
|
|
2462
2622
|
skip_replay=True,
|
|
2463
2623
|
doc="""Store a tile to a global memory array.
|
|
@@ -2466,7 +2626,9 @@ add_builtin(
|
|
|
2466
2626
|
|
|
2467
2627
|
:param a: The destination array in global memory
|
|
2468
2628
|
:param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
|
|
2469
|
-
:param offset: Offset in the destination array (optional)
|
|
2629
|
+
:param offset: Offset in the destination array (optional)
|
|
2630
|
+
:param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
|
|
2631
|
+
""",
|
|
2470
2632
|
group="Tile Primitives",
|
|
2471
2633
|
export=False,
|
|
2472
2634
|
)
|
|
@@ -2474,10 +2636,15 @@ add_builtin(
|
|
|
2474
2636
|
# overload for scalar offset
|
|
2475
2637
|
add_builtin(
|
|
2476
2638
|
"tile_store",
|
|
2477
|
-
input_types={
|
|
2639
|
+
input_types={
|
|
2640
|
+
"a": array(dtype=Any),
|
|
2641
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2642
|
+
"offset": int,
|
|
2643
|
+
"bounds_check": builtins.bool,
|
|
2644
|
+
},
|
|
2478
2645
|
value_func=tile_store_value_func,
|
|
2479
2646
|
dispatch_func=tile_store_dispatch_func,
|
|
2480
|
-
defaults={"offset": None},
|
|
2647
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2481
2648
|
variadic=False,
|
|
2482
2649
|
skip_replay=True,
|
|
2483
2650
|
group="Tile Primitives",
|
|
@@ -2486,6 +2653,151 @@ add_builtin(
|
|
|
2486
2653
|
)
|
|
2487
2654
|
|
|
2488
2655
|
|
|
2656
|
+
def tile_store_indexed_value_func(arg_types, arg_values):
|
|
2657
|
+
# return generic type (for doc builds)
|
|
2658
|
+
if arg_types is None:
|
|
2659
|
+
return None
|
|
2660
|
+
|
|
2661
|
+
a = arg_types["a"]
|
|
2662
|
+
t = arg_types["t"]
|
|
2663
|
+
indices_tile = arg_types["indices"]
|
|
2664
|
+
indices_tile.storage = "shared" # force to shared
|
|
2665
|
+
|
|
2666
|
+
axis = arg_values["axis"]
|
|
2667
|
+
if axis >= a.ndim:
|
|
2668
|
+
raise ValueError(f"tile_store_indexed() axis argument must be valid axis of array {a}, got {axis}.")
|
|
2669
|
+
|
|
2670
|
+
indices_tile_dim = len(indices_tile.shape)
|
|
2671
|
+
if indices_tile_dim != 1:
|
|
2672
|
+
raise ValueError(
|
|
2673
|
+
f"tile_store_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
|
|
2674
|
+
)
|
|
2675
|
+
|
|
2676
|
+
num_indices = indices_tile.shape[0]
|
|
2677
|
+
if num_indices != t.shape[axis]:
|
|
2678
|
+
raise ValueError(
|
|
2679
|
+
"The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
|
|
2680
|
+
)
|
|
2681
|
+
|
|
2682
|
+
if "offset" in arg_types:
|
|
2683
|
+
c = extract_tuple(arg_values["offset"])
|
|
2684
|
+
else:
|
|
2685
|
+
c = (0,) * a.ndim
|
|
2686
|
+
|
|
2687
|
+
if len(c) != a.ndim:
|
|
2688
|
+
raise ValueError(
|
|
2689
|
+
f"tile_store_indexed() 'a' argument must have {len(c)} dimensions, "
|
|
2690
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2691
|
+
)
|
|
2692
|
+
|
|
2693
|
+
if len(t.shape) != a.ndim:
|
|
2694
|
+
raise ValueError(
|
|
2695
|
+
f"tile_store_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2696
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2697
|
+
)
|
|
2698
|
+
|
|
2699
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2700
|
+
raise TypeError(
|
|
2701
|
+
f"tile_store_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2702
|
+
)
|
|
2703
|
+
|
|
2704
|
+
return None
|
|
2705
|
+
|
|
2706
|
+
|
|
2707
|
+
def tile_store_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2708
|
+
a = args["a"]
|
|
2709
|
+
indices_tile = args["indices"]
|
|
2710
|
+
axis = args["axis"]
|
|
2711
|
+
t = args["t"]
|
|
2712
|
+
|
|
2713
|
+
if "offset" in args:
|
|
2714
|
+
offset = extract_tuple(args["offset"])
|
|
2715
|
+
else:
|
|
2716
|
+
offset = (0,) * a.type.ndim
|
|
2717
|
+
|
|
2718
|
+
func_args = (a, indices_tile, axis, *offset, t)
|
|
2719
|
+
template_args = []
|
|
2720
|
+
|
|
2721
|
+
return (func_args, template_args)
|
|
2722
|
+
|
|
2723
|
+
|
|
2724
|
+
add_builtin(
|
|
2725
|
+
"tile_store_indexed",
|
|
2726
|
+
input_types={
|
|
2727
|
+
"a": array(dtype=Any),
|
|
2728
|
+
"indices": tile(dtype=int, shape=Tuple[int]),
|
|
2729
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2730
|
+
"offset": Tuple[int, ...],
|
|
2731
|
+
"axis": int,
|
|
2732
|
+
},
|
|
2733
|
+
value_func=tile_store_indexed_value_func,
|
|
2734
|
+
dispatch_func=tile_store_indexed_dispatch_func,
|
|
2735
|
+
defaults={"offset": None, "axis": 0},
|
|
2736
|
+
variadic=False,
|
|
2737
|
+
skip_replay=True,
|
|
2738
|
+
doc="""Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
|
|
2739
|
+
|
|
2740
|
+
:param a: The destination array in global memory
|
|
2741
|
+
:param indices: A 1D tile of integer indices mapping to elements in ``a``.
|
|
2742
|
+
:param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
|
|
2743
|
+
:param offset: Offset in the destination array (optional)
|
|
2744
|
+
:param axis: Axis of ``a`` that indices refer to
|
|
2745
|
+
|
|
2746
|
+
This example shows how to map tile rows to the even rows of a 2D array:
|
|
2747
|
+
|
|
2748
|
+
.. code-block:: python
|
|
2749
|
+
|
|
2750
|
+
TILE_M = wp.constant(2)
|
|
2751
|
+
TILE_N = wp.constant(2)
|
|
2752
|
+
TWO_M = wp.constant(TILE_M * 2)
|
|
2753
|
+
TWO_N = wp.constant(TILE_N * 2)
|
|
2754
|
+
|
|
2755
|
+
@wp.kernel
|
|
2756
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
2757
|
+
i, j = wp.tid()
|
|
2758
|
+
|
|
2759
|
+
t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
|
|
2760
|
+
|
|
2761
|
+
evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2
|
|
2762
|
+
|
|
2763
|
+
wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0)
|
|
2764
|
+
|
|
2765
|
+
M = TILE_M * 2
|
|
2766
|
+
N = TILE_N * 2
|
|
2767
|
+
|
|
2768
|
+
arr = np.arange(M * N, dtype=float).reshape(M, N)
|
|
2769
|
+
|
|
2770
|
+
x = wp.array(arr, dtype=float, requires_grad=True, device=device)
|
|
2771
|
+
y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)
|
|
2772
|
+
|
|
2773
|
+
wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
|
|
2774
|
+
|
|
2775
|
+
print(x.numpy())
|
|
2776
|
+
print(y.numpy())
|
|
2777
|
+
|
|
2778
|
+
Prints:
|
|
2779
|
+
|
|
2780
|
+
.. code-block:: text
|
|
2781
|
+
|
|
2782
|
+
[[ 0. 1. 2. 3.]
|
|
2783
|
+
[ 4. 5. 6. 7.]
|
|
2784
|
+
[ 8. 9. 10. 11.]
|
|
2785
|
+
[12. 13. 14. 15.]]
|
|
2786
|
+
|
|
2787
|
+
[[ 0. 1. 2. 3.]
|
|
2788
|
+
[ 0. 0. 0. 0.]
|
|
2789
|
+
[ 4. 5. 6. 7.]
|
|
2790
|
+
[ 0. 0. 0. 0.]
|
|
2791
|
+
[ 8. 9. 10. 11.]
|
|
2792
|
+
[ 0. 0. 0. 0.]
|
|
2793
|
+
[12. 13. 14. 15.]
|
|
2794
|
+
[ 0. 0. 0. 0.]]
|
|
2795
|
+
""",
|
|
2796
|
+
group="Tile Primitives",
|
|
2797
|
+
export=False,
|
|
2798
|
+
)
|
|
2799
|
+
|
|
2800
|
+
|
|
2489
2801
|
def tile_atomic_add_value_func(arg_types, arg_values):
|
|
2490
2802
|
# return generic type (for doc builds)
|
|
2491
2803
|
if arg_types is None:
|
|
@@ -2526,6 +2838,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
|
|
|
2526
2838
|
def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2527
2839
|
a = args["a"]
|
|
2528
2840
|
t = args["t"]
|
|
2841
|
+
bounds_check = args["bounds_check"]
|
|
2529
2842
|
|
|
2530
2843
|
if "offset" in args:
|
|
2531
2844
|
offset = extract_tuple(args["offset"])
|
|
@@ -2533,17 +2846,22 @@ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type:
|
|
|
2533
2846
|
offset = (0,) * a.type.ndim
|
|
2534
2847
|
|
|
2535
2848
|
func_args = (a, *offset, t)
|
|
2536
|
-
template_args =
|
|
2849
|
+
template_args = (a.type.dtype, bounds_check.constant)
|
|
2537
2850
|
|
|
2538
2851
|
return (func_args, template_args)
|
|
2539
2852
|
|
|
2540
2853
|
|
|
2541
2854
|
add_builtin(
|
|
2542
2855
|
"tile_atomic_add",
|
|
2543
|
-
input_types={
|
|
2856
|
+
input_types={
|
|
2857
|
+
"a": array(dtype=Any),
|
|
2858
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2859
|
+
"offset": Tuple[int, ...],
|
|
2860
|
+
"bounds_check": builtins.bool,
|
|
2861
|
+
},
|
|
2544
2862
|
value_func=tile_atomic_add_value_func,
|
|
2545
2863
|
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2546
|
-
defaults={"offset": None},
|
|
2864
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2547
2865
|
variadic=False,
|
|
2548
2866
|
skip_replay=True,
|
|
2549
2867
|
doc="""Atomically add a tile onto the array `a`, each element will be updated atomically.
|
|
@@ -2551,6 +2869,7 @@ add_builtin(
|
|
|
2551
2869
|
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
2552
2870
|
:param t: Source tile to add to the destination array
|
|
2553
2871
|
:param offset: Offset in the destination array (optional)
|
|
2872
|
+
:param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
|
|
2554
2873
|
:returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
|
|
2555
2874
|
group="Tile Primitives",
|
|
2556
2875
|
export=False,
|
|
@@ -2559,10 +2878,15 @@ add_builtin(
|
|
|
2559
2878
|
# overload for scalar offset
|
|
2560
2879
|
add_builtin(
|
|
2561
2880
|
"tile_atomic_add",
|
|
2562
|
-
input_types={
|
|
2881
|
+
input_types={
|
|
2882
|
+
"a": array(dtype=Any),
|
|
2883
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2884
|
+
"offset": int,
|
|
2885
|
+
"bounds_check": builtins.bool,
|
|
2886
|
+
},
|
|
2563
2887
|
value_func=tile_atomic_add_value_func,
|
|
2564
2888
|
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2565
|
-
defaults={"offset": None},
|
|
2889
|
+
defaults={"offset": None, "bounds_check": True},
|
|
2566
2890
|
variadic=False,
|
|
2567
2891
|
skip_replay=True,
|
|
2568
2892
|
group="Tile Primitives",
|
|
@@ -2571,6 +2895,143 @@ add_builtin(
|
|
|
2571
2895
|
)
|
|
2572
2896
|
|
|
2573
2897
|
|
|
2898
|
+
def tile_atomic_add_indexed_value_func(arg_types, arg_values):
|
|
2899
|
+
# return generic type (for doc builds)
|
|
2900
|
+
if arg_types is None:
|
|
2901
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2902
|
+
|
|
2903
|
+
a = arg_types["a"]
|
|
2904
|
+
t = arg_types["t"]
|
|
2905
|
+
indices_tile = arg_types["indices"]
|
|
2906
|
+
indices_tile.storage = "shared" # force to shared
|
|
2907
|
+
|
|
2908
|
+
axis = arg_values["axis"]
|
|
2909
|
+
if axis >= a.ndim:
|
|
2910
|
+
raise ValueError(f"tile_atomic_add_indexed() axis argument must be valid axis of array {a}, got {axis}.")
|
|
2911
|
+
|
|
2912
|
+
indices_tile_dim = len(indices_tile.shape)
|
|
2913
|
+
if indices_tile_dim != 1:
|
|
2914
|
+
raise ValueError(
|
|
2915
|
+
f"tile_atomic_add_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
|
|
2916
|
+
)
|
|
2917
|
+
|
|
2918
|
+
num_indices = indices_tile.shape[0]
|
|
2919
|
+
if num_indices != t.shape[axis]:
|
|
2920
|
+
raise ValueError(
|
|
2921
|
+
"The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
|
|
2922
|
+
)
|
|
2923
|
+
|
|
2924
|
+
if "offset" in arg_types:
|
|
2925
|
+
c = extract_tuple(arg_values["offset"])
|
|
2926
|
+
else:
|
|
2927
|
+
c = (0,) * a.ndim
|
|
2928
|
+
|
|
2929
|
+
if len(c) != a.ndim:
|
|
2930
|
+
raise ValueError(
|
|
2931
|
+
f"tile_atomic_add_indexed() 'a' argument must have {len(c)} dimensions, "
|
|
2932
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2933
|
+
)
|
|
2934
|
+
|
|
2935
|
+
if len(t.shape) != a.ndim:
|
|
2936
|
+
raise ValueError(
|
|
2937
|
+
f"tile_atomic_add_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2938
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2939
|
+
)
|
|
2940
|
+
|
|
2941
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2942
|
+
raise TypeError(
|
|
2943
|
+
f"tile_atomic_add_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2944
|
+
)
|
|
2945
|
+
|
|
2946
|
+
return tile(dtype=t.dtype, shape=t.shape, storage=t.storage)
|
|
2947
|
+
|
|
2948
|
+
|
|
2949
|
+
def tile_atomic_add_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2950
|
+
a = args["a"]
|
|
2951
|
+
indices_tile = args["indices"]
|
|
2952
|
+
axis = args["axis"]
|
|
2953
|
+
t = args["t"]
|
|
2954
|
+
|
|
2955
|
+
if "offset" in args:
|
|
2956
|
+
offset = extract_tuple(args["offset"])
|
|
2957
|
+
else:
|
|
2958
|
+
offset = (0,) * a.type.ndim
|
|
2959
|
+
|
|
2960
|
+
func_args = (a, indices_tile, axis, *offset, t)
|
|
2961
|
+
template_args = []
|
|
2962
|
+
|
|
2963
|
+
return (func_args, template_args)
|
|
2964
|
+
|
|
2965
|
+
|
|
2966
|
+
add_builtin(
|
|
2967
|
+
"tile_atomic_add_indexed",
|
|
2968
|
+
input_types={
|
|
2969
|
+
"a": array(dtype=Any),
|
|
2970
|
+
"indices": tile(dtype=int, shape=Tuple[int]),
|
|
2971
|
+
"t": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2972
|
+
"offset": Tuple[int, ...],
|
|
2973
|
+
"axis": int,
|
|
2974
|
+
},
|
|
2975
|
+
value_func=tile_atomic_add_indexed_value_func,
|
|
2976
|
+
dispatch_func=tile_atomic_add_indexed_dispatch_func,
|
|
2977
|
+
defaults={"offset": None, "axis": 0},
|
|
2978
|
+
variadic=False,
|
|
2979
|
+
skip_replay=True,
|
|
2980
|
+
doc="""Atomically add a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
|
|
2981
|
+
|
|
2982
|
+
:param a: The destination array in global memory
|
|
2983
|
+
:param indices: A 1D tile of integer indices mapping to elements in ``a``.
|
|
2984
|
+
:param t: The source tile to extract data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
|
|
2985
|
+
:param offset: Offset in the destination array (optional)
|
|
2986
|
+
:param axis: Axis of ``a`` that indices refer to
|
|
2987
|
+
|
|
2988
|
+
This example shows how to compute a blocked, row-wise reduction:
|
|
2989
|
+
|
|
2990
|
+
.. code-block:: python
|
|
2991
|
+
|
|
2992
|
+
TILE_M = wp.constant(2)
|
|
2993
|
+
TILE_N = wp.constant(2)
|
|
2994
|
+
|
|
2995
|
+
@wp.kernel
|
|
2996
|
+
def tile_atomic_add_indexed(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
2997
|
+
i, j = wp.tid()
|
|
2998
|
+
|
|
2999
|
+
t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
|
|
3000
|
+
|
|
3001
|
+
zeros = wp.tile_zeros(TILE_M, dtype=int, storage="shared")
|
|
3002
|
+
|
|
3003
|
+
wp.tile_atomic_add_indexed(y, indices=zeros, t=t, offset=(i, j*TILE_N), axis=0)
|
|
3004
|
+
|
|
3005
|
+
M = TILE_M * 2
|
|
3006
|
+
N = TILE_N * 2
|
|
3007
|
+
|
|
3008
|
+
arr = np.arange(M * N, dtype=float).reshape(M, N)
|
|
3009
|
+
|
|
3010
|
+
x = wp.array(arr, dtype=float, requires_grad=True, device=device)
|
|
3011
|
+
y = wp.zeros((2, N), dtype=float, requires_grad=True, device=device)
|
|
3012
|
+
|
|
3013
|
+
wp.launch_tiled(tile_atomic_add_indexed, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
|
|
3014
|
+
|
|
3015
|
+
print(x.numpy())
|
|
3016
|
+
print(y.numpy())
|
|
3017
|
+
|
|
3018
|
+
Prints:
|
|
3019
|
+
|
|
3020
|
+
.. code-block:: text
|
|
3021
|
+
|
|
3022
|
+
[[ 0. 1. 2. 3.]
|
|
3023
|
+
[ 4. 5. 6. 7.]
|
|
3024
|
+
[ 8. 9. 10. 11.]
|
|
3025
|
+
[12. 13. 14. 15.]]
|
|
3026
|
+
|
|
3027
|
+
[[ 4. 6. 8. 10.]
|
|
3028
|
+
[20. 22. 24. 26.]]
|
|
3029
|
+
""",
|
|
3030
|
+
group="Tile Primitives",
|
|
3031
|
+
export=False,
|
|
3032
|
+
)
|
|
3033
|
+
|
|
3034
|
+
|
|
2574
3035
|
def tile_view_value_func(arg_types, arg_values):
|
|
2575
3036
|
# return generic type (for doc builds)
|
|
2576
3037
|
if arg_types is None:
|
|
@@ -3934,14 +4395,45 @@ def tile_unary_map_value_func(arg_types, arg_values):
|
|
|
3934
4395
|
if not is_tile(a):
|
|
3935
4396
|
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
3936
4397
|
|
|
3937
|
-
|
|
4398
|
+
if "op" in arg_values:
|
|
4399
|
+
op = arg_values["op"]
|
|
4400
|
+
try:
|
|
4401
|
+
overload = op.get_overload([a.dtype], {})
|
|
4402
|
+
except KeyError as exc:
|
|
4403
|
+
raise RuntimeError(f"No overload of {op} found for tile element type {type_repr(a.dtype)}") from exc
|
|
4404
|
+
|
|
4405
|
+
# build the right overload on demand
|
|
4406
|
+
if overload.value_func is None:
|
|
4407
|
+
overload.build(None)
|
|
4408
|
+
|
|
4409
|
+
value_type = overload.value_func(None, None)
|
|
4410
|
+
|
|
4411
|
+
if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
|
|
4412
|
+
raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
|
|
4413
|
+
|
|
4414
|
+
return tile(dtype=value_type, shape=a.shape)
|
|
4415
|
+
|
|
4416
|
+
else:
|
|
4417
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
4418
|
+
|
|
4419
|
+
|
|
4420
|
+
def tile_unary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4421
|
+
op = arg_values["op"]
|
|
4422
|
+
tile_a = arg_values["a"]
|
|
4423
|
+
|
|
4424
|
+
overload = op.get_overload([tile_a.type.dtype], {})
|
|
4425
|
+
|
|
4426
|
+
# necessary, in case return type is different from input tile types
|
|
4427
|
+
tile_r = Var(label=None, type=return_type)
|
|
4428
|
+
|
|
4429
|
+
return ((overload, tile_a, tile_r), ())
|
|
3938
4430
|
|
|
3939
4431
|
|
|
3940
4432
|
add_builtin(
|
|
3941
4433
|
"tile_map",
|
|
3942
4434
|
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3943
4435
|
value_func=tile_unary_map_value_func,
|
|
3944
|
-
|
|
4436
|
+
dispatch_func=tile_unary_map_dispatch_func,
|
|
3945
4437
|
# variadic=True,
|
|
3946
4438
|
native_func="tile_unary_map",
|
|
3947
4439
|
doc="""Apply a unary function onto the tile.
|
|
@@ -3950,7 +4442,7 @@ add_builtin(
|
|
|
3950
4442
|
|
|
3951
4443
|
:param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
|
|
3952
4444
|
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
|
|
3953
|
-
:returns: A tile with the same dimensions
|
|
4445
|
+
:returns: A tile with the same dimensions as the input tile. Its datatype is specified by the return type of op
|
|
3954
4446
|
|
|
3955
4447
|
Example:
|
|
3956
4448
|
|
|
@@ -3991,10 +4483,6 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
3991
4483
|
if not is_tile(b):
|
|
3992
4484
|
raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
|
|
3993
4485
|
|
|
3994
|
-
# ensure types equal
|
|
3995
|
-
if not types_equal(a.dtype, b.dtype):
|
|
3996
|
-
raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
|
|
3997
|
-
|
|
3998
4486
|
if len(a.shape) != len(b.shape):
|
|
3999
4487
|
raise ValueError(
|
|
4000
4488
|
f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
|
|
@@ -4004,7 +4492,47 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
4004
4492
|
if a.shape[i] != b.shape[i]:
|
|
4005
4493
|
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
4006
4494
|
|
|
4007
|
-
|
|
4495
|
+
if "op" in arg_values:
|
|
4496
|
+
op = arg_values["op"]
|
|
4497
|
+
try:
|
|
4498
|
+
overload = op.get_overload([a.dtype, b.dtype], {})
|
|
4499
|
+
except KeyError as exc:
|
|
4500
|
+
raise RuntimeError(
|
|
4501
|
+
f"No overload of {op} found for tile element types {type_repr(a.dtype)}, {type_repr(b.dtype)}"
|
|
4502
|
+
) from exc
|
|
4503
|
+
|
|
4504
|
+
# build the right overload on demand
|
|
4505
|
+
if overload.value_func is None:
|
|
4506
|
+
overload.build(None)
|
|
4507
|
+
|
|
4508
|
+
value_type = overload.value_func(None, None)
|
|
4509
|
+
|
|
4510
|
+
if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
|
|
4511
|
+
raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
|
|
4512
|
+
|
|
4513
|
+
return tile(dtype=value_type, shape=a.shape)
|
|
4514
|
+
|
|
4515
|
+
else:
|
|
4516
|
+
# ensure types equal
|
|
4517
|
+
if not types_equal(a.dtype, b.dtype):
|
|
4518
|
+
raise TypeError(
|
|
4519
|
+
f"tile_map() arguments must have the same dtype for this operation, got {a.dtype} and {b.dtype}"
|
|
4520
|
+
)
|
|
4521
|
+
|
|
4522
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
4523
|
+
|
|
4524
|
+
|
|
4525
|
+
def tile_binary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4526
|
+
op = arg_values["op"]
|
|
4527
|
+
tile_a = arg_values["a"]
|
|
4528
|
+
tile_b = arg_values["b"]
|
|
4529
|
+
|
|
4530
|
+
overload = op.get_overload([tile_a.type.dtype, tile_b.type.dtype], {})
|
|
4531
|
+
|
|
4532
|
+
# necessary, in case return type is different from input tile types
|
|
4533
|
+
tile_r = Var(label=None, type=return_type)
|
|
4534
|
+
|
|
4535
|
+
return ((overload, tile_a, tile_b, tile_r), ())
|
|
4008
4536
|
|
|
4009
4537
|
|
|
4010
4538
|
add_builtin(
|
|
@@ -4015,18 +4543,18 @@ add_builtin(
|
|
|
4015
4543
|
"b": tile(dtype=Scalar, shape=Tuple[int, ...]),
|
|
4016
4544
|
},
|
|
4017
4545
|
value_func=tile_binary_map_value_func,
|
|
4018
|
-
|
|
4546
|
+
dispatch_func=tile_binary_map_dispatch_func,
|
|
4019
4547
|
# variadic=True,
|
|
4020
4548
|
native_func="tile_binary_map",
|
|
4021
4549
|
doc="""Apply a binary function onto the tile.
|
|
4022
4550
|
|
|
4023
4551
|
This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
|
|
4024
|
-
Both input tiles must have the same dimensions and
|
|
4552
|
+
Both input tiles must have the same dimensions, and if using a builtin op, the same datatypes.
|
|
4025
4553
|
|
|
4026
4554
|
:param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
|
|
4027
4555
|
:param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
4028
4556
|
:param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
4029
|
-
:returns: A tile with the same dimensions
|
|
4557
|
+
:returns: A tile with the same dimensions as the input tiles. Its datatype is specified by the return type of op
|
|
4030
4558
|
|
|
4031
4559
|
Example:
|
|
4032
4560
|
|
|
@@ -5544,7 +6072,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
|
|
|
5544
6072
|
return array(dtype=Scalar)
|
|
5545
6073
|
|
|
5546
6074
|
dtype = arg_values["dtype"]
|
|
5547
|
-
shape = extract_tuple(arg_values["shape"], as_constant=
|
|
6075
|
+
shape = extract_tuple(arg_values["shape"], as_constant=False)
|
|
5548
6076
|
return array(dtype=dtype, ndim=len(shape))
|
|
5549
6077
|
|
|
5550
6078
|
|
|
@@ -5554,7 +6082,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
|
|
|
5554
6082
|
# to the underlying C++ function's runtime and template params.
|
|
5555
6083
|
|
|
5556
6084
|
dtype = return_type.dtype
|
|
5557
|
-
shape = extract_tuple(args["shape"], as_constant=
|
|
6085
|
+
shape = extract_tuple(args["shape"], as_constant=False)
|
|
5558
6086
|
|
|
5559
6087
|
func_args = (args["ptr"], *shape)
|
|
5560
6088
|
template_args = (dtype,)
|
|
@@ -5563,7 +6091,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
|
|
|
5563
6091
|
|
|
5564
6092
|
add_builtin(
|
|
5565
6093
|
"array",
|
|
5566
|
-
input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype":
|
|
6094
|
+
input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Any},
|
|
5567
6095
|
value_func=array_value_func,
|
|
5568
6096
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
5569
6097
|
dispatch_func=array_dispatch_func,
|
|
@@ -5575,6 +6103,48 @@ add_builtin(
|
|
|
5575
6103
|
)
|
|
5576
6104
|
|
|
5577
6105
|
|
|
6106
|
+
def zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6107
|
+
if arg_types is None:
|
|
6108
|
+
return fixedarray(dtype=Scalar)
|
|
6109
|
+
|
|
6110
|
+
dtype = arg_values["dtype"]
|
|
6111
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
6112
|
+
|
|
6113
|
+
if None in shape:
|
|
6114
|
+
raise RuntimeError("the `shape` argument must be specified as a constant when zero-initializing an array")
|
|
6115
|
+
|
|
6116
|
+
return fixedarray(dtype=dtype, shape=shape)
|
|
6117
|
+
|
|
6118
|
+
|
|
6119
|
+
def zeros_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6120
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
6121
|
+
# Further validate the given argument values if needed and map them
|
|
6122
|
+
# to the underlying C++ function's runtime and template params.
|
|
6123
|
+
|
|
6124
|
+
dtype = return_type.dtype
|
|
6125
|
+
shape = extract_tuple(args["shape"], as_constant=True)
|
|
6126
|
+
|
|
6127
|
+
size = math.prod(shape)
|
|
6128
|
+
|
|
6129
|
+
func_args = shape
|
|
6130
|
+
template_args = (size, dtype)
|
|
6131
|
+
return (func_args, template_args)
|
|
6132
|
+
|
|
6133
|
+
|
|
6134
|
+
add_builtin(
|
|
6135
|
+
"zeros",
|
|
6136
|
+
input_types={"shape": Tuple[int, ...], "dtype": Any},
|
|
6137
|
+
value_func=zeros_value_func,
|
|
6138
|
+
export_func=lambda input_types: {},
|
|
6139
|
+
dispatch_func=zeros_dispatch_func,
|
|
6140
|
+
native_func="fixedarray_t",
|
|
6141
|
+
group="Utility",
|
|
6142
|
+
export=False,
|
|
6143
|
+
missing_grad=True,
|
|
6144
|
+
hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
|
|
6145
|
+
)
|
|
6146
|
+
|
|
6147
|
+
|
|
5578
6148
|
# does argument checking and type propagation for address()
|
|
5579
6149
|
def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5580
6150
|
arr_type = arg_types["arr"]
|
|
@@ -5864,8 +6434,8 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
|
|
|
5864
6434
|
|
|
5865
6435
|
|
|
5866
6436
|
for array_type in array_types:
|
|
5867
|
-
# don't list indexed array operations explicitly in docs
|
|
5868
|
-
hidden = array_type
|
|
6437
|
+
# don't list fixed or indexed array operations explicitly in docs
|
|
6438
|
+
hidden = array_type in (indexedarray, fixedarray)
|
|
5869
6439
|
|
|
5870
6440
|
add_builtin(
|
|
5871
6441
|
"atomic_add",
|
|
@@ -6187,46 +6757,110 @@ for array_type in array_types:
|
|
|
6187
6757
|
|
|
6188
6758
|
|
|
6189
6759
|
# used to index into builtin types, i.e.: y = vec3[1]
|
|
6190
|
-
def
|
|
6191
|
-
|
|
6760
|
+
def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6761
|
+
vec_type = arg_types["a"]
|
|
6762
|
+
idx_type = arg_types["i"]
|
|
6763
|
+
|
|
6764
|
+
if isinstance(idx_type, slice_t):
|
|
6765
|
+
length = idx_type.get_length(vec_type._length_)
|
|
6766
|
+
return vector(length=length, dtype=vec_type._wp_scalar_type_)
|
|
6767
|
+
|
|
6768
|
+
return vec_type._wp_scalar_type_
|
|
6769
|
+
|
|
6770
|
+
|
|
6771
|
+
def vector_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6772
|
+
func_args = tuple(args.values())
|
|
6773
|
+
template_args = getattr(return_type, "_shape_", ())
|
|
6774
|
+
return (func_args, template_args)
|
|
6192
6775
|
|
|
6193
6776
|
|
|
6194
6777
|
add_builtin(
|
|
6195
6778
|
"extract",
|
|
6196
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
6197
|
-
value_func=
|
|
6779
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any},
|
|
6780
|
+
value_func=vector_extract_value_func,
|
|
6781
|
+
dispatch_func=vector_extract_dispatch_func,
|
|
6782
|
+
export=False,
|
|
6198
6783
|
hidden=True,
|
|
6199
6784
|
group="Utility",
|
|
6200
6785
|
)
|
|
6201
6786
|
add_builtin(
|
|
6202
6787
|
"extract",
|
|
6203
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
6204
|
-
value_func=
|
|
6788
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any},
|
|
6789
|
+
value_func=vector_extract_value_func,
|
|
6790
|
+
dispatch_func=vector_extract_dispatch_func,
|
|
6791
|
+
export=False,
|
|
6205
6792
|
hidden=True,
|
|
6206
6793
|
group="Utility",
|
|
6207
6794
|
)
|
|
6208
|
-
|
|
6209
6795
|
add_builtin(
|
|
6210
6796
|
"extract",
|
|
6211
|
-
input_types={"a":
|
|
6212
|
-
value_func=
|
|
6213
|
-
|
|
6214
|
-
|
|
6797
|
+
input_types={"a": transformation(dtype=Scalar), "i": Any},
|
|
6798
|
+
value_func=vector_extract_value_func,
|
|
6799
|
+
dispatch_func=vector_extract_dispatch_func,
|
|
6800
|
+
export=False,
|
|
6215
6801
|
hidden=True,
|
|
6216
6802
|
group="Utility",
|
|
6217
6803
|
)
|
|
6804
|
+
|
|
6805
|
+
|
|
6806
|
+
def matrix_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6807
|
+
mat_type = arg_types["a"]
|
|
6808
|
+
idx_types = tuple(arg_types[x] for x in "ij" if arg_types.get(x, None) is not None)
|
|
6809
|
+
|
|
6810
|
+
# Compute the resulting shape from the slicing, with -1 being simple indexing.
|
|
6811
|
+
shape = tuple(
|
|
6812
|
+
idx.get_length(mat_type._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idx_types)
|
|
6813
|
+
)
|
|
6814
|
+
|
|
6815
|
+
# Append any non indexed slice.
|
|
6816
|
+
for i in range(len(idx_types), len(mat_type._shape_)):
|
|
6817
|
+
shape += (mat_type._shape_[i],)
|
|
6818
|
+
|
|
6819
|
+
# Count how many dimensions the output value will have.
|
|
6820
|
+
ndim = sum(1 for x in shape if x >= 0)
|
|
6821
|
+
|
|
6822
|
+
if ndim == 0:
|
|
6823
|
+
return mat_type._wp_scalar_type_
|
|
6824
|
+
|
|
6825
|
+
assert shape[0] != -1 or shape[1] != -1
|
|
6826
|
+
|
|
6827
|
+
if ndim == 1:
|
|
6828
|
+
length = shape[0] if shape[0] != -1 else shape[1]
|
|
6829
|
+
return vector(length=length, dtype=mat_type._wp_scalar_type_)
|
|
6830
|
+
|
|
6831
|
+
assert ndim == 2
|
|
6832
|
+
|
|
6833
|
+
# When a matrix dimension is 0, all other dimensions are also expected to be 0.
|
|
6834
|
+
if any(x == 0 for x in shape):
|
|
6835
|
+
shape = (0,) * len(shape)
|
|
6836
|
+
|
|
6837
|
+
return matrix(shape=shape, dtype=mat_type._wp_scalar_type_)
|
|
6838
|
+
|
|
6839
|
+
|
|
6840
|
+
def matrix_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6841
|
+
idx_types = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
|
|
6842
|
+
has_slice = any(isinstance(x, slice_t) for x in idx_types)
|
|
6843
|
+
|
|
6844
|
+
func_args = tuple(args.values())
|
|
6845
|
+
template_args = getattr(return_type, "_shape_", ()) if has_slice else ()
|
|
6846
|
+
return (func_args, template_args)
|
|
6847
|
+
|
|
6848
|
+
|
|
6218
6849
|
add_builtin(
|
|
6219
6850
|
"extract",
|
|
6220
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6221
|
-
value_func=
|
|
6851
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any},
|
|
6852
|
+
value_func=matrix_extract_value_func,
|
|
6853
|
+
dispatch_func=matrix_extract_dispatch_func,
|
|
6854
|
+
export=False,
|
|
6222
6855
|
hidden=True,
|
|
6223
6856
|
group="Utility",
|
|
6224
6857
|
)
|
|
6225
|
-
|
|
6226
6858
|
add_builtin(
|
|
6227
6859
|
"extract",
|
|
6228
|
-
input_types={"a":
|
|
6229
|
-
value_func=
|
|
6860
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any},
|
|
6861
|
+
value_func=matrix_extract_value_func,
|
|
6862
|
+
dispatch_func=matrix_extract_dispatch_func,
|
|
6863
|
+
export=False,
|
|
6230
6864
|
hidden=True,
|
|
6231
6865
|
group="Utility",
|
|
6232
6866
|
)
|
|
@@ -6247,6 +6881,19 @@ def vector_index_dispatch_func(input_types: Mapping[str, type], return_type: Any
|
|
|
6247
6881
|
return (func_args, template_args)
|
|
6248
6882
|
|
|
6249
6883
|
|
|
6884
|
+
def matrix_ij_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6885
|
+
mat_type = arg_types["a"]
|
|
6886
|
+
value_type = mat_type._wp_scalar_type_
|
|
6887
|
+
|
|
6888
|
+
return Reference(value_type)
|
|
6889
|
+
|
|
6890
|
+
|
|
6891
|
+
def matrix_ij_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6892
|
+
func_args = (Reference(args["a"]), args["i"], args["j"])
|
|
6893
|
+
template_args = ()
|
|
6894
|
+
return (func_args, template_args)
|
|
6895
|
+
|
|
6896
|
+
|
|
6250
6897
|
# implements &vector[index]
|
|
6251
6898
|
add_builtin(
|
|
6252
6899
|
"index",
|
|
@@ -6287,6 +6934,16 @@ add_builtin(
|
|
|
6287
6934
|
group="Utility",
|
|
6288
6935
|
skip_replay=True,
|
|
6289
6936
|
)
|
|
6937
|
+
# implements &(*matrix)[i, j]
|
|
6938
|
+
add_builtin(
|
|
6939
|
+
"indexref",
|
|
6940
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
|
|
6941
|
+
value_func=matrix_ij_value_func,
|
|
6942
|
+
dispatch_func=matrix_ij_dispatch_func,
|
|
6943
|
+
hidden=True,
|
|
6944
|
+
group="Utility",
|
|
6945
|
+
skip_replay=True,
|
|
6946
|
+
)
|
|
6290
6947
|
# implements &(*quaternion)[index]
|
|
6291
6948
|
add_builtin(
|
|
6292
6949
|
"indexref",
|
|
@@ -6309,11 +6966,46 @@ add_builtin(
|
|
|
6309
6966
|
)
|
|
6310
6967
|
|
|
6311
6968
|
|
|
6969
|
+
def vector_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6970
|
+
vec = args["a"].type
|
|
6971
|
+
idx = args["i"].type
|
|
6972
|
+
value_type = strip_reference(args["value"].type)
|
|
6973
|
+
|
|
6974
|
+
if isinstance(idx, slice_t):
|
|
6975
|
+
length = idx.get_length(vec._length_)
|
|
6976
|
+
|
|
6977
|
+
if type_is_vector(value_type):
|
|
6978
|
+
if not types_equal(value_type._wp_scalar_type_, vec._wp_scalar_type_):
|
|
6979
|
+
raise ValueError(
|
|
6980
|
+
f"The provided vector is expected to be of length {length} with dtype {type_repr(vec._wp_scalar_type_)}."
|
|
6981
|
+
)
|
|
6982
|
+
if value_type._length_ != length:
|
|
6983
|
+
raise ValueError(
|
|
6984
|
+
f"The length of the provided vector ({args['value'].type._length_}) isn't compatible with the given slice (expected {length})."
|
|
6985
|
+
)
|
|
6986
|
+
template_args = (length,)
|
|
6987
|
+
else:
|
|
6988
|
+
# Disallow broadcasting.
|
|
6989
|
+
raise ValueError(
|
|
6990
|
+
f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(vec._wp_scalar_type_)}."
|
|
6991
|
+
)
|
|
6992
|
+
else:
|
|
6993
|
+
if not types_equal(value_type, vec._wp_scalar_type_):
|
|
6994
|
+
raise ValueError(
|
|
6995
|
+
f"The provided value is expected to be a scalar of type {type_repr(vec._wp_scalar_type_)}."
|
|
6996
|
+
)
|
|
6997
|
+
template_args = ()
|
|
6998
|
+
|
|
6999
|
+
func_args = tuple(args.values())
|
|
7000
|
+
return (func_args, template_args)
|
|
7001
|
+
|
|
7002
|
+
|
|
6312
7003
|
# implements vector[index] = value
|
|
6313
7004
|
add_builtin(
|
|
6314
7005
|
"assign_inplace",
|
|
6315
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
7006
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
6316
7007
|
value_type=None,
|
|
7008
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6317
7009
|
hidden=True,
|
|
6318
7010
|
export=False,
|
|
6319
7011
|
group="Utility",
|
|
@@ -6322,8 +7014,9 @@ add_builtin(
|
|
|
6322
7014
|
# implements quaternion[index] = value
|
|
6323
7015
|
add_builtin(
|
|
6324
7016
|
"assign_inplace",
|
|
6325
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
7017
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
6326
7018
|
value_type=None,
|
|
7019
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6327
7020
|
hidden=True,
|
|
6328
7021
|
export=False,
|
|
6329
7022
|
group="Utility",
|
|
@@ -6331,15 +7024,16 @@ add_builtin(
|
|
|
6331
7024
|
# implements transformation[index] = value
|
|
6332
7025
|
add_builtin(
|
|
6333
7026
|
"assign_inplace",
|
|
6334
|
-
input_types={"a": transformation(dtype=Scalar), "i":
|
|
7027
|
+
input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
|
|
6335
7028
|
value_type=None,
|
|
7029
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6336
7030
|
hidden=True,
|
|
6337
7031
|
export=False,
|
|
6338
7032
|
group="Utility",
|
|
6339
7033
|
)
|
|
6340
7034
|
|
|
6341
7035
|
|
|
6342
|
-
def
|
|
7036
|
+
def vector_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6343
7037
|
vec_type = arg_types["a"]
|
|
6344
7038
|
return vec_type
|
|
6345
7039
|
|
|
@@ -6347,8 +7041,9 @@ def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
|
|
|
6347
7041
|
# implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6348
7042
|
add_builtin(
|
|
6349
7043
|
"assign_copy",
|
|
6350
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
6351
|
-
value_func=
|
|
7044
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
7045
|
+
value_func=vector_assign_copy_value_func,
|
|
7046
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6352
7047
|
hidden=True,
|
|
6353
7048
|
export=False,
|
|
6354
7049
|
group="Utility",
|
|
@@ -6357,8 +7052,9 @@ add_builtin(
|
|
|
6357
7052
|
# implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6358
7053
|
add_builtin(
|
|
6359
7054
|
"assign_copy",
|
|
6360
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
6361
|
-
value_func=
|
|
7055
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
7056
|
+
value_func=vector_assign_copy_value_func,
|
|
7057
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6362
7058
|
hidden=True,
|
|
6363
7059
|
export=False,
|
|
6364
7060
|
group="Utility",
|
|
@@ -6367,8 +7063,9 @@ add_builtin(
|
|
|
6367
7063
|
# implements transformation[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6368
7064
|
add_builtin(
|
|
6369
7065
|
"assign_copy",
|
|
6370
|
-
input_types={"a": transformation(dtype=Scalar), "i":
|
|
6371
|
-
value_func=
|
|
7066
|
+
input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
|
|
7067
|
+
value_func=vector_assign_copy_value_func,
|
|
7068
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6372
7069
|
hidden=True,
|
|
6373
7070
|
export=False,
|
|
6374
7071
|
group="Utility",
|
|
@@ -6377,8 +7074,9 @@ add_builtin(
|
|
|
6377
7074
|
# implements vector[idx] += scalar
|
|
6378
7075
|
add_builtin(
|
|
6379
7076
|
"add_inplace",
|
|
6380
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
7077
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
6381
7078
|
value_type=None,
|
|
7079
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6382
7080
|
hidden=True,
|
|
6383
7081
|
export=False,
|
|
6384
7082
|
group="Utility",
|
|
@@ -6387,8 +7085,9 @@ add_builtin(
|
|
|
6387
7085
|
# implements quaternion[idx] += scalar
|
|
6388
7086
|
add_builtin(
|
|
6389
7087
|
"add_inplace",
|
|
6390
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
7088
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
6391
7089
|
value_type=None,
|
|
7090
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6392
7091
|
hidden=True,
|
|
6393
7092
|
export=False,
|
|
6394
7093
|
group="Utility",
|
|
@@ -6397,8 +7096,9 @@ add_builtin(
|
|
|
6397
7096
|
# implements transformation[idx] += scalar
|
|
6398
7097
|
add_builtin(
|
|
6399
7098
|
"add_inplace",
|
|
6400
|
-
input_types={"a": transformation(dtype=Float), "i":
|
|
7099
|
+
input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
|
|
6401
7100
|
value_type=None,
|
|
7101
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6402
7102
|
hidden=True,
|
|
6403
7103
|
export=False,
|
|
6404
7104
|
group="Utility",
|
|
@@ -6417,8 +7117,9 @@ add_builtin(
|
|
|
6417
7117
|
# implements vector[idx] -= scalar
|
|
6418
7118
|
add_builtin(
|
|
6419
7119
|
"sub_inplace",
|
|
6420
|
-
input_types={"a": vector(length=Any, dtype=Scalar), "i":
|
|
7120
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
6421
7121
|
value_type=None,
|
|
7122
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6422
7123
|
hidden=True,
|
|
6423
7124
|
export=False,
|
|
6424
7125
|
group="Utility",
|
|
@@ -6427,8 +7128,9 @@ add_builtin(
|
|
|
6427
7128
|
# implements quaternion[idx] -= scalar
|
|
6428
7129
|
add_builtin(
|
|
6429
7130
|
"sub_inplace",
|
|
6430
|
-
input_types={"a": quaternion(dtype=Scalar), "i":
|
|
7131
|
+
input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
|
|
6431
7132
|
value_type=None,
|
|
7133
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6432
7134
|
hidden=True,
|
|
6433
7135
|
export=False,
|
|
6434
7136
|
group="Utility",
|
|
@@ -6437,8 +7139,9 @@ add_builtin(
|
|
|
6437
7139
|
# implements transformation[idx] -= scalar
|
|
6438
7140
|
add_builtin(
|
|
6439
7141
|
"sub_inplace",
|
|
6440
|
-
input_types={"a": transformation(dtype=
|
|
7142
|
+
input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
|
|
6441
7143
|
value_type=None,
|
|
7144
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
6442
7145
|
hidden=True,
|
|
6443
7146
|
export=False,
|
|
6444
7147
|
group="Utility",
|
|
@@ -6499,61 +7202,154 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
6499
7202
|
return mat_size == vec_size and mat_type == vec_type
|
|
6500
7203
|
|
|
6501
7204
|
|
|
6502
|
-
|
|
7205
|
+
def matrix_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
7206
|
+
mat = args["a"].type
|
|
7207
|
+
value_type = strip_reference(args["value"].type)
|
|
7208
|
+
|
|
7209
|
+
idxs = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
|
|
7210
|
+
has_slice = any(isinstance(x, slice_t) for x in idxs)
|
|
7211
|
+
|
|
7212
|
+
if has_slice:
|
|
7213
|
+
# Compute the resulting shape from the slicing, with -1 being simple indexing.
|
|
7214
|
+
shape = tuple(idx.get_length(mat._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idxs))
|
|
7215
|
+
|
|
7216
|
+
# Append any non indexed slice.
|
|
7217
|
+
for i in range(len(idxs), len(mat._shape_)):
|
|
7218
|
+
shape += (mat._shape_[i],)
|
|
7219
|
+
|
|
7220
|
+
# Count how many dimensions the output value will have.
|
|
7221
|
+
ndim = sum(1 for x in shape if x >= 0)
|
|
7222
|
+
assert ndim > 0
|
|
7223
|
+
|
|
7224
|
+
if ndim == 1:
|
|
7225
|
+
length = shape[0] if shape[0] != -1 else shape[1]
|
|
7226
|
+
|
|
7227
|
+
if type_is_vector(value_type):
|
|
7228
|
+
if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
|
|
7229
|
+
raise ValueError(
|
|
7230
|
+
f"The provided vector is expected to be of length {length} with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7231
|
+
)
|
|
7232
|
+
|
|
7233
|
+
if value_type._length_ != length:
|
|
7234
|
+
raise ValueError(
|
|
7235
|
+
f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {length})."
|
|
7236
|
+
)
|
|
7237
|
+
|
|
7238
|
+
template_args = (length,)
|
|
7239
|
+
else:
|
|
7240
|
+
# Disallow broadcasting.
|
|
7241
|
+
raise ValueError(
|
|
7242
|
+
f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7243
|
+
)
|
|
7244
|
+
else:
|
|
7245
|
+
assert ndim == 2
|
|
7246
|
+
|
|
7247
|
+
# When a matrix dimension is 0, all other dimensions are also expected to be 0.
|
|
7248
|
+
if any(x == 0 for x in shape):
|
|
7249
|
+
shape = (0,) * len(shape)
|
|
7250
|
+
|
|
7251
|
+
if type_is_matrix(value_type):
|
|
7252
|
+
if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
|
|
7253
|
+
raise ValueError(
|
|
7254
|
+
f"The provided matrix is expected to be of shape {shape} with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7255
|
+
)
|
|
7256
|
+
|
|
7257
|
+
if value_type._shape_ != shape:
|
|
7258
|
+
raise ValueError(
|
|
7259
|
+
f"The shape of the provided matrix ({value_type._shape_}) isn't compatible with the given slice (expected {shape})."
|
|
7260
|
+
)
|
|
7261
|
+
|
|
7262
|
+
template_args = shape
|
|
7263
|
+
else:
|
|
7264
|
+
# Disallow broadcasting.
|
|
7265
|
+
raise ValueError(
|
|
7266
|
+
f"The provided value is expected to be a matrix of shape {shape}, with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7267
|
+
)
|
|
7268
|
+
elif len(idxs) == 1:
|
|
7269
|
+
if not type_is_vector(value_type) or not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
|
|
7270
|
+
raise ValueError(
|
|
7271
|
+
f"The provided value is expected to be a vector of length {mat._shape_[1]}, with dtype {type_repr(mat._wp_scalar_type_)}."
|
|
7272
|
+
)
|
|
7273
|
+
|
|
7274
|
+
if value_type._length_ != mat._shape_[1]:
|
|
7275
|
+
raise ValueError(
|
|
7276
|
+
f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {mat._shape_[1]})."
|
|
7277
|
+
)
|
|
7278
|
+
|
|
7279
|
+
template_args = ()
|
|
7280
|
+
elif len(idxs) == 2:
|
|
7281
|
+
if not types_equal(value_type, mat._wp_scalar_type_):
|
|
7282
|
+
raise ValueError(
|
|
7283
|
+
f"The provided value is expected to be a scalar of type {type_repr(mat._wp_scalar_type_)}."
|
|
7284
|
+
)
|
|
7285
|
+
|
|
7286
|
+
template_args = ()
|
|
7287
|
+
else:
|
|
7288
|
+
raise AssertionError
|
|
7289
|
+
|
|
7290
|
+
func_args = tuple(args.values())
|
|
7291
|
+
return (func_args, template_args)
|
|
7292
|
+
|
|
7293
|
+
|
|
7294
|
+
# implements matrix[i] = value
|
|
6503
7295
|
add_builtin(
|
|
6504
7296
|
"assign_inplace",
|
|
6505
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7297
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
7298
|
+
constraint=matrix_vector_sametype,
|
|
6506
7299
|
value_type=None,
|
|
7300
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6507
7301
|
hidden=True,
|
|
6508
7302
|
export=False,
|
|
6509
7303
|
group="Utility",
|
|
6510
7304
|
)
|
|
6511
7305
|
|
|
6512
7306
|
|
|
6513
|
-
# implements matrix[i] =
|
|
7307
|
+
# implements matrix[i,j] = value
|
|
6514
7308
|
add_builtin(
|
|
6515
7309
|
"assign_inplace",
|
|
6516
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6517
|
-
constraint=matrix_vector_sametype,
|
|
7310
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
6518
7311
|
value_type=None,
|
|
7312
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6519
7313
|
hidden=True,
|
|
6520
7314
|
export=False,
|
|
6521
7315
|
group="Utility",
|
|
6522
7316
|
)
|
|
6523
7317
|
|
|
6524
7318
|
|
|
6525
|
-
def
|
|
7319
|
+
def matrix_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6526
7320
|
mat_type = arg_types["a"]
|
|
6527
7321
|
return mat_type
|
|
6528
7322
|
|
|
6529
7323
|
|
|
6530
|
-
# implements matrix[i
|
|
7324
|
+
# implements matrix[i] = value
|
|
6531
7325
|
add_builtin(
|
|
6532
7326
|
"assign_copy",
|
|
6533
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6534
|
-
value_func=
|
|
7327
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
7328
|
+
value_func=matrix_assign_copy_value_func,
|
|
7329
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6535
7330
|
hidden=True,
|
|
6536
7331
|
export=False,
|
|
6537
7332
|
group="Utility",
|
|
6538
7333
|
)
|
|
6539
7334
|
|
|
6540
7335
|
|
|
6541
|
-
# implements matrix[i] =
|
|
7336
|
+
# implements matrix[i,j] = value
|
|
6542
7337
|
add_builtin(
|
|
6543
7338
|
"assign_copy",
|
|
6544
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6545
|
-
|
|
6546
|
-
|
|
7339
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
7340
|
+
value_func=matrix_assign_copy_value_func,
|
|
7341
|
+
dispatch_func=matrix_assign_dispatch_func,
|
|
6547
7342
|
hidden=True,
|
|
6548
7343
|
export=False,
|
|
6549
7344
|
group="Utility",
|
|
6550
7345
|
)
|
|
6551
7346
|
|
|
6552
7347
|
|
|
6553
|
-
# implements matrix[i
|
|
7348
|
+
# implements matrix[i] += value
|
|
6554
7349
|
add_builtin(
|
|
6555
7350
|
"add_inplace",
|
|
6556
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7351
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
7352
|
+
constraint=matrix_vector_sametype,
|
|
6557
7353
|
value_type=None,
|
|
6558
7354
|
hidden=True,
|
|
6559
7355
|
export=False,
|
|
@@ -6561,11 +7357,10 @@ add_builtin(
|
|
|
6561
7357
|
)
|
|
6562
7358
|
|
|
6563
7359
|
|
|
6564
|
-
# implements matrix[i] +=
|
|
7360
|
+
# implements matrix[i,j] += value
|
|
6565
7361
|
add_builtin(
|
|
6566
7362
|
"add_inplace",
|
|
6567
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
6568
|
-
constraint=matrix_vector_sametype,
|
|
7363
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
6569
7364
|
value_type=None,
|
|
6570
7365
|
hidden=True,
|
|
6571
7366
|
export=False,
|
|
@@ -6573,10 +7368,10 @@ add_builtin(
|
|
|
6573
7368
|
)
|
|
6574
7369
|
|
|
6575
7370
|
|
|
6576
|
-
# implements matrix[i
|
|
7371
|
+
# implements matrix[i] -= value
|
|
6577
7372
|
add_builtin(
|
|
6578
7373
|
"sub_inplace",
|
|
6579
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7374
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
6580
7375
|
value_type=None,
|
|
6581
7376
|
hidden=True,
|
|
6582
7377
|
export=False,
|
|
@@ -6584,10 +7379,10 @@ add_builtin(
|
|
|
6584
7379
|
)
|
|
6585
7380
|
|
|
6586
7381
|
|
|
6587
|
-
# implements matrix[i] -=
|
|
7382
|
+
# implements matrix[i,j] -= value
|
|
6588
7383
|
add_builtin(
|
|
6589
7384
|
"sub_inplace",
|
|
6590
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i":
|
|
7385
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
6591
7386
|
value_type=None,
|
|
6592
7387
|
hidden=True,
|
|
6593
7388
|
export=False,
|
|
@@ -6807,6 +7602,7 @@ add_builtin(
|
|
|
6807
7602
|
# ---------------------------------
|
|
6808
7603
|
# Operators
|
|
6809
7604
|
|
|
7605
|
+
|
|
6810
7606
|
add_builtin(
|
|
6811
7607
|
"add", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
6812
7608
|
)
|
|
@@ -7079,7 +7875,7 @@ add_builtin(
|
|
|
7079
7875
|
"mod",
|
|
7080
7876
|
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
7081
7877
|
constraint=sametypes,
|
|
7082
|
-
value_func=sametypes_create_value_func(Scalar),
|
|
7878
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
7083
7879
|
doc="Modulo operation using truncated division.",
|
|
7084
7880
|
group="Operators",
|
|
7085
7881
|
)
|
|
@@ -7481,7 +8277,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
7481
8277
|
num_threads = options["block_dim"]
|
|
7482
8278
|
arch = options["output_arch"]
|
|
7483
8279
|
|
|
7484
|
-
if arch is None or not warp.context.runtime.core.
|
|
8280
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7485
8281
|
# CPU/no-MathDx dispatch
|
|
7486
8282
|
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
7487
8283
|
else:
|
|
@@ -7671,7 +8467,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
7671
8467
|
arch = options["output_arch"]
|
|
7672
8468
|
ept = size // num_threads
|
|
7673
8469
|
|
|
7674
|
-
if arch is None or not warp.context.runtime.core.
|
|
8470
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7675
8471
|
# CPU/no-MathDx dispatch
|
|
7676
8472
|
return ([], [], [], 0)
|
|
7677
8473
|
else:
|
|
@@ -7792,28 +8588,27 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
7792
8588
|
raise TypeError("tile_cholesky() returns one output")
|
|
7793
8589
|
out = return_values[0]
|
|
7794
8590
|
|
|
7795
|
-
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
7796
|
-
|
|
7797
8591
|
# We already ensured a is square in tile_cholesky_generic_value_func()
|
|
7798
8592
|
M, N = a.type.shape
|
|
7799
8593
|
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
7800
8594
|
raise ValueError("tile_cholesky() output tile must be square")
|
|
7801
8595
|
|
|
7802
|
-
solver = "potrf"
|
|
7803
|
-
solver_enum = cusolver_function_map[solver]
|
|
7804
|
-
|
|
7805
|
-
side_enum = cusolver_side_map["-"]
|
|
7806
|
-
diag_enum = cusolver_diag_map["-"]
|
|
7807
|
-
fill_mode = cusolver_fill_mode_map["lower"]
|
|
7808
|
-
|
|
7809
8596
|
arch = options["output_arch"]
|
|
7810
|
-
num_threads = options["block_dim"]
|
|
7811
|
-
parameter_list = f"({dtype}*, int*)"
|
|
7812
8597
|
|
|
7813
|
-
if arch is None or not warp.context.runtime.core.
|
|
8598
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7814
8599
|
# CPU/no-MathDx dispatch
|
|
7815
8600
|
return ((0, a, out), [], [], 0)
|
|
7816
8601
|
else:
|
|
8602
|
+
solver = "potrf"
|
|
8603
|
+
solver_enum = cusolver_function_map[solver]
|
|
8604
|
+
side_enum = cusolver_side_map["-"]
|
|
8605
|
+
diag_enum = cusolver_diag_map["-"]
|
|
8606
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8607
|
+
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
8608
|
+
num_threads = options["block_dim"]
|
|
8609
|
+
parameter_list = f"({dtype}*, int*)"
|
|
8610
|
+
req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
|
|
8611
|
+
|
|
7817
8612
|
# generate the LTO
|
|
7818
8613
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
7819
8614
|
M,
|
|
@@ -7831,6 +8626,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
7831
8626
|
num_threads,
|
|
7832
8627
|
parameter_list,
|
|
7833
8628
|
builder,
|
|
8629
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
7834
8630
|
)
|
|
7835
8631
|
|
|
7836
8632
|
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
@@ -7918,9 +8714,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
7918
8714
|
if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
|
|
7919
8715
|
raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
|
|
7920
8716
|
|
|
7921
|
-
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
7922
8717
|
M, N = L.type.shape
|
|
7923
|
-
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
7924
8718
|
|
|
7925
8719
|
if len(x.type.shape) > 2 or len(x.type.shape) < 1:
|
|
7926
8720
|
raise TypeError(f"tile_cholesky_solve() output vector must be 1D or 2D, got {len(x.type.shape)}-D")
|
|
@@ -7931,21 +8725,23 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
7931
8725
|
f"got {x.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
7932
8726
|
)
|
|
7933
8727
|
|
|
7934
|
-
solver = "potrs"
|
|
7935
|
-
solver_enum = cusolver_function_map[solver]
|
|
7936
|
-
|
|
7937
|
-
side_enum = cusolver_side_map["-"]
|
|
7938
|
-
diag_enum = cusolver_diag_map["-"]
|
|
7939
|
-
fill_mode = cusolver_fill_mode_map["lower"]
|
|
7940
|
-
|
|
7941
8728
|
arch = options["output_arch"]
|
|
7942
|
-
num_threads = options["block_dim"]
|
|
7943
|
-
parameter_list = f"({dtype}*, {dtype}*)"
|
|
7944
8729
|
|
|
7945
|
-
if arch is None or not warp.context.runtime.core.
|
|
8730
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
7946
8731
|
# CPU/no-MathDx dispatch
|
|
7947
8732
|
return ((0, L, y, x), [], [], 0)
|
|
7948
8733
|
else:
|
|
8734
|
+
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
8735
|
+
solver = "potrs"
|
|
8736
|
+
solver_enum = cusolver_function_map[solver]
|
|
8737
|
+
side_enum = cusolver_side_map["-"]
|
|
8738
|
+
diag_enum = cusolver_diag_map["-"]
|
|
8739
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8740
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
8741
|
+
num_threads = options["block_dim"]
|
|
8742
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8743
|
+
req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8744
|
+
|
|
7949
8745
|
# generate the LTO
|
|
7950
8746
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
7951
8747
|
M,
|
|
@@ -7963,6 +8759,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
7963
8759
|
num_threads,
|
|
7964
8760
|
parameter_list,
|
|
7965
8761
|
builder,
|
|
8762
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
7966
8763
|
)
|
|
7967
8764
|
|
|
7968
8765
|
return ((Var(lto_symbol, str, False, True, False), L, y, x), [], [lto_code_data], 0)
|
|
@@ -8013,9 +8810,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8013
8810
|
|
|
8014
8811
|
z = return_values[0]
|
|
8015
8812
|
|
|
8016
|
-
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
8017
8813
|
M, N = L.type.shape
|
|
8018
|
-
NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
|
|
8019
8814
|
|
|
8020
8815
|
if len(z.type.shape) > 2 or len(z.type.shape) < 1:
|
|
8021
8816
|
raise TypeError(f"tile_lower_solve() output vector must be 1D or 2D, got {len(z.type.shape)}-D")
|
|
@@ -8026,21 +8821,23 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8026
8821
|
f"got {z.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
8027
8822
|
)
|
|
8028
8823
|
|
|
8029
|
-
solver = "trsm"
|
|
8030
|
-
solver_enum = cusolver_function_map[solver]
|
|
8031
|
-
|
|
8032
|
-
side_enum = cusolver_side_map["left"]
|
|
8033
|
-
diag_enum = cusolver_diag_map["nounit"]
|
|
8034
|
-
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8035
|
-
|
|
8036
8824
|
arch = options["output_arch"]
|
|
8037
|
-
num_threads = options["block_dim"]
|
|
8038
|
-
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8039
8825
|
|
|
8040
|
-
if arch is None or not warp.context.runtime.core.
|
|
8826
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
8041
8827
|
# CPU/no-MathDx dispatch
|
|
8042
8828
|
return ((0, L, y, z), [], [], 0)
|
|
8043
8829
|
else:
|
|
8830
|
+
NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
|
|
8831
|
+
solver = "trsm"
|
|
8832
|
+
solver_enum = cusolver_function_map[solver]
|
|
8833
|
+
side_enum = cusolver_side_map["left"]
|
|
8834
|
+
diag_enum = cusolver_diag_map["nounit"]
|
|
8835
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8836
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
8837
|
+
num_threads = options["block_dim"]
|
|
8838
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8839
|
+
req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8840
|
+
|
|
8044
8841
|
# generate the LTO
|
|
8045
8842
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
8046
8843
|
M,
|
|
@@ -8058,6 +8855,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8058
8855
|
num_threads,
|
|
8059
8856
|
parameter_list,
|
|
8060
8857
|
builder,
|
|
8858
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
8061
8859
|
)
|
|
8062
8860
|
|
|
8063
8861
|
return ((Var(lto_symbol, str, False, True, False), L, y, z), [], [lto_code_data], 0)
|
|
@@ -8144,9 +8942,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8144
8942
|
|
|
8145
8943
|
x = return_values[0]
|
|
8146
8944
|
|
|
8147
|
-
dtype, precision_enum = cusolver_type_map[U.type.dtype]
|
|
8148
8945
|
M, N = U.type.shape
|
|
8149
|
-
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
8150
8946
|
|
|
8151
8947
|
if len(z.type.shape) > 2 or len(z.type.shape) < 1:
|
|
8152
8948
|
raise TypeError(f"tile_upper_solve() output tile must be 1D or 2D, got {len(z.type.shape)}-D")
|
|
@@ -8157,21 +8953,23 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8157
8953
|
f"got {z.type.shape[0]} elements in output and {M} rows in 'U'"
|
|
8158
8954
|
)
|
|
8159
8955
|
|
|
8160
|
-
solver = "trsm"
|
|
8161
|
-
solver_enum = cusolver_function_map[solver]
|
|
8162
|
-
|
|
8163
|
-
side_enum = cusolver_side_map["left"]
|
|
8164
|
-
diag_enum = cusolver_diag_map["nounit"]
|
|
8165
|
-
fill_mode = cusolver_fill_mode_map["upper"]
|
|
8166
|
-
|
|
8167
8956
|
arch = options["output_arch"]
|
|
8168
|
-
num_threads = options["block_dim"]
|
|
8169
|
-
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8170
8957
|
|
|
8171
|
-
if arch is None or not warp.context.runtime.core.
|
|
8958
|
+
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
8172
8959
|
# CPU/no-MathDx dispatch
|
|
8173
8960
|
return ((0, U, z, x), [], [], 0)
|
|
8174
8961
|
else:
|
|
8962
|
+
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
8963
|
+
solver = "trsm"
|
|
8964
|
+
solver_enum = cusolver_function_map[solver]
|
|
8965
|
+
side_enum = cusolver_side_map["left"]
|
|
8966
|
+
diag_enum = cusolver_diag_map["nounit"]
|
|
8967
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
8968
|
+
dtype, precision_enum = cusolver_type_map[U.type.dtype]
|
|
8969
|
+
num_threads = options["block_dim"]
|
|
8970
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8971
|
+
req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
|
|
8972
|
+
|
|
8175
8973
|
# generate the LTO
|
|
8176
8974
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
8177
8975
|
M,
|
|
@@ -8189,6 +8987,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8189
8987
|
num_threads,
|
|
8190
8988
|
parameter_list,
|
|
8191
8989
|
builder,
|
|
8990
|
+
smem_estimate_bytes=req_smem_bytes,
|
|
8192
8991
|
)
|
|
8193
8992
|
|
|
8194
8993
|
return ((Var(lto_symbol, str, False, True, False), U, z, x), [], [lto_code_data], 0)
|
|
@@ -8413,3 +9212,22 @@ add_builtin(
|
|
|
8413
9212
|
group="Utility",
|
|
8414
9213
|
export=False,
|
|
8415
9214
|
)
|
|
9215
|
+
|
|
9216
|
+
# ---------------------------------
|
|
9217
|
+
# Slicing
|
|
9218
|
+
|
|
9219
|
+
|
|
9220
|
+
def slice_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
9221
|
+
return slice_t(**arg_values)
|
|
9222
|
+
|
|
9223
|
+
|
|
9224
|
+
add_builtin(
|
|
9225
|
+
"slice",
|
|
9226
|
+
input_types={"start": int, "stop": int, "step": int},
|
|
9227
|
+
value_func=slice_value_func,
|
|
9228
|
+
native_func="slice_t",
|
|
9229
|
+
export=False,
|
|
9230
|
+
group="Utility",
|
|
9231
|
+
hidden=True,
|
|
9232
|
+
missing_grad=True,
|
|
9233
|
+
)
|