warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -47,7 +47,7 @@ except ImportError:
47
47
 
48
48
 
49
49
  # The following variables are NVIDIA Modifications
50
- START_DIRECTORY = os.path.dirname(__file__) # The directory to start test discovery
50
+ START_DIRECTORY = os.path.join(os.path.dirname(__file__), "..") # The directory to start test discovery
51
51
 
52
52
 
53
53
  def main(argv=None):
@@ -275,7 +275,14 @@ def main(argv=None):
275
275
  parallel_failed = True
276
276
 
277
277
  # Fallback to isolated single-process execution if parallel failed
278
- if parallel_failed:
278
+ # Skip fallback in CI/CD environments to respect job timeouts
279
+ in_ci = os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") or os.environ.get("GITLAB_CI")
280
+ if parallel_failed and in_ci:
281
+ parser.exit(
282
+ status=1,
283
+ message="Error: Parallel execution failed in CI/CD environment. Skipping single-process fallback due to job timeout constraints.\n",
284
+ )
285
+ elif parallel_failed:
279
286
  print("Running all tests in isolated single-process mode...", file=sys.stderr)
280
287
  # Run all test suites in isolated single-process pools
281
288
  results = []
warp/_src/torch.py ADDED
@@ -0,0 +1,393 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+
18
+ import numpy
19
+
20
+ import warp
21
+ import warp._src.context
22
+
23
+ _wp_module_name_ = "warp.torch"
24
+
25
+
26
+ # return the warp device corresponding to a torch device
27
+ def device_from_torch(torch_device) -> warp._src.context.Device:
28
+ """Return the Warp device corresponding to a Torch device.
29
+
30
+ Args:
31
+ torch_device (`torch.device` or `str`): Torch device identifier
32
+
33
+ Raises:
34
+ RuntimeError: Torch device does not have a corresponding Warp device
35
+ """
36
+ if type(torch_device) is str:
37
+ warp_device = warp._src.context.runtime.device_map.get(torch_device)
38
+ if warp_device is not None:
39
+ return warp_device
40
+ elif torch_device == "cuda":
41
+ return warp._src.context.runtime.get_current_cuda_device()
42
+ else:
43
+ raise RuntimeError(f"Unsupported Torch device {torch_device}")
44
+ else:
45
+ try:
46
+ if torch_device.type == "cuda":
47
+ return warp._src.context.runtime.cuda_devices[torch_device.index]
48
+ elif torch_device.type == "cpu":
49
+ return warp._src.context.runtime.cpu_device
50
+ else:
51
+ raise RuntimeError(f"Unsupported Torch device type {torch_device.type}")
52
+ except Exception as e:
53
+ import torch
54
+
55
+ if not isinstance(torch_device, torch.device):
56
+ raise ValueError("Argument must be a torch.device object or a string") from e
57
+ raise
58
+
59
+
60
+ def device_to_torch(warp_device: warp._src.context.Devicelike) -> str:
61
+ """Return the Torch device string corresponding to a Warp device.
62
+
63
+ Args:
64
+ warp_device: An identifier that can be resolved to a :class:`warp._src.context.Device`.
65
+
66
+ Raises:
67
+ RuntimeError: The Warp device is not compatible with PyTorch.
68
+ """
69
+ device = warp.get_device(warp_device)
70
+ if device.is_cpu or device.is_primary:
71
+ return str(device)
72
+ elif device.is_cuda and device.is_uva:
73
+ # it's not a primary context, but torch can access the data ptr directly thanks to UVA
74
+ return f"cuda:{device.ordinal}"
75
+ raise RuntimeError(f"Warp device {device} is not compatible with torch")
76
+
77
+
78
+ def dtype_to_torch(warp_dtype):
79
+ """Return the Torch dtype corresponding to a Warp dtype.
80
+
81
+ Args:
82
+ warp_dtype: A Warp data type that has a corresponding ``torch.dtype``.
83
+ ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
84
+ to the signed integer ``torch.dtype`` of the same width.
85
+ Raises:
86
+ TypeError: Unable to find a corresponding PyTorch data type.
87
+ """
88
+ # initialize lookup table on first call to defer torch import
89
+ if dtype_to_torch.type_map is None:
90
+ import torch
91
+
92
+ dtype_to_torch.type_map = {
93
+ warp.float16: torch.float16,
94
+ warp.float32: torch.float32,
95
+ warp.float64: torch.float64,
96
+ warp.int8: torch.int8,
97
+ warp.int16: torch.int16,
98
+ warp.int32: torch.int32,
99
+ warp.int64: torch.int64,
100
+ warp.uint8: torch.uint8,
101
+ # torch doesn't support unsigned ints bigger than 8 bits
102
+ warp.uint16: torch.int16,
103
+ warp.uint32: torch.int32,
104
+ warp.uint64: torch.int64,
105
+ warp.bool: torch.bool,
106
+ }
107
+
108
+ torch_dtype = dtype_to_torch.type_map.get(warp_dtype)
109
+ if torch_dtype is not None:
110
+ return torch_dtype
111
+ else:
112
+ raise TypeError(f"Cannot convert {warp_dtype} to a Torch type")
113
+
114
+
115
+ def dtype_from_torch(torch_dtype):
116
+ """Return the Warp dtype corresponding to a Torch dtype.
117
+
118
+ Args:
119
+ torch_dtype: A ``torch.dtype`` that has a corresponding Warp data type.
120
+ Currently ``torch.bfloat16``, ``torch.complex64``, and
121
+ ``torch.complex128`` are not supported.
122
+
123
+ Raises:
124
+ TypeError: Unable to find a corresponding Warp data type.
125
+ """
126
+ # initialize lookup table on first call to defer torch import
127
+ if dtype_from_torch.type_map is None:
128
+ import torch
129
+
130
+ dtype_from_torch.type_map = {
131
+ torch.float16: warp.float16,
132
+ torch.float32: warp.float32,
133
+ torch.float64: warp.float64,
134
+ torch.int8: warp.int8,
135
+ torch.int16: warp.int16,
136
+ torch.int32: warp.int32,
137
+ torch.int64: warp.int64,
138
+ torch.uint8: warp.uint8,
139
+ torch.bool: warp.bool,
140
+ # currently unsupported by Warp
141
+ # torch.bfloat16:
142
+ # torch.complex64:
143
+ # torch.complex128:
144
+ }
145
+
146
+ warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
147
+
148
+ if warp_dtype is not None:
149
+ return warp_dtype
150
+ else:
151
+ raise TypeError(f"Cannot convert {torch_dtype} to a Warp type")
152
+
153
+
154
+ def dtype_is_compatible(torch_dtype, warp_dtype) -> bool:
155
+ """Evaluates whether the given torch dtype is compatible with the given Warp dtype."""
156
+ # initialize lookup table on first call to defer torch import
157
+ if dtype_is_compatible.compatible_sets is None:
158
+ import torch
159
+
160
+ dtype_is_compatible.compatible_sets = {
161
+ torch.float64: {warp.float64},
162
+ torch.float32: {warp.float32},
163
+ torch.float16: {warp.float16},
164
+ # allow aliasing integer tensors as signed or unsigned integer arrays
165
+ torch.int64: {warp.int64, warp.uint64},
166
+ torch.int32: {warp.int32, warp.uint32},
167
+ torch.int16: {warp.int16, warp.uint16},
168
+ torch.int8: {warp.int8, warp.uint8},
169
+ torch.uint8: {warp.uint8, warp.int8},
170
+ torch.bool: {warp.bool, warp.uint8, warp.int8},
171
+ # currently unsupported by Warp
172
+ # torch.bfloat16:
173
+ # torch.complex64:
174
+ # torch.complex128:
175
+ }
176
+
177
+ compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
178
+
179
+ if compatible_set is not None:
180
+ if warp_dtype in compatible_set:
181
+ return True
182
+ # check if it's a vector or matrix type
183
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
184
+ return warp_dtype._wp_scalar_type_ in compatible_set
185
+
186
+ return False
187
+
188
+
189
+ # lookup tables initialized when needed
190
+ dtype_from_torch.type_map = None
191
+ dtype_to_torch.type_map = None
192
+ dtype_is_compatible.compatible_sets = None
193
+
194
+
195
+ # wrap a torch tensor to a wp array, data is not copied
196
+ def from_torch(t, dtype=None, requires_grad=None, grad=None, return_ctype=False):
197
+ """Convert a Torch tensor to a Warp array without copying the data.
198
+
199
+ Args:
200
+ t (torch.Tensor): The torch tensor to wrap.
201
+ dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
202
+ requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
203
+ return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels.
204
+
205
+ Returns:
206
+ warp.array: The wrapped array or array descriptor.
207
+ """
208
+ if dtype is None:
209
+ dtype = dtype_from_torch(t.dtype)
210
+ elif not dtype_is_compatible(t.dtype, dtype):
211
+ raise RuntimeError(f"Cannot convert Torch type {t.dtype} to Warp type {dtype}")
212
+
213
+ # get size of underlying data type to compute strides
214
+ ctype_size = ctypes.sizeof(dtype._type_)
215
+
216
+ shape = tuple(t.shape)
217
+ strides = tuple(s * ctype_size for s in t.stride())
218
+
219
+ # if target is a vector or matrix type
220
+ # then check if trailing dimensions match
221
+ # the target type and update the shape
222
+ if hasattr(dtype, "_shape_"):
223
+ dtype_shape = dtype._shape_
224
+ dtype_dims = len(dtype._shape_)
225
+ # ensure inner shape matches
226
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
227
+ raise RuntimeError(
228
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
229
+ )
230
+ # ensure inner strides are contiguous
231
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
232
+ raise RuntimeError(
233
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
234
+ )
235
+ # trim shape and strides
236
+ shape = tuple(shape[:-dtype_dims]) or (1,)
237
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
238
+
239
+ # gradient
240
+ # - if return_ctype is False, we set `grad` to a wp.array or None
241
+ # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or torch.Tensor)
242
+ requires_grad = t.requires_grad if requires_grad is None else requires_grad
243
+ grad_ptr = 0
244
+ if grad is not None:
245
+ if isinstance(grad, warp.array):
246
+ if return_ctype:
247
+ if grad.strides != strides:
248
+ raise RuntimeError(
249
+ f"Gradient strides must match array strides, expected {strides} but got {grad.strides}"
250
+ )
251
+ grad_ptr = grad.ptr
252
+ else:
253
+ # assume grad is a torch.Tensor
254
+ if return_ctype:
255
+ if t.stride() != grad.stride():
256
+ raise RuntimeError(
257
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
258
+ )
259
+ grad_ptr = grad.data_ptr()
260
+ else:
261
+ grad = from_torch(grad, dtype=dtype, requires_grad=False)
262
+ elif requires_grad:
263
+ # wrap the tensor gradient, allocate if necessary
264
+ if t.grad is not None:
265
+ if return_ctype:
266
+ grad = t.grad
267
+ if t.stride() != grad.stride():
268
+ raise RuntimeError(
269
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
270
+ )
271
+ grad_ptr = grad.data_ptr()
272
+ else:
273
+ grad = from_torch(t.grad, dtype=dtype, requires_grad=False)
274
+ else:
275
+ # allocate a zero-filled gradient if it doesn't exist
276
+ # Note: we use Warp to allocate the shared gradient with compatible strides
277
+ grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_torch(t.device))
278
+ t.grad = to_torch(grad, requires_grad=False)
279
+ grad_ptr = grad.ptr
280
+
281
+ if return_ctype:
282
+ ptr = t.data_ptr()
283
+
284
+ # create array descriptor
285
+ array_ctype = warp._src.types.array_t(ptr, grad_ptr, len(shape), shape, strides)
286
+
287
+ # keep data and gradient alive
288
+ array_ctype._ref = t
289
+ array_ctype._gradref = grad
290
+
291
+ return array_ctype
292
+
293
+ else:
294
+ a = warp.array(
295
+ ptr=t.data_ptr(),
296
+ dtype=dtype,
297
+ shape=shape,
298
+ strides=strides,
299
+ device=device_from_torch(t.device),
300
+ copy=False,
301
+ grad=grad,
302
+ requires_grad=requires_grad,
303
+ )
304
+
305
+ # save a reference to the source tensor, otherwise it may get deallocated
306
+ a._tensor = t
307
+
308
+ return a
309
+
310
+
311
+ def to_torch(a, requires_grad=None):
312
+ """
313
+ Convert a Warp array to a Torch tensor without copying the data.
314
+
315
+ Args:
316
+ a (warp.array): The Warp array to convert.
317
+ requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
318
+
319
+ Returns:
320
+ torch.Tensor: The converted tensor.
321
+ """
322
+ import torch
323
+
324
+ if requires_grad is None:
325
+ requires_grad = a.requires_grad
326
+
327
+ # Torch does not support structured arrays
328
+ if isinstance(a.dtype, warp._src.codegen.Struct):
329
+ raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
330
+
331
+ if a.device.is_cpu:
332
+ # Torch has an issue wrapping CPU objects
333
+ # that support the __array_interface__ protocol
334
+ # in this case we need to workaround by going
335
+ # to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
336
+ t = torch.as_tensor(numpy.asarray(a))
337
+ t.requires_grad = requires_grad
338
+ if requires_grad and a.requires_grad:
339
+ t.grad = torch.as_tensor(numpy.asarray(a.grad))
340
+ return t
341
+
342
+ elif a.device.is_cuda:
343
+ # Torch does support the __cuda_array_interface__
344
+ # correctly, but we must be sure to maintain a reference
345
+ # to the owning object to prevent memory allocs going out of scope
346
+ t = torch.as_tensor(a, device=device_to_torch(a.device))
347
+ t.requires_grad = requires_grad
348
+ if requires_grad and a.requires_grad:
349
+ t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
350
+ return t
351
+
352
+ else:
353
+ raise RuntimeError("Unsupported device")
354
+
355
+
356
+ def stream_from_torch(stream_or_device=None):
357
+ """Convert from a Torch CUDA stream to a Warp CUDA stream."""
358
+ import torch
359
+
360
+ if isinstance(stream_or_device, torch.cuda.Stream):
361
+ stream = stream_or_device
362
+ else:
363
+ # assume arg is a torch device
364
+ stream = torch.cuda.current_stream(stream_or_device)
365
+
366
+ device = device_from_torch(stream.device)
367
+
368
+ warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
369
+
370
+ # save a reference to the source stream, otherwise it may be destroyed
371
+ warp_stream._torch_stream = stream
372
+
373
+ return warp_stream
374
+
375
+
376
+ def stream_to_torch(stream_or_device=None):
377
+ """Convert from a Warp CUDA stream to a Torch CUDA stream."""
378
+ import torch
379
+
380
+ if isinstance(stream_or_device, warp.Stream):
381
+ stream = stream_or_device
382
+ else:
383
+ # assume arg is a warp device
384
+ stream = warp.get_device(stream_or_device).stream
385
+
386
+ device = device_to_torch(stream.device)
387
+
388
+ torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
389
+
390
+ # save a reference to the source stream, otherwise it may be destroyed
391
+ torch_stream._warp_stream = stream
392
+
393
+ return torch_stream