warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_struct.py CHANGED
@@ -5,15 +5,16 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
8
9
  from typing import Any
10
+
9
11
  import numpy as np
12
+
10
13
  import warp as wp
11
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
12
15
 
13
16
  from warp.fem import Sample as StructFromAnotherModule
14
17
 
15
- import unittest
16
-
17
18
  wp.init()
18
19
 
19
20
 
@@ -291,7 +292,7 @@ def test_struct_math_conversions(test, device):
291
292
  s.m5 = [10, 20, 30, 40]
292
293
  s.m6 = np.array([100, 200, 300, 400])
293
294
 
294
- wp.launch(check_math_conversions, dim=1, inputs=[s])
295
+ wp.launch(check_math_conversions, dim=1, inputs=[s], device=device)
295
296
 
296
297
 
297
298
  @wp.struct
@@ -416,9 +417,9 @@ def test_nested_array_struct(test, device):
416
417
  var2.i = 2
417
418
 
418
419
  struct = ArrayStruct()
419
- struct.array = wp.array([var1, var2], dtype=InnerStruct)
420
+ struct.array = wp.array([var1, var2], dtype=InnerStruct, device=device)
420
421
 
421
- wp.launch(struct2_reader, dim=2, inputs=[struct])
422
+ wp.launch(struct2_reader, dim=2, inputs=[struct], device=device)
422
423
 
423
424
 
424
425
  @wp.struct
@@ -564,81 +565,111 @@ def test_dependent_module_import(c: DependentModuleImport_C):
564
565
  wp.tid() # nop, we're just testing codegen
565
566
 
566
567
 
567
- def register(parent):
568
- devices = get_test_devices()
568
+ devices = get_test_devices()
569
569
 
570
- class TestStruct(parent):
571
- pass
572
570
 
573
- add_function_test(TestStruct, "test_step", test_step, devices=devices)
574
- add_function_test(TestStruct, "test_step_grad", test_step_grad, devices=devices)
575
- add_kernel_test(TestStruct, kernel=test_empty, name="test_empty", dim=1, inputs=[Empty()], devices=devices)
576
- add_kernel_test(
577
- TestStruct,
578
- kernel=test_uninitialized,
579
- name="test_uninitialized",
580
- dim=1,
581
- inputs=[Uninitialized()],
582
- devices=devices,
583
- )
584
- add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
585
- add_function_test(TestStruct, "test_struct_attribute_error", test_struct_attribute_error, devices=devices)
586
- add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
587
- add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
588
- add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
589
- add_function_test(TestStruct, "test_struct_math_conversions", test_struct_math_conversions, devices=devices)
590
- add_function_test(
591
- TestStruct, "test_struct_default_attributes_python", test_struct_default_attributes_python, devices=devices
592
- )
593
- add_kernel_test(
594
- TestStruct,
595
- name="test_struct_default_attributes",
596
- kernel=test_struct_default_attributes_kernel,
597
- dim=1,
598
- inputs=[],
599
- devices=devices,
600
- )
571
+ class TestStruct(unittest.TestCase):
572
+ pass
573
+
601
574
 
575
+ add_function_test(TestStruct, "test_step", test_step, devices=devices)
576
+ add_function_test(TestStruct, "test_step_grad", test_step_grad, devices=devices)
577
+ add_kernel_test(TestStruct, kernel=test_empty, name="test_empty", dim=1, inputs=[Empty()], devices=devices)
578
+ add_kernel_test(
579
+ TestStruct,
580
+ kernel=test_uninitialized,
581
+ name="test_uninitialized",
582
+ dim=1,
583
+ inputs=[Uninitialized()],
584
+ devices=devices,
585
+ )
586
+ add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
587
+ add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
588
+ add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
589
+ add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
590
+ add_function_test(TestStruct, "test_struct_math_conversions", test_struct_math_conversions, devices=devices)
591
+ add_function_test(
592
+ TestStruct, "test_struct_default_attributes_python", test_struct_default_attributes_python, devices=devices
593
+ )
594
+ add_kernel_test(
595
+ TestStruct,
596
+ name="test_struct_default_attributes",
597
+ kernel=test_struct_default_attributes_kernel,
598
+ dim=1,
599
+ inputs=[],
600
+ devices=devices,
601
+ )
602
+
603
+ add_kernel_test(
604
+ TestStruct,
605
+ name="test_struct_mutate_attributes",
606
+ kernel=test_struct_mutate_attributes_kernel,
607
+ dim=1,
608
+ inputs=[],
609
+ devices=devices,
610
+ )
611
+ add_kernel_test(
612
+ TestStruct,
613
+ kernel=test_uninitialized,
614
+ name="test_uninitialized",
615
+ dim=1,
616
+ inputs=[Uninitialized()],
617
+ devices=devices,
618
+ )
619
+ add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
620
+ add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
621
+ add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
622
+ add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
623
+ add_function_test(TestStruct, "test_struct_math_conversions", test_struct_math_conversions, devices=devices)
624
+ add_function_test(
625
+ TestStruct, "test_struct_default_attributes_python", test_struct_default_attributes_python, devices=devices
626
+ )
627
+ add_kernel_test(
628
+ TestStruct,
629
+ name="test_struct_default_attributes",
630
+ kernel=test_struct_default_attributes_kernel,
631
+ dim=1,
632
+ inputs=[],
633
+ devices=devices,
634
+ )
635
+
636
+ add_kernel_test(
637
+ TestStruct,
638
+ name="test_struct_mutate_attributes",
639
+ kernel=test_struct_mutate_attributes_kernel,
640
+ dim=1,
641
+ inputs=[],
642
+ devices=devices,
643
+ )
644
+
645
+ for device in devices:
602
646
  add_kernel_test(
603
647
  TestStruct,
604
- name="test_struct_mutate_attributes",
605
- kernel=test_struct_mutate_attributes_kernel,
648
+ kernel=test_struct_instantiate,
649
+ name="test_struct_instantiate",
606
650
  dim=1,
607
- inputs=[],
608
- devices=devices,
651
+ inputs=[wp.array([1], dtype=int, device=device)],
652
+ devices=[device],
609
653
  )
610
-
611
- for device in devices:
612
- add_kernel_test(
613
- TestStruct,
614
- kernel=test_struct_instantiate,
615
- name="test_struct_instantiate",
616
- dim=1,
617
- inputs=[wp.array([1], dtype=int, device=device)],
618
- devices=[device],
619
- )
620
- add_kernel_test(
621
- TestStruct,
622
- kernel=test_return_struct,
623
- name="test_return_struct",
624
- dim=1,
625
- inputs=[wp.zeros(10, dtype=int, device=device)],
626
- devices=[device],
627
- )
628
-
629
654
  add_kernel_test(
630
655
  TestStruct,
631
- kernel=test_dependent_module_import,
632
- name="test_dependent_module_import",
656
+ kernel=test_return_struct,
657
+ name="test_return_struct",
633
658
  dim=1,
634
- inputs=[DependentModuleImport_C()],
635
- devices=devices,
659
+ inputs=[wp.zeros(10, dtype=int, device=device)],
660
+ devices=[device],
636
661
  )
637
662
 
638
- return TestStruct
663
+ add_kernel_test(
664
+ TestStruct,
665
+ kernel=test_dependent_module_import,
666
+ name="test_dependent_module_import",
667
+ dim=1,
668
+ inputs=[DependentModuleImport_C()],
669
+ devices=devices,
670
+ )
639
671
 
640
672
 
641
673
  if __name__ == "__main__":
642
674
  wp.build.clear_kernel_cache()
643
- _ = register(unittest.TestCase)
644
675
  unittest.main(verbosity=2)
warp/tests/test_tape.py CHANGED
@@ -10,7 +10,7 @@ import unittest
10
10
  import numpy as np
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
14
14
 
15
15
  wp.init()
16
16
 
@@ -127,28 +127,22 @@ def test_tape_dot_product(test, device):
127
127
  assert_np_equal(tape.gradients[y].numpy(), x.numpy())
128
128
 
129
129
 
130
- def test_tape_no_nested_tapes(test, device):
131
- with test.assertRaises(RuntimeError):
132
- with wp.Tape():
133
- with wp.Tape():
134
- pass
135
-
130
+ devices = get_test_devices()
136
131
 
137
- def register(parent):
138
- devices = get_test_devices()
139
132
 
140
- class TestTape(parent):
141
- pass
133
+ class TestTape(unittest.TestCase):
134
+ def test_tape_no_nested_tapes(self):
135
+ with self.assertRaises(RuntimeError):
136
+ with wp.Tape():
137
+ with wp.Tape():
138
+ pass
142
139
 
143
- add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
144
- add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
145
- add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
146
- add_function_test(TestTape, "test_tape_no_nested_tapes", test_tape_no_nested_tapes, devices=devices)
147
140
 
148
- return TestTape
141
+ add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
142
+ add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
143
+ add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
149
144
 
150
145
 
151
146
  if __name__ == "__main__":
152
147
  wp.build.clear_kernel_cache()
153
- _ = register(unittest.TestCase)
154
148
  unittest.main(verbosity=2)
warp/tests/test_torch.py CHANGED
@@ -7,11 +7,10 @@
7
7
 
8
8
  import unittest
9
9
 
10
- # include parent path
11
10
  import numpy as np
12
11
 
13
12
  import warp as wp
14
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  wp.init()
17
16
 
@@ -470,6 +469,8 @@ def test_torch_autograd(test, device):
470
469
  def test_torch_graph_torch_stream(test, device):
471
470
  """Capture Torch graph on Torch stream"""
472
471
 
472
+ wp.load_module(device=device)
473
+
473
474
  import torch
474
475
 
475
476
  torch_device = wp.device_to_torch(device)
@@ -551,12 +552,14 @@ def test_warp_graph_warp_stream(test, device):
551
552
 
552
553
  # capture graph
553
554
  with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
554
- wp.capture_begin()
555
- t += 1.0
556
- wp.launch(inc, dim=n, inputs=[a])
557
- t += 1.0
558
- wp.launch(inc, dim=n, inputs=[a])
559
- g = wp.capture_end()
555
+ wp.capture_begin(force_module_load=False)
556
+ try:
557
+ t += 1.0
558
+ wp.launch(inc, dim=n, inputs=[a])
559
+ t += 1.0
560
+ wp.launch(inc, dim=n, inputs=[a])
561
+ finally:
562
+ g = wp.capture_end()
560
563
 
561
564
  # replay graph
562
565
  num_iters = 10
@@ -570,6 +573,8 @@ def test_warp_graph_warp_stream(test, device):
570
573
  def test_warp_graph_torch_stream(test, device):
571
574
  """Capture Warp graph on Torch stream"""
572
575
 
576
+ wp.load_module(device=device)
577
+
573
578
  import torch
574
579
 
575
580
  torch_device = wp.device_to_torch(device)
@@ -587,12 +592,14 @@ def test_warp_graph_torch_stream(test, device):
587
592
 
588
593
  # capture graph
589
594
  with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
590
- wp.capture_begin()
591
- t += 1.0
592
- wp.launch(inc, dim=n, inputs=[a])
593
- t += 1.0
594
- wp.launch(inc, dim=n, inputs=[a])
595
- g = wp.capture_end()
595
+ wp.capture_begin(force_module_load=False)
596
+ try:
597
+ t += 1.0
598
+ wp.launch(inc, dim=n, inputs=[a])
599
+ t += 1.0
600
+ wp.launch(inc, dim=n, inputs=[a])
601
+ finally:
602
+ g = wp.capture_end()
596
603
 
597
604
  # replay graph
598
605
  num_iters = 10
@@ -603,83 +610,79 @@ def test_warp_graph_torch_stream(test, device):
603
610
  assert passed.item()
604
611
 
605
612
 
606
- def register(parent):
607
- class TestTorch(parent):
608
- pass
609
-
610
- try:
611
- import torch
612
-
613
- # check which Warp devices work with Torch
614
- # CUDA devices may fail if Torch was not compiled with CUDA support
615
- test_devices = get_test_devices()
616
- torch_compatible_devices = []
617
- torch_compatible_cuda_devices = []
618
-
619
- for d in test_devices:
620
- try:
621
- t = torch.arange(10, device=wp.device_to_torch(d))
622
- t += 1
623
- torch_compatible_devices.append(d)
624
- if d.is_cuda:
625
- torch_compatible_cuda_devices.append(d)
626
- except Exception as e:
627
- print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
628
-
629
- if torch_compatible_devices:
630
- add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
631
- add_function_test(
632
- TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices
633
- )
634
- add_function_test(
635
- TestTorch,
636
- "test_from_torch_zero_strides",
637
- test_from_torch_zero_strides,
638
- devices=torch_compatible_devices,
639
- )
640
- add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
641
- add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
642
- add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
643
-
644
- if torch_compatible_cuda_devices:
645
- add_function_test(
646
- TestTorch,
647
- "test_torch_graph_torch_stream",
648
- test_torch_graph_torch_stream,
649
- devices=torch_compatible_cuda_devices,
650
- )
651
- add_function_test(
652
- TestTorch,
653
- "test_torch_graph_warp_stream",
654
- test_torch_graph_warp_stream,
655
- devices=torch_compatible_cuda_devices,
656
- )
657
- add_function_test(
658
- TestTorch,
659
- "test_warp_graph_warp_stream",
660
- test_warp_graph_warp_stream,
661
- devices=torch_compatible_cuda_devices,
662
- )
663
- add_function_test(
664
- TestTorch,
665
- "test_warp_graph_torch_stream",
666
- test_warp_graph_torch_stream,
667
- devices=torch_compatible_cuda_devices,
668
- )
613
+ class TestTorch(unittest.TestCase):
614
+ pass
669
615
 
670
- # multi-GPU tests
671
- if len(torch_compatible_cuda_devices) > 1:
672
- add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
673
- add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
674
- add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
675
616
 
676
- except Exception as e:
677
- print(f"Skipping Torch tests due to exception: {e}")
617
+ test_devices = get_test_devices()
618
+
619
+ try:
620
+ import torch
678
621
 
679
- return TestTorch
622
+ # check which Warp devices work with Torch
623
+ # CUDA devices may fail if Torch was not compiled with CUDA support
624
+ torch_compatible_devices = []
625
+ torch_compatible_cuda_devices = []
626
+
627
+ for d in test_devices:
628
+ try:
629
+ t = torch.arange(10, device=wp.device_to_torch(d))
630
+ t += 1
631
+ torch_compatible_devices.append(d)
632
+ if d.is_cuda:
633
+ torch_compatible_cuda_devices.append(d)
634
+ except Exception as e:
635
+ print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
636
+
637
+ if torch_compatible_devices:
638
+ add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
639
+ add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
640
+ add_function_test(
641
+ TestTorch,
642
+ "test_from_torch_zero_strides",
643
+ test_from_torch_zero_strides,
644
+ devices=torch_compatible_devices,
645
+ )
646
+ add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
647
+ add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
648
+ add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
649
+
650
+ if torch_compatible_cuda_devices:
651
+ add_function_test(
652
+ TestTorch,
653
+ "test_torch_graph_torch_stream",
654
+ test_torch_graph_torch_stream,
655
+ devices=torch_compatible_cuda_devices,
656
+ )
657
+ add_function_test(
658
+ TestTorch,
659
+ "test_torch_graph_warp_stream",
660
+ test_torch_graph_warp_stream,
661
+ devices=torch_compatible_cuda_devices,
662
+ )
663
+ add_function_test(
664
+ TestTorch,
665
+ "test_warp_graph_warp_stream",
666
+ test_warp_graph_warp_stream,
667
+ devices=torch_compatible_cuda_devices,
668
+ )
669
+ add_function_test(
670
+ TestTorch,
671
+ "test_warp_graph_torch_stream",
672
+ test_warp_graph_torch_stream,
673
+ devices=torch_compatible_cuda_devices,
674
+ )
675
+
676
+ # multi-GPU tests
677
+ if len(torch_compatible_cuda_devices) > 1:
678
+ add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
679
+ add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
680
+ add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
681
+
682
+ except Exception as e:
683
+ print(f"Skipping Torch tests due to exception: {e}")
680
684
 
681
685
 
682
686
  if __name__ == "__main__":
683
687
  wp.build.clear_kernel_cache()
684
- _ = register(unittest.TestCase)
685
688
  unittest.main(verbosity=2)
@@ -5,14 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import importlib
9
8
  import os
10
9
  import tempfile
11
10
  import unittest
11
+ from importlib import util
12
12
 
13
13
  import warp as wp
14
- from warp.tests.test_base import *
15
- from importlib import util
14
+ from warp.tests.unittest_utils import *
16
15
 
17
16
  CODE = """# -*- coding: utf-8 -*-
18
17
 
@@ -64,27 +63,25 @@ def test_transient_module(test, device):
64
63
  assert len(module.compute.module.functions) == 1
65
64
 
66
65
  data = module.Data()
67
- data.x = wp.array([123], dtype=int)
66
+ data.x = wp.array([123], dtype=int, device=device)
68
67
 
69
68
  wp.set_module_options({"foo": "bar"}, module=module)
70
69
  assert wp.get_module_options(module=module).get("foo") == "bar"
71
70
  assert module.compute.module.options.get("foo") == "bar"
72
71
 
73
- wp.launch(module.compute, dim=1, inputs=[data])
72
+ wp.launch(module.compute, dim=1, inputs=[data], device=device)
74
73
  assert_np_equal(data.x.numpy(), np.array([124]))
75
74
 
76
75
 
77
- def register(parent):
78
- devices = get_test_devices()
76
+ devices = get_test_devices()
77
+
79
78
 
80
- class TestTransientModule(parent):
81
- pass
79
+ class TestTransientModule(unittest.TestCase):
80
+ pass
82
81
 
83
- add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
84
- return TestTransientModule
85
82
 
83
+ add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
86
84
 
87
85
  if __name__ == "__main__":
88
86
  wp.build.clear_kernel_cache()
89
- _ = register(unittest.TestCase)
90
87
  unittest.main(verbosity=2)