warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/_src/optim/sgd.py ADDED
@@ -0,0 +1,114 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ import warp as wp
19
+
20
+ _wp_module_name_ = "warp.optim.sgd"
21
+
22
+
23
+ @wp.kernel
24
+ def sgd_step_kernel(
25
+ g: wp.array(dtype=Any),
26
+ b: wp.array(dtype=Any),
27
+ lr: float,
28
+ weight_decay: float,
29
+ momentum: float,
30
+ damping: float,
31
+ nesterov: int,
32
+ t: int,
33
+ params: wp.array(dtype=Any),
34
+ ):
35
+ i = wp.tid()
36
+ gt = g[i]
37
+ if weight_decay != 0.0:
38
+ gt += weight_decay * params[i]
39
+ if momentum != 0.0:
40
+ bt = b[i]
41
+ if t > 0:
42
+ bt = momentum * bt + (1.0 - damping) * gt
43
+ else:
44
+ bt = gt
45
+ if nesterov == 1:
46
+ gt += momentum * bt
47
+ else:
48
+ gt = bt
49
+ b[i] = bt
50
+ params[i] = params[i] - lr * gt
51
+
52
+
53
+ class SGD:
54
+ """An implementation of the Stochastic Gradient Descent Optimizer
55
+ It is designed to mimic Pytorch's version.
56
+ https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
57
+ """
58
+
59
+ def __init__(self, params=None, lr=0.001, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False):
60
+ self.b = [] # momentum buffer
61
+ self.set_params(params)
62
+ self.lr = lr
63
+ self.momentum = momentum
64
+ self.dampening = dampening
65
+ self.weight_decay = weight_decay
66
+ self.nesterov = nesterov
67
+ self.t = 0
68
+
69
+ def set_params(self, params):
70
+ self.params = params
71
+ if params is not None and isinstance(params, list) and len(params) > 0:
72
+ if len(self.b) != len(params):
73
+ self.b = [None] * len(params)
74
+ for i in range(len(params)):
75
+ param = params[i]
76
+ if self.b[i] is None or self.b[i].shape != param.shape or self.b[i].dtype != param.dtype:
77
+ self.b[i] = wp.zeros_like(param)
78
+ # Overload the kernel for each parameter so we can precompile the SGD kernel
79
+ if param is not None:
80
+ wp.overload(sgd_step_kernel, {"g": param, "b": param, "params": param})
81
+
82
+ def reset_internal_state(self):
83
+ for b_i in self.b:
84
+ b_i.zero_()
85
+ self.t = 0
86
+
87
+ def step(self, grad):
88
+ assert self.params is not None
89
+ for i in range(len(self.params)):
90
+ SGD.step_detail(
91
+ grad[i],
92
+ self.b[i],
93
+ self.lr,
94
+ self.momentum,
95
+ self.dampening,
96
+ self.weight_decay,
97
+ self.nesterov,
98
+ self.t,
99
+ self.params[i],
100
+ )
101
+ self.t = self.t + 1
102
+
103
+ @staticmethod
104
+ def step_detail(g, b, lr, momentum, dampening, weight_decay, nesterov, t, params):
105
+ assert params.dtype == g.dtype
106
+ assert params.dtype == b.dtype
107
+ assert params.shape == g.shape
108
+ kernel_inputs = [g, b, lr, momentum, dampening, weight_decay, int(nesterov), t, params]
109
+ wp.launch(
110
+ kernel=sgd_step_kernel,
111
+ dim=len(params),
112
+ inputs=kernel_inputs,
113
+ device=params.device,
114
+ )
warp/_src/paddle.py ADDED
@@ -0,0 +1,408 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import ctypes
19
+ from typing import TYPE_CHECKING
20
+
21
+ import numpy
22
+
23
+ import warp
24
+ import warp._src.context
25
+
26
+ if TYPE_CHECKING:
27
+ import paddle
28
+ from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place
29
+
30
+ _wp_module_name_ = "warp.paddle"
31
+
32
+
33
+ # return the warp device corresponding to a paddle device
34
+ def device_from_paddle(paddle_device: Place | CPUPlace | CUDAPinnedPlace | CUDAPlace | str) -> warp._src.context.Device:
35
+ """Return the Warp device corresponding to a Paddle device.
36
+
37
+ Args:
38
+ paddle_device (`Place`, `CPUPlace`, `CUDAPinnedPlace`, `CUDAPlace`, or `str`): Paddle device identifier
39
+
40
+ Raises:
41
+ RuntimeError: Paddle device does not have a corresponding Warp device
42
+ """
43
+ if type(paddle_device) is str:
44
+ if paddle_device.startswith("gpu:"):
45
+ paddle_device = paddle_device.replace("gpu:", "cuda:")
46
+ warp_device = warp._src.context.runtime.device_map.get(paddle_device)
47
+ if warp_device is not None:
48
+ return warp_device
49
+ elif paddle_device == "gpu":
50
+ return warp._src.context.runtime.get_current_cuda_device()
51
+ else:
52
+ raise RuntimeError(f"Unsupported Paddle device {paddle_device}")
53
+ else:
54
+ try:
55
+ from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place
56
+
57
+ if isinstance(paddle_device, Place):
58
+ if paddle_device.is_gpu_place():
59
+ return warp._src.context.runtime.cuda_devices[paddle_device.gpu_device_id()]
60
+ elif paddle_device.is_cpu_place():
61
+ return warp._src.context.runtime.cpu_device
62
+ else:
63
+ raise RuntimeError(f"Unsupported Paddle device type {paddle_device}")
64
+ elif isinstance(paddle_device, (CPUPlace, CUDAPinnedPlace)):
65
+ return warp._src.context.runtime.cpu_device
66
+ elif isinstance(paddle_device, CUDAPlace):
67
+ return warp._src.context.runtime.cuda_devices[paddle_device.get_device_id()]
68
+ else:
69
+ raise RuntimeError(f"Unsupported Paddle device type {paddle_device}")
70
+ except ModuleNotFoundError as e:
71
+ raise ModuleNotFoundError("Please install paddlepaddle first.") from e
72
+ except Exception as e:
73
+ if not isinstance(paddle_device, (Place, CPUPlace, CUDAPinnedPlace, CUDAPlace)):
74
+ raise TypeError(
75
+ "device_from_paddle() received an invalid argument - "
76
+ f"got {paddle_device}({type(paddle_device)}), but expected one of:\n"
77
+ "* paddle.base.libpaddle.Place\n"
78
+ "* paddle.CPUPlace\n"
79
+ "* paddle.CUDAPinnedPlace\n"
80
+ "* paddle.CUDAPlace or 'gpu' or 'gpu:x'(x means device id)"
81
+ ) from e
82
+ raise
83
+
84
+
85
+ def device_to_paddle(warp_device: warp._src.context.Devicelike) -> str:
86
+ """Return the Paddle device string corresponding to a Warp device.
87
+
88
+ Args:
89
+ warp_device: An identifier that can be resolved to a :class:`warp._src.context.Device`.
90
+
91
+ Raises:
92
+ RuntimeError: The Warp device is not compatible with PyPaddle.
93
+ """
94
+ device = warp.get_device(warp_device)
95
+ if device.is_cpu or device.is_primary:
96
+ return str(device).replace("cuda", "gpu")
97
+ elif device.is_cuda and device.is_uva:
98
+ # it's not a primary context, but paddle can access the data ptr directly thanks to UVA
99
+ return f"gpu:{device.ordinal}"
100
+ raise RuntimeError(f"Warp device {device} is not compatible with paddle")
101
+
102
+
103
+ def dtype_to_paddle(warp_dtype):
104
+ """Return the Paddle dtype corresponding to a Warp dtype.
105
+
106
+ Args:
107
+ warp_dtype: A Warp data type that has a corresponding ``paddle.dtype``.
108
+ ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
109
+ to the signed integer ``paddle.dtype`` of the same width.
110
+ Raises:
111
+ TypeError: Unable to find a corresponding PyPaddle data type.
112
+ """
113
+ # initialize lookup table on first call to defer paddle import
114
+ if dtype_to_paddle.type_map is None:
115
+ import paddle
116
+
117
+ dtype_to_paddle.type_map = {
118
+ warp.float16: paddle.float16,
119
+ warp.float32: paddle.float32,
120
+ warp.float64: paddle.float64,
121
+ warp.int8: paddle.int8,
122
+ warp.int16: paddle.int16,
123
+ warp.int32: paddle.int32,
124
+ warp.int64: paddle.int64,
125
+ warp.uint8: paddle.uint8,
126
+ warp.bool: paddle.bool,
127
+ # paddle doesn't support unsigned ints bigger than 8 bits
128
+ warp.uint16: paddle.int16,
129
+ warp.uint32: paddle.int32,
130
+ warp.uint64: paddle.int64,
131
+ }
132
+
133
+ paddle_dtype = dtype_to_paddle.type_map.get(warp_dtype)
134
+ if paddle_dtype is not None:
135
+ return paddle_dtype
136
+ else:
137
+ raise TypeError(f"Cannot convert {warp_dtype} to a Paddle type")
138
+
139
+
140
+ def dtype_from_paddle(paddle_dtype):
141
+ """Return the Warp dtype corresponding to a Paddle dtype.
142
+
143
+ Args:
144
+ paddle_dtype: A ``paddle.dtype`` that has a corresponding Warp data type.
145
+ Currently ``paddle.bfloat16``, ``paddle.complex64``, and
146
+ ``paddle.complex128`` are not supported.
147
+
148
+ Raises:
149
+ TypeError: Unable to find a corresponding Warp data type.
150
+ """
151
+ # initialize lookup table on first call to defer paddle import
152
+ if dtype_from_paddle.type_map is None:
153
+ import paddle
154
+
155
+ dtype_from_paddle.type_map = {
156
+ paddle.float16: warp.float16,
157
+ paddle.float32: warp.float32,
158
+ paddle.float64: warp.float64,
159
+ paddle.int8: warp.int8,
160
+ paddle.int16: warp.int16,
161
+ paddle.int32: warp.int32,
162
+ paddle.int64: warp.int64,
163
+ paddle.uint8: warp.uint8,
164
+ paddle.bool: warp.bool,
165
+ # currently unsupported by Warp
166
+ # paddle.bfloat16:
167
+ # paddle.complex64:
168
+ # paddle.complex128:
169
+ }
170
+
171
+ warp_dtype = dtype_from_paddle.type_map.get(paddle_dtype)
172
+
173
+ if warp_dtype is not None:
174
+ return warp_dtype
175
+ else:
176
+ raise TypeError(f"Cannot convert {paddle_dtype} to a Warp type")
177
+
178
+
179
+ def dtype_is_compatible(paddle_dtype: paddle.dtype, warp_dtype) -> bool:
180
+ """Evaluates whether the given paddle dtype is compatible with the given Warp dtype."""
181
+ # initialize lookup table on first call to defer paddle import
182
+ if dtype_is_compatible.compatible_sets is None:
183
+ import paddle
184
+
185
+ dtype_is_compatible.compatible_sets = {
186
+ paddle.float64: {warp.float64},
187
+ paddle.float32: {warp.float32},
188
+ paddle.float16: {warp.float16},
189
+ # allow aliasing integer tensors as signed or unsigned integer arrays
190
+ paddle.int64: {warp.int64, warp.uint64},
191
+ paddle.int32: {warp.int32, warp.uint32},
192
+ paddle.int16: {warp.int16, warp.uint16},
193
+ paddle.int8: {warp.int8, warp.uint8},
194
+ paddle.uint8: {warp.uint8, warp.int8},
195
+ paddle.bool: {warp.bool, warp.uint8, warp.int8},
196
+ # currently unsupported by Warp
197
+ # paddle.bfloat16:
198
+ # paddle.complex64:
199
+ # paddle.complex128:
200
+ }
201
+
202
+ compatible_set = dtype_is_compatible.compatible_sets.get(paddle_dtype)
203
+
204
+ if compatible_set is not None:
205
+ if warp_dtype in compatible_set:
206
+ return True
207
+ # check if it's a vector or matrix type
208
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
209
+ return warp_dtype._wp_scalar_type_ in compatible_set
210
+
211
+ return False
212
+
213
+
214
+ # lookup tables initialized when needed
215
+ dtype_from_paddle.type_map = None
216
+ dtype_to_paddle.type_map = None
217
+ dtype_is_compatible.compatible_sets = None
218
+
219
+
220
+ # wrap a paddle tensor to a wp array, data is not copied
221
+ def from_paddle(
222
+ t: paddle.Tensor,
223
+ dtype: paddle.dtype | None = None,
224
+ requires_grad: bool | None = None,
225
+ grad: paddle.Tensor | None = None,
226
+ return_ctype: bool = False,
227
+ ) -> warp.array:
228
+ """Convert a Paddle tensor to a Warp array without copying the data.
229
+
230
+ Args:
231
+ t (paddle.Tensor): The paddle tensor to wrap.
232
+ dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
233
+ requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
234
+ grad (paddle.Tensor, optional): The grad attached to given tensor. Defaults to None.
235
+ return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels.
236
+
237
+ Returns:
238
+ warp.array: The wrapped array or array descriptor.
239
+ """
240
+ if dtype is None:
241
+ dtype = dtype_from_paddle(t.dtype)
242
+ elif not dtype_is_compatible(t.dtype, dtype):
243
+ raise RuntimeError(f"Cannot convert Paddle type {t.dtype} to Warp type {dtype}")
244
+
245
+ # get size of underlying data type to compute strides
246
+ ctype_size = ctypes.sizeof(dtype._type_)
247
+
248
+ shape = tuple(t.shape)
249
+ strides = tuple(s * ctype_size for s in t.strides)
250
+
251
+ # if target is a vector or matrix type
252
+ # then check if trailing dimensions match
253
+ # the target type and update the shape
254
+ if hasattr(dtype, "_shape_"):
255
+ dtype_shape = dtype._shape_
256
+ dtype_dims = len(dtype._shape_)
257
+ # ensure inner shape matches
258
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
259
+ raise RuntimeError(
260
+ f"Could not convert Paddle tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
261
+ )
262
+ # ensure inner strides are contiguous
263
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
264
+ raise RuntimeError(
265
+ f"Could not convert Paddle tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
266
+ )
267
+ # trim shape and strides
268
+ shape = tuple(shape[:-dtype_dims]) or (1,)
269
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
270
+
271
+ # gradient
272
+ # - if return_ctype is False, we set `grad` to a wp.array or None
273
+ # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or paddle.Tensor)
274
+ requires_grad = (not t.stop_gradient) if requires_grad is None else requires_grad
275
+ grad_ptr = 0
276
+ if grad is not None:
277
+ if isinstance(grad, warp.array):
278
+ if return_ctype:
279
+ if grad.strides != strides:
280
+ raise RuntimeError(
281
+ f"Gradient strides must match array strides, expected {strides} but got {grad.strides}"
282
+ )
283
+ grad_ptr = grad.ptr
284
+ else:
285
+ # assume grad is a paddle.Tensor
286
+ if return_ctype:
287
+ if t.strides != grad.strides:
288
+ raise RuntimeError(
289
+ f"Gradient strides must match array strides, expected {t.strides} but got {grad.strides}"
290
+ )
291
+ grad_ptr = grad.data_ptr()
292
+ else:
293
+ grad = from_paddle(grad, dtype=dtype, requires_grad=False)
294
+ elif requires_grad:
295
+ # wrap the tensor gradient, allocate if necessary
296
+ if t.grad is not None:
297
+ if return_ctype:
298
+ grad = t.grad
299
+ if t.strides != grad.strides:
300
+ raise RuntimeError(
301
+ f"Gradient strides must match array strides, expected {t.strides} but got {grad.strides}"
302
+ )
303
+ grad_ptr = grad.data_ptr()
304
+ else:
305
+ grad = from_paddle(t.grad, dtype=dtype, requires_grad=False)
306
+ else:
307
+ # allocate a zero-filled gradient if it doesn't exist
308
+ # Note: we use Warp to allocate the shared gradient with compatible strides
309
+ grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_paddle(t.place))
310
+ # use .grad_ for zero-copy
311
+ t.grad_ = to_paddle(grad, requires_grad=False)
312
+ grad_ptr = grad.ptr
313
+
314
+ if return_ctype:
315
+ ptr = t.data_ptr()
316
+
317
+ # create array descriptor
318
+ array_ctype = warp._src.types.array_t(ptr, grad_ptr, len(shape), shape, strides)
319
+
320
+ # keep data and gradient alive
321
+ array_ctype._ref = t
322
+ array_ctype._gradref = grad
323
+
324
+ return array_ctype
325
+
326
+ else:
327
+ a = warp.array(
328
+ ptr=t.data_ptr(),
329
+ dtype=dtype,
330
+ shape=shape,
331
+ strides=strides,
332
+ device=device_from_paddle(t.place),
333
+ copy=False,
334
+ grad=grad,
335
+ requires_grad=requires_grad,
336
+ )
337
+
338
+ # save a reference to the source tensor, otherwise it may get deallocated
339
+ a._tensor = t
340
+
341
+ return a
342
+
343
+
344
+ def to_paddle(a: warp.array, requires_grad: bool | None = None) -> paddle.Tensor:
345
+ """Convert a Warp array to a Paddle tensor without copying the data.
346
+
347
+ Args:
348
+ a (warp.array): The Warp array to convert.
349
+ requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
350
+
351
+ Returns:
352
+ paddle.Tensor: The converted tensor.
353
+ """
354
+ import paddle
355
+ import paddle.utils.dlpack
356
+
357
+ if requires_grad is None:
358
+ requires_grad = a.requires_grad
359
+
360
+ # Paddle does not support structured arrays
361
+ if isinstance(a.dtype, warp._src.codegen.Struct):
362
+ raise RuntimeError("Cannot convert structured Warp arrays to Paddle.")
363
+
364
+ if a.device.is_cpu:
365
+ # Paddle has an issue wrapping CPU objects
366
+ # that support the __array_interface__ protocol
367
+ # in this case we need to workaround by going
368
+ # to an ndarray first, see https://pearu.github.io/array_interface_pypaddle.html
369
+ t = paddle.to_tensor(numpy.asarray(a), place="cpu")
370
+ t.stop_gradient = not requires_grad
371
+ if requires_grad and a.requires_grad:
372
+ # use .grad_ for zero-copy
373
+ t.grad_ = paddle.to_tensor(numpy.asarray(a.grad), place="cpu")
374
+ return t
375
+
376
+ elif a.device.is_cuda:
377
+ # Paddle does support the __cuda_array_interface__
378
+ # correctly, but we must be sure to maintain a reference
379
+ # to the owning object to prevent memory allocs going out of scope
380
+ t = paddle.utils.dlpack.from_dlpack(warp.to_dlpack(a)).to(device=device_to_paddle(a.device))
381
+ t.stop_gradient = not requires_grad
382
+ if requires_grad and a.requires_grad:
383
+ # use .grad_ for zero-copy
384
+ t.grad_ = paddle.utils.dlpack.from_dlpack(warp.to_dlpack(a.grad)).to(device=device_to_paddle(a.device))
385
+ return t
386
+
387
+ else:
388
+ raise RuntimeError("Unsupported device")
389
+
390
+
391
+ def stream_from_paddle(stream_or_device=None):
392
+ """Convert from a Paddle CUDA stream to a Warp CUDA stream."""
393
+ import paddle
394
+
395
+ if isinstance(stream_or_device, paddle.device.Stream):
396
+ stream = stream_or_device
397
+ else:
398
+ # assume arg is a paddle device
399
+ stream = paddle.device.current_stream(stream_or_device)
400
+
401
+ device = device_from_paddle(stream.device)
402
+
403
+ warp_stream = warp.Stream(device, cuda_stream=stream.stream_base.cuda_stream)
404
+
405
+ # save a reference to the source stream, otherwise it may be destroyed
406
+ warp_stream._paddle_stream = stream
407
+
408
+ return warp_stream
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.