warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_copy.py CHANGED
@@ -5,13 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
8
+ import unittest
9
+
9
10
  import numpy as np
10
11
 
11
12
  import warp as wp
12
- from warp.tests.test_base import *
13
-
14
- import unittest
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  wp.init()
17
16
 
@@ -200,19 +199,17 @@ def test_copy_indexed(test, device):
200
199
  assert_np_equal(a4.numpy(), expected4 * s)
201
200
 
202
201
 
203
- def register(parent):
204
- devices = get_test_devices()
202
+ devices = get_test_devices()
203
+
205
204
 
206
- class TestCopy(parent):
207
- pass
205
+ class TestCopy(unittest.TestCase):
206
+ pass
208
207
 
209
- add_function_test(TestCopy, "test_copy_strided", test_copy_strided, devices=devices)
210
- add_function_test(TestCopy, "test_copy_indexed", test_copy_indexed, devices=devices)
211
208
 
212
- return TestCopy
209
+ add_function_test(TestCopy, "test_copy_strided", test_copy_strided, devices=devices)
210
+ add_function_test(TestCopy, "test_copy_indexed", test_copy_indexed, devices=devices)
213
211
 
214
212
 
215
213
  if __name__ == "__main__":
216
214
  wp.build.clear_kernel_cache()
217
- _ = register(unittest.TestCase)
218
215
  unittest.main(verbosity=2)
warp/tests/test_ctypes.py CHANGED
@@ -5,14 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
8
+ import unittest
9
+
9
10
  import numpy as np
10
- import math
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
14
-
15
- import unittest
13
+ from warp.tests.unittest_utils import *
16
14
 
17
15
  wp.init()
18
16
 
@@ -24,11 +22,11 @@ def add_vec2(dest: wp.array(dtype=wp.vec2), c: wp.vec2):
24
22
 
25
23
 
26
24
  @wp.kernel
27
- def transform_vec2(dest: wp.array(dtype=wp.vec2), m: wp.mat22, v: wp.vec2):
25
+ def transform_vec2(dest_right: wp.array(dtype=wp.vec2), dest_left: wp.array(dtype=wp.vec2), m: wp.mat22, v: wp.vec2):
28
26
  tid = wp.tid()
29
27
 
30
- p = wp.mul(m, v)
31
- dest[tid] = p
28
+ dest_right[tid] = wp.mul(m, v)
29
+ dest_left[tid] = wp.mul(v, m)
32
30
 
33
31
 
34
32
  @wp.kernel
@@ -38,11 +36,11 @@ def add_vec3(dest: wp.array(dtype=wp.vec3), c: wp.vec3):
38
36
 
39
37
 
40
38
  @wp.kernel
41
- def transform_vec3(dest: wp.array(dtype=wp.vec3), m: wp.mat33, v: wp.vec3):
39
+ def transform_vec3(dest_right: wp.array(dtype=wp.vec3), dest_left: wp.array(dtype=wp.vec3), m: wp.mat33, v: wp.vec3):
42
40
  tid = wp.tid()
43
41
 
44
- p = wp.mul(m, v)
45
- dest[tid] = p
42
+ dest_right[tid] = wp.mul(m, v)
43
+ dest_left[tid] = wp.mul(v, m)
46
44
 
47
45
 
48
46
  @wp.kernel
@@ -63,12 +61,14 @@ def test_vec2_arg(test, device, n):
63
61
 
64
62
 
65
63
  def test_vec2_transform(test, device, n):
66
- dest = wp.zeros(n=n, dtype=wp.vec2, device=device)
64
+ dest_right = wp.zeros(n=n, dtype=wp.vec2, device=device)
65
+ dest_left = wp.zeros(n=n, dtype=wp.vec2, device=device)
67
66
  c = np.array((1.0, 2.0))
68
67
  m = np.array(((3.0, -1.0), (2.5, 4.0)))
69
68
 
70
- wp.launch(transform_vec2, dim=n, inputs=[dest, m, c], device=device)
71
- test.assertTrue(np.array_equal(dest.numpy(), np.tile(m @ c, (n, 1))))
69
+ wp.launch(transform_vec2, dim=n, inputs=[dest_right, dest_left, m, c], device=device)
70
+ test.assertTrue(np.array_equal(dest_right.numpy(), np.tile(m @ c, (n, 1))))
71
+ test.assertTrue(np.array_equal(dest_left.numpy(), np.tile(c @ m, (n, 1))))
72
72
 
73
73
 
74
74
  def test_vec3_arg(test, device, n):
@@ -80,12 +80,14 @@ def test_vec3_arg(test, device, n):
80
80
 
81
81
 
82
82
  def test_vec3_transform(test, device, n):
83
- dest = wp.zeros(n=n, dtype=wp.vec3, device=device)
83
+ dest_right = wp.zeros(n=n, dtype=wp.vec3, device=device)
84
+ dest_left = wp.zeros(n=n, dtype=wp.vec3, device=device)
84
85
  c = np.array((1.0, 2.0, 3.0))
85
86
  m = np.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0), (7.0, 8.0, 9.0)))
86
87
 
87
- wp.launch(transform_vec3, dim=n, inputs=[dest, m, c], device=device)
88
- test.assertTrue(np.array_equal(dest.numpy(), np.tile(m @ c, (n, 1))))
88
+ wp.launch(transform_vec3, dim=n, inputs=[dest_right, dest_left, m, c], device=device)
89
+ test.assertTrue(np.array_equal(dest_right.numpy(), np.tile(m @ c, (n, 1))))
90
+ test.assertTrue(np.array_equal(dest_left.numpy(), np.tile(c @ m, (n, 1))))
89
91
 
90
92
 
91
93
  def test_transform_multiply(test, device, n):
@@ -552,89 +554,79 @@ def test_transform_matrix():
552
554
  wp.expect_near(r_2, r_0 - t, 1.0e-4)
553
555
 
554
556
 
555
- def register(parent):
556
- devices = get_test_devices()
557
-
558
- class TestCTypes(parent):
559
- pass
560
-
561
- inputs = [
562
- wp.vec2(1.0, 2.0),
563
- wp.vec3(1.0, 2.0, 3.0),
564
- wp.vec4(1.0, 2.0, 3.0, 4.0),
565
- wp.mat22(1.0, 2.0, 3.0, 4.0),
566
- wp.mat33(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0),
567
- wp.mat44(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0),
568
- ]
569
-
570
- add_function_test(TestCTypes, "test_mat22", test_mat22, devices=devices)
571
- add_function_test(TestCTypes, "test_mat33", test_mat33, devices=devices)
572
- add_function_test(TestCTypes, "test_mat44", test_mat44, devices=devices)
573
- add_kernel_test(
574
- TestCTypes,
575
- name="test_transformation_constructor",
576
- kernel=test_transformation_constructor,
577
- dim=1,
578
- devices=devices,
579
- )
580
- add_kernel_test(
581
- TestCTypes,
582
- name="test_spatial_vector_constructor",
583
- kernel=test_spatial_vector_constructor,
584
- dim=1,
585
- devices=devices,
586
- )
587
- add_kernel_test(
588
- TestCTypes,
589
- name="test_scalar_arg_types",
590
- kernel=test_scalar_arg_types,
591
- dim=1,
592
- inputs=[-64, 255, -64, 255, -64, 255, -64, 255, 3.14159, 3.14159],
593
- devices=devices,
594
- )
595
- add_kernel_test(
596
- TestCTypes,
597
- name="test_scalar_arg_types_explicit",
598
- kernel=test_scalar_arg_types,
599
- dim=1,
600
- inputs=[
601
- wp.int8(-64),
602
- wp.uint8(255),
603
- wp.int16(-64),
604
- wp.uint16(255),
605
- wp.int32(-64),
606
- wp.uint32(255),
607
- wp.int64(-64),
608
- wp.uint64(255),
609
- wp.float32(3.14159),
610
- wp.float64(3.14159),
611
- ],
612
- devices=devices,
613
- )
614
- add_kernel_test(
615
- TestCTypes, name="test_vector_arg_types", kernel=test_vector_arg_types, dim=1, inputs=inputs, devices=devices
616
- )
617
- add_kernel_test(TestCTypes, name="test_type_convesrions", kernel=test_type_conversions, dim=1, devices=devices)
618
-
619
- add_function_test(
620
- TestCTypes, "test_scalar_array_load", test_scalar_array_types, devices=devices, load=True, store=False
621
- )
622
- add_function_test(
623
- TestCTypes, "test_scalar_array_store", test_scalar_array_types, devices=devices, load=False, store=True
624
- )
625
- add_function_test(TestCTypes, "test_vec2_arg", test_vec2_arg, devices=devices, n=8)
626
- add_function_test(TestCTypes, "test_vec2_transform", test_vec2_transform, devices=devices, n=8)
627
- add_function_test(TestCTypes, "test_vec3_arg", test_vec3_arg, devices=devices, n=8)
628
- add_function_test(TestCTypes, "test_vec3_transform", test_vec3_transform, devices=devices, n=8)
629
- add_function_test(TestCTypes, "test_transform_multiply", test_transform_multiply, devices=devices, n=8)
630
- add_kernel_test(TestCTypes, name="test_transform_matrix", kernel=test_transform_matrix, dim=1, devices=devices)
631
- add_function_test(TestCTypes, "test_scalar_array", test_scalar_array, devices=devices)
632
- add_function_test(TestCTypes, "test_vector_array", test_vector_array, devices=devices)
633
-
634
- return TestCTypes
557
+ devices = get_test_devices()
558
+
559
+
560
+ class TestCTypes(unittest.TestCase):
561
+ pass
562
+
563
+
564
+ inputs = [
565
+ wp.vec2(1.0, 2.0),
566
+ wp.vec3(1.0, 2.0, 3.0),
567
+ wp.vec4(1.0, 2.0, 3.0, 4.0),
568
+ wp.mat22(1.0, 2.0, 3.0, 4.0),
569
+ wp.mat33(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0),
570
+ wp.mat44(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0),
571
+ ]
572
+
573
+ add_function_test(TestCTypes, "test_mat22", test_mat22, devices=devices)
574
+ add_function_test(TestCTypes, "test_mat33", test_mat33, devices=devices)
575
+ add_function_test(TestCTypes, "test_mat44", test_mat44, devices=devices)
576
+ add_kernel_test(
577
+ TestCTypes, name="test_transformation_constructor", kernel=test_transformation_constructor, dim=1, devices=devices
578
+ )
579
+ add_kernel_test(
580
+ TestCTypes, name="test_spatial_vector_constructor", kernel=test_spatial_vector_constructor, dim=1, devices=devices
581
+ )
582
+ add_kernel_test(
583
+ TestCTypes,
584
+ name="test_scalar_arg_types",
585
+ kernel=test_scalar_arg_types,
586
+ dim=1,
587
+ inputs=[-64, 255, -64, 255, -64, 255, -64, 255, 3.14159, 3.14159],
588
+ devices=devices,
589
+ )
590
+ add_kernel_test(
591
+ TestCTypes,
592
+ name="test_scalar_arg_types_explicit",
593
+ kernel=test_scalar_arg_types,
594
+ dim=1,
595
+ inputs=[
596
+ wp.int8(-64),
597
+ wp.uint8(255),
598
+ wp.int16(-64),
599
+ wp.uint16(255),
600
+ wp.int32(-64),
601
+ wp.uint32(255),
602
+ wp.int64(-64),
603
+ wp.uint64(255),
604
+ wp.float32(3.14159),
605
+ wp.float64(3.14159),
606
+ ],
607
+ devices=devices,
608
+ )
609
+ add_kernel_test(
610
+ TestCTypes, name="test_vector_arg_types", kernel=test_vector_arg_types, dim=1, inputs=inputs, devices=devices
611
+ )
612
+ add_kernel_test(TestCTypes, name="test_type_convesrions", kernel=test_type_conversions, dim=1, devices=devices)
613
+
614
+ add_function_test(
615
+ TestCTypes, "test_scalar_array_load", test_scalar_array_types, devices=devices, load=True, store=False
616
+ )
617
+ add_function_test(
618
+ TestCTypes, "test_scalar_array_store", test_scalar_array_types, devices=devices, load=False, store=True
619
+ )
620
+ add_function_test(TestCTypes, "test_vec2_arg", test_vec2_arg, devices=devices, n=8)
621
+ add_function_test(TestCTypes, "test_vec2_transform", test_vec2_transform, devices=devices, n=8)
622
+ add_function_test(TestCTypes, "test_vec3_arg", test_vec3_arg, devices=devices, n=8)
623
+ add_function_test(TestCTypes, "test_vec3_transform", test_vec3_transform, devices=devices, n=8)
624
+ add_function_test(TestCTypes, "test_transform_multiply", test_transform_multiply, devices=devices, n=8)
625
+ add_kernel_test(TestCTypes, name="test_transform_matrix", kernel=test_transform_matrix, dim=1, devices=devices)
626
+ add_function_test(TestCTypes, "test_scalar_array", test_scalar_array, devices=devices)
627
+ add_function_test(TestCTypes, "test_vector_array", test_vector_array, devices=devices)
635
628
 
636
629
 
637
630
  if __name__ == "__main__":
638
631
  wp.build.clear_kernel_cache()
639
- _ = register(unittest.TestCase)
640
632
  unittest.main(verbosity=2)
warp/tests/test_dense.py CHANGED
@@ -1,11 +1,15 @@
1
- import numpy as np
2
- import math
3
-
4
- import warp as wp
5
- from warp.tests.test_base import *
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
6
7
 
7
8
  import unittest
8
9
 
10
+ import warp as wp
11
+ from warp.tests.unittest_utils import *
12
+
9
13
  wp.init()
10
14
 
11
15
 
@@ -42,20 +46,22 @@ def eval_dense_solve(
42
46
  wp.dense_solve(n, A, L, b, x)
43
47
 
44
48
 
45
- def register(parent):
46
- devices = get_test_devices()
47
-
48
- class TestDense(parent):
49
- pass
50
-
49
+ def test_dense_compilation(test, device):
51
50
  # just testing compilation of the dense matrix routines
52
51
  # most are deprecated / WIP
53
- wp.force_load()
52
+ wp.load_module(device=device)
53
+
54
+
55
+ devices = get_test_devices()
56
+
57
+
58
+ class TestDense(unittest.TestCase):
59
+ pass
60
+
54
61
 
55
- return TestDense
62
+ add_function_test(TestDense, "test_dense_compilation", test_dense_compilation, devices=devices)
56
63
 
57
64
 
58
65
  if __name__ == "__main__":
59
66
  wp.build.clear_kernel_cache()
60
- _ = register(unittest.TestCase)
61
67
  unittest.main(verbosity=2)
@@ -8,33 +8,31 @@
8
8
  import unittest
9
9
 
10
10
  import warp as wp
11
- from warp.tests.test_base import *
11
+ from warp.tests.unittest_utils import *
12
12
 
13
13
  wp.init()
14
14
 
15
15
 
16
- def test_devices_get_device_functions(test, device):
17
- # save default device
18
- saved_device = wp.get_device()
19
-
20
- test.assertTrue(saved_device.is_cuda)
16
+ def test_devices_get_cuda_device_functions(test, device):
17
+ test.assertTrue(device.is_cuda)
21
18
  test.assertTrue(wp.is_device_available(device))
22
19
 
23
20
  device_ordinal = device.ordinal
24
21
  current_device = wp.get_cuda_device(device_ordinal)
25
22
  test.assertEqual(current_device, device)
26
-
27
23
  current_device = wp.get_cuda_device() # No-ordinal version
28
24
  test.assertTrue(wp.is_device_available(current_device))
29
25
 
26
+ if device == current_device:
27
+ test.assertEqual(device, "cuda")
28
+ else:
29
+ test.assertNotEqual(device, "cuda")
30
+
30
31
  preferred_device = wp.get_preferred_device()
31
32
  test.assertTrue(wp.is_device_available(preferred_device))
32
33
 
33
- # restore default device
34
- wp.set_device(saved_device)
35
34
 
36
-
37
- def test_devices_map_device(test, device):
35
+ def test_devices_map_cuda_device(test, device):
38
36
  with wp.ScopedDevice(device):
39
37
  saved_alias = device.alias
40
38
  # Map alias twice to check code path
@@ -58,42 +56,43 @@ def test_devices_verify_cuda_device(test, device):
58
56
  wp.config.verify_cuda = verify_cuda_saved
59
57
 
60
58
 
59
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
61
60
  def test_devices_can_access_self(test, device):
62
61
  test.assertTrue(device.can_access(device))
63
62
 
64
- # Also test CPU access
65
- cpu_device = wp.get_device("cpu")
66
- if device != cpu_device:
67
- test.assertFalse(device.can_access(cpu_device))
63
+ for warp_device in wp.get_devices():
64
+ device_str = str(warp_device)
65
+
66
+ if (device.is_cpu and warp_device.is_cuda) or (device.is_cuda and warp_device.is_cpu):
67
+ test.assertFalse(device.can_access(warp_device))
68
+ test.assertNotEqual(device, warp_device)
69
+ test.assertNotEqual(device, device_str)
68
70
 
69
- test.assertNotEqual(cpu_device, "cuda")
70
71
 
72
+ devices = get_test_devices()
71
73
 
72
- def register(parent):
73
- devices = get_test_devices()
74
74
 
75
- class TestDevices(parent):
76
- pass
75
+ class TestDevices(unittest.TestCase):
76
+ pass
77
77
 
78
- add_function_test(
79
- TestDevices,
80
- "test_devices_get_device_functions",
81
- test_devices_get_device_functions,
82
- devices=wp.get_cuda_devices(),
83
- )
84
- add_function_test(TestDevices, "test_devices_map_device", test_devices_map_device, devices=wp.get_cuda_devices())
85
- add_function_test(
86
- TestDevices, "test_devices_unmap_imaginary_device", test_devices_unmap_imaginary_device, devices=devices
87
- )
88
- add_function_test(TestDevices, "test_devices_verify_cuda_device", test_devices_verify_cuda_device, devices=devices)
89
78
 
90
- if wp.is_cuda_available():
91
- add_function_test(TestDevices, "test_devices_can_access_self", test_devices_can_access_self, devices=devices)
79
+ add_function_test(
80
+ TestDevices,
81
+ "test_devices_get_cuda_device_functions",
82
+ test_devices_get_cuda_device_functions,
83
+ devices=get_unique_cuda_test_devices(),
84
+ )
85
+ add_function_test(
86
+ TestDevices, "test_devices_map_cuda_device", test_devices_map_cuda_device, devices=get_unique_cuda_test_devices()
87
+ )
88
+ add_function_test(
89
+ TestDevices, "test_devices_unmap_imaginary_device", test_devices_unmap_imaginary_device, devices=devices
90
+ )
91
+ add_function_test(TestDevices, "test_devices_verify_cuda_device", test_devices_verify_cuda_device, devices=devices)
92
92
 
93
- return TestDevices
93
+ add_function_test(TestDevices, "test_devices_can_access_self", test_devices_can_access_self, devices=devices)
94
94
 
95
95
 
96
96
  if __name__ == "__main__":
97
97
  wp.build.clear_kernel_cache()
98
- _ = register(unittest.TestCase)
99
98
  unittest.main(verbosity=2)
warp/tests/test_dlpack.py CHANGED
@@ -5,13 +5,14 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import numpy as np
9
- import unittest
10
- import os
11
8
  import ctypes
9
+ import os
10
+ import unittest
11
+
12
+ import numpy as np
12
13
 
13
14
  import warp as wp
14
- from warp.tests.test_base import *
15
+ from warp.tests.unittest_utils import *
15
16
 
16
17
  wp.init()
17
18
 
@@ -299,79 +300,77 @@ def test_dlpack_jax_to_warp(test, device):
299
300
  assert_np_equal(a2.numpy(), np.asarray(j))
300
301
 
301
302
 
302
- def register(parent):
303
- class TestDLPack(parent):
304
- pass
305
-
306
- devices = get_test_devices()
307
-
308
- add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
309
- add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
310
-
311
- # torch interop via dlpack
312
- try:
313
- import torch
314
- import torch.utils.dlpack
315
-
316
- # check which Warp devices work with Torch
317
- # CUDA devices may fail if Torch was not compiled with CUDA support
318
- test_devices = get_test_devices()
319
- torch_compatible_devices = []
320
- for d in test_devices:
321
- try:
322
- t = torch.arange(10, device=wp.device_to_torch(d))
323
- t += 1
324
- torch_compatible_devices.append(d)
325
- except Exception as e:
326
- print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}")
327
-
328
- if torch_compatible_devices:
329
- add_function_test(
330
- TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
331
- )
332
- add_function_test(
333
- TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
334
- )
335
-
336
- except Exception as e:
337
- print(f"Skipping Torch DLPack tests due to exception: {e}")
338
-
339
- # jax interop via dlpack
340
- try:
341
- # prevent Jax from gobbling up GPU memory
342
- os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
343
-
344
- import jax
345
- import jax.dlpack
346
-
347
- # check which Warp devices work with Jax
348
- # CUDA devices may fail if Jax cannot find a CUDA Toolkit
349
- test_devices = get_test_devices()
350
- jax_compatible_devices = []
351
- for d in test_devices:
352
- try:
353
- with jax.default_device(wp.device_to_jax(d)):
354
- j = jax.numpy.arange(10, dtype=jax.numpy.float32)
355
- j += 1
356
- jax_compatible_devices.append(d)
357
- except Exception as e:
358
- print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
359
-
360
- if jax_compatible_devices:
361
- add_function_test(
362
- TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
363
- )
364
- add_function_test(
365
- TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
366
- )
367
-
368
- except Exception as e:
369
- print(f"Skipping Jax DLPack tests due to exception: {e}")
370
-
371
- return TestDLPack
303
+ class TestDLPack(unittest.TestCase):
304
+ pass
305
+
306
+
307
+ devices = get_test_devices()
308
+
309
+ add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
310
+ add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
311
+
312
+ # torch interop via dlpack
313
+ try:
314
+ import torch
315
+ import torch.utils.dlpack
316
+
317
+ # check which Warp devices work with Torch
318
+ # CUDA devices may fail if Torch was not compiled with CUDA support
319
+ test_devices = get_test_devices()
320
+ torch_compatible_devices = []
321
+ for d in test_devices:
322
+ try:
323
+ t = torch.arange(10, device=wp.device_to_torch(d))
324
+ t += 1
325
+ torch_compatible_devices.append(d)
326
+ except Exception as e:
327
+ print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}")
328
+
329
+ if torch_compatible_devices:
330
+ add_function_test(
331
+ TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
332
+ )
333
+ add_function_test(
334
+ TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
335
+ )
336
+
337
+ except Exception as e:
338
+ print(f"Skipping Torch DLPack tests due to exception: {e}")
339
+
340
+ # jax interop via dlpack
341
+ try:
342
+ # prevent Jax from gobbling up GPU memory
343
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
344
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
345
+
346
+ import jax
347
+ import jax.dlpack
348
+
349
+ # check which Warp devices work with Jax
350
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
351
+ test_devices = get_test_devices()
352
+ jax_compatible_devices = []
353
+ for d in test_devices:
354
+ try:
355
+ with jax.default_device(wp.device_to_jax(d)):
356
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
357
+ j += 1
358
+ jax_compatible_devices.append(d)
359
+ except Exception as e:
360
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
361
+
362
+ if jax_compatible_devices:
363
+ add_function_test(
364
+ TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
365
+ )
366
+ add_function_test(
367
+ TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
368
+ )
369
+
370
+ except Exception as e:
371
+ print(f"Skipping Jax DLPack tests due to exception: {e}")
372
372
 
373
373
 
374
374
  if __name__ == "__main__":
375
375
  wp.build.clear_kernel_cache()
376
- _ = register(unittest.TestCase)
377
376
  unittest.main(verbosity=2)