warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -47,7 +47,7 @@ except ImportError:
47
47
 
48
48
 
49
49
  # The following variables are NVIDIA Modifications
50
- START_DIRECTORY = os.path.dirname(__file__) # The directory to start test discovery
50
+ START_DIRECTORY = os.path.join(os.path.dirname(__file__), "..") # The directory to start test discovery
51
51
 
52
52
 
53
53
  def main(argv=None):
@@ -275,7 +275,14 @@ def main(argv=None):
275
275
  parallel_failed = True
276
276
 
277
277
  # Fallback to isolated single-process execution if parallel failed
278
- if parallel_failed:
278
+ # Skip fallback in CI/CD environments to respect job timeouts
279
+ in_ci = os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") or os.environ.get("GITLAB_CI")
280
+ if parallel_failed and in_ci:
281
+ parser.exit(
282
+ status=1,
283
+ message="Error: Parallel execution failed in CI/CD environment. Skipping single-process fallback due to job timeout constraints.\n",
284
+ )
285
+ elif parallel_failed:
279
286
  print("Running all tests in isolated single-process mode...", file=sys.stderr)
280
287
  # Run all test suites in isolated single-process pools
281
288
  results = []
warp/_src/torch.py ADDED
@@ -0,0 +1,391 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+
18
+ import numpy
19
+
20
+ import warp
21
+ import warp._src.context
22
+
23
+
24
+ # return the warp device corresponding to a torch device
25
+ def device_from_torch(torch_device) -> warp._src.context.Device:
26
+ """Return the Warp device corresponding to a Torch device.
27
+
28
+ Args:
29
+ torch_device (`torch.device` or `str`): Torch device identifier
30
+
31
+ Raises:
32
+ RuntimeError: Torch device does not have a corresponding Warp device
33
+ """
34
+ if type(torch_device) is str:
35
+ warp_device = warp._src.context.runtime.device_map.get(torch_device)
36
+ if warp_device is not None:
37
+ return warp_device
38
+ elif torch_device == "cuda":
39
+ return warp._src.context.runtime.get_current_cuda_device()
40
+ else:
41
+ raise RuntimeError(f"Unsupported Torch device {torch_device}")
42
+ else:
43
+ try:
44
+ if torch_device.type == "cuda":
45
+ return warp._src.context.runtime.cuda_devices[torch_device.index]
46
+ elif torch_device.type == "cpu":
47
+ return warp._src.context.runtime.cpu_device
48
+ else:
49
+ raise RuntimeError(f"Unsupported Torch device type {torch_device.type}")
50
+ except Exception as e:
51
+ import torch
52
+
53
+ if not isinstance(torch_device, torch.device):
54
+ raise ValueError("Argument must be a torch.device object or a string") from e
55
+ raise
56
+
57
+
58
+ def device_to_torch(warp_device: warp._src.context.Devicelike) -> str:
59
+ """Return the Torch device string corresponding to a Warp device.
60
+
61
+ Args:
62
+ warp_device: An identifier that can be resolved to a :class:`warp._src.context.Device`.
63
+
64
+ Raises:
65
+ RuntimeError: The Warp device is not compatible with PyTorch.
66
+ """
67
+ device = warp.get_device(warp_device)
68
+ if device.is_cpu or device.is_primary:
69
+ return str(device)
70
+ elif device.is_cuda and device.is_uva:
71
+ # it's not a primary context, but torch can access the data ptr directly thanks to UVA
72
+ return f"cuda:{device.ordinal}"
73
+ raise RuntimeError(f"Warp device {device} is not compatible with torch")
74
+
75
+
76
+ def dtype_to_torch(warp_dtype):
77
+ """Return the Torch dtype corresponding to a Warp dtype.
78
+
79
+ Args:
80
+ warp_dtype: A Warp data type that has a corresponding ``torch.dtype``.
81
+ ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
82
+ to the signed integer ``torch.dtype`` of the same width.
83
+ Raises:
84
+ TypeError: Unable to find a corresponding PyTorch data type.
85
+ """
86
+ # initialize lookup table on first call to defer torch import
87
+ if dtype_to_torch.type_map is None:
88
+ import torch
89
+
90
+ dtype_to_torch.type_map = {
91
+ warp.float16: torch.float16,
92
+ warp.float32: torch.float32,
93
+ warp.float64: torch.float64,
94
+ warp.int8: torch.int8,
95
+ warp.int16: torch.int16,
96
+ warp.int32: torch.int32,
97
+ warp.int64: torch.int64,
98
+ warp.uint8: torch.uint8,
99
+ # torch doesn't support unsigned ints bigger than 8 bits
100
+ warp.uint16: torch.int16,
101
+ warp.uint32: torch.int32,
102
+ warp.uint64: torch.int64,
103
+ warp.bool: torch.bool,
104
+ }
105
+
106
+ torch_dtype = dtype_to_torch.type_map.get(warp_dtype)
107
+ if torch_dtype is not None:
108
+ return torch_dtype
109
+ else:
110
+ raise TypeError(f"Cannot convert {warp_dtype} to a Torch type")
111
+
112
+
113
+ def dtype_from_torch(torch_dtype):
114
+ """Return the Warp dtype corresponding to a Torch dtype.
115
+
116
+ Args:
117
+ torch_dtype: A ``torch.dtype`` that has a corresponding Warp data type.
118
+ Currently ``torch.bfloat16``, ``torch.complex64``, and
119
+ ``torch.complex128`` are not supported.
120
+
121
+ Raises:
122
+ TypeError: Unable to find a corresponding Warp data type.
123
+ """
124
+ # initialize lookup table on first call to defer torch import
125
+ if dtype_from_torch.type_map is None:
126
+ import torch
127
+
128
+ dtype_from_torch.type_map = {
129
+ torch.float16: warp.float16,
130
+ torch.float32: warp.float32,
131
+ torch.float64: warp.float64,
132
+ torch.int8: warp.int8,
133
+ torch.int16: warp.int16,
134
+ torch.int32: warp.int32,
135
+ torch.int64: warp.int64,
136
+ torch.uint8: warp.uint8,
137
+ torch.bool: warp.bool,
138
+ # currently unsupported by Warp
139
+ # torch.bfloat16:
140
+ # torch.complex64:
141
+ # torch.complex128:
142
+ }
143
+
144
+ warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
145
+
146
+ if warp_dtype is not None:
147
+ return warp_dtype
148
+ else:
149
+ raise TypeError(f"Cannot convert {torch_dtype} to a Warp type")
150
+
151
+
152
+ def dtype_is_compatible(torch_dtype, warp_dtype) -> bool:
153
+ """Evaluates whether the given torch dtype is compatible with the given Warp dtype."""
154
+ # initialize lookup table on first call to defer torch import
155
+ if dtype_is_compatible.compatible_sets is None:
156
+ import torch
157
+
158
+ dtype_is_compatible.compatible_sets = {
159
+ torch.float64: {warp.float64},
160
+ torch.float32: {warp.float32},
161
+ torch.float16: {warp.float16},
162
+ # allow aliasing integer tensors as signed or unsigned integer arrays
163
+ torch.int64: {warp.int64, warp.uint64},
164
+ torch.int32: {warp.int32, warp.uint32},
165
+ torch.int16: {warp.int16, warp.uint16},
166
+ torch.int8: {warp.int8, warp.uint8},
167
+ torch.uint8: {warp.uint8, warp.int8},
168
+ torch.bool: {warp.bool, warp.uint8, warp.int8},
169
+ # currently unsupported by Warp
170
+ # torch.bfloat16:
171
+ # torch.complex64:
172
+ # torch.complex128:
173
+ }
174
+
175
+ compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
176
+
177
+ if compatible_set is not None:
178
+ if warp_dtype in compatible_set:
179
+ return True
180
+ # check if it's a vector or matrix type
181
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
182
+ return warp_dtype._wp_scalar_type_ in compatible_set
183
+
184
+ return False
185
+
186
+
187
+ # lookup tables initialized when needed
188
+ dtype_from_torch.type_map = None
189
+ dtype_to_torch.type_map = None
190
+ dtype_is_compatible.compatible_sets = None
191
+
192
+
193
+ # wrap a torch tensor to a wp array, data is not copied
194
+ def from_torch(t, dtype=None, requires_grad=None, grad=None, return_ctype=False):
195
+ """Convert a Torch tensor to a Warp array without copying the data.
196
+
197
+ Args:
198
+ t (torch.Tensor): The torch tensor to wrap.
199
+ dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
200
+ requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
201
+ return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels.
202
+
203
+ Returns:
204
+ warp.array: The wrapped array or array descriptor.
205
+ """
206
+ if dtype is None:
207
+ dtype = dtype_from_torch(t.dtype)
208
+ elif not dtype_is_compatible(t.dtype, dtype):
209
+ raise RuntimeError(f"Cannot convert Torch type {t.dtype} to Warp type {dtype}")
210
+
211
+ # get size of underlying data type to compute strides
212
+ ctype_size = ctypes.sizeof(dtype._type_)
213
+
214
+ shape = tuple(t.shape)
215
+ strides = tuple(s * ctype_size for s in t.stride())
216
+
217
+ # if target is a vector or matrix type
218
+ # then check if trailing dimensions match
219
+ # the target type and update the shape
220
+ if hasattr(dtype, "_shape_"):
221
+ dtype_shape = dtype._shape_
222
+ dtype_dims = len(dtype._shape_)
223
+ # ensure inner shape matches
224
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
225
+ raise RuntimeError(
226
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
227
+ )
228
+ # ensure inner strides are contiguous
229
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
230
+ raise RuntimeError(
231
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
232
+ )
233
+ # trim shape and strides
234
+ shape = tuple(shape[:-dtype_dims]) or (1,)
235
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
236
+
237
+ # gradient
238
+ # - if return_ctype is False, we set `grad` to a wp.array or None
239
+ # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or torch.Tensor)
240
+ requires_grad = t.requires_grad if requires_grad is None else requires_grad
241
+ grad_ptr = 0
242
+ if grad is not None:
243
+ if isinstance(grad, warp.array):
244
+ if return_ctype:
245
+ if grad.strides != strides:
246
+ raise RuntimeError(
247
+ f"Gradient strides must match array strides, expected {strides} but got {grad.strides}"
248
+ )
249
+ grad_ptr = grad.ptr
250
+ else:
251
+ # assume grad is a torch.Tensor
252
+ if return_ctype:
253
+ if t.stride() != grad.stride():
254
+ raise RuntimeError(
255
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
256
+ )
257
+ grad_ptr = grad.data_ptr()
258
+ else:
259
+ grad = from_torch(grad, dtype=dtype, requires_grad=False)
260
+ elif requires_grad:
261
+ # wrap the tensor gradient, allocate if necessary
262
+ if t.grad is not None:
263
+ if return_ctype:
264
+ grad = t.grad
265
+ if t.stride() != grad.stride():
266
+ raise RuntimeError(
267
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
268
+ )
269
+ grad_ptr = grad.data_ptr()
270
+ else:
271
+ grad = from_torch(t.grad, dtype=dtype, requires_grad=False)
272
+ else:
273
+ # allocate a zero-filled gradient if it doesn't exist
274
+ # Note: we use Warp to allocate the shared gradient with compatible strides
275
+ grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_torch(t.device))
276
+ t.grad = to_torch(grad, requires_grad=False)
277
+ grad_ptr = grad.ptr
278
+
279
+ if return_ctype:
280
+ ptr = t.data_ptr()
281
+
282
+ # create array descriptor
283
+ array_ctype = warp._src.types.array_t(ptr, grad_ptr, len(shape), shape, strides)
284
+
285
+ # keep data and gradient alive
286
+ array_ctype._ref = t
287
+ array_ctype._gradref = grad
288
+
289
+ return array_ctype
290
+
291
+ else:
292
+ a = warp.array(
293
+ ptr=t.data_ptr(),
294
+ dtype=dtype,
295
+ shape=shape,
296
+ strides=strides,
297
+ device=device_from_torch(t.device),
298
+ copy=False,
299
+ grad=grad,
300
+ requires_grad=requires_grad,
301
+ )
302
+
303
+ # save a reference to the source tensor, otherwise it may get deallocated
304
+ a._tensor = t
305
+
306
+ return a
307
+
308
+
309
+ def to_torch(a, requires_grad=None):
310
+ """
311
+ Convert a Warp array to a Torch tensor without copying the data.
312
+
313
+ Args:
314
+ a (warp.array): The Warp array to convert.
315
+ requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
316
+
317
+ Returns:
318
+ torch.Tensor: The converted tensor.
319
+ """
320
+ import torch
321
+
322
+ if requires_grad is None:
323
+ requires_grad = a.requires_grad
324
+
325
+ # Torch does not support structured arrays
326
+ if isinstance(a.dtype, warp._src.codegen.Struct):
327
+ raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
328
+
329
+ if a.device.is_cpu:
330
+ # Torch has an issue wrapping CPU objects
331
+ # that support the __array_interface__ protocol
332
+ # in this case we need to workaround by going
333
+ # to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
334
+ t = torch.as_tensor(numpy.asarray(a))
335
+ t.requires_grad = requires_grad
336
+ if requires_grad and a.requires_grad:
337
+ t.grad = torch.as_tensor(numpy.asarray(a.grad))
338
+ return t
339
+
340
+ elif a.device.is_cuda:
341
+ # Torch does support the __cuda_array_interface__
342
+ # correctly, but we must be sure to maintain a reference
343
+ # to the owning object to prevent memory allocs going out of scope
344
+ t = torch.as_tensor(a, device=device_to_torch(a.device))
345
+ t.requires_grad = requires_grad
346
+ if requires_grad and a.requires_grad:
347
+ t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
348
+ return t
349
+
350
+ else:
351
+ raise RuntimeError("Unsupported device")
352
+
353
+
354
+ def stream_from_torch(stream_or_device=None):
355
+ """Convert from a Torch CUDA stream to a Warp CUDA stream."""
356
+ import torch
357
+
358
+ if isinstance(stream_or_device, torch.cuda.Stream):
359
+ stream = stream_or_device
360
+ else:
361
+ # assume arg is a torch device
362
+ stream = torch.cuda.current_stream(stream_or_device)
363
+
364
+ device = device_from_torch(stream.device)
365
+
366
+ warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
367
+
368
+ # save a reference to the source stream, otherwise it may be destroyed
369
+ warp_stream._torch_stream = stream
370
+
371
+ return warp_stream
372
+
373
+
374
+ def stream_to_torch(stream_or_device=None):
375
+ """Convert from a Warp CUDA stream to a Torch CUDA stream."""
376
+ import torch
377
+
378
+ if isinstance(stream_or_device, warp.Stream):
379
+ stream = stream_or_device
380
+ else:
381
+ # assume arg is a warp device
382
+ stream = warp.get_device(stream_or_device).stream
383
+
384
+ device = device_to_torch(stream.device)
385
+
386
+ torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
387
+
388
+ # save a reference to the source stream, otherwise it may be destroyed
389
+ torch_stream._warp_stream = stream
390
+
391
+ return torch_stream