warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -15,10 +15,10 @@
15
15
 
16
16
  import builtins
17
17
  import functools
18
- import tempfile
19
- from pathlib import Path
20
18
  from typing import Any, Callable, Mapping, Sequence
21
19
 
20
+ import warp.build
21
+ import warp.context
22
22
  from warp.codegen import Reference, Var, strip_reference
23
23
  from warp.types import *
24
24
 
@@ -41,7 +41,7 @@ def sametypes(arg_types: Mapping[str, Any]):
41
41
  return all(types_equal(arg_type_0, t) for t in arg_types_iter)
42
42
 
43
43
 
44
- def sametypes_create_value_func(default):
44
+ def sametypes_create_value_func(default: TypeVar):
45
45
  def fn(arg_types, arg_values):
46
46
  if arg_types is None:
47
47
  return default
@@ -399,7 +399,7 @@ add_builtin(
399
399
  )
400
400
 
401
401
 
402
- def scalar_infer_type(arg_types: Mapping[str, type]):
402
+ def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
403
403
  if arg_types is None:
404
404
  return Scalar
405
405
 
@@ -950,6 +950,12 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
950
950
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
951
951
 
952
952
  if all(type_is_vector(x) for x in variadic_arg_types):
953
+ warp.utils.warn(
954
+ "the built-in `wp.matrix()` won't support taking column vectors as input "
955
+ "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
956
+ DeprecationWarning,
957
+ )
958
+
953
959
  if shape[1] != variadic_arg_count:
954
960
  raise RuntimeError(
955
961
  f"incompatible number of column vectors given ({variadic_arg_count}) "
@@ -1030,6 +1036,86 @@ add_builtin(
1030
1036
  )
1031
1037
 
1032
1038
 
1039
+ def matrix_from_vecs_create_value_func(cols: bool):
1040
+ def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1041
+ if arg_types is None:
1042
+ return matrix(shape=(Any, Any), dtype=Scalar)
1043
+
1044
+ variadic_arg_types = arg_types.get("args", ())
1045
+ variadic_arg_count = len(variadic_arg_types)
1046
+
1047
+ if not all(type_is_vector(x) for x in variadic_arg_types):
1048
+ raise RuntimeError("all arguments are expected to be vectors")
1049
+
1050
+ length = variadic_arg_types[0]._length_
1051
+ if any(x._length_ != length for x in variadic_arg_types):
1052
+ raise RuntimeError("all vectors are expected to have the same length")
1053
+
1054
+ dtype = variadic_arg_types[0]._wp_scalar_type_
1055
+ if any(x._wp_scalar_type_ != dtype for x in variadic_arg_types):
1056
+ raise RuntimeError("all vectors are expected to have the same dtype")
1057
+
1058
+ shape = (length, variadic_arg_count) if cols else (variadic_arg_count, length)
1059
+ return matrix(shape=shape, dtype=dtype)
1060
+
1061
+ return fn
1062
+
1063
+
1064
+ def matrix_from_vecs_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1065
+ # We're in the codegen stage where we emit the code calling the built-in.
1066
+ # Further validate the given argument values if needed and map them
1067
+ # to the underlying C++ function's runtime and template params.
1068
+
1069
+ shape = return_type._shape_
1070
+ dtype = return_type._wp_scalar_type_
1071
+
1072
+ variadic_args = args.get("args", ())
1073
+
1074
+ func_args = variadic_args
1075
+
1076
+ if shape in ((2, 2), (3, 3), (4, 4)):
1077
+ # Template specializations exist for these shapes, don't pass them
1078
+ # as template parameters.
1079
+ template_args = (dtype,)
1080
+ else:
1081
+ template_args = (*shape, dtype)
1082
+
1083
+ return (func_args, template_args)
1084
+
1085
+
1086
+ def matrix_from_vecs_initializer_list_func(args, return_type):
1087
+ shape = return_type._shape_
1088
+
1089
+ return shape[0] != shape[1] or shape[0] > 4
1090
+
1091
+
1092
+ add_builtin(
1093
+ "matrix_from_cols",
1094
+ input_types={"*args": vector(length=Any, dtype=Scalar)},
1095
+ variadic=True,
1096
+ value_func=matrix_from_vecs_create_value_func(cols=True),
1097
+ dispatch_func=matrix_from_vecs_dispatch_func,
1098
+ initializer_list_func=matrix_from_vecs_initializer_list_func,
1099
+ native_func="matrix_from_cols",
1100
+ doc="Construct a matrix from column vectors.",
1101
+ group="Vector Math",
1102
+ export=False,
1103
+ )
1104
+
1105
+ add_builtin(
1106
+ "matrix_from_rows",
1107
+ input_types={"*args": vector(length=Any, dtype=Scalar)},
1108
+ variadic=True,
1109
+ value_func=matrix_from_vecs_create_value_func(cols=False),
1110
+ dispatch_func=matrix_from_vecs_dispatch_func,
1111
+ initializer_list_func=matrix_from_vecs_initializer_list_func,
1112
+ native_func="matrix_from_rows",
1113
+ doc="Construct a matrix from row vectors.",
1114
+ group="Vector Math",
1115
+ export=False,
1116
+ )
1117
+
1118
+
1033
1119
  def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1034
1120
  if arg_types is None:
1035
1121
  return matrix(shape=(Any, Any), dtype=Scalar)
@@ -1141,6 +1227,21 @@ add_builtin(
1141
1227
  while the left and right basis vectors are returned in ``U`` and ``V``.""",
1142
1228
  )
1143
1229
 
1230
+ add_builtin(
1231
+ "svd2",
1232
+ input_types={
1233
+ "A": matrix(shape=(2, 2), dtype=Float),
1234
+ "U": matrix(shape=(2, 2), dtype=Float),
1235
+ "sigma": vector(length=2, dtype=Float),
1236
+ "V": matrix(shape=(2, 2), dtype=Scalar),
1237
+ },
1238
+ value_type=None,
1239
+ group="Vector Math",
1240
+ export=False,
1241
+ doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
1242
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
1243
+ )
1244
+
1144
1245
  add_builtin(
1145
1246
  "qr3",
1146
1247
  input_types={
@@ -1332,7 +1433,18 @@ add_builtin(
1332
1433
  input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
1333
1434
  value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1334
1435
  group="Quaternion Math",
1335
- doc="Construct a quaternion from a 3x3 matrix.",
1436
+ doc="""Construct a quaternion from a 3x3 matrix.
1437
+
1438
+ If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
1439
+ )
1440
+ add_builtin(
1441
+ "quat_from_matrix",
1442
+ input_types={"mat": matrix(shape=(4, 4), dtype=Float)},
1443
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1444
+ group="Quaternion Math",
1445
+ doc="""Construct a quaternion from a 4x4 matrix.
1446
+
1447
+ If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
1336
1448
  )
1337
1449
  add_builtin(
1338
1450
  "quat_rpy",
@@ -2375,7 +2487,7 @@ add_builtin(
2375
2487
 
2376
2488
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2377
2489
 
2378
- * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
2490
+ * If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
2379
2491
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2380
2492
 
2381
2493
  :param x: A per-thread local value, e.g. scalar, vector, or matrix.
@@ -2669,11 +2781,9 @@ def tile_broadcast_value_func(arg_types, arg_values):
2669
2781
  def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2670
2782
  tile = arg_values["a"]
2671
2783
 
2672
- template_args = []
2673
- template_args.append(return_type.shape[0])
2674
- template_args.append(return_type.shape[1])
2675
- template_args.append(return_type.strides[0])
2676
- template_args.append(return_type.strides[1])
2784
+ assert len(return_type.shape) == len(return_type.strides)
2785
+ assert 1 <= len(return_type.shape) <= 4
2786
+ template_args = [*return_type.shape, *return_type.strides]
2677
2787
 
2678
2788
  return ((tile,), template_args)
2679
2789
 
@@ -2686,52 +2796,13 @@ add_builtin(
2686
2796
  variadic=False,
2687
2797
  doc="""Broadcast a tile.
2688
2798
 
2689
- This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
2690
-
2799
+ Broadcasts the input tile ``a`` to the destination shape.
2691
2800
  Broadcasting follows NumPy broadcast rules.
2692
2801
 
2693
2802
  :param a: Tile to broadcast
2694
2803
  :param shape: The shape to broadcast to
2695
- :returns: Tile with broadcast ``shape=(m, n)``""",
2696
- group="Tile Primitives",
2697
- export=False,
2698
- )
2699
-
2700
-
2701
- def tile_matmul_value_func(arg_types, arg_values):
2702
- # return generic type (for doc builds)
2703
- if arg_types is None:
2704
- return Tile(dtype=Any, shape=Any)
2705
-
2706
- if len(arg_types) != 3:
2707
- raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
2708
-
2709
- return None
2710
-
2711
-
2712
- def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2713
- a = arg_values["a"]
2714
- b = arg_values["b"]
2715
- out = arg_values["out"]
2716
-
2717
- # force the storage type of the input variables to shared memory
2718
- a.type.storage = "shared"
2719
- b.type.storage = "shared"
2720
- out.type.storage = "shared"
2721
-
2722
- template_args = []
2723
- return ((a, b, out), template_args)
2724
-
2725
-
2726
- add_builtin(
2727
- "tile_matmul_scalar",
2728
- input_types={"a": Tile, "b": Tile, "out": Tile},
2729
- value_func=tile_matmul_value_func,
2730
- dispatch_func=tile_matmul_dispatch_func,
2731
- variadic=True,
2732
- doc="Compute matrix product and accumulate out += a*b.",
2804
+ :returns: Tile with broadcast shape""",
2733
2805
  group="Tile Primitives",
2734
- hidden=True,
2735
2806
  export=False,
2736
2807
  )
2737
2808
 
@@ -3030,7 +3101,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
3030
3101
 
3031
3102
  for i in range(len(a.shape)):
3032
3103
  if a.shape[i] != b.shape[i]:
3033
- raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
3104
+ raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
3034
3105
 
3035
3106
  return TileBinaryMap(a, b)
3036
3107
 
@@ -3807,6 +3878,18 @@ _volume_supported_value_types = {
3807
3878
  }
3808
3879
 
3809
3880
 
3881
+ def _is_volume_type_supported(dtype):
3882
+ for typ in _volume_supported_value_types:
3883
+ if types_equal(typ, dtype):
3884
+ return True
3885
+ return False
3886
+
3887
+
3888
+ def _check_volume_type_is_supported(dtype):
3889
+ if not _is_volume_type_supported(dtype):
3890
+ raise RuntimeError(f"unsupported volume type `{type_repr(dtype)}`")
3891
+
3892
+
3810
3893
  def check_volume_value_grad_compatibility(dtype, grad_dtype):
3811
3894
  if type_is_vector(dtype):
3812
3895
  expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
@@ -3822,9 +3905,7 @@ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
3822
3905
  return Any
3823
3906
 
3824
3907
  dtype = arg_values["dtype"]
3825
-
3826
- if dtype not in _volume_supported_value_types:
3827
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3908
+ _check_volume_type_is_supported(dtype)
3828
3909
 
3829
3910
  return dtype
3830
3911
 
@@ -3860,9 +3941,7 @@ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Map
3860
3941
  return Any
3861
3942
 
3862
3943
  dtype = arg_values["dtype"]
3863
-
3864
- if dtype not in _volume_supported_value_types:
3865
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3944
+ _check_volume_type_is_supported(dtype)
3866
3945
 
3867
3946
  check_volume_value_grad_compatibility(dtype, arg_types["grad"])
3868
3947
 
@@ -3900,9 +3979,7 @@ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
3900
3979
  return Any
3901
3980
 
3902
3981
  dtype = arg_values["dtype"]
3903
-
3904
- if dtype not in _volume_supported_value_types:
3905
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3982
+ _check_volume_type_is_supported(dtype)
3906
3983
 
3907
3984
  return dtype
3908
3985
 
@@ -3939,9 +4016,7 @@ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[s
3939
4016
  return None
3940
4017
 
3941
4018
  dtype = arg_types["value"]
3942
-
3943
- if dtype not in _volume_supported_value_types:
3944
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
4019
+ _check_volume_type_is_supported(dtype)
3945
4020
 
3946
4021
  return None
3947
4022
 
@@ -4191,6 +4266,20 @@ add_builtin(
4191
4266
  group="Random",
4192
4267
  doc="Return a random integer between [low, high).",
4193
4268
  )
4269
+ add_builtin(
4270
+ "randu",
4271
+ input_types={"state": uint32},
4272
+ value_type=uint32,
4273
+ group="Random",
4274
+ doc="Return a random unsigned integer in the range [0, 2^32).",
4275
+ )
4276
+ add_builtin(
4277
+ "randu",
4278
+ input_types={"state": uint32, "low": uint32, "high": uint32},
4279
+ value_type=uint32,
4280
+ group="Random",
4281
+ doc="Return a random unsigned integer between [low, high).",
4282
+ )
4194
4283
  add_builtin(
4195
4284
  "randf",
4196
4285
  input_types={"state": uint32},
@@ -4499,11 +4588,31 @@ add_builtin(
4499
4588
  export=False,
4500
4589
  group="Utility",
4501
4590
  )
4591
+
4592
+
4593
+ def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
4594
+ warp.utils.warn(
4595
+ "wp.select() is deprecated and will be removed in a future\n"
4596
+ "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
4597
+ category=DeprecationWarning,
4598
+ )
4599
+
4600
+ func_args = tuple(args.values())
4601
+ template_args = ()
4602
+
4603
+ return (func_args, template_args)
4604
+
4605
+
4502
4606
  add_builtin(
4503
4607
  "select",
4504
4608
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
4505
4609
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4506
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
4610
+ dispatch_func=select_dispatch_func,
4611
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
4612
+
4613
+ .. deprecated:: 1.7
4614
+ Use :func:`where` instead, which has the more intuitive argument order:
4615
+ ``where(cond, value_if_true, value_if_false)``.""",
4507
4616
  group="Utility",
4508
4617
  )
4509
4618
  for t in int_types:
@@ -4511,14 +4620,47 @@ for t in int_types:
4511
4620
  "select",
4512
4621
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
4513
4622
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4514
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
4623
+ dispatch_func=select_dispatch_func,
4624
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
4625
+
4626
+ .. deprecated:: 1.7
4627
+ Use :func:`where` instead, which has the more intuitive argument order:
4628
+ ``where(cond, value_if_true, value_if_false)``.""",
4515
4629
  group="Utility",
4516
4630
  )
4517
4631
  add_builtin(
4518
4632
  "select",
4519
4633
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
4520
4634
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4521
- doc="Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``",
4635
+ dispatch_func=select_dispatch_func,
4636
+ doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
4637
+
4638
+ .. deprecated:: 1.7
4639
+ Use :func:`where` instead, which has the more intuitive argument order:
4640
+ ``where(arr, value_if_true, value_if_false)``.""",
4641
+ group="Utility",
4642
+ )
4643
+
4644
+ add_builtin(
4645
+ "where",
4646
+ input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
4647
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4648
+ doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
4649
+ group="Utility",
4650
+ )
4651
+ for t in int_types:
4652
+ add_builtin(
4653
+ "where",
4654
+ input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
4655
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4656
+ doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
4657
+ group="Utility",
4658
+ )
4659
+ add_builtin(
4660
+ "where",
4661
+ input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
4662
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4663
+ doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
4522
4664
  group="Utility",
4523
4665
  )
4524
4666
 
@@ -5112,33 +5254,51 @@ add_builtin(
5112
5254
  )
5113
5255
 
5114
5256
 
5257
+ # implements vector[index] = value
5258
+ add_builtin(
5259
+ "assign_inplace",
5260
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5261
+ value_type=None,
5262
+ hidden=True,
5263
+ group="Utility",
5264
+ )
5265
+
5266
+ # implements quaternion[index] = value
5267
+ add_builtin(
5268
+ "assign_inplace",
5269
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5270
+ value_type=None,
5271
+ hidden=True,
5272
+ group="Utility",
5273
+ )
5274
+
5275
+
5115
5276
  def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5116
5277
  vec_type = arg_types["a"]
5117
5278
  return vec_type
5118
5279
 
5119
5280
 
5120
- # implements vector[index] = value
5281
+ # implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
5121
5282
  add_builtin(
5122
- "assign",
5283
+ "assign_copy",
5123
5284
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5124
5285
  value_func=vector_assign_value_func,
5125
5286
  hidden=True,
5126
5287
  group="Utility",
5127
5288
  )
5128
5289
 
5129
- # implements quaternion[index] = value
5290
+ # implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
5130
5291
  add_builtin(
5131
- "assign",
5292
+ "assign_copy",
5132
5293
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5133
5294
  value_func=vector_assign_value_func,
5134
5295
  hidden=True,
5135
5296
  group="Utility",
5136
5297
  )
5137
5298
 
5138
-
5139
5299
  # implements vector[idx] += scalar
5140
5300
  add_builtin(
5141
- "augassign_add",
5301
+ "add_inplace",
5142
5302
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5143
5303
  value_type=None,
5144
5304
  hidden=True,
@@ -5147,7 +5307,7 @@ add_builtin(
5147
5307
 
5148
5308
  # implements quaternion[idx] += scalar
5149
5309
  add_builtin(
5150
- "augassign_add",
5310
+ "add_inplace",
5151
5311
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5152
5312
  value_type=None,
5153
5313
  hidden=True,
@@ -5156,7 +5316,7 @@ add_builtin(
5156
5316
 
5157
5317
  # implements vector[idx] -= scalar
5158
5318
  add_builtin(
5159
- "augassign_sub",
5319
+ "sub_inplace",
5160
5320
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5161
5321
  value_type=None,
5162
5322
  hidden=True,
@@ -5165,7 +5325,7 @@ add_builtin(
5165
5325
 
5166
5326
  # implements quaternion[idx] -= scalar
5167
5327
  add_builtin(
5168
- "augassign_sub",
5328
+ "sub_inplace",
5169
5329
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5170
5330
  value_type=None,
5171
5331
  hidden=True,
@@ -5209,11 +5369,6 @@ add_builtin(
5209
5369
  )
5210
5370
 
5211
5371
 
5212
- def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5213
- mat_type = arg_types["a"]
5214
- return mat_type
5215
-
5216
-
5217
5372
  def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5218
5373
  mat_size = arg_types["a"]._shape_[0]
5219
5374
  vec_size = arg_types["value"]._length_
@@ -5224,7 +5379,33 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5224
5379
 
5225
5380
  # implements matrix[i,j] = scalar
5226
5381
  add_builtin(
5227
- "assign",
5382
+ "assign_inplace",
5383
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5384
+ value_type=None,
5385
+ hidden=True,
5386
+ group="Utility",
5387
+ )
5388
+
5389
+
5390
+ # implements matrix[i] = vector
5391
+ add_builtin(
5392
+ "assign_inplace",
5393
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5394
+ constraint=matrix_vector_sametype,
5395
+ value_type=None,
5396
+ hidden=True,
5397
+ group="Utility",
5398
+ )
5399
+
5400
+
5401
+ def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5402
+ mat_type = arg_types["a"]
5403
+ return mat_type
5404
+
5405
+
5406
+ # implements matrix[i,j] = scalar
5407
+ add_builtin(
5408
+ "assign_copy",
5228
5409
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5229
5410
  value_func=matrix_assign_value_func,
5230
5411
  hidden=True,
@@ -5234,7 +5415,7 @@ add_builtin(
5234
5415
 
5235
5416
  # implements matrix[i] = vector
5236
5417
  add_builtin(
5237
- "assign",
5418
+ "assign_copy",
5238
5419
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5239
5420
  constraint=matrix_vector_sametype,
5240
5421
  value_func=matrix_assign_value_func,
@@ -5245,7 +5426,7 @@ add_builtin(
5245
5426
 
5246
5427
  # implements matrix[i,j] += scalar
5247
5428
  add_builtin(
5248
- "augassign_add",
5429
+ "add_inplace",
5249
5430
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5250
5431
  value_type=None,
5251
5432
  hidden=True,
@@ -5253,9 +5434,20 @@ add_builtin(
5253
5434
  )
5254
5435
 
5255
5436
 
5437
+ # implements matrix[i] += vector
5438
+ add_builtin(
5439
+ "add_inplace",
5440
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5441
+ constraint=matrix_vector_sametype,
5442
+ value_type=None,
5443
+ hidden=True,
5444
+ group="Utility",
5445
+ )
5446
+
5447
+
5256
5448
  # implements matrix[i,j] -= scalar
5257
5449
  add_builtin(
5258
- "augassign_sub",
5450
+ "sub_inplace",
5259
5451
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5260
5452
  value_type=None,
5261
5453
  hidden=True,
@@ -5263,6 +5455,16 @@ add_builtin(
5263
5455
  )
5264
5456
 
5265
5457
 
5458
+ # implements matrix[i] -= vector
5459
+ add_builtin(
5460
+ "sub_inplace",
5461
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5462
+ value_type=None,
5463
+ hidden=True,
5464
+ group="Utility",
5465
+ )
5466
+
5467
+
5266
5468
  for t in scalar_types + vector_types + (bool,):
5267
5469
  if "vec" in t.__name__ or "mat" in t.__name__:
5268
5470
  continue
@@ -5410,7 +5612,27 @@ add_builtin(
5410
5612
  )
5411
5613
  add_builtin(
5412
5614
  "expect_near",
5413
- input_types={"a": vec3, "b": vec3, "tolerance": float},
5615
+ input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
5616
+ defaults={"tolerance": 1.0e-6},
5617
+ value_type=None,
5618
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5619
+ group="Utility",
5620
+ )
5621
+ add_builtin(
5622
+ "expect_near",
5623
+ input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
5624
+ defaults={"tolerance": 1.0e-6},
5625
+ value_type=None,
5626
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5627
+ group="Utility",
5628
+ )
5629
+ add_builtin(
5630
+ "expect_near",
5631
+ input_types={
5632
+ "a": matrix(shape=(Any, Any), dtype=Float),
5633
+ "b": matrix(shape=(Any, Any), dtype=Float),
5634
+ "tolerance": Float,
5635
+ },
5414
5636
  defaults={"tolerance": 1.0e-6},
5415
5637
  value_type=None,
5416
5638
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
@@ -5989,7 +6211,7 @@ add_builtin(
5989
6211
  ##
5990
6212
  ## Matmul
5991
6213
  ##
5992
- def tile_matmul_generic_value_func(arg_types, arg_values):
6214
+ def tile_matmul_value_func(arg_types, arg_values):
5993
6215
  # return generic type (for doc builds)
5994
6216
  if arg_types is None:
5995
6217
  return Tile(dtype=Any, shape=Any)
@@ -6015,7 +6237,7 @@ def tile_matmul_generic_value_func(arg_types, arg_values):
6015
6237
  return None
6016
6238
 
6017
6239
 
6018
- def tile_matmul_generic_lto_dispatch_func(
6240
+ def tile_matmul_lto_dispatch_func(
6019
6241
  arg_types: Mapping[str, type],
6020
6242
  return_type: Any,
6021
6243
  return_values: List[Var],
@@ -6054,142 +6276,82 @@ def tile_matmul_generic_lto_dispatch_func(
6054
6276
  out.type.storage = "shared"
6055
6277
  template_args = [accumulate]
6056
6278
 
6057
- # Maps Python/Warp types to C++ types and enums
6058
- def cublasdx_type_map(dtype):
6059
- if dtype == float16:
6060
- return ("wp::float16", 3, 0)
6061
- if dtype == float32:
6062
- return ("wp::float32", 5, 0)
6063
- if dtype == float64:
6064
- return ("wp::float64", 6, 0)
6065
- if dtype == vec2h:
6066
- return ("wp::vec2h", 3, 1)
6067
- if dtype == vec2f:
6068
- return ("wp::vec2f", 5, 1)
6069
- if dtype == vec2d:
6070
- return ("wp::vec2d", 6, 1)
6071
- raise TypeError("Unsupported input type in tile_matmul")
6072
-
6073
- def cublasdx_arrangement_map(layout):
6074
- if layout == "colmajor":
6075
- return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
6076
- if layout == "rowmajor":
6077
- return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
6078
- raise ValueError("Unsupported layout in tile_matmul")
6079
-
6080
- # generate the LTO
6081
6279
  M, K = a.type.shape[0], a.type.shape[1]
6082
6280
  _, N = b.type.shape[0], b.type.shape[1]
6083
6281
  num_threads = options["block_dim"]
6084
6282
  arch = options["output_arch"]
6085
6283
 
6086
- def make_function(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout):
6087
- (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
6088
- (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
6089
- (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
6090
- a_arrangement = cublasdx_arrangement_map(alayout)
6091
- b_arrangement = cublasdx_arrangement_map(blayout)
6092
- c_arrangement = cublasdx_arrangement_map(clayout)
6093
-
6094
- if a_type != b_type or a_type != c_type:
6095
- raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6096
-
6097
- element_type = a_type
6098
-
6099
- lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
6284
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6285
+ # CPU/no-MathDx dispatch
6286
+ return ((0, 0, 0, a, b, out), template_args, [], 0)
6287
+ else:
6100
6288
 
6101
- # early out if LTO for this combination already exists for this module
6102
- if lto_symbol in builder.ltoirs:
6103
- return lto_symbol, builder.ltoirs[lto_symbol]
6289
+ def tile_flip_layout(layout):
6290
+ if layout == "rowmajor":
6291
+ return "colmajor"
6292
+ elif layout == "colmajor":
6293
+ return "rowmajor"
6104
6294
 
6105
- # otherwise compile LTO
6106
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6107
- result = warp.context.runtime.core.cuda_compile_dot(
6108
- lto_code.name.encode("utf-8"),
6109
- lto_symbol.encode("utf-8"),
6110
- 0,
6111
- None,
6112
- None,
6295
+ # generate the LTOs
6296
+ # C += A * B
6297
+ (fun_forward, lto_forward) = warp.build.build_lto_dot(
6298
+ M,
6299
+ N,
6300
+ K,
6301
+ a.type.dtype,
6302
+ b.type.dtype,
6303
+ out.type.dtype,
6304
+ a.type.layout,
6305
+ b.type.layout,
6306
+ out.type.layout,
6113
6307
  arch,
6308
+ num_threads,
6309
+ builder,
6310
+ )
6311
+ # adjA += adjC * B^T - Transpose ~= flipped layout
6312
+ (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
6114
6313
  M,
6314
+ K,
6115
6315
  N,
6316
+ out.type.dtype,
6317
+ b.type.dtype,
6318
+ a.type.dtype,
6319
+ out.type.layout,
6320
+ tile_flip_layout(b.type.layout),
6321
+ a.type.layout,
6322
+ arch,
6323
+ num_threads,
6324
+ builder,
6325
+ )
6326
+ # adjB += A^T * adjC - Transpose ~= flipped layout
6327
+ (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
6116
6328
  K,
6117
- a_prec,
6118
- b_prec,
6119
- c_prec,
6120
- element_type,
6121
- a_arrangement,
6122
- b_arrangement,
6123
- c_arrangement,
6329
+ N,
6330
+ M,
6331
+ a.type.dtype,
6332
+ out.type.dtype,
6333
+ b.type.dtype,
6334
+ tile_flip_layout(a.type.layout),
6335
+ out.type.layout,
6336
+ b.type.layout,
6337
+ arch,
6124
6338
  num_threads,
6339
+ builder,
6125
6340
  )
6126
- lto_code_path = Path(lto_code.name)
6127
- if not result:
6128
- lto_code.close()
6129
- if lto_code_path.exists():
6130
- lto_code_path.unlink()
6131
- raise RuntimeError("Failed to compile tile_matmul")
6132
- else:
6133
- with open(lto_code.name, "rb") as f:
6134
- lto_code_data = f.read()
6135
- lto_code.close()
6136
- lto_code_path.unlink()
6137
-
6138
- builder.ltoirs[lto_symbol] = lto_code_data
6139
- builder.ltoirs_decl[lto_symbol] = (
6140
- f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
6141
- )
6142
-
6143
- return lto_symbol, lto_code_data
6144
6341
 
6145
- def tile_flip_layout(layout):
6146
- if layout == "rowmajor":
6147
- return "colmajor"
6148
- elif layout == "colmajor":
6149
- return "rowmajor"
6150
-
6151
- # C += A * B
6152
- (fun_forward, lto_forward) = make_function(
6153
- M, N, K, a.type.dtype, b.type.dtype, out.type.dtype, a.type.layout, b.type.layout, out.type.layout
6154
- )
6155
- # adjA += adjC * B^T - Transpose ~= flipped layout
6156
- (fun_backward_A, lto_backward_A) = make_function(
6157
- M,
6158
- K,
6159
- N,
6160
- out.type.dtype,
6161
- b.type.dtype,
6162
- a.type.dtype,
6163
- out.type.layout,
6164
- tile_flip_layout(b.type.layout),
6165
- a.type.layout,
6166
- )
6167
- # adjB += A^T * adjC - Transpose ~= flipped layout
6168
- (fun_backward_B, lto_backward_B) = make_function(
6169
- K,
6170
- N,
6171
- M,
6172
- a.type.dtype,
6173
- out.type.dtype,
6174
- b.type.dtype,
6175
- tile_flip_layout(a.type.layout),
6176
- out.type.layout,
6177
- b.type.layout,
6178
- )
6179
-
6180
- return (
6181
- (
6182
- Var(fun_forward, str, False, True, False),
6183
- Var(fun_backward_A, str, False, True, False),
6184
- Var(fun_backward_B, str, False, True, False),
6185
- a,
6186
- b,
6187
- out,
6188
- ),
6189
- template_args,
6190
- [lto_forward, lto_backward_A, lto_backward_B],
6191
- 0,
6192
- )
6342
+ return (
6343
+ (
6344
+ Var(fun_forward, str, False, True, False),
6345
+ Var(fun_backward_A, str, False, True, False),
6346
+ Var(fun_backward_B, str, False, True, False),
6347
+ a,
6348
+ b,
6349
+ out,
6350
+ ),
6351
+ template_args,
6352
+ [lto_forward, lto_backward_A, lto_backward_B],
6353
+ 0,
6354
+ )
6193
6355
 
6194
6356
 
6195
6357
  add_builtin(
@@ -6199,8 +6361,8 @@ add_builtin(
6199
6361
  "b": Tile(dtype=Any, shape=Any),
6200
6362
  "out": Tile(dtype=Any, shape=Any),
6201
6363
  },
6202
- value_func=tile_matmul_generic_value_func,
6203
- lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
6364
+ value_func=tile_matmul_value_func,
6365
+ lto_dispatch_func=tile_matmul_lto_dispatch_func,
6204
6366
  variadic=False,
6205
6367
  doc="""Computes the matrix product and accumulates ``out += a*b``.
6206
6368
 
@@ -6208,7 +6370,7 @@ add_builtin(
6208
6370
  * fp16, fp32, fp64 (real)
6209
6371
  * vec2h, vec2f, vec2d (complex)
6210
6372
 
6211
- All input and output tiles must have the same datatype. Tile data will be automatically be migrated
6373
+ All input and output tiles must have the same datatype. Tile data will automatically be migrated
6212
6374
  to shared memory if necessary and will use TensorCore operations when available.
6213
6375
 
6214
6376
  :param a: A tile with ``shape=(M, K)``
@@ -6222,8 +6384,8 @@ add_builtin(
6222
6384
  add_builtin(
6223
6385
  "tile_matmul",
6224
6386
  input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
6225
- value_func=tile_matmul_generic_value_func,
6226
- lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
6387
+ value_func=tile_matmul_value_func,
6388
+ lto_dispatch_func=tile_matmul_lto_dispatch_func,
6227
6389
  variadic=False,
6228
6390
  doc="""Computes the matrix product ``out = a*b``.
6229
6391
 
@@ -6231,7 +6393,7 @@ add_builtin(
6231
6393
  * fp16, fp32, fp64 (real)
6232
6394
  * vec2h, vec2f, vec2d (complex)
6233
6395
 
6234
- Both input tiles must have the same datatype. Tile data will be automatically be migrated
6396
+ Both input tiles must have the same datatype. Tile data will automatically be migrated
6235
6397
  to shared memory if necessary and will use TensorCore operations when available.
6236
6398
 
6237
6399
  :param a: A tile with ``shape=(M, K)``
@@ -6303,59 +6465,29 @@ def tile_fft_generic_lto_dispatch_func(
6303
6465
  num_threads = options["block_dim"]
6304
6466
  arch = options["output_arch"]
6305
6467
  ept = size // num_threads
6306
- lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
6307
-
6308
- # early out if LTO for this combination already exists for this module
6309
- if lto_symbol in builder.ltoirs:
6310
- return lto_symbol, builder.ltoirs[lto_symbol]
6311
-
6312
- # otherwise compile LTO
6313
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6314
- shared_memory_size = ctypes.c_int(0)
6315
-
6316
- result = warp.context.runtime.core.cuda_compile_fft(
6317
- lto_code.name.encode("utf-8"),
6318
- lto_symbol.encode("utf-8"),
6319
- 0,
6320
- None,
6321
- None,
6322
- arch,
6323
- size,
6324
- ept,
6325
- dir,
6326
- precision,
6327
- ctypes.byref(shared_memory_size),
6328
- )
6329
- lto_code_path = Path(lto_code.name)
6330
- if not result:
6331
- lto_code.close()
6332
- if lto_code_path.exists():
6333
- lto_code_path.unlink()
6334
- raise RuntimeError("Failed to compile tile_fft")
6335
-
6336
- with open(lto_code.name, "rb") as f:
6337
- lto_code_data = f.read()
6338
-
6339
- lto_code.close()
6340
- lto_code_path.unlink()
6341
-
6342
- builder.ltoirs[lto_symbol] = lto_code_data
6343
-
6344
- shared_memory_bytes = Tile.round_up(shared_memory_size.value)
6345
-
6346
- return (
6347
- (
6348
- Var(lto_symbol, str, False, True, False),
6349
- Var(dtype, str, False, True, False),
6350
- Var(str(shared_memory_bytes), str, False, True, False),
6351
- Var(str(batch), str, False, True, False),
6352
- Var(str(ept), str, False, True, False),
6353
- inout,
6354
- ),
6355
- [],
6356
- [lto_code_data],
6357
- shared_memory_bytes,
6358
- )
6468
+
6469
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6470
+ # CPU/no-MathDx dispatch
6471
+ return ([], [], [], 0)
6472
+ else:
6473
+ # generate the LTO
6474
+ lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
6475
+ arch, size, ept, direction, dir, precision, builder
6476
+ )
6477
+
6478
+ return (
6479
+ (
6480
+ Var(lto_symbol, str, False, True, False),
6481
+ Var(dtype, str, False, True, False),
6482
+ Var(str(shared_memory_bytes), str, False, True, False),
6483
+ Var(str(batch), str, False, True, False),
6484
+ Var(str(ept), str, False, True, False),
6485
+ inout,
6486
+ ),
6487
+ [],
6488
+ [lto_code_data],
6489
+ shared_memory_bytes,
6490
+ )
6359
6491
 
6360
6492
 
6361
6493
  add_builtin(
@@ -6417,7 +6549,7 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
6417
6549
  raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
6418
6550
 
6419
6551
  if len(a.shape) != 2:
6420
- raise ValueError("tile_cholesky() argumust must be a 2D tile")
6552
+ raise ValueError("tile_cholesky() argument must be a 2D tile")
6421
6553
 
6422
6554
  if a.shape[0] != a.shape[1]:
6423
6555
  raise ValueError("tile_cholesky() argument must be square")
@@ -6458,57 +6590,36 @@ def tile_cholesky_generic_lto_dispatch_func(
6458
6590
  if out.type.shape[0] != M or out.type.shape[1] != M:
6459
6591
  raise ValueError("tile_cholesky() output tile must be square")
6460
6592
 
6461
- num_threads = options["block_dim"]
6462
- arch = options["output_arch"]
6463
- lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
6464
-
6465
- # early out if LTO for this combination already exists for this module
6466
- if lto_symbol in builder.ltoirs:
6467
- return lto_symbol, builder.ltoirs[lto_symbol]
6468
-
6469
- # otherwise compile LTO
6470
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6471
- universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6593
+ solver = "potrf"
6594
+ solver_enum = cusolver_function_map[solver]
6472
6595
 
6473
- # cuSOLVERDx only support col-major input/outputs,
6596
+ # cuSOLVERDx only supports col-major input/outputs,
6474
6597
  # so we use upper to mimic a row-major input
6475
- result = warp.context.runtime.core.cuda_compile_solver(
6476
- universal_fatbin_code.name.encode("utf-8"),
6477
- lto_code.name.encode("utf-8"),
6478
- lto_symbol.encode("utf-8"),
6479
- 0,
6480
- None,
6481
- None,
6482
- arch,
6483
- M,
6484
- N,
6485
- cusolver_function_map["potrf"],
6486
- precision_enum,
6487
- cusolver_fill_mode_map["upper"],
6488
- num_threads,
6489
- )
6598
+ fill_mode = cusolver_fill_mode_map["upper"]
6490
6599
 
6491
- if not result:
6492
- for f in [lto_code, universal_fatbin_code]:
6493
- f.close()
6494
- if Path(f.name).exists():
6495
- Path(f.name).unlink()
6496
- raise RuntimeError("Failed to compile tile_cholesky")
6600
+ arch = options["output_arch"]
6601
+ num_threads = options["block_dim"]
6602
+ parameter_list = f"({dtype}*, unsigned)"
6497
6603
 
6604
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6605
+ # CPU/no-MathDx dispatch
6606
+ return ((0, a, out), [], [], 0)
6498
6607
  else:
6499
- with open(lto_code.name, "rb") as f:
6500
- lto_code_data = f.read()
6501
- with open(universal_fatbin_code.name, "rb") as f:
6502
- universal_fatbin_code_data = f.read()
6503
- for f in [lto_code, universal_fatbin_code]:
6504
- f.close()
6505
- Path(f.name).unlink()
6506
-
6507
- builder.ltoirs[lto_symbol] = lto_code_data
6508
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
6509
- builder.fatbins["cholesky"] = universal_fatbin_code_data
6608
+ # generate the LTO
6609
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
6610
+ M,
6611
+ N,
6612
+ solver,
6613
+ solver_enum,
6614
+ fill_mode,
6615
+ arch,
6616
+ precision_enum,
6617
+ num_threads,
6618
+ parameter_list,
6619
+ builder,
6620
+ )
6510
6621
 
6511
- return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6622
+ return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6512
6623
 
6513
6624
 
6514
6625
  add_builtin(
@@ -6602,57 +6713,36 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
6602
6713
  f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
6603
6714
  )
6604
6715
 
6605
- num_threads = options["block_dim"]
6606
- arch = options["output_arch"]
6607
- lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
6608
-
6609
- # early out if LTO for this combination already exists for this module
6610
- if lto_symbol in builder.ltoirs:
6611
- return lto_symbol, builder.ltoirs[lto_symbol]
6612
-
6613
- # otherwise compile LTO
6614
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6615
- universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6716
+ solver = "potrs"
6717
+ solver_enum = cusolver_function_map[solver]
6616
6718
 
6617
- # cuSOLVERDx only support col-major input/outputs,
6719
+ # cuSOLVERDx only supports col-major input/outputs,
6618
6720
  # so we use upper to mimic a row-major input
6619
- result = warp.context.runtime.core.cuda_compile_solver(
6620
- universal_fatbin_code.name.encode("utf-8"),
6621
- lto_code.name.encode("utf-8"),
6622
- lto_symbol.encode("utf-8"),
6623
- 0,
6624
- None,
6625
- None,
6626
- arch,
6627
- M,
6628
- N,
6629
- cusolver_function_map["potrs"],
6630
- precision_enum,
6631
- cusolver_fill_mode_map["upper"],
6632
- num_threads,
6633
- )
6721
+ fill_mode = cusolver_fill_mode_map["upper"]
6634
6722
 
6635
- if not result:
6636
- for f in [lto_code, universal_fatbin_code]:
6637
- f.close()
6638
- if Path(f.name).exists():
6639
- Path(f.name).unlink()
6640
- raise RuntimeError("Failed to compile tile_cholesky_solve")
6723
+ arch = options["output_arch"]
6724
+ num_threads = options["block_dim"]
6725
+ parameter_list = f"({dtype}*, {dtype}*)"
6641
6726
 
6727
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6728
+ # CPU/no-MathDx dispatch
6729
+ return ((0, L, x, y), [], [], 0)
6642
6730
  else:
6643
- with open(lto_code.name, "rb") as f:
6644
- lto_code_data = f.read()
6645
- with open(universal_fatbin_code.name, "rb") as f:
6646
- universal_fatbin_code_data = f.read()
6647
- for f in [lto_code, universal_fatbin_code]:
6648
- f.close()
6649
- Path(f.name).unlink()
6650
-
6651
- builder.ltoirs[lto_symbol] = lto_code_data
6652
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
6653
- builder.fatbins["cholesky"] = universal_fatbin_code_data
6654
-
6655
- return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6731
+ # generate the LTO
6732
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
6733
+ M,
6734
+ N,
6735
+ solver,
6736
+ solver_enum,
6737
+ fill_mode,
6738
+ arch,
6739
+ precision_enum,
6740
+ num_threads,
6741
+ parameter_list,
6742
+ builder,
6743
+ )
6744
+
6745
+ return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6656
6746
 
6657
6747
 
6658
6748
  add_builtin(