warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -17,10 +17,12 @@ from __future__ import annotations
17
17
 
18
18
  import builtins
19
19
  import functools
20
+ import math
20
21
  from typing import Any, Callable, Mapping, Sequence
21
22
 
22
23
  import warp.build
23
24
  import warp.context
25
+ import warp.utils
24
26
  from warp.codegen import Reference, Var, get_arg_value, strip_reference
25
27
  from warp.types import *
26
28
 
@@ -2355,6 +2357,7 @@ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mappin
2355
2357
  def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2356
2358
  a = args["a"]
2357
2359
  shape = extract_tuple(args["shape"], as_constant=True)
2360
+ bounds_check = args["bounds_check"]
2358
2361
 
2359
2362
  if None in shape:
2360
2363
  raise ValueError("Tile functions require shape to be a compile time constant.")
@@ -2365,17 +2368,23 @@ def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type:
2365
2368
  offset = (0,) * a.type.ndim
2366
2369
 
2367
2370
  func_args = (a, *offset)
2368
- template_args = shape
2371
+ template_args = (return_type.dtype, bounds_check.constant, *shape)
2369
2372
 
2370
2373
  return (func_args, template_args)
2371
2374
 
2372
2375
 
2373
2376
  add_builtin(
2374
2377
  "tile_load",
2375
- input_types={"a": array(dtype=Any), "shape": Tuple[int, ...], "offset": Tuple[int, ...], "storage": str},
2378
+ input_types={
2379
+ "a": array(dtype=Any),
2380
+ "shape": Tuple[int, ...],
2381
+ "offset": Tuple[int, ...],
2382
+ "storage": str,
2383
+ "bounds_check": builtins.bool,
2384
+ },
2376
2385
  value_func=tile_load_tuple_value_func,
2377
2386
  dispatch_func=tile_load_tuple_dispatch_func,
2378
- defaults={"offset": None, "storage": "register"},
2387
+ defaults={"offset": None, "storage": "register", "bounds_check": True},
2379
2388
  variadic=False,
2380
2389
  doc="""Loads a tile from a global memory array.
2381
2390
 
@@ -2386,6 +2395,7 @@ add_builtin(
2386
2395
  :param offset: Offset in the source array to begin reading from (optional)
2387
2396
  :param storage: The storage location for the tile: ``"register"`` for registers
2388
2397
  (default) or ``"shared"`` for shared memory.
2398
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster load times
2389
2399
  :returns: A tile with shape as specified and data type the same as the source array""",
2390
2400
  group="Tile Primitives",
2391
2401
  export=False,
@@ -2394,16 +2404,160 @@ add_builtin(
2394
2404
  # overload for scalar shape
2395
2405
  add_builtin(
2396
2406
  "tile_load",
2397
- input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
2407
+ input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str, "bounds_check": builtins.bool},
2398
2408
  value_func=tile_load_tuple_value_func,
2399
2409
  dispatch_func=tile_load_tuple_dispatch_func,
2400
- defaults={"offset": None, "storage": "register"},
2410
+ defaults={"offset": None, "storage": "register", "bounds_check": True},
2401
2411
  group="Tile Primitives",
2402
2412
  hidden=True,
2403
2413
  export=False,
2404
2414
  )
2405
2415
 
2406
2416
 
2417
+ def tile_load_indexed_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2418
+ if arg_types is None:
2419
+ return tile(dtype=Any, shape=Tuple[int, ...])
2420
+
2421
+ a = arg_types["a"]
2422
+
2423
+ indices_tile = arg_types["indices"]
2424
+ indices_tile.storage = "shared" # force to shared
2425
+
2426
+ axis = arg_values["axis"]
2427
+ if axis >= a.ndim:
2428
+ raise ValueError(f"tile_load_indexed() axis argument must be valid axis of array {a}, got {axis}.")
2429
+
2430
+ indices_tile_dim = len(indices_tile.shape)
2431
+ if indices_tile_dim != 1:
2432
+ raise ValueError(
2433
+ f"tile_load_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
2434
+ )
2435
+
2436
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2437
+
2438
+ if None in shape:
2439
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2440
+
2441
+ num_indices = indices_tile.shape[0]
2442
+ if num_indices != shape[axis]:
2443
+ raise ValueError(
2444
+ "The number of elements in the 1D indices tile must match the output tile shape along the specified axis."
2445
+ )
2446
+
2447
+ if "offset" in arg_values:
2448
+ offset = extract_tuple(arg_values["offset"])
2449
+ else:
2450
+ offset = (0,) * a.ndim
2451
+
2452
+ if a.ndim != len(shape):
2453
+ raise ValueError(
2454
+ f"tile_load_indexed() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
2455
+ )
2456
+
2457
+ if a.ndim != len(offset):
2458
+ raise ValueError(
2459
+ f"tile_load_indexed() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
2460
+ )
2461
+
2462
+ if arg_values["storage"] not in {"shared", "register"}:
2463
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2464
+
2465
+ return tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
2466
+
2467
+
2468
+ def tile_load_indexed_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2469
+ a = args["a"]
2470
+ indices_tile = args["indices"]
2471
+ axis = args["axis"]
2472
+
2473
+ shape = extract_tuple(args["shape"], as_constant=True)
2474
+
2475
+ if None in shape:
2476
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2477
+
2478
+ if "offset" in args:
2479
+ offset = extract_tuple(args["offset"])
2480
+ else:
2481
+ offset = (0,) * a.type.ndim
2482
+
2483
+ func_args = (a, indices_tile, axis, *offset)
2484
+ template_args = shape
2485
+
2486
+ return (func_args, template_args)
2487
+
2488
+
2489
+ add_builtin(
2490
+ "tile_load_indexed",
2491
+ input_types={
2492
+ "a": array(dtype=Any),
2493
+ "indices": tile(dtype=int, shape=Tuple[int]),
2494
+ "shape": Tuple[int, ...],
2495
+ "offset": Tuple[int, ...],
2496
+ "axis": int,
2497
+ "storage": str,
2498
+ },
2499
+ value_func=tile_load_indexed_tuple_value_func,
2500
+ dispatch_func=tile_load_indexed_tuple_dispatch_func,
2501
+ defaults={"offset": None, "axis": 0, "storage": "register"},
2502
+ variadic=False,
2503
+ doc="""Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.
2504
+
2505
+ :param a: The source array in global memory
2506
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
2507
+ :param shape: Shape of the tile to load, must have the same number of dimensions as ``a``, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
2508
+ :param offset: Offset in the source array to begin reading from (optional)
2509
+ :param axis: Axis of ``a`` that indices refer to
2510
+ :param storage: The storage location for the tile: ``"register"`` for registers (default) or ``"shared"`` for shared memory.
2511
+ :returns: A tile with shape as specified and data type the same as the source array
2512
+
2513
+ This example shows how to select and store the even indexed rows from a 2D array:
2514
+
2515
+ .. code-block:: python
2516
+
2517
+ TILE_M = wp.constant(2)
2518
+ TILE_N = wp.constant(2)
2519
+ HALF_M = wp.constant(TILE_M // 2)
2520
+ HALF_N = wp.constant(TILE_N // 2)
2521
+
2522
+ @wp.kernel
2523
+ def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2524
+ i, j = wp.tid()
2525
+
2526
+ evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2
2527
+
2528
+ t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register")
2529
+ wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N))
2530
+
2531
+ M = TILE_M * 2
2532
+ N = TILE_N * 2
2533
+
2534
+ arr = np.arange(M * N).reshape(M, N)
2535
+
2536
+ x = wp.array(arr, dtype=float)
2537
+ y = wp.zeros((M // 2, N), dtype=float)
2538
+
2539
+ wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
2540
+
2541
+ print(x.numpy())
2542
+ print(y.numpy())
2543
+
2544
+ Prints:
2545
+
2546
+ .. code-block:: text
2547
+
2548
+ [[ 0. 1. 2. 3.]
2549
+ [ 4. 5. 6. 7.]
2550
+ [ 8. 9. 10. 11.]
2551
+ [12. 13. 14. 15.]]
2552
+
2553
+ [[ 0. 1. 2. 3.]
2554
+ [ 8. 9. 10. 11.]]
2555
+ """,
2556
+ group="Tile Primitives",
2557
+ export=False,
2558
+ )
2559
+
2560
+
2407
2561
  def tile_store_value_func(arg_types, arg_values):
2408
2562
  # return generic type (for doc builds)
2409
2563
  if arg_types is None:
@@ -2440,6 +2594,7 @@ def tile_store_value_func(arg_types, arg_values):
2440
2594
  def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2441
2595
  a = args["a"]
2442
2596
  t = args["t"]
2597
+ bounds_check = args["bounds_check"]
2443
2598
 
2444
2599
  if "offset" in args:
2445
2600
  offset = extract_tuple(args["offset"])
@@ -2447,17 +2602,22 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2447
2602
  offset = (0,) * a.type.ndim
2448
2603
 
2449
2604
  func_args = (a, *offset, t)
2450
- template_args = []
2605
+ template_args = (a.type.dtype, bounds_check.constant)
2451
2606
 
2452
2607
  return (func_args, template_args)
2453
2608
 
2454
2609
 
2455
2610
  add_builtin(
2456
2611
  "tile_store",
2457
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
2612
+ input_types={
2613
+ "a": array(dtype=Any),
2614
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2615
+ "offset": Tuple[int, ...],
2616
+ "bounds_check": builtins.bool,
2617
+ },
2458
2618
  value_func=tile_store_value_func,
2459
2619
  dispatch_func=tile_store_dispatch_func,
2460
- defaults={"offset": None},
2620
+ defaults={"offset": None, "bounds_check": True},
2461
2621
  variadic=False,
2462
2622
  skip_replay=True,
2463
2623
  doc="""Store a tile to a global memory array.
@@ -2466,7 +2626,9 @@ add_builtin(
2466
2626
 
2467
2627
  :param a: The destination array in global memory
2468
2628
  :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
2469
- :param offset: Offset in the destination array (optional)""",
2629
+ :param offset: Offset in the destination array (optional)
2630
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
2631
+ """,
2470
2632
  group="Tile Primitives",
2471
2633
  export=False,
2472
2634
  )
@@ -2474,10 +2636,15 @@ add_builtin(
2474
2636
  # overload for scalar offset
2475
2637
  add_builtin(
2476
2638
  "tile_store",
2477
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
2639
+ input_types={
2640
+ "a": array(dtype=Any),
2641
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2642
+ "offset": int,
2643
+ "bounds_check": builtins.bool,
2644
+ },
2478
2645
  value_func=tile_store_value_func,
2479
2646
  dispatch_func=tile_store_dispatch_func,
2480
- defaults={"offset": None},
2647
+ defaults={"offset": None, "bounds_check": True},
2481
2648
  variadic=False,
2482
2649
  skip_replay=True,
2483
2650
  group="Tile Primitives",
@@ -2486,6 +2653,151 @@ add_builtin(
2486
2653
  )
2487
2654
 
2488
2655
 
2656
+ def tile_store_indexed_value_func(arg_types, arg_values):
2657
+ # return generic type (for doc builds)
2658
+ if arg_types is None:
2659
+ return None
2660
+
2661
+ a = arg_types["a"]
2662
+ t = arg_types["t"]
2663
+ indices_tile = arg_types["indices"]
2664
+ indices_tile.storage = "shared" # force to shared
2665
+
2666
+ axis = arg_values["axis"]
2667
+ if axis >= a.ndim:
2668
+ raise ValueError(f"tile_store_indexed() axis argument must be valid axis of array {a}, got {axis}.")
2669
+
2670
+ indices_tile_dim = len(indices_tile.shape)
2671
+ if indices_tile_dim != 1:
2672
+ raise ValueError(
2673
+ f"tile_store_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
2674
+ )
2675
+
2676
+ num_indices = indices_tile.shape[0]
2677
+ if num_indices != t.shape[axis]:
2678
+ raise ValueError(
2679
+ "The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
2680
+ )
2681
+
2682
+ if "offset" in arg_types:
2683
+ c = extract_tuple(arg_values["offset"])
2684
+ else:
2685
+ c = (0,) * a.ndim
2686
+
2687
+ if len(c) != a.ndim:
2688
+ raise ValueError(
2689
+ f"tile_store_indexed() 'a' argument must have {len(c)} dimensions, "
2690
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2691
+ )
2692
+
2693
+ if len(t.shape) != a.ndim:
2694
+ raise ValueError(
2695
+ f"tile_store_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
2696
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2697
+ )
2698
+
2699
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2700
+ raise TypeError(
2701
+ f"tile_store_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2702
+ )
2703
+
2704
+ return None
2705
+
2706
+
2707
+ def tile_store_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2708
+ a = args["a"]
2709
+ indices_tile = args["indices"]
2710
+ axis = args["axis"]
2711
+ t = args["t"]
2712
+
2713
+ if "offset" in args:
2714
+ offset = extract_tuple(args["offset"])
2715
+ else:
2716
+ offset = (0,) * a.type.ndim
2717
+
2718
+ func_args = (a, indices_tile, axis, *offset, t)
2719
+ template_args = []
2720
+
2721
+ return (func_args, template_args)
2722
+
2723
+
2724
+ add_builtin(
2725
+ "tile_store_indexed",
2726
+ input_types={
2727
+ "a": array(dtype=Any),
2728
+ "indices": tile(dtype=int, shape=Tuple[int]),
2729
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2730
+ "offset": Tuple[int, ...],
2731
+ "axis": int,
2732
+ },
2733
+ value_func=tile_store_indexed_value_func,
2734
+ dispatch_func=tile_store_indexed_dispatch_func,
2735
+ defaults={"offset": None, "axis": 0},
2736
+ variadic=False,
2737
+ skip_replay=True,
2738
+ doc="""Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
2739
+
2740
+ :param a: The destination array in global memory
2741
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
2742
+ :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
2743
+ :param offset: Offset in the destination array (optional)
2744
+ :param axis: Axis of ``a`` that indices refer to
2745
+
2746
+ This example shows how to map tile rows to the even rows of a 2D array:
2747
+
2748
+ .. code-block:: python
2749
+
2750
+ TILE_M = wp.constant(2)
2751
+ TILE_N = wp.constant(2)
2752
+ TWO_M = wp.constant(TILE_M * 2)
2753
+ TWO_N = wp.constant(TILE_N * 2)
2754
+
2755
+ @wp.kernel
2756
+ def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2757
+ i, j = wp.tid()
2758
+
2759
+ t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
2760
+
2761
+ evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2
2762
+
2763
+ wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0)
2764
+
2765
+ M = TILE_M * 2
2766
+ N = TILE_N * 2
2767
+
2768
+ arr = np.arange(M * N, dtype=float).reshape(M, N)
2769
+
2770
+ x = wp.array(arr, dtype=float, requires_grad=True, device=device)
2771
+ y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)
2772
+
2773
+ wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
2774
+
2775
+ print(x.numpy())
2776
+ print(y.numpy())
2777
+
2778
+ Prints:
2779
+
2780
+ .. code-block:: text
2781
+
2782
+ [[ 0. 1. 2. 3.]
2783
+ [ 4. 5. 6. 7.]
2784
+ [ 8. 9. 10. 11.]
2785
+ [12. 13. 14. 15.]]
2786
+
2787
+ [[ 0. 1. 2. 3.]
2788
+ [ 0. 0. 0. 0.]
2789
+ [ 4. 5. 6. 7.]
2790
+ [ 0. 0. 0. 0.]
2791
+ [ 8. 9. 10. 11.]
2792
+ [ 0. 0. 0. 0.]
2793
+ [12. 13. 14. 15.]
2794
+ [ 0. 0. 0. 0.]]
2795
+ """,
2796
+ group="Tile Primitives",
2797
+ export=False,
2798
+ )
2799
+
2800
+
2489
2801
  def tile_atomic_add_value_func(arg_types, arg_values):
2490
2802
  # return generic type (for doc builds)
2491
2803
  if arg_types is None:
@@ -2526,6 +2838,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
2526
2838
  def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2527
2839
  a = args["a"]
2528
2840
  t = args["t"]
2841
+ bounds_check = args["bounds_check"]
2529
2842
 
2530
2843
  if "offset" in args:
2531
2844
  offset = extract_tuple(args["offset"])
@@ -2533,17 +2846,22 @@ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type:
2533
2846
  offset = (0,) * a.type.ndim
2534
2847
 
2535
2848
  func_args = (a, *offset, t)
2536
- template_args = []
2849
+ template_args = (a.type.dtype, bounds_check.constant)
2537
2850
 
2538
2851
  return (func_args, template_args)
2539
2852
 
2540
2853
 
2541
2854
  add_builtin(
2542
2855
  "tile_atomic_add",
2543
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
2856
+ input_types={
2857
+ "a": array(dtype=Any),
2858
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2859
+ "offset": Tuple[int, ...],
2860
+ "bounds_check": builtins.bool,
2861
+ },
2544
2862
  value_func=tile_atomic_add_value_func,
2545
2863
  dispatch_func=tile_atomic_add_dispatch_func,
2546
- defaults={"offset": None},
2864
+ defaults={"offset": None, "bounds_check": True},
2547
2865
  variadic=False,
2548
2866
  skip_replay=True,
2549
2867
  doc="""Atomically add a tile onto the array `a`, each element will be updated atomically.
@@ -2551,6 +2869,7 @@ add_builtin(
2551
2869
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
2552
2870
  :param t: Source tile to add to the destination array
2553
2871
  :param offset: Offset in the destination array (optional)
2872
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
2554
2873
  :returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
2555
2874
  group="Tile Primitives",
2556
2875
  export=False,
@@ -2559,10 +2878,15 @@ add_builtin(
2559
2878
  # overload for scalar offset
2560
2879
  add_builtin(
2561
2880
  "tile_atomic_add",
2562
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
2881
+ input_types={
2882
+ "a": array(dtype=Any),
2883
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2884
+ "offset": int,
2885
+ "bounds_check": builtins.bool,
2886
+ },
2563
2887
  value_func=tile_atomic_add_value_func,
2564
2888
  dispatch_func=tile_atomic_add_dispatch_func,
2565
- defaults={"offset": None},
2889
+ defaults={"offset": None, "bounds_check": True},
2566
2890
  variadic=False,
2567
2891
  skip_replay=True,
2568
2892
  group="Tile Primitives",
@@ -2571,6 +2895,143 @@ add_builtin(
2571
2895
  )
2572
2896
 
2573
2897
 
2898
+ def tile_atomic_add_indexed_value_func(arg_types, arg_values):
2899
+ # return generic type (for doc builds)
2900
+ if arg_types is None:
2901
+ return tile(dtype=Any, shape=Tuple[int, ...])
2902
+
2903
+ a = arg_types["a"]
2904
+ t = arg_types["t"]
2905
+ indices_tile = arg_types["indices"]
2906
+ indices_tile.storage = "shared" # force to shared
2907
+
2908
+ axis = arg_values["axis"]
2909
+ if axis >= a.ndim:
2910
+ raise ValueError(f"tile_atomic_add_indexed() axis argument must be valid axis of array {a}, got {axis}.")
2911
+
2912
+ indices_tile_dim = len(indices_tile.shape)
2913
+ if indices_tile_dim != 1:
2914
+ raise ValueError(
2915
+ f"tile_atomic_add_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
2916
+ )
2917
+
2918
+ num_indices = indices_tile.shape[0]
2919
+ if num_indices != t.shape[axis]:
2920
+ raise ValueError(
2921
+ "The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
2922
+ )
2923
+
2924
+ if "offset" in arg_types:
2925
+ c = extract_tuple(arg_values["offset"])
2926
+ else:
2927
+ c = (0,) * a.ndim
2928
+
2929
+ if len(c) != a.ndim:
2930
+ raise ValueError(
2931
+ f"tile_atomic_add_indexed() 'a' argument must have {len(c)} dimensions, "
2932
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2933
+ )
2934
+
2935
+ if len(t.shape) != a.ndim:
2936
+ raise ValueError(
2937
+ f"tile_atomic_add_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
2938
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2939
+ )
2940
+
2941
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2942
+ raise TypeError(
2943
+ f"tile_atomic_add_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2944
+ )
2945
+
2946
+ return tile(dtype=t.dtype, shape=t.shape, storage=t.storage)
2947
+
2948
+
2949
+ def tile_atomic_add_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2950
+ a = args["a"]
2951
+ indices_tile = args["indices"]
2952
+ axis = args["axis"]
2953
+ t = args["t"]
2954
+
2955
+ if "offset" in args:
2956
+ offset = extract_tuple(args["offset"])
2957
+ else:
2958
+ offset = (0,) * a.type.ndim
2959
+
2960
+ func_args = (a, indices_tile, axis, *offset, t)
2961
+ template_args = []
2962
+
2963
+ return (func_args, template_args)
2964
+
2965
+
2966
+ add_builtin(
2967
+ "tile_atomic_add_indexed",
2968
+ input_types={
2969
+ "a": array(dtype=Any),
2970
+ "indices": tile(dtype=int, shape=Tuple[int]),
2971
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2972
+ "offset": Tuple[int, ...],
2973
+ "axis": int,
2974
+ },
2975
+ value_func=tile_atomic_add_indexed_value_func,
2976
+ dispatch_func=tile_atomic_add_indexed_dispatch_func,
2977
+ defaults={"offset": None, "axis": 0},
2978
+ variadic=False,
2979
+ skip_replay=True,
2980
+ doc="""Atomically add a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
2981
+
2982
+ :param a: The destination array in global memory
2983
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
2984
+ :param t: The source tile to extract data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
2985
+ :param offset: Offset in the destination array (optional)
2986
+ :param axis: Axis of ``a`` that indices refer to
2987
+
2988
+ This example shows how to compute a blocked, row-wise reduction:
2989
+
2990
+ .. code-block:: python
2991
+
2992
+ TILE_M = wp.constant(2)
2993
+ TILE_N = wp.constant(2)
2994
+
2995
+ @wp.kernel
2996
+ def tile_atomic_add_indexed(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2997
+ i, j = wp.tid()
2998
+
2999
+ t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
3000
+
3001
+ zeros = wp.tile_zeros(TILE_M, dtype=int, storage="shared")
3002
+
3003
+ wp.tile_atomic_add_indexed(y, indices=zeros, t=t, offset=(i, j*TILE_N), axis=0)
3004
+
3005
+ M = TILE_M * 2
3006
+ N = TILE_N * 2
3007
+
3008
+ arr = np.arange(M * N, dtype=float).reshape(M, N)
3009
+
3010
+ x = wp.array(arr, dtype=float, requires_grad=True, device=device)
3011
+ y = wp.zeros((2, N), dtype=float, requires_grad=True, device=device)
3012
+
3013
+ wp.launch_tiled(tile_atomic_add_indexed, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
3014
+
3015
+ print(x.numpy())
3016
+ print(y.numpy())
3017
+
3018
+ Prints:
3019
+
3020
+ .. code-block:: text
3021
+
3022
+ [[ 0. 1. 2. 3.]
3023
+ [ 4. 5. 6. 7.]
3024
+ [ 8. 9. 10. 11.]
3025
+ [12. 13. 14. 15.]]
3026
+
3027
+ [[ 4. 6. 8. 10.]
3028
+ [20. 22. 24. 26.]]
3029
+ """,
3030
+ group="Tile Primitives",
3031
+ export=False,
3032
+ )
3033
+
3034
+
2574
3035
  def tile_view_value_func(arg_types, arg_values):
2575
3036
  # return generic type (for doc builds)
2576
3037
  if arg_types is None:
@@ -3934,14 +4395,45 @@ def tile_unary_map_value_func(arg_types, arg_values):
3934
4395
  if not is_tile(a):
3935
4396
  raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
3936
4397
 
3937
- return tile(dtype=a.dtype, shape=a.shape)
4398
+ if "op" in arg_values:
4399
+ op = arg_values["op"]
4400
+ try:
4401
+ overload = op.get_overload([a.dtype], {})
4402
+ except KeyError as exc:
4403
+ raise RuntimeError(f"No overload of {op} found for tile element type {type_repr(a.dtype)}") from exc
4404
+
4405
+ # build the right overload on demand
4406
+ if overload.value_func is None:
4407
+ overload.build(None)
4408
+
4409
+ value_type = overload.value_func(None, None)
4410
+
4411
+ if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
4412
+ raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
4413
+
4414
+ return tile(dtype=value_type, shape=a.shape)
4415
+
4416
+ else:
4417
+ return tile(dtype=a.dtype, shape=a.shape)
4418
+
4419
+
4420
+ def tile_unary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4421
+ op = arg_values["op"]
4422
+ tile_a = arg_values["a"]
4423
+
4424
+ overload = op.get_overload([tile_a.type.dtype], {})
4425
+
4426
+ # necessary, in case return type is different from input tile types
4427
+ tile_r = Var(label=None, type=return_type)
4428
+
4429
+ return ((overload, tile_a, tile_r), ())
3938
4430
 
3939
4431
 
3940
4432
  add_builtin(
3941
4433
  "tile_map",
3942
4434
  input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3943
4435
  value_func=tile_unary_map_value_func,
3944
- # dispatch_func=tile_map_dispatch_func,
4436
+ dispatch_func=tile_unary_map_dispatch_func,
3945
4437
  # variadic=True,
3946
4438
  native_func="tile_unary_map",
3947
4439
  doc="""Apply a unary function onto the tile.
@@ -3950,7 +4442,7 @@ add_builtin(
3950
4442
 
3951
4443
  :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
3952
4444
  :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
3953
- :returns: A tile with the same dimensions and data type as the input tile.
4445
+ :returns: A tile with the same dimensions as the input tile. Its datatype is specified by the return type of op
3954
4446
 
3955
4447
  Example:
3956
4448
 
@@ -3991,10 +4483,6 @@ def tile_binary_map_value_func(arg_types, arg_values):
3991
4483
  if not is_tile(b):
3992
4484
  raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
3993
4485
 
3994
- # ensure types equal
3995
- if not types_equal(a.dtype, b.dtype):
3996
- raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
3997
-
3998
4486
  if len(a.shape) != len(b.shape):
3999
4487
  raise ValueError(
4000
4488
  f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
@@ -4004,7 +4492,47 @@ def tile_binary_map_value_func(arg_types, arg_values):
4004
4492
  if a.shape[i] != b.shape[i]:
4005
4493
  raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
4006
4494
 
4007
- return tile(dtype=a.dtype, shape=a.shape)
4495
+ if "op" in arg_values:
4496
+ op = arg_values["op"]
4497
+ try:
4498
+ overload = op.get_overload([a.dtype, b.dtype], {})
4499
+ except KeyError as exc:
4500
+ raise RuntimeError(
4501
+ f"No overload of {op} found for tile element types {type_repr(a.dtype)}, {type_repr(b.dtype)}"
4502
+ ) from exc
4503
+
4504
+ # build the right overload on demand
4505
+ if overload.value_func is None:
4506
+ overload.build(None)
4507
+
4508
+ value_type = overload.value_func(None, None)
4509
+
4510
+ if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
4511
+ raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
4512
+
4513
+ return tile(dtype=value_type, shape=a.shape)
4514
+
4515
+ else:
4516
+ # ensure types equal
4517
+ if not types_equal(a.dtype, b.dtype):
4518
+ raise TypeError(
4519
+ f"tile_map() arguments must have the same dtype for this operation, got {a.dtype} and {b.dtype}"
4520
+ )
4521
+
4522
+ return tile(dtype=a.dtype, shape=a.shape)
4523
+
4524
+
4525
+ def tile_binary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4526
+ op = arg_values["op"]
4527
+ tile_a = arg_values["a"]
4528
+ tile_b = arg_values["b"]
4529
+
4530
+ overload = op.get_overload([tile_a.type.dtype, tile_b.type.dtype], {})
4531
+
4532
+ # necessary, in case return type is different from input tile types
4533
+ tile_r = Var(label=None, type=return_type)
4534
+
4535
+ return ((overload, tile_a, tile_b, tile_r), ())
4008
4536
 
4009
4537
 
4010
4538
  add_builtin(
@@ -4015,18 +4543,18 @@ add_builtin(
4015
4543
  "b": tile(dtype=Scalar, shape=Tuple[int, ...]),
4016
4544
  },
4017
4545
  value_func=tile_binary_map_value_func,
4018
- # dispatch_func=tile_map_dispatch_func,
4546
+ dispatch_func=tile_binary_map_dispatch_func,
4019
4547
  # variadic=True,
4020
4548
  native_func="tile_binary_map",
4021
4549
  doc="""Apply a binary function onto the tile.
4022
4550
 
4023
4551
  This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
4024
- Both input tiles must have the same dimensions and datatype.
4552
+ Both input tiles must have the same dimensions, and if using a builtin op, the same datatypes.
4025
4553
 
4026
4554
  :param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
4027
4555
  :param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
4028
4556
  :param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
4029
- :returns: A tile with the same dimensions and datatype as the input tiles.
4557
+ :returns: A tile with the same dimensions as the input tiles. Its datatype is specified by the return type of op
4030
4558
 
4031
4559
  Example:
4032
4560
 
@@ -5544,7 +6072,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
5544
6072
  return array(dtype=Scalar)
5545
6073
 
5546
6074
  dtype = arg_values["dtype"]
5547
- shape = extract_tuple(arg_values["shape"], as_constant=True)
6075
+ shape = extract_tuple(arg_values["shape"], as_constant=False)
5548
6076
  return array(dtype=dtype, ndim=len(shape))
5549
6077
 
5550
6078
 
@@ -5554,7 +6082,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
5554
6082
  # to the underlying C++ function's runtime and template params.
5555
6083
 
5556
6084
  dtype = return_type.dtype
5557
- shape = extract_tuple(args["shape"], as_constant=True)
6085
+ shape = extract_tuple(args["shape"], as_constant=False)
5558
6086
 
5559
6087
  func_args = (args["ptr"], *shape)
5560
6088
  template_args = (dtype,)
@@ -5563,7 +6091,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
5563
6091
 
5564
6092
  add_builtin(
5565
6093
  "array",
5566
- input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Scalar},
6094
+ input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Any},
5567
6095
  value_func=array_value_func,
5568
6096
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
5569
6097
  dispatch_func=array_dispatch_func,
@@ -5575,6 +6103,48 @@ add_builtin(
5575
6103
  )
5576
6104
 
5577
6105
 
6106
+ def zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6107
+ if arg_types is None:
6108
+ return fixedarray(dtype=Scalar)
6109
+
6110
+ dtype = arg_values["dtype"]
6111
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
6112
+
6113
+ if None in shape:
6114
+ raise RuntimeError("the `shape` argument must be specified as a constant when zero-initializing an array")
6115
+
6116
+ return fixedarray(dtype=dtype, shape=shape)
6117
+
6118
+
6119
+ def zeros_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6120
+ # We're in the codegen stage where we emit the code calling the built-in.
6121
+ # Further validate the given argument values if needed and map them
6122
+ # to the underlying C++ function's runtime and template params.
6123
+
6124
+ dtype = return_type.dtype
6125
+ shape = extract_tuple(args["shape"], as_constant=True)
6126
+
6127
+ size = math.prod(shape)
6128
+
6129
+ func_args = shape
6130
+ template_args = (size, dtype)
6131
+ return (func_args, template_args)
6132
+
6133
+
6134
+ add_builtin(
6135
+ "zeros",
6136
+ input_types={"shape": Tuple[int, ...], "dtype": Any},
6137
+ value_func=zeros_value_func,
6138
+ export_func=lambda input_types: {},
6139
+ dispatch_func=zeros_dispatch_func,
6140
+ native_func="fixedarray_t",
6141
+ group="Utility",
6142
+ export=False,
6143
+ missing_grad=True,
6144
+ hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
6145
+ )
6146
+
6147
+
5578
6148
  # does argument checking and type propagation for address()
5579
6149
  def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5580
6150
  arr_type = arg_types["arr"]
@@ -5864,8 +6434,8 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
5864
6434
 
5865
6435
 
5866
6436
  for array_type in array_types:
5867
- # don't list indexed array operations explicitly in docs
5868
- hidden = array_type == indexedarray
6437
+ # don't list fixed or indexed array operations explicitly in docs
6438
+ hidden = array_type in (indexedarray, fixedarray)
5869
6439
 
5870
6440
  add_builtin(
5871
6441
  "atomic_add",
@@ -6187,46 +6757,110 @@ for array_type in array_types:
6187
6757
 
6188
6758
 
6189
6759
  # used to index into builtin types, i.e.: y = vec3[1]
6190
- def extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6191
- return arg_types["a"]._wp_scalar_type_
6760
+ def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6761
+ vec_type = arg_types["a"]
6762
+ idx_type = arg_types["i"]
6763
+
6764
+ if isinstance(idx_type, slice_t):
6765
+ length = idx_type.get_length(vec_type._length_)
6766
+ return vector(length=length, dtype=vec_type._wp_scalar_type_)
6767
+
6768
+ return vec_type._wp_scalar_type_
6769
+
6770
+
6771
+ def vector_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6772
+ func_args = tuple(args.values())
6773
+ template_args = getattr(return_type, "_shape_", ())
6774
+ return (func_args, template_args)
6192
6775
 
6193
6776
 
6194
6777
  add_builtin(
6195
6778
  "extract",
6196
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
6197
- value_func=extract_value_func,
6779
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any},
6780
+ value_func=vector_extract_value_func,
6781
+ dispatch_func=vector_extract_dispatch_func,
6782
+ export=False,
6198
6783
  hidden=True,
6199
6784
  group="Utility",
6200
6785
  )
6201
6786
  add_builtin(
6202
6787
  "extract",
6203
- input_types={"a": quaternion(dtype=Scalar), "i": int},
6204
- value_func=extract_value_func,
6788
+ input_types={"a": quaternion(dtype=Scalar), "i": Any},
6789
+ value_func=vector_extract_value_func,
6790
+ dispatch_func=vector_extract_dispatch_func,
6791
+ export=False,
6205
6792
  hidden=True,
6206
6793
  group="Utility",
6207
6794
  )
6208
-
6209
6795
  add_builtin(
6210
6796
  "extract",
6211
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
6212
- value_func=lambda arg_types, arg_values: vector(
6213
- length=arg_types["a"]._shape_[1], dtype=arg_types["a"]._wp_scalar_type_
6214
- ),
6797
+ input_types={"a": transformation(dtype=Scalar), "i": Any},
6798
+ value_func=vector_extract_value_func,
6799
+ dispatch_func=vector_extract_dispatch_func,
6800
+ export=False,
6215
6801
  hidden=True,
6216
6802
  group="Utility",
6217
6803
  )
6804
+
6805
+
6806
+ def matrix_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6807
+ mat_type = arg_types["a"]
6808
+ idx_types = tuple(arg_types[x] for x in "ij" if arg_types.get(x, None) is not None)
6809
+
6810
+ # Compute the resulting shape from the slicing, with -1 being simple indexing.
6811
+ shape = tuple(
6812
+ idx.get_length(mat_type._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idx_types)
6813
+ )
6814
+
6815
+ # Append any non indexed slice.
6816
+ for i in range(len(idx_types), len(mat_type._shape_)):
6817
+ shape += (mat_type._shape_[i],)
6818
+
6819
+ # Count how many dimensions the output value will have.
6820
+ ndim = sum(1 for x in shape if x >= 0)
6821
+
6822
+ if ndim == 0:
6823
+ return mat_type._wp_scalar_type_
6824
+
6825
+ assert shape[0] != -1 or shape[1] != -1
6826
+
6827
+ if ndim == 1:
6828
+ length = shape[0] if shape[0] != -1 else shape[1]
6829
+ return vector(length=length, dtype=mat_type._wp_scalar_type_)
6830
+
6831
+ assert ndim == 2
6832
+
6833
+ # When a matrix dimension is 0, all other dimensions are also expected to be 0.
6834
+ if any(x == 0 for x in shape):
6835
+ shape = (0,) * len(shape)
6836
+
6837
+ return matrix(shape=shape, dtype=mat_type._wp_scalar_type_)
6838
+
6839
+
6840
+ def matrix_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6841
+ idx_types = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
6842
+ has_slice = any(isinstance(x, slice_t) for x in idx_types)
6843
+
6844
+ func_args = tuple(args.values())
6845
+ template_args = getattr(return_type, "_shape_", ()) if has_slice else ()
6846
+ return (func_args, template_args)
6847
+
6848
+
6218
6849
  add_builtin(
6219
6850
  "extract",
6220
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
6221
- value_func=extract_value_func,
6851
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any},
6852
+ value_func=matrix_extract_value_func,
6853
+ dispatch_func=matrix_extract_dispatch_func,
6854
+ export=False,
6222
6855
  hidden=True,
6223
6856
  group="Utility",
6224
6857
  )
6225
-
6226
6858
  add_builtin(
6227
6859
  "extract",
6228
- input_types={"a": transformation(dtype=Scalar), "i": int},
6229
- value_func=extract_value_func,
6860
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any},
6861
+ value_func=matrix_extract_value_func,
6862
+ dispatch_func=matrix_extract_dispatch_func,
6863
+ export=False,
6230
6864
  hidden=True,
6231
6865
  group="Utility",
6232
6866
  )
@@ -6247,6 +6881,19 @@ def vector_index_dispatch_func(input_types: Mapping[str, type], return_type: Any
6247
6881
  return (func_args, template_args)
6248
6882
 
6249
6883
 
6884
+ def matrix_ij_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6885
+ mat_type = arg_types["a"]
6886
+ value_type = mat_type._wp_scalar_type_
6887
+
6888
+ return Reference(value_type)
6889
+
6890
+
6891
+ def matrix_ij_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6892
+ func_args = (Reference(args["a"]), args["i"], args["j"])
6893
+ template_args = ()
6894
+ return (func_args, template_args)
6895
+
6896
+
6250
6897
  # implements &vector[index]
6251
6898
  add_builtin(
6252
6899
  "index",
@@ -6287,6 +6934,16 @@ add_builtin(
6287
6934
  group="Utility",
6288
6935
  skip_replay=True,
6289
6936
  )
6937
+ # implements &(*matrix)[i, j]
6938
+ add_builtin(
6939
+ "indexref",
6940
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
6941
+ value_func=matrix_ij_value_func,
6942
+ dispatch_func=matrix_ij_dispatch_func,
6943
+ hidden=True,
6944
+ group="Utility",
6945
+ skip_replay=True,
6946
+ )
6290
6947
  # implements &(*quaternion)[index]
6291
6948
  add_builtin(
6292
6949
  "indexref",
@@ -6309,11 +6966,46 @@ add_builtin(
6309
6966
  )
6310
6967
 
6311
6968
 
6969
+ def vector_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6970
+ vec = args["a"].type
6971
+ idx = args["i"].type
6972
+ value_type = strip_reference(args["value"].type)
6973
+
6974
+ if isinstance(idx, slice_t):
6975
+ length = idx.get_length(vec._length_)
6976
+
6977
+ if type_is_vector(value_type):
6978
+ if not types_equal(value_type._wp_scalar_type_, vec._wp_scalar_type_):
6979
+ raise ValueError(
6980
+ f"The provided vector is expected to be of length {length} with dtype {type_repr(vec._wp_scalar_type_)}."
6981
+ )
6982
+ if value_type._length_ != length:
6983
+ raise ValueError(
6984
+ f"The length of the provided vector ({args['value'].type._length_}) isn't compatible with the given slice (expected {length})."
6985
+ )
6986
+ template_args = (length,)
6987
+ else:
6988
+ # Disallow broadcasting.
6989
+ raise ValueError(
6990
+ f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(vec._wp_scalar_type_)}."
6991
+ )
6992
+ else:
6993
+ if not types_equal(value_type, vec._wp_scalar_type_):
6994
+ raise ValueError(
6995
+ f"The provided value is expected to be a scalar of type {type_repr(vec._wp_scalar_type_)}."
6996
+ )
6997
+ template_args = ()
6998
+
6999
+ func_args = tuple(args.values())
7000
+ return (func_args, template_args)
7001
+
7002
+
6312
7003
  # implements vector[index] = value
6313
7004
  add_builtin(
6314
7005
  "assign_inplace",
6315
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
7006
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
6316
7007
  value_type=None,
7008
+ dispatch_func=vector_assign_dispatch_func,
6317
7009
  hidden=True,
6318
7010
  export=False,
6319
7011
  group="Utility",
@@ -6322,8 +7014,9 @@ add_builtin(
6322
7014
  # implements quaternion[index] = value
6323
7015
  add_builtin(
6324
7016
  "assign_inplace",
6325
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
7017
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
6326
7018
  value_type=None,
7019
+ dispatch_func=vector_assign_dispatch_func,
6327
7020
  hidden=True,
6328
7021
  export=False,
6329
7022
  group="Utility",
@@ -6331,15 +7024,16 @@ add_builtin(
6331
7024
  # implements transformation[index] = value
6332
7025
  add_builtin(
6333
7026
  "assign_inplace",
6334
- input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
7027
+ input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
6335
7028
  value_type=None,
7029
+ dispatch_func=vector_assign_dispatch_func,
6336
7030
  hidden=True,
6337
7031
  export=False,
6338
7032
  group="Utility",
6339
7033
  )
6340
7034
 
6341
7035
 
6342
- def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7036
+ def vector_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6343
7037
  vec_type = arg_types["a"]
6344
7038
  return vec_type
6345
7039
 
@@ -6347,8 +7041,9 @@ def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
6347
7041
  # implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6348
7042
  add_builtin(
6349
7043
  "assign_copy",
6350
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
6351
- value_func=vector_assign_value_func,
7044
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
7045
+ value_func=vector_assign_copy_value_func,
7046
+ dispatch_func=vector_assign_dispatch_func,
6352
7047
  hidden=True,
6353
7048
  export=False,
6354
7049
  group="Utility",
@@ -6357,8 +7052,9 @@ add_builtin(
6357
7052
  # implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6358
7053
  add_builtin(
6359
7054
  "assign_copy",
6360
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
6361
- value_func=vector_assign_value_func,
7055
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
7056
+ value_func=vector_assign_copy_value_func,
7057
+ dispatch_func=vector_assign_dispatch_func,
6362
7058
  hidden=True,
6363
7059
  export=False,
6364
7060
  group="Utility",
@@ -6367,8 +7063,9 @@ add_builtin(
6367
7063
  # implements transformation[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6368
7064
  add_builtin(
6369
7065
  "assign_copy",
6370
- input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
6371
- value_func=vector_assign_value_func,
7066
+ input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
7067
+ value_func=vector_assign_copy_value_func,
7068
+ dispatch_func=vector_assign_dispatch_func,
6372
7069
  hidden=True,
6373
7070
  export=False,
6374
7071
  group="Utility",
@@ -6377,8 +7074,9 @@ add_builtin(
6377
7074
  # implements vector[idx] += scalar
6378
7075
  add_builtin(
6379
7076
  "add_inplace",
6380
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
7077
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
6381
7078
  value_type=None,
7079
+ dispatch_func=vector_assign_dispatch_func,
6382
7080
  hidden=True,
6383
7081
  export=False,
6384
7082
  group="Utility",
@@ -6387,8 +7085,9 @@ add_builtin(
6387
7085
  # implements quaternion[idx] += scalar
6388
7086
  add_builtin(
6389
7087
  "add_inplace",
6390
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
7088
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
6391
7089
  value_type=None,
7090
+ dispatch_func=vector_assign_dispatch_func,
6392
7091
  hidden=True,
6393
7092
  export=False,
6394
7093
  group="Utility",
@@ -6397,8 +7096,9 @@ add_builtin(
6397
7096
  # implements transformation[idx] += scalar
6398
7097
  add_builtin(
6399
7098
  "add_inplace",
6400
- input_types={"a": transformation(dtype=Float), "i": int, "value": Float},
7099
+ input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
6401
7100
  value_type=None,
7101
+ dispatch_func=vector_assign_dispatch_func,
6402
7102
  hidden=True,
6403
7103
  export=False,
6404
7104
  group="Utility",
@@ -6417,8 +7117,9 @@ add_builtin(
6417
7117
  # implements vector[idx] -= scalar
6418
7118
  add_builtin(
6419
7119
  "sub_inplace",
6420
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
7120
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
6421
7121
  value_type=None,
7122
+ dispatch_func=vector_assign_dispatch_func,
6422
7123
  hidden=True,
6423
7124
  export=False,
6424
7125
  group="Utility",
@@ -6427,8 +7128,9 @@ add_builtin(
6427
7128
  # implements quaternion[idx] -= scalar
6428
7129
  add_builtin(
6429
7130
  "sub_inplace",
6430
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
7131
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
6431
7132
  value_type=None,
7133
+ dispatch_func=vector_assign_dispatch_func,
6432
7134
  hidden=True,
6433
7135
  export=False,
6434
7136
  group="Utility",
@@ -6437,8 +7139,9 @@ add_builtin(
6437
7139
  # implements transformation[idx] -= scalar
6438
7140
  add_builtin(
6439
7141
  "sub_inplace",
6440
- input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
7142
+ input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
6441
7143
  value_type=None,
7144
+ dispatch_func=vector_assign_dispatch_func,
6442
7145
  hidden=True,
6443
7146
  export=False,
6444
7147
  group="Utility",
@@ -6499,61 +7202,154 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
6499
7202
  return mat_size == vec_size and mat_type == vec_type
6500
7203
 
6501
7204
 
6502
- # implements matrix[i,j] = scalar
7205
+ def matrix_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
7206
+ mat = args["a"].type
7207
+ value_type = strip_reference(args["value"].type)
7208
+
7209
+ idxs = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
7210
+ has_slice = any(isinstance(x, slice_t) for x in idxs)
7211
+
7212
+ if has_slice:
7213
+ # Compute the resulting shape from the slicing, with -1 being simple indexing.
7214
+ shape = tuple(idx.get_length(mat._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idxs))
7215
+
7216
+ # Append any non indexed slice.
7217
+ for i in range(len(idxs), len(mat._shape_)):
7218
+ shape += (mat._shape_[i],)
7219
+
7220
+ # Count how many dimensions the output value will have.
7221
+ ndim = sum(1 for x in shape if x >= 0)
7222
+ assert ndim > 0
7223
+
7224
+ if ndim == 1:
7225
+ length = shape[0] if shape[0] != -1 else shape[1]
7226
+
7227
+ if type_is_vector(value_type):
7228
+ if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
7229
+ raise ValueError(
7230
+ f"The provided vector is expected to be of length {length} with dtype {type_repr(mat._wp_scalar_type_)}."
7231
+ )
7232
+
7233
+ if value_type._length_ != length:
7234
+ raise ValueError(
7235
+ f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {length})."
7236
+ )
7237
+
7238
+ template_args = (length,)
7239
+ else:
7240
+ # Disallow broadcasting.
7241
+ raise ValueError(
7242
+ f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(mat._wp_scalar_type_)}."
7243
+ )
7244
+ else:
7245
+ assert ndim == 2
7246
+
7247
+ # When a matrix dimension is 0, all other dimensions are also expected to be 0.
7248
+ if any(x == 0 for x in shape):
7249
+ shape = (0,) * len(shape)
7250
+
7251
+ if type_is_matrix(value_type):
7252
+ if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
7253
+ raise ValueError(
7254
+ f"The provided matrix is expected to be of shape {shape} with dtype {type_repr(mat._wp_scalar_type_)}."
7255
+ )
7256
+
7257
+ if value_type._shape_ != shape:
7258
+ raise ValueError(
7259
+ f"The shape of the provided matrix ({value_type._shape_}) isn't compatible with the given slice (expected {shape})."
7260
+ )
7261
+
7262
+ template_args = shape
7263
+ else:
7264
+ # Disallow broadcasting.
7265
+ raise ValueError(
7266
+ f"The provided value is expected to be a matrix of shape {shape}, with dtype {type_repr(mat._wp_scalar_type_)}."
7267
+ )
7268
+ elif len(idxs) == 1:
7269
+ if not type_is_vector(value_type) or not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
7270
+ raise ValueError(
7271
+ f"The provided value is expected to be a vector of length {mat._shape_[1]}, with dtype {type_repr(mat._wp_scalar_type_)}."
7272
+ )
7273
+
7274
+ if value_type._length_ != mat._shape_[1]:
7275
+ raise ValueError(
7276
+ f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {mat._shape_[1]})."
7277
+ )
7278
+
7279
+ template_args = ()
7280
+ elif len(idxs) == 2:
7281
+ if not types_equal(value_type, mat._wp_scalar_type_):
7282
+ raise ValueError(
7283
+ f"The provided value is expected to be a scalar of type {type_repr(mat._wp_scalar_type_)}."
7284
+ )
7285
+
7286
+ template_args = ()
7287
+ else:
7288
+ raise AssertionError
7289
+
7290
+ func_args = tuple(args.values())
7291
+ return (func_args, template_args)
7292
+
7293
+
7294
+ # implements matrix[i] = value
6503
7295
  add_builtin(
6504
7296
  "assign_inplace",
6505
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
7297
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
7298
+ constraint=matrix_vector_sametype,
6506
7299
  value_type=None,
7300
+ dispatch_func=matrix_assign_dispatch_func,
6507
7301
  hidden=True,
6508
7302
  export=False,
6509
7303
  group="Utility",
6510
7304
  )
6511
7305
 
6512
7306
 
6513
- # implements matrix[i] = vector
7307
+ # implements matrix[i,j] = value
6514
7308
  add_builtin(
6515
7309
  "assign_inplace",
6516
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
6517
- constraint=matrix_vector_sametype,
7310
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
6518
7311
  value_type=None,
7312
+ dispatch_func=matrix_assign_dispatch_func,
6519
7313
  hidden=True,
6520
7314
  export=False,
6521
7315
  group="Utility",
6522
7316
  )
6523
7317
 
6524
7318
 
6525
- def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7319
+ def matrix_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6526
7320
  mat_type = arg_types["a"]
6527
7321
  return mat_type
6528
7322
 
6529
7323
 
6530
- # implements matrix[i,j] = scalar
7324
+ # implements matrix[i] = value
6531
7325
  add_builtin(
6532
7326
  "assign_copy",
6533
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
6534
- value_func=matrix_assign_value_func,
7327
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
7328
+ value_func=matrix_assign_copy_value_func,
7329
+ dispatch_func=matrix_assign_dispatch_func,
6535
7330
  hidden=True,
6536
7331
  export=False,
6537
7332
  group="Utility",
6538
7333
  )
6539
7334
 
6540
7335
 
6541
- # implements matrix[i] = vector
7336
+ # implements matrix[i,j] = value
6542
7337
  add_builtin(
6543
7338
  "assign_copy",
6544
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
6545
- constraint=matrix_vector_sametype,
6546
- value_func=matrix_assign_value_func,
7339
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
7340
+ value_func=matrix_assign_copy_value_func,
7341
+ dispatch_func=matrix_assign_dispatch_func,
6547
7342
  hidden=True,
6548
7343
  export=False,
6549
7344
  group="Utility",
6550
7345
  )
6551
7346
 
6552
7347
 
6553
- # implements matrix[i,j] += scalar
7348
+ # implements matrix[i] += value
6554
7349
  add_builtin(
6555
7350
  "add_inplace",
6556
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
7351
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
7352
+ constraint=matrix_vector_sametype,
6557
7353
  value_type=None,
6558
7354
  hidden=True,
6559
7355
  export=False,
@@ -6561,11 +7357,10 @@ add_builtin(
6561
7357
  )
6562
7358
 
6563
7359
 
6564
- # implements matrix[i] += vector
7360
+ # implements matrix[i,j] += value
6565
7361
  add_builtin(
6566
7362
  "add_inplace",
6567
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
6568
- constraint=matrix_vector_sametype,
7363
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
6569
7364
  value_type=None,
6570
7365
  hidden=True,
6571
7366
  export=False,
@@ -6573,10 +7368,10 @@ add_builtin(
6573
7368
  )
6574
7369
 
6575
7370
 
6576
- # implements matrix[i,j] -= scalar
7371
+ # implements matrix[i] -= value
6577
7372
  add_builtin(
6578
7373
  "sub_inplace",
6579
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
7374
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
6580
7375
  value_type=None,
6581
7376
  hidden=True,
6582
7377
  export=False,
@@ -6584,10 +7379,10 @@ add_builtin(
6584
7379
  )
6585
7380
 
6586
7381
 
6587
- # implements matrix[i] -= vector
7382
+ # implements matrix[i,j] -= value
6588
7383
  add_builtin(
6589
7384
  "sub_inplace",
6590
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
7385
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
6591
7386
  value_type=None,
6592
7387
  hidden=True,
6593
7388
  export=False,
@@ -6807,6 +7602,7 @@ add_builtin(
6807
7602
  # ---------------------------------
6808
7603
  # Operators
6809
7604
 
7605
+
6810
7606
  add_builtin(
6811
7607
  "add", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
6812
7608
  )
@@ -7079,7 +7875,7 @@ add_builtin(
7079
7875
  "mod",
7080
7876
  input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
7081
7877
  constraint=sametypes,
7082
- value_func=sametypes_create_value_func(Scalar),
7878
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
7083
7879
  doc="Modulo operation using truncated division.",
7084
7880
  group="Operators",
7085
7881
  )
@@ -7481,7 +8277,7 @@ def tile_matmul_lto_dispatch_func(
7481
8277
  num_threads = options["block_dim"]
7482
8278
  arch = options["output_arch"]
7483
8279
 
7484
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8280
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7485
8281
  # CPU/no-MathDx dispatch
7486
8282
  return ((0, 0, 0, a, b, out), template_args, [], 0)
7487
8283
  else:
@@ -7671,7 +8467,7 @@ def tile_fft_generic_lto_dispatch_func(
7671
8467
  arch = options["output_arch"]
7672
8468
  ept = size // num_threads
7673
8469
 
7674
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8470
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7675
8471
  # CPU/no-MathDx dispatch
7676
8472
  return ([], [], [], 0)
7677
8473
  else:
@@ -7792,28 +8588,27 @@ def tile_cholesky_generic_lto_dispatch_func(
7792
8588
  raise TypeError("tile_cholesky() returns one output")
7793
8589
  out = return_values[0]
7794
8590
 
7795
- dtype, precision_enum = cusolver_type_map[a.type.dtype]
7796
-
7797
8591
  # We already ensured a is square in tile_cholesky_generic_value_func()
7798
8592
  M, N = a.type.shape
7799
8593
  if out.type.shape[0] != M or out.type.shape[1] != M:
7800
8594
  raise ValueError("tile_cholesky() output tile must be square")
7801
8595
 
7802
- solver = "potrf"
7803
- solver_enum = cusolver_function_map[solver]
7804
-
7805
- side_enum = cusolver_side_map["-"]
7806
- diag_enum = cusolver_diag_map["-"]
7807
- fill_mode = cusolver_fill_mode_map["lower"]
7808
-
7809
8596
  arch = options["output_arch"]
7810
- num_threads = options["block_dim"]
7811
- parameter_list = f"({dtype}*, int*)"
7812
8597
 
7813
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8598
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7814
8599
  # CPU/no-MathDx dispatch
7815
8600
  return ((0, a, out), [], [], 0)
7816
8601
  else:
8602
+ solver = "potrf"
8603
+ solver_enum = cusolver_function_map[solver]
8604
+ side_enum = cusolver_side_map["-"]
8605
+ diag_enum = cusolver_diag_map["-"]
8606
+ fill_mode = cusolver_fill_mode_map["lower"]
8607
+ dtype, precision_enum = cusolver_type_map[a.type.dtype]
8608
+ num_threads = options["block_dim"]
8609
+ parameter_list = f"({dtype}*, int*)"
8610
+ req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
8611
+
7817
8612
  # generate the LTO
7818
8613
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
7819
8614
  M,
@@ -7831,6 +8626,7 @@ def tile_cholesky_generic_lto_dispatch_func(
7831
8626
  num_threads,
7832
8627
  parameter_list,
7833
8628
  builder,
8629
+ smem_estimate_bytes=req_smem_bytes,
7834
8630
  )
7835
8631
 
7836
8632
  return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
@@ -7918,9 +8714,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
7918
8714
  if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
7919
8715
  raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
7920
8716
 
7921
- dtype, precision_enum = cusolver_type_map[L.type.dtype]
7922
8717
  M, N = L.type.shape
7923
- NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
7924
8718
 
7925
8719
  if len(x.type.shape) > 2 or len(x.type.shape) < 1:
7926
8720
  raise TypeError(f"tile_cholesky_solve() output vector must be 1D or 2D, got {len(x.type.shape)}-D")
@@ -7931,21 +8725,23 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
7931
8725
  f"got {x.type.shape[0]} elements in output and {M} rows in 'L'"
7932
8726
  )
7933
8727
 
7934
- solver = "potrs"
7935
- solver_enum = cusolver_function_map[solver]
7936
-
7937
- side_enum = cusolver_side_map["-"]
7938
- diag_enum = cusolver_diag_map["-"]
7939
- fill_mode = cusolver_fill_mode_map["lower"]
7940
-
7941
8728
  arch = options["output_arch"]
7942
- num_threads = options["block_dim"]
7943
- parameter_list = f"({dtype}*, {dtype}*)"
7944
8729
 
7945
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8730
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7946
8731
  # CPU/no-MathDx dispatch
7947
8732
  return ((0, L, y, x), [], [], 0)
7948
8733
  else:
8734
+ NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
8735
+ solver = "potrs"
8736
+ solver_enum = cusolver_function_map[solver]
8737
+ side_enum = cusolver_side_map["-"]
8738
+ diag_enum = cusolver_diag_map["-"]
8739
+ fill_mode = cusolver_fill_mode_map["lower"]
8740
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
8741
+ num_threads = options["block_dim"]
8742
+ parameter_list = f"({dtype}*, {dtype}*)"
8743
+ req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8744
+
7949
8745
  # generate the LTO
7950
8746
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
7951
8747
  M,
@@ -7963,6 +8759,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
7963
8759
  num_threads,
7964
8760
  parameter_list,
7965
8761
  builder,
8762
+ smem_estimate_bytes=req_smem_bytes,
7966
8763
  )
7967
8764
 
7968
8765
  return ((Var(lto_symbol, str, False, True, False), L, y, x), [], [lto_code_data], 0)
@@ -8013,9 +8810,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8013
8810
 
8014
8811
  z = return_values[0]
8015
8812
 
8016
- dtype, precision_enum = cusolver_type_map[L.type.dtype]
8017
8813
  M, N = L.type.shape
8018
- NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
8019
8814
 
8020
8815
  if len(z.type.shape) > 2 or len(z.type.shape) < 1:
8021
8816
  raise TypeError(f"tile_lower_solve() output vector must be 1D or 2D, got {len(z.type.shape)}-D")
@@ -8026,21 +8821,23 @@ def tile_lower_solve_generic_lto_dispatch_func(
8026
8821
  f"got {z.type.shape[0]} elements in output and {M} rows in 'L'"
8027
8822
  )
8028
8823
 
8029
- solver = "trsm"
8030
- solver_enum = cusolver_function_map[solver]
8031
-
8032
- side_enum = cusolver_side_map["left"]
8033
- diag_enum = cusolver_diag_map["nounit"]
8034
- fill_mode = cusolver_fill_mode_map["lower"]
8035
-
8036
8824
  arch = options["output_arch"]
8037
- num_threads = options["block_dim"]
8038
- parameter_list = f"({dtype}*, {dtype}*)"
8039
8825
 
8040
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8826
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
8041
8827
  # CPU/no-MathDx dispatch
8042
8828
  return ((0, L, y, z), [], [], 0)
8043
8829
  else:
8830
+ NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
8831
+ solver = "trsm"
8832
+ solver_enum = cusolver_function_map[solver]
8833
+ side_enum = cusolver_side_map["left"]
8834
+ diag_enum = cusolver_diag_map["nounit"]
8835
+ fill_mode = cusolver_fill_mode_map["lower"]
8836
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
8837
+ num_threads = options["block_dim"]
8838
+ parameter_list = f"({dtype}*, {dtype}*)"
8839
+ req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8840
+
8044
8841
  # generate the LTO
8045
8842
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
8046
8843
  M,
@@ -8058,6 +8855,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8058
8855
  num_threads,
8059
8856
  parameter_list,
8060
8857
  builder,
8858
+ smem_estimate_bytes=req_smem_bytes,
8061
8859
  )
8062
8860
 
8063
8861
  return ((Var(lto_symbol, str, False, True, False), L, y, z), [], [lto_code_data], 0)
@@ -8144,9 +8942,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8144
8942
 
8145
8943
  x = return_values[0]
8146
8944
 
8147
- dtype, precision_enum = cusolver_type_map[U.type.dtype]
8148
8945
  M, N = U.type.shape
8149
- NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
8150
8946
 
8151
8947
  if len(z.type.shape) > 2 or len(z.type.shape) < 1:
8152
8948
  raise TypeError(f"tile_upper_solve() output tile must be 1D or 2D, got {len(z.type.shape)}-D")
@@ -8157,21 +8953,23 @@ def tile_upper_solve_generic_lto_dispatch_func(
8157
8953
  f"got {z.type.shape[0]} elements in output and {M} rows in 'U'"
8158
8954
  )
8159
8955
 
8160
- solver = "trsm"
8161
- solver_enum = cusolver_function_map[solver]
8162
-
8163
- side_enum = cusolver_side_map["left"]
8164
- diag_enum = cusolver_diag_map["nounit"]
8165
- fill_mode = cusolver_fill_mode_map["upper"]
8166
-
8167
8956
  arch = options["output_arch"]
8168
- num_threads = options["block_dim"]
8169
- parameter_list = f"({dtype}*, {dtype}*)"
8170
8957
 
8171
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8958
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
8172
8959
  # CPU/no-MathDx dispatch
8173
8960
  return ((0, U, z, x), [], [], 0)
8174
8961
  else:
8962
+ NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
8963
+ solver = "trsm"
8964
+ solver_enum = cusolver_function_map[solver]
8965
+ side_enum = cusolver_side_map["left"]
8966
+ diag_enum = cusolver_diag_map["nounit"]
8967
+ fill_mode = cusolver_fill_mode_map["upper"]
8968
+ dtype, precision_enum = cusolver_type_map[U.type.dtype]
8969
+ num_threads = options["block_dim"]
8970
+ parameter_list = f"({dtype}*, {dtype}*)"
8971
+ req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
8972
+
8175
8973
  # generate the LTO
8176
8974
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
8177
8975
  M,
@@ -8189,6 +8987,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8189
8987
  num_threads,
8190
8988
  parameter_list,
8191
8989
  builder,
8990
+ smem_estimate_bytes=req_smem_bytes,
8192
8991
  )
8193
8992
 
8194
8993
  return ((Var(lto_symbol, str, False, True, False), U, z, x), [], [lto_code_data], 0)
@@ -8413,3 +9212,22 @@ add_builtin(
8413
9212
  group="Utility",
8414
9213
  export=False,
8415
9214
  )
9215
+
9216
+ # ---------------------------------
9217
+ # Slicing
9218
+
9219
+
9220
+ def slice_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
9221
+ return slice_t(**arg_values)
9222
+
9223
+
9224
+ add_builtin(
9225
+ "slice",
9226
+ input_types={"start": int, "stop": int, "step": int},
9227
+ value_func=slice_value_func,
9228
+ native_func="slice_t",
9229
+ export=False,
9230
+ group="Utility",
9231
+ hidden=True,
9232
+ missing_grad=True,
9233
+ )