warp-lang 1.3.2__py3-none-manylinux2014_x86_64.whl → 1.4.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (107) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp.so +0 -0
  4. warp/build_dll.py +8 -10
  5. warp/builtins.py +126 -4
  6. warp/codegen.py +435 -53
  7. warp/config.py +1 -1
  8. warp/context.py +678 -403
  9. warp/dlpack.py +2 -0
  10. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  11. warp/examples/core/example_render_opengl.py +12 -10
  12. warp/examples/fem/example_adaptive_grid.py +251 -0
  13. warp/examples/fem/example_apic_fluid.py +1 -1
  14. warp/examples/fem/example_diffusion_3d.py +2 -2
  15. warp/examples/fem/example_magnetostatics.py +1 -1
  16. warp/examples/fem/example_streamlines.py +1 -0
  17. warp/examples/fem/utils.py +23 -4
  18. warp/examples/sim/example_cloth.py +50 -6
  19. warp/fem/__init__.py +2 -0
  20. warp/fem/adaptivity.py +493 -0
  21. warp/fem/field/field.py +2 -1
  22. warp/fem/field/nodal_field.py +18 -26
  23. warp/fem/field/test.py +4 -4
  24. warp/fem/field/trial.py +4 -4
  25. warp/fem/geometry/__init__.py +1 -0
  26. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  27. warp/fem/geometry/nanogrid.py +55 -28
  28. warp/fem/space/__init__.py +1 -1
  29. warp/fem/space/nanogrid_function_space.py +69 -35
  30. warp/fem/utils.py +113 -107
  31. warp/jax_experimental.py +28 -15
  32. warp/native/array.h +0 -1
  33. warp/native/builtin.h +103 -6
  34. warp/native/bvh.cu +2 -0
  35. warp/native/cuda_util.cpp +14 -0
  36. warp/native/cuda_util.h +2 -0
  37. warp/native/error.cpp +4 -2
  38. warp/native/exports.h +99 -17
  39. warp/native/mat.h +97 -0
  40. warp/native/mesh.cpp +36 -0
  41. warp/native/mesh.cu +51 -0
  42. warp/native/mesh.h +1 -0
  43. warp/native/quat.h +43 -0
  44. warp/native/spatial.h +6 -0
  45. warp/native/vec.h +74 -0
  46. warp/native/warp.cpp +2 -1
  47. warp/native/warp.cu +10 -3
  48. warp/native/warp.h +8 -1
  49. warp/paddle.py +382 -0
  50. warp/sim/__init__.py +1 -0
  51. warp/sim/collide.py +519 -0
  52. warp/sim/integrator_euler.py +18 -5
  53. warp/sim/integrator_featherstone.py +5 -5
  54. warp/sim/integrator_vbd.py +1026 -0
  55. warp/sim/model.py +49 -23
  56. warp/stubs.py +459 -0
  57. warp/tape.py +2 -0
  58. warp/tests/aux_test_dependent.py +1 -0
  59. warp/tests/aux_test_name_clash1.py +32 -0
  60. warp/tests/aux_test_name_clash2.py +32 -0
  61. warp/tests/aux_test_square.py +1 -0
  62. warp/tests/test_array.py +222 -0
  63. warp/tests/test_async.py +3 -3
  64. warp/tests/test_atomic.py +6 -0
  65. warp/tests/test_closest_point_edge_edge.py +93 -1
  66. warp/tests/test_codegen.py +62 -15
  67. warp/tests/test_codegen_instancing.py +1457 -0
  68. warp/tests/test_collision.py +486 -0
  69. warp/tests/test_compile_consts.py +3 -28
  70. warp/tests/test_dlpack.py +170 -0
  71. warp/tests/test_examples.py +22 -8
  72. warp/tests/test_fast_math.py +10 -4
  73. warp/tests/test_fem.py +64 -0
  74. warp/tests/test_func.py +46 -0
  75. warp/tests/test_implicit_init.py +49 -0
  76. warp/tests/test_jax.py +58 -0
  77. warp/tests/test_mat.py +84 -0
  78. warp/tests/test_mesh_query_point.py +188 -0
  79. warp/tests/test_module_hashing.py +40 -0
  80. warp/tests/test_multigpu.py +3 -3
  81. warp/tests/test_overwrite.py +8 -0
  82. warp/tests/test_paddle.py +852 -0
  83. warp/tests/test_print.py +89 -0
  84. warp/tests/test_quat.py +111 -0
  85. warp/tests/test_reload.py +31 -1
  86. warp/tests/test_scalar_ops.py +2 -0
  87. warp/tests/test_static.py +412 -0
  88. warp/tests/test_streams.py +64 -3
  89. warp/tests/test_struct.py +4 -4
  90. warp/tests/test_torch.py +24 -0
  91. warp/tests/test_triangle_closest_point.py +137 -0
  92. warp/tests/test_types.py +1 -1
  93. warp/tests/test_vbd.py +386 -0
  94. warp/tests/test_vec.py +143 -0
  95. warp/tests/test_vec_scalar_ops.py +139 -0
  96. warp/tests/test_volume.py +30 -0
  97. warp/tests/unittest_suites.py +12 -0
  98. warp/tests/unittest_utils.py +9 -5
  99. warp/thirdparty/dlpack.py +3 -1
  100. warp/types.py +157 -34
  101. warp/utils.py +37 -14
  102. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
  103. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/RECORD +106 -94
  104. warp/tests/test_point_triangle_closest_point.py +0 -143
  105. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
  106. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
  107. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
warp/tests/test_dlpack.py CHANGED
@@ -188,6 +188,57 @@ def test_dlpack_dtypes_and_shapes(test, device):
188
188
  wrap_scalar_to_matrix_tensor(mat_type)
189
189
 
190
190
 
191
+ def test_dlpack_stream_arg(test, device):
192
+ # test valid range for the stream argument to array.__dlpack__()
193
+
194
+ data = np.arange(10)
195
+
196
+ def check_result(capsule):
197
+ result = wp.dlpack._from_dlpack(capsule)
198
+ assert_np_equal(result.numpy(), data)
199
+
200
+ with wp.ScopedDevice(device):
201
+ a = wp.array(data=data)
202
+
203
+ # stream arguments supported for all devices
204
+ check_result(a.__dlpack__())
205
+ check_result(a.__dlpack__(stream=None))
206
+ check_result(a.__dlpack__(stream=-1))
207
+
208
+ # device-specific stream arguments
209
+ if device.is_cuda:
210
+ check_result(a.__dlpack__(stream=0)) # default stream
211
+ check_result(a.__dlpack__(stream=1)) # legacy default stream
212
+ check_result(a.__dlpack__(stream=2)) # per thread default stream
213
+
214
+ # custom stream
215
+ stream = wp.Stream(device)
216
+ check_result(a.__dlpack__(stream=stream.cuda_stream))
217
+
218
+ # unsupported stream arguments
219
+ expected_error = r"DLPack stream must None or an integer >= -1"
220
+ with test.assertRaisesRegex(TypeError, expected_error):
221
+ check_result(a.__dlpack__(stream=-2))
222
+ with test.assertRaisesRegex(TypeError, expected_error):
223
+ check_result(a.__dlpack__(stream="nope"))
224
+ else:
225
+ expected_error = r"DLPack stream must be None or -1 for CPU device"
226
+
227
+ with test.assertRaisesRegex(TypeError, expected_error):
228
+ check_result(a.__dlpack__(stream=0))
229
+ with test.assertRaisesRegex(TypeError, expected_error):
230
+ check_result(a.__dlpack__(stream=1))
231
+ with test.assertRaisesRegex(TypeError, expected_error):
232
+ check_result(a.__dlpack__(stream=2))
233
+ with test.assertRaisesRegex(TypeError, expected_error):
234
+ check_result(a.__dlpack__(stream=1742))
235
+
236
+ with test.assertRaisesRegex(TypeError, expected_error):
237
+ check_result(a.__dlpack__(stream=-2))
238
+ with test.assertRaisesRegex(TypeError, expected_error):
239
+ check_result(a.__dlpack__(stream="nope"))
240
+
241
+
191
242
  def test_dlpack_warp_to_torch(test, device):
192
243
  import torch.utils.dlpack
193
244
 
@@ -299,6 +350,34 @@ def test_dlpack_torch_to_warp_v2(test, device):
299
350
  assert_np_equal(a.numpy(), t.cpu().numpy())
300
351
 
301
352
 
353
+ def test_dlpack_paddle_to_warp(test, device):
354
+ import paddle
355
+ import paddle.utils.dlpack
356
+
357
+ t = paddle.arange(N, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
358
+
359
+ # paddle do not implement __dlpack__ yet, so only test to_dlpack here
360
+ a = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(t))
361
+
362
+ item_size = wp.types.type_size_in_bytes(a.dtype)
363
+
364
+ test.assertEqual(a.ptr, t.data_ptr())
365
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
366
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
367
+ test.assertEqual(a.shape, tuple(t.shape))
368
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
369
+
370
+ assert_np_equal(a.numpy(), t.numpy())
371
+
372
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
373
+
374
+ assert_np_equal(a.numpy(), t.numpy())
375
+
376
+ paddle.assign(t + 1, t)
377
+
378
+ assert_np_equal(a.numpy(), t.numpy())
379
+
380
+
302
381
  def test_dlpack_warp_to_jax(test, device):
303
382
  import jax
304
383
  import jax.dlpack
@@ -370,6 +449,61 @@ def test_dlpack_warp_to_jax_v2(test, device):
370
449
  assert_np_equal(a.numpy(), np.asarray(j2))
371
450
 
372
451
 
452
+ def test_dlpack_warp_to_paddle(test, device):
453
+ import paddle.utils.dlpack
454
+
455
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
456
+
457
+ t = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(a))
458
+
459
+ item_size = wp.types.type_size_in_bytes(a.dtype)
460
+
461
+ test.assertEqual(a.ptr, t.data_ptr())
462
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
463
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
464
+ test.assertEqual(a.shape, tuple(t.shape))
465
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
466
+
467
+ assert_np_equal(a.numpy(), t.cpu().numpy())
468
+
469
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
470
+
471
+ assert_np_equal(a.numpy(), t.cpu().numpy())
472
+
473
+ paddle.assign(t + 1, t)
474
+
475
+ assert_np_equal(a.numpy(), t.cpu().numpy())
476
+
477
+
478
+ def test_dlpack_warp_to_paddle_v2(test, device):
479
+ # same as original test, but uses newer __dlpack__() method
480
+
481
+ import paddle.utils.dlpack
482
+
483
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
484
+
485
+ # pass the array directly
486
+ t = paddle.utils.dlpack.from_dlpack(a)
487
+
488
+ item_size = wp.types.type_size_in_bytes(a.dtype)
489
+
490
+ test.assertEqual(a.ptr, t.data_ptr())
491
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
492
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
493
+ test.assertEqual(a.shape, tuple(t.shape))
494
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
495
+
496
+ assert_np_equal(a.numpy(), t.numpy())
497
+
498
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
499
+
500
+ assert_np_equal(a.numpy(), t.numpy())
501
+
502
+ paddle.assign(t + 1, t)
503
+
504
+ assert_np_equal(a.numpy(), t.numpy())
505
+
506
+
373
507
  def test_dlpack_jax_to_warp(test, device):
374
508
  import jax
375
509
  import jax.dlpack
@@ -448,6 +582,7 @@ devices = get_test_devices()
448
582
 
449
583
  add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
450
584
  add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
585
+ add_function_test(TestDLPack, "test_dlpack_stream_arg", test_dlpack_stream_arg, devices=devices)
451
586
 
452
587
  # torch interop via dlpack
453
588
  try:
@@ -523,6 +658,41 @@ except Exception as e:
523
658
  print(f"Skipping Jax DLPack tests due to exception: {e}")
524
659
 
525
660
 
661
+ # paddle interop via dlpack
662
+ try:
663
+ import paddle
664
+ import paddle.utils.dlpack
665
+
666
+ # check which Warp devices work with paddle
667
+ # CUDA devices may fail if paddle was not compiled with CUDA support
668
+ test_devices = get_test_devices()
669
+ paddle_compatible_devices = []
670
+ for d in test_devices:
671
+ try:
672
+ t = paddle.arange(10).to(device=wp.device_to_paddle(d))
673
+ paddle.assign(t + 1, t)
674
+ paddle_compatible_devices.append(d)
675
+ except Exception as e:
676
+ print(f"Skipping paddle DLPack tests on device '{d}' due to exception: {e}")
677
+
678
+ if paddle_compatible_devices:
679
+ add_function_test(
680
+ TestDLPack, "test_dlpack_warp_to_paddle", test_dlpack_warp_to_paddle, devices=paddle_compatible_devices
681
+ )
682
+ add_function_test(
683
+ TestDLPack,
684
+ "test_dlpack_warp_to_paddle_v2",
685
+ test_dlpack_warp_to_paddle_v2,
686
+ devices=paddle_compatible_devices,
687
+ )
688
+ add_function_test(
689
+ TestDLPack, "test_dlpack_paddle_to_warp", test_dlpack_paddle_to_warp, devices=paddle_compatible_devices
690
+ )
691
+
692
+ except Exception as e:
693
+ print(f"Skipping Paddle DLPack tests due to exception: {e}")
694
+
695
+
526
696
  if __name__ == "__main__":
527
697
  wp.clear_kernel_cache()
528
698
  unittest.main(verbosity=2)
@@ -20,7 +20,12 @@ override example defaults so the example can run in less than ten seconds.
20
20
  Use {"usd_required": True} and {"torch_required": True} to skip running the test
21
21
  if usd-core or torch are not found in the Python environment.
22
22
 
23
+ Use "cutlass_required": True} to skip the test if Warp needs to be built with
24
+ CUTLASS.
25
+
23
26
  Use the "num_frames" and "train_iters" keys to control the number of steps.
27
+
28
+ Use "test_timeout" to override the default test timeout threshold of 300 seconds.
24
29
  """
25
30
 
26
31
  import os
@@ -37,6 +42,9 @@ from warp.tests.unittest_utils import (
37
42
  get_test_devices,
38
43
  sanitize_identifier,
39
44
  )
45
+ from warp.utils import check_p2p
46
+
47
+ wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
40
48
 
41
49
 
42
50
  def _build_command_line_options(test_options: Dict[str, Any]) -> list:
@@ -103,6 +111,10 @@ def add_example_test(
103
111
  if usd_required and not USD_AVAILABLE:
104
112
  test.skipTest("Requires usd-core")
105
113
 
114
+ cutlass_required = options.pop("cutlass_required", False)
115
+ if cutlass_required and not wp.context.runtime.core.is_cutlass_enabled():
116
+ test.skipTest("Warp was not built with CUTLASS support")
117
+
106
118
  # Find the current Warp cache
107
119
  warp_cache_path = wp.config.kernel_cache_dir
108
120
 
@@ -286,6 +298,7 @@ add_example_test(
286
298
  test_options_cuda={
287
299
  "train_iters": 1 if warp.context.runtime.core.is_debug_enabled() else 3,
288
300
  "num_frames": 1 if warp.context.runtime.core.is_debug_enabled() else 60,
301
+ "cutlass_required": True,
289
302
  },
290
303
  test_options_cpu={"train_iters": 1, "num_frames": 30},
291
304
  )
@@ -340,12 +353,14 @@ class TestFemDiffusionExamples(unittest.TestCase):
340
353
  pass
341
354
 
342
355
 
343
- add_example_test(
344
- TestFemDiffusionExamples,
345
- name="fem.example_diffusion_mgpu",
346
- devices=get_selected_cuda_test_devices(mode="basic"),
347
- test_options={"headless": True},
348
- )
356
+ # MGPU tests may fail on systems where P2P transfers are misconfigured
357
+ if check_p2p():
358
+ add_example_test(
359
+ TestFemDiffusionExamples,
360
+ name="fem.example_diffusion_mgpu",
361
+ devices=get_selected_cuda_test_devices(mode="basic"),
362
+ test_options={"headless": True},
363
+ )
349
364
 
350
365
  add_example_test(
351
366
  TestFemExamples,
@@ -433,5 +448,4 @@ add_example_test(
433
448
  if __name__ == "__main__":
434
449
  # force rebuild of all kernels
435
450
  wp.clear_kernel_cache()
436
-
437
- unittest.main(verbosity=2, failfast=True)
451
+ unittest.main(verbosity=2)
@@ -12,12 +12,19 @@ from warp.tests.unittest_utils import *
12
12
 
13
13
 
14
14
  @wp.kernel
15
- def test_pow(e: float, result: float):
15
+ def test_pow(e: float, expected: float):
16
16
  tid = wp.tid()
17
17
 
18
18
  y = wp.pow(-2.0, e)
19
19
 
20
- wp.expect_eq(y, result)
20
+ # Since equality comparisons with NaN's are false, we have to do something manually
21
+ if wp.isnan(expected):
22
+ if not wp.isnan(y):
23
+ print("Error, comparison failed")
24
+ wp.printf(" Expected: %f\n", expected)
25
+ wp.printf(" Actual: %f\n", y)
26
+ else:
27
+ wp.expect_eq(y, expected)
21
28
 
22
29
 
23
30
  def test_fast_math_disabled(test, device):
@@ -26,14 +33,13 @@ def test_fast_math_disabled(test, device):
26
33
  wp.launch(test_pow, dim=1, inputs=[2.0, 4.0], device=device)
27
34
 
28
35
 
29
- @unittest.expectedFailure
30
36
  def test_fast_math_cuda(test, device):
31
37
  # on CUDA with --fast-math enabled taking the pow()
32
38
  # of a negative number will result in a NaN
33
39
 
34
40
  wp.set_module_options({"fast_math": True})
35
41
  try:
36
- wp.launch(test_pow, dim=1, inputs=[2.0, 4.0], device=device)
42
+ wp.launch(test_pow, dim=1, inputs=[2.0, wp.NAN], device=device)
37
43
  finally:
38
44
  # Turn fast math back off
39
45
  wp.set_module_options({"fast_math": False})
warp/tests/test_fem.py CHANGED
@@ -430,6 +430,9 @@ def _launch_test_geometry_kernel(geo: fem.Geometry, device):
430
430
  pos_inner = geo.cell_position(cell_arg, inner_s)
431
431
  pos_outer = geo.cell_position(cell_arg, outer_s)
432
432
 
433
+ # if wp.length(pos_outer - pos_side) > 0.1:
434
+ # wp.print(side_index)
435
+
433
436
  for k in range(type(pos_side).length):
434
437
  wp.expect_near(pos_side[k], pos_inner[k], 0.0001)
435
438
  wp.expect_near(pos_side[k], pos_outer[k], 0.0001)
@@ -616,6 +619,66 @@ def test_nanogrid(test, device):
616
619
  assert_np_equal(cell_measures.numpy(), np.full(cell_measures.shape, 1.0 / (N**3)), tol=1.0e-4)
617
620
 
618
621
 
622
+ @wp.func
623
+ def _refinement_field(x: wp.vec3):
624
+ return 4.0 * (wp.length(x) - 0.5)
625
+
626
+
627
+ def test_adaptive_nanogrid(test, device):
628
+ # 3 res-1 voxels, 8 res-0 voxels
629
+
630
+ res0 = wp.array(
631
+ [
632
+ [2, 2, 0],
633
+ [2, 3, 0],
634
+ [3, 2, 0],
635
+ [3, 3, 0],
636
+ [2, 2, 1],
637
+ [2, 3, 1],
638
+ [3, 2, 1],
639
+ [3, 3, 1],
640
+ ],
641
+ dtype=int,
642
+ device=device,
643
+ )
644
+ res1 = wp.array(
645
+ [
646
+ [0, 0, 0],
647
+ [0, 1, 0],
648
+ [1, 0, 0],
649
+ [1, 1, 0],
650
+ ],
651
+ dtype=int,
652
+ device=device,
653
+ )
654
+
655
+ grid0 = wp.Volume.allocate_by_voxels(res0, 0.5, device=device)
656
+ grid1 = wp.Volume.allocate_by_voxels(res1, 1.0, device=device)
657
+ geo = fem.adaptive_nanogrid_from_hierarchy([grid0, grid1])
658
+
659
+ test.assertEqual(geo.cell_count(), 3 + 8)
660
+ test.assertEqual(geo.vertex_count(), 2 * 9 + 27 - 8)
661
+ test.assertEqual(geo.side_count(), 2 * 4 + 6 * 2 + (3 * (2 + 1) * 2**2 - 6))
662
+ test.assertEqual(geo.boundary_side_count(), 2 * 4 + 4 * 2 + (4 * 4 - 4))
663
+ # test.assertEqual(geo.edge_count(), 6 * 4 + 9 + (3 * 2 * (2 + 1) ** 2 - 12))
664
+ test.assertEqual(geo.stacked_face_count(), geo.side_count() + 2)
665
+ test.assertEqual(geo.stacked_edge_count(), 6 * 4 + 9 + (3 * 2 * (2 + 1) ** 2 - 12) + 7)
666
+
667
+ side_measures, cell_measures = _launch_test_geometry_kernel(geo, device)
668
+
669
+ test.assertAlmostEqual(np.sum(cell_measures.numpy()), 4.0, places=4)
670
+ test.assertAlmostEqual(np.sum(side_measures.numpy()), 20 + 3.0, places=4)
671
+
672
+ # Test with non-graded geometry
673
+ ref_field = fem.ImplicitField(fem.Cells(geo), func=_refinement_field)
674
+ non_graded_geo = fem.adaptive_nanogrid_from_field(grid1, level_count=3, refinement_field=ref_field)
675
+ _launch_test_geometry_kernel(geo, device)
676
+
677
+ # Test automatic grading
678
+ graded_geo = fem.adaptive_nanogrid_from_field(grid1, level_count=3, refinement_field=ref_field, grading="face")
679
+ test.assertEqual(non_graded_geo.cell_count() + 7, graded_geo.cell_count())
680
+
681
+
619
682
  @integrand
620
683
  def _rigid_deformation_field(s: Sample, domain: Domain, translation: wp.vec3, rotation: wp.vec3, scale: float):
621
684
  q = wp.quat_from_axis_angle(wp.normalize(rotation), wp.length(rotation))
@@ -1531,6 +1594,7 @@ add_function_test(TestFem, "test_grid_3d", test_grid_3d, devices=devices)
1531
1594
  add_function_test(TestFem, "test_tet_mesh", test_tet_mesh, devices=devices)
1532
1595
  add_function_test(TestFem, "test_hex_mesh", test_hex_mesh, devices=devices)
1533
1596
  add_function_test(TestFem, "test_nanogrid", test_nanogrid, devices=cuda_devices)
1597
+ add_function_test(TestFem, "test_adaptive_nanogrid", test_adaptive_nanogrid, devices=cuda_devices)
1534
1598
  add_function_test(TestFem, "test_deformed_geometry", test_deformed_geometry, devices=devices)
1535
1599
  add_function_test(TestFem, "test_dof_mapper", test_dof_mapper)
1536
1600
  add_function_test(TestFem, "test_point_basis", test_point_basis)
warp/tests/test_func.py CHANGED
@@ -7,6 +7,7 @@
7
7
 
8
8
  import math
9
9
  import unittest
10
+ from typing import Tuple
10
11
 
11
12
  import numpy as np
12
13
 
@@ -155,6 +156,41 @@ def test_builtin_shadowing():
155
156
  wp.expect_eq(sign(1.23), 123.0)
156
157
 
157
158
 
159
+ @wp.func
160
+ def user_func_with_defaults(a: int = 123, b: int = 234) -> int:
161
+ return a + b
162
+
163
+
164
+ @wp.kernel
165
+ def test_user_func_with_defaults():
166
+ a = user_func_with_defaults()
167
+ wp.expect_eq(a, 357)
168
+
169
+ b = user_func_with_defaults(111)
170
+ wp.expect_eq(b, 345)
171
+
172
+ c = user_func_with_defaults(111, 222)
173
+ wp.expect_eq(c, 333)
174
+
175
+ d = user_func_with_defaults(a=111)
176
+ wp.expect_eq(d, 345)
177
+
178
+ e = user_func_with_defaults(b=111)
179
+ wp.expect_eq(e, 234)
180
+
181
+
182
+ @wp.func
183
+ def user_func_return_multiple_values(a: int, b: float) -> Tuple[int, float]:
184
+ return a + a, b * b
185
+
186
+
187
+ @wp.kernel
188
+ def test_user_func_return_multiple_values():
189
+ a, b = user_func_return_multiple_values(123, 234.0)
190
+ wp.expect_eq(a, 246)
191
+ wp.expect_eq(b, 54756.0)
192
+
193
+
158
194
  devices = get_test_devices()
159
195
 
160
196
 
@@ -329,6 +365,16 @@ add_function_test(TestFunc, func=test_func_closure_capture, name="test_func_clos
329
365
  add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
330
366
  add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
331
367
  add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices)
368
+ add_kernel_test(
369
+ TestFunc, kernel=test_user_func_with_defaults, name="test_user_func_with_defaults", dim=1, devices=devices
370
+ )
371
+ add_kernel_test(
372
+ TestFunc,
373
+ kernel=test_user_func_return_multiple_values,
374
+ name="test_user_func_return_multiple_values",
375
+ dim=1,
376
+ devices=devices,
377
+ )
332
378
 
333
379
 
334
380
  if __name__ == "__main__":
@@ -347,6 +347,55 @@ add_function_test(
347
347
  )
348
348
 
349
349
 
350
+ # Structs
351
+ # ------------------------------------------------------------------------------
352
+
353
+
354
+ def test_struct_member_init(test, device):
355
+ @wp.struct
356
+ class S:
357
+ # fp16 requires conversion functions from warp.so
358
+ x: wp.float16
359
+ v: wp.vec3h
360
+
361
+ s = S()
362
+ s.x = 42.0
363
+ s.v = wp.vec3h(1.0, 2.0, 3.0)
364
+
365
+
366
+ class TestImplicitInitStructMemberInit(unittest.TestCase):
367
+ pass
368
+
369
+
370
+ add_function_test(
371
+ TestImplicitInitStructMemberInit,
372
+ "test_struct_member_init",
373
+ test_struct_member_init,
374
+ check_output=False,
375
+ )
376
+
377
+
378
+ # Tape
379
+ # ------------------------------------------------------------------------------
380
+
381
+
382
+ def test_tape(test, device):
383
+ with wp.Tape():
384
+ pass
385
+
386
+
387
+ class TestImplicitInitTape(unittest.TestCase):
388
+ pass
389
+
390
+
391
+ add_function_test(
392
+ TestImplicitInitTape,
393
+ "test_tape",
394
+ test_tape,
395
+ check_output=False,
396
+ )
397
+
398
+
350
399
  if __name__ == "__main__":
351
400
  # Do not clear the kernel cache or call anything that would initialize Warp
352
401
  # since these tests are specifically aiming to catch issues where Warp isn't
warp/tests/test_jax.py CHANGED
@@ -246,6 +246,60 @@ def test_jax_kernel_multiarg(test, device):
246
246
  assert_np_equal(result_y, expected_y)
247
247
 
248
248
 
249
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
250
+ def test_jax_kernel_launch_dims(test, device):
251
+ import jax.numpy as jp
252
+
253
+ from warp.jax_experimental import jax_kernel
254
+
255
+ n = 64
256
+ m = 32
257
+
258
+ # Test with 1D launch dims
259
+ @wp.kernel
260
+ def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
261
+ tid = wp.tid()
262
+ y[tid] = x[tid] + 1.0
263
+
264
+ jax_add_one = jax_kernel(
265
+ add_one_kernel, launch_dims=(n - 2,)
266
+ ) # Intentionally not the same as the first dimension of the input
267
+
268
+ @jax.jit
269
+ def f_1d():
270
+ x = jp.arange(n, dtype=jp.float32)
271
+ return jax_add_one(x)
272
+
273
+ # Test with 2D launch dims
274
+ @wp.kernel
275
+ def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
276
+ i, j = wp.tid()
277
+ y[i, j] = x[i, j] + 1.0
278
+
279
+ jax_add_one_2d = jax_kernel(
280
+ add_one_2d_kernel, launch_dims=(n - 2, m - 2)
281
+ ) # Intentionally not the same as the first dimension of the input
282
+
283
+ @jax.jit
284
+ def f_2d():
285
+ x = jp.zeros((n, m), dtype=jp.float32) + 3.0
286
+ return jax_add_one_2d(x)
287
+
288
+ # run on the given device
289
+ with jax.default_device(wp.device_to_jax(device)):
290
+ y_1d = f_1d()
291
+ y_2d = f_2d()
292
+
293
+ result_1d = np.asarray(y_1d).reshape((n - 2,))
294
+ expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
295
+
296
+ result_2d = np.asarray(y_2d).reshape((n - 2, m - 2))
297
+ expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32)
298
+
299
+ assert_np_equal(result_1d, expected_1d)
300
+ assert_np_equal(result_2d, expected_2d)
301
+
302
+
249
303
  class TestJax(unittest.TestCase):
250
304
  pass
251
305
 
@@ -296,6 +350,10 @@ try:
296
350
  TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
297
351
  )
298
352
 
353
+ add_function_test(
354
+ TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices
355
+ )
356
+
299
357
  except Exception as e:
300
358
  print(f"Skipping Jax tests due to exception: {e}")
301
359
 
warp/tests/test_mat.py CHANGED
@@ -1559,6 +1559,83 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
1559
1559
  tape.zero()
1560
1560
 
1561
1561
 
1562
+ def test_mat_array_type_indexing(test, device, dtype, register_kernels=False):
1563
+ np_type = np.dtype(dtype)
1564
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
1565
+
1566
+ vec2 = wp.types.vector(length=2, dtype=wp_type)
1567
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wp_type)
1568
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wp_type)
1569
+
1570
+ def mattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=mat22)):
1571
+ tid = wp.tid()
1572
+
1573
+ t = a[tid]
1574
+ t[0, 0] = x[tid]
1575
+ a[tid] = t
1576
+
1577
+ def mattest_in_register(x: wp.array2d(dtype=mat22), y: wp.array(dtype=vec2)):
1578
+ i, j = wp.tid()
1579
+
1580
+ a = mat22(wp_type(0.0))
1581
+ a[0] = y[i]
1582
+ a[1, 1] = wp_type(3.0)
1583
+ x[i, j] = a
1584
+
1585
+ def mattest_in_register_overwrite(x: wp.array2d(dtype=mat22), y: wp.array(dtype=vec2)):
1586
+ i, j = wp.tid()
1587
+
1588
+ a = mat22(wp_type(0.0))
1589
+ a[0] = y[i]
1590
+ a[0, 1] = wp_type(3.0)
1591
+ x[i, j] = a
1592
+
1593
+ kernel_read_write_store = getkernel(mattest_read_write_store, suffix=dtype.__name__)
1594
+ kernel_in_register = getkernel(mattest_in_register, suffix=dtype.__name__)
1595
+ kernel_in_register_overwrite = getkernel(mattest_in_register_overwrite, suffix=dtype.__name__)
1596
+
1597
+ if register_kernels:
1598
+ return
1599
+
1600
+ a = wp.ones(1, dtype=mat22, device=device, requires_grad=True)
1601
+ x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1602
+
1603
+ tape = wp.Tape()
1604
+ with tape:
1605
+ wp.launch(kernel_read_write_store, dim=1, inputs=[x, a], device=device)
1606
+
1607
+ tape.backward(grads={a: wp.ones_like(a, requires_grad=False)})
1608
+
1609
+ assert_np_equal(a.numpy(), np.array([[[2.0, 1.0], [1.0, 1.0]]], dtype=np_type))
1610
+ assert_np_equal(x.grad.numpy(), np.array([1.0], dtype=np_type))
1611
+
1612
+ tape.reset()
1613
+
1614
+ x = wp.zeros((1, 1), dtype=mat22, device=device, requires_grad=True)
1615
+ y = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1616
+
1617
+ with tape:
1618
+ wp.launch(kernel_in_register, dim=(1, 1), inputs=[x, y], device=device)
1619
+
1620
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1621
+
1622
+ assert_np_equal(x.numpy(), np.array([[[[1.0, 1.0], [0.0, 3.0]]]], dtype=np_type))
1623
+ assert_np_equal(y.grad.numpy(), np.array([[1.0, 1.0]], dtype=np_type))
1624
+
1625
+ tape.reset()
1626
+
1627
+ x = wp.zeros((1, 1), dtype=mat22, device=device, requires_grad=True)
1628
+ y = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1629
+
1630
+ with tape:
1631
+ wp.launch(kernel_in_register_overwrite, dim=(1, 1), inputs=[x, y], device=device)
1632
+
1633
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1634
+
1635
+ assert_np_equal(x.numpy(), np.array([[[[1.0, 3.0], [0.0, 0.0]]]], dtype=np_type))
1636
+ assert_np_equal(y.grad.numpy(), np.array([[1.0, 0.0]], dtype=np_type))
1637
+
1638
+
1562
1639
  # Test matrix constructors using explicit type (float16)
1563
1640
  # note that these tests are specifically not using generics / closure
1564
1641
  # args to create kernels dynamically (like the rest of this file)
@@ -1791,6 +1868,13 @@ for dtype in np_float_types:
1791
1868
  TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
1792
1869
  )
1793
1870
  add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
1871
+ add_function_test_register_kernel(
1872
+ TestMat,
1873
+ f"test_mat_array_type_indexing_{dtype.__name__}",
1874
+ test_mat_array_type_indexing,
1875
+ devices=devices,
1876
+ dtype=dtype,
1877
+ )
1794
1878
 
1795
1879
 
1796
1880
  if __name__ == "__main__":