warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -13,13 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import builtins
17
19
  import functools
18
20
  from typing import Any, Callable, Mapping, Sequence
19
21
 
20
22
  import warp.build
21
23
  import warp.context
22
- from warp.codegen import Reference, Var, strip_reference
24
+ from warp.codegen import Reference, Var, get_arg_value, strip_reference
23
25
  from warp.types import *
24
26
 
25
27
  from .context import add_builtin
@@ -55,6 +57,33 @@ def sametypes_create_value_func(default: TypeVar):
55
57
  return fn
56
58
 
57
59
 
60
+ def extract_tuple(arg, as_constant=False):
61
+ if isinstance(arg, Var):
62
+ if isinstance(arg.type, warp.types.tuple_t):
63
+ out = arg.type.values
64
+ else:
65
+ out = (arg,)
66
+ elif isinstance(arg, warp.types.tuple_t):
67
+ out = arg.values
68
+ elif not isinstance(arg, Sequence):
69
+ out = (arg,)
70
+ else:
71
+ out = arg
72
+
73
+ if as_constant:
74
+ return tuple(x.constant if isinstance(x, Var) else x for x in out)
75
+
76
+ return out
77
+
78
+
79
+ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
80
+ if arg_types is None:
81
+ return int
82
+
83
+ length = warp.types.type_length(arg_types["a"])
84
+ return Var(None, type=int, constant=length)
85
+
86
+
58
87
  # ---------------------------------
59
88
  # Scalar Math
60
89
 
@@ -399,7 +428,7 @@ add_builtin(
399
428
  )
400
429
 
401
430
 
402
- def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
431
+ def scalar_infer_type(arg_types: Mapping[str, type] | tuple[type, ...] | None):
403
432
  if arg_types is None:
404
433
  return Scalar
405
434
 
@@ -1155,6 +1184,11 @@ add_builtin(
1155
1184
 
1156
1185
 
1157
1186
  def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1187
+ warp.utils.warn(
1188
+ "the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
1189
+ "and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
1190
+ DeprecationWarning,
1191
+ )
1158
1192
  if arg_types is None:
1159
1193
  return matrix(shape=(4, 4), dtype=Float)
1160
1194
 
@@ -1204,21 +1238,47 @@ add_builtin(
1204
1238
  dispatch_func=matrix_transform_dispatch_func,
1205
1239
  native_func="mat_t",
1206
1240
  doc="""Construct a 4x4 transformation matrix that applies the transformations as
1207
- Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
1241
+ Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
1242
+
1243
+ .. warning::
1244
+ This function has been deprecated in favor of :func:`warp.math.transform_compose()`.""",
1208
1245
  group="Vector Math",
1209
1246
  export=False,
1210
1247
  )
1211
1248
 
1212
1249
 
1213
- # not making these functions available outside kernels (export=False) as they
1214
- # return data via references, which we don't currently support:
1250
+ def svd3_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1251
+ if arg_types is None:
1252
+ return (
1253
+ matrix(shape=(3, 3), dtype=Float),
1254
+ vector(length=3, dtype=Float),
1255
+ matrix(shape=(3, 3), dtype=Float),
1256
+ )
1257
+
1258
+ dtype = arg_types["A"]._wp_scalar_type_
1259
+ return (
1260
+ matrix(shape=(3, 3), dtype=dtype),
1261
+ vector(length=3, dtype=dtype),
1262
+ matrix(shape=(3, 3), dtype=dtype),
1263
+ )
1264
+
1265
+
1266
+ add_builtin(
1267
+ "svd3",
1268
+ input_types={"A": matrix(shape=(3, 3), dtype=Float)},
1269
+ value_func=svd3_value_func,
1270
+ group="Vector Math",
1271
+ doc="""Compute the SVD of a 3x3 matrix ``A``. The singular values are returned in ``sigma``,
1272
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
1273
+ )
1274
+
1215
1275
  add_builtin(
1216
1276
  "svd3",
1217
1277
  input_types={
1218
1278
  "A": matrix(shape=(3, 3), dtype=Float),
1219
1279
  "U": matrix(shape=(3, 3), dtype=Float),
1220
1280
  "sigma": vector(length=3, dtype=Float),
1221
- "V": matrix(shape=(3, 3), dtype=Scalar),
1281
+ "V": matrix(shape=(3, 3), dtype=Float),
1222
1282
  },
1223
1283
  value_type=None,
1224
1284
  group="Vector Math",
@@ -1227,13 +1287,39 @@ add_builtin(
1227
1287
  while the left and right basis vectors are returned in ``U`` and ``V``.""",
1228
1288
  )
1229
1289
 
1290
+
1291
+ def svd2_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1292
+ if arg_types is None:
1293
+ return (
1294
+ matrix(shape=(2, 2), dtype=Float),
1295
+ vector(length=2, dtype=Float),
1296
+ matrix(shape=(2, 2), dtype=Float),
1297
+ )
1298
+
1299
+ dtype = arg_types["A"]._wp_scalar_type_
1300
+ return (
1301
+ matrix(shape=(2, 2), dtype=dtype),
1302
+ vector(length=2, dtype=dtype),
1303
+ matrix(shape=(2, 2), dtype=dtype),
1304
+ )
1305
+
1306
+
1307
+ add_builtin(
1308
+ "svd2",
1309
+ input_types={"A": matrix(shape=(2, 2), dtype=Float)},
1310
+ value_func=svd2_value_func,
1311
+ group="Vector Math",
1312
+ doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
1313
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
1314
+ )
1315
+
1230
1316
  add_builtin(
1231
1317
  "svd2",
1232
1318
  input_types={
1233
1319
  "A": matrix(shape=(2, 2), dtype=Float),
1234
1320
  "U": matrix(shape=(2, 2), dtype=Float),
1235
1321
  "sigma": vector(length=2, dtype=Float),
1236
- "V": matrix(shape=(2, 2), dtype=Scalar),
1322
+ "V": matrix(shape=(2, 2), dtype=Float),
1237
1323
  },
1238
1324
  value_type=None,
1239
1325
  group="Vector Math",
@@ -1242,6 +1328,30 @@ add_builtin(
1242
1328
  while the left and right basis vectors are returned in ``U`` and ``V``.""",
1243
1329
  )
1244
1330
 
1331
+
1332
+ def qr3_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1333
+ if arg_types is None:
1334
+ return (
1335
+ matrix(shape=(3, 3), dtype=Float),
1336
+ matrix(shape=(3, 3), dtype=Float),
1337
+ )
1338
+
1339
+ dtype = arg_types["A"]._wp_scalar_type_
1340
+ return (
1341
+ matrix(shape=(3, 3), dtype=dtype),
1342
+ matrix(shape=(3, 3), dtype=dtype),
1343
+ )
1344
+
1345
+
1346
+ add_builtin(
1347
+ "qr3",
1348
+ input_types={"A": matrix(shape=(3, 3), dtype=Float)},
1349
+ value_func=qr3_value_func,
1350
+ group="Vector Math",
1351
+ doc="""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
1352
+ while the upper triangular matrix is returned in ``R``.""",
1353
+ )
1354
+
1245
1355
  add_builtin(
1246
1356
  "qr3",
1247
1357
  input_types={
@@ -1256,6 +1366,27 @@ add_builtin(
1256
1366
  while the upper triangular matrix is returned in ``R``.""",
1257
1367
  )
1258
1368
 
1369
+
1370
+ def eig3_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1371
+ if arg_types is None:
1372
+ return (matrix(shape=(3, 3), dtype=Float), vector(length=3, dtype=Float))
1373
+
1374
+ dtype = arg_types["A"]._wp_scalar_type_
1375
+ return (
1376
+ matrix(shape=(3, 3), dtype=dtype),
1377
+ vector(length=3, dtype=dtype),
1378
+ )
1379
+
1380
+
1381
+ add_builtin(
1382
+ "eig3",
1383
+ input_types={"A": matrix(shape=(3, 3), dtype=Float)},
1384
+ value_func=eig3_value_func,
1385
+ group="Vector Math",
1386
+ doc="""Compute the eigendecomposition of a 3x3 matrix ``A``. The eigenvectors are returned as the columns of ``Q``,
1387
+ while the corresponding eigenvalues are returned in ``d``.""",
1388
+ )
1389
+
1259
1390
  add_builtin(
1260
1391
  "eig3",
1261
1392
  input_types={
@@ -1422,13 +1553,34 @@ add_builtin(
1422
1553
  group="Quaternion Math",
1423
1554
  doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
1424
1555
  )
1556
+
1557
+
1558
+ def quat_to_axis_angle_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1559
+ if arg_types is None:
1560
+ return (vector(length=3, dtype=Float), Float)
1561
+
1562
+ dtype = arg_types["quat"]._wp_scalar_type_
1563
+ return (vector(length=3, dtype=dtype), dtype)
1564
+
1565
+
1566
+ add_builtin(
1567
+ "quat_to_axis_angle",
1568
+ input_types={"quat": quaternion(dtype=Float)},
1569
+ value_func=quat_to_axis_angle_value_func,
1570
+ group="Quaternion Math",
1571
+ doc="Extract the rotation axis and angle radians a quaternion represents.",
1572
+ )
1573
+
1425
1574
  add_builtin(
1426
1575
  "quat_to_axis_angle",
1427
1576
  input_types={"quat": quaternion(dtype=Float), "axis": vector(length=3, dtype=Float), "angle": Float},
1428
1577
  value_type=None,
1429
1578
  group="Quaternion Math",
1430
1579
  doc="Extract the rotation axis and angle radians a quaternion represents.",
1580
+ export=False,
1431
1581
  )
1582
+
1583
+
1432
1584
  add_builtin(
1433
1585
  "quat_from_matrix",
1434
1586
  input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
@@ -1506,6 +1658,48 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1506
1658
  if arg_types is None:
1507
1659
  return transformation(dtype=Float)
1508
1660
 
1661
+ dtype = arg_values.get("dtype", None)
1662
+
1663
+ variadic_arg_types = arg_types.get("args", ())
1664
+ variadic_arg_count = len(variadic_arg_types)
1665
+ if variadic_arg_count == 0:
1666
+ # Zero-initialization, e.g.: `wp.transform()`, `wp.transformation(dtype=wp.float16)`.
1667
+ if dtype is None:
1668
+ dtype = float32
1669
+ elif variadic_arg_count == 1:
1670
+ # Initialization by filling a value, e.g.: `wp.transform(123)`,
1671
+ # `wp.transformation(123)`.
1672
+ value_type = strip_reference(variadic_arg_types[0])
1673
+ if dtype is None:
1674
+ dtype = value_type
1675
+ elif not warp.types.scalars_equal(value_type, dtype):
1676
+ raise RuntimeError(
1677
+ f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
1678
+ )
1679
+ elif variadic_arg_count == 7:
1680
+ # Initializing by value, e.g.: `wp.transform(1, 2, 3, 4, 5, 6, 7)`.
1681
+ try:
1682
+ value_type = scalar_infer_type(variadic_arg_types)
1683
+ except RuntimeError:
1684
+ raise RuntimeError("all values given when constructing a transform must have the same type") from None
1685
+
1686
+ if dtype is None:
1687
+ dtype = value_type
1688
+ elif not warp.types.scalars_equal(value_type, dtype):
1689
+ raise RuntimeError(
1690
+ f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
1691
+ )
1692
+
1693
+ if dtype is None:
1694
+ raise RuntimeError("could not infer the `dtype` argument when calling the `wp.transform()` function")
1695
+
1696
+ return transformation(dtype=dtype)
1697
+
1698
+
1699
+ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1700
+ if arg_types is None:
1701
+ return transformation(dtype=Float)
1702
+
1509
1703
  try:
1510
1704
  value_type = float_infer_type(arg_types)
1511
1705
  except RuntimeError:
@@ -1540,20 +1734,35 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1540
1734
 
1541
1735
  add_builtin(
1542
1736
  "transformation",
1543
- input_types={"pos": vector(length=3, dtype=Float), "rot": quaternion(dtype=Float), "dtype": Float},
1737
+ input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
1544
1738
  defaults={"dtype": None},
1545
- value_func=transformation_value_func,
1739
+ value_func=transformation_pq_value_func,
1546
1740
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1547
1741
  dispatch_func=transformation_dispatch_func,
1548
1742
  native_func="transform_t",
1549
1743
  group="Transformations",
1550
- doc="Construct a rigid-body transformation with translation part ``pos`` and rotation ``rot``.",
1744
+ doc="Construct a rigid-body transformation with translation part ``p`` and rotation ``q``.",
1745
+ export=False,
1746
+ )
1747
+
1748
+
1749
+ add_builtin(
1750
+ "transformation",
1751
+ input_types={"*args": Float, "dtype": Float},
1752
+ defaults={"dtype": None},
1753
+ variadic=True,
1754
+ initializer_list_func=lambda arg_types, arg_values: len(arg_types.get("args", ())) > 1,
1755
+ value_func=transformation_value_func,
1756
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k not in ("dtype")},
1757
+ dispatch_func=transformation_dispatch_func,
1758
+ native_func="transform_t",
1759
+ doc="Construct a spatial transform vector of given dtype.",
1760
+ group="Spatial Math",
1551
1761
  export=False,
1552
1762
  )
1553
1763
 
1554
1764
 
1555
1765
  def transform_identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1556
- # if arg_types is None then we are in 'export' mode
1557
1766
  if arg_types is None:
1558
1767
  # return transformation(dtype=Float)
1559
1768
  return transformf
@@ -1600,6 +1809,40 @@ add_builtin(
1600
1809
  group="Transformations",
1601
1810
  doc="Return the rotational part of a transform ``xform``.",
1602
1811
  )
1812
+ add_builtin(
1813
+ "transform_set_translation",
1814
+ input_types={"xform": transformation(dtype=Float), "p": vector(length=3, dtype=Float)},
1815
+ value_type=None,
1816
+ group="Transformations",
1817
+ doc="Set the translational part of a transform ``xform``.",
1818
+ )
1819
+ add_builtin(
1820
+ "transform_set_rotation",
1821
+ input_types={"xform": transformation(dtype=Float), "q": quaternion(dtype=Float)},
1822
+ value_type=None,
1823
+ group="Transformations",
1824
+ doc="Set the rotational part of a transform ``xform``.",
1825
+ )
1826
+ # performs a copy internally if wp.config.enable_vector_component_overwrites is True
1827
+ add_builtin(
1828
+ "transform_set_translation_copy",
1829
+ input_types={"xform": transformation(dtype=Float), "p": vector(length=3, dtype=Float)},
1830
+ value_type=transformation(dtype=Float),
1831
+ group="Transformations",
1832
+ doc="Set the translational part of a transform ``xform``.",
1833
+ hidden=True,
1834
+ export=False,
1835
+ )
1836
+ # performs a copy internally if wp.config.enable_vector_component_overwrites is True
1837
+ add_builtin(
1838
+ "transform_set_rotation_copy",
1839
+ input_types={"xform": transformation(dtype=Float), "q": quaternion(dtype=Float)},
1840
+ value_type=transformation(dtype=Float),
1841
+ group="Transformations",
1842
+ doc="Set the rotational part of a transform ``xform``.",
1843
+ hidden=True,
1844
+ export=False,
1845
+ )
1603
1846
  add_builtin(
1604
1847
  "transform_multiply",
1605
1848
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
@@ -1831,40 +2074,15 @@ add_builtin(
1831
2074
  # Tile-based primitives
1832
2075
 
1833
2076
 
1834
- def tile_unpack_shape(arg_values):
1835
- shape = arg_values["shape"]
1836
-
1837
- if not isinstance(shape, tuple):
1838
- # promote to tuple
1839
- shape = (shape,)
1840
-
1841
- # check that components are constants
1842
- for d in shape:
1843
- if d is None:
1844
- raise ValueError("Tile functions require shape to be a compile time constant.")
1845
-
1846
- return shape
1847
-
1848
-
1849
- def tile_unpack_offset(arg_values, ndim=0):
1850
- if "offset" in arg_values:
1851
- offset = arg_values["offset"]
1852
- else:
1853
- offset = (0,) * ndim
1854
-
1855
- if isinstance(offset, tuple):
1856
- return offset
1857
- else:
1858
- # promote to tuple
1859
- return (offset,)
1860
-
1861
-
1862
2077
  def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1863
2078
  # return generic type (for doc builds)
1864
2079
  if arg_types is None:
1865
- return Tile(dtype=Any, shape=Any)
2080
+ return tile(dtype=Any, shape=Tuple[int, ...])
2081
+
2082
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
1866
2083
 
1867
- shape = tile_unpack_shape(arg_values)
2084
+ if None in shape:
2085
+ raise ValueError("Tile functions require shape to be a compile time constant.")
1868
2086
 
1869
2087
  if "dtype" not in arg_values:
1870
2088
  raise TypeError("tile_zeros() missing required keyword argument 'dtype'")
@@ -1877,17 +2095,20 @@ def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1877
2095
 
1878
2096
  dtype = arg_values["dtype"]
1879
2097
 
1880
- return TileZeros(dtype=dtype, shape=shape, storage=arg_values["storage"])
2098
+ return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
1881
2099
 
1882
2100
 
1883
2101
  def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1884
- shape = tile_unpack_shape(arg_values)
2102
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2103
+
2104
+ if None in shape:
2105
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2106
+
1885
2107
  dtype = arg_values["dtype"]
1886
2108
 
1887
2109
  template_args = []
1888
2110
  template_args.append(dtype)
1889
- for d in shape:
1890
- template_args.append(d.constant)
2111
+ template_args.extend(shape)
1891
2112
 
1892
2113
  return ([], template_args)
1893
2114
 
@@ -1929,9 +2150,12 @@ add_builtin(
1929
2150
  def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1930
2151
  # return generic type (for doc builds)
1931
2152
  if arg_types is None:
1932
- return Tile(dtype=Any, shape=Any)
2153
+ return tile(dtype=Any, shape=Tuple[int, ...])
1933
2154
 
1934
- shape = tile_unpack_shape(arg_values)
2155
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2156
+
2157
+ if None in shape:
2158
+ raise ValueError("Tile functions require shape to be a compile time constant.")
1935
2159
 
1936
2160
  if "dtype" not in arg_values:
1937
2161
  raise TypeError("tile_ones() missing required keyword argument 'dtype'")
@@ -1944,17 +2168,20 @@ def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
1944
2168
 
1945
2169
  dtype = arg_values["dtype"]
1946
2170
 
1947
- return TileOnes(dtype=dtype, shape=shape, storage=arg_values["storage"])
2171
+ return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
1948
2172
 
1949
2173
 
1950
2174
  def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1951
- shape = tile_unpack_shape(arg_values)
2175
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2176
+
2177
+ if None in shape:
2178
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2179
+
1952
2180
  dtype = arg_values["dtype"]
1953
2181
 
1954
2182
  template_args = []
1955
2183
  template_args.append(dtype)
1956
- for d in shape:
1957
- template_args.append(d.constant)
2184
+ template_args.extend(shape)
1958
2185
 
1959
2186
  return ([], template_args)
1960
2187
 
@@ -1994,7 +2221,7 @@ add_builtin(
1994
2221
  def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1995
2222
  # return generic type (for doc builds)
1996
2223
  if arg_types is None:
1997
- return Tile(dtype=Any, shape=Any)
2224
+ return tile(dtype=Scalar, shape=Tuple[int])
1998
2225
 
1999
2226
  if "args" not in arg_values:
2000
2227
  raise TypeError("tile_arange() requires at least one positional argument specifying the range")
@@ -2029,7 +2256,8 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
2029
2256
  if arg_values["storage"] not in {"shared", "register"}:
2030
2257
  raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2031
2258
 
2032
- return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
2259
+ n = int((stop - start) / step)
2260
+ return tile(dtype=dtype, shape=(n,), storage=arg_values["storage"])
2033
2261
 
2034
2262
 
2035
2263
  def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
@@ -2045,13 +2273,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
2045
2273
  args = arg_values["args"]
2046
2274
 
2047
2275
  if len(args) == 1:
2048
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
2276
+ start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
2049
2277
  stop = args[0]
2050
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
2278
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2051
2279
  elif len(args) == 2:
2052
2280
  start = args[0]
2053
2281
  stop = args[1]
2054
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
2282
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2055
2283
  elif len(args) == 3:
2056
2284
  start = args[0]
2057
2285
  stop = args[1]
@@ -2069,7 +2297,7 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
2069
2297
 
2070
2298
  add_builtin(
2071
2299
  "tile_arange",
2072
- input_types={"*args": Scalar, "dtype": Any, "storage": str},
2300
+ input_types={"*args": Scalar, "dtype": Scalar, "storage": str},
2073
2301
  defaults={"dtype": None, "storage": "register"},
2074
2302
  value_func=tile_arange_value_func,
2075
2303
  dispatch_func=tile_arange_dispatch_func,
@@ -2094,12 +2322,19 @@ add_builtin(
2094
2322
 
2095
2323
  def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2096
2324
  if arg_types is None:
2097
- return array(dtype=Scalar)
2325
+ return tile(dtype=Any, shape=Tuple[int, ...])
2098
2326
 
2099
2327
  a = arg_types["a"]
2100
2328
 
2101
- shape = tile_unpack_shape(arg_values)
2102
- offset = tile_unpack_offset(arg_values, a.ndim)
2329
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2330
+
2331
+ if None in shape:
2332
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2333
+
2334
+ if "offset" in arg_values:
2335
+ offset = extract_tuple(arg_values["offset"])
2336
+ else:
2337
+ offset = (0,) * a.ndim
2103
2338
 
2104
2339
  if a.ndim != len(shape):
2105
2340
  raise ValueError(
@@ -2114,16 +2349,23 @@ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mappin
2114
2349
  if arg_values["storage"] not in {"shared", "register"}:
2115
2350
  raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2116
2351
 
2117
- return Tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
2352
+ return tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
2118
2353
 
2119
2354
 
2120
2355
  def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2121
2356
  a = args["a"]
2122
- shape = tile_unpack_shape(args)
2123
- offset = tile_unpack_offset(args, a.type.ndim)
2357
+ shape = extract_tuple(args["shape"], as_constant=True)
2358
+
2359
+ if None in shape:
2360
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2361
+
2362
+ if "offset" in args:
2363
+ offset = extract_tuple(args["offset"])
2364
+ else:
2365
+ offset = (0,) * a.type.ndim
2124
2366
 
2125
2367
  func_args = (a, *offset)
2126
- template_args = (d.constant for d in shape)
2368
+ template_args = shape
2127
2369
 
2128
2370
  return (func_args, template_args)
2129
2371
 
@@ -2170,7 +2412,10 @@ def tile_store_value_func(arg_types, arg_values):
2170
2412
  a = arg_types["a"]
2171
2413
  t = arg_types["t"]
2172
2414
 
2173
- c = tile_unpack_offset(arg_types, a.ndim)
2415
+ if "offset" in arg_types:
2416
+ c = extract_tuple(arg_values["offset"])
2417
+ else:
2418
+ c = (0,) * a.ndim
2174
2419
 
2175
2420
  if len(c) != a.ndim:
2176
2421
  raise ValueError(
@@ -2196,7 +2441,10 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2196
2441
  a = args["a"]
2197
2442
  t = args["t"]
2198
2443
 
2199
- offset = tile_unpack_offset(args, a.type.ndim)
2444
+ if "offset" in args:
2445
+ offset = extract_tuple(args["offset"])
2446
+ else:
2447
+ offset = (0,) * a.type.ndim
2200
2448
 
2201
2449
  func_args = (a, *offset, t)
2202
2450
  template_args = []
@@ -2206,7 +2454,7 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2206
2454
 
2207
2455
  add_builtin(
2208
2456
  "tile_store",
2209
- input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2457
+ input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
2210
2458
  value_func=tile_store_value_func,
2211
2459
  dispatch_func=tile_store_dispatch_func,
2212
2460
  defaults={"offset": None},
@@ -2226,7 +2474,7 @@ add_builtin(
2226
2474
  # overload for scalar offset
2227
2475
  add_builtin(
2228
2476
  "tile_store",
2229
- input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
2477
+ input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
2230
2478
  value_func=tile_store_value_func,
2231
2479
  dispatch_func=tile_store_dispatch_func,
2232
2480
  defaults={"offset": None},
@@ -2241,12 +2489,16 @@ add_builtin(
2241
2489
  def tile_atomic_add_value_func(arg_types, arg_values):
2242
2490
  # return generic type (for doc builds)
2243
2491
  if arg_types is None:
2244
- return Tile(dtype=Any, shape=Any)
2492
+ return tile(dtype=Any, shape=Tuple[int, ...])
2245
2493
 
2246
2494
  a = arg_types["a"]
2247
2495
  t = arg_types["t"]
2248
2496
 
2249
- c = tile_unpack_offset(arg_types, a.ndim)
2497
+ if "offset" in arg_types:
2498
+ c = extract_tuple(arg_values["offset"])
2499
+ else:
2500
+ c = (0,) * a.ndim
2501
+
2250
2502
  if len(c) != a.ndim:
2251
2503
  raise ValueError(
2252
2504
  f"tile_atomic_add() 'a' argument must have {len(c)} dimensions, "
@@ -2264,14 +2516,21 @@ def tile_atomic_add_value_func(arg_types, arg_values):
2264
2516
  f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2265
2517
  )
2266
2518
 
2267
- return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape, storage=arg_types["t"].storage)
2519
+ return tile(
2520
+ dtype=arg_types["t"].dtype,
2521
+ shape=arg_types["t"].shape,
2522
+ storage=arg_types["t"].storage,
2523
+ )
2268
2524
 
2269
2525
 
2270
2526
  def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2271
2527
  a = args["a"]
2272
2528
  t = args["t"]
2273
2529
 
2274
- offset = tile_unpack_offset(args, a.type.ndim)
2530
+ if "offset" in args:
2531
+ offset = extract_tuple(args["offset"])
2532
+ else:
2533
+ offset = (0,) * a.type.ndim
2275
2534
 
2276
2535
  func_args = (a, *offset, t)
2277
2536
  template_args = []
@@ -2281,13 +2540,13 @@ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type:
2281
2540
 
2282
2541
  add_builtin(
2283
2542
  "tile_atomic_add",
2284
- input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2543
+ input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
2285
2544
  value_func=tile_atomic_add_value_func,
2286
2545
  dispatch_func=tile_atomic_add_dispatch_func,
2287
2546
  defaults={"offset": None},
2288
2547
  variadic=False,
2289
2548
  skip_replay=True,
2290
- doc="""Atomically add a 1D tile to the array `a`, each element will be updated atomically.
2549
+ doc="""Atomically add a tile onto the array `a`, each element will be updated atomically.
2291
2550
 
2292
2551
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
2293
2552
  :param t: Source tile to add to the destination array
@@ -2300,7 +2559,7 @@ add_builtin(
2300
2559
  # overload for scalar offset
2301
2560
  add_builtin(
2302
2561
  "tile_atomic_add",
2303
- input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
2562
+ input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
2304
2563
  value_func=tile_atomic_add_value_func,
2305
2564
  dispatch_func=tile_atomic_add_dispatch_func,
2306
2565
  defaults={"offset": None},
@@ -2315,54 +2574,59 @@ add_builtin(
2315
2574
  def tile_view_value_func(arg_types, arg_values):
2316
2575
  # return generic type (for doc builds)
2317
2576
  if arg_types is None:
2318
- return Tile(dtype=Any, shape=Any)
2577
+ return tile(dtype=Any, shape=Tuple[int, ...])
2319
2578
 
2320
- tile = arg_types["t"]
2321
- offset = arg_types["offset"]
2579
+ tile_type = arg_types["t"]
2580
+ offset = extract_tuple(arg_values["offset"])
2322
2581
 
2323
- if len(offset) > len(tile.shape):
2324
- raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile.shape)}")
2582
+ if len(offset) > len(tile_type.shape):
2583
+ raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile_type.shape)}")
2325
2584
 
2326
2585
  if "shape" in arg_values:
2327
2586
  # if shape is specified take it directly, e.g.:
2328
2587
  # tile_view(t, offset=(i,j), shape=(m,n))
2329
- shape = arg_values["shape"]
2330
- strides = tile.strides
2588
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2589
+ strides = tile_type.strides
2331
2590
 
2332
- if len(shape) != len(tile.shape):
2591
+ if len(shape) != len(tile_type.shape):
2333
2592
  raise ValueError(
2334
- f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile.shape)}, got {len(shape)}"
2593
+ f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile_type.shape)}, got {len(shape)}"
2335
2594
  )
2336
2595
  else:
2337
2596
  # if not specified, then take output shape from unspecified src dimensions
2338
2597
  # e.g.: tile[i] will return a whole row of a 2D tile
2339
- shape = tile.shape[len(offset) :]
2340
- strides = tile.strides[len(offset) :]
2598
+ shape = tile_type.shape[len(offset) :]
2599
+ strides = tile_type.strides[len(offset) :]
2341
2600
 
2342
2601
  assert len(shape) == len(strides)
2343
2602
 
2344
2603
  # force source tile to shared memory
2345
- tile.storage = "shared"
2604
+ tile_type.storage = "shared"
2346
2605
 
2347
- output = Tile(dtype=tile.dtype, shape=shape, strides=strides, layout=tile.layout, storage="shared", owner=False)
2606
+ output = tile(
2607
+ dtype=tile_type.dtype, shape=shape, strides=strides, layout=tile_type.layout, storage="shared", owner=False
2608
+ )
2348
2609
  return output
2349
2610
 
2350
2611
 
2351
2612
  def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2352
2613
  tile = arg_values["t"]
2353
- coord = arg_values["offset"]
2614
+ coord = extract_tuple(arg_values["offset"])
2354
2615
 
2355
2616
  # zero-pad coord to match source array
2356
2617
  view_coord = [0] * len(tile.type.shape)
2357
2618
  for i in range(len(coord)):
2358
2619
  view_coord[i] = coord[i]
2359
2620
 
2360
- return ((tile, *view_coord), (return_type,))
2621
+ func_args = (tile, *view_coord)
2622
+ template_args = (return_type,)
2623
+
2624
+ return (func_args, template_args)
2361
2625
 
2362
2626
 
2363
2627
  add_builtin(
2364
2628
  "tile_view",
2365
- input_types={"t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
2629
+ input_types={"t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
2366
2630
  value_func=tile_view_value_func,
2367
2631
  dispatch_func=tile_view_dispatch_func,
2368
2632
  defaults={"shape": None},
@@ -2379,116 +2643,363 @@ add_builtin(
2379
2643
  )
2380
2644
 
2381
2645
 
2382
- def tile_assign_value_func(arg_types, arg_values):
2646
+ def tile_squeeze_value_func(arg_types, arg_values):
2647
+ # return generic type (for doc builds)
2383
2648
  if arg_types is None:
2384
- return None
2649
+ return tile(dtype=Any, shape=Tuple[int, ...])
2385
2650
 
2386
- # force the destination tile to shared memory
2387
- arg_types["dst"].storage = "shared"
2388
- return None
2651
+ tile_type = arg_types["t"]
2652
+ shape = tile_type.shape
2653
+ strides = tile_type.strides
2654
+ ndim = len(shape)
2389
2655
 
2656
+ if "axis" in arg_values:
2657
+ axis = arg_values["axis"]
2390
2658
 
2391
- def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2392
- dst = args["dst"]
2393
- src = args["src"]
2659
+ if not isinstance(axis, Sequence):
2660
+ # promote to tuple
2661
+ axis = (axis,)
2394
2662
 
2395
- offset = tile_unpack_offset(args, len(dst.type.shape))
2663
+ # promote negative indices to their positive equivalents
2664
+ axis = tuple([a if a >= 0 else a + ndim for a in axis])
2396
2665
 
2397
- func_args = (dst, src, *offset)
2398
- template_args = []
2666
+ # validate that specified axes are size 1
2667
+ for a in axis:
2668
+ if shape[a] != 1:
2669
+ raise ValueError(
2670
+ f"Cannot select an axis to squeeze out which has size not equal to one, axis={a}, size={shape[a]}"
2671
+ )
2399
2672
 
2400
- return (func_args, template_args)
2673
+ # build new shape by skipping specified axes (if size is 1)
2674
+ new_shape = tuple(dim for i, dim in enumerate(shape) if i not in axis)
2675
+ new_strides = tuple(stride for i, stride in enumerate(strides) if i not in axis)
2401
2676
 
2677
+ else:
2678
+ # no axis specified: remove all singleton dimensions
2679
+ new_shape = tuple(dim for dim in shape if dim != 1)
2680
+ new_strides = tuple(stride for i, stride in enumerate(strides) if shape[i] != 1)
2402
2681
 
2403
- add_builtin(
2404
- "tile_assign",
2405
- input_types={"dst": Tile(dtype=Any, shape=Any), "src": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2406
- value_func=tile_assign_value_func,
2407
- dispatch_func=tile_assign_dispatch_func,
2408
- defaults={"offset": None},
2409
- doc="""Assign a tile to a subrange of a destination tile.
2682
+ # force source tile to shared memory
2683
+ tile_type.storage = "shared"
2684
+
2685
+ output = tile(
2686
+ dtype=tile_type.dtype,
2687
+ shape=new_shape,
2688
+ strides=new_strides,
2689
+ layout=tile_type.layout,
2690
+ storage="shared",
2691
+ owner=False,
2692
+ )
2693
+ return output
2410
2694
 
2411
- :param dst: The destination tile to assign to
2412
- :param src: The source tile to read values from
2413
- :param offset: Offset in the destination tile to write to""",
2414
- group="Tile Primitives",
2415
- export=False,
2416
- )
2417
2695
 
2418
- # handles expressions like tile[i,j] = 1.0
2419
- add_builtin(
2420
- "assign",
2421
- input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "src": Scalar},
2422
- value_func=tile_assign_value_func,
2423
- group="Tile Primitives",
2424
- export=False,
2425
- hidden=True,
2426
- )
2696
+ def tile_squeeze_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2697
+ source_tile = arg_values["t"]
2427
2698
 
2428
- add_builtin(
2429
- "assign",
2430
- input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "src": Scalar},
2431
- value_func=tile_assign_value_func,
2432
- group="Tile Primitives",
2433
- export=False,
2434
- hidden=True,
2435
- )
2699
+ return ((source_tile,), (return_type,))
2436
2700
 
2437
- add_builtin(
2438
- "assign",
2439
- input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "src": Scalar},
2440
- value_func=tile_assign_value_func,
2441
- group="Tile Primitives",
2442
- export=False,
2443
- hidden=True,
2444
- )
2445
2701
 
2446
2702
  add_builtin(
2447
- "assign",
2448
- input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int, "src": Scalar},
2449
- value_func=tile_assign_value_func,
2703
+ "tile_squeeze",
2704
+ input_types={"t": tile(dtype=Any, shape=Tuple[int, ...]), "axis": Tuple[int, ...]},
2705
+ defaults={"axis": None},
2706
+ value_func=tile_squeeze_value_func,
2707
+ dispatch_func=tile_squeeze_dispatch_func,
2708
+ variadic=False,
2709
+ doc="""Return a squeezed view of a tile with the same data.
2710
+
2711
+ :param t: Input tile to squeeze
2712
+ :param axis: A subset of the entries of length one in the shape (optional)
2713
+ :returns: The input tile but with all or a subset of the dimensions of length one removed.""",
2450
2714
  group="Tile Primitives",
2451
2715
  export=False,
2452
- hidden=True,
2453
2716
  )
2454
2717
 
2455
2718
 
2456
- def tile_value_func(arg_types, arg_values):
2719
+ def tile_reshape_value_func(arg_types, arg_values):
2457
2720
  # return generic type (for doc builds)
2458
2721
  if arg_types is None:
2459
- return Tile
2460
-
2461
- if len(arg_types) != 1:
2462
- raise TypeError(f"tile() takes exactly 1 positional argument but {len(arg_types)} were given")
2722
+ return tile(dtype=Any, shape=Tuple[int, ...])
2463
2723
 
2464
- dtype = None
2465
- length = None
2724
+ tile_type = arg_types["t"]
2466
2725
 
2467
- if type_is_vector(arg_types["x"]):
2468
- dtype = arg_types["x"]._wp_scalar_type_
2469
- length = arg_types["x"]._shape_[0]
2470
- shape = (length, warp.codegen.options["block_dim"])
2471
- else:
2472
- dtype = arg_types["x"]
2473
- shape = (warp.codegen.options["block_dim"],)
2726
+ # calculate total size of tile_type
2727
+ size = 1
2728
+ for s in tile_type.shape:
2729
+ size *= int(s)
2474
2730
 
2475
- return Tile(dtype=dtype, shape=shape, op="tile")
2731
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2476
2732
 
2733
+ if None in shape:
2734
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2477
2735
 
2478
- add_builtin(
2479
- "tile",
2480
- input_types={"x": Any},
2481
- value_func=tile_value_func,
2482
- variadic=True,
2483
- doc="""Construct a new tile from per-thread kernel values.
2736
+ # check for -1 dimension and reformat
2737
+ if -1 in shape:
2738
+ idx = size
2739
+ denom = 1
2740
+ minus_one_count = 0
2741
+ for i, d in enumerate(shape):
2742
+ if d == -1:
2743
+ idx = i
2744
+ minus_one_count += 1
2745
+ else:
2746
+ denom *= d
2747
+ if minus_one_count > 1:
2748
+ raise RuntimeError("Cannot infer shape if more than one index is -1.")
2749
+ new_shape = list(shape)
2750
+ new_shape[idx] = int(size / denom)
2751
+ shape = tuple(new_shape)
2752
+
2753
+ # calculate total size of new shape
2754
+ new_size = 1
2755
+ for s in shape:
2756
+ new_size *= int(s)
2757
+
2758
+ if new_size != size:
2759
+ raise ValueError(f"New shape {shape} has total size {new_size} which does not match original size {size}")
2760
+
2761
+ # compute new strides matching shape
2762
+ strides = []
2763
+ stride = 1
2764
+ for s in reversed(shape):
2765
+ strides.append(stride)
2766
+ stride *= s
2767
+ strides = tuple(reversed(strides))
2484
2768
 
2485
- This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2769
+ # force source tile to shared memory
2770
+ tile_type.storage = "shared"
2771
+
2772
+ output = tile(
2773
+ dtype=tile_type.dtype, shape=shape, strides=strides, layout=tile_type.layout, storage="shared", owner=False
2774
+ )
2775
+ return output
2776
+
2777
+
2778
+ def tile_reshape_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2779
+ tile = arg_values["t"]
2780
+
2781
+ return ((tile,), (return_type,))
2782
+
2783
+
2784
+ add_builtin(
2785
+ "tile_reshape",
2786
+ input_types={"t": tile(dtype=Any, shape=Tuple[int, ...]), "shape": Tuple[int, ...]},
2787
+ value_func=tile_reshape_value_func,
2788
+ dispatch_func=tile_reshape_dispatch_func,
2789
+ variadic=False,
2790
+ doc="""Return a reshaped view of a tile with the same data.
2791
+
2792
+ :param t: Input tile to reshape
2793
+ :param shape: New shape for the tile
2794
+ :returns: A tile containing the same data as the input tile, but arranged in a new shape.""",
2795
+ group="Tile Primitives",
2796
+ export=False,
2797
+ )
2798
+
2799
+
2800
+ def tile_astype_value_func(arg_types, arg_values):
2801
+ # return generic type (for doc builds)
2802
+ if arg_types is None:
2803
+ return tile(dtype=Any, shape=Tuple[int, ...])
2804
+
2805
+ tile_type = arg_types["t"]
2806
+ dtype = arg_values["dtype"]
2807
+
2808
+ return tile(dtype=dtype, shape=tile_type.shape)
2809
+
2810
+
2811
+ def tile_astype_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2812
+ tile = arg_values["t"]
2813
+
2814
+ return ((tile,), (return_type,))
2815
+
2816
+
2817
+ add_builtin(
2818
+ "tile_astype",
2819
+ input_types={"t": tile(dtype=Scalar, shape=Tuple[int, ...]), "dtype": Scalar},
2820
+ value_func=tile_astype_value_func,
2821
+ dispatch_func=tile_astype_dispatch_func,
2822
+ variadic=False,
2823
+ doc="""Return a new tile with the same data as the input tile, but with a different data type.
2824
+
2825
+ :param t: Input tile
2826
+ :param dtype: New data type for the tile
2827
+ :returns: A tile with the same data as the input tile, but with a different data type""",
2828
+ group="Tile Primitives",
2829
+ export=False,
2830
+ )
2831
+
2832
+
2833
+ def tile_assign_value_func(arg_types, arg_values):
2834
+ if arg_types is None:
2835
+ return None
2836
+
2837
+ # force the destination tile to shared memory
2838
+ arg_types["dst"].storage = "shared"
2839
+ return None
2840
+
2841
+
2842
+ def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2843
+ dst = args["dst"]
2844
+ src = args["src"]
2845
+
2846
+ if "offset" in args:
2847
+ offset = extract_tuple(args["offset"])
2848
+ else:
2849
+ offset = (0,) * len(dst.type.shape)
2850
+
2851
+ func_args = (dst, src, *offset)
2852
+ template_args = []
2853
+
2854
+ return (func_args, template_args)
2855
+
2856
+
2857
+ add_builtin(
2858
+ "tile_assign",
2859
+ input_types={
2860
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
2861
+ "src": tile(dtype=Any, shape=Tuple[int, ...]),
2862
+ "offset": Tuple[int, ...],
2863
+ },
2864
+ value_func=tile_assign_value_func,
2865
+ dispatch_func=tile_assign_dispatch_func,
2866
+ defaults={"offset": None},
2867
+ doc="""Assign a tile to a subrange of a destination tile.
2868
+
2869
+ :param dst: The destination tile to assign to
2870
+ :param src: The source tile to read values from
2871
+ :param offset: Offset in the destination tile to write to""",
2872
+ group="Tile Primitives",
2873
+ export=False,
2874
+ )
2875
+
2876
+ # handles expressions like tile[i,j] = 1.0
2877
+ add_builtin(
2878
+ "assign",
2879
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int]), "i": int, "src": Any},
2880
+ value_func=tile_assign_value_func,
2881
+ group="Tile Primitives",
2882
+ export=False,
2883
+ hidden=True,
2884
+ )
2885
+
2886
+ add_builtin(
2887
+ "assign",
2888
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int, "src": Any},
2889
+ value_func=tile_assign_value_func,
2890
+ group="Tile Primitives",
2891
+ export=False,
2892
+ hidden=True,
2893
+ )
2894
+
2895
+ add_builtin(
2896
+ "assign",
2897
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int, "src": Any},
2898
+ value_func=tile_assign_value_func,
2899
+ group="Tile Primitives",
2900
+ export=False,
2901
+ hidden=True,
2902
+ )
2903
+
2904
+ add_builtin(
2905
+ "assign",
2906
+ input_types={
2907
+ "dst": tile(dtype=Any, shape=Tuple[int, int, int, int]),
2908
+ "i": int,
2909
+ "j": int,
2910
+ "k": int,
2911
+ "l": int,
2912
+ "src": Any,
2913
+ },
2914
+ value_func=tile_assign_value_func,
2915
+ group="Tile Primitives",
2916
+ export=False,
2917
+ hidden=True,
2918
+ )
2919
+
2920
+
2921
+ def tile_value_func(arg_types, arg_values):
2922
+ # return generic type (for doc builds)
2923
+ if arg_types is None:
2924
+ return tile(dtype=Any, shape=Tuple)
2925
+
2926
+ if len(arg_types) > 2:
2927
+ raise TypeError(f"tile() takes 1 positional argument and 1 optional argument but {len(arg_types)} were given")
2928
+
2929
+ preserve_type = arg_values["preserve_type"]
2930
+
2931
+ if preserve_type:
2932
+ dtype = arg_types["x"]
2933
+ shape = (warp.codegen.options["block_dim"],)
2934
+
2935
+ return tile(dtype=dtype, shape=shape)
2936
+
2937
+ else:
2938
+ if type_is_vector(arg_types["x"]):
2939
+ dtype = arg_types["x"]._wp_scalar_type_
2940
+ length = arg_types["x"]._shape_[0]
2941
+ shape = (length, warp.codegen.options["block_dim"])
2942
+ elif type_is_quaternion(arg_types["x"]):
2943
+ dtype = arg_types["x"]._wp_scalar_type_
2944
+ shape = (4, warp.codegen.options["block_dim"])
2945
+ elif type_is_matrix(arg_types["x"]):
2946
+ dtype = arg_types["x"]._wp_scalar_type_
2947
+ rows = arg_types["x"]._shape_[0]
2948
+ cols = arg_types["x"]._shape_[1]
2949
+ shape = (rows, cols, warp.codegen.options["block_dim"])
2950
+ else:
2951
+ dtype = arg_types["x"]
2952
+ shape = (warp.codegen.options["block_dim"],)
2953
+
2954
+ return tile(dtype=dtype, shape=shape)
2955
+
2956
+
2957
+ def tile_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2958
+ x = arg_values["x"]
2959
+ preserve_type = get_arg_value(arg_values["preserve_type"])
2960
+
2961
+ if preserve_type:
2962
+ dtype = x.type
2963
+ return ((x,), (dtype,))
2964
+
2965
+ else:
2966
+ if type_is_vector(x.type):
2967
+ dtype = x.type._wp_scalar_type_
2968
+ length = x.type._shape_[0]
2969
+ return ((x,), (dtype, length))
2970
+ elif type_is_quaternion(x.type):
2971
+ dtype = x.type._wp_scalar_type_
2972
+ return ((x,), (dtype, 4))
2973
+ elif type_is_matrix(x.type):
2974
+ dtype = x.type._wp_scalar_type_
2975
+ rows = x.type._shape_[0]
2976
+ cols = x.type._shape_[1]
2977
+ return ((x,), (rows, cols, dtype))
2978
+ else:
2979
+ dtype = x.type
2980
+ return ((x,), (dtype,))
2981
+
2982
+
2983
+ add_builtin(
2984
+ "tile",
2985
+ input_types={"x": Any, "preserve_type": bool},
2986
+ value_func=tile_value_func,
2987
+ dispatch_func=tile_dispatch_func,
2988
+ variadic=True,
2989
+ defaults={"preserve_type": False},
2990
+ doc="""Construct a new tile from per-thread kernel values.
2991
+
2992
+ This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2486
2993
 
2487
2994
  * If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
2488
2995
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2996
+ * If the input value is a vector, and ``preserve_type=True``, then the resulting tile has ``dtype=vector`` and ``shape=(block_dim,)``
2997
+ * If the input value is a matrix, then the resulting tile has ``shape=(rows, cols, block_dim)``
2998
+ * If the input value is a matrix, and ``preserve_type=True``, then the resulting tile has ``dtype=matrix`` and ``shape=(block_dim,)``
2489
2999
 
2490
3000
  :param x: A per-thread local value, e.g. scalar, vector, or matrix.
2491
- :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
3001
+ :param preserve_type: If true, the tile will have the same data type as the input value.
3002
+ :returns: If ``preserve_type=True``, a tile of type ``x.type`` of length ``block_dim``. Otherwise, an N-dimensional tile such that the first N-1 dimensions match the shape of ``x`` and the final dimension is of size ``block_dim``.
2492
3003
 
2493
3004
  This example shows how to create a linear sequence from thread variables:
2494
3005
 
@@ -2511,13 +3022,14 @@ add_builtin(
2511
3022
  """,
2512
3023
  group="Tile Primitives",
2513
3024
  export=False,
3025
+ hidden=True,
2514
3026
  )
2515
3027
 
2516
3028
 
2517
3029
  def untile_value_func(arg_types, arg_values):
2518
3030
  # return generic type (for doc builds)
2519
3031
  if arg_types is None:
2520
- return Scalar
3032
+ return Any
2521
3033
 
2522
3034
  if len(arg_types) != 1:
2523
3035
  raise TypeError(f"untile() takes exactly 1 positional argument but {len(arg_types)} were given")
@@ -2536,13 +3048,15 @@ def untile_value_func(arg_types, arg_values):
2536
3048
  return t.dtype
2537
3049
  elif len(t.shape) == 2:
2538
3050
  return warp.types.vector(t.shape[0], t.dtype)
3051
+ elif len(t.shape) == 3:
3052
+ return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
2539
3053
  else:
2540
3054
  raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
2541
3055
 
2542
3056
 
2543
3057
  add_builtin(
2544
3058
  "untile",
2545
- input_types={"a": Tile(dtype=Any, shape=Any)},
3059
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...])},
2546
3060
  value_func=untile_value_func,
2547
3061
  variadic=True,
2548
3062
  doc="""Convert a tile back to per-thread values.
@@ -2592,7 +3106,7 @@ add_builtin(
2592
3106
  def tile_extract_value_func(arg_types, arg_values):
2593
3107
  # return generic type (for doc builds)
2594
3108
  if arg_types is None:
2595
- return Scalar
3109
+ return Any
2596
3110
 
2597
3111
  # force the input tile to shared memory
2598
3112
  arg_types["a"].storage = "shared"
@@ -2602,10 +3116,10 @@ def tile_extract_value_func(arg_types, arg_values):
2602
3116
 
2603
3117
  add_builtin(
2604
3118
  "tile_extract",
2605
- input_types={"a": Tile(dtype=Any, shape=Any), "i": int},
3119
+ input_types={"a": tile(dtype=Any, shape=Tuple[int]), "i": int},
2606
3120
  value_func=tile_extract_value_func,
2607
3121
  variadic=False,
2608
- doc="""Extract a single element from the tile and return it as a scalar type.
3122
+ doc="""Extract a single element from the tile.
2609
3123
 
2610
3124
  This function will extract an element from the tile and broadcast its value to all threads in the block.
2611
3125
 
@@ -2619,13 +3133,12 @@ add_builtin(
2619
3133
  export=False,
2620
3134
  )
2621
3135
 
2622
-
2623
3136
  add_builtin(
2624
3137
  "tile_extract",
2625
- input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int},
3138
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int},
2626
3139
  value_func=tile_extract_value_func,
2627
3140
  variadic=False,
2628
- doc="""Extract a single element from the tile and return it as a scalar type.
3141
+ doc="""Extract a single element from the tile.
2629
3142
 
2630
3143
  This function will extract an element from the tile and broadcast its value to all threads in the block.
2631
3144
 
@@ -2642,10 +3155,10 @@ add_builtin(
2642
3155
 
2643
3156
  add_builtin(
2644
3157
  "tile_extract",
2645
- input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int},
3158
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int},
2646
3159
  value_func=tile_extract_value_func,
2647
3160
  variadic=False,
2648
- doc="""Extract a single element from the tile and return it as a scalar type.
3161
+ doc="""Extract a single element from the tile.
2649
3162
 
2650
3163
  This function will extract an element from the tile and broadcast its value to all threads in the block.
2651
3164
 
@@ -2663,10 +3176,10 @@ add_builtin(
2663
3176
 
2664
3177
  add_builtin(
2665
3178
  "tile_extract",
2666
- input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int},
3179
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int, int]), "i": int, "j": int, "k": int, "l": int},
2667
3180
  value_func=tile_extract_value_func,
2668
3181
  variadic=False,
2669
- doc="""Extract a single element from the tile and return it as a scalar type.
3182
+ doc="""Extract a single element from the tile.
2670
3183
 
2671
3184
  This function will extract an element from the tile and broadcast its value to all threads in the block.
2672
3185
 
@@ -2684,10 +3197,90 @@ add_builtin(
2684
3197
  )
2685
3198
 
2686
3199
 
3200
+ def tile_inplace_value_func(arg_types, arg_values):
3201
+ if not types_equal(arg_types["a"].dtype, arg_types["value"]):
3202
+ raise TypeError(
3203
+ f"'value' must have the same dtype as target tile for inplace ops, got {arg_types['a'].dtype} and {arg_types['value']}"
3204
+ )
3205
+
3206
+ # force the input tile to shared memory
3207
+ # as inplace addition/subtraction relies on shared memory atomics
3208
+ arg_types["a"].storage = "shared"
3209
+
3210
+ return None
3211
+
3212
+
3213
+ add_builtin(
3214
+ "tile_add_inplace",
3215
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
3216
+ value_func=tile_inplace_value_func,
3217
+ group="Tile Primitives",
3218
+ hidden=True,
3219
+ export=False,
3220
+ )
3221
+ add_builtin(
3222
+ "tile_add_inplace",
3223
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
3224
+ value_func=tile_inplace_value_func,
3225
+ group="Tile Primitives",
3226
+ hidden=True,
3227
+ export=False,
3228
+ )
3229
+ add_builtin(
3230
+ "tile_add_inplace",
3231
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
3232
+ value_func=tile_inplace_value_func,
3233
+ group="Tile Primitives",
3234
+ hidden=True,
3235
+ export=False,
3236
+ )
3237
+ add_builtin(
3238
+ "tile_add_inplace",
3239
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
3240
+ value_func=tile_inplace_value_func,
3241
+ group="Tile Primitives",
3242
+ hidden=True,
3243
+ export=False,
3244
+ )
3245
+
3246
+ add_builtin(
3247
+ "tile_sub_inplace",
3248
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
3249
+ value_func=tile_inplace_value_func,
3250
+ group="Tile Primitives",
3251
+ hidden=True,
3252
+ export=False,
3253
+ )
3254
+ add_builtin(
3255
+ "tile_sub_inplace",
3256
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
3257
+ value_func=tile_inplace_value_func,
3258
+ group="Tile Primitives",
3259
+ hidden=True,
3260
+ export=False,
3261
+ )
3262
+ add_builtin(
3263
+ "tile_sub_inplace",
3264
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
3265
+ value_func=tile_inplace_value_func,
3266
+ group="Tile Primitives",
3267
+ hidden=True,
3268
+ export=False,
3269
+ )
3270
+ add_builtin(
3271
+ "tile_sub_inplace",
3272
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
3273
+ value_func=tile_inplace_value_func,
3274
+ group="Tile Primitives",
3275
+ hidden=True,
3276
+ export=False,
3277
+ )
3278
+
3279
+
2687
3280
  def tile_transpose_value_func(arg_types, arg_values):
2688
3281
  # return generic type (for doc builds)
2689
3282
  if arg_types is None:
2690
- return Tile
3283
+ return tile(dtype=Any, shape=Tuple[int, int])
2691
3284
 
2692
3285
  if len(arg_types) != 1:
2693
3286
  raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
@@ -2708,10 +3301,9 @@ def tile_transpose_value_func(arg_types, arg_values):
2708
3301
  # force the input tile to shared memory
2709
3302
  t.storage = "shared"
2710
3303
 
2711
- return Tile(
3304
+ return tile(
2712
3305
  dtype=t.dtype,
2713
3306
  shape=t.shape[::-1],
2714
- op="transpose",
2715
3307
  storage=t.storage,
2716
3308
  strides=t.strides[::-1],
2717
3309
  layout=layout,
@@ -2721,7 +3313,7 @@ def tile_transpose_value_func(arg_types, arg_values):
2721
3313
 
2722
3314
  add_builtin(
2723
3315
  "tile_transpose",
2724
- input_types={"a": Tile(dtype=Any, shape=Any)},
3316
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
2725
3317
  value_func=tile_transpose_value_func,
2726
3318
  variadic=True,
2727
3319
  doc="""Transpose a tile.
@@ -2739,12 +3331,16 @@ add_builtin(
2739
3331
  def tile_broadcast_value_func(arg_types, arg_values):
2740
3332
  # return generic type (for doc builds)
2741
3333
  if arg_types is None:
2742
- return Tile
3334
+ return tile(dtype=Any, shape=Tuple[int, ...])
2743
3335
 
2744
3336
  t = arg_types["a"]
2745
3337
 
2746
3338
  # target shape and strides
2747
- target_shape = tile_unpack_shape(arg_values)
3339
+ target_shape = extract_tuple(arg_values["shape"], as_constant=True)
3340
+
3341
+ if None in target_shape:
3342
+ raise ValueError("Tile functions require shape to be a compile time constant.")
3343
+
2748
3344
  target_strides = [0] * len(target_shape)
2749
3345
 
2750
3346
  offset = len(target_shape) - len(t.shape)
@@ -2769,10 +3365,7 @@ def tile_broadcast_value_func(arg_types, arg_values):
2769
3365
  # force the input tile to shared memory
2770
3366
  t.storage = "shared"
2771
3367
 
2772
- tile_type = Tile(
2773
- dtype=t.dtype, shape=target_shape, op="broadcast", storage=t.storage, strides=target_strides, owner=False
2774
- )
2775
- return tile_type
3368
+ return tile(dtype=t.dtype, shape=target_shape, storage=t.storage, strides=target_strides, owner=False)
2776
3369
 
2777
3370
 
2778
3371
  def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
@@ -2787,7 +3380,7 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
2787
3380
 
2788
3381
  add_builtin(
2789
3382
  "tile_broadcast",
2790
- input_types={"a": Tile(dtype=Any, shape=Any), "shape": Tuple[int, ...]},
3383
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "shape": Tuple[int, ...]},
2791
3384
  value_func=tile_broadcast_value_func,
2792
3385
  dispatch_func=tile_broadcast_dispatch_func,
2793
3386
  variadic=False,
@@ -2807,7 +3400,7 @@ add_builtin(
2807
3400
  def tile_sum_value_func(arg_types, arg_values):
2808
3401
  # return generic type (for doc builds)
2809
3402
  if arg_types is None:
2810
- return Tile(dtype=Any, shape=(1,))
3403
+ return tile(dtype=Scalar, shape=(1,))
2811
3404
 
2812
3405
  if len(arg_types) != 1:
2813
3406
  raise TypeError(f"tile_sum() takes exactly 1 positional argument but {len(arg_types)} were given")
@@ -2817,12 +3410,12 @@ def tile_sum_value_func(arg_types, arg_values):
2817
3410
  if not is_tile(a):
2818
3411
  raise TypeError(f"tile_sum() argument must be a tile, got {a!r}")
2819
3412
 
2820
- return Tile(dtype=a.dtype, shape=(1,), op="sum")
3413
+ return tile(dtype=a.dtype, shape=(1,))
2821
3414
 
2822
3415
 
2823
3416
  add_builtin(
2824
3417
  "tile_sum",
2825
- input_types={"a": Tile},
3418
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
2826
3419
  value_func=tile_sum_value_func,
2827
3420
  variadic=True,
2828
3421
  doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
@@ -2856,10 +3449,89 @@ add_builtin(
2856
3449
  )
2857
3450
 
2858
3451
 
3452
+ def tile_sort_value_func(arg_types, arg_values):
3453
+ # return generic type (for doc builds)
3454
+ if arg_types is None:
3455
+ return None
3456
+
3457
+ if len(arg_types) != 2:
3458
+ raise TypeError(
3459
+ f"tile_sort() takes exactly 2 positional arguments (keys and values) but {len(arg_types)} were given"
3460
+ )
3461
+
3462
+ a = arg_types["keys"]
3463
+ b = arg_types["values"]
3464
+
3465
+ if not is_tile(a):
3466
+ raise TypeError(f"First tile_sort() argument must be a tile, got {a!r}")
3467
+
3468
+ if not is_tile(b):
3469
+ raise TypeError(f"Second tile_sort() argument must be a tile, got {b!r}")
3470
+
3471
+ if not (a.dtype is warp.float32 or a.dtype is warp.int32 or a.dtype is warp.uint32):
3472
+ raise TypeError(f"First tile_sort() argument must be a tile of type float or int, got {a.dtype}")
3473
+
3474
+ # set the storage type to the inputs to shared
3475
+ a.storage = "shared"
3476
+ b.storage = "shared"
3477
+
3478
+ if len(a.shape) != len(b.shape):
3479
+ raise ValueError(
3480
+ f"tile_sort() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
3481
+ )
3482
+
3483
+ for i in range(len(a.shape)):
3484
+ if a.shape[i] != b.shape[i]:
3485
+ raise ValueError(f"tile_sort() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
3486
+
3487
+ return None
3488
+
3489
+
3490
+ add_builtin(
3491
+ "tile_sort",
3492
+ input_types={"keys": tile(dtype=Any, shape=Tuple[int]), "values": tile(dtype=Any, shape=Tuple[int])},
3493
+ value_func=tile_sort_value_func,
3494
+ variadic=True,
3495
+ doc="""Cooperatively sort the elements of two tiles in ascending order based on the keys, using all threads in the block.
3496
+
3497
+ :param keys: Keys to sort by. Supported key types: :class:`float32`, :class:`int32`, :class:`uint32`. Must be in shared memory.
3498
+ :param values: Values to sort along with keys. No type restrictions. Must be in shared memory.
3499
+ :returns: No return value. Sorts both tiles in-place.
3500
+
3501
+ Example:
3502
+
3503
+ .. code-block:: python
3504
+
3505
+ @wp.kernel
3506
+ def compute():
3507
+
3508
+ keys = wp.tile_arange(32, 0, -1, dtype=int, storage="shared")
3509
+ values = wp.tile_arange(0, 32, 1, dtype=int, storage="shared")
3510
+ wp.tile_sort(keys, values)
3511
+
3512
+ print(keys)
3513
+ print(values)
3514
+
3515
+
3516
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
3517
+
3518
+ Prints:
3519
+
3520
+ .. code-block:: text
3521
+
3522
+ [1, 2, ..., 32] = tile(shape=(32), storage=shared)
3523
+ [31, 30, 29, ..., 0] = tile(shape=(32), storage=shared)
3524
+
3525
+ """,
3526
+ group="Tile Primitives",
3527
+ export=False,
3528
+ )
3529
+
3530
+
2859
3531
  def tile_min_value_func(arg_types, arg_values):
2860
3532
  # return generic type (for doc builds)
2861
3533
  if arg_types is None:
2862
- return Tile(dtype=Any, shape=(1,))
3534
+ return tile(dtype=Scalar, shape=(1,))
2863
3535
 
2864
3536
  if len(arg_types) != 1:
2865
3537
  raise TypeError(f"tile_min() takes exactly 1 positional argument but {len(arg_types)} were given")
@@ -2869,12 +3541,12 @@ def tile_min_value_func(arg_types, arg_values):
2869
3541
  if not is_tile(a):
2870
3542
  raise TypeError(f"tile_min() argument must be a tile, got {a!r}")
2871
3543
 
2872
- return Tile(dtype=a.dtype, shape=(1,), op="min")
3544
+ return tile(dtype=a.dtype, shape=(1,))
2873
3545
 
2874
3546
 
2875
3547
  add_builtin(
2876
3548
  "tile_min",
2877
- input_types={"a": Tile},
3549
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
2878
3550
  value_func=tile_min_value_func,
2879
3551
  variadic=True,
2880
3552
  doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
@@ -2909,10 +3581,63 @@ add_builtin(
2909
3581
  )
2910
3582
 
2911
3583
 
3584
+ def tile_argmin_value_func(arg_types, arg_values):
3585
+ # return generic type (for doc builds)
3586
+ if arg_types is None:
3587
+ return tile(dtype=Int, shape=(1,))
3588
+
3589
+ if len(arg_types) != 1:
3590
+ raise TypeError(f"tile_argmin() takes exactly 1 positional argument but {len(arg_types)} were given")
3591
+
3592
+ a = arg_types["a"]
3593
+
3594
+ if not is_tile(a):
3595
+ raise TypeError(f"tile_argmin() argument must be a tile, got {a!r}")
3596
+
3597
+ return tile(dtype=warp.int32, shape=(1,))
3598
+
3599
+
3600
+ add_builtin(
3601
+ "tile_argmin",
3602
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3603
+ value_func=tile_argmin_value_func,
3604
+ variadic=True,
3605
+ doc="""Cooperatively compute the index of the minimum element in the tile using all threads in the block.
3606
+
3607
+ :param a: The tile to compute the argmin from
3608
+ :returns: A single-element tile holding the index of the minimum value
3609
+
3610
+ Example:
3611
+
3612
+ .. code-block:: python
3613
+
3614
+ @wp.kernel
3615
+ def compute():
3616
+
3617
+ t = wp.tile_arange(64, 128)
3618
+ s = wp.tile_argmin(t)
3619
+
3620
+ print(s)
3621
+
3622
+
3623
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
3624
+
3625
+ Prints:
3626
+
3627
+ .. code-block:: text
3628
+
3629
+ [0] = tile(shape=(1), storage=register)
3630
+
3631
+ """,
3632
+ group="Tile Primitives",
3633
+ export=False,
3634
+ )
3635
+
3636
+
2912
3637
  def tile_max_value_func(arg_types, arg_values):
2913
3638
  # return generic type (for doc builds)
2914
3639
  if arg_types is None:
2915
- return Tile(dtype=Any, shape=(1,))
3640
+ return tile(dtype=Scalar, shape=(1,))
2916
3641
 
2917
3642
  if len(arg_types) != 1:
2918
3643
  raise TypeError(f"tile_max() takes exactly 1 positional argument but {len(arg_types)} were given")
@@ -2922,12 +3647,12 @@ def tile_max_value_func(arg_types, arg_values):
2922
3647
  if not is_tile(a):
2923
3648
  raise TypeError(f"tile_max() argument must be a tile, got {a!r}")
2924
3649
 
2925
- return Tile(dtype=a.dtype, shape=(1,), op="min")
3650
+ return tile(dtype=a.dtype, shape=(1,))
2926
3651
 
2927
3652
 
2928
3653
  add_builtin(
2929
3654
  "tile_max",
2930
- input_types={"a": Tile(dtype=Any, shape=Any)},
3655
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
2931
3656
  value_func=tile_max_value_func,
2932
3657
  variadic=False,
2933
3658
  doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
@@ -2961,17 +3686,69 @@ add_builtin(
2961
3686
  )
2962
3687
 
2963
3688
 
3689
+ def tile_argmax_value_func(arg_types, arg_values):
3690
+ # return generic type (for doc builds)
3691
+ if arg_types is None:
3692
+ return tile(dtype=Int, shape=(1,))
3693
+
3694
+ if len(arg_types) != 1:
3695
+ raise TypeError(f"tile_argmax() takes exactly 1 positional argument but {len(arg_types)} were given")
3696
+
3697
+ a = arg_types["a"]
3698
+
3699
+ if not is_tile(a):
3700
+ raise TypeError(f"tile_argmax() argument must be a tile, got {a!r}")
3701
+
3702
+ return tile(dtype=warp.int32, shape=(1,))
3703
+
3704
+
3705
+ add_builtin(
3706
+ "tile_argmax",
3707
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3708
+ value_func=tile_argmax_value_func,
3709
+ variadic=False,
3710
+ doc="""Cooperatively compute the index of the maximum element in the tile using all threads in the block.
3711
+
3712
+ :param a: The tile to compute the argmax from
3713
+ :returns: A single-element tile holding the index of the maximum value
3714
+
3715
+ Example:
3716
+
3717
+ .. code-block:: python
3718
+
3719
+ @wp.kernel
3720
+ def compute():
3721
+
3722
+ t = wp.tile_arange(64, 128)
3723
+ s = wp.tile_argmax(t)
3724
+
3725
+ print(s)
3726
+
3727
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
3728
+
3729
+ Prints:
3730
+
3731
+ .. code-block:: text
3732
+
3733
+ [63] = tile(shape=(1), storage=register)
3734
+
3735
+ """,
3736
+ group="Tile Primitives",
3737
+ export=False,
3738
+ )
3739
+
3740
+
2964
3741
  # does type propagation for load()
2965
3742
  def tile_reduce_value_func(arg_types, arg_values):
2966
3743
  if arg_types is None:
2967
- return Tile(dtype=Any, shape=(1,))
3744
+ return tile(dtype=Scalar, shape=(1,))
2968
3745
 
2969
3746
  a = arg_types["a"]
2970
3747
 
2971
3748
  if not is_tile(a):
2972
3749
  raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
2973
3750
 
2974
- return Tile(dtype=a.dtype, shape=(1,), op="reduce")
3751
+ return tile(dtype=a.dtype, shape=(1,))
2975
3752
 
2976
3753
 
2977
3754
  def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -2982,7 +3759,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2982
3759
 
2983
3760
  add_builtin(
2984
3761
  "tile_reduce",
2985
- input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
3762
+ input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
2986
3763
  value_func=tile_reduce_value_func,
2987
3764
  native_func="tile_reduce",
2988
3765
  doc="""Apply a custom reduction operator across the tile.
@@ -3005,37 +3782,164 @@ add_builtin(
3005
3782
 
3006
3783
  print(s)
3007
3784
 
3008
- wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
3785
+ wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
3786
+
3787
+ Prints:
3788
+
3789
+ .. code-block:: text
3790
+
3791
+ [362880] = tile(shape=(1), storage=register)
3792
+ """,
3793
+ group="Tile Primitives",
3794
+ export=False,
3795
+ )
3796
+
3797
+
3798
+ def tile_scan_inclusive_value_func(arg_types, arg_values):
3799
+ # Return type is the same as input type
3800
+ if arg_types is None:
3801
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
3802
+
3803
+ if len(arg_types) != 1:
3804
+ raise TypeError(f"tile_scan_inclusive() takes exactly 1 positional argument but {len(arg_types)} were given")
3805
+
3806
+ a = arg_types["a"]
3807
+
3808
+ if not is_tile(a):
3809
+ raise TypeError(f"tile_scan_inclusive() argument must be a tile, got {a!r}")
3810
+
3811
+ # Only allow float32, int32, or uint32 for scan (like tile_sort)
3812
+ if not (a.dtype is warp.float32 or a.dtype is warp.int32 or a.dtype is warp.uint32):
3813
+ raise TypeError(
3814
+ f"tile_scan_inclusive() argument must be a tile of type float32, int32, or uint32, got {a.dtype}"
3815
+ )
3816
+
3817
+ return tile(dtype=a.dtype, shape=a.shape)
3818
+
3819
+
3820
+ def tile_scan_inclusive_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3821
+ func_args = (args["a"],)
3822
+ template_args = ()
3823
+ return (func_args, template_args)
3824
+
3825
+
3826
+ add_builtin(
3827
+ "tile_scan_inclusive",
3828
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3829
+ value_func=tile_scan_inclusive_value_func,
3830
+ native_func="tile_scan_inclusive",
3831
+ doc="""Inclusive scan (prefix sum) across the tile.
3832
+
3833
+ This function cooperatively performs an inclusive scan (cumulative sum) across the tile.
3834
+
3835
+ :param a: The input tile. Must be a tile of type float32, int32, or uint32.
3836
+ :returns: A new tile containing the inclusive scan result.
3837
+
3838
+ Example:
3839
+
3840
+ .. code-block:: python
3841
+
3842
+ @wp.kernel
3843
+ def scan_example():
3844
+ t = wp.tile_arange(1, 5, dtype=int)
3845
+ s = wp.tile_scan_inclusive(t)
3846
+ print(s)
3847
+
3848
+ wp.launch_tiled(scan_example, dim=[1], inputs=[], block_dim=16)
3849
+
3850
+ Prints:
3851
+
3852
+ .. code-block:: text
3853
+
3854
+ [1, 3, 6, 10] = tile(shape=(4), storage=register)
3855
+ """,
3856
+ group="Tile Primitives",
3857
+ export=False,
3858
+ )
3859
+
3860
+
3861
+ def tile_scan_exclusive_value_func(arg_types, arg_values):
3862
+ # return generic type (for doc builds)
3863
+ if arg_types is None:
3864
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
3865
+
3866
+ if len(arg_types) != 1:
3867
+ raise TypeError(f"tile_scan_exclusive() takes exactly 1 positional argument but {len(arg_types)} were given")
3868
+
3869
+ a = arg_types["a"]
3870
+
3871
+ if not is_tile(a):
3872
+ raise TypeError(f"tile_scan_exclusive() argument must be a tile, got {a!r}")
3873
+
3874
+ # Only allow float32, int32, or uint32 for scan (like tile_sort)
3875
+ if not (a.dtype is warp.float32 or a.dtype is warp.int32 or a.dtype is warp.uint32):
3876
+ raise TypeError(
3877
+ f"tile_scan_exclusive() argument must be a tile of type float32, int32, or uint32, got {a.dtype}"
3878
+ )
3879
+
3880
+ return tile(dtype=a.dtype, shape=a.shape)
3881
+
3882
+
3883
+ def tile_scan_exclusive_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3884
+ func_args = (args["a"],)
3885
+ template_args = ()
3886
+ return (func_args, template_args)
3887
+
3888
+
3889
+ add_builtin(
3890
+ "tile_scan_exclusive",
3891
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3892
+ value_func=tile_scan_exclusive_value_func,
3893
+ native_func="tile_scan_exclusive",
3894
+ doc="""Exclusive scan (prefix sum) across the tile.
3895
+
3896
+ This function cooperatively performs an exclusive scan (cumulative sum) across the tile.
3897
+
3898
+ :param a: The input tile. Must be a tile of type float32, int32, or uint32.
3899
+ :returns: A new tile containing the exclusive scan result.
3900
+
3901
+ Example:
3902
+
3903
+ .. code-block:: python
3904
+
3905
+ @wp.kernel
3906
+ def scan_example():
3907
+ t = wp.tile_arange(1, 5, dtype=int)
3908
+ s = wp.tile_scan_exclusive(t)
3909
+ print(s)
3910
+
3911
+ wp.launch_tiled(scan_example, dim=[1], inputs=[], block_dim=16)
3009
3912
 
3010
3913
  Prints:
3011
3914
 
3012
3915
  .. code-block:: text
3013
3916
 
3014
- [362880] = tile(shape=(1), storage=register)
3917
+ [0, 1, 3, 6] = tile(shape=(4), storage=register)
3015
3918
  """,
3016
3919
  group="Tile Primitives",
3017
3920
  export=False,
3018
3921
  )
3019
3922
 
3923
+
3020
3924
  # maps
3021
3925
 
3022
3926
 
3023
3927
  # does type propagation for load()
3024
3928
  def tile_unary_map_value_func(arg_types, arg_values):
3025
3929
  if arg_types is None:
3026
- return Tile(dtype=Any, shape=Any)
3930
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
3027
3931
 
3028
3932
  a = arg_types["a"]
3029
3933
 
3030
3934
  if not is_tile(a):
3031
3935
  raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
3032
3936
 
3033
- return TileUnaryMap(a)
3937
+ return tile(dtype=a.dtype, shape=a.shape)
3034
3938
 
3035
3939
 
3036
3940
  add_builtin(
3037
3941
  "tile_map",
3038
- input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
3942
+ input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
3039
3943
  value_func=tile_unary_map_value_func,
3040
3944
  # dispatch_func=tile_map_dispatch_func,
3041
3945
  # variadic=True,
@@ -3075,7 +3979,7 @@ add_builtin(
3075
3979
 
3076
3980
  def tile_binary_map_value_func(arg_types, arg_values):
3077
3981
  if arg_types is None:
3078
- return Tile(dtype=Any, shape=Any)
3982
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
3079
3983
 
3080
3984
  a = arg_types["a"]
3081
3985
  b = arg_types["b"]
@@ -3100,12 +4004,16 @@ def tile_binary_map_value_func(arg_types, arg_values):
3100
4004
  if a.shape[i] != b.shape[i]:
3101
4005
  raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
3102
4006
 
3103
- return TileBinaryMap(a, b)
4007
+ return tile(dtype=a.dtype, shape=a.shape)
3104
4008
 
3105
4009
 
3106
4010
  add_builtin(
3107
4011
  "tile_map",
3108
- input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
4012
+ input_types={
4013
+ "op": Callable,
4014
+ "a": tile(dtype=Scalar, shape=Tuple[int, ...]),
4015
+ "b": tile(dtype=Scalar, shape=Tuple[int, ...]),
4016
+ },
3109
4017
  value_func=tile_binary_map_value_func,
3110
4018
  # dispatch_func=tile_map_dispatch_func,
3111
4019
  # variadic=True,
@@ -3255,57 +4163,13 @@ add_builtin(
3255
4163
  )
3256
4164
 
3257
4165
 
3258
- def mlp_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3259
- warp.utils.warn(
3260
- "wp.mlp() is deprecated and will be removed in a future\nversion. Use tile primitives instead.",
3261
- category=DeprecationWarning,
3262
- )
3263
-
3264
- func_args = tuple(args.values())
3265
- template_args = ()
3266
-
3267
- return (func_args, template_args)
3268
-
3269
-
3270
- add_builtin(
3271
- "mlp",
3272
- input_types={
3273
- "weights": array(dtype=float, ndim=2),
3274
- "bias": array(dtype=float, ndim=1),
3275
- "activation": Callable,
3276
- "index": int,
3277
- "x": array(dtype=float, ndim=2),
3278
- "out": array(dtype=float, ndim=2),
3279
- },
3280
- value_type=None,
3281
- dispatch_func=mlp_dispatch_func,
3282
- skip_replay=True,
3283
- doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
3284
-
3285
- .. deprecated:: 1.6
3286
- Use :doc:`tile primitives </modules/tiles>` instead.
3287
-
3288
- :param weights: A layer's network weights with dimensions ``(m, n)``.
3289
- :param bias: An array with dimensions ``(n)``.
3290
- :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
3291
- :param index: The batch item to process, typically each thread will process one item in the batch, in which case
3292
- index should be ``wp.tid()``
3293
- :param x: The feature matrix with dimensions ``(n, b)``
3294
- :param out: The network output with dimensions ``(m, b)``
3295
-
3296
- :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch.
3297
- All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
3298
- group="Utility",
3299
- )
3300
-
3301
-
3302
4166
  # ---------------------------------
3303
4167
  # Geometry
3304
4168
 
3305
4169
  add_builtin(
3306
4170
  "bvh_query_aabb",
3307
4171
  input_types={"id": uint64, "low": vec3, "high": vec3},
3308
- value_func=lambda arg_types, _: BvhQuery if arg_types is None else bvh_query_t,
4172
+ value_type=BvhQuery,
3309
4173
  group="Geometry",
3310
4174
  doc="""Construct an axis-aligned bounding box query against a BVH object.
3311
4175
 
@@ -3320,7 +4184,7 @@ add_builtin(
3320
4184
  add_builtin(
3321
4185
  "bvh_query_ray",
3322
4186
  input_types={"id": uint64, "start": vec3, "dir": vec3},
3323
- value_func=lambda arg_types, _: BvhQuery if arg_types is None else bvh_query_t,
4187
+ value_type=BvhQuery,
3324
4188
  group="Geometry",
3325
4189
  doc="""Construct a ray query against a BVH object.
3326
4190
 
@@ -3380,7 +4244,7 @@ add_builtin(
3380
4244
  "point": vec3,
3381
4245
  "max_dist": float,
3382
4246
  },
3383
- value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
4247
+ value_type=MeshQueryPoint,
3384
4248
  group="Geometry",
3385
4249
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
3386
4250
 
@@ -3428,7 +4292,7 @@ add_builtin(
3428
4292
  "point": vec3,
3429
4293
  "max_dist": float,
3430
4294
  },
3431
- value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
4295
+ value_type=MeshQueryPoint,
3432
4296
  group="Geometry",
3433
4297
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
3434
4298
 
@@ -3474,7 +4338,7 @@ add_builtin(
3474
4338
  "point": vec3,
3475
4339
  "min_dist": float,
3476
4340
  },
3477
- value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
4341
+ value_type=MeshQueryPoint,
3478
4342
  group="Geometry",
3479
4343
  doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
3480
4344
 
@@ -3531,7 +4395,7 @@ add_builtin(
3531
4395
  "epsilon": float,
3532
4396
  },
3533
4397
  defaults={"epsilon": 1.0e-3},
3534
- value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
4398
+ value_type=MeshQueryPoint,
3535
4399
  group="Geometry",
3536
4400
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
3537
4401
 
@@ -3596,7 +4460,7 @@ add_builtin(
3596
4460
  "threshold": float,
3597
4461
  },
3598
4462
  defaults={"accuracy": 2.0, "threshold": 0.5},
3599
- value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
4463
+ value_type=MeshQueryPoint,
3600
4464
  group="Geometry",
3601
4465
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
3602
4466
 
@@ -3655,7 +4519,7 @@ add_builtin(
3655
4519
  "dir": vec3,
3656
4520
  "max_t": float,
3657
4521
  },
3658
- value_func=lambda arg_types, _: MeshQueryRay if arg_types is None else mesh_query_ray_t,
4522
+ value_type=MeshQueryRay,
3659
4523
  group="Geometry",
3660
4524
  doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
3661
4525
 
@@ -3670,7 +4534,7 @@ add_builtin(
3670
4534
  add_builtin(
3671
4535
  "mesh_query_aabb",
3672
4536
  input_types={"id": uint64, "low": vec3, "high": vec3},
3673
- value_func=lambda arg_types, _: MeshQueryAABB if arg_types is None else mesh_query_aabb_t,
4537
+ value_type=MeshQueryAABB,
3674
4538
  group="Geometry",
3675
4539
  doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
3676
4540
 
@@ -3714,7 +4578,7 @@ add_builtin(
3714
4578
  add_builtin(
3715
4579
  "hash_grid_query",
3716
4580
  input_types={"id": uint64, "point": vec3, "max_dist": float},
3717
- value_func=lambda arg_types, _: HashGridQuery if arg_types is None else hash_grid_query_t,
4581
+ value_type=HashGridQuery,
3718
4582
  group="Geometry",
3719
4583
  doc="""Construct a point query against a :class:`HashGrid`.
3720
4584
 
@@ -3843,10 +4707,10 @@ add_builtin(
3843
4707
 
3844
4708
  add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
3845
4709
  add_builtin(
3846
- "iter_next", input_types={"query": hash_grid_query_t}, value_type=int, group="Utility", export=False, hidden=True
4710
+ "iter_next", input_types={"query": HashGridQuery}, value_type=int, group="Utility", export=False, hidden=True
3847
4711
  )
3848
4712
  add_builtin(
3849
- "iter_next", input_types={"query": mesh_query_aabb_t}, value_type=int, group="Utility", export=False, hidden=True
4713
+ "iter_next", input_types={"query": MeshQueryAABB}, value_type=int, group="Utility", export=False, hidden=True
3850
4714
  )
3851
4715
 
3852
4716
  add_builtin(
@@ -3889,7 +4753,7 @@ def _check_volume_type_is_supported(dtype):
3889
4753
 
3890
4754
  def check_volume_value_grad_compatibility(dtype, grad_dtype):
3891
4755
  if type_is_vector(dtype):
3892
- expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
4756
+ expected = matrix(shape=(type_size(dtype), 3), dtype=type_scalar_type(dtype))
3893
4757
  else:
3894
4758
  expected = vector(length=3, dtype=dtype)
3895
4759
 
@@ -4062,6 +4926,7 @@ add_builtin(
4062
4926
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": float},
4063
4927
  group="Volumes",
4064
4928
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
4929
+ export=False,
4065
4930
  )
4066
4931
 
4067
4932
  add_builtin(
@@ -4089,6 +4954,7 @@ add_builtin(
4089
4954
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": vec3},
4090
4955
  group="Volumes",
4091
4956
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
4957
+ export=False,
4092
4958
  )
4093
4959
 
4094
4960
  add_builtin(
@@ -4114,6 +4980,7 @@ add_builtin(
4114
4980
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": int},
4115
4981
  group="Volumes",
4116
4982
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
4983
+ export=False,
4117
4984
  )
4118
4985
 
4119
4986
 
@@ -4527,6 +5394,16 @@ add_builtin(
4527
5394
  native_func="builtin_tid1d",
4528
5395
  )
4529
5396
 
5397
+ add_builtin(
5398
+ "block_dim",
5399
+ input_types={},
5400
+ value_type=int,
5401
+ group="Utility",
5402
+ doc="Returns the number of threads in the current block.",
5403
+ namespace="",
5404
+ native_func="builtin_block_dim",
5405
+ )
5406
+
4530
5407
  add_builtin(
4531
5408
  "tid",
4532
5409
  input_types={},
@@ -4667,7 +5544,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
4667
5544
  return array(dtype=Scalar)
4668
5545
 
4669
5546
  dtype = arg_values["dtype"]
4670
- shape = arg_values["shape"]
5547
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
4671
5548
  return array(dtype=dtype, ndim=len(shape))
4672
5549
 
4673
5550
 
@@ -4677,8 +5554,9 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
4677
5554
  # to the underlying C++ function's runtime and template params.
4678
5555
 
4679
5556
  dtype = return_type.dtype
5557
+ shape = extract_tuple(args["shape"], as_constant=True)
4680
5558
 
4681
- func_args = (args["ptr"], *args["shape"])
5559
+ func_args = (args["ptr"], *shape)
4682
5560
  template_args = (dtype,)
4683
5561
  return (func_args, template_args)
4684
5562
 
@@ -4958,6 +5836,12 @@ def create_atomic_op_value_func(op: str):
4958
5836
  f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
4959
5837
  f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
4960
5838
  )
5839
+ elif op in ("cas", "exch"):
5840
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in SUPPORTED_ATOMIC_TYPES):
5841
+ raise RuntimeError(
5842
+ f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
5843
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
5844
+ )
4961
5845
  else:
4962
5846
  raise NotImplementedError
4963
5847
 
@@ -5187,6 +6071,120 @@ for array_type in array_types:
5187
6071
  skip_replay=True,
5188
6072
  )
5189
6073
 
6074
+ add_builtin(
6075
+ "atomic_cas",
6076
+ hidden=hidden,
6077
+ input_types={"arr": array_type(dtype=Any), "i": Int, "compare": Any, "value": Any},
6078
+ constraint=atomic_op_constraint,
6079
+ value_func=create_atomic_op_value_func("cas"),
6080
+ dispatch_func=atomic_op_dispatch_func,
6081
+ doc="""Atomically compare and swap ``value`` with ``arr[i]`` if ``arr[i]`` equals ``compare``, and return the old value.
6082
+
6083
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6084
+ group="Utility",
6085
+ skip_replay=True,
6086
+ )
6087
+ add_builtin(
6088
+ "atomic_cas",
6089
+ hidden=hidden,
6090
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "compare": Any, "value": Any},
6091
+ constraint=atomic_op_constraint,
6092
+ value_func=create_atomic_op_value_func("cas"),
6093
+ dispatch_func=atomic_op_dispatch_func,
6094
+ doc="""Atomically compare and swap ``value`` with ``arr[i,j]`` if ``arr[i,j]`` equals ``compare``, and return the old value.
6095
+
6096
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6097
+ group="Utility",
6098
+ skip_replay=True,
6099
+ )
6100
+ add_builtin(
6101
+ "atomic_cas",
6102
+ hidden=hidden,
6103
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "compare": Any, "value": Any},
6104
+ constraint=atomic_op_constraint,
6105
+ value_func=create_atomic_op_value_func("cas"),
6106
+ dispatch_func=atomic_op_dispatch_func,
6107
+ doc="""Atomically compare and swap ``value`` with ``arr[i,j,k]`` if ``arr[i,j,k]`` equals ``compare``, and return the old value.
6108
+
6109
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6110
+ group="Utility",
6111
+ skip_replay=True,
6112
+ )
6113
+ add_builtin(
6114
+ "atomic_cas",
6115
+ hidden=hidden,
6116
+ input_types={
6117
+ "arr": array_type(dtype=Any),
6118
+ "i": Int,
6119
+ "j": Int,
6120
+ "k": Int,
6121
+ "l": Int,
6122
+ "compare": Any,
6123
+ "value": Any,
6124
+ },
6125
+ constraint=atomic_op_constraint,
6126
+ value_func=create_atomic_op_value_func("cas"),
6127
+ dispatch_func=atomic_op_dispatch_func,
6128
+ doc="""Atomically compare and swap ``value`` with ``arr[i,j,k,l]`` if ``arr[i,j,k,l]`` equals ``compare``, and return the old value.
6129
+
6130
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6131
+ group="Utility",
6132
+ skip_replay=True,
6133
+ )
6134
+
6135
+ add_builtin(
6136
+ "atomic_exch",
6137
+ hidden=hidden,
6138
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
6139
+ constraint=atomic_op_constraint,
6140
+ value_func=create_atomic_op_value_func("exch"),
6141
+ dispatch_func=atomic_op_dispatch_func,
6142
+ doc="""Atomically exchange ``value`` with ``arr[i]`` and return the old value.
6143
+
6144
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6145
+ group="Utility",
6146
+ skip_replay=True,
6147
+ )
6148
+ add_builtin(
6149
+ "atomic_exch",
6150
+ hidden=hidden,
6151
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
6152
+ constraint=atomic_op_constraint,
6153
+ value_func=create_atomic_op_value_func("exch"),
6154
+ dispatch_func=atomic_op_dispatch_func,
6155
+ doc="""Atomically exchange ``value`` with ``arr[i,j]`` and return the old value.
6156
+
6157
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6158
+ group="Utility",
6159
+ skip_replay=True,
6160
+ )
6161
+ add_builtin(
6162
+ "atomic_exch",
6163
+ hidden=hidden,
6164
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
6165
+ constraint=atomic_op_constraint,
6166
+ value_func=create_atomic_op_value_func("exch"),
6167
+ dispatch_func=atomic_op_dispatch_func,
6168
+ doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
6169
+
6170
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6171
+ group="Utility",
6172
+ skip_replay=True,
6173
+ )
6174
+ add_builtin(
6175
+ "atomic_exch",
6176
+ hidden=hidden,
6177
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
6178
+ constraint=atomic_op_constraint,
6179
+ value_func=create_atomic_op_value_func("exch"),
6180
+ dispatch_func=atomic_op_dispatch_func,
6181
+ doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
6182
+
6183
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
6184
+ group="Utility",
6185
+ skip_replay=True,
6186
+ )
6187
+
5190
6188
 
5191
6189
  # used to index into builtin types, i.e.: y = vec3[1]
5192
6190
  def extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
@@ -5269,6 +6267,16 @@ add_builtin(
5269
6267
  group="Utility",
5270
6268
  skip_replay=True,
5271
6269
  )
6270
+ # implements &transformation[index]
6271
+ add_builtin(
6272
+ "index",
6273
+ input_types={"a": transformation(dtype=Float), "i": int},
6274
+ value_func=vector_index_value_func,
6275
+ dispatch_func=vector_index_dispatch_func,
6276
+ hidden=True,
6277
+ group="Utility",
6278
+ skip_replay=True,
6279
+ )
5272
6280
  # implements &(*vector)[index]
5273
6281
  add_builtin(
5274
6282
  "indexref",
@@ -5289,6 +6297,16 @@ add_builtin(
5289
6297
  group="Utility",
5290
6298
  skip_replay=True,
5291
6299
  )
6300
+ # implements &(*transformation)[index]
6301
+ add_builtin(
6302
+ "indexref",
6303
+ input_types={"a": transformation(dtype=Float), "i": int},
6304
+ value_func=vector_index_value_func,
6305
+ dispatch_func=vector_index_dispatch_func,
6306
+ hidden=True,
6307
+ group="Utility",
6308
+ skip_replay=True,
6309
+ )
5292
6310
 
5293
6311
 
5294
6312
  # implements vector[index] = value
@@ -5297,6 +6315,7 @@ add_builtin(
5297
6315
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5298
6316
  value_type=None,
5299
6317
  hidden=True,
6318
+ export=False,
5300
6319
  group="Utility",
5301
6320
  )
5302
6321
 
@@ -5306,6 +6325,16 @@ add_builtin(
5306
6325
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5307
6326
  value_type=None,
5308
6327
  hidden=True,
6328
+ export=False,
6329
+ group="Utility",
6330
+ )
6331
+ # implements transformation[index] = value
6332
+ add_builtin(
6333
+ "assign_inplace",
6334
+ input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
6335
+ value_type=None,
6336
+ hidden=True,
6337
+ export=False,
5309
6338
  group="Utility",
5310
6339
  )
5311
6340
 
@@ -5321,6 +6350,7 @@ add_builtin(
5321
6350
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5322
6351
  value_func=vector_assign_value_func,
5323
6352
  hidden=True,
6353
+ export=False,
5324
6354
  group="Utility",
5325
6355
  )
5326
6356
 
@@ -5330,6 +6360,17 @@ add_builtin(
5330
6360
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5331
6361
  value_func=vector_assign_value_func,
5332
6362
  hidden=True,
6363
+ export=False,
6364
+ group="Utility",
6365
+ )
6366
+
6367
+ # implements transformation[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
6368
+ add_builtin(
6369
+ "assign_copy",
6370
+ input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
6371
+ value_func=vector_assign_value_func,
6372
+ hidden=True,
6373
+ export=False,
5333
6374
  group="Utility",
5334
6375
  )
5335
6376
 
@@ -5339,6 +6380,7 @@ add_builtin(
5339
6380
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5340
6381
  value_type=None,
5341
6382
  hidden=True,
6383
+ export=False,
5342
6384
  group="Utility",
5343
6385
  )
5344
6386
 
@@ -5348,6 +6390,27 @@ add_builtin(
5348
6390
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5349
6391
  value_type=None,
5350
6392
  hidden=True,
6393
+ export=False,
6394
+ group="Utility",
6395
+ )
6396
+
6397
+ # implements transformation[idx] += scalar
6398
+ add_builtin(
6399
+ "add_inplace",
6400
+ input_types={"a": transformation(dtype=Float), "i": int, "value": Float},
6401
+ value_type=None,
6402
+ hidden=True,
6403
+ export=False,
6404
+ group="Utility",
6405
+ )
6406
+
6407
+ # implements transformation.p += vec3
6408
+ add_builtin(
6409
+ "transform_add_inplace",
6410
+ input_types={"a": transformation(dtype=Float), "value": vector(length=3, dtype=Float)},
6411
+ value_type=None,
6412
+ hidden=True,
6413
+ export=False,
5351
6414
  group="Utility",
5352
6415
  )
5353
6416
 
@@ -5357,6 +6420,7 @@ add_builtin(
5357
6420
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5358
6421
  value_type=None,
5359
6422
  hidden=True,
6423
+ export=False,
5360
6424
  group="Utility",
5361
6425
  )
5362
6426
 
@@ -5366,6 +6430,27 @@ add_builtin(
5366
6430
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5367
6431
  value_type=None,
5368
6432
  hidden=True,
6433
+ export=False,
6434
+ group="Utility",
6435
+ )
6436
+
6437
+ # implements transformation[idx] -= scalar
6438
+ add_builtin(
6439
+ "sub_inplace",
6440
+ input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
6441
+ value_type=None,
6442
+ hidden=True,
6443
+ export=False,
6444
+ group="Utility",
6445
+ )
6446
+
6447
+ # implements transformation.p -= vec3
6448
+ add_builtin(
6449
+ "transform_sub_inplace",
6450
+ input_types={"a": transformation(dtype=Float), "value": vector(length=3, dtype=Float)},
6451
+ value_type=None,
6452
+ hidden=True,
6453
+ export=False,
5369
6454
  group="Utility",
5370
6455
  )
5371
6456
 
@@ -5407,7 +6492,7 @@ add_builtin(
5407
6492
 
5408
6493
 
5409
6494
  def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5410
- mat_size = arg_types["a"]._shape_[0]
6495
+ mat_size = arg_types["a"]._shape_[1]
5411
6496
  vec_size = arg_types["value"]._length_
5412
6497
  mat_type = arg_types["a"]._type_
5413
6498
  vec_type = arg_types["value"]._type_
@@ -5420,6 +6505,7 @@ add_builtin(
5420
6505
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5421
6506
  value_type=None,
5422
6507
  hidden=True,
6508
+ export=False,
5423
6509
  group="Utility",
5424
6510
  )
5425
6511
 
@@ -5431,6 +6517,7 @@ add_builtin(
5431
6517
  constraint=matrix_vector_sametype,
5432
6518
  value_type=None,
5433
6519
  hidden=True,
6520
+ export=False,
5434
6521
  group="Utility",
5435
6522
  )
5436
6523
 
@@ -5446,6 +6533,7 @@ add_builtin(
5446
6533
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5447
6534
  value_func=matrix_assign_value_func,
5448
6535
  hidden=True,
6536
+ export=False,
5449
6537
  group="Utility",
5450
6538
  )
5451
6539
 
@@ -5457,6 +6545,7 @@ add_builtin(
5457
6545
  constraint=matrix_vector_sametype,
5458
6546
  value_func=matrix_assign_value_func,
5459
6547
  hidden=True,
6548
+ export=False,
5460
6549
  group="Utility",
5461
6550
  )
5462
6551
 
@@ -5467,6 +6556,7 @@ add_builtin(
5467
6556
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5468
6557
  value_type=None,
5469
6558
  hidden=True,
6559
+ export=False,
5470
6560
  group="Utility",
5471
6561
  )
5472
6562
 
@@ -5478,6 +6568,7 @@ add_builtin(
5478
6568
  constraint=matrix_vector_sametype,
5479
6569
  value_type=None,
5480
6570
  hidden=True,
6571
+ export=False,
5481
6572
  group="Utility",
5482
6573
  )
5483
6574
 
@@ -5488,6 +6579,7 @@ add_builtin(
5488
6579
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5489
6580
  value_type=None,
5490
6581
  hidden=True,
6582
+ export=False,
5491
6583
  group="Utility",
5492
6584
  )
5493
6585
 
@@ -5498,6 +6590,7 @@ add_builtin(
5498
6590
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5499
6591
  value_type=None,
5500
6592
  hidden=True,
6593
+ export=False,
5501
6594
  group="Utility",
5502
6595
  )
5503
6596
 
@@ -5522,6 +6615,7 @@ for t in scalar_types + vector_types + (bool,):
5522
6615
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
5523
6616
  group="Utility",
5524
6617
  hidden=True,
6618
+ export=False,
5525
6619
  )
5526
6620
 
5527
6621
 
@@ -5549,6 +6643,7 @@ add_builtin(
5549
6643
  doc="Prints an error to stdout if ``a`` and ``b`` are equal",
5550
6644
  group="Utility",
5551
6645
  hidden=True,
6646
+ export=False,
5552
6647
  )
5553
6648
 
5554
6649
  add_builtin(
@@ -5568,6 +6663,7 @@ add_builtin(
5568
6663
  doc="Prints an error to stdout if ``a`` and ``b`` are equal",
5569
6664
  group="Utility",
5570
6665
  hidden=True,
6666
+ export=False,
5571
6667
  )
5572
6668
 
5573
6669
  add_builtin(
@@ -5638,11 +6734,23 @@ add_builtin(
5638
6734
  group="Utility",
5639
6735
  )
5640
6736
 
6737
+
5641
6738
  # fuzzy compare for float values
6739
+ def expect_near_constraint(arg_types: Mapping[str, type]):
6740
+ if not types_equal(arg_types["a"], arg_types["b"]):
6741
+ return False
6742
+
6743
+ if hasattr(arg_types["a"], "_wp_scalar_type_"):
6744
+ return types_equal(arg_types["a"]._wp_scalar_type_, arg_types["tolerance"])
6745
+
6746
+ return types_equal(arg_types["a"], arg_types["tolerance"])
6747
+
6748
+
5642
6749
  add_builtin(
5643
6750
  "expect_near",
5644
6751
  input_types={"a": Float, "b": Float, "tolerance": Float},
5645
6752
  defaults={"tolerance": 1.0e-6},
6753
+ constraint=expect_near_constraint,
5646
6754
  value_type=None,
5647
6755
  doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
5648
6756
  group="Utility",
@@ -5651,6 +6759,7 @@ add_builtin(
5651
6759
  "expect_near",
5652
6760
  input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
5653
6761
  defaults={"tolerance": 1.0e-6},
6762
+ constraint=expect_near_constraint,
5654
6763
  value_type=None,
5655
6764
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5656
6765
  group="Utility",
@@ -5659,6 +6768,7 @@ add_builtin(
5659
6768
  "expect_near",
5660
6769
  input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
5661
6770
  defaults={"tolerance": 1.0e-6},
6771
+ constraint=expect_near_constraint,
5662
6772
  value_type=None,
5663
6773
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5664
6774
  group="Utility",
@@ -5671,6 +6781,7 @@ add_builtin(
5671
6781
  "tolerance": Float,
5672
6782
  },
5673
6783
  defaults={"tolerance": 1.0e-6},
6784
+ constraint=expect_near_constraint,
5674
6785
  value_type=None,
5675
6786
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5676
6787
  group="Utility",
@@ -6088,19 +7199,19 @@ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.boo
6088
7199
  # Tile operators
6089
7200
  def tile_unary_value_func(arg_types, arg_values):
6090
7201
  if arg_types is None:
6091
- return Tile(dtype=Any, shape=Any)
7202
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
6092
7203
 
6093
7204
  t = arg_types["x"]
6094
7205
 
6095
7206
  if not is_tile(t):
6096
7207
  raise TypeError(f"Expected tile for unary expression, got {t}")
6097
7208
 
6098
- return TileUnaryMap(t)
7209
+ return tile(dtype=t.dtype, shape=t.shape)
6099
7210
 
6100
7211
 
6101
7212
  def tile_scalar_mul_value_func(arg_types, arg_values):
6102
7213
  if arg_types is None:
6103
- return Tile(dtype=Any, shape=Any)
7214
+ return tile(dtype=Any, shape=Tuple[int, ...])
6104
7215
 
6105
7216
  x = arg_types["x"]
6106
7217
  y = arg_types["y"]
@@ -6110,19 +7221,19 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
6110
7221
  if x.dtype != y:
6111
7222
  raise TypeError(f"Scalar factor type {y} does not match tile type {x.dtype} for tile*scalar")
6112
7223
 
6113
- return TileBinaryMap(x, TileConstant(y, x.shape))
7224
+ return tile(dtype=x.dtype, shape=x.shape)
6114
7225
 
6115
7226
  # scalar*tile
6116
7227
  if is_tile(y):
6117
7228
  if y.dtype != x:
6118
7229
  raise TypeError(f"Scalar factor type {x} does not match tile type {y.dtype} for scalar*tile")
6119
7230
 
6120
- return TileBinaryMap(TileConstant(x, y.shape), y)
7231
+ return tile(dtype=y.dtype, shape=y.shape)
6121
7232
 
6122
7233
 
6123
7234
  add_builtin(
6124
7235
  "neg",
6125
- input_types={"x": Tile(dtype=Any, shape=Any)},
7236
+ input_types={"x": tile(dtype=Any, shape=Tuple[int, ...])},
6126
7237
  value_func=tile_unary_value_func,
6127
7238
  doc="Negate each element of a tile",
6128
7239
  export=False,
@@ -6132,7 +7243,7 @@ add_builtin(
6132
7243
 
6133
7244
  add_builtin(
6134
7245
  "add",
6135
- input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
7246
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
6136
7247
  value_func=tile_binary_map_value_func,
6137
7248
  # dispatch_func=tile_map_dispatch_func,
6138
7249
  # variadic=True,
@@ -6144,7 +7255,7 @@ add_builtin(
6144
7255
 
6145
7256
  add_builtin(
6146
7257
  "sub",
6147
- input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
7258
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
6148
7259
  value_func=tile_binary_map_value_func,
6149
7260
  # dispatch_func=tile_map_dispatch_func,
6150
7261
  # variadic=True,
@@ -6157,7 +7268,7 @@ add_builtin(
6157
7268
 
6158
7269
  add_builtin(
6159
7270
  "mul",
6160
- input_types={"x": Tile(dtype=Any, shape=Any), "y": Scalar},
7271
+ input_types={"x": tile(dtype=Any, shape=Tuple[int, ...]), "y": Scalar},
6161
7272
  value_func=tile_scalar_mul_value_func,
6162
7273
  doc="Multiply each element of a tile by a scalar",
6163
7274
  export=False,
@@ -6167,7 +7278,7 @@ add_builtin(
6167
7278
 
6168
7279
  add_builtin(
6169
7280
  "mul",
6170
- input_types={"x": Scalar, "y": Tile(dtype=Any, shape=Any)},
7281
+ input_types={"x": Scalar, "y": tile(dtype=Any, shape=Tuple[int, ...])},
6171
7282
  value_func=tile_scalar_mul_value_func,
6172
7283
  doc="Multiply each element of a tile by a scalar",
6173
7284
  export=False,
@@ -6176,9 +7287,48 @@ add_builtin(
6176
7287
  )
6177
7288
 
6178
7289
 
7290
+ def tile_inplace_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
7291
+ a = args["a"]
7292
+ b = args["b"]
7293
+
7294
+ a_type = input_types["a"]
7295
+ b_type = input_types["b"]
7296
+
7297
+ if a_type.shape != b_type.shape:
7298
+ raise ValueError(f"Tile inplace arguments must have the same shape, got {a_type.shape} and {b_type.shape}")
7299
+
7300
+ func_args = (a, b)
7301
+ template_args = ()
7302
+ return (func_args, template_args)
7303
+
7304
+
7305
+ add_builtin(
7306
+ "add_inplace",
7307
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
7308
+ value_type=None,
7309
+ dispatch_func=tile_inplace_dispatch_func,
7310
+ export=False,
7311
+ hidden=True,
7312
+ native_func="tile_add_inplace",
7313
+ group="Operators",
7314
+ )
7315
+
7316
+
7317
+ add_builtin(
7318
+ "sub_inplace",
7319
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
7320
+ value_type=None,
7321
+ dispatch_func=tile_inplace_dispatch_func,
7322
+ export=False,
7323
+ hidden=True,
7324
+ native_func="tile_sub_inplace",
7325
+ group="Operators",
7326
+ )
7327
+
7328
+
6179
7329
  def tile_diag_add_value_func(arg_types, arg_values):
6180
7330
  if arg_types is None:
6181
- return Tile(dtype=Any, shape=Any)
7331
+ return tile(dtype=Any, shape=Tuple[int, int])
6182
7332
 
6183
7333
  a = arg_types["a"]
6184
7334
  d = arg_types["d"]
@@ -6208,7 +7358,7 @@ def tile_diag_add_value_func(arg_types, arg_values):
6208
7358
  )
6209
7359
 
6210
7360
  # use first argument to define output type
6211
- return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
7361
+ return tile(dtype=a.dtype, shape=a.shape, layout=a.layout, strides=a.strides, storage="shared")
6212
7362
 
6213
7363
 
6214
7364
  def tile_diag_add_lto_dispatch_func(
@@ -6230,7 +7380,7 @@ def tile_diag_add_lto_dispatch_func(
6230
7380
 
6231
7381
  add_builtin(
6232
7382
  "tile_diag_add",
6233
- input_types={"a": Tile(dtype=Any, shape=Any), "d": Tile(dtype=Any, shape=Any)},
7383
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "d": tile(dtype=Any, shape=Tuple[int])},
6234
7384
  value_func=tile_diag_add_value_func,
6235
7385
  lto_dispatch_func=tile_diag_add_lto_dispatch_func,
6236
7386
  native_func="tile_diag_add",
@@ -6240,18 +7390,40 @@ add_builtin(
6240
7390
  )
6241
7391
 
6242
7392
 
6243
- ##
6244
- ## MathDx, LTOIR-based, Tile functions
6245
- ##
7393
+ ##
7394
+ ## MathDx, LTOIR-based, Tile functions
7395
+ ##
7396
+
7397
+
7398
+ ##
7399
+ ## Matmul
7400
+ ##
7401
+
7402
+
7403
+ def tile_matmul_out_value_func(arg_types, arg_values):
7404
+ # return generic type (for doc builds)
7405
+ if arg_types is None:
7406
+ return None
7407
+
7408
+ a = arg_types["a"]
7409
+ b = arg_types["b"]
7410
+
7411
+ if not is_tile(a):
7412
+ raise TypeError(f"tile_matmul() 'a' argument must be a tile, got {a!r}")
7413
+
7414
+ if not is_tile(b):
7415
+ raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
7416
+
7417
+ if not is_tile(arg_types["out"]):
7418
+ raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
7419
+
7420
+ return None
6246
7421
 
6247
7422
 
6248
- ##
6249
- ## Matmul
6250
- ##
6251
7423
  def tile_matmul_value_func(arg_types, arg_values):
6252
7424
  # return generic type (for doc builds)
6253
7425
  if arg_types is None:
6254
- return Tile(dtype=Any, shape=Any)
7426
+ return tile(dtype=Float, shape=Tuple[int, int])
6255
7427
 
6256
7428
  a = arg_types["a"]
6257
7429
  b = arg_types["b"]
@@ -6262,16 +7434,7 @@ def tile_matmul_value_func(arg_types, arg_values):
6262
7434
  if not is_tile(b):
6263
7435
  raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
6264
7436
 
6265
- # out = wp.tile_matmul(a, b)
6266
- if len(arg_types) == 2:
6267
- return Tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
6268
-
6269
- # wp.tile_matmul(a, b, out)
6270
- elif len(arg_types) == 3:
6271
- if not is_tile(arg_types["out"]):
6272
- raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
6273
-
6274
- return None
7437
+ return tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
6275
7438
 
6276
7439
 
6277
7440
  def tile_matmul_lto_dispatch_func(
@@ -6345,36 +7508,41 @@ def tile_matmul_lto_dispatch_func(
6345
7508
  num_threads,
6346
7509
  builder,
6347
7510
  )
6348
- # adjA += adjC * B^T - Transpose ~= flipped layout
6349
- (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
6350
- M,
6351
- K,
6352
- N,
6353
- out.type.dtype,
6354
- b.type.dtype,
6355
- a.type.dtype,
6356
- out.type.layout,
6357
- tile_flip_layout(b.type.layout),
6358
- a.type.layout,
6359
- arch,
6360
- num_threads,
6361
- builder,
6362
- )
6363
- # adjB += A^T * adjC - Transpose ~= flipped layout
6364
- (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
6365
- K,
6366
- N,
6367
- M,
6368
- a.type.dtype,
6369
- out.type.dtype,
6370
- b.type.dtype,
6371
- tile_flip_layout(a.type.layout),
6372
- out.type.layout,
6373
- b.type.layout,
6374
- arch,
6375
- num_threads,
6376
- builder,
6377
- )
7511
+ if warp.config.enable_backward:
7512
+ # adjA += adjC * B^T - Transpose ~= flipped layout
7513
+ (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
7514
+ M,
7515
+ K,
7516
+ N,
7517
+ out.type.dtype,
7518
+ b.type.dtype,
7519
+ a.type.dtype,
7520
+ out.type.layout,
7521
+ tile_flip_layout(b.type.layout),
7522
+ a.type.layout,
7523
+ arch,
7524
+ num_threads,
7525
+ builder,
7526
+ )
7527
+ # adjB += A^T * adjC - Transpose ~= flipped layout
7528
+ (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
7529
+ K,
7530
+ N,
7531
+ M,
7532
+ a.type.dtype,
7533
+ out.type.dtype,
7534
+ b.type.dtype,
7535
+ tile_flip_layout(a.type.layout),
7536
+ out.type.layout,
7537
+ b.type.layout,
7538
+ arch,
7539
+ num_threads,
7540
+ builder,
7541
+ )
7542
+ else:
7543
+ # adjoints aren't computed, so we reuse fun_forward as a dummy arg
7544
+ (fun_backward_A, lto_backward_A) = (fun_forward, None)
7545
+ (fun_backward_B, lto_backward_B) = (fun_forward, None)
6378
7546
 
6379
7547
  return (
6380
7548
  (
@@ -6394,11 +7562,11 @@ def tile_matmul_lto_dispatch_func(
6394
7562
  add_builtin(
6395
7563
  "tile_matmul",
6396
7564
  input_types={
6397
- "a": Tile(dtype=Any, shape=Any),
6398
- "b": Tile(dtype=Any, shape=Any),
6399
- "out": Tile(dtype=Any, shape=Any),
7565
+ "a": tile(dtype=Float, shape=Tuple[int, int]),
7566
+ "b": tile(dtype=Float, shape=Tuple[int, int]),
7567
+ "out": tile(dtype=Float, shape=Tuple[int, int]),
6400
7568
  },
6401
- value_func=tile_matmul_value_func,
7569
+ value_func=tile_matmul_out_value_func,
6402
7570
  lto_dispatch_func=tile_matmul_lto_dispatch_func,
6403
7571
  variadic=False,
6404
7572
  doc="""Computes the matrix product and accumulates ``out += a*b``.
@@ -6420,7 +7588,7 @@ add_builtin(
6420
7588
 
6421
7589
  add_builtin(
6422
7590
  "tile_matmul",
6423
- input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
7591
+ input_types={"a": tile(dtype=Float, shape=Tuple[int, int]), "b": tile(dtype=Float, shape=Tuple[int, int])},
6424
7592
  value_func=tile_matmul_value_func,
6425
7593
  lto_dispatch_func=tile_matmul_lto_dispatch_func,
6426
7594
  variadic=False,
@@ -6447,7 +7615,7 @@ add_builtin(
6447
7615
  ##
6448
7616
  def tile_fft_generic_value_func(arg_types, arg_values):
6449
7617
  if arg_types is None:
6450
- return Tile(dtype=Any, shape=Any)
7618
+ return tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int])
6451
7619
 
6452
7620
  if len(arg_types) != 1:
6453
7621
  raise TypeError(f"tile_fft() takes exactly 1 positional argument but {len(arg_types)} were given")
@@ -6475,7 +7643,7 @@ def tile_fft_generic_lto_dispatch_func(
6475
7643
  arg_values: Mapping[str, Var],
6476
7644
  options: Mapping[str, Any],
6477
7645
  builder: warp.context.ModuleBuilder,
6478
- direction: str = None,
7646
+ direction: str | None = None,
6479
7647
  ):
6480
7648
  inout = arg_values["inout"]
6481
7649
  inout.type.storage = "register"
@@ -6529,7 +7697,7 @@ def tile_fft_generic_lto_dispatch_func(
6529
7697
 
6530
7698
  add_builtin(
6531
7699
  "tile_fft",
6532
- input_types={"inout": Tile},
7700
+ input_types={"inout": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int])},
6533
7701
  value_func=tile_fft_generic_value_func,
6534
7702
  lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="forward"),
6535
7703
  variadic=True,
@@ -6550,7 +7718,7 @@ add_builtin(
6550
7718
 
6551
7719
  add_builtin(
6552
7720
  "tile_ifft",
6553
- input_types={"inout": Tile},
7721
+ input_types={"inout": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int])},
6554
7722
  value_func=tile_fft_generic_value_func,
6555
7723
  lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="inverse"),
6556
7724
  variadic=True,
@@ -6575,7 +7743,7 @@ add_builtin(
6575
7743
  ##
6576
7744
  def tile_cholesky_generic_value_func(arg_types, arg_values):
6577
7745
  if arg_types is None:
6578
- return Tile(dtype=Any, shape=Any)
7746
+ return tile(dtype=Float, shape=Tuple[int, int])
6579
7747
 
6580
7748
  if len(arg_types) != 1:
6581
7749
  raise TypeError("tile_cholesky() requires 1 positional args")
@@ -6591,15 +7759,19 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
6591
7759
  if a.shape[0] != a.shape[1]:
6592
7760
  raise ValueError("tile_cholesky() argument must be square")
6593
7761
 
6594
- return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
7762
+ return tile(dtype=a.dtype, shape=a.shape, layout=a.layout, strides=a.strides, storage="shared")
6595
7763
 
6596
7764
 
6597
- cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
7765
+ cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3, "trsm": 4}
6598
7766
 
6599
7767
  cusolver_type_map = {float32: ("wp::float32", 5), float64: ("wp::float64", 6)}
6600
7768
 
6601
7769
  cusolver_fill_mode_map = {"upper": 0, "lower": 1}
6602
7770
 
7771
+ cusolver_side_map = {"-": -1, "left": 0, "right": 1}
7772
+
7773
+ cusolver_diag_map = {"-": -1, "unit": 0, "nounit": 1}
7774
+
6603
7775
 
6604
7776
  def tile_cholesky_generic_lto_dispatch_func(
6605
7777
  arg_types: Mapping[str, type],
@@ -6623,20 +7795,20 @@ def tile_cholesky_generic_lto_dispatch_func(
6623
7795
  dtype, precision_enum = cusolver_type_map[a.type.dtype]
6624
7796
 
6625
7797
  # We already ensured a is square in tile_cholesky_generic_value_func()
6626
- M, N = a.type.shape[0], a.type.shape[1]
7798
+ M, N = a.type.shape
6627
7799
  if out.type.shape[0] != M or out.type.shape[1] != M:
6628
7800
  raise ValueError("tile_cholesky() output tile must be square")
6629
7801
 
6630
7802
  solver = "potrf"
6631
7803
  solver_enum = cusolver_function_map[solver]
6632
7804
 
6633
- # cuSOLVERDx only supports col-major input/outputs,
6634
- # so we use upper to mimic a row-major input
6635
- fill_mode = cusolver_fill_mode_map["upper"]
7805
+ side_enum = cusolver_side_map["-"]
7806
+ diag_enum = cusolver_diag_map["-"]
7807
+ fill_mode = cusolver_fill_mode_map["lower"]
6636
7808
 
6637
7809
  arch = options["output_arch"]
6638
7810
  num_threads = options["block_dim"]
6639
- parameter_list = f"({dtype}*, unsigned)"
7811
+ parameter_list = f"({dtype}*, int*)"
6640
7812
 
6641
7813
  if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6642
7814
  # CPU/no-MathDx dispatch
@@ -6646,8 +7818,13 @@ def tile_cholesky_generic_lto_dispatch_func(
6646
7818
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
6647
7819
  M,
6648
7820
  N,
7821
+ 1,
6649
7822
  solver,
6650
7823
  solver_enum,
7824
+ side_enum,
7825
+ diag_enum,
7826
+ a.type.layout,
7827
+ out.type.layout,
6651
7828
  fill_mode,
6652
7829
  arch,
6653
7830
  precision_enum,
@@ -6661,20 +7838,23 @@ def tile_cholesky_generic_lto_dispatch_func(
6661
7838
 
6662
7839
  add_builtin(
6663
7840
  "tile_cholesky",
6664
- input_types={"A": Tile},
7841
+ input_types={"A": tile(dtype=Float, shape=Tuple[int, int])},
6665
7842
  value_func=tile_cholesky_generic_value_func,
6666
7843
  lto_dispatch_func=tile_cholesky_generic_lto_dispatch_func,
6667
7844
  variadic=True,
6668
7845
  doc="""Compute the Cholesky factorization L of a matrix A.
6669
7846
  L is lower triangular and satisfies LL^T = A.
6670
7847
 
7848
+ Only the lower triangular portion of A is used for the decomposition;
7849
+ the upper triangular part may be left unspecified.
7850
+
6671
7851
  Note that computing the adjoint is not yet supported.
6672
7852
 
6673
7853
  Supported datatypes are:
6674
7854
  * float32
6675
7855
  * float64
6676
7856
 
6677
- :param A: A square, symmetric positive-definite, matrix.
7857
+ :param A: A square, symmetric positive-definite, matrix. Only the lower triangular part of A is needed; the upper part is ignored.
6678
7858
  :returns L: A square, lower triangular, matrix, such that LL^T = A""",
6679
7859
  group="Tile Primitives",
6680
7860
  export=False,
@@ -6690,30 +7870,30 @@ def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
6690
7870
  raise TypeError("tile_cholesky_solve() requires exactly 2 positional args")
6691
7871
 
6692
7872
  l = arg_types["L"]
6693
- x = arg_types["x"]
7873
+ y = arg_types["y"]
6694
7874
 
6695
7875
  if not is_tile(l):
6696
7876
  raise TypeError(f"tile_cholesky_solve() 'L' argument must be a tile, got {l!r}")
6697
7877
 
6698
- if not is_tile(x):
6699
- raise TypeError(f"tile_cholesky_solve() 'x' argument must be a tile, got {l!r}")
7878
+ if not is_tile(y):
7879
+ raise TypeError(f"tile_cholesky_solve() 'y' argument must be a tile, got {l!r}")
6700
7880
 
6701
- if not types_equal(l.dtype, x.dtype):
6702
- raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {x.dtype}")
7881
+ if not types_equal(l.dtype, y.dtype):
7882
+ raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {y.dtype}")
6703
7883
 
6704
7884
  if l.shape[0] != l.shape[1]:
6705
7885
  raise ValueError("tile_cholesky_solve() 'L' argument must be square")
6706
7886
 
6707
- if len(x.shape) != 1:
6708
- raise TypeError("tile_cholesky_solve() 'x' argument must be a 1D tile")
7887
+ if len(y.shape) > 2 or len(y.shape) < 1:
7888
+ raise TypeError("tile_cholesky_solve() 'y' argument must be a 1D or 2D tile")
6709
7889
 
6710
- if x.shape[0] != l.shape[0]:
7890
+ if y.shape[0] != l.shape[0]:
6711
7891
  raise ValueError(
6712
- f"tile_cholesky_solve() 'x' argument must have the same number of elements as the number of rows in 'L', "
6713
- f"got {x.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
7892
+ f"tile_cholesky_solve() 'y' argument must have the same number of elements as the number of rows in 'L', "
7893
+ f"got {y.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
6714
7894
  )
6715
7895
 
6716
- return Tile(dtype=l.dtype, shape=x.shape, storage="shared")
7896
+ return tile(dtype=l.dtype, shape=y.shape, layout=y.layout, strides=y.strides, storage="shared")
6717
7897
 
6718
7898
 
6719
7899
  def tile_cholesky_solve_generic_lto_dispatch_func(
@@ -6725,37 +7905,38 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
6725
7905
  builder: warp.context.ModuleBuilder,
6726
7906
  ):
6727
7907
  L = arg_values["L"]
6728
- x = arg_values["x"]
7908
+ y = arg_values["y"]
6729
7909
  # force the storage type of the input variables to shared memory
6730
7910
  L.type.storage = "shared"
6731
- x.type.storage = "shared"
7911
+ y.type.storage = "shared"
6732
7912
 
6733
7913
  if len(return_values) != 1:
6734
7914
  raise TypeError(f"tile_cholesky_solve() must return exactly one value, got {len(return_values)}")
6735
7915
 
6736
- y = return_values[0]
7916
+ x = return_values[0]
6737
7917
 
6738
- if any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]):
7918
+ if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
6739
7919
  raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
6740
7920
 
6741
7921
  dtype, precision_enum = cusolver_type_map[L.type.dtype]
6742
- M, N = L.type.shape[0], L.type.shape[1]
7922
+ M, N = L.type.shape
7923
+ NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
6743
7924
 
6744
- if len(y.type.shape) != 1:
6745
- raise TypeError("tile_cholesky_solve() output vector must be 1D")
7925
+ if len(x.type.shape) > 2 or len(x.type.shape) < 1:
7926
+ raise TypeError(f"tile_cholesky_solve() output vector must be 1D or 2D, got {len(x.type.shape)}-D")
6746
7927
 
6747
- if y.type.shape[0] != M:
7928
+ if x.type.shape[0] != M:
6748
7929
  raise ValueError(
6749
7930
  "tile_cholesky_solve() output vector must have same number of elements as the number of rows in 'L' "
6750
- f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
7931
+ f"got {x.type.shape[0]} elements in output and {M} rows in 'L'"
6751
7932
  )
6752
7933
 
6753
7934
  solver = "potrs"
6754
7935
  solver_enum = cusolver_function_map[solver]
6755
7936
 
6756
- # cuSOLVERDx only supports col-major input/outputs,
6757
- # so we use upper to mimic a row-major input
6758
- fill_mode = cusolver_fill_mode_map["upper"]
7937
+ side_enum = cusolver_side_map["-"]
7938
+ diag_enum = cusolver_diag_map["-"]
7939
+ fill_mode = cusolver_fill_mode_map["lower"]
6759
7940
 
6760
7941
  arch = options["output_arch"]
6761
7942
  num_threads = options["block_dim"]
@@ -6763,14 +7944,19 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
6763
7944
 
6764
7945
  if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6765
7946
  # CPU/no-MathDx dispatch
6766
- return ((0, L, x, y), [], [], 0)
7947
+ return ((0, L, y, x), [], [], 0)
6767
7948
  else:
6768
7949
  # generate the LTO
6769
7950
  lto_symbol, lto_code_data = warp.build.build_lto_solver(
6770
7951
  M,
6771
7952
  N,
7953
+ NRHS,
6772
7954
  solver,
6773
7955
  solver_enum,
7956
+ side_enum,
7957
+ diag_enum,
7958
+ L.type.layout,
7959
+ y.type.layout,
6774
7960
  fill_mode,
6775
7961
  arch,
6776
7962
  precision_enum,
@@ -6779,12 +7965,12 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
6779
7965
  builder,
6780
7966
  )
6781
7967
 
6782
- return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
7968
+ return ((Var(lto_symbol, str, False, True, False), L, y, x), [], [lto_code_data], 0)
6783
7969
 
6784
7970
 
6785
7971
  add_builtin(
6786
7972
  "tile_cholesky_solve",
6787
- input_types={"L": Tile, "x": Tile},
7973
+ input_types={"L": tile(dtype=Float, shape=Tuple[int, int]), "y": tile(dtype=Float, shape=Tuple[int])},
6788
7974
  value_func=tile_cholesky_solve_generic_value_func,
6789
7975
  lto_dispatch_func=tile_cholesky_solve_generic_lto_dispatch_func,
6790
7976
  variadic=True,
@@ -6797,13 +7983,276 @@ add_builtin(
6797
7983
  * float64
6798
7984
 
6799
7985
  :param L: A square, lower triangular, matrix, such that LL^T = A
6800
- :param x: An 1D tile of length M
6801
- :returns y: An 1D tile of length M such that LL^T y = x""",
7986
+ :param y: A 1D or 2D tile of length M
7987
+ :returns x: A tile of the same shape as y such that LL^T x = y""",
7988
+ group="Tile Primitives",
7989
+ export=False,
7990
+ namespace="",
7991
+ )
7992
+
7993
+
7994
+ def tile_lower_solve_generic_lto_dispatch_func(
7995
+ arg_types: Mapping[str, type],
7996
+ return_type: Any,
7997
+ return_values: List[Var],
7998
+ arg_values: Mapping[str, Var],
7999
+ options: Mapping[str, Any],
8000
+ builder: warp.context.ModuleBuilder,
8001
+ ):
8002
+ L = arg_values["L"]
8003
+ y = arg_values["y"]
8004
+ # force the storage type of the input variables to shared memory
8005
+ L.type.storage = "shared"
8006
+ y.type.storage = "shared"
8007
+
8008
+ if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
8009
+ raise TypeError("tile_lower_solve() arguments must be tiles of float64 or float32")
8010
+
8011
+ if len(return_values) != 1:
8012
+ raise TypeError(f"tile_lower_solve() must return exactly one value, got {len(return_values)}")
8013
+
8014
+ z = return_values[0]
8015
+
8016
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
8017
+ M, N = L.type.shape
8018
+ NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
8019
+
8020
+ if len(z.type.shape) > 2 or len(z.type.shape) < 1:
8021
+ raise TypeError(f"tile_lower_solve() output vector must be 1D or 2D, got {len(z.type.shape)}-D")
8022
+
8023
+ if z.type.shape[0] != M:
8024
+ raise ValueError(
8025
+ "tile_lower_solve() output vector must have same number of elements as the number of rows in 'L' "
8026
+ f"got {z.type.shape[0]} elements in output and {M} rows in 'L'"
8027
+ )
8028
+
8029
+ solver = "trsm"
8030
+ solver_enum = cusolver_function_map[solver]
8031
+
8032
+ side_enum = cusolver_side_map["left"]
8033
+ diag_enum = cusolver_diag_map["nounit"]
8034
+ fill_mode = cusolver_fill_mode_map["lower"]
8035
+
8036
+ arch = options["output_arch"]
8037
+ num_threads = options["block_dim"]
8038
+ parameter_list = f"({dtype}*, {dtype}*)"
8039
+
8040
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8041
+ # CPU/no-MathDx dispatch
8042
+ return ((0, L, y, z), [], [], 0)
8043
+ else:
8044
+ # generate the LTO
8045
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
8046
+ M,
8047
+ N,
8048
+ NRHS,
8049
+ solver,
8050
+ solver_enum,
8051
+ side_enum,
8052
+ diag_enum,
8053
+ L.type.layout,
8054
+ y.type.layout,
8055
+ fill_mode,
8056
+ arch,
8057
+ precision_enum,
8058
+ num_threads,
8059
+ parameter_list,
8060
+ builder,
8061
+ )
8062
+
8063
+ return ((Var(lto_symbol, str, False, True, False), L, y, z), [], [lto_code_data], 0)
8064
+
8065
+
8066
+ def tile_lower_solve_generic_value_func(arg_types, arg_values):
8067
+ if arg_types is None:
8068
+ return tile(dtype=Float, shape=Tuple[int])
8069
+
8070
+ if len(arg_types) != 2:
8071
+ raise TypeError("tile_lower_solve() requires exactly 2 positional args")
8072
+
8073
+ l = arg_types["L"]
8074
+ y = arg_types["y"]
8075
+
8076
+ if not is_tile(l):
8077
+ raise TypeError(f"tile_lower_solve() 'L' argument must be a tile, got {l!r}")
8078
+
8079
+ if not is_tile(y):
8080
+ raise TypeError(f"tile_lower_solve() 'y' argument must be a tile, got {y!r}")
8081
+
8082
+ if not types_equal(l.dtype, y.dtype):
8083
+ raise TypeError(f"tile_lower_solve() arguments must have the same dtype, got {l.dtype} and {y.dtype}")
8084
+
8085
+ if l.shape[0] != l.shape[1]:
8086
+ raise ValueError("tile_lower_solve() 'L' argument must be square")
8087
+
8088
+ if len(y.shape) > 2 or len(y.shape) < 1:
8089
+ raise TypeError("tile_lower_solve() 'y' argument must be a 1D or 2D tile")
8090
+
8091
+ if y.shape[0] != l.shape[0]:
8092
+ raise ValueError(
8093
+ f"tile_lower_solve() 'y' argument must have the same number of elements as the number of rows in 'L', "
8094
+ f"got {y.shape[0]} elements in 'y' and {l.shape[0]} rows in 'L'"
8095
+ )
8096
+
8097
+ return tile(dtype=l.dtype, shape=y.shape, layout=y.layout, strides=y.strides, storage="shared")
8098
+
8099
+
8100
+ add_builtin(
8101
+ "tile_lower_solve",
8102
+ input_types={"L": tile(dtype=Float, shape=Tuple[int, int]), "y": tile(dtype=Float, shape=Tuple[int])},
8103
+ value_func=tile_lower_solve_generic_value_func,
8104
+ lto_dispatch_func=tile_lower_solve_generic_lto_dispatch_func,
8105
+ variadic=True,
8106
+ doc="""Solve for z in Lz = y, where L is a lower triangular matrix.
8107
+
8108
+ This performs general forward substitution for a lower triangular system.
8109
+
8110
+ Note that computing the adjoint is not yet supported.
8111
+
8112
+ Supported datatypes are:
8113
+ * float32
8114
+ * float64
8115
+
8116
+ :param L: A square, non-singular, lower triangular matrix
8117
+ :param y: A 1D or 2D tile with compatible shape
8118
+ :returns z: A tile of the same shape as y such that Lz = y""",
8119
+ group="Tile Primitives",
8120
+ export=False,
8121
+ namespace="",
8122
+ )
8123
+
8124
+
8125
+ def tile_upper_solve_generic_lto_dispatch_func(
8126
+ arg_types: Mapping[str, type],
8127
+ return_type: Any,
8128
+ return_values: List[Var],
8129
+ arg_values: Mapping[str, Var],
8130
+ options: Mapping[str, Any],
8131
+ builder: warp.context.ModuleBuilder,
8132
+ ):
8133
+ U = arg_values["U"]
8134
+ z = arg_values["z"]
8135
+ # force the storage type of the input variables to shared memory
8136
+ U.type.storage = "shared"
8137
+ z.type.storage = "shared"
8138
+
8139
+ if any(T not in cusolver_type_map.keys() for T in [z.type.dtype, U.type.dtype]):
8140
+ raise TypeError("tile_upper_solve() arguments must be tiles of float64 or float32")
8141
+
8142
+ if len(return_values) != 1:
8143
+ raise TypeError(f"tile_upper_solve() must return exactly one value, got {len(return_values)}")
8144
+
8145
+ x = return_values[0]
8146
+
8147
+ dtype, precision_enum = cusolver_type_map[U.type.dtype]
8148
+ M, N = U.type.shape
8149
+ NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
8150
+
8151
+ if len(z.type.shape) > 2 or len(z.type.shape) < 1:
8152
+ raise TypeError(f"tile_upper_solve() output tile must be 1D or 2D, got {len(z.type.shape)}-D")
8153
+
8154
+ if z.type.shape[0] != M:
8155
+ raise ValueError(
8156
+ "tile_upper_solve() output tile must have same number of elements as the number of rows in 'U' "
8157
+ f"got {z.type.shape[0]} elements in output and {M} rows in 'U'"
8158
+ )
8159
+
8160
+ solver = "trsm"
8161
+ solver_enum = cusolver_function_map[solver]
8162
+
8163
+ side_enum = cusolver_side_map["left"]
8164
+ diag_enum = cusolver_diag_map["nounit"]
8165
+ fill_mode = cusolver_fill_mode_map["upper"]
8166
+
8167
+ arch = options["output_arch"]
8168
+ num_threads = options["block_dim"]
8169
+ parameter_list = f"({dtype}*, {dtype}*)"
8170
+
8171
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
8172
+ # CPU/no-MathDx dispatch
8173
+ return ((0, U, z, x), [], [], 0)
8174
+ else:
8175
+ # generate the LTO
8176
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
8177
+ M,
8178
+ N,
8179
+ NRHS,
8180
+ solver,
8181
+ solver_enum,
8182
+ side_enum,
8183
+ diag_enum,
8184
+ U.type.layout,
8185
+ z.type.layout,
8186
+ fill_mode,
8187
+ arch,
8188
+ precision_enum,
8189
+ num_threads,
8190
+ parameter_list,
8191
+ builder,
8192
+ )
8193
+
8194
+ return ((Var(lto_symbol, str, False, True, False), U, z, x), [], [lto_code_data], 0)
8195
+
8196
+
8197
+ def tile_upper_solve_generic_value_func(arg_types, arg_values):
8198
+ if arg_types is None:
8199
+ return tile(dtype=Float, shape=Tuple[int])
8200
+
8201
+ if len(arg_types) != 2:
8202
+ raise TypeError("tile_upper_solve() requires exactly 2 positional args")
8203
+
8204
+ u = arg_types["U"]
8205
+ z = arg_types["z"]
8206
+
8207
+ if not is_tile(u):
8208
+ raise TypeError(f"tile_upper_solve() 'U' argument must be a tile, got {u!r}")
8209
+
8210
+ if not is_tile(z):
8211
+ raise TypeError(f"tile_upper_solve() 'z' argument must be a tile, got {z!r}")
8212
+
8213
+ if not types_equal(u.dtype, z.dtype):
8214
+ raise TypeError(f"tile_upper_solve() arguments must have the same dtype, got {u.dtype} and {z.dtype}")
8215
+
8216
+ if u.shape[0] != u.shape[1]:
8217
+ raise ValueError("tile_upper_solve() 'U' argument must be square")
8218
+
8219
+ if len(z.shape) > 2 or len(z.shape) < 1:
8220
+ raise TypeError("tile_upper_solve() 'z' argument must be a 1D or 2D tile")
8221
+
8222
+ if z.shape[0] != u.shape[0]:
8223
+ raise ValueError(
8224
+ f"tile_upper_solve() 'z' argument must have the same number of elements as the number of rows in 'U', "
8225
+ f"got {z.shape[0]} elements in 'z' and {u.shape[0]} rows in 'U'"
8226
+ )
8227
+
8228
+ return tile(dtype=u.dtype, shape=z.shape, layout=z.layout, strides=z.strides, storage="shared")
8229
+
8230
+
8231
+ add_builtin(
8232
+ "tile_upper_solve",
8233
+ input_types={"U": tile(dtype=Float, shape=Tuple[int, int]), "z": tile(dtype=Float, shape=Tuple[int])},
8234
+ value_func=tile_upper_solve_generic_value_func,
8235
+ lto_dispatch_func=tile_upper_solve_generic_lto_dispatch_func,
8236
+ variadic=True,
8237
+ doc="""Solve for x in U x = z, where U is an upper triangular matrix.
8238
+
8239
+ This performs general back substitution for upper triangular systems.
8240
+
8241
+ Note that computing the adjoint is not yet supported.
8242
+
8243
+ Supported datatypes are:
8244
+ * float32
8245
+ * float64
8246
+
8247
+ :param U: A square, non-singular, upper triangular matrix
8248
+ :param z: A 1D or 2D tile with compatible shape
8249
+ :returns x: A tile of the same shape as z such that U x = z""",
6802
8250
  group="Tile Primitives",
6803
8251
  export=False,
6804
8252
  namespace="",
6805
8253
  )
6806
8254
 
8255
+
6807
8256
  # ---------------------------------
6808
8257
  # Code Generation
6809
8258
 
@@ -6840,7 +8289,7 @@ def static(expr):
6840
8289
  add_builtin(
6841
8290
  "len",
6842
8291
  input_types={"a": vector(length=Any, dtype=Scalar)},
6843
- value_type=int,
8292
+ value_func=static_len_value_func,
6844
8293
  doc="Return the number of elements in a vector.",
6845
8294
  group="Utility",
6846
8295
  export=False,
@@ -6849,7 +8298,7 @@ add_builtin(
6849
8298
  add_builtin(
6850
8299
  "len",
6851
8300
  input_types={"a": quaternion(dtype=Scalar)},
6852
- value_type=int,
8301
+ value_func=static_len_value_func,
6853
8302
  doc="Return the number of elements in a quaternion.",
6854
8303
  group="Utility",
6855
8304
  export=False,
@@ -6858,7 +8307,7 @@ add_builtin(
6858
8307
  add_builtin(
6859
8308
  "len",
6860
8309
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
6861
- value_type=int,
8310
+ value_func=static_len_value_func,
6862
8311
  doc="Return the number of rows in a matrix.",
6863
8312
  group="Utility",
6864
8313
  export=False,
@@ -6867,7 +8316,7 @@ add_builtin(
6867
8316
  add_builtin(
6868
8317
  "len",
6869
8318
  input_types={"a": transformation(dtype=Float)},
6870
- value_type=int,
8319
+ value_func=static_len_value_func,
6871
8320
  doc="Return the number of elements in a transformation.",
6872
8321
  group="Utility",
6873
8322
  export=False,
@@ -6884,9 +8333,83 @@ add_builtin(
6884
8333
 
6885
8334
  add_builtin(
6886
8335
  "len",
6887
- input_types={"a": Tile(dtype=Any, shape=Any)},
6888
- value_type=int,
8336
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...])},
8337
+ value_func=static_len_value_func,
6889
8338
  doc="Return the number of rows in a tile.",
6890
8339
  group="Utility",
6891
8340
  export=False,
6892
8341
  )
8342
+
8343
+
8344
+ # ---------------------------------
8345
+ # Tuple
8346
+
8347
+
8348
+ def tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
8349
+ return tuple_t(arg_types["args"], arg_values["args"])
8350
+
8351
+
8352
+ def tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
8353
+ func_args = args.get("args", ())
8354
+ template_args = ()
8355
+ return (func_args, template_args)
8356
+
8357
+
8358
+ add_builtin(
8359
+ "tuple",
8360
+ input_types={"*args": Any},
8361
+ value_func=tuple_value_func,
8362
+ dispatch_func=tuple_dispatch_func,
8363
+ variadic=True,
8364
+ doc="Construct a tuple from a list of values",
8365
+ group="Utility",
8366
+ hidden=True,
8367
+ missing_grad=True,
8368
+ export=False,
8369
+ )
8370
+
8371
+
8372
+ def tuple_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
8373
+ tuple_type = arg_types["a"]
8374
+ elements = tuple_type.types if is_tuple(tuple_type) else tuple_type
8375
+
8376
+ if "i" not in arg_values:
8377
+ raise RuntimeError("Tuple index must be a compile time expression.")
8378
+
8379
+ index = arg_values["i"]
8380
+ if isinstance(index, Var):
8381
+ raise RuntimeError("Tuple index must be a compile time expression.")
8382
+
8383
+ length = len(elements)
8384
+ if index >= length:
8385
+ raise RuntimeError(f"Tuple index out of bounds, {index} >= {length}")
8386
+
8387
+ value_type = elements[index]
8388
+ return value_type
8389
+
8390
+
8391
+ def tuple_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
8392
+ func_args = (args["a"],)
8393
+ template_args = (args["i"].constant,)
8394
+ return (func_args, template_args)
8395
+
8396
+
8397
+ add_builtin(
8398
+ "extract",
8399
+ input_types={"a": Tuple, "i": int},
8400
+ value_func=tuple_extract_value_func,
8401
+ dispatch_func=tuple_extract_dispatch_func,
8402
+ group="Utility",
8403
+ hidden=True,
8404
+ missing_grad=True,
8405
+ )
8406
+
8407
+
8408
+ add_builtin(
8409
+ "len",
8410
+ input_types={"a": Tuple},
8411
+ value_func=static_len_value_func,
8412
+ doc="Return the number of elements in a tuple.",
8413
+ group="Utility",
8414
+ export=False,
8415
+ )