warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -47,7 +47,7 @@ except ImportError:
47
47
 
48
48
 
49
49
  # The following variables are NVIDIA Modifications
50
- START_DIRECTORY = os.path.dirname(__file__) # The directory to start test discovery
50
+ START_DIRECTORY = os.path.join(os.path.dirname(__file__), "..") # The directory to start test discovery
51
51
 
52
52
 
53
53
  def main(argv=None):
@@ -275,7 +275,14 @@ def main(argv=None):
275
275
  parallel_failed = True
276
276
 
277
277
  # Fallback to isolated single-process execution if parallel failed
278
- if parallel_failed:
278
+ # Skip fallback in CI/CD environments to respect job timeouts
279
+ in_ci = os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") or os.environ.get("GITLAB_CI")
280
+ if parallel_failed and in_ci:
281
+ parser.exit(
282
+ status=1,
283
+ message="Error: Parallel execution failed in CI/CD environment. Skipping single-process fallback due to job timeout constraints.\n",
284
+ )
285
+ elif parallel_failed:
279
286
  print("Running all tests in isolated single-process mode...", file=sys.stderr)
280
287
  # Run all test suites in isolated single-process pools
281
288
  results = []
warp/_src/torch.py ADDED
@@ -0,0 +1,393 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+
18
+ import numpy
19
+
20
+ import warp
21
+ import warp._src.context
22
+
23
+ _wp_module_name_ = "warp.torch"
24
+
25
+
26
+ # return the warp device corresponding to a torch device
27
+ def device_from_torch(torch_device) -> warp._src.context.Device:
28
+ """Return the Warp device corresponding to a Torch device.
29
+
30
+ Args:
31
+ torch_device (`torch.device` or `str`): Torch device identifier
32
+
33
+ Raises:
34
+ RuntimeError: Torch device does not have a corresponding Warp device
35
+ """
36
+ if type(torch_device) is str:
37
+ warp_device = warp._src.context.runtime.device_map.get(torch_device)
38
+ if warp_device is not None:
39
+ return warp_device
40
+ elif torch_device == "cuda":
41
+ return warp._src.context.runtime.get_current_cuda_device()
42
+ else:
43
+ raise RuntimeError(f"Unsupported Torch device {torch_device}")
44
+ else:
45
+ try:
46
+ if torch_device.type == "cuda":
47
+ return warp._src.context.runtime.cuda_devices[torch_device.index]
48
+ elif torch_device.type == "cpu":
49
+ return warp._src.context.runtime.cpu_device
50
+ else:
51
+ raise RuntimeError(f"Unsupported Torch device type {torch_device.type}")
52
+ except Exception as e:
53
+ import torch
54
+
55
+ if not isinstance(torch_device, torch.device):
56
+ raise ValueError("Argument must be a torch.device object or a string") from e
57
+ raise
58
+
59
+
60
+ def device_to_torch(warp_device: warp._src.context.Devicelike) -> str:
61
+ """Return the Torch device string corresponding to a Warp device.
62
+
63
+ Args:
64
+ warp_device: An identifier that can be resolved to a :class:`warp._src.context.Device`.
65
+
66
+ Raises:
67
+ RuntimeError: The Warp device is not compatible with PyTorch.
68
+ """
69
+ device = warp.get_device(warp_device)
70
+ if device.is_cpu or device.is_primary:
71
+ return str(device)
72
+ elif device.is_cuda and device.is_uva:
73
+ # it's not a primary context, but torch can access the data ptr directly thanks to UVA
74
+ return f"cuda:{device.ordinal}"
75
+ raise RuntimeError(f"Warp device {device} is not compatible with torch")
76
+
77
+
78
+ def dtype_to_torch(warp_dtype):
79
+ """Return the Torch dtype corresponding to a Warp dtype.
80
+
81
+ Args:
82
+ warp_dtype: A Warp data type that has a corresponding ``torch.dtype``.
83
+ ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
84
+ to the signed integer ``torch.dtype`` of the same width.
85
+ Raises:
86
+ TypeError: Unable to find a corresponding PyTorch data type.
87
+ """
88
+ # initialize lookup table on first call to defer torch import
89
+ if dtype_to_torch.type_map is None:
90
+ import torch
91
+
92
+ dtype_to_torch.type_map = {
93
+ warp.float16: torch.float16,
94
+ warp.float32: torch.float32,
95
+ warp.float64: torch.float64,
96
+ warp.int8: torch.int8,
97
+ warp.int16: torch.int16,
98
+ warp.int32: torch.int32,
99
+ warp.int64: torch.int64,
100
+ warp.uint8: torch.uint8,
101
+ # torch doesn't support unsigned ints bigger than 8 bits
102
+ warp.uint16: torch.int16,
103
+ warp.uint32: torch.int32,
104
+ warp.uint64: torch.int64,
105
+ warp.bool: torch.bool,
106
+ }
107
+
108
+ torch_dtype = dtype_to_torch.type_map.get(warp_dtype)
109
+ if torch_dtype is not None:
110
+ return torch_dtype
111
+ else:
112
+ raise TypeError(f"Cannot convert {warp_dtype} to a Torch type")
113
+
114
+
115
+ def dtype_from_torch(torch_dtype):
116
+ """Return the Warp dtype corresponding to a Torch dtype.
117
+
118
+ Args:
119
+ torch_dtype: A ``torch.dtype`` that has a corresponding Warp data type.
120
+ Currently ``torch.bfloat16``, ``torch.complex64``, and
121
+ ``torch.complex128`` are not supported.
122
+
123
+ Raises:
124
+ TypeError: Unable to find a corresponding Warp data type.
125
+ """
126
+ # initialize lookup table on first call to defer torch import
127
+ if dtype_from_torch.type_map is None:
128
+ import torch
129
+
130
+ dtype_from_torch.type_map = {
131
+ torch.float16: warp.float16,
132
+ torch.float32: warp.float32,
133
+ torch.float64: warp.float64,
134
+ torch.int8: warp.int8,
135
+ torch.int16: warp.int16,
136
+ torch.int32: warp.int32,
137
+ torch.int64: warp.int64,
138
+ torch.uint8: warp.uint8,
139
+ torch.bool: warp.bool,
140
+ # currently unsupported by Warp
141
+ # torch.bfloat16:
142
+ # torch.complex64:
143
+ # torch.complex128:
144
+ }
145
+
146
+ warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
147
+
148
+ if warp_dtype is not None:
149
+ return warp_dtype
150
+ else:
151
+ raise TypeError(f"Cannot convert {torch_dtype} to a Warp type")
152
+
153
+
154
+ def dtype_is_compatible(torch_dtype, warp_dtype) -> bool:
155
+ """Evaluates whether the given torch dtype is compatible with the given Warp dtype."""
156
+ # initialize lookup table on first call to defer torch import
157
+ if dtype_is_compatible.compatible_sets is None:
158
+ import torch
159
+
160
+ dtype_is_compatible.compatible_sets = {
161
+ torch.float64: {warp.float64},
162
+ torch.float32: {warp.float32},
163
+ torch.float16: {warp.float16},
164
+ # allow aliasing integer tensors as signed or unsigned integer arrays
165
+ torch.int64: {warp.int64, warp.uint64},
166
+ torch.int32: {warp.int32, warp.uint32},
167
+ torch.int16: {warp.int16, warp.uint16},
168
+ torch.int8: {warp.int8, warp.uint8},
169
+ torch.uint8: {warp.uint8, warp.int8},
170
+ torch.bool: {warp.bool, warp.uint8, warp.int8},
171
+ # currently unsupported by Warp
172
+ # torch.bfloat16:
173
+ # torch.complex64:
174
+ # torch.complex128:
175
+ }
176
+
177
+ compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
178
+
179
+ if compatible_set is not None:
180
+ if warp_dtype in compatible_set:
181
+ return True
182
+ # check if it's a vector or matrix type
183
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
184
+ return warp_dtype._wp_scalar_type_ in compatible_set
185
+
186
+ return False
187
+
188
+
189
+ # lookup tables initialized when needed
190
+ dtype_from_torch.type_map = None
191
+ dtype_to_torch.type_map = None
192
+ dtype_is_compatible.compatible_sets = None
193
+
194
+
195
+ # wrap a torch tensor to a wp array, data is not copied
196
+ def from_torch(t, dtype=None, requires_grad=None, grad=None, return_ctype=False):
197
+ """Convert a Torch tensor to a Warp array without copying the data.
198
+
199
+ Args:
200
+ t (torch.Tensor): The torch tensor to wrap.
201
+ dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
202
+ requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
203
+ return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels.
204
+
205
+ Returns:
206
+ warp.array: The wrapped array or array descriptor.
207
+ """
208
+ if dtype is None:
209
+ dtype = dtype_from_torch(t.dtype)
210
+ elif not dtype_is_compatible(t.dtype, dtype):
211
+ raise RuntimeError(f"Cannot convert Torch type {t.dtype} to Warp type {dtype}")
212
+
213
+ # get size of underlying data type to compute strides
214
+ ctype_size = ctypes.sizeof(dtype._type_)
215
+
216
+ shape = tuple(t.shape)
217
+ strides = tuple(s * ctype_size for s in t.stride())
218
+
219
+ # if target is a vector or matrix type
220
+ # then check if trailing dimensions match
221
+ # the target type and update the shape
222
+ if hasattr(dtype, "_shape_"):
223
+ dtype_shape = dtype._shape_
224
+ dtype_dims = len(dtype._shape_)
225
+ # ensure inner shape matches
226
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
227
+ raise RuntimeError(
228
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
229
+ )
230
+ # ensure inner strides are contiguous
231
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
232
+ raise RuntimeError(
233
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
234
+ )
235
+ # trim shape and strides
236
+ shape = tuple(shape[:-dtype_dims]) or (1,)
237
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
238
+
239
+ # gradient
240
+ # - if return_ctype is False, we set `grad` to a wp.array or None
241
+ # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or torch.Tensor)
242
+ requires_grad = t.requires_grad if requires_grad is None else requires_grad
243
+ grad_ptr = 0
244
+ if grad is not None:
245
+ if isinstance(grad, warp.array):
246
+ if return_ctype:
247
+ if grad.strides != strides:
248
+ raise RuntimeError(
249
+ f"Gradient strides must match array strides, expected {strides} but got {grad.strides}"
250
+ )
251
+ grad_ptr = grad.ptr
252
+ else:
253
+ # assume grad is a torch.Tensor
254
+ if return_ctype:
255
+ if t.stride() != grad.stride():
256
+ raise RuntimeError(
257
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
258
+ )
259
+ grad_ptr = grad.data_ptr()
260
+ else:
261
+ grad = from_torch(grad, dtype=dtype, requires_grad=False)
262
+ elif requires_grad:
263
+ # wrap the tensor gradient, allocate if necessary
264
+ if t.grad is not None:
265
+ if return_ctype:
266
+ grad = t.grad
267
+ if t.stride() != grad.stride():
268
+ raise RuntimeError(
269
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
270
+ )
271
+ grad_ptr = grad.data_ptr()
272
+ else:
273
+ grad = from_torch(t.grad, dtype=dtype, requires_grad=False)
274
+ else:
275
+ # allocate a zero-filled gradient if it doesn't exist
276
+ # Note: we use Warp to allocate the shared gradient with compatible strides
277
+ grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_torch(t.device))
278
+ t.grad = to_torch(grad, requires_grad=False)
279
+ grad_ptr = grad.ptr
280
+
281
+ if return_ctype:
282
+ ptr = t.data_ptr()
283
+
284
+ # create array descriptor
285
+ array_ctype = warp._src.types.array_t(ptr, grad_ptr, len(shape), shape, strides)
286
+
287
+ # keep data and gradient alive
288
+ array_ctype._ref = t
289
+ array_ctype._gradref = grad
290
+
291
+ return array_ctype
292
+
293
+ else:
294
+ a = warp.array(
295
+ ptr=t.data_ptr(),
296
+ dtype=dtype,
297
+ shape=shape,
298
+ strides=strides,
299
+ device=device_from_torch(t.device),
300
+ copy=False,
301
+ grad=grad,
302
+ requires_grad=requires_grad,
303
+ )
304
+
305
+ # save a reference to the source tensor, otherwise it may get deallocated
306
+ a._tensor = t
307
+
308
+ return a
309
+
310
+
311
+ def to_torch(a, requires_grad=None):
312
+ """
313
+ Convert a Warp array to a Torch tensor without copying the data.
314
+
315
+ Args:
316
+ a (warp.array): The Warp array to convert.
317
+ requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
318
+
319
+ Returns:
320
+ torch.Tensor: The converted tensor.
321
+ """
322
+ import torch
323
+
324
+ if requires_grad is None:
325
+ requires_grad = a.requires_grad
326
+
327
+ # Torch does not support structured arrays
328
+ if isinstance(a.dtype, warp._src.codegen.Struct):
329
+ raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
330
+
331
+ if a.device.is_cpu:
332
+ # Torch has an issue wrapping CPU objects
333
+ # that support the __array_interface__ protocol
334
+ # in this case we need to workaround by going
335
+ # to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
336
+ t = torch.as_tensor(numpy.asarray(a))
337
+ t.requires_grad = requires_grad
338
+ if requires_grad and a.requires_grad:
339
+ t.grad = torch.as_tensor(numpy.asarray(a.grad))
340
+ return t
341
+
342
+ elif a.device.is_cuda:
343
+ # Torch does support the __cuda_array_interface__
344
+ # correctly, but we must be sure to maintain a reference
345
+ # to the owning object to prevent memory allocs going out of scope
346
+ t = torch.as_tensor(a, device=device_to_torch(a.device))
347
+ t.requires_grad = requires_grad
348
+ if requires_grad and a.requires_grad:
349
+ t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
350
+ return t
351
+
352
+ else:
353
+ raise RuntimeError("Unsupported device")
354
+
355
+
356
+ def stream_from_torch(stream_or_device=None):
357
+ """Convert from a Torch CUDA stream to a Warp CUDA stream."""
358
+ import torch
359
+
360
+ if isinstance(stream_or_device, torch.cuda.Stream):
361
+ stream = stream_or_device
362
+ else:
363
+ # assume arg is a torch device
364
+ stream = torch.cuda.current_stream(stream_or_device)
365
+
366
+ device = device_from_torch(stream.device)
367
+
368
+ warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
369
+
370
+ # save a reference to the source stream, otherwise it may be destroyed
371
+ warp_stream._torch_stream = stream
372
+
373
+ return warp_stream
374
+
375
+
376
+ def stream_to_torch(stream_or_device=None):
377
+ """Convert from a Warp CUDA stream to a Torch CUDA stream."""
378
+ import torch
379
+
380
+ if isinstance(stream_or_device, warp.Stream):
381
+ stream = stream_or_device
382
+ else:
383
+ # assume arg is a warp device
384
+ stream = warp.get_device(stream_or_device).stream
385
+
386
+ device = device_to_torch(stream.device)
387
+
388
+ torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
389
+
390
+ # save a reference to the source stream, otherwise it may be destroyed
391
+ torch_stream._warp_stream = stream
392
+
393
+ return torch_stream