warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/tests/test_utils.py CHANGED
@@ -87,7 +87,7 @@ def test_array_scan_error_unsupported_dtype(test, device):
87
87
 
88
88
 
89
89
  def test_radix_sort_pairs(test, device):
90
- keyTypes = [int, wp.float32]
90
+ keyTypes = [int, wp.float32, wp.int64]
91
91
 
92
92
  for keyType in keyTypes:
93
93
  keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
@@ -97,18 +97,46 @@ def test_radix_sort_pairs(test, device):
97
97
  assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3)))
98
98
 
99
99
 
100
- def test_radix_sort_pairs_empty(test, device):
100
+ def test_segmented_sort_pairs(test, device):
101
101
  keyTypes = [int, wp.float32]
102
102
 
103
+ for keyType in keyTypes:
104
+ keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
105
+ values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
106
+ wp.utils.segmented_sort_pairs(
107
+ keys,
108
+ values,
109
+ 8,
110
+ wp.array((0, 4), dtype=int, device=device),
111
+ wp.array((4, 8), dtype=int, device=device),
112
+ )
113
+ assert_np_equal(keys.numpy()[:8], np.array((2, 4, 7, 8, 1, 3, 5, 6)))
114
+ assert_np_equal(values.numpy()[:8], np.array((2, 4, 1, 3, 5, 8, 7, 6)))
115
+
116
+
117
+ def test_radix_sort_pairs_empty(test, device):
118
+ keyTypes = [int, wp.float32, wp.int64]
119
+
103
120
  for keyType in keyTypes:
104
121
  keys = wp.array((), dtype=keyType, device=device)
105
122
  values = wp.array((), dtype=int, device=device)
106
123
  wp.utils.radix_sort_pairs(keys, values, 0)
107
124
 
108
125
 
109
- def test_radix_sort_pairs_error_insufficient_storage(test, device):
126
+ def test_segmented_sort_pairs_empty(test, device):
110
127
  keyTypes = [int, wp.float32]
111
128
 
129
+ for keyType in keyTypes:
130
+ keys = wp.array((), dtype=keyType, device=device)
131
+ values = wp.array((), dtype=int, device=device)
132
+ wp.utils.segmented_sort_pairs(
133
+ keys, values, 0, wp.array((), dtype=int, device=device), wp.array((), dtype=int, device=device)
134
+ )
135
+
136
+
137
+ def test_radix_sort_pairs_error_insufficient_storage(test, device):
138
+ keyTypes = [int, wp.float32, wp.int64]
139
+
112
140
  for keyType in keyTypes:
113
141
  keys = wp.array((1, 2, 3), dtype=keyType, device=device)
114
142
  values = wp.array((1, 2, 3), dtype=int, device=device)
@@ -119,9 +147,28 @@ def test_radix_sort_pairs_error_insufficient_storage(test, device):
119
147
  wp.utils.radix_sort_pairs(keys, values, 3)
120
148
 
121
149
 
122
- def test_radix_sort_pairs_error_unsupported_dtype(test, device):
150
+ def test_segmented_sort_pairs_error_insufficient_storage(test, device):
123
151
  keyTypes = [int, wp.float32]
124
152
 
153
+ for keyType in keyTypes:
154
+ keys = wp.array((1, 2, 3), dtype=keyType, device=device)
155
+ values = wp.array((1, 2, 3), dtype=int, device=device)
156
+ with test.assertRaisesRegex(
157
+ RuntimeError,
158
+ r"Array storage must be large enough to contain 2\*count elements$",
159
+ ):
160
+ wp.utils.segmented_sort_pairs(
161
+ keys,
162
+ values,
163
+ 3,
164
+ wp.array((0,), dtype=int, device=device),
165
+ wp.array((3,), dtype=int, device=device),
166
+ )
167
+
168
+
169
+ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
170
+ keyTypes = [int, wp.float32, wp.int64]
171
+
125
172
  for keyType in keyTypes:
126
173
  keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
127
174
  values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
@@ -132,6 +179,25 @@ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
132
179
  wp.utils.radix_sort_pairs(keys, values, 1)
133
180
 
134
181
 
182
+ def test_segmented_sort_pairs_error_unsupported_dtype(test, device):
183
+ keyTypes = [int, wp.float32]
184
+
185
+ for keyType in keyTypes:
186
+ keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
187
+ values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
188
+ with test.assertRaisesRegex(
189
+ RuntimeError,
190
+ r"Unsupported data type$",
191
+ ):
192
+ wp.utils.segmented_sort_pairs(
193
+ keys,
194
+ values,
195
+ 1,
196
+ wp.array((0,), dtype=int, device=device),
197
+ wp.array((3,), dtype=int, device=device),
198
+ )
199
+
200
+
135
201
  def test_array_sum(test, device):
136
202
  for dtype in (wp.float32, wp.float64):
137
203
  with test.subTest(dtype=dtype):
@@ -468,6 +534,20 @@ add_function_test(
468
534
  test_radix_sort_pairs_error_unsupported_dtype,
469
535
  devices=devices,
470
536
  )
537
+ add_function_test(TestUtils, "test_segmented_sort_pairs", test_segmented_sort_pairs, devices=devices)
538
+ add_function_test(TestUtils, "test_segmented_sort_pairs_empty", test_segmented_sort_pairs, devices=devices)
539
+ add_function_test(
540
+ TestUtils,
541
+ "test_segmented_sort_pairs_error_insufficient_storage",
542
+ test_segmented_sort_pairs_error_insufficient_storage,
543
+ devices=devices,
544
+ )
545
+ add_function_test(
546
+ TestUtils,
547
+ "test_segmented_sort_pairs_error_unsupported_dtype",
548
+ test_segmented_sort_pairs_error_unsupported_dtype,
549
+ devices=devices,
550
+ )
471
551
  add_function_test(TestUtils, "test_array_sum", test_array_sum, devices=devices)
472
552
  add_function_test(
473
553
  TestUtils, "test_array_sum_error_out_dtype_mismatch", test_array_sum_error_out_dtype_mismatch, devices=devices
warp/tests/test_vec.py CHANGED
@@ -1044,7 +1044,7 @@ def test_casting_constructors(test, device, dtype, register_kernels=False):
1044
1044
  assert_np_equal(out, a_grad.numpy())
1045
1045
 
1046
1046
 
1047
- def test_vec_assign(test, device, dtype, register_kernels=False):
1047
+ def test_vector_assign_inplace(test, device, dtype, register_kernels=False):
1048
1048
  np_type = np.dtype(dtype)
1049
1049
  wp_type = wp.types.np_dtype_to_warp_type[np_type]
1050
1050
 
@@ -1085,16 +1085,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1085
1085
  g = a_vec[0] + a_vec[1]
1086
1086
  x[tid] = g
1087
1087
 
1088
- def vectest_in_register_overwrite(x: wp.array(dtype=vec3), a: wp.array(dtype=vec3)):
1089
- tid = wp.tid()
1090
-
1091
- f = vec3(wp_type(0.0))
1092
- a_vec = a[tid]
1093
- f = a_vec
1094
- f[1] = wp_type(3.0)
1095
-
1096
- x[tid] = f
1097
-
1098
1088
  def vectest_component(x: wp.array(dtype=vec3), y: wp.array(dtype=wp_type)):
1099
1089
  i = wp.tid()
1100
1090
 
@@ -1106,7 +1096,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1106
1096
 
1107
1097
  kernel_read_write_store = getkernel(vectest_read_write_store, suffix=dtype.__name__)
1108
1098
  kernel_in_register = getkernel(vectest_in_register, suffix=dtype.__name__)
1109
- kernel_in_register_overwrite = getkernel(vectest_in_register_overwrite, suffix=dtype.__name__)
1110
1099
  kernel_component = getkernel(vectest_component, suffix=dtype.__name__)
1111
1100
 
1112
1101
  if register_kernels:
@@ -1156,7 +1145,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1156
1145
  x = wp.zeros(1, dtype=vec3, device=device, requires_grad=True)
1157
1146
  y = wp.ones(1, dtype=wp_type, device=device, requires_grad=True)
1158
1147
 
1159
- tape = wp.Tape()
1160
1148
  with tape:
1161
1149
  wp.launch(kernel_component, dim=1, inputs=[x, y], device=device)
1162
1150
 
@@ -1165,20 +1153,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1165
1153
  assert_np_equal(x.numpy(), np.array([[1.0, 2.0, 3.0]], dtype=np_type))
1166
1154
  assert_np_equal(y.grad.numpy(), np.array([6.0], dtype=np_type))
1167
1155
 
1168
- tape.reset()
1169
-
1170
- x = wp.zeros(1, dtype=vec3, device=device, requires_grad=True)
1171
- a = wp.ones(1, dtype=vec3, device=device, requires_grad=True)
1172
-
1173
- tape = wp.Tape()
1174
- with tape:
1175
- wp.launch(kernel_in_register_overwrite, dim=1, inputs=[x, a], device=device)
1176
-
1177
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1178
-
1179
- assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0]], dtype=np_type))
1180
- assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0]], dtype=np_type))
1181
-
1182
1156
 
1183
1157
  @wp.kernel
1184
1158
  def test_vector_constructor_value_func():
@@ -1325,15 +1299,15 @@ def vector_augassign_kernel(
1325
1299
  def test_vector_augassign(test, device):
1326
1300
  N = 3
1327
1301
 
1328
- a = wp.zeros(N, dtype=wp.vec3, requires_grad=True)
1329
- b = wp.ones(N, dtype=wp.vec3, requires_grad=True)
1302
+ a = wp.zeros(N, dtype=wp.vec3, requires_grad=True, device=device)
1303
+ b = wp.ones(N, dtype=wp.vec3, requires_grad=True, device=device)
1330
1304
 
1331
- c = wp.zeros(N, dtype=wp.vec3, requires_grad=True)
1332
- d = wp.ones(N, dtype=wp.vec3, requires_grad=True)
1305
+ c = wp.zeros(N, dtype=wp.vec3, requires_grad=True, device=device)
1306
+ d = wp.ones(N, dtype=wp.vec3, requires_grad=True, device=device)
1333
1307
 
1334
1308
  tape = wp.Tape()
1335
1309
  with tape:
1336
- wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d])
1310
+ wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d], device=device)
1337
1311
 
1338
1312
  tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
1339
1313
 
@@ -1346,6 +1320,38 @@ def test_vector_augassign(test, device):
1346
1320
  assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
1347
1321
 
1348
1322
 
1323
+ def test_vector_assign_copy(test, device):
1324
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
1325
+ try:
1326
+ wp.config.enable_vector_component_overwrites = True
1327
+
1328
+ @wp.kernel
1329
+ def vec_in_register_overwrite(x: wp.array(dtype=wp.vec3), a: wp.array(dtype=wp.vec3)):
1330
+ tid = wp.tid()
1331
+
1332
+ f = wp.vec3(0.0)
1333
+ a_vec = a[tid]
1334
+ f = a_vec
1335
+ f[1] = 3.0
1336
+
1337
+ x[tid] = f
1338
+
1339
+ x = wp.zeros(1, dtype=wp.vec3, device=device, requires_grad=True)
1340
+ a = wp.ones(1, dtype=wp.vec3, device=device, requires_grad=True)
1341
+
1342
+ tape = wp.Tape()
1343
+ with tape:
1344
+ wp.launch(vec_in_register_overwrite, dim=1, inputs=[x, a], device=device)
1345
+
1346
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1347
+
1348
+ assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0]], dtype=float))
1349
+ assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0]], dtype=float))
1350
+
1351
+ finally:
1352
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
1353
+
1354
+
1349
1355
  devices = get_test_devices()
1350
1356
 
1351
1357
 
@@ -1414,8 +1420,8 @@ for dtype in np_float_types:
1414
1420
  )
1415
1421
  add_function_test_register_kernel(
1416
1422
  TestVec,
1417
- f"test_vec_assign_{dtype.__name__}",
1418
- test_vec_assign,
1423
+ f"test_vector_assign_inplace_{dtype.__name__}",
1424
+ test_vector_assign_inplace,
1419
1425
  devices=devices,
1420
1426
  dtype=dtype,
1421
1427
  )
@@ -1468,6 +1474,12 @@ add_function_test(
1468
1474
  test_vector_augassign,
1469
1475
  devices=devices,
1470
1476
  )
1477
+ add_function_test(
1478
+ TestVec,
1479
+ "test_vector_assign_copy",
1480
+ test_vector_assign_copy,
1481
+ devices=devices,
1482
+ )
1471
1483
 
1472
1484
 
1473
1485
  if __name__ == "__main__":
File without changes
@@ -20,8 +20,6 @@ import numpy as np
20
20
  import warp as wp
21
21
  from warp.tests.unittest_utils import *
22
22
 
23
- wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
24
-
25
23
  TILE_M = wp.constant(8)
26
24
  TILE_N = wp.constant(4)
27
25
  TILE_K = wp.constant(8)
@@ -216,7 +214,6 @@ def test_tile_binary_map(test, device):
216
214
  assert_np_equal(B_wp.grad.numpy(), B_grad)
217
215
 
218
216
 
219
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
220
217
  def test_tile_grouped_gemm(test, device):
221
218
  @wp.kernel
222
219
  def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
@@ -256,60 +253,62 @@ def test_tile_grouped_gemm(test, device):
256
253
  assert_np_equal(C_wp.numpy(), C, 1e-6)
257
254
 
258
255
 
259
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
260
- def test_tile_gemm(test, device):
261
- @wp.kernel
262
- def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
263
- # output tile index
264
- i, j = wp.tid()
256
+ def test_tile_gemm(dtype):
257
+ def test(test, device):
258
+ @wp.kernel
259
+ def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
260
+ # output tile index
261
+ i, j = wp.tid()
265
262
 
266
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
263
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
267
264
 
268
- M = A.shape[0]
269
- N = B.shape[1]
270
- K = A.shape[1]
265
+ M = A.shape[0]
266
+ N = B.shape[1]
267
+ K = A.shape[1]
271
268
 
272
- count = int(K / TILE_K)
269
+ count = int(K / TILE_K)
273
270
 
274
- for k in range(0, count):
275
- a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
276
- b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
271
+ for k in range(0, count):
272
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
273
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
277
274
 
278
- # sum += a*b
279
- wp.tile_matmul(a, b, sum)
275
+ # sum += a*b
276
+ wp.tile_matmul(a, b, sum)
280
277
 
281
- wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
278
+ wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
282
279
 
283
- M = TILE_M * 7
284
- K = TILE_K * 6
285
- N = TILE_N * 5
280
+ M = TILE_M * 7
281
+ K = TILE_K * 6
282
+ N = TILE_N * 5
286
283
 
287
- rng = np.random.default_rng(42)
288
- A = rng.random((M, K), dtype=np.float32)
289
- B = rng.random((K, N), dtype=np.float32)
290
- C = np.zeros((M, N), dtype=np.float32)
284
+ rng = np.random.default_rng(42)
285
+ A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
286
+ B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
+ C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
291
288
 
292
- A_wp = wp.array(A, requires_grad=True, device=device)
293
- B_wp = wp.array(B, requires_grad=True, device=device)
294
- C_wp = wp.array(C, requires_grad=True, device=device)
289
+ A_wp = wp.array(A, requires_grad=True, device=device)
290
+ B_wp = wp.array(B, requires_grad=True, device=device)
291
+ C_wp = wp.array(C, requires_grad=True, device=device)
295
292
 
296
- with wp.Tape() as tape:
297
- wp.launch_tiled(
298
- tile_gemm,
299
- dim=(int(M / TILE_M), int(N / TILE_N)),
300
- inputs=[A_wp, B_wp, C_wp],
301
- block_dim=TILE_DIM,
302
- device=device,
303
- )
293
+ with wp.Tape() as tape:
294
+ wp.launch_tiled(
295
+ tile_gemm,
296
+ dim=(int(M / TILE_M), int(N / TILE_N)),
297
+ inputs=[A_wp, B_wp, C_wp],
298
+ block_dim=TILE_DIM,
299
+ device=device,
300
+ )
304
301
 
305
- assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-5)
302
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
306
303
 
307
- adj_C = np.ones_like(C)
304
+ adj_C = np.ones_like(C)
308
305
 
309
- tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
306
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
310
307
 
311
- assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-5)
312
- assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-5)
308
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
309
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
310
+
311
+ return test
313
312
 
314
313
 
315
314
  @wp.kernel
@@ -550,7 +549,6 @@ def test_tile_transpose(test, device):
550
549
  assert_np_equal(output.numpy(), input.numpy().T)
551
550
 
552
551
 
553
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
554
552
  def test_tile_transpose_matmul(test, device):
555
553
  @wp.kernel
556
554
  def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
@@ -572,9 +570,36 @@ def test_tile_transpose_matmul(test, device):
572
570
 
573
571
 
574
572
  @wp.kernel
575
- def test_tile_broadcast_add_kernel(
573
+ def test_tile_broadcast_add_1d_kernel(
574
+ input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
575
+ ):
576
+ a = wp.tile_load(input_a, shape=(10,))
577
+ b = wp.tile_load(input_b, shape=(1,))
578
+
579
+ c = wp.tile_broadcast(b, shape=(10,))
580
+ d = a + c
581
+
582
+ wp.tile_store(output, d)
583
+
584
+
585
+ def test_tile_broadcast_add_1d(test, device):
586
+ N = 10
587
+
588
+ # implicit 1-dim ([1], 1)
589
+ a = wp.array(np.arange(0, N, dtype=np.float32), device=device)
590
+ b = wp.array(np.ones(1, dtype=np.float32), device=device)
591
+ out = wp.zeros((N,), dtype=float, device=device)
592
+
593
+ wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
594
+
595
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
596
+
597
+
598
+ @wp.kernel
599
+ def test_tile_broadcast_add_2d_kernel(
576
600
  input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
577
601
  ):
602
+ # implicit 1-dim ([1], 10)
578
603
  a = wp.tile_load(input_a, shape=(10, 10))
579
604
  b = wp.tile_load(input_b, shape=10)
580
605
 
@@ -584,7 +609,7 @@ def test_tile_broadcast_add_kernel(
584
609
  wp.tile_store(output, d)
585
610
 
586
611
 
587
- def test_tile_broadcast_add(test, device):
612
+ def test_tile_broadcast_add_2d(test, device):
588
613
  M = 10
589
614
  N = 10
590
615
 
@@ -592,7 +617,62 @@ def test_tile_broadcast_add(test, device):
592
617
  b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
593
618
  out = wp.zeros((M, N), dtype=float, device=device)
594
619
 
595
- wp.launch_tiled(test_tile_broadcast_add_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
620
+ wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
621
+
622
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
623
+
624
+
625
+ @wp.kernel
626
+ def test_tile_broadcast_add_3d_kernel(
627
+ input_a: wp.array3d(dtype=float), input_b: wp.array3d(dtype=float), output: wp.array3d(dtype=float)
628
+ ):
629
+ a = wp.tile_load(input_a, shape=(4, 10, 12))
630
+ b = wp.tile_load(input_b, shape=(4, 10, 1))
631
+
632
+ c = wp.tile_broadcast(b, shape=(4, 10, 12))
633
+ d = a + c
634
+
635
+ wp.tile_store(output, d)
636
+
637
+
638
+ def test_tile_broadcast_add_3d(test, device):
639
+ M = 4
640
+ N = 10
641
+ O = 12
642
+
643
+ # explicit 1-dim (M, N, 1) to (M, N, O)
644
+ a = wp.array(np.ones((M, N, O), dtype=np.float32), device=device)
645
+ b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
646
+ out = wp.zeros((M, N, O), dtype=float, device=device)
647
+
648
+ wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
649
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
650
+
651
+
652
+ @wp.kernel
653
+ def test_tile_broadcast_add_4d_kernel(
654
+ input_a: wp.array4d(dtype=float), input_b: wp.array4d(dtype=float), output: wp.array4d(dtype=float)
655
+ ):
656
+ a = wp.tile_load(input_a, shape=(4, 10, 5, 6))
657
+ b = wp.tile_load(input_b, shape=(4, 1, 5, 1))
658
+ c = wp.tile_broadcast(b, shape=(4, 10, 5, 6))
659
+ d = a + c
660
+
661
+ wp.tile_store(output, d)
662
+
663
+
664
+ def test_tile_broadcast_add_4d(test, device):
665
+ M = 4
666
+ N = 10
667
+ O = 5
668
+ P = 6
669
+
670
+ # explicit 1-dims (M, 1, O, 1) to (M, N, O, P)
671
+ a = wp.array(np.ones((M, N, O, P), dtype=np.float32), device=device)
672
+ b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
673
+ out = wp.zeros((M, N, O, P), dtype=float, device=device)
674
+
675
+ wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
596
676
 
597
677
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
598
678
 
@@ -665,7 +745,7 @@ def test_tile_print(test, device):
665
745
  wp.synchronize()
666
746
 
667
747
 
668
- devices = get_cuda_test_devices()
748
+ devices = get_test_devices()
669
749
 
670
750
 
671
751
  class TestTile(unittest.TestCase):
@@ -677,15 +757,20 @@ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devi
677
757
  add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
678
758
  add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
679
759
  add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
680
- add_function_test(TestTile, "test_tile_gemm", test_tile_gemm, devices=devices)
760
+ add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
761
+ add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
762
+ add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
681
763
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
682
764
  add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
683
765
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
684
- add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
766
+ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
685
767
  add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
686
768
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
687
769
  add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
688
- add_function_test(TestTile, "test_tile_broadcast_add", test_tile_broadcast_add, devices=devices)
770
+ add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
771
+ add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
772
+ add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
773
+ add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
689
774
  add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
690
775
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
691
776
  add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
@@ -376,7 +376,7 @@ def test_tile_load_fortran(test, device):
376
376
  assert_array_equal(B_wp.grad, A_wp.grad)
377
377
 
378
378
 
379
- devices = get_cuda_test_devices()
379
+ devices = get_test_devices()
380
380
 
381
381
 
382
382
  class TestTileLoad(unittest.TestCase):
@@ -92,6 +92,7 @@ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dt
92
92
  wp.tile_store(gy, xy)
93
93
 
94
94
 
95
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
95
96
  def test_tile_math_fft(test, device, wp_dtype):
96
97
  np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
97
98
  np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
@@ -172,31 +173,33 @@ def test_tile_math_cholesky(test, device):
172
173
  # TODO: implement and test backward pass
173
174
 
174
175
 
175
- devices = get_cuda_test_devices()
176
+ all_devices = get_test_devices()
177
+ cuda_devices = get_cuda_test_devices()
176
178
 
177
179
 
178
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
179
180
  class TestTileMathDx(unittest.TestCase):
180
181
  pass
181
182
 
182
183
 
183
184
  # check_output=False so we can enable libmathdx's logging without failing the tests
184
- add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
185
185
  add_function_test(
186
- TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=devices, check_output=False
186
+ TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=all_devices, check_output=False
187
+ )
188
+ add_function_test(
189
+ TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=all_devices, check_output=False
187
190
  )
188
191
  add_function_test(
189
192
  TestTileMathDx,
190
193
  "test_tile_math_fft_vec2f",
191
194
  functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
192
- devices=devices,
195
+ devices=cuda_devices,
193
196
  check_output=False,
194
197
  )
195
198
  add_function_test(
196
199
  TestTileMathDx,
197
200
  "test_tile_math_fft_vec2d",
198
201
  functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
199
- devices=devices,
202
+ devices=cuda_devices,
200
203
  check_output=False,
201
204
  )
202
205