warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,18 @@
1
- from functools import partial
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
2
8
  import unittest
9
+ from functools import partial
3
10
 
4
11
  import numpy as np
5
- import warp as wp
6
12
 
13
+ import warp as wp
14
+ from warp.tests.unittest_utils import *
7
15
  from warp.utils import runlength_encode
8
- from warp.tests.test_base import *
9
16
 
10
17
  wp.init()
11
18
 
@@ -29,39 +36,10 @@ def test_runlength_encode_int(test, device, n):
29
36
  assert_np_equal(unique_counts.numpy()[:run_count], unique_counts_np[:run_count])
30
37
 
31
38
 
32
- def test_runlength_encode_error_devices_mismatch(test, device):
33
- values = wp.zeros(123, dtype=int, device="cpu")
34
- run_values = wp.empty_like(values, device="cuda:0")
35
- run_lengths = wp.empty_like(values, device="cuda:0")
36
- with test.assertRaisesRegex(
37
- RuntimeError,
38
- r"Array storage devices do not match$",
39
- ):
40
- runlength_encode(values, run_values, run_lengths)
41
-
42
- values = wp.zeros(123, dtype=int, device="cpu")
43
- run_values = wp.empty_like(values, device="cpu")
44
- run_lengths = wp.empty_like(values, device="cuda:0")
45
- with test.assertRaisesRegex(
46
- RuntimeError,
47
- r"Array storage devices do not match$",
48
- ):
49
- runlength_encode(values, run_values, run_lengths)
50
-
51
- values = wp.zeros(123, dtype=int, device="cpu")
52
- run_values = wp.empty_like(values, device="cuda:0")
53
- run_lengths = wp.empty_like(values, device="cpu")
54
- with test.assertRaisesRegex(
55
- RuntimeError,
56
- r"Array storage devices do not match$",
57
- ):
58
- runlength_encode(values, run_values, run_lengths)
59
-
60
-
61
39
  def test_runlength_encode_error_insufficient_storage(test, device):
62
- values = wp.zeros(123, dtype=int, device="cpu")
63
- run_values = wp.empty(1, dtype=int, device="cpu")
64
- run_lengths = wp.empty(123, dtype=int, device="cpu")
40
+ values = wp.zeros(123, dtype=int, device=device)
41
+ run_values = wp.empty(1, dtype=int, device=device)
42
+ run_lengths = wp.empty(123, dtype=int, device=device)
65
43
  with test.assertRaisesRegex(
66
44
  RuntimeError,
67
45
  r"Output array storage sizes must be at least equal to value_count$",
@@ -79,9 +57,9 @@ def test_runlength_encode_error_insufficient_storage(test, device):
79
57
 
80
58
 
81
59
  def test_runlength_encode_error_dtypes_mismatch(test, device):
82
- values = wp.zeros(123, dtype=int, device="cpu")
83
- run_values = wp.empty(123, dtype=float, device="cpu")
84
- run_lengths = wp.empty_like(values)
60
+ values = wp.zeros(123, dtype=int, device=device)
61
+ run_values = wp.empty(123, dtype=float, device=device)
62
+ run_lengths = wp.empty_like(values, device=device)
85
63
  with test.assertRaisesRegex(
86
64
  RuntimeError,
87
65
  r"values and run_values data types do not match$",
@@ -90,9 +68,9 @@ def test_runlength_encode_error_dtypes_mismatch(test, device):
90
68
 
91
69
 
92
70
  def test_runlength_encode_error_run_length_unsupported_dtype(test, device):
93
- values = wp.zeros(123, dtype=int, device="cpu")
94
- run_values = wp.empty(123, dtype=int, device="cpu")
95
- run_lengths = wp.empty(123, dtype=float, device="cpu")
71
+ values = wp.zeros(123, dtype=int, device=device)
72
+ run_values = wp.empty(123, dtype=int, device=device)
73
+ run_lengths = wp.empty(123, dtype=float, device=device)
96
74
  with test.assertRaisesRegex(
97
75
  RuntimeError,
98
76
  r"run_lengths array must be of type int32$",
@@ -100,23 +78,11 @@ def test_runlength_encode_error_run_length_unsupported_dtype(test, device):
100
78
  runlength_encode(values, run_values, run_lengths)
101
79
 
102
80
 
103
- def test_runlength_encode_error_run_count_device_mismatch(test, device):
104
- values = wp.zeros(123, dtype=int, device="cpu")
105
- run_values = wp.empty_like(values, device="cpu")
106
- run_lengths = wp.empty_like(values, device="cpu")
107
- run_count = wp.empty(shape=(1,), dtype=int, device="cuda:0")
108
- with test.assertRaisesRegex(
109
- RuntimeError,
110
- r"run_count storage device does not match other arrays$",
111
- ):
112
- runlength_encode(values, run_values, run_lengths, run_count=run_count)
113
-
114
-
115
81
  def test_runlength_encode_error_run_count_unsupported_dtype(test, device):
116
- values = wp.zeros(123, dtype=int, device="cpu")
117
- run_values = wp.empty_like(values, device="cpu")
118
- run_lengths = wp.empty_like(values, device="cpu")
119
- run_count = wp.empty(shape=(1,), dtype=float, device="cpu")
82
+ values = wp.zeros(123, dtype=int, device=device)
83
+ run_values = wp.empty_like(values, device=device)
84
+ run_lengths = wp.empty_like(values, device=device)
85
+ run_count = wp.empty(shape=(1,), dtype=float, device=device)
120
86
  with test.assertRaisesRegex(
121
87
  RuntimeError,
122
88
  r"run_count array must be of type int32$",
@@ -135,57 +101,90 @@ def test_runlength_encode_error_unsupported_dtype(test, device):
135
101
  runlength_encode(values, run_values, run_lengths)
136
102
 
137
103
 
138
- def register(parent):
139
- devices = get_test_devices()
140
-
141
- class TestRunlengthEncode(parent):
142
- pass
143
-
144
- add_function_test(
145
- TestRunlengthEncode, "test_runlength_encode_int", partial(test_runlength_encode_int, n=100), devices=devices
146
- )
147
- add_function_test(
148
- TestRunlengthEncode, "test_runlength_encode_empty", partial(test_runlength_encode_int, n=0), devices=devices
149
- )
150
- add_function_test(
151
- TestRunlengthEncode,
152
- "test_runlength_encode_error_devices_mismatch",
153
- test_runlength_encode_error_devices_mismatch,
154
- )
155
- add_function_test(
156
- TestRunlengthEncode,
157
- "test_runlength_encode_error_insufficient_storage",
158
- test_runlength_encode_error_insufficient_storage,
159
- )
160
- add_function_test(
161
- TestRunlengthEncode, "test_runlength_encode_error_dtypes_mismatch", test_runlength_encode_error_dtypes_mismatch
162
- )
163
- add_function_test(
164
- TestRunlengthEncode,
165
- "test_runlength_encode_error_run_length_unsupported_dtype",
166
- test_runlength_encode_error_run_length_unsupported_dtype,
167
- )
168
- add_function_test(
169
- TestRunlengthEncode,
170
- "test_runlength_encode_error_run_count_device_mismatch",
171
- test_runlength_encode_error_run_count_device_mismatch,
172
- )
173
- add_function_test(
174
- TestRunlengthEncode,
175
- "test_runlength_encode_error_run_count_unsupported_dtype",
176
- test_runlength_encode_error_run_count_unsupported_dtype,
177
- )
178
- add_function_test(
179
- TestRunlengthEncode,
180
- "test_runlength_encode_error_unsupported_dtype",
181
- test_runlength_encode_error_unsupported_dtype,
182
- devices=devices,
183
- )
184
-
185
- return TestRunlengthEncode
104
+ devices = get_test_devices()
105
+
106
+
107
+ class TestRunlengthEncode(unittest.TestCase):
108
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
109
+ def test_runlength_encode_error_devices_mismatch(self):
110
+ values = wp.zeros(123, dtype=int, device="cpu")
111
+ run_values = wp.empty_like(values, device="cuda:0")
112
+ run_lengths = wp.empty_like(values, device="cuda:0")
113
+ with self.assertRaisesRegex(
114
+ RuntimeError,
115
+ r"Array storage devices do not match$",
116
+ ):
117
+ runlength_encode(values, run_values, run_lengths)
118
+
119
+ values = wp.zeros(123, dtype=int, device="cpu")
120
+ run_values = wp.empty_like(values, device="cpu")
121
+ run_lengths = wp.empty_like(values, device="cuda:0")
122
+ with self.assertRaisesRegex(
123
+ RuntimeError,
124
+ r"Array storage devices do not match$",
125
+ ):
126
+ runlength_encode(values, run_values, run_lengths)
127
+
128
+ values = wp.zeros(123, dtype=int, device="cpu")
129
+ run_values = wp.empty_like(values, device="cuda:0")
130
+ run_lengths = wp.empty_like(values, device="cpu")
131
+ with self.assertRaisesRegex(
132
+ RuntimeError,
133
+ r"Array storage devices do not match$",
134
+ ):
135
+ runlength_encode(values, run_values, run_lengths)
136
+
137
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
138
+ def test_runlength_encode_error_run_count_device_mismatch(self):
139
+ values = wp.zeros(123, dtype=int, device="cpu")
140
+ run_values = wp.empty_like(values, device="cpu")
141
+ run_lengths = wp.empty_like(values, device="cpu")
142
+ run_count = wp.empty(shape=(1,), dtype=int, device="cuda:0")
143
+ with self.assertRaisesRegex(
144
+ RuntimeError,
145
+ r"run_count storage device does not match other arrays$",
146
+ ):
147
+ runlength_encode(values, run_values, run_lengths, run_count=run_count)
148
+
149
+
150
+ add_function_test(
151
+ TestRunlengthEncode, "test_runlength_encode_int", partial(test_runlength_encode_int, n=100), devices=devices
152
+ )
153
+ add_function_test(
154
+ TestRunlengthEncode, "test_runlength_encode_empty", partial(test_runlength_encode_int, n=0), devices=devices
155
+ )
156
+ add_function_test(
157
+ TestRunlengthEncode,
158
+ "test_runlength_encode_error_insufficient_storage",
159
+ test_runlength_encode_error_insufficient_storage,
160
+ devices=devices,
161
+ )
162
+ add_function_test(
163
+ TestRunlengthEncode,
164
+ "test_runlength_encode_error_dtypes_mismatch",
165
+ test_runlength_encode_error_dtypes_mismatch,
166
+ devices=devices,
167
+ )
168
+ add_function_test(
169
+ TestRunlengthEncode,
170
+ "test_runlength_encode_error_run_length_unsupported_dtype",
171
+ test_runlength_encode_error_run_length_unsupported_dtype,
172
+ devices=devices,
173
+ )
174
+ add_function_test(
175
+ TestRunlengthEncode,
176
+ "test_runlength_encode_error_run_count_unsupported_dtype",
177
+ test_runlength_encode_error_run_count_unsupported_dtype,
178
+ devices=devices,
179
+ )
180
+ add_function_test(
181
+ TestRunlengthEncode,
182
+ "test_runlength_encode_error_unsupported_dtype",
183
+ test_runlength_encode_error_unsupported_dtype,
184
+ devices=devices,
185
+ )
186
186
 
187
187
 
188
188
  if __name__ == "__main__":
189
189
  wp.build.clear_kernel_cache()
190
- _ = register(unittest.TestCase)
191
190
  unittest.main(verbosity=2)
@@ -5,14 +5,14 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
8
9
  from dataclasses import dataclass
9
10
  from typing import Any
10
- import unittest
11
11
 
12
12
  import numpy as np
13
13
 
14
14
  import warp as wp
15
- from warp.tests.test_base import *
15
+ from warp.tests.unittest_utils import *
16
16
 
17
17
 
18
18
  @dataclass
@@ -153,17 +153,16 @@ def test_smoothstep(test, device):
153
153
  )
154
154
 
155
155
 
156
- def register(parent):
157
- devices = get_test_devices()
156
+ devices = get_test_devices()
157
+
158
+
159
+ class TestSmoothstep(unittest.TestCase):
160
+ pass
158
161
 
159
- class TestSmoothstep(parent):
160
- pass
161
162
 
162
- add_function_test(TestSmoothstep, "test_smoothstep", test_smoothstep, devices=devices)
163
- return TestSmoothstep
163
+ add_function_test(TestSmoothstep, "test_smoothstep", test_smoothstep, devices=devices)
164
164
 
165
165
 
166
166
  if __name__ == "__main__":
167
167
  wp.build.clear_kernel_cache()
168
- _ = register(unittest.TestCase)
169
168
  unittest.main(verbosity=2)
@@ -1,8 +1,10 @@
1
- import warp as wp
2
- from warp.tests.test_base import *
3
- import numpy as np
4
1
  import unittest
5
2
 
3
+ import numpy as np
4
+
5
+ import warp as wp
6
+ from warp.tests.unittest_utils import *
7
+
6
8
  wp.init()
7
9
 
8
10
 
@@ -101,11 +103,11 @@ def test_shared_memory(test, device):
101
103
 
102
104
 
103
105
  def test_cpu_snippet(test, device):
104
-
105
106
  snippet = """
106
107
  int inc = 1;
107
108
  out[tid] = x[tid] + inc;
108
109
  """
110
+
109
111
  @wp.func_native(snippet)
110
112
  def increment_snippet(
111
113
  x: wp.array(dtype=wp.int32),
@@ -115,10 +117,7 @@ def test_cpu_snippet(test, device):
115
117
  ...
116
118
 
117
119
  @wp.kernel
118
- def increment(
119
- x: wp.array(dtype=wp.int32),
120
- out: wp.array(dtype=wp.int32)
121
- ):
120
+ def increment(x: wp.array(dtype=wp.int32), out: wp.array(dtype=wp.int32)):
122
121
  tid = wp.tid()
123
122
  increment_snippet(x, out, tid)
124
123
 
@@ -128,25 +127,17 @@ def test_cpu_snippet(test, device):
128
127
 
129
128
  wp.launch(kernel=increment, dim=N, inputs=[x], outputs=[out], device=device)
130
129
 
131
- assert_np_equal(out.numpy(), np.arange(1, N+1, 1, dtype=np.int32))
130
+ assert_np_equal(out.numpy(), np.arange(1, N + 1, 1, dtype=np.int32))
132
131
 
133
132
 
134
- def register(parent):
135
-
136
- class TestSnippets(parent):
137
- pass
133
+ class TestSnippets(unittest.TestCase):
134
+ pass
138
135
 
139
- if wp.is_cuda_available():
140
- cuda_device = [wp.get_cuda_device()]
141
- add_function_test(TestSnippets, "test_basic", test_basic, devices=cuda_device)
142
- add_function_test(TestSnippets, "test_shared_memory", test_shared_memory, devices=cuda_device)
143
-
144
- if wp.is_cpu_available():
145
- add_function_test(TestSnippets, "test_cpu_snippet", test_cpu_snippet, devices=["cpu"])
146
136
 
147
- return TestSnippets
137
+ add_function_test(TestSnippets, "test_basic", test_basic, devices=get_unique_cuda_test_devices())
138
+ add_function_test(TestSnippets, "test_shared_memory", test_shared_memory, devices=get_unique_cuda_test_devices())
139
+ add_function_test(TestSnippets, "test_cpu_snippet", test_cpu_snippet, devices=["cpu"])
148
140
 
149
141
 
150
142
  if __name__ == "__main__":
151
- c = register(unittest.TestCase)
152
143
  unittest.main(verbosity=2)
warp/tests/test_sparse.py CHANGED
@@ -1,13 +1,20 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
1
8
  import unittest
2
9
 
3
10
  import numpy as np
4
- import warp as wp
5
- import unittest
6
11
 
12
+ import warp as wp
7
13
  from warp.sparse import bsr_zeros, bsr_set_from_triplets, bsr_get_diag, bsr_diag, bsr_identity, bsr_copy, bsr_scale
8
14
  from warp.sparse import bsr_set_transpose, bsr_transposed
9
15
  from warp.sparse import bsr_axpy, bsr_mm, bsr_axpy_work_arrays, bsr_mm_work_arrays, bsr_mv
10
- from warp.tests.test_base import *
16
+ from warp.tests.unittest_utils import *
17
+
11
18
 
12
19
  wp.init()
13
20
 
@@ -419,41 +426,35 @@ def make_test_bsr_mv(block_shape, scalar_type):
419
426
  return test_bsr_mv
420
427
 
421
428
 
422
- def register(parent):
423
- devices = get_test_devices()
429
+ devices = get_test_devices()
424
430
 
425
- class TestSparse(parent):
426
- pass
427
431
 
428
- add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
429
- add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
430
- add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_set_diag, devices=devices)
431
- add_function_test(TestSparse, "test_bsr_copy_scale", test_bsr_copy_scale, devices=devices)
432
+ class TestSparse(unittest.TestCase):
433
+ pass
432
434
 
433
- add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
434
- add_function_test(
435
- TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices
436
- )
437
- add_function_test(
438
- TestSparse, "test_bsr_transpose_3_3", make_test_bsr_transpose((3, 3), wp.float64), devices=devices
439
- )
440
435
 
441
- add_function_test(TestSparse, "test_csr_axpy", make_test_bsr_axpy((1, 1), wp.float32), devices=devices)
442
- add_function_test(TestSparse, "test_bsr_axpy_1_3", make_test_bsr_axpy((1, 3), wp.float32), devices=devices)
443
- add_function_test(TestSparse, "test_bsr_axpy_3_3", make_test_bsr_axpy((3, 3), wp.float64), devices=devices)
436
+ add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
437
+ add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
438
+ add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_set_diag, devices=devices)
439
+ add_function_test(TestSparse, "test_bsr_copy_scale", test_bsr_copy_scale, devices=devices)
440
+
441
+ add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
442
+ add_function_test(TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices)
443
+ add_function_test(TestSparse, "test_bsr_transpose_3_3", make_test_bsr_transpose((3, 3), wp.float64), devices=devices)
444
444
 
445
- add_function_test(TestSparse, "test_csr_mm", make_test_bsr_mm((1, 1), wp.float32), devices=devices)
446
- add_function_test(TestSparse, "test_bsr_mm_1_3", make_test_bsr_mm((1, 3), wp.float32), devices=devices)
447
- add_function_test(TestSparse, "test_bsr_mm_3_3", make_test_bsr_mm((3, 3), wp.float64), devices=devices)
445
+ add_function_test(TestSparse, "test_csr_axpy", make_test_bsr_axpy((1, 1), wp.float32), devices=devices)
446
+ add_function_test(TestSparse, "test_bsr_axpy_1_3", make_test_bsr_axpy((1, 3), wp.float32), devices=devices)
447
+ add_function_test(TestSparse, "test_bsr_axpy_3_3", make_test_bsr_axpy((3, 3), wp.float64), devices=devices)
448
448
 
449
- add_function_test(TestSparse, "test_csr_mv", make_test_bsr_mv((1, 1), wp.float32), devices=devices)
450
- add_function_test(TestSparse, "test_bsr_mv_1_3", make_test_bsr_mv((1, 3), wp.float32), devices=devices)
451
- add_function_test(TestSparse, "test_bsr_mv_3_3", make_test_bsr_mv((3, 3), wp.float64), devices=devices)
449
+ add_function_test(TestSparse, "test_csr_mm", make_test_bsr_mm((1, 1), wp.float32), devices=devices)
450
+ add_function_test(TestSparse, "test_bsr_mm_1_3", make_test_bsr_mm((1, 3), wp.float32), devices=devices)
451
+ add_function_test(TestSparse, "test_bsr_mm_3_3", make_test_bsr_mm((3, 3), wp.float64), devices=devices)
452
452
 
453
- return TestSparse
453
+ add_function_test(TestSparse, "test_csr_mv", make_test_bsr_mv((1, 1), wp.float32), devices=devices)
454
+ add_function_test(TestSparse, "test_bsr_mv_1_3", make_test_bsr_mv((1, 3), wp.float32), devices=devices)
455
+ add_function_test(TestSparse, "test_bsr_mv_3_3", make_test_bsr_mv((3, 3), wp.float64), devices=devices)
454
456
 
455
457
 
456
458
  if __name__ == "__main__":
457
459
  wp.build.clear_kernel_cache()
458
- _ = register(unittest.TestCase)
459
460
  unittest.main(verbosity=2)