warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_tape.py CHANGED
@@ -5,9 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
11
+
9
12
  import warp as wp
10
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
11
14
 
12
15
  wp.init()
13
16
 
@@ -19,11 +22,17 @@ def mul_constant(x: wp.array(dtype=float), y: wp.array(dtype=float)):
19
22
  y[tid] = x[tid] * 2.0
20
23
 
21
24
 
25
+ @wp.struct
26
+ class Multiplicands:
27
+ x: wp.array(dtype=float)
28
+ y: wp.array(dtype=float)
29
+
30
+
22
31
  @wp.kernel
23
- def mul_variable(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
32
+ def mul_variable(mutiplicands: Multiplicands, z: wp.array(dtype=float)):
24
33
  tid = wp.tid()
25
34
 
26
- z[tid] = x[tid] * y[tid]
35
+ z[tid] = mutiplicands.x[tid] * mutiplicands.y[tid]
27
36
 
28
37
 
29
38
  @wp.kernel
@@ -65,12 +74,13 @@ def test_tape_mul_variable(test, device):
65
74
 
66
75
  # record onto tape
67
76
  with tape:
68
- # input data
69
- x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
70
- y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
71
- z = wp.zeros_like(x)
77
+ # input data (Note: We're intentionally testing structs in tapes here)
78
+ multiplicands = Multiplicands()
79
+ multiplicands.x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
80
+ multiplicands.y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
81
+ z = wp.zeros_like(multiplicands.x)
72
82
 
73
- wp.launch(kernel=mul_variable, dim=dim, inputs=[x, y], outputs=[z], device=device)
83
+ wp.launch(kernel=mul_variable, dim=dim, inputs=[multiplicands], outputs=[z], device=device)
74
84
 
75
85
  # loss = wp.sum(x)
76
86
  z.grad = wp.array(np.ones(dim), device=device, dtype=wp.float32)
@@ -79,16 +89,21 @@ def test_tape_mul_variable(test, device):
79
89
  tape.backward()
80
90
 
81
91
  # grad_x=y, grad_y=x
82
- assert_np_equal(tape.gradients[x].numpy(), y.numpy())
83
- assert_np_equal(tape.gradients[y].numpy(), x.numpy())
92
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), multiplicands.y.numpy())
93
+ assert_np_equal(tape.gradients[multiplicands].y.numpy(), multiplicands.x.numpy())
84
94
 
85
95
  # run backward again with different incoming gradient
86
96
  # should accumulate the same gradients again onto output
87
97
  # so gradients = 2.0*prev
88
98
  tape.backward()
89
99
 
90
- assert_np_equal(tape.gradients[x].numpy(), y.numpy() * 2.0)
91
- assert_np_equal(tape.gradients[y].numpy(), x.numpy() * 2.0)
100
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), multiplicands.y.numpy() * 2.0)
101
+ assert_np_equal(tape.gradients[multiplicands].y.numpy(), multiplicands.x.numpy() * 2.0)
102
+
103
+ # Clear launches and zero out the gradients
104
+ tape.reset()
105
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), np.zeros_like(tape.gradients[multiplicands].x.numpy()))
106
+ test.assertFalse(tape.launches)
92
107
 
93
108
 
94
109
  def test_tape_dot_product(test, device):
@@ -112,19 +127,22 @@ def test_tape_dot_product(test, device):
112
127
  assert_np_equal(tape.gradients[y].numpy(), x.numpy())
113
128
 
114
129
 
115
- def register(parent):
116
- devices = get_test_devices()
130
+ devices = get_test_devices()
131
+
117
132
 
118
- class TestTape(parent):
119
- pass
133
+ class TestTape(unittest.TestCase):
134
+ def test_tape_no_nested_tapes(self):
135
+ with self.assertRaises(RuntimeError):
136
+ with wp.Tape():
137
+ with wp.Tape():
138
+ pass
120
139
 
121
- add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
122
- add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
123
- add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
124
140
 
125
- return TestTape
141
+ add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
142
+ add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
143
+ add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
126
144
 
127
145
 
128
146
  if __name__ == "__main__":
129
- c = register(unittest.TestCase)
147
+ wp.build.clear_kernel_cache()
130
148
  unittest.main(verbosity=2)
warp/tests/test_torch.py CHANGED
@@ -7,11 +7,10 @@
7
7
 
8
8
  import unittest
9
9
 
10
- # include parent path
11
10
  import numpy as np
12
11
 
13
12
  import warp as wp
14
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  wp.init()
17
16
 
@@ -470,6 +469,8 @@ def test_torch_autograd(test, device):
470
469
  def test_torch_graph_torch_stream(test, device):
471
470
  """Capture Torch graph on Torch stream"""
472
471
 
472
+ wp.load_module(device=device)
473
+
473
474
  import torch
474
475
 
475
476
  torch_device = wp.device_to_torch(device)
@@ -551,12 +552,14 @@ def test_warp_graph_warp_stream(test, device):
551
552
 
552
553
  # capture graph
553
554
  with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
554
- wp.capture_begin()
555
- t += 1.0
556
- wp.launch(inc, dim=n, inputs=[a])
557
- t += 1.0
558
- wp.launch(inc, dim=n, inputs=[a])
559
- g = wp.capture_end()
555
+ wp.capture_begin(force_module_load=False)
556
+ try:
557
+ t += 1.0
558
+ wp.launch(inc, dim=n, inputs=[a])
559
+ t += 1.0
560
+ wp.launch(inc, dim=n, inputs=[a])
561
+ finally:
562
+ g = wp.capture_end()
560
563
 
561
564
  # replay graph
562
565
  num_iters = 10
@@ -570,6 +573,8 @@ def test_warp_graph_warp_stream(test, device):
570
573
  def test_warp_graph_torch_stream(test, device):
571
574
  """Capture Warp graph on Torch stream"""
572
575
 
576
+ wp.load_module(device=device)
577
+
573
578
  import torch
574
579
 
575
580
  torch_device = wp.device_to_torch(device)
@@ -587,12 +592,14 @@ def test_warp_graph_torch_stream(test, device):
587
592
 
588
593
  # capture graph
589
594
  with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
590
- wp.capture_begin()
591
- t += 1.0
592
- wp.launch(inc, dim=n, inputs=[a])
593
- t += 1.0
594
- wp.launch(inc, dim=n, inputs=[a])
595
- g = wp.capture_end()
595
+ wp.capture_begin(force_module_load=False)
596
+ try:
597
+ t += 1.0
598
+ wp.launch(inc, dim=n, inputs=[a])
599
+ t += 1.0
600
+ wp.launch(inc, dim=n, inputs=[a])
601
+ finally:
602
+ g = wp.capture_end()
596
603
 
597
604
  # replay graph
598
605
  num_iters = 10
@@ -603,82 +610,79 @@ def test_warp_graph_torch_stream(test, device):
603
610
  assert passed.item()
604
611
 
605
612
 
606
- def register(parent):
607
- class TestTorch(parent):
608
- pass
609
-
610
- try:
611
- import torch
612
-
613
- # check which Warp devices work with Torch
614
- # CUDA devices may fail if Torch was not compiled with CUDA support
615
- test_devices = get_test_devices()
616
- torch_compatible_devices = []
617
- torch_compatible_cuda_devices = []
618
-
619
- for d in test_devices:
620
- try:
621
- t = torch.arange(10, device=wp.device_to_torch(d))
622
- t += 1
623
- torch_compatible_devices.append(d)
624
- if d.is_cuda:
625
- torch_compatible_cuda_devices.append(d)
626
- except Exception as e:
627
- print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
628
-
629
- if torch_compatible_devices:
630
- add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
631
- add_function_test(
632
- TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices
633
- )
634
- add_function_test(
635
- TestTorch,
636
- "test_from_torch_zero_strides",
637
- test_from_torch_zero_strides,
638
- devices=torch_compatible_devices,
639
- )
640
- add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
641
- add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
642
- add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
643
-
644
- if torch_compatible_cuda_devices:
645
- add_function_test(
646
- TestTorch,
647
- "test_torch_graph_torch_stream",
648
- test_torch_graph_torch_stream,
649
- devices=torch_compatible_cuda_devices,
650
- )
651
- add_function_test(
652
- TestTorch,
653
- "test_torch_graph_warp_stream",
654
- test_torch_graph_warp_stream,
655
- devices=torch_compatible_cuda_devices,
656
- )
657
- add_function_test(
658
- TestTorch,
659
- "test_warp_graph_warp_stream",
660
- test_warp_graph_warp_stream,
661
- devices=torch_compatible_cuda_devices,
662
- )
663
- add_function_test(
664
- TestTorch,
665
- "test_warp_graph_torch_stream",
666
- test_warp_graph_torch_stream,
667
- devices=torch_compatible_cuda_devices,
668
- )
613
+ class TestTorch(unittest.TestCase):
614
+ pass
669
615
 
670
- # multi-GPU tests
671
- if len(torch_compatible_cuda_devices) > 1:
672
- add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
673
- add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
674
- add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
675
616
 
676
- except Exception as e:
677
- print(f"Skipping Torch tests due to exception: {e}")
617
+ test_devices = get_test_devices()
618
+
619
+ try:
620
+ import torch
678
621
 
679
- return TestTorch
622
+ # check which Warp devices work with Torch
623
+ # CUDA devices may fail if Torch was not compiled with CUDA support
624
+ torch_compatible_devices = []
625
+ torch_compatible_cuda_devices = []
626
+
627
+ for d in test_devices:
628
+ try:
629
+ t = torch.arange(10, device=wp.device_to_torch(d))
630
+ t += 1
631
+ torch_compatible_devices.append(d)
632
+ if d.is_cuda:
633
+ torch_compatible_cuda_devices.append(d)
634
+ except Exception as e:
635
+ print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
636
+
637
+ if torch_compatible_devices:
638
+ add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
639
+ add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
640
+ add_function_test(
641
+ TestTorch,
642
+ "test_from_torch_zero_strides",
643
+ test_from_torch_zero_strides,
644
+ devices=torch_compatible_devices,
645
+ )
646
+ add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
647
+ add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
648
+ add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
649
+
650
+ if torch_compatible_cuda_devices:
651
+ add_function_test(
652
+ TestTorch,
653
+ "test_torch_graph_torch_stream",
654
+ test_torch_graph_torch_stream,
655
+ devices=torch_compatible_cuda_devices,
656
+ )
657
+ add_function_test(
658
+ TestTorch,
659
+ "test_torch_graph_warp_stream",
660
+ test_torch_graph_warp_stream,
661
+ devices=torch_compatible_cuda_devices,
662
+ )
663
+ add_function_test(
664
+ TestTorch,
665
+ "test_warp_graph_warp_stream",
666
+ test_warp_graph_warp_stream,
667
+ devices=torch_compatible_cuda_devices,
668
+ )
669
+ add_function_test(
670
+ TestTorch,
671
+ "test_warp_graph_torch_stream",
672
+ test_warp_graph_torch_stream,
673
+ devices=torch_compatible_cuda_devices,
674
+ )
675
+
676
+ # multi-GPU tests
677
+ if len(torch_compatible_cuda_devices) > 1:
678
+ add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
679
+ add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
680
+ add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
681
+
682
+ except Exception as e:
683
+ print(f"Skipping Torch tests due to exception: {e}")
680
684
 
681
685
 
682
686
  if __name__ == "__main__":
683
- c = register(unittest.TestCase)
687
+ wp.build.clear_kernel_cache()
684
688
  unittest.main(verbosity=2)
@@ -5,14 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import importlib
9
8
  import os
10
9
  import tempfile
11
10
  import unittest
11
+ from importlib import util
12
12
 
13
13
  import warp as wp
14
- from warp.tests.test_base import *
15
- from importlib import util
14
+ from warp.tests.unittest_utils import *
16
15
 
17
16
  CODE = """# -*- coding: utf-8 -*-
18
17
 
@@ -64,26 +63,25 @@ def test_transient_module(test, device):
64
63
  assert len(module.compute.module.functions) == 1
65
64
 
66
65
  data = module.Data()
67
- data.x = wp.array([123], dtype=int)
66
+ data.x = wp.array([123], dtype=int, device=device)
68
67
 
69
68
  wp.set_module_options({"foo": "bar"}, module=module)
70
69
  assert wp.get_module_options(module=module).get("foo") == "bar"
71
70
  assert module.compute.module.options.get("foo") == "bar"
72
71
 
73
- wp.launch(module.compute, dim=1, inputs=[data])
72
+ wp.launch(module.compute, dim=1, inputs=[data], device=device)
74
73
  assert_np_equal(data.x.numpy(), np.array([124]))
75
74
 
76
75
 
77
- def register(parent):
78
- devices = get_test_devices()
76
+ devices = get_test_devices()
77
+
79
78
 
80
- class TestTransientModule(parent):
81
- pass
79
+ class TestTransientModule(unittest.TestCase):
80
+ pass
82
81
 
83
- add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
84
- return TestTransientModule
85
82
 
83
+ add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
86
84
 
87
85
  if __name__ == "__main__":
88
- _ = register(unittest.TestCase)
86
+ wp.build.clear_kernel_cache()
89
87
  unittest.main(verbosity=2)