warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/_src/torch.py ADDED
@@ -0,0 +1,393 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+
18
+ import numpy
19
+
20
+ import warp
21
+ import warp._src.context
22
+
23
+ _wp_module_name_ = "warp.torch"
24
+
25
+
26
+ # return the warp device corresponding to a torch device
27
+ def device_from_torch(torch_device) -> warp._src.context.Device:
28
+ """Return the Warp device corresponding to a Torch device.
29
+
30
+ Args:
31
+ torch_device (`torch.device` or `str`): Torch device identifier
32
+
33
+ Raises:
34
+ RuntimeError: Torch device does not have a corresponding Warp device
35
+ """
36
+ if type(torch_device) is str:
37
+ warp_device = warp._src.context.runtime.device_map.get(torch_device)
38
+ if warp_device is not None:
39
+ return warp_device
40
+ elif torch_device == "cuda":
41
+ return warp._src.context.runtime.get_current_cuda_device()
42
+ else:
43
+ raise RuntimeError(f"Unsupported Torch device {torch_device}")
44
+ else:
45
+ try:
46
+ if torch_device.type == "cuda":
47
+ return warp._src.context.runtime.cuda_devices[torch_device.index]
48
+ elif torch_device.type == "cpu":
49
+ return warp._src.context.runtime.cpu_device
50
+ else:
51
+ raise RuntimeError(f"Unsupported Torch device type {torch_device.type}")
52
+ except Exception as e:
53
+ import torch
54
+
55
+ if not isinstance(torch_device, torch.device):
56
+ raise ValueError("Argument must be a torch.device object or a string") from e
57
+ raise
58
+
59
+
60
+ def device_to_torch(warp_device: warp._src.context.Devicelike) -> str:
61
+ """Return the Torch device string corresponding to a Warp device.
62
+
63
+ Args:
64
+ warp_device: An identifier that can be resolved to a :class:`warp._src.context.Device`.
65
+
66
+ Raises:
67
+ RuntimeError: The Warp device is not compatible with PyTorch.
68
+ """
69
+ device = warp.get_device(warp_device)
70
+ if device.is_cpu or device.is_primary:
71
+ return str(device)
72
+ elif device.is_cuda and device.is_uva:
73
+ # it's not a primary context, but torch can access the data ptr directly thanks to UVA
74
+ return f"cuda:{device.ordinal}"
75
+ raise RuntimeError(f"Warp device {device} is not compatible with torch")
76
+
77
+
78
+ def dtype_to_torch(warp_dtype):
79
+ """Return the Torch dtype corresponding to a Warp dtype.
80
+
81
+ Args:
82
+ warp_dtype: A Warp data type that has a corresponding ``torch.dtype``.
83
+ ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
84
+ to the signed integer ``torch.dtype`` of the same width.
85
+ Raises:
86
+ TypeError: Unable to find a corresponding PyTorch data type.
87
+ """
88
+ # initialize lookup table on first call to defer torch import
89
+ if dtype_to_torch.type_map is None:
90
+ import torch
91
+
92
+ dtype_to_torch.type_map = {
93
+ warp.float16: torch.float16,
94
+ warp.float32: torch.float32,
95
+ warp.float64: torch.float64,
96
+ warp.int8: torch.int8,
97
+ warp.int16: torch.int16,
98
+ warp.int32: torch.int32,
99
+ warp.int64: torch.int64,
100
+ warp.uint8: torch.uint8,
101
+ # torch doesn't support unsigned ints bigger than 8 bits
102
+ warp.uint16: torch.int16,
103
+ warp.uint32: torch.int32,
104
+ warp.uint64: torch.int64,
105
+ warp.bool: torch.bool,
106
+ }
107
+
108
+ torch_dtype = dtype_to_torch.type_map.get(warp_dtype)
109
+ if torch_dtype is not None:
110
+ return torch_dtype
111
+ else:
112
+ raise TypeError(f"Cannot convert {warp_dtype} to a Torch type")
113
+
114
+
115
+ def dtype_from_torch(torch_dtype):
116
+ """Return the Warp dtype corresponding to a Torch dtype.
117
+
118
+ Args:
119
+ torch_dtype: A ``torch.dtype`` that has a corresponding Warp data type.
120
+ Currently ``torch.bfloat16``, ``torch.complex64``, and
121
+ ``torch.complex128`` are not supported.
122
+
123
+ Raises:
124
+ TypeError: Unable to find a corresponding Warp data type.
125
+ """
126
+ # initialize lookup table on first call to defer torch import
127
+ if dtype_from_torch.type_map is None:
128
+ import torch
129
+
130
+ dtype_from_torch.type_map = {
131
+ torch.float16: warp.float16,
132
+ torch.float32: warp.float32,
133
+ torch.float64: warp.float64,
134
+ torch.int8: warp.int8,
135
+ torch.int16: warp.int16,
136
+ torch.int32: warp.int32,
137
+ torch.int64: warp.int64,
138
+ torch.uint8: warp.uint8,
139
+ torch.bool: warp.bool,
140
+ # currently unsupported by Warp
141
+ # torch.bfloat16:
142
+ # torch.complex64:
143
+ # torch.complex128:
144
+ }
145
+
146
+ warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
147
+
148
+ if warp_dtype is not None:
149
+ return warp_dtype
150
+ else:
151
+ raise TypeError(f"Cannot convert {torch_dtype} to a Warp type")
152
+
153
+
154
+ def dtype_is_compatible(torch_dtype, warp_dtype) -> bool:
155
+ """Evaluates whether the given torch dtype is compatible with the given Warp dtype."""
156
+ # initialize lookup table on first call to defer torch import
157
+ if dtype_is_compatible.compatible_sets is None:
158
+ import torch
159
+
160
+ dtype_is_compatible.compatible_sets = {
161
+ torch.float64: {warp.float64},
162
+ torch.float32: {warp.float32},
163
+ torch.float16: {warp.float16},
164
+ # allow aliasing integer tensors as signed or unsigned integer arrays
165
+ torch.int64: {warp.int64, warp.uint64},
166
+ torch.int32: {warp.int32, warp.uint32},
167
+ torch.int16: {warp.int16, warp.uint16},
168
+ torch.int8: {warp.int8, warp.uint8},
169
+ torch.uint8: {warp.uint8, warp.int8},
170
+ torch.bool: {warp.bool, warp.uint8, warp.int8},
171
+ # currently unsupported by Warp
172
+ # torch.bfloat16:
173
+ # torch.complex64:
174
+ # torch.complex128:
175
+ }
176
+
177
+ compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
178
+
179
+ if compatible_set is not None:
180
+ if warp_dtype in compatible_set:
181
+ return True
182
+ # check if it's a vector or matrix type
183
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
184
+ return warp_dtype._wp_scalar_type_ in compatible_set
185
+
186
+ return False
187
+
188
+
189
+ # lookup tables initialized when needed
190
+ dtype_from_torch.type_map = None
191
+ dtype_to_torch.type_map = None
192
+ dtype_is_compatible.compatible_sets = None
193
+
194
+
195
+ # wrap a torch tensor to a wp array, data is not copied
196
+ def from_torch(t, dtype=None, requires_grad=None, grad=None, return_ctype=False):
197
+ """Convert a Torch tensor to a Warp array without copying the data.
198
+
199
+ Args:
200
+ t (torch.Tensor): The torch tensor to wrap.
201
+ dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
202
+ requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
203
+ return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels.
204
+
205
+ Returns:
206
+ warp.array: The wrapped array or array descriptor.
207
+ """
208
+ if dtype is None:
209
+ dtype = dtype_from_torch(t.dtype)
210
+ elif not dtype_is_compatible(t.dtype, dtype):
211
+ raise RuntimeError(f"Cannot convert Torch type {t.dtype} to Warp type {dtype}")
212
+
213
+ # get size of underlying data type to compute strides
214
+ ctype_size = ctypes.sizeof(dtype._type_)
215
+
216
+ shape = tuple(t.shape)
217
+ strides = tuple(s * ctype_size for s in t.stride())
218
+
219
+ # if target is a vector or matrix type
220
+ # then check if trailing dimensions match
221
+ # the target type and update the shape
222
+ if hasattr(dtype, "_shape_"):
223
+ dtype_shape = dtype._shape_
224
+ dtype_dims = len(dtype._shape_)
225
+ # ensure inner shape matches
226
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
227
+ raise RuntimeError(
228
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
229
+ )
230
+ # ensure inner strides are contiguous
231
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
232
+ raise RuntimeError(
233
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
234
+ )
235
+ # trim shape and strides
236
+ shape = tuple(shape[:-dtype_dims]) or (1,)
237
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
238
+
239
+ # gradient
240
+ # - if return_ctype is False, we set `grad` to a wp.array or None
241
+ # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or torch.Tensor)
242
+ requires_grad = t.requires_grad if requires_grad is None else requires_grad
243
+ grad_ptr = 0
244
+ if grad is not None:
245
+ if isinstance(grad, warp.array):
246
+ if return_ctype:
247
+ if grad.strides != strides:
248
+ raise RuntimeError(
249
+ f"Gradient strides must match array strides, expected {strides} but got {grad.strides}"
250
+ )
251
+ grad_ptr = grad.ptr
252
+ else:
253
+ # assume grad is a torch.Tensor
254
+ if return_ctype:
255
+ if t.stride() != grad.stride():
256
+ raise RuntimeError(
257
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
258
+ )
259
+ grad_ptr = grad.data_ptr()
260
+ else:
261
+ grad = from_torch(grad, dtype=dtype, requires_grad=False)
262
+ elif requires_grad:
263
+ # wrap the tensor gradient, allocate if necessary
264
+ if t.grad is not None:
265
+ if return_ctype:
266
+ grad = t.grad
267
+ if t.stride() != grad.stride():
268
+ raise RuntimeError(
269
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
270
+ )
271
+ grad_ptr = grad.data_ptr()
272
+ else:
273
+ grad = from_torch(t.grad, dtype=dtype, requires_grad=False)
274
+ else:
275
+ # allocate a zero-filled gradient if it doesn't exist
276
+ # Note: we use Warp to allocate the shared gradient with compatible strides
277
+ grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_torch(t.device))
278
+ t.grad = to_torch(grad, requires_grad=False)
279
+ grad_ptr = grad.ptr
280
+
281
+ if return_ctype:
282
+ ptr = t.data_ptr()
283
+
284
+ # create array descriptor
285
+ array_ctype = warp._src.types.array_t(ptr, grad_ptr, len(shape), shape, strides)
286
+
287
+ # keep data and gradient alive
288
+ array_ctype._ref = t
289
+ array_ctype._gradref = grad
290
+
291
+ return array_ctype
292
+
293
+ else:
294
+ a = warp.array(
295
+ ptr=t.data_ptr(),
296
+ dtype=dtype,
297
+ shape=shape,
298
+ strides=strides,
299
+ device=device_from_torch(t.device),
300
+ copy=False,
301
+ grad=grad,
302
+ requires_grad=requires_grad,
303
+ )
304
+
305
+ # save a reference to the source tensor, otherwise it may get deallocated
306
+ a._tensor = t
307
+
308
+ return a
309
+
310
+
311
+ def to_torch(a, requires_grad=None):
312
+ """
313
+ Convert a Warp array to a Torch tensor without copying the data.
314
+
315
+ Args:
316
+ a (warp.array): The Warp array to convert.
317
+ requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
318
+
319
+ Returns:
320
+ torch.Tensor: The converted tensor.
321
+ """
322
+ import torch
323
+
324
+ if requires_grad is None:
325
+ requires_grad = a.requires_grad
326
+
327
+ # Torch does not support structured arrays
328
+ if isinstance(a.dtype, warp._src.codegen.Struct):
329
+ raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
330
+
331
+ if a.device.is_cpu:
332
+ # Torch has an issue wrapping CPU objects
333
+ # that support the __array_interface__ protocol
334
+ # in this case we need to workaround by going
335
+ # to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
336
+ t = torch.as_tensor(numpy.asarray(a))
337
+ t.requires_grad = requires_grad
338
+ if requires_grad and a.requires_grad:
339
+ t.grad = torch.as_tensor(numpy.asarray(a.grad))
340
+ return t
341
+
342
+ elif a.device.is_cuda:
343
+ # Torch does support the __cuda_array_interface__
344
+ # correctly, but we must be sure to maintain a reference
345
+ # to the owning object to prevent memory allocs going out of scope
346
+ t = torch.as_tensor(a, device=device_to_torch(a.device))
347
+ t.requires_grad = requires_grad
348
+ if requires_grad and a.requires_grad:
349
+ t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
350
+ return t
351
+
352
+ else:
353
+ raise RuntimeError("Unsupported device")
354
+
355
+
356
+ def stream_from_torch(stream_or_device=None):
357
+ """Convert from a Torch CUDA stream to a Warp CUDA stream."""
358
+ import torch
359
+
360
+ if isinstance(stream_or_device, torch.cuda.Stream):
361
+ stream = stream_or_device
362
+ else:
363
+ # assume arg is a torch device
364
+ stream = torch.cuda.current_stream(stream_or_device)
365
+
366
+ device = device_from_torch(stream.device)
367
+
368
+ warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
369
+
370
+ # save a reference to the source stream, otherwise it may be destroyed
371
+ warp_stream._torch_stream = stream
372
+
373
+ return warp_stream
374
+
375
+
376
+ def stream_to_torch(stream_or_device=None):
377
+ """Convert from a Warp CUDA stream to a Torch CUDA stream."""
378
+ import torch
379
+
380
+ if isinstance(stream_or_device, warp.Stream):
381
+ stream = stream_or_device
382
+ else:
383
+ # assume arg is a warp device
384
+ stream = warp.get_device(stream_or_device).stream
385
+
386
+ device = device_to_torch(stream.device)
387
+
388
+ torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
389
+
390
+ # save a reference to the source stream, otherwise it may be destroyed
391
+ torch_stream._warp_stream = stream
392
+
393
+ return torch_stream