warp-lang 1.7.1__py3-none-manylinux_2_34_aarch64.whl → 1.7.2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

@@ -84,25 +84,25 @@ def test_int_int_args_support(test, device, dtype):
84
84
  else:
85
85
  with test.assertRaisesRegex(
86
86
  RuntimeError,
87
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments '{dtype.__name__}, int'$",
87
+ rf"Couldn't find a function 'mul' compatible with the arguments '{dtype.__name__}, int'$",
88
88
  ):
89
89
  wp.mul(dtype(value), value)
90
90
 
91
91
  with test.assertRaisesRegex(
92
92
  RuntimeError,
93
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments '{np_type.__name__}, int'$",
93
+ rf"Couldn't find a function 'mul' compatible with the arguments '{np_type.__name__}, int'$",
94
94
  ):
95
95
  wp.mul(nps(np_type, value), value)
96
96
 
97
97
  with test.assertRaisesRegex(
98
98
  RuntimeError,
99
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments 'int, {dtype.__name__}'$",
99
+ rf"Couldn't find a function 'mul' compatible with the arguments 'int, {dtype.__name__}'$",
100
100
  ):
101
101
  wp.mul(value, dtype(value))
102
102
 
103
103
  with test.assertRaisesRegex(
104
104
  RuntimeError,
105
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments 'int, {np_type.__name__}'$",
105
+ rf"Couldn't find a function 'mul' compatible with the arguments 'int, {np_type.__name__}'$",
106
106
  ):
107
107
  wp.mul(value, nps(np_type, value))
108
108
 
@@ -189,73 +189,73 @@ def test_mat_mat_args_support(test, device, dtype):
189
189
  else:
190
190
  with test.assertRaisesRegex(
191
191
  RuntimeError,
192
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'mat_t, tuple'$",
192
+ r"Couldn't find a function 'ddot' compatible with the arguments 'mat_t, tuple'$",
193
193
  ):
194
194
  wp.ddot(mat_cls(*a_values), b_values)
195
195
 
196
196
  with test.assertRaisesRegex(
197
197
  RuntimeError,
198
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
198
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
199
199
  ):
200
200
  wp.ddot(wpv(dtype, a_values), b_values)
201
201
 
202
202
  with test.assertRaisesRegex(
203
203
  RuntimeError,
204
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
204
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
205
205
  ):
206
206
  wp.ddot(wpm(dtype, 3, a_values), b_values)
207
207
 
208
208
  with test.assertRaisesRegex(
209
209
  RuntimeError,
210
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
210
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
211
211
  ):
212
212
  wp.ddot(npv(np_type, a_values), b_values)
213
213
 
214
214
  with test.assertRaisesRegex(
215
215
  RuntimeError,
216
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
216
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
217
217
  ):
218
218
  wp.ddot(npm(np_type, 3, a_values), b_values)
219
219
 
220
220
  with test.assertRaisesRegex(
221
221
  RuntimeError,
222
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'ndarray, tuple'$",
222
+ r"Couldn't find a function 'ddot' compatible with the arguments 'ndarray, tuple'$",
223
223
  ):
224
224
  wp.ddot(np.array(npv(np_type, a_values)), b_values)
225
225
 
226
226
  with test.assertRaisesRegex(
227
227
  RuntimeError,
228
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, mat_t'$",
228
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, mat_t'$",
229
229
  ):
230
230
  wp.ddot(a_values, mat_cls(*b_values))
231
231
 
232
232
  with test.assertRaisesRegex(
233
233
  RuntimeError,
234
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
234
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
235
235
  ):
236
236
  wp.ddot(a_values, wpv(dtype, b_values))
237
237
 
238
238
  with test.assertRaisesRegex(
239
239
  RuntimeError,
240
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
240
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
241
241
  ):
242
242
  wp.ddot(a_values, wpm(dtype, 3, b_values))
243
243
 
244
244
  with test.assertRaisesRegex(
245
245
  RuntimeError,
246
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
246
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
247
247
  ):
248
248
  wp.ddot(a_values, npv(np_type, b_values))
249
249
 
250
250
  with test.assertRaisesRegex(
251
251
  RuntimeError,
252
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
252
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
253
253
  ):
254
254
  wp.ddot(a_values, npm(np_type, 3, b_values))
255
255
 
256
256
  with test.assertRaisesRegex(
257
257
  RuntimeError,
258
- r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, ndarray'$",
258
+ r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, ndarray'$",
259
259
  ):
260
260
  wp.ddot(a_values, np.array(npv(np_type, b_values)))
261
261
 
@@ -300,49 +300,49 @@ def test_mat_float_args_support(test, device, dtype):
300
300
  else:
301
301
  with test.assertRaisesRegex(
302
302
  RuntimeError,
303
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'mat_t, float'$",
303
+ r"Couldn't find a function 'mul' compatible with the arguments 'mat_t, float'$",
304
304
  ):
305
305
  wp.mul(mat_cls(*a_values), b_value)
306
306
 
307
307
  with test.assertRaisesRegex(
308
308
  RuntimeError,
309
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
309
+ r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
310
310
  ):
311
311
  wp.mul(wpv(dtype, a_values), b_value)
312
312
 
313
313
  with test.assertRaisesRegex(
314
314
  RuntimeError,
315
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
315
+ r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
316
316
  ):
317
317
  wp.mul(wpm(dtype, 3, a_values), b_value)
318
318
 
319
319
  with test.assertRaisesRegex(
320
320
  RuntimeError,
321
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
321
+ r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
322
322
  ):
323
323
  wp.mul(npv(np_type, a_values), b_value)
324
324
 
325
325
  with test.assertRaisesRegex(
326
326
  RuntimeError,
327
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
327
+ r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
328
328
  ):
329
329
  wp.mul(npm(np_type, 3, a_values), b_value)
330
330
 
331
331
  with test.assertRaisesRegex(
332
332
  RuntimeError,
333
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'ndarray, float'$",
333
+ r"Couldn't find a function 'mul' compatible with the arguments 'ndarray, float'$",
334
334
  ):
335
335
  wp.mul(np.array(npv(np_type, a_values)), b_value)
336
336
 
337
337
  with test.assertRaisesRegex(
338
338
  RuntimeError,
339
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments 'tuple, {dtype.__name__}'$",
339
+ rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {dtype.__name__}'$",
340
340
  ):
341
341
  wp.mul(a_values, dtype(b_value))
342
342
 
343
343
  with test.assertRaisesRegex(
344
344
  RuntimeError,
345
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments 'tuple, {np_type.__name__}'$",
345
+ rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {np_type.__name__}'$",
346
346
  ):
347
347
  wp.mul(a_values, nps(np_type, b_value))
348
348
 
@@ -401,49 +401,49 @@ def test_vec_vec_args_support(test, device, dtype):
401
401
  else:
402
402
  with test.assertRaisesRegex(
403
403
  RuntimeError,
404
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'vec_t, tuple'$",
404
+ r"Couldn't find a function 'dot' compatible with the arguments 'vec_t, tuple'$",
405
405
  ):
406
406
  wp.dot(vec_cls(*a_values), b_values)
407
407
 
408
408
  with test.assertRaisesRegex(
409
409
  RuntimeError,
410
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
410
+ r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
411
411
  ):
412
412
  wp.dot(wpv(dtype, a_values), b_values)
413
413
 
414
414
  with test.assertRaisesRegex(
415
415
  RuntimeError,
416
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
416
+ r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
417
417
  ):
418
418
  wp.dot(npv(np_type, a_values), b_values)
419
419
 
420
420
  with test.assertRaisesRegex(
421
421
  RuntimeError,
422
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'ndarray, tuple'$",
422
+ r"Couldn't find a function 'dot' compatible with the arguments 'ndarray, tuple'$",
423
423
  ):
424
424
  wp.dot(np.array(npv(np_type, a_values)), b_values)
425
425
 
426
426
  with test.assertRaisesRegex(
427
427
  RuntimeError,
428
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, vec_t'$",
428
+ r"Couldn't find a function 'dot' compatible with the arguments 'tuple, vec_t'$",
429
429
  ):
430
430
  wp.dot(a_values, vec_cls(*b_values))
431
431
 
432
432
  with test.assertRaisesRegex(
433
433
  RuntimeError,
434
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
434
+ r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
435
435
  ):
436
436
  wp.dot(a_values, wpv(dtype, b_values))
437
437
 
438
438
  with test.assertRaisesRegex(
439
439
  RuntimeError,
440
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
440
+ r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
441
441
  ):
442
442
  wp.dot(a_values, npv(np_type, b_values))
443
443
 
444
444
  with test.assertRaisesRegex(
445
445
  RuntimeError,
446
- r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, ndarray'$",
446
+ r"Couldn't find a function 'dot' compatible with the arguments 'tuple, ndarray'$",
447
447
  ):
448
448
  wp.dot(a_values, np.array(npv(np_type, b_values)))
449
449
 
@@ -480,37 +480,37 @@ def test_vec_float_args_support(test, device, dtype):
480
480
  else:
481
481
  with test.assertRaisesRegex(
482
482
  RuntimeError,
483
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'vec_t, float'$",
483
+ r"Couldn't find a function 'mul' compatible with the arguments 'vec_t, float'$",
484
484
  ):
485
485
  wp.mul(vec_cls(*a_values), b_value)
486
486
 
487
487
  with test.assertRaisesRegex(
488
488
  RuntimeError,
489
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
489
+ r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
490
490
  ):
491
491
  wp.mul(wpv(dtype, a_values), b_value)
492
492
 
493
493
  with test.assertRaisesRegex(
494
494
  RuntimeError,
495
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
495
+ r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
496
496
  ):
497
497
  wp.mul(npv(np_type, a_values), b_value)
498
498
 
499
499
  with test.assertRaisesRegex(
500
500
  RuntimeError,
501
- r"Couldn't find a function 'mul' compatible with " r"the arguments 'ndarray, float'$",
501
+ r"Couldn't find a function 'mul' compatible with the arguments 'ndarray, float'$",
502
502
  ):
503
503
  wp.mul(np.array(npv(np_type, a_values)), b_value)
504
504
 
505
505
  with test.assertRaisesRegex(
506
506
  RuntimeError,
507
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments 'tuple, {dtype.__name__}'$",
507
+ rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {dtype.__name__}'$",
508
508
  ):
509
509
  wp.mul(a_values, dtype(b_value))
510
510
 
511
511
  with test.assertRaisesRegex(
512
512
  RuntimeError,
513
- rf"Couldn't find a function 'mul' compatible with " rf"the arguments 'tuple, {np_type.__name__}'$",
513
+ rf"Couldn't find a function 'mul' compatible with the arguments 'tuple, {np_type.__name__}'$",
514
514
  ):
515
515
  wp.mul(a_values, nps(np_type, b_value))
516
516
 
warp/tests/test_fem.py CHANGED
@@ -818,15 +818,15 @@ def _rigid_deformation_field(s: Sample, domain: Domain, translation: wp.vec3, ro
818
818
  def test_deformed_geometry(test, device):
819
819
  N = 3
820
820
 
821
+ translation = [1.0, 2.0, 3.0]
822
+ rotation = [0.0, math.pi / 4.0, 0.0]
823
+ scale = 2.0
824
+
821
825
  with wp.ScopedDevice(device):
822
826
  positions, tet_vidx = _gen_tetmesh(N, N, N)
823
827
 
824
828
  geo = fem.Tetmesh(tet_vertex_indices=tet_vidx, positions=positions)
825
829
 
826
- translation = [1.0, 2.0, 3.0]
827
- rotation = [0.0, math.pi / 4.0, 0.0]
828
- scale = 2.0
829
-
830
830
  vector_space = fem.make_polynomial_space(geo, dtype=wp.vec3, degree=2)
831
831
  pos_field = vector_space.make_field()
832
832
  fem.interpolate(
@@ -878,6 +878,15 @@ def test_deformed_geometry(test, device):
878
878
  ],
879
879
  )
880
880
 
881
+
882
+ def test_deformed_geometry_codimensional(test, device):
883
+ N = 3
884
+
885
+ translation = [1.0, 2.0, 3.0]
886
+ rotation = [0.0, math.pi / 4.0, 0.0]
887
+ scale = 2.0
888
+
889
+ with wp.ScopedDevice(device):
881
890
  # Test with Trimesh3d (different space and cell dimensions)
882
891
  positions, tri_vidx = _gen_trimesh(N, N)
883
892
  positions = positions.numpy()
@@ -897,7 +906,9 @@ def test_deformed_geometry(test, device):
897
906
  deformed_geo = pos_field.make_deformed_geometry()
898
907
 
899
908
  @wp.kernel
900
- def _test_deformed_geometry_normal(geo_arg: geo.CellArg, def_arg: deformed_geo.CellArg, rotation: wp.vec3):
909
+ def _test_deformed_geometry_normal_codimensional(
910
+ geo_arg: geo.CellArg, def_arg: deformed_geo.CellArg, rotation: wp.vec3
911
+ ):
901
912
  i = wp.tid()
902
913
 
903
914
  s = make_free_sample(i, Coords(0.5, 0.5, 0.0))
@@ -908,7 +919,7 @@ def test_deformed_geometry(test, device):
908
919
  wp.expect_near(wp.quat_rotate(q, geo_n), def_n, 0.001)
909
920
 
910
921
  wp.launch(
911
- _test_deformed_geometry_normal,
922
+ _test_deformed_geometry_normal_codimensional,
912
923
  dim=geo.cell_count(),
913
924
  inputs=[
914
925
  geo.cell_arg_value(wp.get_device()),
@@ -2035,6 +2046,9 @@ add_function_test(TestFem, "test_hex_mesh", test_hex_mesh, devices=devices)
2035
2046
  add_function_test(TestFem, "test_nanogrid", test_nanogrid, devices=cuda_devices)
2036
2047
  add_function_test(TestFem, "test_adaptive_nanogrid", test_adaptive_nanogrid, devices=cuda_devices)
2037
2048
  add_function_test(TestFem, "test_deformed_geometry", test_deformed_geometry, devices=devices)
2049
+ add_function_test(
2050
+ TestFem, "test_deformed_geometry_codimensional", test_deformed_geometry_codimensional, devices=devices
2051
+ )
2038
2052
  add_function_test(TestFem, "test_vector_spaces", test_vector_spaces, devices=devices)
2039
2053
  add_function_test(TestFem, "test_dof_mapper", test_dof_mapper)
2040
2054
  add_function_test(TestFem, "test_point_basis", test_point_basis)
warp/tests/test_func.py CHANGED
@@ -421,7 +421,7 @@ class TestFunc(unittest.TestCase):
421
421
  b = wp.mat22d(1.0, 2.0, 3.0, 4.0)
422
422
  with self.assertRaisesRegex(
423
423
  RuntimeError,
424
- r"^Couldn't find a function 'mul' compatible with " r"the arguments 'mat22f, mat22d'$",
424
+ r"^Couldn't find a function 'mul' compatible with the arguments 'mat22f, mat22d'$",
425
425
  ):
426
426
  a * b
427
427
 
warp/tests/test_mat.py CHANGED
@@ -1078,15 +1078,21 @@ def test_svd_2D(test, device, dtype, register_kernels=False):
1078
1078
  Vout: wp.array(dtype=mat22),
1079
1079
  outcomponents: wp.array(dtype=wptype),
1080
1080
  ):
1081
+ tid = wp.tid()
1082
+
1081
1083
  U = mat22()
1082
1084
  sigma = vec2()
1083
1085
  V = mat22()
1084
1086
 
1085
- wp.svd2(m2[0], U, sigma, V) # Assuming there's a 2D SVD kernel
1087
+ wp.svd2(m2[tid], U, sigma, V) # Assuming there's a 2D SVD kernel
1086
1088
 
1087
- Uout[0] = U
1088
- sigmaout[0] = sigma
1089
- Vout[0] = V
1089
+ Uout[tid] = U
1090
+ sigmaout[tid] = sigma
1091
+ Vout[tid] = V
1092
+
1093
+ # backprop test only for first input
1094
+ if tid > 0:
1095
+ return
1090
1096
 
1091
1097
  # multiply outputs by 2 so we've got something to backpropagate:
1092
1098
  idx = 0
@@ -1111,22 +1117,46 @@ def test_svd_2D(test, device, dtype, register_kernels=False):
1111
1117
  if register_kernels:
1112
1118
  return
1113
1119
 
1114
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype) + np.eye(2), dtype=mat22, requires_grad=True, device=device)
1120
+ mats = np.concatenate(
1121
+ (
1122
+ randvals(rng, [24, 2, 2], dtype) + np.eye(2),
1123
+ # rng unlikely to hit edge cases, build them manually
1124
+ [
1125
+ np.zeros((2, 2)),
1126
+ np.eye(2),
1127
+ 5.0 * np.eye(2),
1128
+ np.array([[1.0, 0.0], [0.0, 0.0]]),
1129
+ np.array([[0.0, 0.0], [0.0, 2.0]]),
1130
+ np.array([[1.0, 1.0], [-1.0, -1.0]]),
1131
+ np.array([[3.0, 0.0], [4.0, 5.0]]),
1132
+ np.eye(2) + tol * np.array([[1.0, 1.0], [-1.0, -1.0]]),
1133
+ ],
1134
+ ),
1135
+ axis=0,
1136
+ )
1137
+ M = len(mats)
1138
+ m2 = wp.array(mats, dtype=mat22, requires_grad=True, device=device)
1115
1139
 
1116
1140
  outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
1117
- Uout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1118
- sigmaout = wp.zeros(1, dtype=vec2, requires_grad=True, device=device)
1119
- Vout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1141
+ Uout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
1142
+ sigmaout = wp.zeros(M, dtype=vec2, requires_grad=True, device=device)
1143
+ Vout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
1120
1144
 
1121
- wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1145
+ wp.launch(kernel, dim=M, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1122
1146
 
1123
- Uout_np = Uout.numpy()[0].astype(np.float64)
1124
- sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
1125
- Vout_np = Vout.numpy()[0].astype(np.float64)
1147
+ Uout_np = Uout.numpy().astype(np.float64)
1148
+ sigmaout_np = sigmaout.numpy().astype(np.float64)
1149
+ Vout_np = Vout.numpy().astype(np.float64)
1150
+
1151
+ USVt_np = Uout_np @ (sigmaout_np[..., None] * np.transpose(Vout_np, axes=(0, 2, 1)))
1126
1152
 
1127
1153
  assert_np_equal(
1128
- np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m2.numpy()[0].astype(np.float64), tol=30 * tol
1154
+ Uout_np @ np.transpose(Uout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
1129
1155
  )
1156
+ assert_np_equal(
1157
+ Vout_np @ np.transpose(Vout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
1158
+ )
1159
+ assert_np_equal(USVt_np, m2.numpy().astype(np.float64), tol=30 * tol)
1130
1160
 
1131
1161
  if dtype == np.float16:
1132
1162
  # Skip gradient check for float16 due to rounding errors
@@ -1145,7 +1175,7 @@ def test_svd_2D(test, device, dtype, register_kernels=False):
1145
1175
 
1146
1176
  tape.zero()
1147
1177
 
1148
- dx = 0.0001
1178
+ dx = 0.001
1149
1179
  fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1150
1180
  for ii in range(2):
1151
1181
  for jj in range(2):
@@ -1180,9 +1210,9 @@ def test_qr(test, device, dtype, register_kernels=False):
1180
1210
  rng = np.random.default_rng(123)
1181
1211
 
1182
1212
  tol = {
1183
- np.float16: 2.0e-3,
1213
+ np.float16: 2.5e-3,
1184
1214
  np.float32: 1.0e-6,
1185
- np.float64: 1.0e-6,
1215
+ np.float64: 1.0e-12,
1186
1216
  }.get(dtype, 0)
1187
1217
 
1188
1218
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
warp/tests/test_struct.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import gc # Added for garbage collection tests
16
17
  import unittest
17
18
  from typing import Any
18
19
 
@@ -221,6 +222,11 @@ def test_nested_struct(test, device):
221
222
  foo.bar.y = 1.23
222
223
  foo.x = 123
223
224
 
225
+ # verify that struct attributes are instances of their original class
226
+ assert isinstance(foo, Foo.cls)
227
+ assert isinstance(foo.bar, Bar.cls)
228
+ assert isinstance(foo.bar.baz, Baz.cls)
229
+
224
230
  wp.launch(kernel_nested_struct, dim=dim, inputs=[foo], device=device)
225
231
 
226
232
  assert_array_equal(
@@ -243,6 +249,18 @@ def test_struct_attribute_error(test, device):
243
249
  )
244
250
 
245
251
 
252
+ def test_struct_inheritance_error(test, device):
253
+ with test.assertRaisesRegex(RuntimeError, r"Warp structs must be defined as base classes$"):
254
+
255
+ @wp.struct
256
+ class Parent:
257
+ x: int
258
+
259
+ @wp.struct
260
+ class Child(Parent):
261
+ y: int
262
+
263
+
246
264
  @wp.kernel
247
265
  def test_struct_instantiate(data: wp.array(dtype=int)):
248
266
  baz = Baz(data, wp.vec3(0.0, 0.0, 26.0))
@@ -643,6 +661,96 @@ def test_struct_array_hash(test, device):
643
661
  )
644
662
 
645
663
 
664
+ # Tests for garbage collection behavior with arrays in structs
665
+ @wp.struct
666
+ class StructWithArray:
667
+ data: wp.array(dtype=float)
668
+ some_value: int
669
+
670
+
671
+ @wp.kernel
672
+ def access_array_kernel(s: StructWithArray, out: wp.array(dtype=float)):
673
+ # This kernel is used to verify data integrity by reading the first element.
674
+ # Assumes s.data has at least 1 element for this test.
675
+ out[0] = s.data[0]
676
+
677
+
678
+ @wp.kernel
679
+ def compute_loss_from_struct_array_kernel(s_in: StructWithArray, loss_val: wp.array(dtype=float)):
680
+ # Compute a simple scalar loss from the array elements for grad testing.
681
+ # Assumes s_in.data has at least 2 elements for this test.
682
+ res = 0.0
683
+ res += s_in.data[0] * 2.0 # Example weight
684
+ res += s_in.data[1] * 3.0 # Example weight
685
+ loss_val[0] = res
686
+
687
+
688
+ def test_struct_array_gc_direct_assignment(test, device):
689
+ """
690
+ Tests that an array assigned to a struct (with no other direct Python
691
+ references) is not garbage collected prematurely.
692
+ """
693
+ wp.init()
694
+
695
+ s = StructWithArray()
696
+ s.some_value = 20
697
+
698
+ # Create an array, then assign it to the struct.
699
+ # After this assignment, 's.data' is the primary way to access it from
700
+ # Python's perspective, though Warp's context should also hold a reference.
701
+ local_array = wp.array([4.0, 5.0, 6.0], dtype=float, device=device)
702
+ s.data = local_array
703
+ del local_array # Remove the direct Python reference
704
+
705
+ # Force garbage collection
706
+ gc.collect()
707
+
708
+ # Attempt to access the array in a kernel
709
+ out_wp = wp.zeros(1, dtype=float, device=device)
710
+ try:
711
+ wp.launch(kernel=access_array_kernel, dim=1, inputs=[s, out_wp], device=device)
712
+
713
+ # We expect to read 4.0 if the array is still valid
714
+ assert out_wp.numpy()[0] == 4.0, "Array data was not accessible or incorrect after GC with direct assignment."
715
+ except Exception as e:
716
+ test.fail(f"Kernel execution failed after GC with direct assignment: {e}")
717
+
718
+
719
+ def test_struct_array_gc_requires_grad_toggle(test, device):
720
+ """
721
+ Tests that an array within a struct is not garbage collected prematurely
722
+ when its requires_grad flag is toggled, and that backward pass works.
723
+ """
724
+ wp.init()
725
+
726
+ s = StructWithArray()
727
+ s.some_value = 10
728
+ # Initialize array with requires_grad=True. Content: [1.0, 2.0, 3.0]
729
+ s.data = wp.array([1.0, 2.0, 3.0], dtype=float, device=device, requires_grad=True)
730
+
731
+ loss_wp = wp.zeros(1, dtype=float, device=device, requires_grad=True)
732
+
733
+ tape = wp.Tape()
734
+ with tape:
735
+ # Launch kernel that uses s.data to compute a loss
736
+ wp.launch(
737
+ kernel=compute_loss_from_struct_array_kernel,
738
+ dim=1,
739
+ inputs=[s, loss_wp],
740
+ device=device,
741
+ )
742
+
743
+ # Expected loss = 1.0*2.0 + 2.0*3.0 = 2.0 + 6.0 = 8.0
744
+
745
+ # After the forward pass is recorded, toggle requires_grad and run GC
746
+ s.data.requires_grad = False
747
+ gc.collect()
748
+
749
+ # will cause a memory access violation if grad array has been garbage collected
750
+ # or struct is not updated correctly
751
+ tape.backward(loss=loss_wp)
752
+
753
+
646
754
  devices = get_test_devices()
647
755
 
648
756
 
@@ -677,6 +785,8 @@ add_kernel_test(
677
785
  )
678
786
  add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
679
787
  add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
788
+ add_function_test(TestStruct, "test_struct_attribute_error", test_struct_attribute_error, devices=devices)
789
+ add_function_test(TestStruct, "test_struct_inheritance_error", test_struct_inheritance_error, devices=devices)
680
790
  add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
681
791
  add_function_test(TestStruct, "test_convert_to_device", test_convert_to_device, devices=devices)
682
792
  add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
@@ -727,6 +837,12 @@ add_kernel_test(
727
837
  )
728
838
 
729
839
  add_function_test(TestStruct, "test_struct_array_hash", test_struct_array_hash, devices=None)
840
+ add_function_test(
841
+ TestStruct, "test_struct_array_gc_requires_grad_toggle", test_struct_array_gc_requires_grad_toggle, devices=devices
842
+ )
843
+ add_function_test(
844
+ TestStruct, "test_struct_array_gc_direct_assignment", test_struct_array_gc_direct_assignment, devices=devices
845
+ )
730
846
 
731
847
 
732
848
  if __name__ == "__main__":
@@ -531,6 +531,32 @@ def test_tile_extract_repeated(test, device):
531
531
  assert_np_equal(a.grad.numpy(), expected_grad)
532
532
 
533
533
 
534
+ @wp.kernel
535
+ def test_tile_assign_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
536
+ i, j = wp.tid()
537
+
538
+ a = wp.tile_zeros(shape=(TILE_M,), dtype=float)
539
+
540
+ a[j] = x[j]
541
+
542
+ wp.tile_atomic_add(y, a, offset=(0,))
543
+
544
+
545
+ def test_tile_assign(test, device):
546
+ x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
547
+ y = wp.zeros(TILE_M, dtype=float, device=device, requires_grad=True)
548
+
549
+ tape = wp.Tape()
550
+ with tape:
551
+ wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=64, device=device)
552
+
553
+ y.grad = wp.ones_like(y)
554
+ tape.backward()
555
+
556
+ assert_np_equal(y.numpy(), np.full(TILE_M, 2.0, dtype=np.float32))
557
+ assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
558
+
559
+
534
560
  @wp.kernel
535
561
  def test_tile_transpose_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
536
562
  x = wp.tile_load(input, shape=(TILE_M, TILE_N))
@@ -767,6 +793,7 @@ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, che
767
793
  add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
768
794
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
769
795
  add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
796
+ add_function_test(TestTile, "test_tile_assign", test_tile_assign, devices=devices)
770
797
  add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
771
798
  add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
772
799
  add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)