warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl → 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (60) hide show
  1. warp/autograd.py +12 -2
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +1 -1
  5. warp/builtins.py +103 -66
  6. warp/codegen.py +48 -27
  7. warp/config.py +1 -1
  8. warp/context.py +112 -49
  9. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  10. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  11. warp/fem/cache.py +1 -1
  12. warp/fem/field/field.py +11 -1
  13. warp/fem/field/nodal_field.py +36 -22
  14. warp/fem/geometry/adaptive_nanogrid.py +7 -3
  15. warp/fem/geometry/trimesh.py +4 -12
  16. warp/jax_experimental/custom_call.py +14 -2
  17. warp/jax_experimental/ffi.py +100 -67
  18. warp/native/builtin.h +91 -65
  19. warp/native/svd.h +59 -49
  20. warp/native/tile.h +55 -26
  21. warp/native/volume.cpp +2 -2
  22. warp/native/volume_builder.cu +33 -22
  23. warp/native/warp.cu +1 -1
  24. warp/render/render_opengl.py +41 -34
  25. warp/render/render_usd.py +96 -6
  26. warp/sim/collide.py +11 -9
  27. warp/sim/inertia.py +189 -156
  28. warp/sim/integrator_euler.py +3 -0
  29. warp/sim/integrator_xpbd.py +3 -0
  30. warp/sim/model.py +56 -31
  31. warp/sim/render.py +4 -0
  32. warp/sparse.py +1 -1
  33. warp/stubs.py +73 -25
  34. warp/tests/assets/torus.usda +1 -1
  35. warp/tests/cuda/test_streams.py +1 -1
  36. warp/tests/sim/test_collision.py +237 -206
  37. warp/tests/sim/test_inertia.py +161 -0
  38. warp/tests/sim/test_model.py +5 -3
  39. warp/tests/sim/{flaky_test_sim_grad.py → test_sim_grad.py} +1 -4
  40. warp/tests/sim/test_xpbd.py +399 -0
  41. warp/tests/test_array.py +8 -7
  42. warp/tests/test_atomic.py +181 -2
  43. warp/tests/test_builtins_resolution.py +38 -38
  44. warp/tests/test_codegen.py +24 -3
  45. warp/tests/test_examples.py +16 -6
  46. warp/tests/test_fem.py +93 -14
  47. warp/tests/test_func.py +1 -1
  48. warp/tests/test_mat.py +416 -119
  49. warp/tests/test_quat.py +321 -137
  50. warp/tests/test_struct.py +116 -0
  51. warp/tests/test_vec.py +320 -174
  52. warp/tests/tile/test_tile.py +27 -0
  53. warp/tests/tile/test_tile_load.py +124 -0
  54. warp/tests/unittest_suites.py +2 -5
  55. warp/types.py +107 -9
  56. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/METADATA +41 -19
  57. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/RECORD +60 -57
  58. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/WHEEL +1 -1
  59. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/licenses/LICENSE.md +0 -26
  60. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/top_level.txt +0 -0
@@ -157,9 +157,6 @@ def test_sphere_pushing_on_rails(
157
157
  model.joint_attach_ke = 32000.0 * 16
158
158
  model.joint_attach_kd = 500.0 * 4
159
159
 
160
- model.shape_geo.scale.requires_grad = False
161
- model.shape_geo.thickness.requires_grad = False
162
-
163
160
  if static_contacts:
164
161
  wp.sim.eval_fk(model, model.joint_q, model.joint_qd, None, model)
165
162
  model.rigid_contact_margin = 10.0
@@ -268,7 +265,7 @@ def test_sphere_pushing_on_rails(
268
265
  gradcheck(rollout, [action_too_close], device=device, eps=0.2, tol=tol, print_grad=print_grad)
269
266
 
270
267
 
271
- devices = get_test_devices()
268
+ devices = get_test_devices(mode="basic")
272
269
 
273
270
 
274
271
  class TestSimGradients(unittest.TestCase):
@@ -0,0 +1,399 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from warp.sim.model import PARTICLE_FLAG_ACTIVE
17
+ from warp.tests.unittest_utils import *
18
+
19
+ # fmt: off
20
+ CLOTH_POINTS = [
21
+ (-50.0000000, 0.0000000, -50.0000000),
22
+ (-38.8888893, 11.1111107, -50.0000000),
23
+ (-27.7777786, 22.2222214, -50.0000000),
24
+ (-16.6666679, 33.3333321, -50.0000000),
25
+ (-5.5555558, 44.4444427, -50.0000000),
26
+ (5.5555558, 55.5555573, -50.0000000),
27
+ (16.6666679, 66.6666641, -50.0000000),
28
+ (27.7777786, 77.7777786, -50.0000000),
29
+ (38.8888893, 88.8888855, -50.0000000),
30
+ (50.0000000, 100.0000000, -50.0000000),
31
+ (-50.0000000, 0.0000000, -38.8888893),
32
+ (-38.8888893, 11.1111107, -38.8888893),
33
+ (-27.7777786, 22.2222214, -38.8888893),
34
+ (-16.6666679, 33.3333321, -38.8888893),
35
+ (-5.5555558, 44.4444427, -38.8888893),
36
+ (5.5555558, 55.5555573, -38.8888893),
37
+ (16.6666679, 66.6666641, -38.8888893),
38
+ (27.7777786, 77.7777786, -38.8888893),
39
+ (38.8888893, 88.8888855, -38.8888893),
40
+ (50.0000000, 100.0000000, -38.8888893),
41
+ (-50.0000000, 0.0000000, -27.7777786),
42
+ (-38.8888893, 11.1111107, -27.7777786),
43
+ (-27.7777786, 22.2222214, -27.7777786),
44
+ (-16.6666679, 33.3333321, -27.7777786),
45
+ (-5.5555558, 44.4444427, -27.7777786),
46
+ (5.5555558, 55.5555573, -27.7777786),
47
+ (16.6666679, 66.6666641, -27.7777786),
48
+ (27.7777786, 77.7777786, -27.7777786),
49
+ (38.8888893, 88.8888855, -27.7777786),
50
+ (50.0000000, 100.0000000, -27.7777786),
51
+ (-50.0000000, 0.0000000, -16.6666679),
52
+ (-38.8888893, 11.1111107, -16.6666679),
53
+ (-27.7777786, 22.2222214, -16.6666679),
54
+ (-16.6666679, 33.3333321, -16.6666679),
55
+ (-5.5555558, 44.4444427, -16.6666679),
56
+ (5.5555558, 55.5555573, -16.6666679),
57
+ (16.6666679, 66.6666641, -16.6666679),
58
+ (27.7777786, 77.7777786, -16.6666679),
59
+ (38.8888893, 88.8888855, -16.6666679),
60
+ (50.0000000, 100.0000000, -16.6666679),
61
+ (-50.0000000, 0.0000000, -5.5555558),
62
+ (-38.8888893, 11.1111107, -5.5555558),
63
+ (-27.7777786, 22.2222214, -5.5555558),
64
+ (-16.6666679, 33.3333321, -5.5555558),
65
+ (-5.5555558, 44.4444427, -5.5555558),
66
+ (5.5555558, 55.5555573, -5.5555558),
67
+ (16.6666679, 66.6666641, -5.5555558),
68
+ (27.7777786, 77.7777786, -5.5555558),
69
+ (38.8888893, 88.8888855, -5.5555558),
70
+ (50.0000000, 100.0000000, -5.5555558),
71
+ (-50.0000000, 0.0000000, 5.5555558),
72
+ (-38.8888893, 11.1111107, 5.5555558),
73
+ (-27.7777786, 22.2222214, 5.5555558),
74
+ (-16.6666679, 33.3333321, 5.5555558),
75
+ (-5.5555558, 44.4444427, 5.5555558),
76
+ (5.5555558, 55.5555573, 5.5555558),
77
+ (16.6666679, 66.6666641, 5.5555558),
78
+ (27.7777786, 77.7777786, 5.5555558),
79
+ (38.8888893, 88.8888855, 5.5555558),
80
+ (50.0000000, 100.0000000, 5.5555558),
81
+ (-50.0000000, 0.0000000, 16.6666679),
82
+ (-38.8888893, 11.1111107, 16.6666679),
83
+ (-27.7777786, 22.2222214, 16.6666679),
84
+ (-16.6666679, 33.3333321, 16.6666679),
85
+ (-5.5555558, 44.4444427, 16.6666679),
86
+ (5.5555558, 55.5555573, 16.6666679),
87
+ (16.6666679, 66.6666641, 16.6666679),
88
+ (27.7777786, 77.7777786, 16.6666679),
89
+ (38.8888893, 88.8888855, 16.6666679),
90
+ (50.0000000, 100.0000000, 16.6666679),
91
+ (-50.0000000, 0.0000000, 27.7777786),
92
+ (-38.8888893, 11.1111107, 27.7777786),
93
+ (-27.7777786, 22.2222214, 27.7777786),
94
+ (-16.6666679, 33.3333321, 27.7777786),
95
+ (-5.5555558, 44.4444427, 27.7777786),
96
+ (5.5555558, 55.5555573, 27.7777786),
97
+ (16.6666679, 66.6666641, 27.7777786),
98
+ (27.7777786, 77.7777786, 27.7777786),
99
+ (38.8888893, 88.8888855, 27.7777786),
100
+ (50.0000000, 100.0000000, 27.7777786),
101
+ (-50.0000000, 0.0000000, 38.8888893),
102
+ (-38.8888893, 11.1111107, 38.8888893),
103
+ (-27.7777786, 22.2222214, 38.8888893),
104
+ (-16.6666679, 33.3333321, 38.8888893),
105
+ (-5.5555558, 44.4444427, 38.8888893),
106
+ (5.5555558, 55.5555573, 38.8888893),
107
+ (16.6666679, 66.6666641, 38.8888893),
108
+ (27.7777786, 77.7777786, 38.8888893),
109
+ (38.8888893, 88.8888855, 38.8888893),
110
+ (50.0000000, 100.0000000, 38.8888893),
111
+ (-50.0000000, 0.0000000, 50.0000000),
112
+ (-38.8888893, 11.1111107, 50.0000000),
113
+ (-27.7777786, 22.2222214, 50.0000000),
114
+ (-16.6666679, 33.3333321, 50.0000000),
115
+ (-5.5555558, 44.4444427, 50.0000000),
116
+ (5.5555558, 55.5555573, 50.0000000),
117
+ (16.6666679, 66.6666641, 50.0000000),
118
+ (27.7777786, 77.7777786, 50.0000000),
119
+ (38.8888893, 88.8888855, 50.0000000),
120
+ (50.0000000, 100.0000000, 50.0000000),
121
+ ]
122
+
123
+ CLOTH_FACES = [
124
+ 1, 12, 2,
125
+ 1, 11, 12,
126
+ 2, 12, 3,
127
+ 12, 13, 3,
128
+ 3, 14, 4,
129
+ 3, 13, 14,
130
+ 4, 14, 5,
131
+ 14, 15, 5,
132
+ 5, 16, 6,
133
+ 5, 15, 16,
134
+ 6, 16, 7,
135
+ 16, 17, 7,
136
+ 7, 18, 8,
137
+ 7, 17, 18,
138
+ 8, 18, 9,
139
+ 18, 19, 9,
140
+ 9, 20, 10,
141
+ 9, 19, 20,
142
+ 11, 21, 12,
143
+ 21, 22, 12,
144
+ 12, 23, 13,
145
+ 12, 22, 23,
146
+ 13, 23, 14,
147
+ 23, 24, 14,
148
+ 14, 25, 15,
149
+ 14, 24, 25,
150
+ 15, 25, 16,
151
+ 25, 26, 16,
152
+ 16, 27, 17,
153
+ 16, 26, 27,
154
+ 17, 27, 18,
155
+ 27, 28, 18,
156
+ 18, 29, 19,
157
+ 18, 28, 29,
158
+ 19, 29, 20,
159
+ 29, 30, 20,
160
+ 21, 32, 22,
161
+ 21, 31, 32,
162
+ 22, 32, 23,
163
+ 32, 33, 23,
164
+ 23, 34, 24,
165
+ 23, 33, 34,
166
+ 24, 34, 25,
167
+ 34, 35, 25,
168
+ 25, 36, 26,
169
+ 25, 35, 36,
170
+ 26, 36, 27,
171
+ 36, 37, 27,
172
+ 27, 38, 28,
173
+ 27, 37, 38,
174
+ 28, 38, 29,
175
+ 38, 39, 29,
176
+ 29, 40, 30,
177
+ 29, 39, 40,
178
+ 31, 41, 32,
179
+ 41, 42, 32,
180
+ 32, 43, 33,
181
+ 32, 42, 43,
182
+ 33, 43, 34,
183
+ 43, 44, 34,
184
+ 34, 45, 35,
185
+ 34, 44, 45,
186
+ 35, 45, 36,
187
+ 45, 46, 36,
188
+ 36, 47, 37,
189
+ 36, 46, 47,
190
+ 37, 47, 38,
191
+ 47, 48, 38,
192
+ 38, 49, 39,
193
+ 38, 48, 49,
194
+ 39, 49, 40,
195
+ 49, 50, 40,
196
+ 41, 52, 42,
197
+ 41, 51, 52,
198
+ 42, 52, 43,
199
+ 52, 53, 43,
200
+ 43, 54, 44,
201
+ 43, 53, 54,
202
+ 44, 54, 45,
203
+ 54, 55, 45,
204
+ 45, 56, 46,
205
+ 45, 55, 56,
206
+ 46, 56, 47,
207
+ 56, 57, 47,
208
+ 47, 58, 48,
209
+ 47, 57, 58,
210
+ 48, 58, 49,
211
+ 58, 59, 49,
212
+ 49, 60, 50,
213
+ 49, 59, 60,
214
+ 51, 61, 52,
215
+ 61, 62, 52,
216
+ 52, 63, 53,
217
+ 52, 62, 63,
218
+ 53, 63, 54,
219
+ 63, 64, 54,
220
+ 54, 65, 55,
221
+ 54, 64, 65,
222
+ 55, 65, 56,
223
+ 65, 66, 56,
224
+ 56, 67, 57,
225
+ 56, 66, 67,
226
+ 57, 67, 58,
227
+ 67, 68, 58,
228
+ 58, 69, 59,
229
+ 58, 68, 69,
230
+ 59, 69, 60,
231
+ 69, 70, 60,
232
+ 61, 72, 62,
233
+ 61, 71, 72,
234
+ 62, 72, 63,
235
+ 72, 73, 63,
236
+ 63, 74, 64,
237
+ 63, 73, 74,
238
+ 64, 74, 65,
239
+ 74, 75, 65,
240
+ 65, 76, 66,
241
+ 65, 75, 76,
242
+ 66, 76, 67,
243
+ 76, 77, 67,
244
+ 67, 78, 68,
245
+ 67, 77, 78,
246
+ 68, 78, 69,
247
+ 78, 79, 69,
248
+ 69, 80, 70,
249
+ 69, 79, 80,
250
+ 71, 81, 72,
251
+ 81, 82, 72,
252
+ 72, 83, 73,
253
+ 72, 82, 83,
254
+ 73, 83, 74,
255
+ 83, 84, 74,
256
+ 74, 85, 75,
257
+ 74, 84, 85,
258
+ 75, 85, 76,
259
+ 85, 86, 76,
260
+ 76, 87, 77,
261
+ 76, 86, 87,
262
+ 77, 87, 78,
263
+ 87, 88, 78,
264
+ 78, 89, 79,
265
+ 78, 88, 89,
266
+ 79, 89, 80,
267
+ 89, 90, 80,
268
+ 81, 92, 82,
269
+ 81, 91, 92,
270
+ 82, 92, 83,
271
+ 92, 93, 83,
272
+ 83, 94, 84,
273
+ 83, 93, 94,
274
+ 84, 94, 85,
275
+ 94, 95, 85,
276
+ 85, 96, 86,
277
+ 85, 95, 96,
278
+ 86, 96, 87,
279
+ 96, 97, 87,
280
+ 87, 98, 88,
281
+ 87, 97, 98,
282
+ 88, 98, 89,
283
+ 98, 99, 89,
284
+ 89, 100, 90,
285
+ 89, 99, 100
286
+ ]
287
+
288
+ # fmt: on
289
+ class XPBDClothSim:
290
+ def __init__(self, device, use_cuda_graph=False):
291
+ self.frame_dt = 1 / 60
292
+ self.num_test_frames = 100
293
+ self.num_substeps = 20
294
+ self.iterations = 2
295
+ self.dt = self.frame_dt / self.num_substeps
296
+ self.device = device
297
+ self.use_cuda_graph = self.device.is_cuda and use_cuda_graph
298
+ self.builder = wp.sim.ModelBuilder()
299
+
300
+ def set_free_falling_experiment(self):
301
+ self.input_scale_factor = 1.0
302
+ self.renderer_scale_factor = 0.01
303
+ vertices = [wp.vec3(v) * self.input_scale_factor for v in CLOTH_POINTS]
304
+ faces_flatten = [fv - 1 for fv in CLOTH_FACES]
305
+
306
+ self.builder.add_cloth_mesh(
307
+ vertices=vertices,
308
+ indices=faces_flatten,
309
+ scale=0.05,
310
+ density=10,
311
+ pos=wp.vec3(0.0, 4.0, 0.0),
312
+ rot=wp.quat_identity(),
313
+ vel=wp.vec3(0.0, 0.0, 0.0),
314
+ edge_ke=1.0e2,
315
+ add_springs=True,
316
+ spring_ke=1.0e3,
317
+ spring_kd=0.0,
318
+ )
319
+ self.fixed_particles = []
320
+ self.num_test_frames = 30
321
+
322
+ def finalize(self, ground=True):
323
+ self.model = self.builder.finalize(device=self.device)
324
+ self.model.ground = ground
325
+ self.model.gravity = wp.vec3(0, -10.0, 0)
326
+ self.model.soft_contact_ke = 1.0e4
327
+ self.model.soft_contact_kd = 1.0e2
328
+
329
+ self.set_points_fixed(self.model, self.fixed_particles)
330
+
331
+ self.integrator = wp.sim.XPBDIntegrator(self.iterations)
332
+ self.state0 = self.model.state()
333
+ self.state1 = self.model.state()
334
+
335
+ self.init_pos = np.array(self.state0.particle_q.numpy(), copy=True)
336
+
337
+ self.graph = None
338
+ if self.use_cuda_graph:
339
+ with wp.ScopedCapture(device=self.device, force_module_load=False) as capture:
340
+ self.simulate()
341
+ self.graph = capture.graph
342
+
343
+ def simulate(self):
344
+ for _step in range(self.num_substeps * self.num_test_frames):
345
+ self.integrator.simulate(self.model, self.state0, self.state1, self.dt, None)
346
+ (self.state0, self.state1) = (self.state1, self.state0)
347
+
348
+ def run(self):
349
+ if self.graph:
350
+ wp.capture_launch(self.graph)
351
+ else:
352
+ self.simulate()
353
+
354
+ def set_points_fixed(self, model, fixed_particles):
355
+ if len(fixed_particles):
356
+ flags = model.particle_flags.numpy()
357
+ for fixed_v_id in fixed_particles:
358
+ flags[fixed_v_id] = wp.uint32(int(flags[fixed_v_id]) & ~int(PARTICLE_FLAG_ACTIVE))
359
+
360
+ model.particle_flags = wp.array(flags, device=model.device)
361
+
362
+
363
+ def test_xpbd_free_falling(test, device):
364
+ example = XPBDClothSim(device)
365
+ example.set_free_falling_experiment()
366
+ example.finalize(ground=False)
367
+ initial_pos = example.state0.particle_q.numpy().copy()
368
+
369
+ example.run()
370
+
371
+ # examine that the simulation does not explode
372
+ final_pos = example.state0.particle_q.numpy()
373
+ test.assertTrue((final_pos < 1e5).all())
374
+ # examine that the simulation have moved
375
+ test.assertTrue((example.init_pos != final_pos).any())
376
+
377
+ gravity = np.array(example.model.gravity)
378
+ diff = final_pos - initial_pos
379
+ vertical_translation_norm = diff @ gravity[..., None] / (np.linalg.norm(gravity) ** 2)
380
+ # ensure it's free-falling
381
+ test.assertTrue((np.abs(vertical_translation_norm - 0.5 * np.linalg.norm(gravity) * (example.dt**2)) < 2e-1).all())
382
+ horizontal_move = diff - (vertical_translation_norm * gravity)
383
+ # ensure its horizontal translation is minimal
384
+ test.assertTrue((np.abs(horizontal_move) < 1e-1).all())
385
+
386
+
387
+ devices = get_test_devices(mode="basic")
388
+
389
+
390
+ class TestXPBD(unittest.TestCase):
391
+ pass
392
+
393
+
394
+ add_function_test(TestXPBD, "test_xpbd_free_falling", test_xpbd_free_falling, devices=devices)
395
+
396
+
397
+ if __name__ == "__main__":
398
+ wp.clear_kernel_cache()
399
+ unittest.main(verbosity=2)
warp/tests/test_array.py CHANGED
@@ -2803,7 +2803,7 @@ def test_alloc_strides(test, device):
2803
2803
 
2804
2804
  def test_casting(test, device):
2805
2805
  idxs = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
2806
- idxs = wp.array(idxs, device=device).reshape((-1, 3))
2806
+ idxs = wp.array(idxs, device=device, dtype=wp.int32).reshape((-1, 3))
2807
2807
  idxs = wp.array(idxs, shape=idxs.shape[0], dtype=wp.vec3i, device=device)
2808
2808
  assert idxs.dtype is wp.vec3i
2809
2809
  assert idxs.shape == (4,)
@@ -2846,25 +2846,25 @@ def test_array_len(test, device):
2846
2846
 
2847
2847
  def test_cuda_interface_conversion(test, device):
2848
2848
  class MyArrayInterface:
2849
- def __init__(self, data):
2850
- self.data = np.array(data)
2849
+ def __init__(self, data, npdtype):
2850
+ self.data = np.array(data, dtype=npdtype)
2851
2851
  self.__array_interface__ = self.data.__array_interface__
2852
2852
  self.__cuda_array_interface__ = self.data.__array_interface__
2853
2853
  self.__len__ = self.data.__len__
2854
2854
 
2855
- array = MyArrayInterface((1, 2, 3))
2855
+ array = MyArrayInterface((1, 2, 3), np.int8)
2856
2856
  wp_array = wp.array(array, dtype=wp.int8, device=device)
2857
2857
  assert wp_array.ptr != 0
2858
2858
 
2859
- array = MyArrayInterface((1, 2, 3))
2859
+ array = MyArrayInterface((1, 2, 3), np.float32)
2860
2860
  wp_array = wp.array(array, dtype=wp.float32, device=device)
2861
2861
  assert wp_array.ptr != 0
2862
2862
 
2863
- array = MyArrayInterface((1, 2, 3))
2863
+ array = MyArrayInterface((1, 2, 3), np.float32)
2864
2864
  wp_array = wp.array(array, dtype=wp.vec3, device=device)
2865
2865
  assert wp_array.ptr != 0
2866
2866
 
2867
- array = MyArrayInterface((1, 2, 3, 4))
2867
+ array = MyArrayInterface((1, 2, 3, 4), np.float32)
2868
2868
  wp_array = wp.array(array, dtype=wp.mat22, device=device)
2869
2869
  assert wp_array.ptr != 0
2870
2870
 
@@ -2883,6 +2883,7 @@ add_function_test(TestArray, "test_shape", test_shape, devices=devices)
2883
2883
  add_function_test(TestArray, "test_negative_shape", test_negative_shape, devices=devices)
2884
2884
  add_function_test(TestArray, "test_flatten", test_flatten, devices=devices)
2885
2885
  add_function_test(TestArray, "test_reshape", test_reshape, devices=devices)
2886
+
2886
2887
  add_function_test(TestArray, "test_slicing", test_slicing, devices=devices)
2887
2888
  add_function_test(TestArray, "test_transpose", test_transpose, devices=devices)
2888
2889
  add_function_test(TestArray, "test_view", test_view, devices=devices)
warp/tests/test_atomic.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import re
16
17
  import unittest
17
18
 
18
19
  import numpy as np
@@ -32,8 +33,8 @@ def make_atomic_test(type):
32
33
  tid = wp.tid()
33
34
 
34
35
  wp.atomic_add(out_add, 0, val[tid])
35
- wp.atomic_min(out_min, 0, val[tid])
36
- wp.atomic_max(out_max, 0, val[tid])
36
+ wp.atomic_min(out_min, wp.uint32(0), val[tid])
37
+ wp.atomic_max(out_max, wp.int64(0), val[tid])
37
38
 
38
39
  # register a custom kernel (no decorator) function
39
40
  # this lets us register the same function definition
@@ -130,6 +131,102 @@ test_atomic_mat33 = make_atomic_test(wp.mat33)
130
131
  test_atomic_mat44 = make_atomic_test(wp.mat44)
131
132
 
132
133
 
134
+ def test_atomic_add_supported_dtypes(test, device, dtype):
135
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
136
+
137
+ @wp.kernel
138
+ def kernel(arr: wp.array(dtype=dtype)):
139
+ wp.atomic_add(arr, 0, dtype(scalar_type(0)))
140
+
141
+ arr = wp.zeros(1, dtype=dtype, device=device)
142
+ wp.launch(kernel, dim=1, outputs=(arr,), device=device)
143
+
144
+
145
+ def test_atomic_min_supported_dtypes(test, device, dtype):
146
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
147
+
148
+ @wp.kernel
149
+ def kernel(arr: wp.array(dtype=dtype)):
150
+ wp.atomic_min(arr, 0, dtype(scalar_type(0)))
151
+
152
+ arr = wp.zeros(1, dtype=dtype, device=device)
153
+ wp.launch(kernel, dim=1, outputs=(arr,), device=device)
154
+
155
+
156
+ def test_atomic_max_supported_dtypes(test, device, dtype):
157
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
158
+
159
+ @wp.kernel
160
+ def kernel(arr: wp.array(dtype=dtype)):
161
+ wp.atomic_max(arr, 0, dtype(scalar_type(0)))
162
+
163
+ arr = wp.zeros(1, dtype=dtype, device=device)
164
+ wp.launch(kernel, dim=1, outputs=(arr,), device=device)
165
+
166
+
167
+ def test_atomic_add_unsupported_dtypes(test, device, dtype):
168
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
169
+
170
+ dtype_str = re.escape(wp.types.type_repr(dtype))
171
+ scalar_type_str = wp.types.type_repr(scalar_type)
172
+
173
+ @wp.kernel
174
+ def kernel(arr: wp.array(dtype=dtype)):
175
+ wp.atomic_add(arr, 0, dtype(scalar_type(0)))
176
+
177
+ arr = wp.zeros(1, dtype=dtype, device=device)
178
+ with test.assertRaisesRegex(
179
+ RuntimeError,
180
+ (
181
+ r"atomic_add\(\) operations only work on arrays with \[u\]int32, \[u\]int64, float16, float32, or float64 "
182
+ rf"as the underlying scalar types, but got {dtype_str} \(with scalar type {scalar_type_str}\)$"
183
+ ),
184
+ ):
185
+ wp.launch(kernel, dim=1, outputs=(arr,), device=device)
186
+
187
+
188
+ def test_atomic_min_unsupported_dtypes(test, device, dtype):
189
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
190
+
191
+ dtype_str = re.escape(wp.types.type_repr(dtype))
192
+ scalar_type_str = wp.types.type_repr(scalar_type)
193
+
194
+ @wp.kernel
195
+ def kernel(arr: wp.array(dtype=dtype)):
196
+ wp.atomic_min(arr, 0, dtype(scalar_type(0)))
197
+
198
+ arr = wp.zeros(1, dtype=dtype, device=device)
199
+ with test.assertRaisesRegex(
200
+ RuntimeError,
201
+ (
202
+ r"atomic_min\(\) operations only work on arrays with \[u\]int32, \[u\]int64, float32, or float64 "
203
+ rf"as the underlying scalar types, but got {dtype_str} \(with scalar type {scalar_type_str}\)$"
204
+ ),
205
+ ):
206
+ wp.launch(kernel, dim=1, outputs=(arr,), device=device)
207
+
208
+
209
+ def test_atomic_max_unsupported_dtypes(test, device, dtype):
210
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
211
+
212
+ dtype_str = re.escape(wp.types.type_repr(dtype))
213
+ scalar_type_str = wp.types.type_repr(scalar_type)
214
+
215
+ @wp.kernel
216
+ def kernel(arr: wp.array(dtype=dtype)):
217
+ wp.atomic_max(arr, 0, dtype(scalar_type(0)))
218
+
219
+ arr = wp.zeros(1, dtype=dtype, device=device)
220
+ with test.assertRaisesRegex(
221
+ RuntimeError,
222
+ (
223
+ r"atomic_max\(\) operations only work on arrays with \[u\]int32, \[u\]int64, float32, or float64 "
224
+ rf"as the underlying scalar types, but got {dtype_str} \(with scalar type {scalar_type_str}\)$"
225
+ ),
226
+ ):
227
+ wp.launch(kernel, dim=1, outputs=(arr,), device=device)
228
+
229
+
133
230
  devices = get_test_devices()
134
231
 
135
232
 
@@ -147,6 +244,88 @@ add_function_test(TestAtomic, "test_atomic_mat22", test_atomic_mat22, devices=de
147
244
  add_function_test(TestAtomic, "test_atomic_mat33", test_atomic_mat33, devices=devices)
148
245
  add_function_test(TestAtomic, "test_atomic_mat44", test_atomic_mat44, devices=devices)
149
246
 
247
+ for dtype in (
248
+ wp.int32,
249
+ wp.uint32,
250
+ wp.int64,
251
+ wp.uint64,
252
+ wp.float16,
253
+ wp.float32,
254
+ wp.float64,
255
+ wp.vec3i,
256
+ wp.vec3ui,
257
+ wp.vec3l,
258
+ wp.vec3ul,
259
+ wp.vec3h,
260
+ wp.vec3f,
261
+ wp.vec3d,
262
+ ):
263
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
264
+
265
+ add_function_test(
266
+ TestAtomic,
267
+ f"test_atomic_add_supported_dtypes_{dtype.__name__}",
268
+ test_atomic_add_supported_dtypes,
269
+ devices=devices,
270
+ dtype=dtype,
271
+ )
272
+
273
+ if scalar_type is not wp.float16:
274
+ add_function_test(
275
+ TestAtomic,
276
+ f"test_atomic_min_supported_dtypes_{dtype.__name__}",
277
+ test_atomic_min_supported_dtypes,
278
+ devices=devices,
279
+ dtype=dtype,
280
+ )
281
+ add_function_test(
282
+ TestAtomic,
283
+ f"test_atomic_max_supported_dtypes_{dtype.__name__}",
284
+ test_atomic_max_supported_dtypes,
285
+ devices=devices,
286
+ dtype=dtype,
287
+ )
288
+
289
+
290
+ for dtype in (
291
+ wp.int8,
292
+ wp.uint8,
293
+ wp.int16,
294
+ wp.uint16,
295
+ wp.float16,
296
+ wp.vec3b,
297
+ wp.vec3ub,
298
+ wp.vec3s,
299
+ wp.vec3us,
300
+ wp.vec3h,
301
+ ):
302
+ scalar_type = getattr(dtype, "_wp_scalar_type_", dtype)
303
+
304
+ if scalar_type is not wp.float16:
305
+ add_function_test(
306
+ TestAtomic,
307
+ f"test_atomic_add_unsupported_dtypes_{dtype.__name__}",
308
+ test_atomic_add_unsupported_dtypes,
309
+ devices=devices,
310
+ dtype=dtype,
311
+ )
312
+
313
+ add_function_test(
314
+ TestAtomic,
315
+ f"test_atomic_min_unsupported_dtypes_{dtype.__name__}",
316
+ test_atomic_min_unsupported_dtypes,
317
+ devices=devices,
318
+ dtype=dtype,
319
+ )
320
+
321
+ add_function_test(
322
+ TestAtomic,
323
+ f"test_atomic_max_unsupported_dtypes_{dtype.__name__}",
324
+ test_atomic_max_unsupported_dtypes,
325
+ devices=devices,
326
+ dtype=dtype,
327
+ )
328
+
150
329
 
151
330
  if __name__ == "__main__":
152
331
  wp.clear_kernel_cache()