warp-lang 1.7.1__py3-none-manylinux_2_34_aarch64.whl → 1.7.2__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/builtins.py +92 -56
- warp/codegen.py +31 -22
- warp/config.py +1 -1
- warp/context.py +106 -49
- warp/fem/cache.py +1 -1
- warp/jax_experimental/ffi.py +95 -66
- warp/native/builtin.h +91 -65
- warp/native/svd.h +59 -49
- warp/native/tile.h +46 -17
- warp/native/volume.cpp +2 -2
- warp/native/volume_builder.cu +33 -22
- warp/render/render_opengl.py +22 -17
- warp/render/render_usd.py +3 -3
- warp/sim/model.py +29 -21
- warp/sparse.py +1 -1
- warp/stubs.py +72 -24
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/sim/test_model.py +5 -3
- warp/tests/sim/test_sim_grad.py +1 -8
- warp/tests/test_array.py +8 -7
- warp/tests/test_atomic.py +181 -2
- warp/tests/test_builtins_resolution.py +38 -38
- warp/tests/test_fem.py +20 -6
- warp/tests/test_func.py +1 -1
- warp/tests/test_mat.py +46 -16
- warp/tests/test_struct.py +116 -0
- warp/tests/tile/test_tile.py +27 -0
- warp/tests/tile/test_tile_load.py +27 -0
- warp/types.py +42 -1
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2.dist-info}/METADATA +26 -16
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2.dist-info}/RECORD +36 -36
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2.dist-info}/WHEEL +1 -1
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2.dist-info}/top_level.txt +0 -0
|
@@ -84,25 +84,25 @@ def test_int_int_args_support(test, device, dtype):
|
|
|
84
84
|
else:
|
|
85
85
|
with test.assertRaisesRegex(
|
|
86
86
|
RuntimeError,
|
|
87
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
87
|
+
rf"Couldn't find a function 'mul' compatible with the arguments '{dtype.__name__}, int'$",
|
|
88
88
|
):
|
|
89
89
|
wp.mul(dtype(value), value)
|
|
90
90
|
|
|
91
91
|
with test.assertRaisesRegex(
|
|
92
92
|
RuntimeError,
|
|
93
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
93
|
+
rf"Couldn't find a function 'mul' compatible with the arguments '{np_type.__name__}, int'$",
|
|
94
94
|
):
|
|
95
95
|
wp.mul(nps(np_type, value), value)
|
|
96
96
|
|
|
97
97
|
with test.assertRaisesRegex(
|
|
98
98
|
RuntimeError,
|
|
99
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
99
|
+
rf"Couldn't find a function 'mul' compatible with the arguments 'int, {dtype.__name__}'$",
|
|
100
100
|
):
|
|
101
101
|
wp.mul(value, dtype(value))
|
|
102
102
|
|
|
103
103
|
with test.assertRaisesRegex(
|
|
104
104
|
RuntimeError,
|
|
105
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
105
|
+
rf"Couldn't find a function 'mul' compatible with the arguments 'int, {np_type.__name__}'$",
|
|
106
106
|
):
|
|
107
107
|
wp.mul(value, nps(np_type, value))
|
|
108
108
|
|
|
@@ -189,73 +189,73 @@ def test_mat_mat_args_support(test, device, dtype):
|
|
|
189
189
|
else:
|
|
190
190
|
with test.assertRaisesRegex(
|
|
191
191
|
RuntimeError,
|
|
192
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
192
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'mat_t, tuple'$",
|
|
193
193
|
):
|
|
194
194
|
wp.ddot(mat_cls(*a_values), b_values)
|
|
195
195
|
|
|
196
196
|
with test.assertRaisesRegex(
|
|
197
197
|
RuntimeError,
|
|
198
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
198
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
199
199
|
):
|
|
200
200
|
wp.ddot(wpv(dtype, a_values), b_values)
|
|
201
201
|
|
|
202
202
|
with test.assertRaisesRegex(
|
|
203
203
|
RuntimeError,
|
|
204
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
204
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
205
205
|
):
|
|
206
206
|
wp.ddot(wpm(dtype, 3, a_values), b_values)
|
|
207
207
|
|
|
208
208
|
with test.assertRaisesRegex(
|
|
209
209
|
RuntimeError,
|
|
210
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
210
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
211
211
|
):
|
|
212
212
|
wp.ddot(npv(np_type, a_values), b_values)
|
|
213
213
|
|
|
214
214
|
with test.assertRaisesRegex(
|
|
215
215
|
RuntimeError,
|
|
216
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
216
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
217
217
|
):
|
|
218
218
|
wp.ddot(npm(np_type, 3, a_values), b_values)
|
|
219
219
|
|
|
220
220
|
with test.assertRaisesRegex(
|
|
221
221
|
RuntimeError,
|
|
222
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
222
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'ndarray, tuple'$",
|
|
223
223
|
):
|
|
224
224
|
wp.ddot(np.array(npv(np_type, a_values)), b_values)
|
|
225
225
|
|
|
226
226
|
with test.assertRaisesRegex(
|
|
227
227
|
RuntimeError,
|
|
228
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
228
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, mat_t'$",
|
|
229
229
|
):
|
|
230
230
|
wp.ddot(a_values, mat_cls(*b_values))
|
|
231
231
|
|
|
232
232
|
with test.assertRaisesRegex(
|
|
233
233
|
RuntimeError,
|
|
234
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
234
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
235
235
|
):
|
|
236
236
|
wp.ddot(a_values, wpv(dtype, b_values))
|
|
237
237
|
|
|
238
238
|
with test.assertRaisesRegex(
|
|
239
239
|
RuntimeError,
|
|
240
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
240
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
241
241
|
):
|
|
242
242
|
wp.ddot(a_values, wpm(dtype, 3, b_values))
|
|
243
243
|
|
|
244
244
|
with test.assertRaisesRegex(
|
|
245
245
|
RuntimeError,
|
|
246
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
246
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
247
247
|
):
|
|
248
248
|
wp.ddot(a_values, npv(np_type, b_values))
|
|
249
249
|
|
|
250
250
|
with test.assertRaisesRegex(
|
|
251
251
|
RuntimeError,
|
|
252
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
252
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
|
|
253
253
|
):
|
|
254
254
|
wp.ddot(a_values, npm(np_type, 3, b_values))
|
|
255
255
|
|
|
256
256
|
with test.assertRaisesRegex(
|
|
257
257
|
RuntimeError,
|
|
258
|
-
r"Couldn't find a function 'ddot' compatible with
|
|
258
|
+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, ndarray'$",
|
|
259
259
|
):
|
|
260
260
|
wp.ddot(a_values, np.array(npv(np_type, b_values)))
|
|
261
261
|
|
|
@@ -300,49 +300,49 @@ def test_mat_float_args_support(test, device, dtype):
|
|
|
300
300
|
else:
|
|
301
301
|
with test.assertRaisesRegex(
|
|
302
302
|
RuntimeError,
|
|
303
|
-
r"Couldn't find a function 'mul' compatible with
|
|
303
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'mat_t, float'$",
|
|
304
304
|
):
|
|
305
305
|
wp.mul(mat_cls(*a_values), b_value)
|
|
306
306
|
|
|
307
307
|
with test.assertRaisesRegex(
|
|
308
308
|
RuntimeError,
|
|
309
|
-
r"Couldn't find a function 'mul' compatible with
|
|
309
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
|
|
310
310
|
):
|
|
311
311
|
wp.mul(wpv(dtype, a_values), b_value)
|
|
312
312
|
|
|
313
313
|
with test.assertRaisesRegex(
|
|
314
314
|
RuntimeError,
|
|
315
|
-
r"Couldn't find a function 'mul' compatible with
|
|
315
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
|
|
316
316
|
):
|
|
317
317
|
wp.mul(wpm(dtype, 3, a_values), b_value)
|
|
318
318
|
|
|
319
319
|
with test.assertRaisesRegex(
|
|
320
320
|
RuntimeError,
|
|
321
|
-
r"Couldn't find a function 'mul' compatible with
|
|
321
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
|
|
322
322
|
):
|
|
323
323
|
wp.mul(npv(np_type, a_values), b_value)
|
|
324
324
|
|
|
325
325
|
with test.assertRaisesRegex(
|
|
326
326
|
RuntimeError,
|
|
327
|
-
r"Couldn't find a function 'mul' compatible with
|
|
327
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
|
|
328
328
|
):
|
|
329
329
|
wp.mul(npm(np_type, 3, a_values), b_value)
|
|
330
330
|
|
|
331
331
|
with test.assertRaisesRegex(
|
|
332
332
|
RuntimeError,
|
|
333
|
-
r"Couldn't find a function 'mul' compatible with
|
|
333
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'ndarray, float'$",
|
|
334
334
|
):
|
|
335
335
|
wp.mul(np.array(npv(np_type, a_values)), b_value)
|
|
336
336
|
|
|
337
337
|
with test.assertRaisesRegex(
|
|
338
338
|
RuntimeError,
|
|
339
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
339
|
+
rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {dtype.__name__}'$",
|
|
340
340
|
):
|
|
341
341
|
wp.mul(a_values, dtype(b_value))
|
|
342
342
|
|
|
343
343
|
with test.assertRaisesRegex(
|
|
344
344
|
RuntimeError,
|
|
345
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
345
|
+
rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {np_type.__name__}'$",
|
|
346
346
|
):
|
|
347
347
|
wp.mul(a_values, nps(np_type, b_value))
|
|
348
348
|
|
|
@@ -401,49 +401,49 @@ def test_vec_vec_args_support(test, device, dtype):
|
|
|
401
401
|
else:
|
|
402
402
|
with test.assertRaisesRegex(
|
|
403
403
|
RuntimeError,
|
|
404
|
-
r"Couldn't find a function 'dot' compatible with
|
|
404
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'vec_t, tuple'$",
|
|
405
405
|
):
|
|
406
406
|
wp.dot(vec_cls(*a_values), b_values)
|
|
407
407
|
|
|
408
408
|
with test.assertRaisesRegex(
|
|
409
409
|
RuntimeError,
|
|
410
|
-
r"Couldn't find a function 'dot' compatible with
|
|
410
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
|
|
411
411
|
):
|
|
412
412
|
wp.dot(wpv(dtype, a_values), b_values)
|
|
413
413
|
|
|
414
414
|
with test.assertRaisesRegex(
|
|
415
415
|
RuntimeError,
|
|
416
|
-
r"Couldn't find a function 'dot' compatible with
|
|
416
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
|
|
417
417
|
):
|
|
418
418
|
wp.dot(npv(np_type, a_values), b_values)
|
|
419
419
|
|
|
420
420
|
with test.assertRaisesRegex(
|
|
421
421
|
RuntimeError,
|
|
422
|
-
r"Couldn't find a function 'dot' compatible with
|
|
422
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'ndarray, tuple'$",
|
|
423
423
|
):
|
|
424
424
|
wp.dot(np.array(npv(np_type, a_values)), b_values)
|
|
425
425
|
|
|
426
426
|
with test.assertRaisesRegex(
|
|
427
427
|
RuntimeError,
|
|
428
|
-
r"Couldn't find a function 'dot' compatible with
|
|
428
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, vec_t'$",
|
|
429
429
|
):
|
|
430
430
|
wp.dot(a_values, vec_cls(*b_values))
|
|
431
431
|
|
|
432
432
|
with test.assertRaisesRegex(
|
|
433
433
|
RuntimeError,
|
|
434
|
-
r"Couldn't find a function 'dot' compatible with
|
|
434
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
|
|
435
435
|
):
|
|
436
436
|
wp.dot(a_values, wpv(dtype, b_values))
|
|
437
437
|
|
|
438
438
|
with test.assertRaisesRegex(
|
|
439
439
|
RuntimeError,
|
|
440
|
-
r"Couldn't find a function 'dot' compatible with
|
|
440
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
|
|
441
441
|
):
|
|
442
442
|
wp.dot(a_values, npv(np_type, b_values))
|
|
443
443
|
|
|
444
444
|
with test.assertRaisesRegex(
|
|
445
445
|
RuntimeError,
|
|
446
|
-
r"Couldn't find a function 'dot' compatible with
|
|
446
|
+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, ndarray'$",
|
|
447
447
|
):
|
|
448
448
|
wp.dot(a_values, np.array(npv(np_type, b_values)))
|
|
449
449
|
|
|
@@ -480,37 +480,37 @@ def test_vec_float_args_support(test, device, dtype):
|
|
|
480
480
|
else:
|
|
481
481
|
with test.assertRaisesRegex(
|
|
482
482
|
RuntimeError,
|
|
483
|
-
r"Couldn't find a function 'mul' compatible with
|
|
483
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'vec_t, float'$",
|
|
484
484
|
):
|
|
485
485
|
wp.mul(vec_cls(*a_values), b_value)
|
|
486
486
|
|
|
487
487
|
with test.assertRaisesRegex(
|
|
488
488
|
RuntimeError,
|
|
489
|
-
r"Couldn't find a function 'mul' compatible with
|
|
489
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
|
|
490
490
|
):
|
|
491
491
|
wp.mul(wpv(dtype, a_values), b_value)
|
|
492
492
|
|
|
493
493
|
with test.assertRaisesRegex(
|
|
494
494
|
RuntimeError,
|
|
495
|
-
r"Couldn't find a function 'mul' compatible with
|
|
495
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
|
|
496
496
|
):
|
|
497
497
|
wp.mul(npv(np_type, a_values), b_value)
|
|
498
498
|
|
|
499
499
|
with test.assertRaisesRegex(
|
|
500
500
|
RuntimeError,
|
|
501
|
-
r"Couldn't find a function 'mul' compatible with
|
|
501
|
+
r"Couldn't find a function 'mul' compatible with the arguments 'ndarray, float'$",
|
|
502
502
|
):
|
|
503
503
|
wp.mul(np.array(npv(np_type, a_values)), b_value)
|
|
504
504
|
|
|
505
505
|
with test.assertRaisesRegex(
|
|
506
506
|
RuntimeError,
|
|
507
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
507
|
+
rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {dtype.__name__}'$",
|
|
508
508
|
):
|
|
509
509
|
wp.mul(a_values, dtype(b_value))
|
|
510
510
|
|
|
511
511
|
with test.assertRaisesRegex(
|
|
512
512
|
RuntimeError,
|
|
513
|
-
rf"Couldn't find a function 'mul' compatible with
|
|
513
|
+
rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {np_type.__name__}'$",
|
|
514
514
|
):
|
|
515
515
|
wp.mul(a_values, nps(np_type, b_value))
|
|
516
516
|
|
warp/tests/test_fem.py
CHANGED
|
@@ -818,15 +818,15 @@ def _rigid_deformation_field(s: Sample, domain: Domain, translation: wp.vec3, ro
|
|
|
818
818
|
def test_deformed_geometry(test, device):
|
|
819
819
|
N = 3
|
|
820
820
|
|
|
821
|
+
translation = [1.0, 2.0, 3.0]
|
|
822
|
+
rotation = [0.0, math.pi / 4.0, 0.0]
|
|
823
|
+
scale = 2.0
|
|
824
|
+
|
|
821
825
|
with wp.ScopedDevice(device):
|
|
822
826
|
positions, tet_vidx = _gen_tetmesh(N, N, N)
|
|
823
827
|
|
|
824
828
|
geo = fem.Tetmesh(tet_vertex_indices=tet_vidx, positions=positions)
|
|
825
829
|
|
|
826
|
-
translation = [1.0, 2.0, 3.0]
|
|
827
|
-
rotation = [0.0, math.pi / 4.0, 0.0]
|
|
828
|
-
scale = 2.0
|
|
829
|
-
|
|
830
830
|
vector_space = fem.make_polynomial_space(geo, dtype=wp.vec3, degree=2)
|
|
831
831
|
pos_field = vector_space.make_field()
|
|
832
832
|
fem.interpolate(
|
|
@@ -878,6 +878,15 @@ def test_deformed_geometry(test, device):
|
|
|
878
878
|
],
|
|
879
879
|
)
|
|
880
880
|
|
|
881
|
+
|
|
882
|
+
def test_deformed_geometry_codimensional(test, device):
|
|
883
|
+
N = 3
|
|
884
|
+
|
|
885
|
+
translation = [1.0, 2.0, 3.0]
|
|
886
|
+
rotation = [0.0, math.pi / 4.0, 0.0]
|
|
887
|
+
scale = 2.0
|
|
888
|
+
|
|
889
|
+
with wp.ScopedDevice(device):
|
|
881
890
|
# Test with Trimesh3d (different space and cell dimensions)
|
|
882
891
|
positions, tri_vidx = _gen_trimesh(N, N)
|
|
883
892
|
positions = positions.numpy()
|
|
@@ -897,7 +906,9 @@ def test_deformed_geometry(test, device):
|
|
|
897
906
|
deformed_geo = pos_field.make_deformed_geometry()
|
|
898
907
|
|
|
899
908
|
@wp.kernel
|
|
900
|
-
def
|
|
909
|
+
def _test_deformed_geometry_normal_codimensional(
|
|
910
|
+
geo_arg: geo.CellArg, def_arg: deformed_geo.CellArg, rotation: wp.vec3
|
|
911
|
+
):
|
|
901
912
|
i = wp.tid()
|
|
902
913
|
|
|
903
914
|
s = make_free_sample(i, Coords(0.5, 0.5, 0.0))
|
|
@@ -908,7 +919,7 @@ def test_deformed_geometry(test, device):
|
|
|
908
919
|
wp.expect_near(wp.quat_rotate(q, geo_n), def_n, 0.001)
|
|
909
920
|
|
|
910
921
|
wp.launch(
|
|
911
|
-
|
|
922
|
+
_test_deformed_geometry_normal_codimensional,
|
|
912
923
|
dim=geo.cell_count(),
|
|
913
924
|
inputs=[
|
|
914
925
|
geo.cell_arg_value(wp.get_device()),
|
|
@@ -2035,6 +2046,9 @@ add_function_test(TestFem, "test_hex_mesh", test_hex_mesh, devices=devices)
|
|
|
2035
2046
|
add_function_test(TestFem, "test_nanogrid", test_nanogrid, devices=cuda_devices)
|
|
2036
2047
|
add_function_test(TestFem, "test_adaptive_nanogrid", test_adaptive_nanogrid, devices=cuda_devices)
|
|
2037
2048
|
add_function_test(TestFem, "test_deformed_geometry", test_deformed_geometry, devices=devices)
|
|
2049
|
+
add_function_test(
|
|
2050
|
+
TestFem, "test_deformed_geometry_codimensional", test_deformed_geometry_codimensional, devices=devices
|
|
2051
|
+
)
|
|
2038
2052
|
add_function_test(TestFem, "test_vector_spaces", test_vector_spaces, devices=devices)
|
|
2039
2053
|
add_function_test(TestFem, "test_dof_mapper", test_dof_mapper)
|
|
2040
2054
|
add_function_test(TestFem, "test_point_basis", test_point_basis)
|
warp/tests/test_func.py
CHANGED
|
@@ -421,7 +421,7 @@ class TestFunc(unittest.TestCase):
|
|
|
421
421
|
b = wp.mat22d(1.0, 2.0, 3.0, 4.0)
|
|
422
422
|
with self.assertRaisesRegex(
|
|
423
423
|
RuntimeError,
|
|
424
|
-
r"^Couldn't find a function 'mul' compatible with
|
|
424
|
+
r"^Couldn't find a function 'mul' compatible with the arguments 'mat22f, mat22d'$",
|
|
425
425
|
):
|
|
426
426
|
a * b
|
|
427
427
|
|
warp/tests/test_mat.py
CHANGED
|
@@ -1078,15 +1078,21 @@ def test_svd_2D(test, device, dtype, register_kernels=False):
|
|
|
1078
1078
|
Vout: wp.array(dtype=mat22),
|
|
1079
1079
|
outcomponents: wp.array(dtype=wptype),
|
|
1080
1080
|
):
|
|
1081
|
+
tid = wp.tid()
|
|
1082
|
+
|
|
1081
1083
|
U = mat22()
|
|
1082
1084
|
sigma = vec2()
|
|
1083
1085
|
V = mat22()
|
|
1084
1086
|
|
|
1085
|
-
wp.svd2(m2[
|
|
1087
|
+
wp.svd2(m2[tid], U, sigma, V) # Assuming there's a 2D SVD kernel
|
|
1086
1088
|
|
|
1087
|
-
Uout[
|
|
1088
|
-
sigmaout[
|
|
1089
|
-
Vout[
|
|
1089
|
+
Uout[tid] = U
|
|
1090
|
+
sigmaout[tid] = sigma
|
|
1091
|
+
Vout[tid] = V
|
|
1092
|
+
|
|
1093
|
+
# backprop test only for first input
|
|
1094
|
+
if tid > 0:
|
|
1095
|
+
return
|
|
1090
1096
|
|
|
1091
1097
|
# multiply outputs by 2 so we've got something to backpropagate:
|
|
1092
1098
|
idx = 0
|
|
@@ -1111,22 +1117,46 @@ def test_svd_2D(test, device, dtype, register_kernels=False):
|
|
|
1111
1117
|
if register_kernels:
|
|
1112
1118
|
return
|
|
1113
1119
|
|
|
1114
|
-
|
|
1120
|
+
mats = np.concatenate(
|
|
1121
|
+
(
|
|
1122
|
+
randvals(rng, [24, 2, 2], dtype) + np.eye(2),
|
|
1123
|
+
# rng unlikely to hit edge cases, build them manually
|
|
1124
|
+
[
|
|
1125
|
+
np.zeros((2, 2)),
|
|
1126
|
+
np.eye(2),
|
|
1127
|
+
5.0 * np.eye(2),
|
|
1128
|
+
np.array([[1.0, 0.0], [0.0, 0.0]]),
|
|
1129
|
+
np.array([[0.0, 0.0], [0.0, 2.0]]),
|
|
1130
|
+
np.array([[1.0, 1.0], [-1.0, -1.0]]),
|
|
1131
|
+
np.array([[3.0, 0.0], [4.0, 5.0]]),
|
|
1132
|
+
np.eye(2) + tol * np.array([[1.0, 1.0], [-1.0, -1.0]]),
|
|
1133
|
+
],
|
|
1134
|
+
),
|
|
1135
|
+
axis=0,
|
|
1136
|
+
)
|
|
1137
|
+
M = len(mats)
|
|
1138
|
+
m2 = wp.array(mats, dtype=mat22, requires_grad=True, device=device)
|
|
1115
1139
|
|
|
1116
1140
|
outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
|
|
1117
|
-
Uout = wp.zeros(
|
|
1118
|
-
sigmaout = wp.zeros(
|
|
1119
|
-
Vout = wp.zeros(
|
|
1141
|
+
Uout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
|
|
1142
|
+
sigmaout = wp.zeros(M, dtype=vec2, requires_grad=True, device=device)
|
|
1143
|
+
Vout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
|
|
1120
1144
|
|
|
1121
|
-
wp.launch(kernel, dim=
|
|
1145
|
+
wp.launch(kernel, dim=M, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
|
|
1122
1146
|
|
|
1123
|
-
Uout_np = Uout.numpy()
|
|
1124
|
-
sigmaout_np =
|
|
1125
|
-
Vout_np = Vout.numpy()
|
|
1147
|
+
Uout_np = Uout.numpy().astype(np.float64)
|
|
1148
|
+
sigmaout_np = sigmaout.numpy().astype(np.float64)
|
|
1149
|
+
Vout_np = Vout.numpy().astype(np.float64)
|
|
1150
|
+
|
|
1151
|
+
USVt_np = Uout_np @ (sigmaout_np[..., None] * np.transpose(Vout_np, axes=(0, 2, 1)))
|
|
1126
1152
|
|
|
1127
1153
|
assert_np_equal(
|
|
1128
|
-
np.
|
|
1154
|
+
Uout_np @ np.transpose(Uout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
|
|
1129
1155
|
)
|
|
1156
|
+
assert_np_equal(
|
|
1157
|
+
Vout_np @ np.transpose(Vout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
|
|
1158
|
+
)
|
|
1159
|
+
assert_np_equal(USVt_np, m2.numpy().astype(np.float64), tol=30 * tol)
|
|
1130
1160
|
|
|
1131
1161
|
if dtype == np.float16:
|
|
1132
1162
|
# Skip gradient check for float16 due to rounding errors
|
|
@@ -1145,7 +1175,7 @@ def test_svd_2D(test, device, dtype, register_kernels=False):
|
|
|
1145
1175
|
|
|
1146
1176
|
tape.zero()
|
|
1147
1177
|
|
|
1148
|
-
dx = 0.
|
|
1178
|
+
dx = 0.001
|
|
1149
1179
|
fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
|
|
1150
1180
|
for ii in range(2):
|
|
1151
1181
|
for jj in range(2):
|
|
@@ -1180,9 +1210,9 @@ def test_qr(test, device, dtype, register_kernels=False):
|
|
|
1180
1210
|
rng = np.random.default_rng(123)
|
|
1181
1211
|
|
|
1182
1212
|
tol = {
|
|
1183
|
-
np.float16: 2.
|
|
1213
|
+
np.float16: 2.5e-3,
|
|
1184
1214
|
np.float32: 1.0e-6,
|
|
1185
|
-
np.float64: 1.0e-
|
|
1215
|
+
np.float64: 1.0e-12,
|
|
1186
1216
|
}.get(dtype, 0)
|
|
1187
1217
|
|
|
1188
1218
|
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
warp/tests/test_struct.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import gc # Added for garbage collection tests
|
|
16
17
|
import unittest
|
|
17
18
|
from typing import Any
|
|
18
19
|
|
|
@@ -221,6 +222,11 @@ def test_nested_struct(test, device):
|
|
|
221
222
|
foo.bar.y = 1.23
|
|
222
223
|
foo.x = 123
|
|
223
224
|
|
|
225
|
+
# verify that struct attributes are instances of their original class
|
|
226
|
+
assert isinstance(foo, Foo.cls)
|
|
227
|
+
assert isinstance(foo.bar, Bar.cls)
|
|
228
|
+
assert isinstance(foo.bar.baz, Baz.cls)
|
|
229
|
+
|
|
224
230
|
wp.launch(kernel_nested_struct, dim=dim, inputs=[foo], device=device)
|
|
225
231
|
|
|
226
232
|
assert_array_equal(
|
|
@@ -243,6 +249,18 @@ def test_struct_attribute_error(test, device):
|
|
|
243
249
|
)
|
|
244
250
|
|
|
245
251
|
|
|
252
|
+
def test_struct_inheritance_error(test, device):
|
|
253
|
+
with test.assertRaisesRegex(RuntimeError, r"Warp structs must be defined as base classes$"):
|
|
254
|
+
|
|
255
|
+
@wp.struct
|
|
256
|
+
class Parent:
|
|
257
|
+
x: int
|
|
258
|
+
|
|
259
|
+
@wp.struct
|
|
260
|
+
class Child(Parent):
|
|
261
|
+
y: int
|
|
262
|
+
|
|
263
|
+
|
|
246
264
|
@wp.kernel
|
|
247
265
|
def test_struct_instantiate(data: wp.array(dtype=int)):
|
|
248
266
|
baz = Baz(data, wp.vec3(0.0, 0.0, 26.0))
|
|
@@ -643,6 +661,96 @@ def test_struct_array_hash(test, device):
|
|
|
643
661
|
)
|
|
644
662
|
|
|
645
663
|
|
|
664
|
+
# Tests for garbage collection behavior with arrays in structs
|
|
665
|
+
@wp.struct
|
|
666
|
+
class StructWithArray:
|
|
667
|
+
data: wp.array(dtype=float)
|
|
668
|
+
some_value: int
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@wp.kernel
|
|
672
|
+
def access_array_kernel(s: StructWithArray, out: wp.array(dtype=float)):
|
|
673
|
+
# This kernel is used to verify data integrity by reading the first element.
|
|
674
|
+
# Assumes s.data has at least 1 element for this test.
|
|
675
|
+
out[0] = s.data[0]
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
@wp.kernel
|
|
679
|
+
def compute_loss_from_struct_array_kernel(s_in: StructWithArray, loss_val: wp.array(dtype=float)):
|
|
680
|
+
# Compute a simple scalar loss from the array elements for grad testing.
|
|
681
|
+
# Assumes s_in.data has at least 2 elements for this test.
|
|
682
|
+
res = 0.0
|
|
683
|
+
res += s_in.data[0] * 2.0 # Example weight
|
|
684
|
+
res += s_in.data[1] * 3.0 # Example weight
|
|
685
|
+
loss_val[0] = res
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def test_struct_array_gc_direct_assignment(test, device):
|
|
689
|
+
"""
|
|
690
|
+
Tests that an array assigned to a struct (with no other direct Python
|
|
691
|
+
references) is not garbage collected prematurely.
|
|
692
|
+
"""
|
|
693
|
+
wp.init()
|
|
694
|
+
|
|
695
|
+
s = StructWithArray()
|
|
696
|
+
s.some_value = 20
|
|
697
|
+
|
|
698
|
+
# Create an array, then assign it to the struct.
|
|
699
|
+
# After this assignment, 's.data' is the primary way to access it from
|
|
700
|
+
# Python's perspective, though Warp's context should also hold a reference.
|
|
701
|
+
local_array = wp.array([4.0, 5.0, 6.0], dtype=float, device=device)
|
|
702
|
+
s.data = local_array
|
|
703
|
+
del local_array # Remove the direct Python reference
|
|
704
|
+
|
|
705
|
+
# Force garbage collection
|
|
706
|
+
gc.collect()
|
|
707
|
+
|
|
708
|
+
# Attempt to access the array in a kernel
|
|
709
|
+
out_wp = wp.zeros(1, dtype=float, device=device)
|
|
710
|
+
try:
|
|
711
|
+
wp.launch(kernel=access_array_kernel, dim=1, inputs=[s, out_wp], device=device)
|
|
712
|
+
|
|
713
|
+
# We expect to read 4.0 if the array is still valid
|
|
714
|
+
assert out_wp.numpy()[0] == 4.0, "Array data was not accessible or incorrect after GC with direct assignment."
|
|
715
|
+
except Exception as e:
|
|
716
|
+
test.fail(f"Kernel execution failed after GC with direct assignment: {e}")
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def test_struct_array_gc_requires_grad_toggle(test, device):
|
|
720
|
+
"""
|
|
721
|
+
Tests that an array within a struct is not garbage collected prematurely
|
|
722
|
+
when its requires_grad flag is toggled, and that backward pass works.
|
|
723
|
+
"""
|
|
724
|
+
wp.init()
|
|
725
|
+
|
|
726
|
+
s = StructWithArray()
|
|
727
|
+
s.some_value = 10
|
|
728
|
+
# Initialize array with requires_grad=True. Content: [1.0, 2.0, 3.0]
|
|
729
|
+
s.data = wp.array([1.0, 2.0, 3.0], dtype=float, device=device, requires_grad=True)
|
|
730
|
+
|
|
731
|
+
loss_wp = wp.zeros(1, dtype=float, device=device, requires_grad=True)
|
|
732
|
+
|
|
733
|
+
tape = wp.Tape()
|
|
734
|
+
with tape:
|
|
735
|
+
# Launch kernel that uses s.data to compute a loss
|
|
736
|
+
wp.launch(
|
|
737
|
+
kernel=compute_loss_from_struct_array_kernel,
|
|
738
|
+
dim=1,
|
|
739
|
+
inputs=[s, loss_wp],
|
|
740
|
+
device=device,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
# Expected loss = 1.0*2.0 + 2.0*3.0 = 2.0 + 6.0 = 8.0
|
|
744
|
+
|
|
745
|
+
# After the forward pass is recorded, toggle requires_grad and run GC
|
|
746
|
+
s.data.requires_grad = False
|
|
747
|
+
gc.collect()
|
|
748
|
+
|
|
749
|
+
# will cause a memory access violation if grad array has been garbage collected
|
|
750
|
+
# or struct is not updated correctly
|
|
751
|
+
tape.backward(loss=loss_wp)
|
|
752
|
+
|
|
753
|
+
|
|
646
754
|
devices = get_test_devices()
|
|
647
755
|
|
|
648
756
|
|
|
@@ -677,6 +785,8 @@ add_kernel_test(
|
|
|
677
785
|
)
|
|
678
786
|
add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
|
|
679
787
|
add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
|
|
788
|
+
add_function_test(TestStruct, "test_struct_attribute_error", test_struct_attribute_error, devices=devices)
|
|
789
|
+
add_function_test(TestStruct, "test_struct_inheritance_error", test_struct_inheritance_error, devices=devices)
|
|
680
790
|
add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
|
|
681
791
|
add_function_test(TestStruct, "test_convert_to_device", test_convert_to_device, devices=devices)
|
|
682
792
|
add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
|
|
@@ -727,6 +837,12 @@ add_kernel_test(
|
|
|
727
837
|
)
|
|
728
838
|
|
|
729
839
|
add_function_test(TestStruct, "test_struct_array_hash", test_struct_array_hash, devices=None)
|
|
840
|
+
add_function_test(
|
|
841
|
+
TestStruct, "test_struct_array_gc_requires_grad_toggle", test_struct_array_gc_requires_grad_toggle, devices=devices
|
|
842
|
+
)
|
|
843
|
+
add_function_test(
|
|
844
|
+
TestStruct, "test_struct_array_gc_direct_assignment", test_struct_array_gc_direct_assignment, devices=devices
|
|
845
|
+
)
|
|
730
846
|
|
|
731
847
|
|
|
732
848
|
if __name__ == "__main__":
|
warp/tests/tile/test_tile.py
CHANGED
|
@@ -531,6 +531,32 @@ def test_tile_extract_repeated(test, device):
|
|
|
531
531
|
assert_np_equal(a.grad.numpy(), expected_grad)
|
|
532
532
|
|
|
533
533
|
|
|
534
|
+
@wp.kernel
|
|
535
|
+
def test_tile_assign_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
536
|
+
i, j = wp.tid()
|
|
537
|
+
|
|
538
|
+
a = wp.tile_zeros(shape=(TILE_M,), dtype=float)
|
|
539
|
+
|
|
540
|
+
a[j] = x[j]
|
|
541
|
+
|
|
542
|
+
wp.tile_atomic_add(y, a, offset=(0,))
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def test_tile_assign(test, device):
|
|
546
|
+
x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
|
|
547
|
+
y = wp.zeros(TILE_M, dtype=float, device=device, requires_grad=True)
|
|
548
|
+
|
|
549
|
+
tape = wp.Tape()
|
|
550
|
+
with tape:
|
|
551
|
+
wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=64, device=device)
|
|
552
|
+
|
|
553
|
+
y.grad = wp.ones_like(y)
|
|
554
|
+
tape.backward()
|
|
555
|
+
|
|
556
|
+
assert_np_equal(y.numpy(), np.full(TILE_M, 2.0, dtype=np.float32))
|
|
557
|
+
assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
|
|
558
|
+
|
|
559
|
+
|
|
534
560
|
@wp.kernel
|
|
535
561
|
def test_tile_transpose_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
|
|
536
562
|
x = wp.tile_load(input, shape=(TILE_M, TILE_N))
|
|
@@ -767,6 +793,7 @@ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, che
|
|
|
767
793
|
add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
|
|
768
794
|
add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
|
|
769
795
|
add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
|
|
796
|
+
add_function_test(TestTile, "test_tile_assign", test_tile_assign, devices=devices)
|
|
770
797
|
add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
|
|
771
798
|
add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
|
|
772
799
|
add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
|