warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,21 @@
1
- import numpy as np
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
2
8
  import unittest
3
9
 
10
+ import numpy as np
11
+
4
12
  import warp as wp
5
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
6
14
 
7
15
  wp.init()
8
16
 
17
+ from warp.context import runtime # noqa: E402
18
+
9
19
 
10
20
  class gemm_test_bed_runner:
11
21
  def __init__(self, dtype, device):
@@ -21,63 +31,54 @@ class gemm_test_bed_runner:
21
31
  np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
22
32
  dtype=self.dtype,
23
33
  device=self.device,
24
- requires_grad=True
34
+ requires_grad=True,
25
35
  )
26
36
  B = wp.array2d(
27
37
  np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
28
38
  dtype=self.dtype,
29
39
  device=self.device,
30
- requires_grad=True
40
+ requires_grad=True,
31
41
  )
32
42
  C = wp.array2d(
33
43
  np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
34
44
  dtype=self.dtype,
35
45
  device=self.device,
36
- requires_grad=True
46
+ requires_grad=True,
37
47
  )
38
- D = wp.array2d(
39
- np.zeros((m, n)),
40
- dtype=self.dtype,
41
- device=self.device,
42
- requires_grad=True)
48
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
43
49
  else:
44
50
  A = wp.array3d(
45
51
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
46
52
  dtype=self.dtype,
47
53
  device=self.device,
48
- requires_grad=True
54
+ requires_grad=True,
49
55
  )
50
56
  B = wp.array3d(
51
57
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
52
58
  dtype=self.dtype,
53
59
  device=self.device,
54
- requires_grad=True
60
+ requires_grad=True,
55
61
  )
56
62
  C = wp.array3d(
57
63
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
58
64
  dtype=self.dtype,
59
65
  device=self.device,
60
- requires_grad=True
61
- )
62
- D = wp.array3d(
63
- np.zeros((batch_count, m, n)),
64
- dtype=self.dtype,
65
- device=self.device,
66
- requires_grad=True
66
+ requires_grad=True,
67
67
  )
68
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
68
69
  return A, B, C, D
69
70
 
70
71
  def run_and_verify(self, m, n, k, batch_count, alpha, beta):
71
72
  A, B, C, D = self.alloc(m, n, k, batch_count)
72
73
  ones = wp.zeros_like(D)
73
74
  ones.fill_(1.0)
74
-
75
+
75
76
  if batch_count == 1:
76
77
  tape = wp.Tape()
77
78
  with tape:
78
79
  wp.matmul(A, B, C, D, alpha, beta, False, self.device)
79
- tape.backward(grads={D : ones})
80
-
80
+ tape.backward(grads={D: ones})
81
+
81
82
  D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
82
83
  assert np.array_equal(D_np, D.numpy())
83
84
 
@@ -89,8 +90,8 @@ class gemm_test_bed_runner:
89
90
  tape = wp.Tape()
90
91
  with tape:
91
92
  wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
92
- tape.backward(grads={D : ones})
93
-
93
+ tape.backward(grads={D: ones})
94
+
94
95
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
95
96
  assert np.array_equal(D_np, D.numpy())
96
97
 
@@ -132,75 +133,45 @@ class gemm_test_bed_runner_transpose:
132
133
  np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
133
134
  dtype=self.dtype,
134
135
  device=self.device,
135
- requires_grad=True
136
+ requires_grad=True,
136
137
  )
137
138
  B = wp.array2d(
138
139
  np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
139
140
  dtype=self.dtype,
140
141
  device=self.device,
141
- requires_grad=True
142
+ requires_grad=True,
142
143
  )
143
144
  C = wp.array2d(
144
145
  np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
145
146
  dtype=self.dtype,
146
147
  device=self.device,
147
- requires_grad=True
148
- )
149
- D = wp.array2d(
150
- np.zeros((m, n)),
151
- dtype=self.dtype,
152
- device=self.device,
153
- requires_grad=True
154
- )
155
- AT = wp.array2d(
156
- A.numpy().transpose([1, 0]),
157
- dtype=self.dtype,
158
- device=self.device,
159
- requires_grad=True
160
- )
161
- BT = wp.array2d(
162
- B.numpy().transpose([1, 0]),
163
- dtype=self.dtype,
164
- device=self.device,
165
- requires_grad=True
148
+ requires_grad=True,
166
149
  )
150
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
151
+ AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
152
+ BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
167
153
  else:
168
154
  A = wp.array3d(
169
155
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
170
156
  dtype=self.dtype,
171
157
  device=self.device,
172
- requires_grad=True
158
+ requires_grad=True,
173
159
  )
174
160
  B = wp.array3d(
175
161
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
176
162
  dtype=self.dtype,
177
163
  device=self.device,
178
- requires_grad=True
164
+ requires_grad=True,
179
165
  )
180
166
  C = wp.array3d(
181
167
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
182
168
  dtype=self.dtype,
183
169
  device=self.device,
184
- requires_grad=True
185
- )
186
- D = wp.array3d(
187
- np.zeros((batch_count, m, n)),
188
- dtype=self.dtype,
189
- device=self.device,
190
- requires_grad=True
191
- )
192
- AT = wp.array3d(
193
- A.numpy().transpose([0, 2, 1]),
194
- dtype=self.dtype,
195
- device=self.device,
196
- requires_grad=True
197
- )
198
- BT = wp.array3d(
199
- B.numpy().transpose([0, 2, 1]),
200
- dtype=self.dtype,
201
- device=self.device,
202
- requires_grad=True
170
+ requires_grad=True,
203
171
  )
172
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
173
+ AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
174
+ BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
204
175
  return A, B, C, D, AT, BT
205
176
 
206
177
  def run_and_verify(self, m, n, k, batch_count, alpha, beta):
@@ -219,17 +190,17 @@ class gemm_test_bed_runner_transpose:
219
190
  ones3.fill_(1.0)
220
191
 
221
192
  if batch_count == 1:
222
- ATT1 = AT1.transpose([1, 0])
193
+ ATT1 = AT1.transpose([1, 0])
223
194
  BTT1 = BT1.transpose([1, 0])
224
- ATT2 = AT2.transpose([1, 0])
195
+ ATT2 = AT2.transpose([1, 0])
225
196
  BTT2 = BT2.transpose([1, 0])
226
197
  tape = wp.Tape()
227
198
  with tape:
228
199
  wp.matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
229
200
  wp.matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
230
201
  wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
231
- tape.backward(grads={D1 : ones1, D2 : ones2, D3 : ones3})
232
-
202
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
203
+
233
204
  D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
234
205
  assert np.array_equal(D_np, D1.numpy())
235
206
  assert np.array_equal(D_np, D2.numpy())
@@ -240,7 +211,7 @@ class gemm_test_bed_runner_transpose:
240
211
  adj_C_np = beta * ones1.numpy()
241
212
 
242
213
  else:
243
- ATT1 = AT1.transpose([0, 2, 1])
214
+ ATT1 = AT1.transpose([0, 2, 1])
244
215
  BTT1 = BT1.transpose([0, 2, 1])
245
216
  ATT2 = AT2.transpose([0, 2, 1])
246
217
  BTT2 = BT2.transpose([0, 2, 1])
@@ -249,8 +220,8 @@ class gemm_test_bed_runner_transpose:
249
220
  wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
250
221
  wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
251
222
  wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
252
- tape.backward(grads={D1 : ones1, D2 : ones2, D3 : ones3})
253
-
223
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
224
+
254
225
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
255
226
  assert np.array_equal(D_np, D1.numpy())
256
227
  assert np.array_equal(D_np, D2.numpy())
@@ -282,6 +253,7 @@ class gemm_test_bed_runner_transpose:
282
253
  self.run_and_verify(m, n, k, batch_count, alpha, beta)
283
254
 
284
255
 
256
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
285
257
  def test_f32(test, device):
286
258
  gemm_test_bed_runner(wp.float32, device).run()
287
259
  gemm_test_bed_runner_transpose(wp.float32, device).run()
@@ -293,6 +265,7 @@ def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float))
293
265
  wp.atomic_add(loss, 0, arr[i, j])
294
266
 
295
267
 
268
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
296
269
  def test_tape(test, device):
297
270
  rng = np.random.default_rng(42)
298
271
  low = -4.5
@@ -320,6 +293,7 @@ def test_tape(test, device):
320
293
 
321
294
  tape.backward(loss=loss)
322
295
  A_grad = A.grad.numpy()
296
+ tape.reset()
323
297
 
324
298
  # test adjoint
325
299
  D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
@@ -331,6 +305,7 @@ def test_tape(test, device):
331
305
  assert_array_equal(A.grad, wp.zeros_like(A))
332
306
 
333
307
 
308
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
334
309
  def test_operator(test, device):
335
310
  rng = np.random.default_rng(42)
336
311
  low = -4.5
@@ -366,6 +341,7 @@ def test_operator(test, device):
366
341
  assert_array_equal(A.grad, wp.zeros_like(A))
367
342
 
368
343
 
344
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
369
345
  def test_large_batch_count(test, device):
370
346
  rng = np.random.default_rng(42)
371
347
  low = -4.5
@@ -375,31 +351,38 @@ def test_large_batch_count(test, device):
375
351
  k = 4
376
352
  batch_count = 65535 * 2 + int(65535 / 2)
377
353
  A = wp.array3d(
378
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))), dtype=float, device=device, requires_grad=True
354
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
355
+ dtype=float,
356
+ device=device,
357
+ requires_grad=True,
379
358
  )
380
359
  B = wp.array3d(
381
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))), dtype=float, device=device, requires_grad=True
360
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
361
+ dtype=float,
362
+ device=device,
363
+ requires_grad=True,
382
364
  )
383
365
  C = wp.array3d(
384
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))), dtype=float, device=device, requires_grad=True
385
- )
386
- D = wp.array3d(
387
- np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True
366
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
367
+ dtype=float,
368
+ device=device,
369
+ requires_grad=True,
388
370
  )
371
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
389
372
  ones = wp.zeros_like(D)
390
373
  ones.fill_(1.0)
391
374
 
392
375
  alpha = 1.0
393
376
  beta = 1.0
394
-
377
+
395
378
  tape = wp.Tape()
396
379
  with tape:
397
380
  wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False, device=device)
398
- tape.backward(grads={D : ones})
381
+ tape.backward(grads={D: ones})
399
382
 
400
383
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
401
384
  assert np.array_equal(D_np, D.numpy())
402
-
385
+
403
386
  adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
404
387
  adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
405
388
  adj_C_np = beta * ones.numpy()
@@ -409,28 +392,19 @@ def test_large_batch_count(test, device):
409
392
  assert np.array_equal(adj_C_np, C.grad.numpy())
410
393
 
411
394
 
412
- def register(parent):
413
- devices = [d for d in get_test_devices()]
395
+ devices = get_test_devices()
414
396
 
415
- class TestMatmul(parent):
416
- pass
417
397
 
418
- if devices:
419
- # check if CUTLASS is available
420
- from warp.context import runtime
398
+ class TestMatmulLite(unittest.TestCase):
399
+ pass
421
400
 
422
- if runtime.core.is_cutlass_enabled():
423
- add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
424
- add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
425
- add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
426
- add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices)
427
- else:
428
- print("Skipping matmul tests because CUTLASS is not supported in this build")
429
401
 
430
- return TestMatmul
402
+ add_function_test(TestMatmulLite, "test_f32", test_f32, devices=devices)
403
+ add_function_test(TestMatmulLite, "test_tape", test_tape, devices=devices)
404
+ add_function_test(TestMatmulLite, "test_operator", test_operator, devices=devices)
405
+ add_function_test(TestMatmulLite, "test_large_batch_count", test_large_batch_count, devices=devices)
431
406
 
432
407
 
433
408
  if __name__ == "__main__":
434
409
  wp.build.clear_kernel_cache()
435
- _ = register(unittest.TestCase)
436
410
  unittest.main(verbosity=2, failfast=False)
warp/tests/test_mesh.py CHANGED
@@ -10,7 +10,7 @@ import unittest
10
10
  import numpy as np
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
14
14
 
15
15
  # fmt: off
16
16
 
@@ -222,9 +222,9 @@ def query_ray_kernel(
222
222
 
223
223
 
224
224
  def test_mesh_query_ray(test, device):
225
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
225
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
226
226
 
227
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
227
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
228
228
  mesh = wp.Mesh(points=points, indices=indices)
229
229
  expected_sign = -1.0
230
230
  wp.launch(
@@ -234,9 +234,10 @@ def test_mesh_query_ray(test, device):
234
234
  mesh.id,
235
235
  expected_sign,
236
236
  ],
237
+ device=device,
237
238
  )
238
239
 
239
- indices = wp.array(LEFT_HANDED_FACE_VERTEX_INDICES, dtype=int)
240
+ indices = wp.array(LEFT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
240
241
  mesh = wp.Mesh(points=points, indices=indices)
241
242
  expected_sign = 1.0
242
243
  wp.launch(
@@ -246,76 +247,78 @@ def test_mesh_query_ray(test, device):
246
247
  mesh.id,
247
248
  expected_sign,
248
249
  ],
250
+ device=device,
249
251
  )
250
252
 
251
253
 
252
254
  def test_mesh_refit_graph(test, device):
253
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
255
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
254
256
 
255
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
257
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
256
258
  mesh = wp.Mesh(points=points, indices=indices)
257
259
 
258
- wp.capture_begin()
259
-
260
- mesh.refit()
261
-
262
- graph = wp.capture_end()
260
+ wp.capture_begin(device, force_module_load=False)
261
+ try:
262
+ mesh.refit()
263
+ finally:
264
+ graph = wp.capture_end(device)
263
265
 
264
266
  # replay
265
267
  num_iters = 10
266
268
  for _ in range(num_iters):
267
269
  wp.capture_launch(graph)
268
270
 
271
+ wp.synchronize_device(device)
272
+
269
273
 
270
274
  def test_mesh_exceptions(test, device):
271
275
  # points and indices must be on same device
272
276
  with test.assertRaises(RuntimeError):
273
277
  points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device="cpu")
274
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
278
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
275
279
  wp.Mesh(points=points, indices=indices)
276
280
 
277
281
  # points must be vec3
278
282
  with test.assertRaises(RuntimeError):
279
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3d)
280
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
283
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3d, device=device)
284
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
281
285
  wp.Mesh(points=points, indices=indices)
282
286
 
283
287
  # velocities must be vec3
284
288
  with test.assertRaises(RuntimeError):
285
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
286
- velocities = wp.zeros(points.shape, dtype=wp.vec3d)
287
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
289
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
290
+ velocities = wp.zeros(points.shape, dtype=wp.vec3d, device=device)
291
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
288
292
  wp.Mesh(points=points, indices=indices, velocities=velocities)
289
293
 
290
294
  # indices must be int32
291
295
  with test.assertRaises(RuntimeError):
292
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
293
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=wp.int64)
296
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
297
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=wp.int64, device=device)
294
298
  wp.Mesh(points=points, indices=indices)
295
299
 
296
300
  # indices must be 1d
297
301
  with test.assertRaises(RuntimeError):
298
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
299
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
302
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
303
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
300
304
  indices = indices.reshape((3, -1))
301
305
  wp.Mesh(points=points, indices=indices)
302
306
 
303
307
 
304
- def register(parent):
305
- devices = get_test_devices()
308
+ devices = get_test_devices()
309
+
310
+
311
+ class TestMesh(unittest.TestCase):
312
+ pass
306
313
 
307
- class TestMesh(parent):
308
- pass
309
314
 
310
- add_function_test(TestMesh, "test_mesh_read_properties", test_mesh_read_properties, devices=devices)
311
- add_function_test(TestMesh, "test_mesh_query_point", test_mesh_query_point, devices=devices)
312
- add_function_test(TestMesh, "test_mesh_query_ray", test_mesh_query_ray, devices=devices)
313
- add_function_test(TestMesh, "test_mesh_refit_graph", test_mesh_refit_graph, devices=wp.get_cuda_devices())
314
- add_function_test(TestMesh, "test_mesh_exceptions", test_mesh_exceptions, devices=wp.get_cuda_devices())
315
- return TestMesh
315
+ add_function_test(TestMesh, "test_mesh_read_properties", test_mesh_read_properties, devices=devices)
316
+ add_function_test(TestMesh, "test_mesh_query_point", test_mesh_query_point, devices=devices)
317
+ add_function_test(TestMesh, "test_mesh_query_ray", test_mesh_query_ray, devices=devices)
318
+ add_function_test(TestMesh, "test_mesh_refit_graph", test_mesh_refit_graph, devices=get_unique_cuda_test_devices())
319
+ add_function_test(TestMesh, "test_mesh_exceptions", test_mesh_exceptions, devices=get_unique_cuda_test_devices())
316
320
 
317
321
 
318
322
  if __name__ == "__main__":
319
323
  wp.build.clear_kernel_cache()
320
- _ = register(unittest.TestCase)
321
324
  unittest.main(verbosity=2)
@@ -10,7 +10,7 @@ import unittest
10
10
  import numpy as np
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
14
14
 
15
15
  wp.init()
16
16
 
@@ -98,7 +98,6 @@ def test_compute_bounds(test, device):
98
98
 
99
99
  lower_view = lowers.numpy()
100
100
  upper_view = uppers.numpy()
101
- wp.synchronize()
102
101
 
103
102
  # Confirm the bounds of each triangle are correct.
104
103
  test.assertTrue(lower_view[0][0] == 0)
@@ -150,8 +149,6 @@ def test_mesh_query_aabb_count_overlap(test, device):
150
149
  device=device,
151
150
  )
152
151
 
153
- wp.synchronize()
154
-
155
152
  view = counts.numpy()
156
153
 
157
154
  # 2 triangles that share a vertex having overlapping AABBs.
@@ -190,8 +187,6 @@ def test_mesh_query_aabb_count_nonoverlap(test, device):
190
187
  device=device,
191
188
  )
192
189
 
193
- wp.synchronize()
194
-
195
190
  view = counts.numpy()
196
191
 
197
192
  # AABB query only returns one triangle at a time, the triangles are not close enough to overlap.
@@ -199,30 +194,28 @@ def test_mesh_query_aabb_count_nonoverlap(test, device):
199
194
  test.assertTrue(c == 1)
200
195
 
201
196
 
202
- def register(parent):
203
- devices = get_test_devices()
197
+ devices = get_test_devices()
204
198
 
205
- class TestMeshQueryAABBMethods(parent):
206
- pass
207
199
 
208
- add_function_test(TestMeshQueryAABBMethods, "test_compute_bounds", test_compute_bounds, devices=devices)
209
- add_function_test(
210
- TestMeshQueryAABBMethods,
211
- "test_mesh_query_aabb_count_overlap",
212
- test_mesh_query_aabb_count_overlap,
213
- devices=devices,
214
- )
215
- add_function_test(
216
- TestMeshQueryAABBMethods,
217
- "test_mesh_query_aabb_count_nonoverlap",
218
- test_mesh_query_aabb_count_nonoverlap,
219
- devices=devices,
220
- )
200
+ class TestMeshQueryAABBMethods(unittest.TestCase):
201
+ pass
202
+
221
203
 
222
- return TestMeshQueryAABBMethods
204
+ add_function_test(TestMeshQueryAABBMethods, "test_compute_bounds", test_compute_bounds, devices=devices)
205
+ add_function_test(
206
+ TestMeshQueryAABBMethods,
207
+ "test_mesh_query_aabb_count_overlap",
208
+ test_mesh_query_aabb_count_overlap,
209
+ devices=devices,
210
+ )
211
+ add_function_test(
212
+ TestMeshQueryAABBMethods,
213
+ "test_mesh_query_aabb_count_nonoverlap",
214
+ test_mesh_query_aabb_count_nonoverlap,
215
+ devices=devices,
216
+ )
223
217
 
224
218
 
225
219
  if __name__ == "__main__":
226
220
  wp.build.clear_kernel_cache()
227
- _ = register(unittest.TestCase)
228
221
  unittest.main(verbosity=2)