warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (123) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1076 -480
  8. warp/codegen.py +240 -119
  9. warp/config.py +1 -1
  10. warp/context.py +298 -84
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth_self_contact.py +260 -0
  27. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  28. warp/examples/sim/example_jacobian_ik.py +0 -2
  29. warp/examples/sim/example_quadruped.py +5 -2
  30. warp/examples/tile/example_tile_cholesky.py +79 -0
  31. warp/examples/tile/example_tile_convolution.py +2 -2
  32. warp/examples/tile/example_tile_fft.py +2 -2
  33. warp/examples/tile/example_tile_filtering.py +3 -3
  34. warp/examples/tile/example_tile_matmul.py +4 -4
  35. warp/examples/tile/example_tile_mlp.py +12 -12
  36. warp/examples/tile/example_tile_nbody.py +180 -0
  37. warp/examples/tile/example_tile_walker.py +319 -0
  38. warp/math.py +147 -0
  39. warp/native/array.h +12 -0
  40. warp/native/builtin.h +0 -1
  41. warp/native/bvh.cpp +149 -70
  42. warp/native/bvh.cu +287 -68
  43. warp/native/bvh.h +195 -85
  44. warp/native/clang/clang.cpp +5 -1
  45. warp/native/cuda_util.cpp +35 -0
  46. warp/native/cuda_util.h +5 -0
  47. warp/native/exports.h +40 -40
  48. warp/native/intersect.h +17 -0
  49. warp/native/mat.h +41 -0
  50. warp/native/mathdx.cpp +19 -0
  51. warp/native/mesh.cpp +25 -8
  52. warp/native/mesh.cu +153 -101
  53. warp/native/mesh.h +482 -403
  54. warp/native/quat.h +40 -0
  55. warp/native/solid_angle.h +7 -0
  56. warp/native/sort.cpp +85 -0
  57. warp/native/sort.cu +34 -0
  58. warp/native/sort.h +3 -1
  59. warp/native/spatial.h +11 -0
  60. warp/native/tile.h +1185 -664
  61. warp/native/tile_reduce.h +8 -6
  62. warp/native/vec.h +41 -0
  63. warp/native/warp.cpp +8 -1
  64. warp/native/warp.cu +263 -40
  65. warp/native/warp.h +19 -5
  66. warp/optim/linear.py +22 -4
  67. warp/render/render_opengl.py +124 -59
  68. warp/sim/__init__.py +6 -1
  69. warp/sim/collide.py +270 -26
  70. warp/sim/integrator_euler.py +25 -7
  71. warp/sim/integrator_featherstone.py +154 -35
  72. warp/sim/integrator_vbd.py +842 -40
  73. warp/sim/model.py +111 -53
  74. warp/stubs.py +248 -115
  75. warp/tape.py +28 -30
  76. warp/tests/aux_test_module_unload.py +15 -0
  77. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  78. warp/tests/test_array.py +74 -0
  79. warp/tests/test_assert.py +242 -0
  80. warp/tests/test_codegen.py +14 -61
  81. warp/tests/test_collision.py +2 -2
  82. warp/tests/test_examples.py +9 -0
  83. warp/tests/test_grad_debug.py +87 -2
  84. warp/tests/test_hash_grid.py +1 -1
  85. warp/tests/test_ipc.py +116 -0
  86. warp/tests/test_mat.py +138 -167
  87. warp/tests/test_math.py +47 -1
  88. warp/tests/test_matmul.py +11 -7
  89. warp/tests/test_matmul_lite.py +4 -4
  90. warp/tests/test_mesh.py +84 -60
  91. warp/tests/test_mesh_query_aabb.py +165 -0
  92. warp/tests/test_mesh_query_point.py +328 -286
  93. warp/tests/test_mesh_query_ray.py +134 -121
  94. warp/tests/test_mlp.py +2 -2
  95. warp/tests/test_operators.py +43 -0
  96. warp/tests/test_overwrite.py +2 -2
  97. warp/tests/test_quat.py +77 -0
  98. warp/tests/test_reload.py +29 -0
  99. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  100. warp/tests/test_static.py +16 -0
  101. warp/tests/test_tape.py +25 -0
  102. warp/tests/test_tile.py +134 -191
  103. warp/tests/test_tile_load.py +356 -0
  104. warp/tests/test_tile_mathdx.py +61 -8
  105. warp/tests/test_tile_mlp.py +17 -17
  106. warp/tests/test_tile_reduce.py +24 -18
  107. warp/tests/test_tile_shared_memory.py +66 -17
  108. warp/tests/test_tile_view.py +165 -0
  109. warp/tests/test_torch.py +35 -0
  110. warp/tests/test_utils.py +36 -24
  111. warp/tests/test_vec.py +110 -0
  112. warp/tests/unittest_suites.py +29 -4
  113. warp/tests/unittest_utils.py +30 -11
  114. warp/thirdparty/unittest_parallel.py +2 -2
  115. warp/types.py +409 -99
  116. warp/utils.py +9 -5
  117. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
  118. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
  119. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  120. warp/examples/benchmarks/benchmark_tile.py +0 -179
  121. warp/native/tile_gemm.h +0 -341
  122. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  123. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
@@ -91,114 +91,120 @@ def test_mesh_query_ray_grad(test, device):
91
91
  mesh_points = wp.array(np.array(mesh_geom.GetPointsAttr().Get()), dtype=wp.vec3, device=device)
92
92
  mesh_indices = wp.array(np.array(tri_indices), dtype=int, device=device)
93
93
 
94
- p = wp.vec3(50.0, 50.0, 0.0)
95
- D = wp.vec3(0.0, -1.0, 0.0)
96
-
97
- # create mesh
98
- mesh = wp.Mesh(points=mesh_points, velocities=None, indices=mesh_indices)
99
-
100
- tape = wp.Tape()
101
-
102
- # analytic gradients
103
- with tape:
104
- query_points = wp.array(p, dtype=wp.vec3, device=device, requires_grad=True)
105
- query_dirs = wp.array(D, dtype=wp.vec3, device=device, requires_grad=True)
106
- intersection_points = wp.zeros(n=1, dtype=wp.vec3, device=device)
107
- loss = wp.zeros(n=1, dtype=float, device=device, requires_grad=True)
108
-
109
- wp.launch(
110
- kernel=mesh_query_ray_loss,
111
- dim=1,
112
- inputs=[mesh.id, query_points, query_dirs, intersection_points, loss],
113
- device=device,
114
- )
115
-
116
- tape.backward(loss=loss)
117
- q = intersection_points.numpy().flatten()
118
- analytic_p = tape.gradients[query_points].numpy().flatten()
119
- analytic_D = tape.gradients[query_dirs].numpy().flatten()
120
-
121
- # numeric gradients
122
-
123
- # ray origin
124
- eps = 1.0e-3
125
- loss_values_p = []
126
- numeric_p = np.zeros(3)
127
-
128
- offset_query_points = [
129
- wp.vec3(p[0] - eps, p[1], p[2]),
130
- wp.vec3(p[0] + eps, p[1], p[2]),
131
- wp.vec3(p[0], p[1] - eps, p[2]),
132
- wp.vec3(p[0], p[1] + eps, p[2]),
133
- wp.vec3(p[0], p[1], p[2] - eps),
134
- wp.vec3(p[0], p[1], p[2] + eps),
135
- ]
136
-
137
- for i in range(6):
138
- q = offset_query_points[i]
139
-
140
- query_points = wp.array(q, dtype=wp.vec3, device=device)
141
- query_dirs = wp.array(D, dtype=wp.vec3, device=device)
142
- intersection_points = wp.zeros(n=1, dtype=wp.vec3, device=device)
143
- loss = wp.zeros(n=1, dtype=float, device=device)
144
-
145
- wp.launch(
146
- kernel=mesh_query_ray_loss,
147
- dim=1,
148
- inputs=[mesh.id, query_points, query_dirs, intersection_points, loss],
149
- device=device,
150
- )
151
-
152
- loss_values_p.append(loss.numpy()[0])
153
-
154
- for i in range(3):
155
- l_0 = loss_values_p[i * 2]
156
- l_1 = loss_values_p[i * 2 + 1]
157
- gradient = (l_1 - l_0) / (2.0 * eps)
158
- numeric_p[i] = gradient
159
-
160
- # ray dir
161
- loss_values_D = []
162
- numeric_D = np.zeros(3)
163
-
164
- offset_query_dirs = [
165
- wp.vec3(D[0] - eps, D[1], D[2]),
166
- wp.vec3(D[0] + eps, D[1], D[2]),
167
- wp.vec3(D[0], D[1] - eps, D[2]),
168
- wp.vec3(D[0], D[1] + eps, D[2]),
169
- wp.vec3(D[0], D[1], D[2] - eps),
170
- wp.vec3(D[0], D[1], D[2] + eps),
171
- ]
172
-
173
- for i in range(6):
174
- q = offset_query_dirs[i]
175
-
176
- query_points = wp.array(p, dtype=wp.vec3, device=device)
177
- query_dirs = wp.array(q, dtype=wp.vec3, device=device)
178
- intersection_points = wp.zeros(n=1, dtype=wp.vec3, device=device)
179
- loss = wp.zeros(n=1, dtype=float, device=device)
180
-
181
- wp.launch(
182
- kernel=mesh_query_ray_loss,
183
- dim=1,
184
- inputs=[mesh.id, query_points, query_dirs, intersection_points, loss],
185
- device=device,
186
- )
187
-
188
- loss_values_D.append(loss.numpy()[0])
189
-
190
- for i in range(3):
191
- l_0 = loss_values_D[i * 2]
192
- l_1 = loss_values_D[i * 2 + 1]
193
- gradient = (l_1 - l_0) / (2.0 * eps)
194
- numeric_D[i] = gradient
195
-
196
- error_p = ((analytic_p - numeric_p) * (analytic_p - numeric_p)).sum(axis=0)
197
- error_D = ((analytic_D - numeric_D) * (analytic_D - numeric_D)).sum(axis=0)
198
-
199
- tolerance = 1.0e-3
200
- test.assertTrue(error_p < tolerance, f"error is {error_p} which is >= {tolerance}")
201
- test.assertTrue(error_D < tolerance, f"error is {error_D} which is >= {tolerance}")
94
+ if device.is_cpu:
95
+ constructors = ["sah", "median"]
96
+ else:
97
+ constructors = ["sah", "median", "lbvh"]
98
+
99
+ for constructor in constructors:
100
+ p = wp.vec3(50.0, 50.0, 0.0)
101
+ D = wp.vec3(0.0, -1.0, 0.0)
102
+
103
+ # create mesh
104
+ mesh = wp.Mesh(points=mesh_points, velocities=None, indices=mesh_indices, bvh_constructor=constructor)
105
+
106
+ tape = wp.Tape()
107
+
108
+ # analytic gradients
109
+ with tape:
110
+ query_points = wp.array(p, dtype=wp.vec3, device=device, requires_grad=True)
111
+ query_dirs = wp.array(D, dtype=wp.vec3, device=device, requires_grad=True)
112
+ intersection_points = wp.zeros(n=1, dtype=wp.vec3, device=device)
113
+ loss = wp.zeros(n=1, dtype=float, device=device, requires_grad=True)
114
+
115
+ wp.launch(
116
+ kernel=mesh_query_ray_loss,
117
+ dim=1,
118
+ inputs=[mesh.id, query_points, query_dirs, intersection_points, loss],
119
+ device=device,
120
+ )
121
+
122
+ tape.backward(loss=loss)
123
+ q = intersection_points.numpy().flatten()
124
+ analytic_p = tape.gradients[query_points].numpy().flatten()
125
+ analytic_D = tape.gradients[query_dirs].numpy().flatten()
126
+
127
+ # numeric gradients
128
+
129
+ # ray origin
130
+ eps = 1.0e-3
131
+ loss_values_p = []
132
+ numeric_p = np.zeros(3)
133
+
134
+ offset_query_points = [
135
+ wp.vec3(p[0] - eps, p[1], p[2]),
136
+ wp.vec3(p[0] + eps, p[1], p[2]),
137
+ wp.vec3(p[0], p[1] - eps, p[2]),
138
+ wp.vec3(p[0], p[1] + eps, p[2]),
139
+ wp.vec3(p[0], p[1], p[2] - eps),
140
+ wp.vec3(p[0], p[1], p[2] + eps),
141
+ ]
142
+
143
+ for i in range(6):
144
+ q = offset_query_points[i]
145
+
146
+ query_points = wp.array(q, dtype=wp.vec3, device=device)
147
+ query_dirs = wp.array(D, dtype=wp.vec3, device=device)
148
+ intersection_points = wp.zeros(n=1, dtype=wp.vec3, device=device)
149
+ loss = wp.zeros(n=1, dtype=float, device=device)
150
+
151
+ wp.launch(
152
+ kernel=mesh_query_ray_loss,
153
+ dim=1,
154
+ inputs=[mesh.id, query_points, query_dirs, intersection_points, loss],
155
+ device=device,
156
+ )
157
+
158
+ loss_values_p.append(loss.numpy()[0])
159
+
160
+ for i in range(3):
161
+ l_0 = loss_values_p[i * 2]
162
+ l_1 = loss_values_p[i * 2 + 1]
163
+ gradient = (l_1 - l_0) / (2.0 * eps)
164
+ numeric_p[i] = gradient
165
+
166
+ # ray dir
167
+ loss_values_D = []
168
+ numeric_D = np.zeros(3)
169
+
170
+ offset_query_dirs = [
171
+ wp.vec3(D[0] - eps, D[1], D[2]),
172
+ wp.vec3(D[0] + eps, D[1], D[2]),
173
+ wp.vec3(D[0], D[1] - eps, D[2]),
174
+ wp.vec3(D[0], D[1] + eps, D[2]),
175
+ wp.vec3(D[0], D[1], D[2] - eps),
176
+ wp.vec3(D[0], D[1], D[2] + eps),
177
+ ]
178
+
179
+ for i in range(6):
180
+ q = offset_query_dirs[i]
181
+
182
+ query_points = wp.array(p, dtype=wp.vec3, device=device)
183
+ query_dirs = wp.array(q, dtype=wp.vec3, device=device)
184
+ intersection_points = wp.zeros(n=1, dtype=wp.vec3, device=device)
185
+ loss = wp.zeros(n=1, dtype=float, device=device)
186
+
187
+ wp.launch(
188
+ kernel=mesh_query_ray_loss,
189
+ dim=1,
190
+ inputs=[mesh.id, query_points, query_dirs, intersection_points, loss],
191
+ device=device,
192
+ )
193
+
194
+ loss_values_D.append(loss.numpy()[0])
195
+
196
+ for i in range(3):
197
+ l_0 = loss_values_D[i * 2]
198
+ l_1 = loss_values_D[i * 2 + 1]
199
+ gradient = (l_1 - l_0) / (2.0 * eps)
200
+ numeric_D[i] = gradient
201
+
202
+ error_p = ((analytic_p - numeric_p) * (analytic_p - numeric_p)).sum(axis=0)
203
+ error_D = ((analytic_D - numeric_D) * (analytic_D - numeric_D)).sum(axis=0)
204
+
205
+ tolerance = 1.0e-3
206
+ test.assertTrue(error_p < tolerance, f"error is {error_p} which is >= {tolerance}")
207
+ test.assertTrue(error_D < tolerance, f"error is {error_D} which is >= {tolerance}")
202
208
 
203
209
 
204
210
  @wp.kernel
@@ -229,6 +235,11 @@ def raycast_kernel(
229
235
 
230
236
 
231
237
  def test_mesh_query_ray_edge(test, device):
238
+ if device.is_cpu:
239
+ constructors = ["sah", "median"]
240
+ else:
241
+ constructors = ["sah", "median", "lbvh"]
242
+
232
243
  # Create raycast starts and directions
233
244
  xx, yy = np.meshgrid(np.arange(0.1, 0.4, 0.01), np.arange(0.1, 0.4, 0.01))
234
245
  xx = xx.flatten().reshape(-1, 1)
@@ -239,27 +250,29 @@ def test_mesh_query_ray_edge(test, device):
239
250
  ray_dirs = np.zeros_like(ray_starts)
240
251
  ray_dirs[:, 2] = -1.0
241
252
 
253
+ n = len(ray_starts)
254
+
255
+ ray_starts = wp.array(ray_starts, shape=(n,), dtype=wp.vec3, device=device)
256
+ ray_dirs = wp.array(ray_dirs, shape=(n,), dtype=wp.vec3, device=device)
257
+
242
258
  # Create simple square mesh
243
259
  vertices = np.array([[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.5, 0.0, 0.0], [0.5, 0.5, 0.0]], dtype=np.float32)
244
260
 
245
261
  triangles = np.array([[1, 0, 2], [1, 2, 3]], dtype=np.int32)
246
262
 
247
- mesh = wp.Mesh(
248
- points=wp.array(vertices, dtype=wp.vec3, device=device),
249
- indices=wp.array(triangles.flatten(), dtype=int, device=device),
250
- )
251
-
252
- counts = wp.zeros(1, dtype=int, device=device)
253
-
254
- n = len(ray_starts)
263
+ for constructor in constructors:
264
+ mesh = wp.Mesh(
265
+ points=wp.array(vertices, dtype=wp.vec3, device=device),
266
+ indices=wp.array(triangles.flatten(), dtype=int, device=device),
267
+ bvh_constructor=constructor,
268
+ )
255
269
 
256
- ray_starts = wp.array(ray_starts, shape=(n,), dtype=wp.vec3, device=device)
257
- ray_dirs = wp.array(ray_dirs, shape=(n,), dtype=wp.vec3, device=device)
270
+ counts = wp.zeros(1, dtype=int, device=device)
258
271
 
259
- wp.launch(kernel=raycast_kernel, dim=n, inputs=[mesh.id, ray_starts, ray_dirs, counts], device=device)
260
- wp.synchronize()
272
+ wp.launch(kernel=raycast_kernel, dim=n, inputs=[mesh.id, ray_starts, ray_dirs, counts], device=device)
273
+ wp.synchronize()
261
274
 
262
- test.assertEqual(counts.numpy()[0], n)
275
+ test.assertEqual(counts.numpy()[0], n)
263
276
 
264
277
 
265
278
  devices = get_test_devices()
warp/tests/test_mlp.py CHANGED
@@ -265,8 +265,8 @@ class TestMLP(unittest.TestCase):
265
265
  pass
266
266
 
267
267
 
268
- add_function_test(TestMLP, "test_mlp", test_mlp, devices=devices)
269
- add_function_test(TestMLP, "test_mlp_grad", test_mlp_grad, devices=devices)
268
+ add_function_test(TestMLP, "test_mlp", test_mlp, devices=devices, check_output=False)
269
+ add_function_test(TestMLP, "test_mlp_grad", test_mlp_grad, devices=devices, check_output=False)
270
270
 
271
271
 
272
272
  if __name__ == "__main__":
@@ -224,6 +224,48 @@ def test_operators_mat44():
224
224
  expect_eq(r0[3], wp.vec4(39.0, 42.0, 45.0, 48.0))
225
225
 
226
226
 
227
+ @wp.struct
228
+ class Complex:
229
+ real: float
230
+ imag: float
231
+
232
+
233
+ @wp.func
234
+ def add(
235
+ a: Complex,
236
+ b: Complex,
237
+ ) -> Complex:
238
+ return Complex(
239
+ a.real + b.real,
240
+ a.imag + b.imag,
241
+ )
242
+
243
+
244
+ @wp.func
245
+ def mul(
246
+ a: Complex,
247
+ b: Complex,
248
+ ) -> Complex:
249
+ return Complex(
250
+ a.real * b.real - a.imag * b.imag,
251
+ a.real * b.imag + a.imag * b.real,
252
+ )
253
+
254
+
255
+ @wp.kernel
256
+ def test_operators_overload():
257
+ a = Complex(1.0, 2.0)
258
+ b = Complex(3.0, 4.0)
259
+
260
+ c = a + b
261
+ expect_eq(c.real, 4.0)
262
+ expect_eq(c.imag, 6.0)
263
+
264
+ d = a * b
265
+ expect_eq(d.real, -5.0)
266
+ expect_eq(d.imag, 10.0)
267
+
268
+
227
269
  devices = get_test_devices()
228
270
 
229
271
 
@@ -241,6 +283,7 @@ add_kernel_test(TestOperators, test_operators_vec4, dim=1, devices=devices)
241
283
  add_kernel_test(TestOperators, test_operators_mat22, dim=1, devices=devices)
242
284
  add_kernel_test(TestOperators, test_operators_mat33, dim=1, devices=devices)
243
285
  add_kernel_test(TestOperators, test_operators_mat44, dim=1, devices=devices)
286
+ add_kernel_test(TestOperators, test_operators_overload, dim=1, devices=devices)
244
287
 
245
288
 
246
289
  if __name__ == "__main__":
@@ -577,8 +577,8 @@ add_function_test(TestOverwrite, "test_views", test_views, devices=devices)
577
577
  add_function_test(TestOverwrite, "test_reset", test_reset, devices=devices)
578
578
 
579
579
  add_function_test(TestOverwrite, "test_copy", test_copy, devices=devices)
580
- add_function_test(TestOverwrite, "test_matmul", test_matmul, devices=devices)
581
- add_function_test(TestOverwrite, "test_batched_matmul", test_batched_matmul, devices=devices)
580
+ add_function_test(TestOverwrite, "test_matmul", test_matmul, devices=devices, check_output=False)
581
+ add_function_test(TestOverwrite, "test_batched_matmul", test_batched_matmul, devices=devices, check_output=False)
582
582
  add_function_test(TestOverwrite, "test_atomic_operations", test_atomic_operations, devices=devices)
583
583
 
584
584
  # Some warning are only issued during codegen, and codegen only runs on cuda_0 in the MGPU case.
warp/tests/test_quat.py CHANGED
@@ -2095,6 +2095,81 @@ def test_py_arithmetic_ops(test, device, dtype):
2095
2095
  test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
2096
2096
 
2097
2097
 
2098
+ @wp.kernel
2099
+ def quat_len_kernel(
2100
+ q: wp.quat,
2101
+ out: wp.array(dtype=int),
2102
+ ):
2103
+ length = wp.static(len(q))
2104
+ wp.expect_eq(wp.static(len(q)), 4)
2105
+ out[0] = wp.static(len(q))
2106
+
2107
+ foo = wp.quat()
2108
+ length = len(foo)
2109
+ wp.expect_eq(len(foo), 4)
2110
+ out[1] = len(foo)
2111
+
2112
+
2113
+ def test_quat_len(test, device):
2114
+ q = wp.quat()
2115
+ out = wp.empty(2, dtype=int, device=device)
2116
+ wp.launch(quat_len_kernel, dim=(1,), inputs=(q,), outputs=(out,), device=device)
2117
+
2118
+ test.assertEqual(out.numpy()[0], 4)
2119
+ test.assertEqual(out.numpy()[1], 4)
2120
+
2121
+
2122
+ @wp.kernel
2123
+ def vector_augassign_kernel(
2124
+ a: wp.array(dtype=wp.quat), b: wp.array(dtype=wp.quat), c: wp.array(dtype=wp.quat), d: wp.array(dtype=wp.quat)
2125
+ ):
2126
+ i = wp.tid()
2127
+
2128
+ q1 = wp.quat()
2129
+ q2 = b[i]
2130
+
2131
+ q1[0] += q2[0]
2132
+ q1[1] += q2[1]
2133
+ q1[2] += q2[2]
2134
+ q1[3] += q2[3]
2135
+
2136
+ a[i] = q1
2137
+
2138
+ q3 = wp.quat()
2139
+ q4 = d[i]
2140
+
2141
+ q3[0] += q4[0]
2142
+ q3[1] += q4[1]
2143
+ q3[2] += q4[2]
2144
+ q3[3] += q4[3]
2145
+
2146
+ c[i] = q1
2147
+
2148
+
2149
+ def test_vector_augassign(test, device):
2150
+ N = 3
2151
+
2152
+ a = wp.zeros(N, dtype=wp.quat, requires_grad=True)
2153
+ b = wp.ones(N, dtype=wp.quat, requires_grad=True)
2154
+
2155
+ c = wp.zeros(N, dtype=wp.quat, requires_grad=True)
2156
+ d = wp.ones(N, dtype=wp.quat, requires_grad=True)
2157
+
2158
+ tape = wp.Tape()
2159
+ with tape:
2160
+ wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d])
2161
+
2162
+ tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
2163
+
2164
+ assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
2165
+ assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
2166
+ assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
2167
+
2168
+ assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
2169
+ assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
2170
+ assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
2171
+
2172
+
2098
2173
  devices = get_test_devices()
2099
2174
 
2100
2175
 
@@ -2203,6 +2278,8 @@ for dtype in np_float_types:
2203
2278
  TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2204
2279
  )
2205
2280
 
2281
+ add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices)
2282
+
2206
2283
 
2207
2284
  if __name__ == "__main__":
2208
2285
  wp.clear_kernel_cache()
warp/tests/test_reload.py CHANGED
@@ -241,6 +241,32 @@ def test_graph_launch_after_module_reload(test, device):
241
241
  test.assertEqual(a.numpy()[0], 42)
242
242
 
243
243
 
244
+ def test_module_unload_during_graph_capture(test, device):
245
+ @wp.kernel
246
+ def foo(a: wp.array(dtype=int)):
247
+ a[0] = 42
248
+
249
+ # preload module before graph capture
250
+ wp.load_module(device=device)
251
+
252
+ # load another module to test unloading during graph capture
253
+ other_module = wp.get_module("warp.tests.aux_test_module_unload")
254
+ other_module.load(device)
255
+
256
+ with wp.ScopedDevice(device):
257
+ a = wp.zeros(1, dtype=int)
258
+
259
+ with wp.ScopedCapture(force_module_load=False) as capture:
260
+ wp.launch(foo, dim=1, inputs=[a])
261
+
262
+ # unloading a module during graph capture should be fine (deferred until capture completes)
263
+ other_module.unload()
264
+
265
+ wp.capture_launch(capture.graph)
266
+
267
+ test.assertEqual(a.numpy()[0], 42)
268
+
269
+
244
270
  devices = get_test_devices()
245
271
  cuda_devices = get_cuda_test_devices()
246
272
 
@@ -258,6 +284,9 @@ add_function_test(TestReload, "test_reload_references", test_reload_references,
258
284
  add_function_test(
259
285
  TestReload, "test_graph_launch_after_module_reload", test_graph_launch_after_module_reload, devices=cuda_devices
260
286
  )
287
+ add_function_test(
288
+ TestReload, "test_module_unload_during_graph_capture", test_module_unload_during_graph_capture, devices=cuda_devices
289
+ )
261
290
 
262
291
 
263
292
  if __name__ == "__main__":