warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -17,10 +17,12 @@ from __future__ import annotations
17
17
 
18
18
  import builtins
19
19
  import functools
20
+ import math
20
21
  from typing import Any, Callable, Mapping, Sequence
21
22
 
22
23
  import warp.build
23
24
  import warp.context
25
+ import warp.utils
24
26
  from warp.codegen import Reference, Var, get_arg_value, strip_reference
25
27
  from warp.types import *
26
28
 
@@ -124,6 +126,7 @@ add_builtin(
124
126
  value_func=sametypes_create_value_func(Scalar),
125
127
  doc="Return -1 if ``x`` < 0, return 1 otherwise.",
126
128
  group="Scalar Math",
129
+ missing_grad=True,
127
130
  )
128
131
 
129
132
  add_builtin(
@@ -132,6 +135,7 @@ add_builtin(
132
135
  value_func=sametypes_create_value_func(Scalar),
133
136
  doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
134
137
  group="Scalar Math",
138
+ missing_grad=True,
135
139
  )
136
140
  add_builtin(
137
141
  "nonzero",
@@ -139,6 +143,7 @@ add_builtin(
139
143
  value_func=sametypes_create_value_func(Scalar),
140
144
  doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
141
145
  group="Scalar Math",
146
+ missing_grad=True,
142
147
  )
143
148
 
144
149
  add_builtin(
@@ -290,6 +295,7 @@ add_builtin(
290
295
 
291
296
  This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
292
297
  Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
298
+ missing_grad=True,
293
299
  )
294
300
 
295
301
  add_builtin(
@@ -300,6 +306,7 @@ add_builtin(
300
306
  doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
301
307
 
302
308
  It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
309
+ missing_grad=True,
303
310
  )
304
311
 
305
312
  add_builtin(
@@ -312,6 +319,7 @@ add_builtin(
312
319
  In other words, it discards the fractional part of ``x``.
313
320
  It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
314
321
  Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
322
+ missing_grad=True,
315
323
  )
316
324
 
317
325
  add_builtin(
@@ -320,6 +328,7 @@ add_builtin(
320
328
  value_func=sametypes_create_value_func(Float),
321
329
  group="Scalar Math",
322
330
  doc="""Return the largest integer that is less than or equal to ``x``.""",
331
+ missing_grad=True,
323
332
  )
324
333
 
325
334
  add_builtin(
@@ -328,6 +337,7 @@ add_builtin(
328
337
  value_func=sametypes_create_value_func(Float),
329
338
  group="Scalar Math",
330
339
  doc="""Return the smallest integer that is greater than or equal to ``x``.""",
340
+ missing_grad=True,
331
341
  )
332
342
 
333
343
  add_builtin(
@@ -338,6 +348,7 @@ add_builtin(
338
348
  doc="""Retrieve the fractional part of ``x``.
339
349
 
340
350
  In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
351
+ missing_grad=True,
341
352
  )
342
353
 
343
354
  add_builtin(
@@ -346,6 +357,7 @@ add_builtin(
346
357
  value_type=builtins.bool,
347
358
  group="Scalar Math",
348
359
  doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
360
+ missing_grad=True,
349
361
  )
350
362
  add_builtin(
351
363
  "isfinite",
@@ -353,6 +365,7 @@ add_builtin(
353
365
  value_type=builtins.bool,
354
366
  group="Vector Math",
355
367
  doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
368
+ missing_grad=True,
356
369
  )
357
370
  add_builtin(
358
371
  "isfinite",
@@ -360,6 +373,7 @@ add_builtin(
360
373
  value_type=builtins.bool,
361
374
  group="Vector Math",
362
375
  doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
376
+ missing_grad=True,
363
377
  )
364
378
  add_builtin(
365
379
  "isfinite",
@@ -367,6 +381,7 @@ add_builtin(
367
381
  value_type=builtins.bool,
368
382
  group="Vector Math",
369
383
  doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
384
+ missing_grad=True,
370
385
  )
371
386
 
372
387
  add_builtin(
@@ -375,6 +390,7 @@ add_builtin(
375
390
  value_type=builtins.bool,
376
391
  doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
377
392
  group="Scalar Math",
393
+ missing_grad=True,
378
394
  )
379
395
  add_builtin(
380
396
  "isnan",
@@ -382,6 +398,7 @@ add_builtin(
382
398
  value_type=builtins.bool,
383
399
  group="Vector Math",
384
400
  doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
401
+ missing_grad=True,
385
402
  )
386
403
  add_builtin(
387
404
  "isnan",
@@ -389,6 +406,7 @@ add_builtin(
389
406
  value_type=builtins.bool,
390
407
  group="Vector Math",
391
408
  doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
409
+ missing_grad=True,
392
410
  )
393
411
  add_builtin(
394
412
  "isnan",
@@ -396,6 +414,7 @@ add_builtin(
396
414
  value_type=builtins.bool,
397
415
  group="Vector Math",
398
416
  doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
417
+ missing_grad=True,
399
418
  )
400
419
 
401
420
  add_builtin(
@@ -404,6 +423,7 @@ add_builtin(
404
423
  value_type=builtins.bool,
405
424
  group="Scalar Math",
406
425
  doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
426
+ missing_grad=True,
407
427
  )
408
428
  add_builtin(
409
429
  "isinf",
@@ -411,6 +431,7 @@ add_builtin(
411
431
  value_type=builtins.bool,
412
432
  group="Vector Math",
413
433
  doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
434
+ missing_grad=True,
414
435
  )
415
436
  add_builtin(
416
437
  "isinf",
@@ -418,6 +439,7 @@ add_builtin(
418
439
  value_type=builtins.bool,
419
440
  group="Vector Math",
420
441
  doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
442
+ missing_grad=True,
421
443
  )
422
444
  add_builtin(
423
445
  "isinf",
@@ -425,6 +447,7 @@ add_builtin(
425
447
  value_type=builtins.bool,
426
448
  group="Vector Math",
427
449
  doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
450
+ missing_grad=True,
428
451
  )
429
452
 
430
453
 
@@ -1180,6 +1203,7 @@ add_builtin(
1180
1203
  doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
1181
1204
  group="Vector Math",
1182
1205
  export=False,
1206
+ missing_grad=True,
1183
1207
  )
1184
1208
 
1185
1209
 
@@ -1544,6 +1568,7 @@ add_builtin(
1544
1568
  group="Quaternion Math",
1545
1569
  doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
1546
1570
  export=True,
1571
+ missing_grad=True,
1547
1572
  )
1548
1573
 
1549
1574
  add_builtin(
@@ -1759,6 +1784,7 @@ add_builtin(
1759
1784
  doc="Construct a spatial transform vector of given dtype.",
1760
1785
  group="Spatial Math",
1761
1786
  export=False,
1787
+ missing_grad=True,
1762
1788
  )
1763
1789
 
1764
1790
 
@@ -1793,6 +1819,7 @@ add_builtin(
1793
1819
  group="Transformations",
1794
1820
  doc="Construct an identity transform with zero translation and identity rotation.",
1795
1821
  export=True,
1822
+ missing_grad=True,
1796
1823
  )
1797
1824
 
1798
1825
  add_builtin(
@@ -2355,6 +2382,7 @@ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mappin
2355
2382
  def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2356
2383
  a = args["a"]
2357
2384
  shape = extract_tuple(args["shape"], as_constant=True)
2385
+ bounds_check = args["bounds_check"]
2358
2386
 
2359
2387
  if None in shape:
2360
2388
  raise ValueError("Tile functions require shape to be a compile time constant.")
@@ -2365,17 +2393,23 @@ def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type:
2365
2393
  offset = (0,) * a.type.ndim
2366
2394
 
2367
2395
  func_args = (a, *offset)
2368
- template_args = shape
2396
+ template_args = (return_type.dtype, bounds_check.constant, *shape)
2369
2397
 
2370
2398
  return (func_args, template_args)
2371
2399
 
2372
2400
 
2373
2401
  add_builtin(
2374
2402
  "tile_load",
2375
- input_types={"a": array(dtype=Any), "shape": Tuple[int, ...], "offset": Tuple[int, ...], "storage": str},
2403
+ input_types={
2404
+ "a": array(dtype=Any),
2405
+ "shape": Tuple[int, ...],
2406
+ "offset": Tuple[int, ...],
2407
+ "storage": str,
2408
+ "bounds_check": builtins.bool,
2409
+ },
2376
2410
  value_func=tile_load_tuple_value_func,
2377
2411
  dispatch_func=tile_load_tuple_dispatch_func,
2378
- defaults={"offset": None, "storage": "register"},
2412
+ defaults={"offset": None, "storage": "register", "bounds_check": True},
2379
2413
  variadic=False,
2380
2414
  doc="""Loads a tile from a global memory array.
2381
2415
 
@@ -2386,6 +2420,7 @@ add_builtin(
2386
2420
  :param offset: Offset in the source array to begin reading from (optional)
2387
2421
  :param storage: The storage location for the tile: ``"register"`` for registers
2388
2422
  (default) or ``"shared"`` for shared memory.
2423
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster load times
2389
2424
  :returns: A tile with shape as specified and data type the same as the source array""",
2390
2425
  group="Tile Primitives",
2391
2426
  export=False,
@@ -2394,16 +2429,160 @@ add_builtin(
2394
2429
  # overload for scalar shape
2395
2430
  add_builtin(
2396
2431
  "tile_load",
2397
- input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
2432
+ input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str, "bounds_check": builtins.bool},
2398
2433
  value_func=tile_load_tuple_value_func,
2399
2434
  dispatch_func=tile_load_tuple_dispatch_func,
2400
- defaults={"offset": None, "storage": "register"},
2435
+ defaults={"offset": None, "storage": "register", "bounds_check": True},
2401
2436
  group="Tile Primitives",
2402
2437
  hidden=True,
2403
2438
  export=False,
2404
2439
  )
2405
2440
 
2406
2441
 
2442
+ def tile_load_indexed_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2443
+ if arg_types is None:
2444
+ return tile(dtype=Any, shape=Tuple[int, ...])
2445
+
2446
+ a = arg_types["a"]
2447
+
2448
+ indices_tile = arg_types["indices"]
2449
+ indices_tile.storage = "shared" # force to shared
2450
+
2451
+ axis = arg_values["axis"]
2452
+ if axis >= a.ndim:
2453
+ raise ValueError(f"tile_load_indexed() axis argument must be valid axis of array {a}, got {axis}.")
2454
+
2455
+ indices_tile_dim = len(indices_tile.shape)
2456
+ if indices_tile_dim != 1:
2457
+ raise ValueError(
2458
+ f"tile_load_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
2459
+ )
2460
+
2461
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2462
+
2463
+ if None in shape:
2464
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2465
+
2466
+ num_indices = indices_tile.shape[0]
2467
+ if num_indices != shape[axis]:
2468
+ raise ValueError(
2469
+ "The number of elements in the 1D indices tile must match the output tile shape along the specified axis."
2470
+ )
2471
+
2472
+ if "offset" in arg_values:
2473
+ offset = extract_tuple(arg_values["offset"])
2474
+ else:
2475
+ offset = (0,) * a.ndim
2476
+
2477
+ if a.ndim != len(shape):
2478
+ raise ValueError(
2479
+ f"tile_load_indexed() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
2480
+ )
2481
+
2482
+ if a.ndim != len(offset):
2483
+ raise ValueError(
2484
+ f"tile_load_indexed() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
2485
+ )
2486
+
2487
+ if arg_values["storage"] not in {"shared", "register"}:
2488
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2489
+
2490
+ return tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
2491
+
2492
+
2493
+ def tile_load_indexed_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2494
+ a = args["a"]
2495
+ indices_tile = args["indices"]
2496
+ axis = args["axis"]
2497
+
2498
+ shape = extract_tuple(args["shape"], as_constant=True)
2499
+
2500
+ if None in shape:
2501
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2502
+
2503
+ if "offset" in args:
2504
+ offset = extract_tuple(args["offset"])
2505
+ else:
2506
+ offset = (0,) * a.type.ndim
2507
+
2508
+ func_args = (a, indices_tile, axis, *offset)
2509
+ template_args = shape
2510
+
2511
+ return (func_args, template_args)
2512
+
2513
+
2514
+ add_builtin(
2515
+ "tile_load_indexed",
2516
+ input_types={
2517
+ "a": array(dtype=Any),
2518
+ "indices": tile(dtype=int, shape=Tuple[int]),
2519
+ "shape": Tuple[int, ...],
2520
+ "offset": Tuple[int, ...],
2521
+ "axis": int,
2522
+ "storage": str,
2523
+ },
2524
+ value_func=tile_load_indexed_tuple_value_func,
2525
+ dispatch_func=tile_load_indexed_tuple_dispatch_func,
2526
+ defaults={"offset": None, "axis": 0, "storage": "register"},
2527
+ variadic=False,
2528
+ doc="""Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.
2529
+
2530
+ :param a: The source array in global memory
2531
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
2532
+ :param shape: Shape of the tile to load, must have the same number of dimensions as ``a``, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
2533
+ :param offset: Offset in the source array to begin reading from (optional)
2534
+ :param axis: Axis of ``a`` that indices refer to
2535
+ :param storage: The storage location for the tile: ``"register"`` for registers (default) or ``"shared"`` for shared memory.
2536
+ :returns: A tile with shape as specified and data type the same as the source array
2537
+
2538
+ This example shows how to select and store the even indexed rows from a 2D array:
2539
+
2540
+ .. code-block:: python
2541
+
2542
+ TILE_M = wp.constant(2)
2543
+ TILE_N = wp.constant(2)
2544
+ HALF_M = wp.constant(TILE_M // 2)
2545
+ HALF_N = wp.constant(TILE_N // 2)
2546
+
2547
+ @wp.kernel
2548
+ def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2549
+ i, j = wp.tid()
2550
+
2551
+ evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2
2552
+
2553
+ t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register")
2554
+ wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N))
2555
+
2556
+ M = TILE_M * 2
2557
+ N = TILE_N * 2
2558
+
2559
+ arr = np.arange(M * N).reshape(M, N)
2560
+
2561
+ x = wp.array(arr, dtype=float)
2562
+ y = wp.zeros((M // 2, N), dtype=float)
2563
+
2564
+ wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
2565
+
2566
+ print(x.numpy())
2567
+ print(y.numpy())
2568
+
2569
+ Prints:
2570
+
2571
+ .. code-block:: text
2572
+
2573
+ [[ 0. 1. 2. 3.]
2574
+ [ 4. 5. 6. 7.]
2575
+ [ 8. 9. 10. 11.]
2576
+ [12. 13. 14. 15.]]
2577
+
2578
+ [[ 0. 1. 2. 3.]
2579
+ [ 8. 9. 10. 11.]]
2580
+ """,
2581
+ group="Tile Primitives",
2582
+ export=False,
2583
+ )
2584
+
2585
+
2407
2586
  def tile_store_value_func(arg_types, arg_values):
2408
2587
  # return generic type (for doc builds)
2409
2588
  if arg_types is None:
@@ -2440,6 +2619,7 @@ def tile_store_value_func(arg_types, arg_values):
2440
2619
  def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2441
2620
  a = args["a"]
2442
2621
  t = args["t"]
2622
+ bounds_check = args["bounds_check"]
2443
2623
 
2444
2624
  if "offset" in args:
2445
2625
  offset = extract_tuple(args["offset"])
@@ -2447,17 +2627,22 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2447
2627
  offset = (0,) * a.type.ndim
2448
2628
 
2449
2629
  func_args = (a, *offset, t)
2450
- template_args = []
2630
+ template_args = (a.type.dtype, bounds_check.constant)
2451
2631
 
2452
2632
  return (func_args, template_args)
2453
2633
 
2454
2634
 
2455
2635
  add_builtin(
2456
2636
  "tile_store",
2457
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
2637
+ input_types={
2638
+ "a": array(dtype=Any),
2639
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2640
+ "offset": Tuple[int, ...],
2641
+ "bounds_check": builtins.bool,
2642
+ },
2458
2643
  value_func=tile_store_value_func,
2459
2644
  dispatch_func=tile_store_dispatch_func,
2460
- defaults={"offset": None},
2645
+ defaults={"offset": None, "bounds_check": True},
2461
2646
  variadic=False,
2462
2647
  skip_replay=True,
2463
2648
  doc="""Store a tile to a global memory array.
@@ -2466,7 +2651,9 @@ add_builtin(
2466
2651
 
2467
2652
  :param a: The destination array in global memory
2468
2653
  :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
2469
- :param offset: Offset in the destination array (optional)""",
2654
+ :param offset: Offset in the destination array (optional)
2655
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
2656
+ """,
2470
2657
  group="Tile Primitives",
2471
2658
  export=False,
2472
2659
  )
@@ -2474,10 +2661,15 @@ add_builtin(
2474
2661
  # overload for scalar offset
2475
2662
  add_builtin(
2476
2663
  "tile_store",
2477
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
2664
+ input_types={
2665
+ "a": array(dtype=Any),
2666
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2667
+ "offset": int,
2668
+ "bounds_check": builtins.bool,
2669
+ },
2478
2670
  value_func=tile_store_value_func,
2479
2671
  dispatch_func=tile_store_dispatch_func,
2480
- defaults={"offset": None},
2672
+ defaults={"offset": None, "bounds_check": True},
2481
2673
  variadic=False,
2482
2674
  skip_replay=True,
2483
2675
  group="Tile Primitives",
@@ -2486,6 +2678,151 @@ add_builtin(
2486
2678
  )
2487
2679
 
2488
2680
 
2681
+ def tile_store_indexed_value_func(arg_types, arg_values):
2682
+ # return generic type (for doc builds)
2683
+ if arg_types is None:
2684
+ return None
2685
+
2686
+ a = arg_types["a"]
2687
+ t = arg_types["t"]
2688
+ indices_tile = arg_types["indices"]
2689
+ indices_tile.storage = "shared" # force to shared
2690
+
2691
+ axis = arg_values["axis"]
2692
+ if axis >= a.ndim:
2693
+ raise ValueError(f"tile_store_indexed() axis argument must be valid axis of array {a}, got {axis}.")
2694
+
2695
+ indices_tile_dim = len(indices_tile.shape)
2696
+ if indices_tile_dim != 1:
2697
+ raise ValueError(
2698
+ f"tile_store_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
2699
+ )
2700
+
2701
+ num_indices = indices_tile.shape[0]
2702
+ if num_indices != t.shape[axis]:
2703
+ raise ValueError(
2704
+ "The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
2705
+ )
2706
+
2707
+ if "offset" in arg_types:
2708
+ c = extract_tuple(arg_values["offset"])
2709
+ else:
2710
+ c = (0,) * a.ndim
2711
+
2712
+ if len(c) != a.ndim:
2713
+ raise ValueError(
2714
+ f"tile_store_indexed() 'a' argument must have {len(c)} dimensions, "
2715
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2716
+ )
2717
+
2718
+ if len(t.shape) != a.ndim:
2719
+ raise ValueError(
2720
+ f"tile_store_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
2721
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2722
+ )
2723
+
2724
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2725
+ raise TypeError(
2726
+ f"tile_store_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2727
+ )
2728
+
2729
+ return None
2730
+
2731
+
2732
+ def tile_store_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2733
+ a = args["a"]
2734
+ indices_tile = args["indices"]
2735
+ axis = args["axis"]
2736
+ t = args["t"]
2737
+
2738
+ if "offset" in args:
2739
+ offset = extract_tuple(args["offset"])
2740
+ else:
2741
+ offset = (0,) * a.type.ndim
2742
+
2743
+ func_args = (a, indices_tile, axis, *offset, t)
2744
+ template_args = []
2745
+
2746
+ return (func_args, template_args)
2747
+
2748
+
2749
+ add_builtin(
2750
+ "tile_store_indexed",
2751
+ input_types={
2752
+ "a": array(dtype=Any),
2753
+ "indices": tile(dtype=int, shape=Tuple[int]),
2754
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2755
+ "offset": Tuple[int, ...],
2756
+ "axis": int,
2757
+ },
2758
+ value_func=tile_store_indexed_value_func,
2759
+ dispatch_func=tile_store_indexed_dispatch_func,
2760
+ defaults={"offset": None, "axis": 0},
2761
+ variadic=False,
2762
+ skip_replay=True,
2763
+ doc="""Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
2764
+
2765
+ :param a: The destination array in global memory
2766
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
2767
+ :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
2768
+ :param offset: Offset in the destination array (optional)
2769
+ :param axis: Axis of ``a`` that indices refer to
2770
+
2771
+ This example shows how to map tile rows to the even rows of a 2D array:
2772
+
2773
+ .. code-block:: python
2774
+
2775
+ TILE_M = wp.constant(2)
2776
+ TILE_N = wp.constant(2)
2777
+ TWO_M = wp.constant(TILE_M * 2)
2778
+ TWO_N = wp.constant(TILE_N * 2)
2779
+
2780
+ @wp.kernel
2781
+ def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2782
+ i, j = wp.tid()
2783
+
2784
+ t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
2785
+
2786
+ evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2
2787
+
2788
+ wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0)
2789
+
2790
+ M = TILE_M * 2
2791
+ N = TILE_N * 2
2792
+
2793
+ arr = np.arange(M * N, dtype=float).reshape(M, N)
2794
+
2795
+ x = wp.array(arr, dtype=float, requires_grad=True, device=device)
2796
+ y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)
2797
+
2798
+ wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
2799
+
2800
+ print(x.numpy())
2801
+ print(y.numpy())
2802
+
2803
+ Prints:
2804
+
2805
+ .. code-block:: text
2806
+
2807
+ [[ 0. 1. 2. 3.]
2808
+ [ 4. 5. 6. 7.]
2809
+ [ 8. 9. 10. 11.]
2810
+ [12. 13. 14. 15.]]
2811
+
2812
+ [[ 0. 1. 2. 3.]
2813
+ [ 0. 0. 0. 0.]
2814
+ [ 4. 5. 6. 7.]
2815
+ [ 0. 0. 0. 0.]
2816
+ [ 8. 9. 10. 11.]
2817
+ [ 0. 0. 0. 0.]
2818
+ [12. 13. 14. 15.]
2819
+ [ 0. 0. 0. 0.]]
2820
+ """,
2821
+ group="Tile Primitives",
2822
+ export=False,
2823
+ )
2824
+
2825
+
2489
2826
  def tile_atomic_add_value_func(arg_types, arg_values):
2490
2827
  # return generic type (for doc builds)
2491
2828
  if arg_types is None:
@@ -2526,6 +2863,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
2526
2863
  def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2527
2864
  a = args["a"]
2528
2865
  t = args["t"]
2866
+ bounds_check = args["bounds_check"]
2529
2867
 
2530
2868
  if "offset" in args:
2531
2869
  offset = extract_tuple(args["offset"])
@@ -2533,17 +2871,22 @@ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type:
2533
2871
  offset = (0,) * a.type.ndim
2534
2872
 
2535
2873
  func_args = (a, *offset, t)
2536
- template_args = []
2874
+ template_args = (a.type.dtype, bounds_check.constant)
2537
2875
 
2538
2876
  return (func_args, template_args)
2539
2877
 
2540
2878
 
2541
2879
  add_builtin(
2542
2880
  "tile_atomic_add",
2543
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
2881
+ input_types={
2882
+ "a": array(dtype=Any),
2883
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2884
+ "offset": Tuple[int, ...],
2885
+ "bounds_check": builtins.bool,
2886
+ },
2544
2887
  value_func=tile_atomic_add_value_func,
2545
2888
  dispatch_func=tile_atomic_add_dispatch_func,
2546
- defaults={"offset": None},
2889
+ defaults={"offset": None, "bounds_check": True},
2547
2890
  variadic=False,
2548
2891
  skip_replay=True,
2549
2892
  doc="""Atomically add a tile onto the array `a`, each element will be updated atomically.
@@ -2551,6 +2894,7 @@ add_builtin(
2551
2894
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
2552
2895
  :param t: Source tile to add to the destination array
2553
2896
  :param offset: Offset in the destination array (optional)
2897
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
2554
2898
  :returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
2555
2899
  group="Tile Primitives",
2556
2900
  export=False,
@@ -2559,10 +2903,15 @@ add_builtin(
2559
2903
  # overload for scalar offset
2560
2904
  add_builtin(
2561
2905
  "tile_atomic_add",
2562
- input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
2906
+ input_types={
2907
+ "a": array(dtype=Any),
2908
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2909
+ "offset": int,
2910
+ "bounds_check": builtins.bool,
2911
+ },
2563
2912
  value_func=tile_atomic_add_value_func,
2564
2913
  dispatch_func=tile_atomic_add_dispatch_func,
2565
- defaults={"offset": None},
2914
+ defaults={"offset": None, "bounds_check": True},
2566
2915
  variadic=False,
2567
2916
  skip_replay=True,
2568
2917
  group="Tile Primitives",
@@ -2571,6 +2920,143 @@ add_builtin(
2571
2920
  )
2572
2921
 
2573
2922
 
2923
+ def tile_atomic_add_indexed_value_func(arg_types, arg_values):
2924
+ # return generic type (for doc builds)
2925
+ if arg_types is None:
2926
+ return tile(dtype=Any, shape=Tuple[int, ...])
2927
+
2928
+ a = arg_types["a"]
2929
+ t = arg_types["t"]
2930
+ indices_tile = arg_types["indices"]
2931
+ indices_tile.storage = "shared" # force to shared
2932
+
2933
+ axis = arg_values["axis"]
2934
+ if axis >= a.ndim:
2935
+ raise ValueError(f"tile_atomic_add_indexed() axis argument must be valid axis of array {a}, got {axis}.")
2936
+
2937
+ indices_tile_dim = len(indices_tile.shape)
2938
+ if indices_tile_dim != 1:
2939
+ raise ValueError(
2940
+ f"tile_atomic_add_indexed() indices argument must be a 1D tile, got {indices_tile_dim} dimensions instead."
2941
+ )
2942
+
2943
+ num_indices = indices_tile.shape[0]
2944
+ if num_indices != t.shape[axis]:
2945
+ raise ValueError(
2946
+ "The number of elements in the 1D indices tile must match the input tile shape along the specified axis."
2947
+ )
2948
+
2949
+ if "offset" in arg_types:
2950
+ c = extract_tuple(arg_values["offset"])
2951
+ else:
2952
+ c = (0,) * a.ndim
2953
+
2954
+ if len(c) != a.ndim:
2955
+ raise ValueError(
2956
+ f"tile_atomic_add_indexed() 'a' argument must have {len(c)} dimensions, "
2957
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2958
+ )
2959
+
2960
+ if len(t.shape) != a.ndim:
2961
+ raise ValueError(
2962
+ f"tile_atomic_add_indexed() 'a' argument must have the same number of dimensions as the 't' argument, "
2963
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2964
+ )
2965
+
2966
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2967
+ raise TypeError(
2968
+ f"tile_atomic_add_indexed() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2969
+ )
2970
+
2971
+ return tile(dtype=t.dtype, shape=t.shape, storage=t.storage)
2972
+
2973
+
2974
+ def tile_atomic_add_indexed_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2975
+ a = args["a"]
2976
+ indices_tile = args["indices"]
2977
+ axis = args["axis"]
2978
+ t = args["t"]
2979
+
2980
+ if "offset" in args:
2981
+ offset = extract_tuple(args["offset"])
2982
+ else:
2983
+ offset = (0,) * a.type.ndim
2984
+
2985
+ func_args = (a, indices_tile, axis, *offset, t)
2986
+ template_args = []
2987
+
2988
+ return (func_args, template_args)
2989
+
2990
+
2991
+ add_builtin(
2992
+ "tile_atomic_add_indexed",
2993
+ input_types={
2994
+ "a": array(dtype=Any),
2995
+ "indices": tile(dtype=int, shape=Tuple[int]),
2996
+ "t": tile(dtype=Any, shape=Tuple[int, ...]),
2997
+ "offset": Tuple[int, ...],
2998
+ "axis": int,
2999
+ },
3000
+ value_func=tile_atomic_add_indexed_value_func,
3001
+ dispatch_func=tile_atomic_add_indexed_dispatch_func,
3002
+ defaults={"offset": None, "axis": 0},
3003
+ variadic=False,
3004
+ skip_replay=True,
3005
+ doc="""Atomically add a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
3006
+
3007
+ :param a: The destination array in global memory
3008
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
3009
+ :param t: The source tile to extract data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
3010
+ :param offset: Offset in the destination array (optional)
3011
+ :param axis: Axis of ``a`` that indices refer to
3012
+
3013
+ This example shows how to compute a blocked, row-wise reduction:
3014
+
3015
+ .. code-block:: python
3016
+
3017
+ TILE_M = wp.constant(2)
3018
+ TILE_N = wp.constant(2)
3019
+
3020
+ @wp.kernel
3021
+ def tile_atomic_add_indexed(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
3022
+ i, j = wp.tid()
3023
+
3024
+ t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")
3025
+
3026
+ zeros = wp.tile_zeros(TILE_M, dtype=int, storage="shared")
3027
+
3028
+ wp.tile_atomic_add_indexed(y, indices=zeros, t=t, offset=(i, j*TILE_N), axis=0)
3029
+
3030
+ M = TILE_M * 2
3031
+ N = TILE_N * 2
3032
+
3033
+ arr = np.arange(M * N, dtype=float).reshape(M, N)
3034
+
3035
+ x = wp.array(arr, dtype=float, requires_grad=True, device=device)
3036
+ y = wp.zeros((2, N), dtype=float, requires_grad=True, device=device)
3037
+
3038
+ wp.launch_tiled(tile_atomic_add_indexed, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)
3039
+
3040
+ print(x.numpy())
3041
+ print(y.numpy())
3042
+
3043
+ Prints:
3044
+
3045
+ .. code-block:: text
3046
+
3047
+ [[ 0. 1. 2. 3.]
3048
+ [ 4. 5. 6. 7.]
3049
+ [ 8. 9. 10. 11.]
3050
+ [12. 13. 14. 15.]]
3051
+
3052
+ [[ 4. 6. 8. 10.]
3053
+ [20. 22. 24. 26.]]
3054
+ """,
3055
+ group="Tile Primitives",
3056
+ export=False,
3057
+ )
3058
+
3059
+
2574
3060
  def tile_view_value_func(arg_types, arg_values):
2575
3061
  # return generic type (for doc builds)
2576
3062
  if arg_types is None:
@@ -3525,6 +4011,7 @@ add_builtin(
3525
4011
  """,
3526
4012
  group="Tile Primitives",
3527
4013
  export=False,
4014
+ missing_grad=True,
3528
4015
  )
3529
4016
 
3530
4017
 
@@ -3578,6 +4065,7 @@ add_builtin(
3578
4065
  """,
3579
4066
  group="Tile Primitives",
3580
4067
  export=False,
4068
+ missing_grad=True,
3581
4069
  )
3582
4070
 
3583
4071
 
@@ -3631,6 +4119,7 @@ add_builtin(
3631
4119
  """,
3632
4120
  group="Tile Primitives",
3633
4121
  export=False,
4122
+ missing_grad=True,
3634
4123
  )
3635
4124
 
3636
4125
 
@@ -3683,6 +4172,7 @@ add_builtin(
3683
4172
  """,
3684
4173
  group="Tile Primitives",
3685
4174
  export=False,
4175
+ missing_grad=True,
3686
4176
  )
3687
4177
 
3688
4178
 
@@ -3735,6 +4225,7 @@ add_builtin(
3735
4225
  """,
3736
4226
  group="Tile Primitives",
3737
4227
  export=False,
4228
+ missing_grad=True,
3738
4229
  )
3739
4230
 
3740
4231
 
@@ -3792,6 +4283,7 @@ add_builtin(
3792
4283
  """,
3793
4284
  group="Tile Primitives",
3794
4285
  export=False,
4286
+ missing_grad=True,
3795
4287
  )
3796
4288
 
3797
4289
 
@@ -3855,6 +4347,7 @@ add_builtin(
3855
4347
  """,
3856
4348
  group="Tile Primitives",
3857
4349
  export=False,
4350
+ missing_grad=True,
3858
4351
  )
3859
4352
 
3860
4353
 
@@ -3918,6 +4411,7 @@ add_builtin(
3918
4411
  """,
3919
4412
  group="Tile Primitives",
3920
4413
  export=False,
4414
+ missing_grad=True,
3921
4415
  )
3922
4416
 
3923
4417
 
@@ -3934,14 +4428,45 @@ def tile_unary_map_value_func(arg_types, arg_values):
3934
4428
  if not is_tile(a):
3935
4429
  raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
3936
4430
 
3937
- return tile(dtype=a.dtype, shape=a.shape)
4431
+ if "op" in arg_values:
4432
+ op = arg_values["op"]
4433
+ try:
4434
+ overload = op.get_overload([a.dtype], {})
4435
+ except KeyError as exc:
4436
+ raise RuntimeError(f"No overload of {op} found for tile element type {type_repr(a.dtype)}") from exc
4437
+
4438
+ # build the right overload on demand
4439
+ if overload.value_func is None:
4440
+ overload.build(None)
4441
+
4442
+ value_type = overload.value_func(None, None)
4443
+
4444
+ if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
4445
+ raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
4446
+
4447
+ return tile(dtype=value_type, shape=a.shape)
4448
+
4449
+ else:
4450
+ return tile(dtype=a.dtype, shape=a.shape)
4451
+
4452
+
4453
+ def tile_unary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4454
+ op = arg_values["op"]
4455
+ tile_a = arg_values["a"]
4456
+
4457
+ overload = op.get_overload([tile_a.type.dtype], {})
4458
+
4459
+ # necessary, in case return type is different from input tile types
4460
+ tile_r = Var(label=None, type=return_type)
4461
+
4462
+ return ((overload, tile_a, tile_r), ())
3938
4463
 
3939
4464
 
3940
4465
  add_builtin(
3941
4466
  "tile_map",
3942
4467
  input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3943
4468
  value_func=tile_unary_map_value_func,
3944
- # dispatch_func=tile_map_dispatch_func,
4469
+ dispatch_func=tile_unary_map_dispatch_func,
3945
4470
  # variadic=True,
3946
4471
  native_func="tile_unary_map",
3947
4472
  doc="""Apply a unary function onto the tile.
@@ -3950,7 +4475,7 @@ add_builtin(
3950
4475
 
3951
4476
  :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
3952
4477
  :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
3953
- :returns: A tile with the same dimensions and data type as the input tile.
4478
+ :returns: A tile with the same dimensions as the input tile. Its datatype is specified by the return type of op
3954
4479
 
3955
4480
  Example:
3956
4481
 
@@ -3991,10 +4516,6 @@ def tile_binary_map_value_func(arg_types, arg_values):
3991
4516
  if not is_tile(b):
3992
4517
  raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
3993
4518
 
3994
- # ensure types equal
3995
- if not types_equal(a.dtype, b.dtype):
3996
- raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
3997
-
3998
4519
  if len(a.shape) != len(b.shape):
3999
4520
  raise ValueError(
4000
4521
  f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
@@ -4004,7 +4525,47 @@ def tile_binary_map_value_func(arg_types, arg_values):
4004
4525
  if a.shape[i] != b.shape[i]:
4005
4526
  raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
4006
4527
 
4007
- return tile(dtype=a.dtype, shape=a.shape)
4528
+ if "op" in arg_values:
4529
+ op = arg_values["op"]
4530
+ try:
4531
+ overload = op.get_overload([a.dtype, b.dtype], {})
4532
+ except KeyError as exc:
4533
+ raise RuntimeError(
4534
+ f"No overload of {op} found for tile element types {type_repr(a.dtype)}, {type_repr(b.dtype)}"
4535
+ ) from exc
4536
+
4537
+ # build the right overload on demand
4538
+ if overload.value_func is None:
4539
+ overload.build(None)
4540
+
4541
+ value_type = overload.value_func(None, None)
4542
+
4543
+ if not type_is_scalar(value_type) and not type_is_vector(value_type) and not type_is_matrix(value_type):
4544
+ raise TypeError(f"Operator {op} returns unsupported type {type_repr(value_type)} for a tile element")
4545
+
4546
+ return tile(dtype=value_type, shape=a.shape)
4547
+
4548
+ else:
4549
+ # ensure types equal
4550
+ if not types_equal(a.dtype, b.dtype):
4551
+ raise TypeError(
4552
+ f"tile_map() arguments must have the same dtype for this operation, got {a.dtype} and {b.dtype}"
4553
+ )
4554
+
4555
+ return tile(dtype=a.dtype, shape=a.shape)
4556
+
4557
+
4558
+ def tile_binary_map_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4559
+ op = arg_values["op"]
4560
+ tile_a = arg_values["a"]
4561
+ tile_b = arg_values["b"]
4562
+
4563
+ overload = op.get_overload([tile_a.type.dtype, tile_b.type.dtype], {})
4564
+
4565
+ # necessary, in case return type is different from input tile types
4566
+ tile_r = Var(label=None, type=return_type)
4567
+
4568
+ return ((overload, tile_a, tile_b, tile_r), ())
4008
4569
 
4009
4570
 
4010
4571
  add_builtin(
@@ -4015,18 +4576,18 @@ add_builtin(
4015
4576
  "b": tile(dtype=Scalar, shape=Tuple[int, ...]),
4016
4577
  },
4017
4578
  value_func=tile_binary_map_value_func,
4018
- # dispatch_func=tile_map_dispatch_func,
4579
+ dispatch_func=tile_binary_map_dispatch_func,
4019
4580
  # variadic=True,
4020
4581
  native_func="tile_binary_map",
4021
4582
  doc="""Apply a binary function onto the tile.
4022
4583
 
4023
4584
  This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
4024
- Both input tiles must have the same dimensions and datatype.
4585
+ Both input tiles must have the same dimensions, and if using a builtin op, the same datatypes.
4025
4586
 
4026
4587
  :param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
4027
4588
  :param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
4028
4589
  :param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
4029
- :returns: A tile with the same dimensions and datatype as the input tiles.
4590
+ :returns: A tile with the same dimensions as the input tiles. Its datatype is specified by the return type of op
4030
4591
 
4031
4592
  Example:
4032
4593
 
@@ -4104,6 +4665,7 @@ add_builtin(
4104
4665
  doc="WIP",
4105
4666
  group="Utility",
4106
4667
  hidden=True,
4668
+ missing_grad=True,
4107
4669
  )
4108
4670
 
4109
4671
  add_builtin(
@@ -4119,6 +4681,7 @@ add_builtin(
4119
4681
  doc="WIP",
4120
4682
  group="Utility",
4121
4683
  hidden=True,
4684
+ missing_grad=True,
4122
4685
  )
4123
4686
 
4124
4687
  add_builtin(
@@ -4128,6 +4691,7 @@ add_builtin(
4128
4691
  doc="WIP",
4129
4692
  group="Utility",
4130
4693
  hidden=True,
4694
+ missing_grad=True,
4131
4695
  )
4132
4696
 
4133
4697
  add_builtin(
@@ -4179,6 +4743,7 @@ add_builtin(
4179
4743
  :param low: The lower bound of the bounding box in BVH space
4180
4744
  :param high: The upper bound of the bounding box in BVH space""",
4181
4745
  export=False,
4746
+ missing_grad=True,
4182
4747
  )
4183
4748
 
4184
4749
  add_builtin(
@@ -4194,6 +4759,7 @@ add_builtin(
4194
4759
  :param start: The start of the ray in BVH space
4195
4760
  :param dir: The direction of the ray in BVH space""",
4196
4761
  export=False,
4762
+ missing_grad=True,
4197
4763
  )
4198
4764
 
4199
4765
  add_builtin(
@@ -4204,6 +4770,7 @@ add_builtin(
4204
4770
  doc="""Move to the next bound returned by the query.
4205
4771
  The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
4206
4772
  export=False,
4773
+ missing_grad=True,
4207
4774
  )
4208
4775
 
4209
4776
  add_builtin(
@@ -4538,12 +5105,13 @@ add_builtin(
4538
5105
  group="Geometry",
4539
5106
  doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
4540
5107
 
4541
- This query can be used to iterate over all triangles inside a volume.
5108
+ This query can be used to iterate over all bounding boxes of the triangles inside a volume.
4542
5109
 
4543
5110
  :param id: The mesh identifier
4544
5111
  :param low: The lower bound of the bounding box in mesh space
4545
5112
  :param high: The upper bound of the bounding box in mesh space""",
4546
5113
  export=False,
5114
+ missing_grad=True,
4547
5115
  )
4548
5116
 
4549
5117
  add_builtin(
@@ -4551,10 +5119,11 @@ add_builtin(
4551
5119
  input_types={"query": MeshQueryAABB, "index": int},
4552
5120
  value_type=builtins.bool,
4553
5121
  group="Geometry",
4554
- doc="""Move to the next triangle overlapping the query bounding box.
5122
+ doc="""Move to the next triangle whose bounding box overlaps the query bounding box.
4555
5123
 
4556
5124
  The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
4557
5125
  export=False,
5126
+ missing_grad=True,
4558
5127
  )
4559
5128
 
4560
5129
  add_builtin(
@@ -4584,6 +5153,7 @@ add_builtin(
4584
5153
 
4585
5154
  This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
4586
5155
  export=False,
5156
+ missing_grad=True,
4587
5157
  )
4588
5158
 
4589
5159
  add_builtin(
@@ -4595,6 +5165,7 @@ add_builtin(
4595
5165
 
4596
5166
  The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
4597
5167
  export=False,
5168
+ missing_grad=True,
4598
5169
  )
4599
5170
 
4600
5171
  add_builtin(
@@ -4608,6 +5179,7 @@ add_builtin(
4608
5179
 
4609
5180
  Returns -1 if the :class:`HashGrid` has not been reserved.""",
4610
5181
  export=False,
5182
+ missing_grad=True,
4611
5183
  )
4612
5184
 
4613
5185
  add_builtin(
@@ -4619,6 +5191,7 @@ add_builtin(
4619
5191
 
4620
5192
  Returns > 0 if triangles intersect.""",
4621
5193
  export=False,
5194
+ missing_grad=True,
4622
5195
  )
4623
5196
 
4624
5197
  add_builtin(
@@ -4638,6 +5211,7 @@ add_builtin(
4638
5211
  group="Geometry",
4639
5212
  doc="""Evaluates the face normal the mesh given a face index.""",
4640
5213
  export=False,
5214
+ missing_grad=True,
4641
5215
  )
4642
5216
 
4643
5217
  add_builtin(
@@ -4647,6 +5221,7 @@ add_builtin(
4647
5221
  group="Geometry",
4648
5222
  doc="""Returns the point of the mesh given a index.""",
4649
5223
  export=False,
5224
+ missing_grad=True,
4650
5225
  )
4651
5226
 
4652
5227
  add_builtin(
@@ -4656,6 +5231,7 @@ add_builtin(
4656
5231
  group="Geometry",
4657
5232
  doc="""Returns the velocity of the mesh given a index.""",
4658
5233
  export=False,
5234
+ missing_grad=True,
4659
5235
  )
4660
5236
 
4661
5237
  add_builtin(
@@ -4665,6 +5241,7 @@ add_builtin(
4665
5241
  group="Geometry",
4666
5242
  doc="""Returns the point-index of the mesh given a face-vertex index.""",
4667
5243
  export=False,
5244
+ missing_grad=True,
4668
5245
  )
4669
5246
 
4670
5247
 
@@ -4705,12 +5282,32 @@ add_builtin(
4705
5282
  # ---------------------------------
4706
5283
  # Iterators
4707
5284
 
4708
- add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
4709
5285
  add_builtin(
4710
- "iter_next", input_types={"query": HashGridQuery}, value_type=int, group="Utility", export=False, hidden=True
5286
+ "iter_next",
5287
+ input_types={"range": range_t},
5288
+ value_type=int,
5289
+ group="Utility",
5290
+ export=False,
5291
+ hidden=True,
5292
+ missing_grad=True,
5293
+ )
5294
+ add_builtin(
5295
+ "iter_next",
5296
+ input_types={"query": HashGridQuery},
5297
+ value_type=int,
5298
+ group="Utility",
5299
+ export=False,
5300
+ hidden=True,
5301
+ missing_grad=True,
4711
5302
  )
4712
5303
  add_builtin(
4713
- "iter_next", input_types={"query": MeshQueryAABB}, value_type=int, group="Utility", export=False, hidden=True
5304
+ "iter_next",
5305
+ input_types={"query": MeshQueryAABB},
5306
+ value_type=int,
5307
+ group="Utility",
5308
+ export=False,
5309
+ hidden=True,
5310
+ missing_grad=True,
4714
5311
  )
4715
5312
 
4716
5313
  add_builtin(
@@ -4721,6 +5318,7 @@ add_builtin(
4721
5318
  group="Utility",
4722
5319
  doc="""Returns the range in reversed order.""",
4723
5320
  export=False,
5321
+ missing_grad=True,
4724
5322
  )
4725
5323
 
4726
5324
  # ---------------------------------
@@ -4869,6 +5467,7 @@ add_builtin(
4869
5467
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
4870
5468
 
4871
5469
  If the voxel at this index does not exist, this function returns the background value.""",
5470
+ missing_grad=True,
4872
5471
  )
4873
5472
 
4874
5473
 
@@ -4889,6 +5488,7 @@ add_builtin(
4889
5488
  export=False,
4890
5489
  group="Volumes",
4891
5490
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5491
+ missing_grad=True,
4892
5492
  )
4893
5493
 
4894
5494
  add_builtin(
@@ -4919,6 +5519,7 @@ add_builtin(
4919
5519
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
4920
5520
 
4921
5521
  If the voxel at this index does not exist, this function returns the background value""",
5522
+ missing_grad=True,
4922
5523
  )
4923
5524
 
4924
5525
  add_builtin(
@@ -4927,6 +5528,7 @@ add_builtin(
4927
5528
  group="Volumes",
4928
5529
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
4929
5530
  export=False,
5531
+ missing_grad=True,
4930
5532
  )
4931
5533
 
4932
5534
  add_builtin(
@@ -4947,6 +5549,7 @@ add_builtin(
4947
5549
  doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
4948
5550
 
4949
5551
  If the voxel at this index does not exist, this function returns the background value.""",
5552
+ missing_grad=True,
4950
5553
  )
4951
5554
 
4952
5555
  add_builtin(
@@ -4955,6 +5558,7 @@ add_builtin(
4955
5558
  group="Volumes",
4956
5559
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
4957
5560
  export=False,
5561
+ missing_grad=True,
4958
5562
  )
4959
5563
 
4960
5564
  add_builtin(
@@ -4973,6 +5577,7 @@ add_builtin(
4973
5577
  doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
4974
5578
 
4975
5579
  If the voxel at this index does not exist, this function returns the background value.""",
5580
+ missing_grad=True,
4976
5581
  )
4977
5582
 
4978
5583
  add_builtin(
@@ -4981,6 +5586,7 @@ add_builtin(
4981
5586
  group="Volumes",
4982
5587
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
4983
5588
  export=False,
5589
+ missing_grad=True,
4984
5590
  )
4985
5591
 
4986
5592
 
@@ -5062,6 +5668,7 @@ add_builtin(
5062
5668
  If the voxel at this index does not exist, this function returns -1.
5063
5669
  This function is available for both index grids and classical volumes.
5064
5670
  """,
5671
+ missing_grad=True,
5065
5672
  )
5066
5673
 
5067
5674
  add_builtin(
@@ -5103,6 +5710,7 @@ add_builtin(
5103
5710
  value_type=uint32,
5104
5711
  group="Random",
5105
5712
  doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
5713
+ missing_grad=True,
5106
5714
  )
5107
5715
 
5108
5716
  add_builtin(
@@ -5114,6 +5722,7 @@ add_builtin(
5114
5722
 
5115
5723
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
5116
5724
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
5725
+ missing_grad=True,
5117
5726
  )
5118
5727
 
5119
5728
  add_builtin(
@@ -5122,6 +5731,7 @@ add_builtin(
5122
5731
  value_type=int,
5123
5732
  group="Random",
5124
5733
  doc="Return a random integer in the range [-2^31, 2^31).",
5734
+ missing_grad=True,
5125
5735
  )
5126
5736
  add_builtin(
5127
5737
  "randi",
@@ -5129,6 +5739,7 @@ add_builtin(
5129
5739
  value_type=int,
5130
5740
  group="Random",
5131
5741
  doc="Return a random integer between [low, high).",
5742
+ missing_grad=True,
5132
5743
  )
5133
5744
  add_builtin(
5134
5745
  "randu",
@@ -5136,6 +5747,7 @@ add_builtin(
5136
5747
  value_type=uint32,
5137
5748
  group="Random",
5138
5749
  doc="Return a random unsigned integer in the range [0, 2^32).",
5750
+ missing_grad=True,
5139
5751
  )
5140
5752
  add_builtin(
5141
5753
  "randu",
@@ -5143,6 +5755,7 @@ add_builtin(
5143
5755
  value_type=uint32,
5144
5756
  group="Random",
5145
5757
  doc="Return a random unsigned integer between [low, high).",
5758
+ missing_grad=True,
5146
5759
  )
5147
5760
  add_builtin(
5148
5761
  "randf",
@@ -5150,6 +5763,7 @@ add_builtin(
5150
5763
  value_type=float,
5151
5764
  group="Random",
5152
5765
  doc="Return a random float between [0.0, 1.0).",
5766
+ missing_grad=True,
5153
5767
  )
5154
5768
  add_builtin(
5155
5769
  "randf",
@@ -5157,6 +5771,7 @@ add_builtin(
5157
5771
  value_type=float,
5158
5772
  group="Random",
5159
5773
  doc="Return a random float between [low, high).",
5774
+ missing_grad=True,
5160
5775
  )
5161
5776
  add_builtin(
5162
5777
  "randn",
@@ -5164,6 +5779,7 @@ add_builtin(
5164
5779
  value_type=float,
5165
5780
  group="Random",
5166
5781
  doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
5782
+ missing_grad=True,
5167
5783
  )
5168
5784
 
5169
5785
  add_builtin(
@@ -5172,6 +5788,7 @@ add_builtin(
5172
5788
  value_type=int,
5173
5789
  group="Random",
5174
5790
  doc="Inverse-transform sample a cumulative distribution function.",
5791
+ missing_grad=True,
5175
5792
  )
5176
5793
  add_builtin(
5177
5794
  "sample_triangle",
@@ -5179,6 +5796,7 @@ add_builtin(
5179
5796
  value_type=vec2,
5180
5797
  group="Random",
5181
5798
  doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
5799
+ missing_grad=True,
5182
5800
  )
5183
5801
  add_builtin(
5184
5802
  "sample_unit_ring",
@@ -5186,6 +5804,7 @@ add_builtin(
5186
5804
  value_type=vec2,
5187
5805
  group="Random",
5188
5806
  doc="Uniformly sample a ring in the xy plane.",
5807
+ missing_grad=True,
5189
5808
  )
5190
5809
  add_builtin(
5191
5810
  "sample_unit_disk",
@@ -5193,6 +5812,7 @@ add_builtin(
5193
5812
  value_type=vec2,
5194
5813
  group="Random",
5195
5814
  doc="Uniformly sample a disk in the xy plane.",
5815
+ missing_grad=True,
5196
5816
  )
5197
5817
  add_builtin(
5198
5818
  "sample_unit_sphere_surface",
@@ -5200,6 +5820,7 @@ add_builtin(
5200
5820
  value_type=vec3,
5201
5821
  group="Random",
5202
5822
  doc="Uniformly sample a unit sphere surface.",
5823
+ missing_grad=True,
5203
5824
  )
5204
5825
  add_builtin(
5205
5826
  "sample_unit_sphere",
@@ -5207,6 +5828,7 @@ add_builtin(
5207
5828
  value_type=vec3,
5208
5829
  group="Random",
5209
5830
  doc="Uniformly sample a unit sphere.",
5831
+ missing_grad=True,
5210
5832
  )
5211
5833
  add_builtin(
5212
5834
  "sample_unit_hemisphere_surface",
@@ -5214,6 +5836,7 @@ add_builtin(
5214
5836
  value_type=vec3,
5215
5837
  group="Random",
5216
5838
  doc="Uniformly sample a unit hemisphere surface.",
5839
+ missing_grad=True,
5217
5840
  )
5218
5841
  add_builtin(
5219
5842
  "sample_unit_hemisphere",
@@ -5221,6 +5844,7 @@ add_builtin(
5221
5844
  value_type=vec3,
5222
5845
  group="Random",
5223
5846
  doc="Uniformly sample a unit hemisphere.",
5847
+ missing_grad=True,
5224
5848
  )
5225
5849
  add_builtin(
5226
5850
  "sample_unit_square",
@@ -5228,6 +5852,7 @@ add_builtin(
5228
5852
  value_type=vec2,
5229
5853
  group="Random",
5230
5854
  doc="Uniformly sample a unit square.",
5855
+ missing_grad=True,
5231
5856
  )
5232
5857
  add_builtin(
5233
5858
  "sample_unit_cube",
@@ -5235,6 +5860,7 @@ add_builtin(
5235
5860
  value_type=vec3,
5236
5861
  group="Random",
5237
5862
  doc="Uniformly sample a unit cube.",
5863
+ missing_grad=True,
5238
5864
  )
5239
5865
 
5240
5866
  add_builtin(
@@ -5246,6 +5872,7 @@ add_builtin(
5246
5872
 
5247
5873
  :param state: RNG state
5248
5874
  :param lam: The expected value of the distribution""",
5875
+ missing_grad=True,
5249
5876
  )
5250
5877
 
5251
5878
  add_builtin(
@@ -5363,9 +5990,16 @@ add_builtin(
5363
5990
  dispatch_func=printf_dispatch_func,
5364
5991
  group="Utility",
5365
5992
  doc="Allows printing formatted strings using C-style format specifiers.",
5993
+ missing_grad=True,
5366
5994
  )
5367
5995
 
5368
- add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
5996
+ add_builtin(
5997
+ "print",
5998
+ input_types={"value": Any},
5999
+ doc="Print variable to stdout",
6000
+ export=False,
6001
+ group="Utility",
6002
+ )
5369
6003
 
5370
6004
  add_builtin(
5371
6005
  "breakpoint",
@@ -5375,6 +6009,7 @@ add_builtin(
5375
6009
  group="Utility",
5376
6010
  namespace="",
5377
6011
  native_func="__debugbreak",
6012
+ missing_grad=True,
5378
6013
  )
5379
6014
 
5380
6015
  # helpers
@@ -5392,6 +6027,7 @@ add_builtin(
5392
6027
  This function may not be called from user-defined Warp functions.""",
5393
6028
  namespace="",
5394
6029
  native_func="builtin_tid1d",
6030
+ missing_grad=True,
5395
6031
  )
5396
6032
 
5397
6033
  add_builtin(
@@ -5402,6 +6038,7 @@ add_builtin(
5402
6038
  doc="Returns the number of threads in the current block.",
5403
6039
  namespace="",
5404
6040
  native_func="builtin_block_dim",
6041
+ missing_grad=True,
5405
6042
  )
5406
6043
 
5407
6044
  add_builtin(
@@ -5416,6 +6053,7 @@ add_builtin(
5416
6053
  This function may not be called from user-defined Warp functions.""",
5417
6054
  namespace="",
5418
6055
  native_func="builtin_tid2d",
6056
+ missing_grad=True,
5419
6057
  )
5420
6058
 
5421
6059
  add_builtin(
@@ -5430,6 +6068,7 @@ add_builtin(
5430
6068
  This function may not be called from user-defined Warp functions.""",
5431
6069
  namespace="",
5432
6070
  native_func="builtin_tid3d",
6071
+ missing_grad=True,
5433
6072
  )
5434
6073
 
5435
6074
  add_builtin(
@@ -5444,17 +6083,37 @@ add_builtin(
5444
6083
  This function may not be called from user-defined Warp functions.""",
5445
6084
  namespace="",
5446
6085
  native_func="builtin_tid4d",
6086
+ missing_grad=True,
5447
6087
  )
5448
6088
 
5449
6089
 
6090
+ def copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6091
+ a = arg_types["a"]
6092
+
6093
+ # if the input is a shared tile, we force a copy
6094
+ if is_tile(a) and a.storage == "shared":
6095
+ return tile(
6096
+ dtype=a.dtype,
6097
+ shape=a.shape,
6098
+ storage=a.storage,
6099
+ strides=a.strides,
6100
+ layout=a.layout,
6101
+ owner=True,
6102
+ )
6103
+
6104
+ return a
6105
+
6106
+
5450
6107
  add_builtin(
5451
6108
  "copy",
5452
6109
  input_types={"a": Any},
5453
- value_func=lambda arg_types, arg_values: arg_types["a"],
6110
+ value_func=copy_value_func,
5454
6111
  hidden=True,
5455
6112
  export=False,
5456
6113
  group="Utility",
5457
6114
  )
6115
+
6116
+
5458
6117
  add_builtin(
5459
6118
  "assign",
5460
6119
  input_types={"dest": Any, "src": Any},
@@ -5464,6 +6123,37 @@ add_builtin(
5464
6123
  )
5465
6124
 
5466
6125
 
6126
+ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6127
+ if arg_types is None:
6128
+ return Any
6129
+
6130
+ v_true = arg_types["value_if_true"]
6131
+ v_false = arg_types["value_if_false"]
6132
+
6133
+ if not types_equal(v_true, v_false):
6134
+ raise RuntimeError(
6135
+ f"select() true value type ({v_true}) must be of the same type as the false type ({v_false})"
6136
+ )
6137
+
6138
+ if is_tile(v_false):
6139
+ if v_true.storage == "register":
6140
+ return v_true
6141
+ if v_false.storage == "register":
6142
+ return v_false
6143
+
6144
+ # both v_true and v_false are shared
6145
+ return tile(
6146
+ dtype=v_true.dtype,
6147
+ shape=v_true.shape,
6148
+ storage=v_true.storage,
6149
+ strides=v_true.strides,
6150
+ layout=v_true.layout,
6151
+ owner=True,
6152
+ )
6153
+
6154
+ return v_true
6155
+
6156
+
5467
6157
  def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
5468
6158
  warp.utils.warn(
5469
6159
  "wp.select() is deprecated and will be removed in a future\n"
@@ -5480,7 +6170,7 @@ def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args
5480
6170
  add_builtin(
5481
6171
  "select",
5482
6172
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
5483
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6173
+ value_func=select_value_func,
5484
6174
  dispatch_func=select_dispatch_func,
5485
6175
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
5486
6176
 
@@ -5493,7 +6183,7 @@ for t in int_types:
5493
6183
  add_builtin(
5494
6184
  "select",
5495
6185
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
5496
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6186
+ value_func=select_value_func,
5497
6187
  dispatch_func=select_dispatch_func,
5498
6188
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
5499
6189
 
@@ -5505,7 +6195,7 @@ for t in int_types:
5505
6195
  add_builtin(
5506
6196
  "select",
5507
6197
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
5508
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6198
+ value_func=select_value_func,
5509
6199
  dispatch_func=select_dispatch_func,
5510
6200
  doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
5511
6201
 
@@ -5515,10 +6205,40 @@ add_builtin(
5515
6205
  group="Utility",
5516
6206
  )
5517
6207
 
6208
+
6209
+ def where_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6210
+ if arg_types is None:
6211
+ return Any
6212
+
6213
+ v_true = arg_types["value_if_true"]
6214
+ v_false = arg_types["value_if_false"]
6215
+
6216
+ if not types_equal(v_true, v_false):
6217
+ raise RuntimeError(f"where() true value type ({v_true}) must be of the same type as the false type ({v_false})")
6218
+
6219
+ if is_tile(v_false):
6220
+ if v_true.storage == "register":
6221
+ return v_true
6222
+ if v_false.storage == "register":
6223
+ return v_false
6224
+
6225
+ # both v_true and v_false are shared
6226
+ return tile(
6227
+ dtype=v_true.dtype,
6228
+ shape=v_true.shape,
6229
+ storage=v_true.storage,
6230
+ strides=v_true.strides,
6231
+ layout=v_true.layout,
6232
+ owner=True,
6233
+ )
6234
+
6235
+ return v_true
6236
+
6237
+
5518
6238
  add_builtin(
5519
6239
  "where",
5520
6240
  input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
5521
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6241
+ value_func=where_value_func,
5522
6242
  doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
5523
6243
  group="Utility",
5524
6244
  )
@@ -5526,14 +6246,14 @@ for t in int_types:
5526
6246
  add_builtin(
5527
6247
  "where",
5528
6248
  input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
5529
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6249
+ value_func=where_value_func,
5530
6250
  doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
5531
6251
  group="Utility",
5532
6252
  )
5533
6253
  add_builtin(
5534
6254
  "where",
5535
6255
  input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
5536
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6256
+ value_func=where_value_func,
5537
6257
  doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
5538
6258
  group="Utility",
5539
6259
  )
@@ -5544,7 +6264,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
5544
6264
  return array(dtype=Scalar)
5545
6265
 
5546
6266
  dtype = arg_values["dtype"]
5547
- shape = extract_tuple(arg_values["shape"], as_constant=True)
6267
+ shape = extract_tuple(arg_values["shape"], as_constant=False)
5548
6268
  return array(dtype=dtype, ndim=len(shape))
5549
6269
 
5550
6270
 
@@ -5554,7 +6274,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
5554
6274
  # to the underlying C++ function's runtime and template params.
5555
6275
 
5556
6276
  dtype = return_type.dtype
5557
- shape = extract_tuple(args["shape"], as_constant=True)
6277
+ shape = extract_tuple(args["shape"], as_constant=False)
5558
6278
 
5559
6279
  func_args = (args["ptr"], *shape)
5560
6280
  template_args = (dtype,)
@@ -5563,7 +6283,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
5563
6283
 
5564
6284
  add_builtin(
5565
6285
  "array",
5566
- input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Scalar},
6286
+ input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Any},
5567
6287
  value_func=array_value_func,
5568
6288
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
5569
6289
  dispatch_func=array_dispatch_func,
@@ -5575,6 +6295,48 @@ add_builtin(
5575
6295
  )
5576
6296
 
5577
6297
 
6298
+ def zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6299
+ if arg_types is None:
6300
+ return fixedarray(dtype=Scalar)
6301
+
6302
+ dtype = arg_values["dtype"]
6303
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
6304
+
6305
+ if None in shape:
6306
+ raise RuntimeError("the `shape` argument must be specified as a constant when zero-initializing an array")
6307
+
6308
+ return fixedarray(dtype=dtype, shape=shape)
6309
+
6310
+
6311
+ def zeros_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6312
+ # We're in the codegen stage where we emit the code calling the built-in.
6313
+ # Further validate the given argument values if needed and map them
6314
+ # to the underlying C++ function's runtime and template params.
6315
+
6316
+ dtype = return_type.dtype
6317
+ shape = extract_tuple(args["shape"], as_constant=True)
6318
+
6319
+ size = math.prod(shape)
6320
+
6321
+ func_args = shape
6322
+ template_args = (size, dtype)
6323
+ return (func_args, template_args)
6324
+
6325
+
6326
+ add_builtin(
6327
+ "zeros",
6328
+ input_types={"shape": Tuple[int, ...], "dtype": Any},
6329
+ value_func=zeros_value_func,
6330
+ export_func=lambda input_types: {},
6331
+ dispatch_func=zeros_dispatch_func,
6332
+ native_func="fixedarray_t",
6333
+ group="Utility",
6334
+ export=False,
6335
+ missing_grad=True,
6336
+ hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
6337
+ )
6338
+
6339
+
5578
6340
  # does argument checking and type propagation for address()
5579
6341
  def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5580
6342
  arr_type = arg_types["arr"]
@@ -5751,6 +6513,7 @@ add_builtin(
5751
6513
  hidden=True,
5752
6514
  skip_replay=True,
5753
6515
  group="Utility",
6516
+ missing_grad=True,
5754
6517
  )
5755
6518
 
5756
6519
 
@@ -5767,6 +6530,7 @@ add_builtin(
5767
6530
  dispatch_func=load_dispatch_func,
5768
6531
  hidden=True,
5769
6532
  group="Utility",
6533
+ missing_grad=True,
5770
6534
  )
5771
6535
 
5772
6536
 
@@ -5864,8 +6628,8 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
5864
6628
 
5865
6629
 
5866
6630
  for array_type in array_types:
5867
- # don't list indexed array operations explicitly in docs
5868
- hidden = array_type == indexedarray
6631
+ # don't list fixed or indexed array operations explicitly in docs
6632
+ hidden = array_type in (indexedarray, fixedarray)
5869
6633
 
5870
6634
  add_builtin(
5871
6635
  "atomic_add",
@@ -6083,6 +6847,7 @@ for array_type in array_types:
6083
6847
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6084
6848
  group="Utility",
6085
6849
  skip_replay=True,
6850
+ missing_grad=True,
6086
6851
  )
6087
6852
  add_builtin(
6088
6853
  "atomic_cas",
@@ -6096,6 +6861,7 @@ for array_type in array_types:
6096
6861
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6097
6862
  group="Utility",
6098
6863
  skip_replay=True,
6864
+ missing_grad=True,
6099
6865
  )
6100
6866
  add_builtin(
6101
6867
  "atomic_cas",
@@ -6109,6 +6875,7 @@ for array_type in array_types:
6109
6875
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6110
6876
  group="Utility",
6111
6877
  skip_replay=True,
6878
+ missing_grad=True,
6112
6879
  )
6113
6880
  add_builtin(
6114
6881
  "atomic_cas",
@@ -6130,6 +6897,7 @@ for array_type in array_types:
6130
6897
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6131
6898
  group="Utility",
6132
6899
  skip_replay=True,
6900
+ missing_grad=True,
6133
6901
  )
6134
6902
 
6135
6903
  add_builtin(
@@ -6144,6 +6912,7 @@ for array_type in array_types:
6144
6912
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6145
6913
  group="Utility",
6146
6914
  skip_replay=True,
6915
+ missing_grad=True,
6147
6916
  )
6148
6917
  add_builtin(
6149
6918
  "atomic_exch",
@@ -6157,6 +6926,7 @@ for array_type in array_types:
6157
6926
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6158
6927
  group="Utility",
6159
6928
  skip_replay=True,
6929
+ missing_grad=True,
6160
6930
  )
6161
6931
  add_builtin(
6162
6932
  "atomic_exch",
@@ -6170,6 +6940,7 @@ for array_type in array_types:
6170
6940
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6171
6941
  group="Utility",
6172
6942
  skip_replay=True,
6943
+ missing_grad=True,
6173
6944
  )
6174
6945
  add_builtin(
6175
6946
  "atomic_exch",
@@ -6187,46 +6958,110 @@ for array_type in array_types:
6187
6958
 
6188
6959
 
6189
6960
  # used to index into builtin types, i.e.: y = vec3[1]
6190
- def extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6191
- return arg_types["a"]._wp_scalar_type_
6961
+ def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6962
+ vec_type = arg_types["a"]
6963
+ idx_type = arg_types["i"]
6964
+
6965
+ if isinstance(idx_type, slice_t):
6966
+ length = idx_type.get_length(vec_type._length_)
6967
+ return vector(length=length, dtype=vec_type._wp_scalar_type_)
6968
+
6969
+ return vec_type._wp_scalar_type_
6970
+
6971
+
6972
+ def vector_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6973
+ func_args = tuple(args.values())
6974
+ template_args = getattr(return_type, "_shape_", ())
6975
+ return (func_args, template_args)
6192
6976
 
6193
6977
 
6194
6978
  add_builtin(
6195
6979
  "extract",
6196
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
6197
- value_func=extract_value_func,
6980
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any},
6981
+ value_func=vector_extract_value_func,
6982
+ dispatch_func=vector_extract_dispatch_func,
6983
+ export=False,
6198
6984
  hidden=True,
6199
6985
  group="Utility",
6200
6986
  )
6201
6987
  add_builtin(
6202
6988
  "extract",
6203
- input_types={"a": quaternion(dtype=Scalar), "i": int},
6204
- value_func=extract_value_func,
6989
+ input_types={"a": quaternion(dtype=Scalar), "i": Any},
6990
+ value_func=vector_extract_value_func,
6991
+ dispatch_func=vector_extract_dispatch_func,
6992
+ export=False,
6205
6993
  hidden=True,
6206
6994
  group="Utility",
6207
6995
  )
6208
-
6209
6996
  add_builtin(
6210
6997
  "extract",
6211
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
6212
- value_func=lambda arg_types, arg_values: vector(
6213
- length=arg_types["a"]._shape_[1], dtype=arg_types["a"]._wp_scalar_type_
6214
- ),
6998
+ input_types={"a": transformation(dtype=Scalar), "i": Any},
6999
+ value_func=vector_extract_value_func,
7000
+ dispatch_func=vector_extract_dispatch_func,
7001
+ export=False,
6215
7002
  hidden=True,
6216
7003
  group="Utility",
6217
7004
  )
7005
+
7006
+
7007
+ def matrix_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7008
+ mat_type = arg_types["a"]
7009
+ idx_types = tuple(arg_types[x] for x in "ij" if arg_types.get(x, None) is not None)
7010
+
7011
+ # Compute the resulting shape from the slicing, with -1 being simple indexing.
7012
+ shape = tuple(
7013
+ idx.get_length(mat_type._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idx_types)
7014
+ )
7015
+
7016
+ # Append any non indexed slice.
7017
+ for i in range(len(idx_types), len(mat_type._shape_)):
7018
+ shape += (mat_type._shape_[i],)
7019
+
7020
+ # Count how many dimensions the output value will have.
7021
+ ndim = sum(1 for x in shape if x >= 0)
7022
+
7023
+ if ndim == 0:
7024
+ return mat_type._wp_scalar_type_
7025
+
7026
+ assert shape[0] != -1 or shape[1] != -1
7027
+
7028
+ if ndim == 1:
7029
+ length = shape[0] if shape[0] != -1 else shape[1]
7030
+ return vector(length=length, dtype=mat_type._wp_scalar_type_)
7031
+
7032
+ assert ndim == 2
7033
+
7034
+ # When a matrix dimension is 0, all other dimensions are also expected to be 0.
7035
+ if any(x == 0 for x in shape):
7036
+ shape = (0,) * len(shape)
7037
+
7038
+ return matrix(shape=shape, dtype=mat_type._wp_scalar_type_)
7039
+
7040
+
7041
+ def matrix_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
7042
+ idx_types = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
7043
+ has_slice = any(isinstance(x, slice_t) for x in idx_types)
7044
+
7045
+ func_args = tuple(args.values())
7046
+ template_args = getattr(return_type, "_shape_", ()) if has_slice else ()
7047
+ return (func_args, template_args)
7048
+
7049
+
6218
7050
  add_builtin(
6219
7051
  "extract",
6220
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
6221
- value_func=extract_value_func,
7052
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any},
7053
+ value_func=matrix_extract_value_func,
7054
+ dispatch_func=matrix_extract_dispatch_func,
7055
+ export=False,
6222
7056
  hidden=True,
6223
7057
  group="Utility",
6224
7058
  )
6225
-
6226
7059
  add_builtin(
6227
7060
  "extract",
6228
- input_types={"a": transformation(dtype=Scalar), "i": int},
6229
- value_func=extract_value_func,
7061
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any},
7062
+ value_func=matrix_extract_value_func,
7063
+ dispatch_func=matrix_extract_dispatch_func,
7064
+ export=False,
6230
7065
  hidden=True,
6231
7066
  group="Utility",
6232
7067
  )
@@ -6247,6 +7082,19 @@ def vector_index_dispatch_func(input_types: Mapping[str, type], return_type: Any
6247
7082
  return (func_args, template_args)
6248
7083
 
6249
7084
 
7085
+ def matrix_ij_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7086
+ mat_type = arg_types["a"]
7087
+ value_type = mat_type._wp_scalar_type_
7088
+
7089
+ return Reference(value_type)
7090
+
7091
+
7092
+ def matrix_ij_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
7093
+ func_args = (Reference(args["a"]), args["i"], args["j"])
7094
+ template_args = ()
7095
+ return (func_args, template_args)
7096
+
7097
+
6250
7098
  # implements &vector[index]
6251
7099
  add_builtin(
6252
7100
  "index",
@@ -6256,6 +7104,7 @@ add_builtin(
6256
7104
  hidden=True,
6257
7105
  group="Utility",
6258
7106
  skip_replay=True,
7107
+ missing_grad=True,
6259
7108
  )
6260
7109
  # implements &quaternion[index]
6261
7110
  add_builtin(
@@ -6266,6 +7115,7 @@ add_builtin(
6266
7115
  hidden=True,
6267
7116
  group="Utility",
6268
7117
  skip_replay=True,
7118
+ missing_grad=True,
6269
7119
  )
6270
7120
  # implements &transformation[index]
6271
7121
  add_builtin(
@@ -6276,6 +7126,7 @@ add_builtin(
6276
7126
  hidden=True,
6277
7127
  group="Utility",
6278
7128
  skip_replay=True,
7129
+ missing_grad=True,
6279
7130
  )
6280
7131
  # implements &(*vector)[index]
6281
7132
  add_builtin(
@@ -6286,6 +7137,18 @@ add_builtin(
6286
7137
  hidden=True,
6287
7138
  group="Utility",
6288
7139
  skip_replay=True,
7140
+ missing_grad=True,
7141
+ )
7142
+ # implements &(*matrix)[i, j]
7143
+ add_builtin(
7144
+ "indexref",
7145
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
7146
+ value_func=matrix_ij_value_func,
7147
+ dispatch_func=matrix_ij_dispatch_func,
7148
+ hidden=True,
7149
+ group="Utility",
7150
+ skip_replay=True,
7151
+ missing_grad=True,
6289
7152
  )
6290
7153
  # implements &(*quaternion)[index]
6291
7154
  add_builtin(
@@ -6296,6 +7159,7 @@ add_builtin(
6296
7159
  hidden=True,
6297
7160
  group="Utility",
6298
7161
  skip_replay=True,
7162
+ missing_grad=True,
6299
7163
  )
6300
7164
  # implements &(*transformation)[index]
6301
7165
  add_builtin(
@@ -6306,14 +7170,50 @@ add_builtin(
6306
7170
  hidden=True,
6307
7171
  group="Utility",
6308
7172
  skip_replay=True,
7173
+ missing_grad=True,
6309
7174
  )
6310
7175
 
6311
7176
 
7177
+ def vector_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
7178
+ vec = args["a"].type
7179
+ idx = args["i"].type
7180
+ value_type = strip_reference(args["value"].type)
7181
+
7182
+ if isinstance(idx, slice_t):
7183
+ length = idx.get_length(vec._length_)
7184
+
7185
+ if type_is_vector(value_type):
7186
+ if not types_equal(value_type._wp_scalar_type_, vec._wp_scalar_type_):
7187
+ raise ValueError(
7188
+ f"The provided vector is expected to be of length {length} with dtype {type_repr(vec._wp_scalar_type_)}."
7189
+ )
7190
+ if value_type._length_ != length:
7191
+ raise ValueError(
7192
+ f"The length of the provided vector ({args['value'].type._length_}) isn't compatible with the given slice (expected {length})."
7193
+ )
7194
+ template_args = (length,)
7195
+ else:
7196
+ # Disallow broadcasting.
7197
+ raise ValueError(
7198
+ f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(vec._wp_scalar_type_)}."
7199
+ )
7200
+ else:
7201
+ if not types_equal(value_type, vec._wp_scalar_type_):
7202
+ raise ValueError(
7203
+ f"The provided value is expected to be a scalar of type {type_repr(vec._wp_scalar_type_)}."
7204
+ )
7205
+ template_args = ()
7206
+
7207
+ func_args = tuple(args.values())
7208
+ return (func_args, template_args)
7209
+
7210
+
6312
7211
  # implements vector[index] = value
6313
7212
  add_builtin(
6314
7213
  "assign_inplace",
6315
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
7214
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
6316
7215
  value_type=None,
7216
+ dispatch_func=vector_assign_dispatch_func,
6317
7217
  hidden=True,
6318
7218
  export=False,
6319
7219
  group="Utility",
@@ -6322,8 +7222,9 @@ add_builtin(
6322
7222
  # implements quaternion[index] = value
6323
7223
  add_builtin(
6324
7224
  "assign_inplace",
6325
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
7225
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
6326
7226
  value_type=None,
7227
+ dispatch_func=vector_assign_dispatch_func,
6327
7228
  hidden=True,
6328
7229
  export=False,
6329
7230
  group="Utility",
@@ -6331,15 +7232,16 @@ add_builtin(
6331
7232
  # implements transformation[index] = value
6332
7233
  add_builtin(
6333
7234
  "assign_inplace",
6334
- input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
7235
+ input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
6335
7236
  value_type=None,
7237
+ dispatch_func=vector_assign_dispatch_func,
6336
7238
  hidden=True,
6337
7239
  export=False,
6338
7240
  group="Utility",
6339
7241
  )
6340
7242
 
6341
7243
 
6342
- def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7244
+ def vector_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6343
7245
  vec_type = arg_types["a"]
6344
7246
  return vec_type
6345
7247
 
@@ -6347,8 +7249,9 @@ def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
6347
7249
  # implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6348
7250
  add_builtin(
6349
7251
  "assign_copy",
6350
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
6351
- value_func=vector_assign_value_func,
7252
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
7253
+ value_func=vector_assign_copy_value_func,
7254
+ dispatch_func=vector_assign_dispatch_func,
6352
7255
  hidden=True,
6353
7256
  export=False,
6354
7257
  group="Utility",
@@ -6357,8 +7260,9 @@ add_builtin(
6357
7260
  # implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6358
7261
  add_builtin(
6359
7262
  "assign_copy",
6360
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
6361
- value_func=vector_assign_value_func,
7263
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
7264
+ value_func=vector_assign_copy_value_func,
7265
+ dispatch_func=vector_assign_dispatch_func,
6362
7266
  hidden=True,
6363
7267
  export=False,
6364
7268
  group="Utility",
@@ -6367,8 +7271,9 @@ add_builtin(
6367
7271
  # implements transformation[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6368
7272
  add_builtin(
6369
7273
  "assign_copy",
6370
- input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
6371
- value_func=vector_assign_value_func,
7274
+ input_types={"a": transformation(dtype=Scalar), "i": Any, "value": Any},
7275
+ value_func=vector_assign_copy_value_func,
7276
+ dispatch_func=vector_assign_dispatch_func,
6372
7277
  hidden=True,
6373
7278
  export=False,
6374
7279
  group="Utility",
@@ -6377,8 +7282,9 @@ add_builtin(
6377
7282
  # implements vector[idx] += scalar
6378
7283
  add_builtin(
6379
7284
  "add_inplace",
6380
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
7285
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
6381
7286
  value_type=None,
7287
+ dispatch_func=vector_assign_dispatch_func,
6382
7288
  hidden=True,
6383
7289
  export=False,
6384
7290
  group="Utility",
@@ -6387,8 +7293,9 @@ add_builtin(
6387
7293
  # implements quaternion[idx] += scalar
6388
7294
  add_builtin(
6389
7295
  "add_inplace",
6390
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
7296
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
6391
7297
  value_type=None,
7298
+ dispatch_func=vector_assign_dispatch_func,
6392
7299
  hidden=True,
6393
7300
  export=False,
6394
7301
  group="Utility",
@@ -6397,8 +7304,9 @@ add_builtin(
6397
7304
  # implements transformation[idx] += scalar
6398
7305
  add_builtin(
6399
7306
  "add_inplace",
6400
- input_types={"a": transformation(dtype=Float), "i": int, "value": Float},
7307
+ input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
6401
7308
  value_type=None,
7309
+ dispatch_func=vector_assign_dispatch_func,
6402
7310
  hidden=True,
6403
7311
  export=False,
6404
7312
  group="Utility",
@@ -6417,8 +7325,9 @@ add_builtin(
6417
7325
  # implements vector[idx] -= scalar
6418
7326
  add_builtin(
6419
7327
  "sub_inplace",
6420
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
7328
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
6421
7329
  value_type=None,
7330
+ dispatch_func=vector_assign_dispatch_func,
6422
7331
  hidden=True,
6423
7332
  export=False,
6424
7333
  group="Utility",
@@ -6427,8 +7336,9 @@ add_builtin(
6427
7336
  # implements quaternion[idx] -= scalar
6428
7337
  add_builtin(
6429
7338
  "sub_inplace",
6430
- input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
7339
+ input_types={"a": quaternion(dtype=Scalar), "i": Any, "value": Any},
6431
7340
  value_type=None,
7341
+ dispatch_func=vector_assign_dispatch_func,
6432
7342
  hidden=True,
6433
7343
  export=False,
6434
7344
  group="Utility",
@@ -6437,8 +7347,9 @@ add_builtin(
6437
7347
  # implements transformation[idx] -= scalar
6438
7348
  add_builtin(
6439
7349
  "sub_inplace",
6440
- input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
7350
+ input_types={"a": transformation(dtype=Float), "i": Any, "value": Any},
6441
7351
  value_type=None,
7352
+ dispatch_func=vector_assign_dispatch_func,
6442
7353
  hidden=True,
6443
7354
  export=False,
6444
7355
  group="Utility",
@@ -6470,6 +7381,7 @@ add_builtin(
6470
7381
  hidden=True,
6471
7382
  group="Utility",
6472
7383
  skip_replay=True,
7384
+ missing_grad=True,
6473
7385
  )
6474
7386
 
6475
7387
 
@@ -6488,6 +7400,7 @@ add_builtin(
6488
7400
  hidden=True,
6489
7401
  group="Utility",
6490
7402
  skip_replay=True,
7403
+ missing_grad=True,
6491
7404
  )
6492
7405
 
6493
7406
 
@@ -6499,61 +7412,154 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
6499
7412
  return mat_size == vec_size and mat_type == vec_type
6500
7413
 
6501
7414
 
6502
- # implements matrix[i,j] = scalar
7415
+ def matrix_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
7416
+ mat = args["a"].type
7417
+ value_type = strip_reference(args["value"].type)
7418
+
7419
+ idxs = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
7420
+ has_slice = any(isinstance(x, slice_t) for x in idxs)
7421
+
7422
+ if has_slice:
7423
+ # Compute the resulting shape from the slicing, with -1 being simple indexing.
7424
+ shape = tuple(idx.get_length(mat._shape_[i]) if isinstance(idx, slice_t) else -1 for i, idx in enumerate(idxs))
7425
+
7426
+ # Append any non indexed slice.
7427
+ for i in range(len(idxs), len(mat._shape_)):
7428
+ shape += (mat._shape_[i],)
7429
+
7430
+ # Count how many dimensions the output value will have.
7431
+ ndim = sum(1 for x in shape if x >= 0)
7432
+ assert ndim > 0
7433
+
7434
+ if ndim == 1:
7435
+ length = shape[0] if shape[0] != -1 else shape[1]
7436
+
7437
+ if type_is_vector(value_type):
7438
+ if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
7439
+ raise ValueError(
7440
+ f"The provided vector is expected to be of length {length} with dtype {type_repr(mat._wp_scalar_type_)}."
7441
+ )
7442
+
7443
+ if value_type._length_ != length:
7444
+ raise ValueError(
7445
+ f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {length})."
7446
+ )
7447
+
7448
+ template_args = (length,)
7449
+ else:
7450
+ # Disallow broadcasting.
7451
+ raise ValueError(
7452
+ f"The provided value is expected to be a vector of length {length}, with dtype {type_repr(mat._wp_scalar_type_)}."
7453
+ )
7454
+ else:
7455
+ assert ndim == 2
7456
+
7457
+ # When a matrix dimension is 0, all other dimensions are also expected to be 0.
7458
+ if any(x == 0 for x in shape):
7459
+ shape = (0,) * len(shape)
7460
+
7461
+ if type_is_matrix(value_type):
7462
+ if not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
7463
+ raise ValueError(
7464
+ f"The provided matrix is expected to be of shape {shape} with dtype {type_repr(mat._wp_scalar_type_)}."
7465
+ )
7466
+
7467
+ if value_type._shape_ != shape:
7468
+ raise ValueError(
7469
+ f"The shape of the provided matrix ({value_type._shape_}) isn't compatible with the given slice (expected {shape})."
7470
+ )
7471
+
7472
+ template_args = shape
7473
+ else:
7474
+ # Disallow broadcasting.
7475
+ raise ValueError(
7476
+ f"The provided value is expected to be a matrix of shape {shape}, with dtype {type_repr(mat._wp_scalar_type_)}."
7477
+ )
7478
+ elif len(idxs) == 1:
7479
+ if not type_is_vector(value_type) or not types_equal(value_type._wp_scalar_type_, mat._wp_scalar_type_):
7480
+ raise ValueError(
7481
+ f"The provided value is expected to be a vector of length {mat._shape_[1]}, with dtype {type_repr(mat._wp_scalar_type_)}."
7482
+ )
7483
+
7484
+ if value_type._length_ != mat._shape_[1]:
7485
+ raise ValueError(
7486
+ f"The length of the provided vector ({value_type._length_}) isn't compatible with the given slice (expected {mat._shape_[1]})."
7487
+ )
7488
+
7489
+ template_args = ()
7490
+ elif len(idxs) == 2:
7491
+ if not types_equal(value_type, mat._wp_scalar_type_):
7492
+ raise ValueError(
7493
+ f"The provided value is expected to be a scalar of type {type_repr(mat._wp_scalar_type_)}."
7494
+ )
7495
+
7496
+ template_args = ()
7497
+ else:
7498
+ raise AssertionError
7499
+
7500
+ func_args = tuple(args.values())
7501
+ return (func_args, template_args)
7502
+
7503
+
7504
+ # implements matrix[i] = value
6503
7505
  add_builtin(
6504
7506
  "assign_inplace",
6505
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
7507
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
7508
+ constraint=matrix_vector_sametype,
6506
7509
  value_type=None,
7510
+ dispatch_func=matrix_assign_dispatch_func,
6507
7511
  hidden=True,
6508
7512
  export=False,
6509
7513
  group="Utility",
6510
7514
  )
6511
7515
 
6512
7516
 
6513
- # implements matrix[i] = vector
7517
+ # implements matrix[i,j] = value
6514
7518
  add_builtin(
6515
7519
  "assign_inplace",
6516
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
6517
- constraint=matrix_vector_sametype,
7520
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
6518
7521
  value_type=None,
7522
+ dispatch_func=matrix_assign_dispatch_func,
6519
7523
  hidden=True,
6520
7524
  export=False,
6521
7525
  group="Utility",
6522
7526
  )
6523
7527
 
6524
7528
 
6525
- def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7529
+ def matrix_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6526
7530
  mat_type = arg_types["a"]
6527
7531
  return mat_type
6528
7532
 
6529
7533
 
6530
- # implements matrix[i,j] = scalar
7534
+ # implements matrix[i] = value
6531
7535
  add_builtin(
6532
7536
  "assign_copy",
6533
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
6534
- value_func=matrix_assign_value_func,
7537
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
7538
+ value_func=matrix_assign_copy_value_func,
7539
+ dispatch_func=matrix_assign_dispatch_func,
6535
7540
  hidden=True,
6536
7541
  export=False,
6537
7542
  group="Utility",
6538
7543
  )
6539
7544
 
6540
7545
 
6541
- # implements matrix[i] = vector
7546
+ # implements matrix[i,j] = value
6542
7547
  add_builtin(
6543
7548
  "assign_copy",
6544
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
6545
- constraint=matrix_vector_sametype,
6546
- value_func=matrix_assign_value_func,
7549
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
7550
+ value_func=matrix_assign_copy_value_func,
7551
+ dispatch_func=matrix_assign_dispatch_func,
6547
7552
  hidden=True,
6548
7553
  export=False,
6549
7554
  group="Utility",
6550
7555
  )
6551
7556
 
6552
7557
 
6553
- # implements matrix[i,j] += scalar
7558
+ # implements matrix[i] += value
6554
7559
  add_builtin(
6555
7560
  "add_inplace",
6556
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
7561
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
7562
+ constraint=matrix_vector_sametype,
6557
7563
  value_type=None,
6558
7564
  hidden=True,
6559
7565
  export=False,
@@ -6561,11 +7567,10 @@ add_builtin(
6561
7567
  )
6562
7568
 
6563
7569
 
6564
- # implements matrix[i] += vector
7570
+ # implements matrix[i,j] += value
6565
7571
  add_builtin(
6566
7572
  "add_inplace",
6567
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
6568
- constraint=matrix_vector_sametype,
7573
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
6569
7574
  value_type=None,
6570
7575
  hidden=True,
6571
7576
  export=False,
@@ -6573,10 +7578,10 @@ add_builtin(
6573
7578
  )
6574
7579
 
6575
7580
 
6576
- # implements matrix[i,j] -= scalar
7581
+ # implements matrix[i] -= value
6577
7582
  add_builtin(
6578
7583
  "sub_inplace",
6579
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
7584
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
6580
7585
  value_type=None,
6581
7586
  hidden=True,
6582
7587
  export=False,
@@ -6584,10 +7589,10 @@ add_builtin(
6584
7589
  )
6585
7590
 
6586
7591
 
6587
- # implements matrix[i] -= vector
7592
+ # implements matrix[i,j] -= value
6588
7593
  add_builtin(
6589
7594
  "sub_inplace",
6590
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
7595
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
6591
7596
  value_type=None,
6592
7597
  hidden=True,
6593
7598
  export=False,
@@ -6606,6 +7611,7 @@ for t in scalar_types + vector_types + (bool,):
6606
7611
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
6607
7612
  group="Utility",
6608
7613
  hidden=True,
7614
+ missing_grad=True,
6609
7615
  )
6610
7616
 
6611
7617
  add_builtin(
@@ -6616,6 +7622,7 @@ for t in scalar_types + vector_types + (bool,):
6616
7622
  group="Utility",
6617
7623
  hidden=True,
6618
7624
  export=False,
7625
+ missing_grad=True,
6619
7626
  )
6620
7627
 
6621
7628
 
@@ -6634,6 +7641,7 @@ add_builtin(
6634
7641
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
6635
7642
  group="Utility",
6636
7643
  hidden=True,
7644
+ missing_grad=True,
6637
7645
  )
6638
7646
  add_builtin(
6639
7647
  "expect_neq",
@@ -6644,6 +7652,7 @@ add_builtin(
6644
7652
  group="Utility",
6645
7653
  hidden=True,
6646
7654
  export=False,
7655
+ missing_grad=True,
6647
7656
  )
6648
7657
 
6649
7658
  add_builtin(
@@ -6654,6 +7663,7 @@ add_builtin(
6654
7663
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
6655
7664
  group="Utility",
6656
7665
  hidden=True,
7666
+ missing_grad=True,
6657
7667
  )
6658
7668
  add_builtin(
6659
7669
  "expect_neq",
@@ -6664,6 +7674,7 @@ add_builtin(
6664
7674
  group="Utility",
6665
7675
  hidden=True,
6666
7676
  export=False,
7677
+ missing_grad=True,
6667
7678
  )
6668
7679
 
6669
7680
  add_builtin(
@@ -6754,6 +7765,7 @@ add_builtin(
6754
7765
  value_type=None,
6755
7766
  doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
6756
7767
  group="Utility",
7768
+ missing_grad=True,
6757
7769
  )
6758
7770
  add_builtin(
6759
7771
  "expect_near",
@@ -6763,6 +7775,7 @@ add_builtin(
6763
7775
  value_type=None,
6764
7776
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
6765
7777
  group="Utility",
7778
+ missing_grad=True,
6766
7779
  )
6767
7780
  add_builtin(
6768
7781
  "expect_near",
@@ -6772,6 +7785,7 @@ add_builtin(
6772
7785
  value_type=None,
6773
7786
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
6774
7787
  group="Utility",
7788
+ missing_grad=True,
6775
7789
  )
6776
7790
  add_builtin(
6777
7791
  "expect_near",
@@ -6785,6 +7799,7 @@ add_builtin(
6785
7799
  value_type=None,
6786
7800
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
6787
7801
  group="Utility",
7802
+ missing_grad=True,
6788
7803
  )
6789
7804
 
6790
7805
  # ---------------------------------
@@ -6795,6 +7810,7 @@ add_builtin(
6795
7810
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
6796
7811
  value_type=int,
6797
7812
  doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
7813
+ missing_grad=True,
6798
7814
  )
6799
7815
 
6800
7816
  add_builtin(
@@ -6802,11 +7818,13 @@ add_builtin(
6802
7818
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
6803
7819
  value_type=int,
6804
7820
  doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
7821
+ missing_grad=True,
6805
7822
  )
6806
7823
 
6807
7824
  # ---------------------------------
6808
7825
  # Operators
6809
7826
 
7827
+
6810
7828
  add_builtin(
6811
7829
  "add", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
6812
7830
  )
@@ -6876,13 +7894,36 @@ add_builtin(
6876
7894
  )
6877
7895
 
6878
7896
  # bitwise operators
6879
- add_builtin("bit_and", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
6880
- add_builtin("bit_or", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
6881
- add_builtin("bit_xor", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
6882
- add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
6883
- add_builtin("rshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
6884
- add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int))
6885
-
7897
+ add_builtin(
7898
+ "bit_and",
7899
+ input_types={"a": Int, "b": Int},
7900
+ value_func=sametypes_create_value_func(Int),
7901
+ group="Operators",
7902
+ missing_grad=True,
7903
+ )
7904
+ add_builtin(
7905
+ "bit_or",
7906
+ input_types={"a": Int, "b": Int},
7907
+ value_func=sametypes_create_value_func(Int),
7908
+ group="Operators",
7909
+ missing_grad=True,
7910
+ )
7911
+ add_builtin(
7912
+ "bit_xor",
7913
+ input_types={"a": Int, "b": Int},
7914
+ value_func=sametypes_create_value_func(Int),
7915
+ group="Operators",
7916
+ missing_grad=True,
7917
+ )
7918
+ add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
7919
+ add_builtin(
7920
+ "rshift",
7921
+ input_types={"a": Int, "b": Int},
7922
+ value_func=sametypes_create_value_func(Int),
7923
+ group="Operators",
7924
+ missing_grad=True,
7925
+ )
7926
+ add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
6886
7927
 
6887
7928
  add_builtin(
6888
7929
  "mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
@@ -7079,9 +8120,10 @@ add_builtin(
7079
8120
  "mod",
7080
8121
  input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
7081
8122
  constraint=sametypes,
7082
- value_func=sametypes_create_value_func(Scalar),
8123
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
7083
8124
  doc="Modulo operation using truncated division.",
7084
8125
  group="Operators",
8126
+ missing_grad=True,
7085
8127
  )
7086
8128
 
7087
8129
  add_builtin(
@@ -7141,6 +8183,7 @@ add_builtin(
7141
8183
  value_func=sametypes_create_value_func(Scalar),
7142
8184
  doc="",
7143
8185
  group="Operators",
8186
+ missing_grad=True,
7144
8187
  )
7145
8188
 
7146
8189
  add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
@@ -7188,12 +8231,16 @@ add_builtin(
7188
8231
  group="Operators",
7189
8232
  )
7190
8233
 
7191
- add_builtin("unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
8234
+ add_builtin(
8235
+ "unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
8236
+ )
7192
8237
  for t in int_types:
7193
- add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators")
8238
+ add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True)
7194
8239
 
7195
8240
 
7196
- add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")
8241
+ add_builtin(
8242
+ "unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
8243
+ )
7197
8244
 
7198
8245
 
7199
8246
  # Tile operators
@@ -7387,6 +8434,7 @@ add_builtin(
7387
8434
  doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
7388
8435
  group="Tile Primitives",
7389
8436
  export=False,
8437
+ missing_grad=True,
7390
8438
  )
7391
8439
 
7392
8440
 
@@ -7481,7 +8529,7 @@ def tile_matmul_lto_dispatch_func(
7481
8529
  num_threads = options["block_dim"]
7482
8530
  arch = options["output_arch"]
7483
8531
 
7484
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8532
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7485
8533
  # CPU/no-MathDx dispatch
7486
8534
  return ((0, 0, 0, a, b, out), template_args, [], 0)
7487
8535
  else:
@@ -7671,7 +8719,7 @@ def tile_fft_generic_lto_dispatch_func(
7671
8719
  arch = options["output_arch"]
7672
8720
  ept = size // num_threads
7673
8721
 
7674
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8722
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7675
8723
  # CPU/no-MathDx dispatch
7676
8724
  return ([], [], [], 0)
7677
8725
  else:
@@ -7714,6 +8762,7 @@ add_builtin(
7714
8762
  group="Tile Primitives",
7715
8763
  export=False,
7716
8764
  namespace="",
8765
+ missing_grad=True,
7717
8766
  )
7718
8767
 
7719
8768
  add_builtin(
@@ -7735,6 +8784,7 @@ add_builtin(
7735
8784
  group="Tile Primitives",
7736
8785
  export=False,
7737
8786
  namespace="",
8787
+ missing_grad=True,
7738
8788
  )
7739
8789
 
7740
8790
 
@@ -7792,28 +8842,27 @@ def tile_cholesky_generic_lto_dispatch_func(
7792
8842
  raise TypeError("tile_cholesky() returns one output")
7793
8843
  out = return_values[0]
7794
8844
 
7795
- dtype, precision_enum = cusolver_type_map[a.type.dtype]
7796
-
7797
8845
  # We already ensured a is square in tile_cholesky_generic_value_func()
7798
8846
  M, N = a.type.shape
7799
8847
  if out.type.shape[0] != M or out.type.shape[1] != M:
7800
8848
  raise ValueError("tile_cholesky() output tile must be square")
7801
8849
 
7802
- solver = "potrf"
7803
- solver_enum = cusolver_function_map[solver]
7804
-
7805
- side_enum = cusolver_side_map["-"]
7806
- diag_enum = cusolver_diag_map["-"]
7807
- fill_mode = cusolver_fill_mode_map["lower"]
7808
-
7809
8850
  arch = options["output_arch"]
7810
- num_threads = options["block_dim"]
7811
- parameter_list = f"({dtype}*, int*)"
7812
8851
 
7813
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8852
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7814
8853
  # CPU/no-MathDx dispatch
7815
8854
  return ((0, a, out), [], [], 0)
7816
8855
  else:
8856
+ solver = "potrf"
8857
+ solver_enum = cusolver_function_map[solver]
8858
+ side_enum = cusolver_side_map["-"]
8859
+ diag_enum = cusolver_diag_map["-"]
8860
+ fill_mode = cusolver_fill_mode_map["lower"]
8861
+ dtype, precision_enum = cusolver_type_map[a.type.dtype]
8862
+ num_threads = options["block_dim"]
8863
+ parameter_list = f"({dtype}*, int*)"
8864
+ req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
8865
+
7817
8866
  # generate the LTO
7818
8867
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
7819
8868
  M,
@@ -7831,6 +8880,7 @@ def tile_cholesky_generic_lto_dispatch_func(
7831
8880
  num_threads,
7832
8881
  parameter_list,
7833
8882
  builder,
8883
+ smem_estimate_bytes=req_smem_bytes,
7834
8884
  )
7835
8885
 
7836
8886
  return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
@@ -7859,6 +8909,7 @@ add_builtin(
7859
8909
  group="Tile Primitives",
7860
8910
  export=False,
7861
8911
  namespace="",
8912
+ missing_grad=True,
7862
8913
  )
7863
8914
 
7864
8915
 
@@ -7918,9 +8969,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
7918
8969
  if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
7919
8970
  raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
7920
8971
 
7921
- dtype, precision_enum = cusolver_type_map[L.type.dtype]
7922
8972
  M, N = L.type.shape
7923
- NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
7924
8973
 
7925
8974
  if len(x.type.shape) > 2 or len(x.type.shape) < 1:
7926
8975
  raise TypeError(f"tile_cholesky_solve() output vector must be 1D or 2D, got {len(x.type.shape)}-D")
@@ -7931,21 +8980,23 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
7931
8980
  f"got {x.type.shape[0]} elements in output and {M} rows in 'L'"
7932
8981
  )
7933
8982
 
7934
- solver = "potrs"
7935
- solver_enum = cusolver_function_map[solver]
7936
-
7937
- side_enum = cusolver_side_map["-"]
7938
- diag_enum = cusolver_diag_map["-"]
7939
- fill_mode = cusolver_fill_mode_map["lower"]
7940
-
7941
8983
  arch = options["output_arch"]
7942
- num_threads = options["block_dim"]
7943
- parameter_list = f"({dtype}*, {dtype}*)"
7944
8984
 
7945
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8985
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
7946
8986
  # CPU/no-MathDx dispatch
7947
8987
  return ((0, L, y, x), [], [], 0)
7948
8988
  else:
8989
+ NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
8990
+ solver = "potrs"
8991
+ solver_enum = cusolver_function_map[solver]
8992
+ side_enum = cusolver_side_map["-"]
8993
+ diag_enum = cusolver_diag_map["-"]
8994
+ fill_mode = cusolver_fill_mode_map["lower"]
8995
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
8996
+ num_threads = options["block_dim"]
8997
+ parameter_list = f"({dtype}*, {dtype}*)"
8998
+ req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8999
+
7949
9000
  # generate the LTO
7950
9001
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
7951
9002
  M,
@@ -7963,6 +9014,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
7963
9014
  num_threads,
7964
9015
  parameter_list,
7965
9016
  builder,
9017
+ smem_estimate_bytes=req_smem_bytes,
7966
9018
  )
7967
9019
 
7968
9020
  return ((Var(lto_symbol, str, False, True, False), L, y, x), [], [lto_code_data], 0)
@@ -7988,6 +9040,7 @@ add_builtin(
7988
9040
  group="Tile Primitives",
7989
9041
  export=False,
7990
9042
  namespace="",
9043
+ missing_grad=True,
7991
9044
  )
7992
9045
 
7993
9046
 
@@ -8013,9 +9066,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8013
9066
 
8014
9067
  z = return_values[0]
8015
9068
 
8016
- dtype, precision_enum = cusolver_type_map[L.type.dtype]
8017
9069
  M, N = L.type.shape
8018
- NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
8019
9070
 
8020
9071
  if len(z.type.shape) > 2 or len(z.type.shape) < 1:
8021
9072
  raise TypeError(f"tile_lower_solve() output vector must be 1D or 2D, got {len(z.type.shape)}-D")
@@ -8026,21 +9077,23 @@ def tile_lower_solve_generic_lto_dispatch_func(
8026
9077
  f"got {z.type.shape[0]} elements in output and {M} rows in 'L'"
8027
9078
  )
8028
9079
 
8029
- solver = "trsm"
8030
- solver_enum = cusolver_function_map[solver]
8031
-
8032
- side_enum = cusolver_side_map["left"]
8033
- diag_enum = cusolver_diag_map["nounit"]
8034
- fill_mode = cusolver_fill_mode_map["lower"]
8035
-
8036
9080
  arch = options["output_arch"]
8037
- num_threads = options["block_dim"]
8038
- parameter_list = f"({dtype}*, {dtype}*)"
8039
9081
 
8040
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
9082
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
8041
9083
  # CPU/no-MathDx dispatch
8042
9084
  return ((0, L, y, z), [], [], 0)
8043
9085
  else:
9086
+ NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
9087
+ solver = "trsm"
9088
+ solver_enum = cusolver_function_map[solver]
9089
+ side_enum = cusolver_side_map["left"]
9090
+ diag_enum = cusolver_diag_map["nounit"]
9091
+ fill_mode = cusolver_fill_mode_map["lower"]
9092
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
9093
+ num_threads = options["block_dim"]
9094
+ parameter_list = f"({dtype}*, {dtype}*)"
9095
+ req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
9096
+
8044
9097
  # generate the LTO
8045
9098
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
8046
9099
  M,
@@ -8058,6 +9111,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8058
9111
  num_threads,
8059
9112
  parameter_list,
8060
9113
  builder,
9114
+ smem_estimate_bytes=req_smem_bytes,
8061
9115
  )
8062
9116
 
8063
9117
  return ((Var(lto_symbol, str, False, True, False), L, y, z), [], [lto_code_data], 0)
@@ -8119,6 +9173,7 @@ add_builtin(
8119
9173
  group="Tile Primitives",
8120
9174
  export=False,
8121
9175
  namespace="",
9176
+ missing_grad=True,
8122
9177
  )
8123
9178
 
8124
9179
 
@@ -8144,9 +9199,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8144
9199
 
8145
9200
  x = return_values[0]
8146
9201
 
8147
- dtype, precision_enum = cusolver_type_map[U.type.dtype]
8148
9202
  M, N = U.type.shape
8149
- NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
8150
9203
 
8151
9204
  if len(z.type.shape) > 2 or len(z.type.shape) < 1:
8152
9205
  raise TypeError(f"tile_upper_solve() output tile must be 1D or 2D, got {len(z.type.shape)}-D")
@@ -8157,21 +9210,23 @@ def tile_upper_solve_generic_lto_dispatch_func(
8157
9210
  f"got {z.type.shape[0]} elements in output and {M} rows in 'U'"
8158
9211
  )
8159
9212
 
8160
- solver = "trsm"
8161
- solver_enum = cusolver_function_map[solver]
8162
-
8163
- side_enum = cusolver_side_map["left"]
8164
- diag_enum = cusolver_diag_map["nounit"]
8165
- fill_mode = cusolver_fill_mode_map["upper"]
8166
-
8167
9213
  arch = options["output_arch"]
8168
- num_threads = options["block_dim"]
8169
- parameter_list = f"({dtype}*, {dtype}*)"
8170
9214
 
8171
- if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
9215
+ if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
8172
9216
  # CPU/no-MathDx dispatch
8173
9217
  return ((0, U, z, x), [], [], 0)
8174
9218
  else:
9219
+ NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
9220
+ solver = "trsm"
9221
+ solver_enum = cusolver_function_map[solver]
9222
+ side_enum = cusolver_side_map["left"]
9223
+ diag_enum = cusolver_diag_map["nounit"]
9224
+ fill_mode = cusolver_fill_mode_map["upper"]
9225
+ dtype, precision_enum = cusolver_type_map[U.type.dtype]
9226
+ num_threads = options["block_dim"]
9227
+ parameter_list = f"({dtype}*, {dtype}*)"
9228
+ req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
9229
+
8175
9230
  # generate the LTO
8176
9231
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
8177
9232
  M,
@@ -8189,6 +9244,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8189
9244
  num_threads,
8190
9245
  parameter_list,
8191
9246
  builder,
9247
+ smem_estimate_bytes=req_smem_bytes,
8192
9248
  )
8193
9249
 
8194
9250
  return ((Var(lto_symbol, str, False, True, False), U, z, x), [], [lto_code_data], 0)
@@ -8250,6 +9306,7 @@ add_builtin(
8250
9306
  group="Tile Primitives",
8251
9307
  export=False,
8252
9308
  namespace="",
9309
+ missing_grad=True,
8253
9310
  )
8254
9311
 
8255
9312
 
@@ -8269,6 +9326,7 @@ add_builtin(
8269
9326
  The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
8270
9327
  (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
8271
9328
  group="Code Generation",
9329
+ missing_grad=True,
8272
9330
  )
8273
9331
 
8274
9332
 
@@ -8293,6 +9351,7 @@ add_builtin(
8293
9351
  doc="Return the number of elements in a vector.",
8294
9352
  group="Utility",
8295
9353
  export=False,
9354
+ missing_grad=True,
8296
9355
  )
8297
9356
 
8298
9357
  add_builtin(
@@ -8302,6 +9361,7 @@ add_builtin(
8302
9361
  doc="Return the number of elements in a quaternion.",
8303
9362
  group="Utility",
8304
9363
  export=False,
9364
+ missing_grad=True,
8305
9365
  )
8306
9366
 
8307
9367
  add_builtin(
@@ -8311,6 +9371,7 @@ add_builtin(
8311
9371
  doc="Return the number of rows in a matrix.",
8312
9372
  group="Utility",
8313
9373
  export=False,
9374
+ missing_grad=True,
8314
9375
  )
8315
9376
 
8316
9377
  add_builtin(
@@ -8320,6 +9381,7 @@ add_builtin(
8320
9381
  doc="Return the number of elements in a transformation.",
8321
9382
  group="Utility",
8322
9383
  export=False,
9384
+ missing_grad=True,
8323
9385
  )
8324
9386
 
8325
9387
  add_builtin(
@@ -8329,6 +9391,7 @@ add_builtin(
8329
9391
  doc="Return the size of the first dimension in an array.",
8330
9392
  group="Utility",
8331
9393
  export=False,
9394
+ missing_grad=True,
8332
9395
  )
8333
9396
 
8334
9397
  add_builtin(
@@ -8338,6 +9401,7 @@ add_builtin(
8338
9401
  doc="Return the number of rows in a tile.",
8339
9402
  group="Utility",
8340
9403
  export=False,
9404
+ missing_grad=True,
8341
9405
  )
8342
9406
 
8343
9407
 
@@ -8412,4 +9476,24 @@ add_builtin(
8412
9476
  doc="Return the number of elements in a tuple.",
8413
9477
  group="Utility",
8414
9478
  export=False,
9479
+ missing_grad=True,
9480
+ )
9481
+
9482
+ # ---------------------------------
9483
+ # Slicing
9484
+
9485
+
9486
+ def slice_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
9487
+ return slice_t(**arg_values)
9488
+
9489
+
9490
+ add_builtin(
9491
+ "slice",
9492
+ input_types={"start": int, "stop": int, "step": int},
9493
+ value_func=slice_value_func,
9494
+ native_func="slice_t",
9495
+ export=False,
9496
+ group="Utility",
9497
+ hidden=True,
9498
+ missing_grad=True,
8415
9499
  )