warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -15,10 +15,10 @@
15
15
 
16
16
  import builtins
17
17
  import functools
18
- import tempfile
19
- from pathlib import Path
20
18
  from typing import Any, Callable, Mapping, Sequence
21
19
 
20
+ import warp.build
21
+ import warp.context
22
22
  from warp.codegen import Reference, Var, strip_reference
23
23
  from warp.types import *
24
24
 
@@ -41,7 +41,7 @@ def sametypes(arg_types: Mapping[str, Any]):
41
41
  return all(types_equal(arg_type_0, t) for t in arg_types_iter)
42
42
 
43
43
 
44
- def sametypes_create_value_func(default):
44
+ def sametypes_create_value_func(default: TypeVar):
45
45
  def fn(arg_types, arg_values):
46
46
  if arg_types is None:
47
47
  return default
@@ -399,7 +399,7 @@ add_builtin(
399
399
  )
400
400
 
401
401
 
402
- def scalar_infer_type(arg_types: Mapping[str, type]):
402
+ def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
403
403
  if arg_types is None:
404
404
  return Scalar
405
405
 
@@ -836,7 +836,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
836
836
 
837
837
  if dtype is None:
838
838
  dtype = value_type
839
- elif value_type != dtype:
839
+ elif not warp.types.scalars_equal(value_type, dtype):
840
840
  raise RuntimeError(
841
841
  f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
842
842
  )
@@ -857,9 +857,9 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
857
857
 
858
858
  if dtype is None:
859
859
  dtype = value_type
860
- elif value_type != dtype:
860
+ elif not warp.types.scalars_equal(value_type, dtype):
861
861
  raise RuntimeError(
862
- f"all values used to initialize this vector matrix are expected to be of the type `{dtype.__name__}`"
862
+ f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
863
863
  )
864
864
 
865
865
  if length is None:
@@ -940,7 +940,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
940
940
 
941
941
  if dtype is None:
942
942
  dtype = value_type
943
- elif value_type != dtype:
943
+ elif not warp.types.scalars_equal(value_type, dtype):
944
944
  raise RuntimeError(
945
945
  f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
946
946
  )
@@ -950,6 +950,12 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
950
950
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
951
951
 
952
952
  if all(type_is_vector(x) for x in variadic_arg_types):
953
+ warp.utils.warn(
954
+ "the built-in `wp.matrix()` won't support taking column vectors as input "
955
+ "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
956
+ DeprecationWarning,
957
+ )
958
+
953
959
  if shape[1] != variadic_arg_count:
954
960
  raise RuntimeError(
955
961
  f"incompatible number of column vectors given ({variadic_arg_count}) "
@@ -973,7 +979,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
973
979
 
974
980
  if dtype is None:
975
981
  dtype = value_type
976
- elif value_type != dtype:
982
+ elif not warp.types.scalars_equal(value_type, dtype):
977
983
  raise RuntimeError(
978
984
  f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
979
985
  )
@@ -1030,6 +1036,86 @@ add_builtin(
1030
1036
  )
1031
1037
 
1032
1038
 
1039
+ def matrix_from_vecs_create_value_func(cols: bool):
1040
+ def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1041
+ if arg_types is None:
1042
+ return matrix(shape=(Any, Any), dtype=Scalar)
1043
+
1044
+ variadic_arg_types = arg_types.get("args", ())
1045
+ variadic_arg_count = len(variadic_arg_types)
1046
+
1047
+ if not all(type_is_vector(x) for x in variadic_arg_types):
1048
+ raise RuntimeError("all arguments are expected to be vectors")
1049
+
1050
+ length = variadic_arg_types[0]._length_
1051
+ if any(x._length_ != length for x in variadic_arg_types):
1052
+ raise RuntimeError("all vectors are expected to have the same length")
1053
+
1054
+ dtype = variadic_arg_types[0]._wp_scalar_type_
1055
+ if any(x._wp_scalar_type_ != dtype for x in variadic_arg_types):
1056
+ raise RuntimeError("all vectors are expected to have the same dtype")
1057
+
1058
+ shape = (length, variadic_arg_count) if cols else (variadic_arg_count, length)
1059
+ return matrix(shape=shape, dtype=dtype)
1060
+
1061
+ return fn
1062
+
1063
+
1064
+ def matrix_from_vecs_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1065
+ # We're in the codegen stage where we emit the code calling the built-in.
1066
+ # Further validate the given argument values if needed and map them
1067
+ # to the underlying C++ function's runtime and template params.
1068
+
1069
+ shape = return_type._shape_
1070
+ dtype = return_type._wp_scalar_type_
1071
+
1072
+ variadic_args = args.get("args", ())
1073
+
1074
+ func_args = variadic_args
1075
+
1076
+ if shape in ((2, 2), (3, 3), (4, 4)):
1077
+ # Template specializations exist for these shapes, don't pass them
1078
+ # as template parameters.
1079
+ template_args = (dtype,)
1080
+ else:
1081
+ template_args = (*shape, dtype)
1082
+
1083
+ return (func_args, template_args)
1084
+
1085
+
1086
+ def matrix_from_vecs_initializer_list_func(args, return_type):
1087
+ shape = return_type._shape_
1088
+
1089
+ return shape[0] != shape[1] or shape[0] > 4
1090
+
1091
+
1092
+ add_builtin(
1093
+ "matrix_from_cols",
1094
+ input_types={"*args": vector(length=Any, dtype=Scalar)},
1095
+ variadic=True,
1096
+ value_func=matrix_from_vecs_create_value_func(cols=True),
1097
+ dispatch_func=matrix_from_vecs_dispatch_func,
1098
+ initializer_list_func=matrix_from_vecs_initializer_list_func,
1099
+ native_func="matrix_from_cols",
1100
+ doc="Construct a matrix from column vectors.",
1101
+ group="Vector Math",
1102
+ export=False,
1103
+ )
1104
+
1105
+ add_builtin(
1106
+ "matrix_from_rows",
1107
+ input_types={"*args": vector(length=Any, dtype=Scalar)},
1108
+ variadic=True,
1109
+ value_func=matrix_from_vecs_create_value_func(cols=False),
1110
+ dispatch_func=matrix_from_vecs_dispatch_func,
1111
+ initializer_list_func=matrix_from_vecs_initializer_list_func,
1112
+ native_func="matrix_from_rows",
1113
+ doc="Construct a matrix from row vectors.",
1114
+ group="Vector Math",
1115
+ export=False,
1116
+ )
1117
+
1118
+
1033
1119
  def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1034
1120
  if arg_types is None:
1035
1121
  return matrix(shape=(Any, Any), dtype=Scalar)
@@ -1084,7 +1170,7 @@ def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mappi
1084
1170
 
1085
1171
  if dtype is None:
1086
1172
  dtype = value_type
1087
- elif value_type != dtype:
1173
+ elif not warp.types.scalars_equal(value_type, dtype):
1088
1174
  raise RuntimeError(
1089
1175
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1090
1176
  )
@@ -1141,6 +1227,21 @@ add_builtin(
1141
1227
  while the left and right basis vectors are returned in ``U`` and ``V``.""",
1142
1228
  )
1143
1229
 
1230
+ add_builtin(
1231
+ "svd2",
1232
+ input_types={
1233
+ "A": matrix(shape=(2, 2), dtype=Float),
1234
+ "U": matrix(shape=(2, 2), dtype=Float),
1235
+ "sigma": vector(length=2, dtype=Float),
1236
+ "V": matrix(shape=(2, 2), dtype=Scalar),
1237
+ },
1238
+ value_type=None,
1239
+ group="Vector Math",
1240
+ export=False,
1241
+ doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
1242
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
1243
+ )
1244
+
1144
1245
  add_builtin(
1145
1246
  "qr3",
1146
1247
  input_types={
@@ -1204,7 +1305,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1204
1305
 
1205
1306
  if dtype is None:
1206
1307
  dtype = value_type
1207
- elif value_type != dtype:
1308
+ elif not warp.types.scalars_equal(value_type, dtype):
1208
1309
  raise RuntimeError(
1209
1310
  f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1210
1311
  )
@@ -1244,7 +1345,8 @@ add_builtin(
1244
1345
  )
1245
1346
  add_builtin(
1246
1347
  "quaternion",
1247
- input_types={"x": Float, "y": Float, "z": Float, "w": Float},
1348
+ input_types={"x": Float, "y": Float, "z": Float, "w": Float, "dtype": Scalar},
1349
+ defaults={"dtype": None},
1248
1350
  value_func=quaternion_value_func,
1249
1351
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1250
1352
  dispatch_func=quaternion_dispatch_func,
@@ -1332,7 +1434,18 @@ add_builtin(
1332
1434
  input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
1333
1435
  value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1334
1436
  group="Quaternion Math",
1335
- doc="Construct a quaternion from a 3x3 matrix.",
1437
+ doc="""Construct a quaternion from a 3x3 matrix.
1438
+
1439
+ If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
1440
+ )
1441
+ add_builtin(
1442
+ "quat_from_matrix",
1443
+ input_types={"mat": matrix(shape=(4, 4), dtype=Float)},
1444
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1445
+ group="Quaternion Math",
1446
+ doc="""Construct a quaternion from a 4x4 matrix.
1447
+
1448
+ If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
1336
1449
  )
1337
1450
  add_builtin(
1338
1451
  "quat_rpy",
@@ -1403,7 +1516,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1403
1516
  dtype = arg_values.get("dtype", None)
1404
1517
  if dtype is None:
1405
1518
  dtype = value_type
1406
- elif value_type != dtype:
1519
+ elif not warp.types.scalars_equal(value_type, dtype):
1407
1520
  raise RuntimeError(
1408
1521
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1409
1522
  )
@@ -1570,7 +1683,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1570
1683
 
1571
1684
  if dtype is None:
1572
1685
  dtype = value_type
1573
- elif value_type != dtype:
1686
+ elif not warp.types.scalars_equal(value_type, dtype):
1574
1687
  raise RuntimeError(
1575
1688
  f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1576
1689
  )
@@ -2375,7 +2488,7 @@ add_builtin(
2375
2488
 
2376
2489
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2377
2490
 
2378
- * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
2491
+ * If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
2379
2492
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2380
2493
 
2381
2494
  :param x: A per-thread local value, e.g. scalar, vector, or matrix.
@@ -2669,11 +2782,9 @@ def tile_broadcast_value_func(arg_types, arg_values):
2669
2782
  def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2670
2783
  tile = arg_values["a"]
2671
2784
 
2672
- template_args = []
2673
- template_args.append(return_type.shape[0])
2674
- template_args.append(return_type.shape[1])
2675
- template_args.append(return_type.strides[0])
2676
- template_args.append(return_type.strides[1])
2785
+ assert len(return_type.shape) == len(return_type.strides)
2786
+ assert 1 <= len(return_type.shape) <= 4
2787
+ template_args = [*return_type.shape, *return_type.strides]
2677
2788
 
2678
2789
  return ((tile,), template_args)
2679
2790
 
@@ -2686,56 +2797,17 @@ add_builtin(
2686
2797
  variadic=False,
2687
2798
  doc="""Broadcast a tile.
2688
2799
 
2689
- This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
2690
-
2800
+ Broadcasts the input tile ``a`` to the destination shape.
2691
2801
  Broadcasting follows NumPy broadcast rules.
2692
2802
 
2693
2803
  :param a: Tile to broadcast
2694
2804
  :param shape: The shape to broadcast to
2695
- :returns: Tile with broadcast ``shape=(m, n)``""",
2805
+ :returns: Tile with broadcast shape""",
2696
2806
  group="Tile Primitives",
2697
2807
  export=False,
2698
2808
  )
2699
2809
 
2700
2810
 
2701
- def tile_matmul_value_func(arg_types, arg_values):
2702
- # return generic type (for doc builds)
2703
- if arg_types is None:
2704
- return Tile(dtype=Any, shape=Any)
2705
-
2706
- if len(arg_types) != 3:
2707
- raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
2708
-
2709
- return None
2710
-
2711
-
2712
- def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2713
- a = arg_values["a"]
2714
- b = arg_values["b"]
2715
- out = arg_values["out"]
2716
-
2717
- # force the storage type of the input variables to shared memory
2718
- a.type.storage = "shared"
2719
- b.type.storage = "shared"
2720
- out.type.storage = "shared"
2721
-
2722
- template_args = []
2723
- return ((a, b, out), template_args)
2724
-
2725
-
2726
- add_builtin(
2727
- "tile_matmul_scalar",
2728
- input_types={"a": Tile, "b": Tile, "out": Tile},
2729
- value_func=tile_matmul_value_func,
2730
- dispatch_func=tile_matmul_dispatch_func,
2731
- variadic=True,
2732
- doc="Compute matrix product and accumulate out += a*b.",
2733
- group="Tile Primitives",
2734
- hidden=True,
2735
- export=False,
2736
- )
2737
-
2738
-
2739
2811
  def tile_sum_value_func(arg_types, arg_values):
2740
2812
  # return generic type (for doc builds)
2741
2813
  if arg_types is None:
@@ -3030,7 +3102,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
3030
3102
 
3031
3103
  for i in range(len(a.shape)):
3032
3104
  if a.shape[i] != b.shape[i]:
3033
- raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
3105
+ raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
3034
3106
 
3035
3107
  return TileBinaryMap(a, b)
3036
3108
 
@@ -3807,6 +3879,18 @@ _volume_supported_value_types = {
3807
3879
  }
3808
3880
 
3809
3881
 
3882
+ def _is_volume_type_supported(dtype):
3883
+ for typ in _volume_supported_value_types:
3884
+ if types_equal(typ, dtype):
3885
+ return True
3886
+ return False
3887
+
3888
+
3889
+ def _check_volume_type_is_supported(dtype):
3890
+ if not _is_volume_type_supported(dtype):
3891
+ raise RuntimeError(f"unsupported volume type `{type_repr(dtype)}`")
3892
+
3893
+
3810
3894
  def check_volume_value_grad_compatibility(dtype, grad_dtype):
3811
3895
  if type_is_vector(dtype):
3812
3896
  expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
@@ -3822,9 +3906,7 @@ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
3822
3906
  return Any
3823
3907
 
3824
3908
  dtype = arg_values["dtype"]
3825
-
3826
- if dtype not in _volume_supported_value_types:
3827
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3909
+ _check_volume_type_is_supported(dtype)
3828
3910
 
3829
3911
  return dtype
3830
3912
 
@@ -3860,9 +3942,7 @@ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Map
3860
3942
  return Any
3861
3943
 
3862
3944
  dtype = arg_values["dtype"]
3863
-
3864
- if dtype not in _volume_supported_value_types:
3865
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3945
+ _check_volume_type_is_supported(dtype)
3866
3946
 
3867
3947
  check_volume_value_grad_compatibility(dtype, arg_types["grad"])
3868
3948
 
@@ -3900,9 +3980,7 @@ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
3900
3980
  return Any
3901
3981
 
3902
3982
  dtype = arg_values["dtype"]
3903
-
3904
- if dtype not in _volume_supported_value_types:
3905
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3983
+ _check_volume_type_is_supported(dtype)
3906
3984
 
3907
3985
  return dtype
3908
3986
 
@@ -3939,9 +4017,7 @@ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[s
3939
4017
  return None
3940
4018
 
3941
4019
  dtype = arg_types["value"]
3942
-
3943
- if dtype not in _volume_supported_value_types:
3944
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
4020
+ _check_volume_type_is_supported(dtype)
3945
4021
 
3946
4022
  return None
3947
4023
 
@@ -4191,6 +4267,20 @@ add_builtin(
4191
4267
  group="Random",
4192
4268
  doc="Return a random integer between [low, high).",
4193
4269
  )
4270
+ add_builtin(
4271
+ "randu",
4272
+ input_types={"state": uint32},
4273
+ value_type=uint32,
4274
+ group="Random",
4275
+ doc="Return a random unsigned integer in the range [0, 2^32).",
4276
+ )
4277
+ add_builtin(
4278
+ "randu",
4279
+ input_types={"state": uint32, "low": uint32, "high": uint32},
4280
+ value_type=uint32,
4281
+ group="Random",
4282
+ doc="Return a random unsigned integer between [low, high).",
4283
+ )
4194
4284
  add_builtin(
4195
4285
  "randf",
4196
4286
  input_types={"state": uint32},
@@ -4499,11 +4589,31 @@ add_builtin(
4499
4589
  export=False,
4500
4590
  group="Utility",
4501
4591
  )
4592
+
4593
+
4594
+ def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
4595
+ warp.utils.warn(
4596
+ "wp.select() is deprecated and will be removed in a future\n"
4597
+ "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
4598
+ category=DeprecationWarning,
4599
+ )
4600
+
4601
+ func_args = tuple(args.values())
4602
+ template_args = ()
4603
+
4604
+ return (func_args, template_args)
4605
+
4606
+
4502
4607
  add_builtin(
4503
4608
  "select",
4504
4609
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
4505
4610
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4506
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
4611
+ dispatch_func=select_dispatch_func,
4612
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
4613
+
4614
+ .. deprecated:: 1.7
4615
+ Use :func:`where` instead, which has the more intuitive argument order:
4616
+ ``where(cond, value_if_true, value_if_false)``.""",
4507
4617
  group="Utility",
4508
4618
  )
4509
4619
  for t in int_types:
@@ -4511,14 +4621,47 @@ for t in int_types:
4511
4621
  "select",
4512
4622
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
4513
4623
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4514
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
4624
+ dispatch_func=select_dispatch_func,
4625
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
4626
+
4627
+ .. deprecated:: 1.7
4628
+ Use :func:`where` instead, which has the more intuitive argument order:
4629
+ ``where(cond, value_if_true, value_if_false)``.""",
4515
4630
  group="Utility",
4516
4631
  )
4517
4632
  add_builtin(
4518
4633
  "select",
4519
4634
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
4520
4635
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4521
- doc="Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``",
4636
+ dispatch_func=select_dispatch_func,
4637
+ doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
4638
+
4639
+ .. deprecated:: 1.7
4640
+ Use :func:`where` instead, which has the more intuitive argument order:
4641
+ ``where(arr, value_if_true, value_if_false)``.""",
4642
+ group="Utility",
4643
+ )
4644
+
4645
+ add_builtin(
4646
+ "where",
4647
+ input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
4648
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4649
+ doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
4650
+ group="Utility",
4651
+ )
4652
+ for t in int_types:
4653
+ add_builtin(
4654
+ "where",
4655
+ input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
4656
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4657
+ doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
4658
+ group="Utility",
4659
+ )
4660
+ add_builtin(
4661
+ "where",
4662
+ input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
4663
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4664
+ doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
4522
4665
  group="Utility",
4523
4666
  )
4524
4667
 
@@ -5112,33 +5255,51 @@ add_builtin(
5112
5255
  )
5113
5256
 
5114
5257
 
5258
+ # implements vector[index] = value
5259
+ add_builtin(
5260
+ "assign_inplace",
5261
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5262
+ value_type=None,
5263
+ hidden=True,
5264
+ group="Utility",
5265
+ )
5266
+
5267
+ # implements quaternion[index] = value
5268
+ add_builtin(
5269
+ "assign_inplace",
5270
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5271
+ value_type=None,
5272
+ hidden=True,
5273
+ group="Utility",
5274
+ )
5275
+
5276
+
5115
5277
  def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5116
5278
  vec_type = arg_types["a"]
5117
5279
  return vec_type
5118
5280
 
5119
5281
 
5120
- # implements vector[index] = value
5282
+ # implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
5121
5283
  add_builtin(
5122
- "assign",
5284
+ "assign_copy",
5123
5285
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5124
5286
  value_func=vector_assign_value_func,
5125
5287
  hidden=True,
5126
5288
  group="Utility",
5127
5289
  )
5128
5290
 
5129
- # implements quaternion[index] = value
5291
+ # implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
5130
5292
  add_builtin(
5131
- "assign",
5293
+ "assign_copy",
5132
5294
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5133
5295
  value_func=vector_assign_value_func,
5134
5296
  hidden=True,
5135
5297
  group="Utility",
5136
5298
  )
5137
5299
 
5138
-
5139
5300
  # implements vector[idx] += scalar
5140
5301
  add_builtin(
5141
- "augassign_add",
5302
+ "add_inplace",
5142
5303
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5143
5304
  value_type=None,
5144
5305
  hidden=True,
@@ -5147,7 +5308,7 @@ add_builtin(
5147
5308
 
5148
5309
  # implements quaternion[idx] += scalar
5149
5310
  add_builtin(
5150
- "augassign_add",
5311
+ "add_inplace",
5151
5312
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5152
5313
  value_type=None,
5153
5314
  hidden=True,
@@ -5156,7 +5317,7 @@ add_builtin(
5156
5317
 
5157
5318
  # implements vector[idx] -= scalar
5158
5319
  add_builtin(
5159
- "augassign_sub",
5320
+ "sub_inplace",
5160
5321
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5161
5322
  value_type=None,
5162
5323
  hidden=True,
@@ -5165,7 +5326,7 @@ add_builtin(
5165
5326
 
5166
5327
  # implements quaternion[idx] -= scalar
5167
5328
  add_builtin(
5168
- "augassign_sub",
5329
+ "sub_inplace",
5169
5330
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5170
5331
  value_type=None,
5171
5332
  hidden=True,
@@ -5209,11 +5370,6 @@ add_builtin(
5209
5370
  )
5210
5371
 
5211
5372
 
5212
- def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5213
- mat_type = arg_types["a"]
5214
- return mat_type
5215
-
5216
-
5217
5373
  def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5218
5374
  mat_size = arg_types["a"]._shape_[0]
5219
5375
  vec_size = arg_types["value"]._length_
@@ -5224,7 +5380,33 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5224
5380
 
5225
5381
  # implements matrix[i,j] = scalar
5226
5382
  add_builtin(
5227
- "assign",
5383
+ "assign_inplace",
5384
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5385
+ value_type=None,
5386
+ hidden=True,
5387
+ group="Utility",
5388
+ )
5389
+
5390
+
5391
+ # implements matrix[i] = vector
5392
+ add_builtin(
5393
+ "assign_inplace",
5394
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5395
+ constraint=matrix_vector_sametype,
5396
+ value_type=None,
5397
+ hidden=True,
5398
+ group="Utility",
5399
+ )
5400
+
5401
+
5402
+ def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5403
+ mat_type = arg_types["a"]
5404
+ return mat_type
5405
+
5406
+
5407
+ # implements matrix[i,j] = scalar
5408
+ add_builtin(
5409
+ "assign_copy",
5228
5410
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5229
5411
  value_func=matrix_assign_value_func,
5230
5412
  hidden=True,
@@ -5234,7 +5416,7 @@ add_builtin(
5234
5416
 
5235
5417
  # implements matrix[i] = vector
5236
5418
  add_builtin(
5237
- "assign",
5419
+ "assign_copy",
5238
5420
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5239
5421
  constraint=matrix_vector_sametype,
5240
5422
  value_func=matrix_assign_value_func,
@@ -5245,7 +5427,7 @@ add_builtin(
5245
5427
 
5246
5428
  # implements matrix[i,j] += scalar
5247
5429
  add_builtin(
5248
- "augassign_add",
5430
+ "add_inplace",
5249
5431
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5250
5432
  value_type=None,
5251
5433
  hidden=True,
@@ -5253,9 +5435,20 @@ add_builtin(
5253
5435
  )
5254
5436
 
5255
5437
 
5438
+ # implements matrix[i] += vector
5439
+ add_builtin(
5440
+ "add_inplace",
5441
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5442
+ constraint=matrix_vector_sametype,
5443
+ value_type=None,
5444
+ hidden=True,
5445
+ group="Utility",
5446
+ )
5447
+
5448
+
5256
5449
  # implements matrix[i,j] -= scalar
5257
5450
  add_builtin(
5258
- "augassign_sub",
5451
+ "sub_inplace",
5259
5452
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5260
5453
  value_type=None,
5261
5454
  hidden=True,
@@ -5263,6 +5456,16 @@ add_builtin(
5263
5456
  )
5264
5457
 
5265
5458
 
5459
+ # implements matrix[i] -= vector
5460
+ add_builtin(
5461
+ "sub_inplace",
5462
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5463
+ value_type=None,
5464
+ hidden=True,
5465
+ group="Utility",
5466
+ )
5467
+
5468
+
5266
5469
  for t in scalar_types + vector_types + (bool,):
5267
5470
  if "vec" in t.__name__ or "mat" in t.__name__:
5268
5471
  continue
@@ -5410,7 +5613,27 @@ add_builtin(
5410
5613
  )
5411
5614
  add_builtin(
5412
5615
  "expect_near",
5413
- input_types={"a": vec3, "b": vec3, "tolerance": float},
5616
+ input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
5617
+ defaults={"tolerance": 1.0e-6},
5618
+ value_type=None,
5619
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5620
+ group="Utility",
5621
+ )
5622
+ add_builtin(
5623
+ "expect_near",
5624
+ input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
5625
+ defaults={"tolerance": 1.0e-6},
5626
+ value_type=None,
5627
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5628
+ group="Utility",
5629
+ )
5630
+ add_builtin(
5631
+ "expect_near",
5632
+ input_types={
5633
+ "a": matrix(shape=(Any, Any), dtype=Float),
5634
+ "b": matrix(shape=(Any, Any), dtype=Float),
5635
+ "tolerance": Float,
5636
+ },
5414
5637
  defaults={"tolerance": 1.0e-6},
5415
5638
  value_type=None,
5416
5639
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
@@ -5989,7 +6212,7 @@ add_builtin(
5989
6212
  ##
5990
6213
  ## Matmul
5991
6214
  ##
5992
- def tile_matmul_generic_value_func(arg_types, arg_values):
6215
+ def tile_matmul_value_func(arg_types, arg_values):
5993
6216
  # return generic type (for doc builds)
5994
6217
  if arg_types is None:
5995
6218
  return Tile(dtype=Any, shape=Any)
@@ -6015,7 +6238,7 @@ def tile_matmul_generic_value_func(arg_types, arg_values):
6015
6238
  return None
6016
6239
 
6017
6240
 
6018
- def tile_matmul_generic_lto_dispatch_func(
6241
+ def tile_matmul_lto_dispatch_func(
6019
6242
  arg_types: Mapping[str, type],
6020
6243
  return_type: Any,
6021
6244
  return_values: List[Var],
@@ -6054,142 +6277,82 @@ def tile_matmul_generic_lto_dispatch_func(
6054
6277
  out.type.storage = "shared"
6055
6278
  template_args = [accumulate]
6056
6279
 
6057
- # Maps Python/Warp types to C++ types and enums
6058
- def cublasdx_type_map(dtype):
6059
- if dtype == float16:
6060
- return ("wp::float16", 3, 0)
6061
- if dtype == float32:
6062
- return ("wp::float32", 5, 0)
6063
- if dtype == float64:
6064
- return ("wp::float64", 6, 0)
6065
- if dtype == vec2h:
6066
- return ("wp::vec2h", 3, 1)
6067
- if dtype == vec2f:
6068
- return ("wp::vec2f", 5, 1)
6069
- if dtype == vec2d:
6070
- return ("wp::vec2d", 6, 1)
6071
- raise TypeError("Unsupported input type in tile_matmul")
6072
-
6073
- def cublasdx_arrangement_map(layout):
6074
- if layout == "colmajor":
6075
- return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
6076
- if layout == "rowmajor":
6077
- return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
6078
- raise ValueError("Unsupported layout in tile_matmul")
6079
-
6080
- # generate the LTO
6081
6280
  M, K = a.type.shape[0], a.type.shape[1]
6082
6281
  _, N = b.type.shape[0], b.type.shape[1]
6083
6282
  num_threads = options["block_dim"]
6084
6283
  arch = options["output_arch"]
6085
6284
 
6086
- def make_function(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout):
6087
- (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
6088
- (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
6089
- (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
6090
- a_arrangement = cublasdx_arrangement_map(alayout)
6091
- b_arrangement = cublasdx_arrangement_map(blayout)
6092
- c_arrangement = cublasdx_arrangement_map(clayout)
6093
-
6094
- if a_type != b_type or a_type != c_type:
6095
- raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6096
-
6097
- element_type = a_type
6098
-
6099
- lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
6285
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6286
+ # CPU/no-MathDx dispatch
6287
+ return ((0, 0, 0, a, b, out), template_args, [], 0)
6288
+ else:
6100
6289
 
6101
- # early out if LTO for this combination already exists for this module
6102
- if lto_symbol in builder.ltoirs:
6103
- return lto_symbol, builder.ltoirs[lto_symbol]
6290
+ def tile_flip_layout(layout):
6291
+ if layout == "rowmajor":
6292
+ return "colmajor"
6293
+ elif layout == "colmajor":
6294
+ return "rowmajor"
6104
6295
 
6105
- # otherwise compile LTO
6106
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6107
- result = warp.context.runtime.core.cuda_compile_dot(
6108
- lto_code.name.encode("utf-8"),
6109
- lto_symbol.encode("utf-8"),
6110
- 0,
6111
- None,
6112
- None,
6296
+ # generate the LTOs
6297
+ # C += A * B
6298
+ (fun_forward, lto_forward) = warp.build.build_lto_dot(
6299
+ M,
6300
+ N,
6301
+ K,
6302
+ a.type.dtype,
6303
+ b.type.dtype,
6304
+ out.type.dtype,
6305
+ a.type.layout,
6306
+ b.type.layout,
6307
+ out.type.layout,
6113
6308
  arch,
6309
+ num_threads,
6310
+ builder,
6311
+ )
6312
+ # adjA += adjC * B^T - Transpose ~= flipped layout
6313
+ (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
6114
6314
  M,
6315
+ K,
6115
6316
  N,
6317
+ out.type.dtype,
6318
+ b.type.dtype,
6319
+ a.type.dtype,
6320
+ out.type.layout,
6321
+ tile_flip_layout(b.type.layout),
6322
+ a.type.layout,
6323
+ arch,
6324
+ num_threads,
6325
+ builder,
6326
+ )
6327
+ # adjB += A^T * adjC - Transpose ~= flipped layout
6328
+ (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
6116
6329
  K,
6117
- a_prec,
6118
- b_prec,
6119
- c_prec,
6120
- element_type,
6121
- a_arrangement,
6122
- b_arrangement,
6123
- c_arrangement,
6330
+ N,
6331
+ M,
6332
+ a.type.dtype,
6333
+ out.type.dtype,
6334
+ b.type.dtype,
6335
+ tile_flip_layout(a.type.layout),
6336
+ out.type.layout,
6337
+ b.type.layout,
6338
+ arch,
6124
6339
  num_threads,
6340
+ builder,
6125
6341
  )
6126
- lto_code_path = Path(lto_code.name)
6127
- if not result:
6128
- lto_code.close()
6129
- if lto_code_path.exists():
6130
- lto_code_path.unlink()
6131
- raise RuntimeError("Failed to compile tile_matmul")
6132
- else:
6133
- with open(lto_code.name, "rb") as f:
6134
- lto_code_data = f.read()
6135
- lto_code.close()
6136
- lto_code_path.unlink()
6137
-
6138
- builder.ltoirs[lto_symbol] = lto_code_data
6139
- builder.ltoirs_decl[lto_symbol] = (
6140
- f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
6141
- )
6142
-
6143
- return lto_symbol, lto_code_data
6144
6342
 
6145
- def tile_flip_layout(layout):
6146
- if layout == "rowmajor":
6147
- return "colmajor"
6148
- elif layout == "colmajor":
6149
- return "rowmajor"
6150
-
6151
- # C += A * B
6152
- (fun_forward, lto_forward) = make_function(
6153
- M, N, K, a.type.dtype, b.type.dtype, out.type.dtype, a.type.layout, b.type.layout, out.type.layout
6154
- )
6155
- # adjA += adjC * B^T - Transpose ~= flipped layout
6156
- (fun_backward_A, lto_backward_A) = make_function(
6157
- M,
6158
- K,
6159
- N,
6160
- out.type.dtype,
6161
- b.type.dtype,
6162
- a.type.dtype,
6163
- out.type.layout,
6164
- tile_flip_layout(b.type.layout),
6165
- a.type.layout,
6166
- )
6167
- # adjB += A^T * adjC - Transpose ~= flipped layout
6168
- (fun_backward_B, lto_backward_B) = make_function(
6169
- K,
6170
- N,
6171
- M,
6172
- a.type.dtype,
6173
- out.type.dtype,
6174
- b.type.dtype,
6175
- tile_flip_layout(a.type.layout),
6176
- out.type.layout,
6177
- b.type.layout,
6178
- )
6179
-
6180
- return (
6181
- (
6182
- Var(fun_forward, str, False, True, False),
6183
- Var(fun_backward_A, str, False, True, False),
6184
- Var(fun_backward_B, str, False, True, False),
6185
- a,
6186
- b,
6187
- out,
6188
- ),
6189
- template_args,
6190
- [lto_forward, lto_backward_A, lto_backward_B],
6191
- 0,
6192
- )
6343
+ return (
6344
+ (
6345
+ Var(fun_forward, str, False, True, False),
6346
+ Var(fun_backward_A, str, False, True, False),
6347
+ Var(fun_backward_B, str, False, True, False),
6348
+ a,
6349
+ b,
6350
+ out,
6351
+ ),
6352
+ template_args,
6353
+ [lto_forward, lto_backward_A, lto_backward_B],
6354
+ 0,
6355
+ )
6193
6356
 
6194
6357
 
6195
6358
  add_builtin(
@@ -6199,8 +6362,8 @@ add_builtin(
6199
6362
  "b": Tile(dtype=Any, shape=Any),
6200
6363
  "out": Tile(dtype=Any, shape=Any),
6201
6364
  },
6202
- value_func=tile_matmul_generic_value_func,
6203
- lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
6365
+ value_func=tile_matmul_value_func,
6366
+ lto_dispatch_func=tile_matmul_lto_dispatch_func,
6204
6367
  variadic=False,
6205
6368
  doc="""Computes the matrix product and accumulates ``out += a*b``.
6206
6369
 
@@ -6208,7 +6371,7 @@ add_builtin(
6208
6371
  * fp16, fp32, fp64 (real)
6209
6372
  * vec2h, vec2f, vec2d (complex)
6210
6373
 
6211
- All input and output tiles must have the same datatype. Tile data will be automatically be migrated
6374
+ All input and output tiles must have the same datatype. Tile data will automatically be migrated
6212
6375
  to shared memory if necessary and will use TensorCore operations when available.
6213
6376
 
6214
6377
  :param a: A tile with ``shape=(M, K)``
@@ -6222,8 +6385,8 @@ add_builtin(
6222
6385
  add_builtin(
6223
6386
  "tile_matmul",
6224
6387
  input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
6225
- value_func=tile_matmul_generic_value_func,
6226
- lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
6388
+ value_func=tile_matmul_value_func,
6389
+ lto_dispatch_func=tile_matmul_lto_dispatch_func,
6227
6390
  variadic=False,
6228
6391
  doc="""Computes the matrix product ``out = a*b``.
6229
6392
 
@@ -6231,7 +6394,7 @@ add_builtin(
6231
6394
  * fp16, fp32, fp64 (real)
6232
6395
  * vec2h, vec2f, vec2d (complex)
6233
6396
 
6234
- Both input tiles must have the same datatype. Tile data will be automatically be migrated
6397
+ Both input tiles must have the same datatype. Tile data will automatically be migrated
6235
6398
  to shared memory if necessary and will use TensorCore operations when available.
6236
6399
 
6237
6400
  :param a: A tile with ``shape=(M, K)``
@@ -6303,59 +6466,29 @@ def tile_fft_generic_lto_dispatch_func(
6303
6466
  num_threads = options["block_dim"]
6304
6467
  arch = options["output_arch"]
6305
6468
  ept = size // num_threads
6306
- lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
6307
-
6308
- # early out if LTO for this combination already exists for this module
6309
- if lto_symbol in builder.ltoirs:
6310
- return lto_symbol, builder.ltoirs[lto_symbol]
6311
-
6312
- # otherwise compile LTO
6313
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6314
- shared_memory_size = ctypes.c_int(0)
6315
-
6316
- result = warp.context.runtime.core.cuda_compile_fft(
6317
- lto_code.name.encode("utf-8"),
6318
- lto_symbol.encode("utf-8"),
6319
- 0,
6320
- None,
6321
- None,
6322
- arch,
6323
- size,
6324
- ept,
6325
- dir,
6326
- precision,
6327
- ctypes.byref(shared_memory_size),
6328
- )
6329
- lto_code_path = Path(lto_code.name)
6330
- if not result:
6331
- lto_code.close()
6332
- if lto_code_path.exists():
6333
- lto_code_path.unlink()
6334
- raise RuntimeError("Failed to compile tile_fft")
6335
-
6336
- with open(lto_code.name, "rb") as f:
6337
- lto_code_data = f.read()
6338
-
6339
- lto_code.close()
6340
- lto_code_path.unlink()
6341
-
6342
- builder.ltoirs[lto_symbol] = lto_code_data
6343
-
6344
- shared_memory_bytes = Tile.round_up(shared_memory_size.value)
6345
-
6346
- return (
6347
- (
6348
- Var(lto_symbol, str, False, True, False),
6349
- Var(dtype, str, False, True, False),
6350
- Var(str(shared_memory_bytes), str, False, True, False),
6351
- Var(str(batch), str, False, True, False),
6352
- Var(str(ept), str, False, True, False),
6353
- inout,
6354
- ),
6355
- [],
6356
- [lto_code_data],
6357
- shared_memory_bytes,
6358
- )
6469
+
6470
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6471
+ # CPU/no-MathDx dispatch
6472
+ return ([], [], [], 0)
6473
+ else:
6474
+ # generate the LTO
6475
+ lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
6476
+ arch, size, ept, direction, dir, precision, builder
6477
+ )
6478
+
6479
+ return (
6480
+ (
6481
+ Var(lto_symbol, str, False, True, False),
6482
+ Var(dtype, str, False, True, False),
6483
+ Var(str(shared_memory_bytes), str, False, True, False),
6484
+ Var(str(batch), str, False, True, False),
6485
+ Var(str(ept), str, False, True, False),
6486
+ inout,
6487
+ ),
6488
+ [],
6489
+ [lto_code_data],
6490
+ shared_memory_bytes,
6491
+ )
6359
6492
 
6360
6493
 
6361
6494
  add_builtin(
@@ -6417,7 +6550,7 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
6417
6550
  raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
6418
6551
 
6419
6552
  if len(a.shape) != 2:
6420
- raise ValueError("tile_cholesky() argumust must be a 2D tile")
6553
+ raise ValueError("tile_cholesky() argument must be a 2D tile")
6421
6554
 
6422
6555
  if a.shape[0] != a.shape[1]:
6423
6556
  raise ValueError("tile_cholesky() argument must be square")
@@ -6458,57 +6591,36 @@ def tile_cholesky_generic_lto_dispatch_func(
6458
6591
  if out.type.shape[0] != M or out.type.shape[1] != M:
6459
6592
  raise ValueError("tile_cholesky() output tile must be square")
6460
6593
 
6461
- num_threads = options["block_dim"]
6462
- arch = options["output_arch"]
6463
- lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
6464
-
6465
- # early out if LTO for this combination already exists for this module
6466
- if lto_symbol in builder.ltoirs:
6467
- return lto_symbol, builder.ltoirs[lto_symbol]
6468
-
6469
- # otherwise compile LTO
6470
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6471
- universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6594
+ solver = "potrf"
6595
+ solver_enum = cusolver_function_map[solver]
6472
6596
 
6473
- # cuSOLVERDx only support col-major input/outputs,
6597
+ # cuSOLVERDx only supports col-major input/outputs,
6474
6598
  # so we use upper to mimic a row-major input
6475
- result = warp.context.runtime.core.cuda_compile_solver(
6476
- universal_fatbin_code.name.encode("utf-8"),
6477
- lto_code.name.encode("utf-8"),
6478
- lto_symbol.encode("utf-8"),
6479
- 0,
6480
- None,
6481
- None,
6482
- arch,
6483
- M,
6484
- N,
6485
- cusolver_function_map["potrf"],
6486
- precision_enum,
6487
- cusolver_fill_mode_map["upper"],
6488
- num_threads,
6489
- )
6599
+ fill_mode = cusolver_fill_mode_map["upper"]
6490
6600
 
6491
- if not result:
6492
- for f in [lto_code, universal_fatbin_code]:
6493
- f.close()
6494
- if Path(f.name).exists():
6495
- Path(f.name).unlink()
6496
- raise RuntimeError("Failed to compile tile_cholesky")
6601
+ arch = options["output_arch"]
6602
+ num_threads = options["block_dim"]
6603
+ parameter_list = f"({dtype}*, unsigned)"
6497
6604
 
6605
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6606
+ # CPU/no-MathDx dispatch
6607
+ return ((0, a, out), [], [], 0)
6498
6608
  else:
6499
- with open(lto_code.name, "rb") as f:
6500
- lto_code_data = f.read()
6501
- with open(universal_fatbin_code.name, "rb") as f:
6502
- universal_fatbin_code_data = f.read()
6503
- for f in [lto_code, universal_fatbin_code]:
6504
- f.close()
6505
- Path(f.name).unlink()
6506
-
6507
- builder.ltoirs[lto_symbol] = lto_code_data
6508
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
6509
- builder.fatbins["cholesky"] = universal_fatbin_code_data
6609
+ # generate the LTO
6610
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
6611
+ M,
6612
+ N,
6613
+ solver,
6614
+ solver_enum,
6615
+ fill_mode,
6616
+ arch,
6617
+ precision_enum,
6618
+ num_threads,
6619
+ parameter_list,
6620
+ builder,
6621
+ )
6510
6622
 
6511
- return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6623
+ return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6512
6624
 
6513
6625
 
6514
6626
  add_builtin(
@@ -6602,57 +6714,36 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
6602
6714
  f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
6603
6715
  )
6604
6716
 
6605
- num_threads = options["block_dim"]
6606
- arch = options["output_arch"]
6607
- lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
6608
-
6609
- # early out if LTO for this combination already exists for this module
6610
- if lto_symbol in builder.ltoirs:
6611
- return lto_symbol, builder.ltoirs[lto_symbol]
6612
-
6613
- # otherwise compile LTO
6614
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6615
- universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6717
+ solver = "potrs"
6718
+ solver_enum = cusolver_function_map[solver]
6616
6719
 
6617
- # cuSOLVERDx only support col-major input/outputs,
6720
+ # cuSOLVERDx only supports col-major input/outputs,
6618
6721
  # so we use upper to mimic a row-major input
6619
- result = warp.context.runtime.core.cuda_compile_solver(
6620
- universal_fatbin_code.name.encode("utf-8"),
6621
- lto_code.name.encode("utf-8"),
6622
- lto_symbol.encode("utf-8"),
6623
- 0,
6624
- None,
6625
- None,
6626
- arch,
6627
- M,
6628
- N,
6629
- cusolver_function_map["potrs"],
6630
- precision_enum,
6631
- cusolver_fill_mode_map["upper"],
6632
- num_threads,
6633
- )
6722
+ fill_mode = cusolver_fill_mode_map["upper"]
6634
6723
 
6635
- if not result:
6636
- for f in [lto_code, universal_fatbin_code]:
6637
- f.close()
6638
- if Path(f.name).exists():
6639
- Path(f.name).unlink()
6640
- raise RuntimeError("Failed to compile tile_cholesky_solve")
6724
+ arch = options["output_arch"]
6725
+ num_threads = options["block_dim"]
6726
+ parameter_list = f"({dtype}*, {dtype}*)"
6641
6727
 
6728
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6729
+ # CPU/no-MathDx dispatch
6730
+ return ((0, L, x, y), [], [], 0)
6642
6731
  else:
6643
- with open(lto_code.name, "rb") as f:
6644
- lto_code_data = f.read()
6645
- with open(universal_fatbin_code.name, "rb") as f:
6646
- universal_fatbin_code_data = f.read()
6647
- for f in [lto_code, universal_fatbin_code]:
6648
- f.close()
6649
- Path(f.name).unlink()
6650
-
6651
- builder.ltoirs[lto_symbol] = lto_code_data
6652
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
6653
- builder.fatbins["cholesky"] = universal_fatbin_code_data
6654
-
6655
- return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6732
+ # generate the LTO
6733
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
6734
+ M,
6735
+ N,
6736
+ solver,
6737
+ solver_enum,
6738
+ fill_mode,
6739
+ arch,
6740
+ precision_enum,
6741
+ num_threads,
6742
+ parameter_list,
6743
+ builder,
6744
+ )
6745
+
6746
+ return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6656
6747
 
6657
6748
 
6658
6749
  add_builtin(