warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,120 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+ import taichi as ti
18
+
19
+
20
+ @ti.func
21
+ def step(x):
22
+ ret = 0.0
23
+ if x < 0:
24
+ ret = 1
25
+ return ret
26
+
27
+
28
+ @ti.data_oriented
29
+ class TiIntegrator:
30
+ @ti.kernel
31
+ def eval_springs(self):
32
+ for tid in range(self.cloth.num_springs):
33
+ i = self.spring_indices[2 * tid]
34
+ j = self.spring_indices[2 * tid + 1]
35
+
36
+ ke = self.spring_stiffness[tid]
37
+ kd = self.spring_damping[tid]
38
+ rest = self.spring_lengths[tid]
39
+
40
+ xi = self.positions[i]
41
+ xj = self.positions[j]
42
+
43
+ vi = self.velocities[i]
44
+ vj = self.velocities[j]
45
+
46
+ xij = xi - xj
47
+ vij = vi - vj
48
+
49
+ l = xij.norm()
50
+ dir = xij.normalized()
51
+
52
+ c = l - rest
53
+ dcdt = dir.dot(vij)
54
+
55
+ fs = dir * (ke * c + kd * dcdt)
56
+
57
+ self.forces[i] -= fs
58
+ self.forces[j] += fs
59
+
60
+ @ti.kernel
61
+ def integrate_particles(self, dt: ti.f32):
62
+ for tid in range(self.cloth.num_particles):
63
+ x0 = self.positions[tid]
64
+ v0 = self.velocities[tid]
65
+ f0 = self.forces[tid]
66
+ w = self.inv_mass[tid]
67
+
68
+ g = ti.Vector([0.0, 0.0, 0.0])
69
+
70
+ if w > 0.0:
71
+ g = ti.Vector([0.0, -9.81, 0.0])
72
+
73
+ v1 = v0 + (f0 * w + g) * dt
74
+ x1 = x0 + v1 * dt
75
+
76
+ self.positions[tid] = x1
77
+ self.velocities[tid] = v1
78
+ self.forces[tid] = ti.Vector([0.0, 0.0, 0.0])
79
+
80
+ def __init__(self, cloth, device):
81
+ if device == "cpu":
82
+ ti.init(arch=ti.cpu)
83
+ elif device == "cuda":
84
+ ti.init(arch=ti.gpu)
85
+ else:
86
+ raise RuntimeError("Unsupported Taichi device")
87
+
88
+ self.cloth = cloth
89
+
90
+ self.positions = ti.Vector.field(3, dtype=ti.f32, shape=self.cloth.num_particles)
91
+ self.velocities = ti.Vector.field(3, dtype=ti.f32, shape=self.cloth.num_particles)
92
+ self.inv_mass = ti.field(ti.f32, shape=self.cloth.num_particles)
93
+
94
+ self.spring_indices = ti.field(ti.i32, shape=self.cloth.num_springs * 2)
95
+ self.spring_lengths = ti.field(ti.f32, shape=self.cloth.num_springs)
96
+ self.spring_stiffness = ti.field(ti.f32, shape=self.cloth.num_springs)
97
+ self.spring_damping = ti.field(ti.f32, shape=self.cloth.num_springs)
98
+
99
+ self.forces = ti.Vector.field(3, dtype=ti.f32, shape=self.cloth.num_particles)
100
+
101
+ # upload data
102
+ self.positions.from_numpy(cloth.positions)
103
+ self.velocities.from_numpy(cloth.velocities)
104
+ self.inv_mass.from_numpy(cloth.inv_masses)
105
+ self.forces.from_numpy(np.zeros_like(self.cloth.velocities))
106
+
107
+ self.spring_indices.from_numpy(cloth.spring_indices)
108
+ self.spring_lengths.from_numpy(cloth.spring_lengths)
109
+ self.spring_stiffness.from_numpy(cloth.spring_stiffness)
110
+ self.spring_damping.from_numpy(cloth.spring_damping)
111
+
112
+ def simulate(self, dt, substeps):
113
+ sim_dt = dt / substeps
114
+
115
+ for _s in range(substeps):
116
+ self.eval_springs()
117
+
118
+ self.integrate_particles(sim_dt)
119
+
120
+ return self.positions.to_numpy()
@@ -0,0 +1,153 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+
18
+ wp.clear_kernel_cache()
19
+
20
+
21
+ @wp.kernel
22
+ def eval_springs(
23
+ x: wp.array(dtype=wp.vec3),
24
+ v: wp.array(dtype=wp.vec3),
25
+ spring_indices: wp.array(dtype=int),
26
+ spring_rest_lengths: wp.array(dtype=float),
27
+ spring_stiffness: wp.array(dtype=float),
28
+ spring_damping: wp.array(dtype=float),
29
+ f: wp.array(dtype=wp.vec3),
30
+ ):
31
+ tid = wp.tid()
32
+
33
+ i = spring_indices[tid * 2 + 0]
34
+ j = spring_indices[tid * 2 + 1]
35
+
36
+ ke = spring_stiffness[tid]
37
+ kd = spring_damping[tid]
38
+ rest = spring_rest_lengths[tid]
39
+
40
+ xi = x[i]
41
+ xj = x[j]
42
+
43
+ vi = v[i]
44
+ vj = v[j]
45
+
46
+ xij = xi - xj
47
+ vij = vi - vj
48
+
49
+ l = wp.length(xij)
50
+ l_inv = 1.0 / l
51
+
52
+ # normalized spring direction
53
+ dir = xij * l_inv
54
+
55
+ c = l - rest
56
+ dcdt = wp.dot(dir, vij)
57
+
58
+ # damping based on relative velocity.
59
+ fs = dir * (ke * c + kd * dcdt)
60
+
61
+ wp.atomic_sub(f, i, fs)
62
+ wp.atomic_add(f, j, fs)
63
+
64
+
65
+ @wp.kernel
66
+ def integrate_particles(
67
+ x: wp.array(dtype=wp.vec3),
68
+ v: wp.array(dtype=wp.vec3),
69
+ f: wp.array(dtype=wp.vec3),
70
+ w: wp.array(dtype=float),
71
+ dt: float,
72
+ ):
73
+ tid = wp.tid()
74
+
75
+ x0 = x[tid]
76
+ v0 = v[tid]
77
+ f0 = f[tid]
78
+ inv_mass = w[tid]
79
+
80
+ g = wp.vec3()
81
+
82
+ # treat particles with inv_mass == 0 as kinematic
83
+ if inv_mass > 0.0:
84
+ g = wp.vec3(0.0, 0.0 - 9.81, 0.0)
85
+
86
+ # simple semi-implicit Euler. v1 = v0 + a dt, x1 = x0 + v1 dt
87
+ v1 = v0 + (f0 * inv_mass + g) * dt
88
+ x1 = x0 + v1 * dt
89
+
90
+ x[tid] = x1
91
+ v[tid] = v1
92
+
93
+ # clear forces
94
+ f[tid] = wp.vec3()
95
+
96
+
97
+ class WpIntegrator:
98
+ def __init__(self, cloth, device):
99
+ self.device = wp.get_device(device)
100
+
101
+ with wp.ScopedDevice(self.device):
102
+ self.positions = wp.from_numpy(cloth.positions, dtype=wp.vec3)
103
+ self.positions_host = wp.from_numpy(cloth.positions, dtype=wp.vec3, device="cpu")
104
+ self.invmass = wp.from_numpy(cloth.inv_masses, dtype=float)
105
+
106
+ self.velocities = wp.zeros(cloth.num_particles, dtype=wp.vec3)
107
+ self.forces = wp.zeros(cloth.num_particles, dtype=wp.vec3)
108
+
109
+ self.spring_indices = wp.from_numpy(cloth.spring_indices, dtype=int)
110
+ self.spring_lengths = wp.from_numpy(cloth.spring_lengths, dtype=float)
111
+ self.spring_stiffness = wp.from_numpy(cloth.spring_stiffness, dtype=float)
112
+ self.spring_damping = wp.from_numpy(cloth.spring_damping, dtype=float)
113
+
114
+ self.cloth = cloth
115
+
116
+ def simulate(self, dt, substeps):
117
+ sim_dt = dt / substeps
118
+
119
+ for _s in range(substeps):
120
+ wp.launch(
121
+ kernel=eval_springs,
122
+ dim=self.cloth.num_springs,
123
+ inputs=[
124
+ self.positions,
125
+ self.velocities,
126
+ self.spring_indices,
127
+ self.spring_lengths,
128
+ self.spring_stiffness,
129
+ self.spring_damping,
130
+ self.forces,
131
+ ],
132
+ outputs=[],
133
+ device=self.device,
134
+ )
135
+
136
+ # integrate
137
+ wp.launch(
138
+ kernel=integrate_particles,
139
+ dim=self.cloth.num_particles,
140
+ inputs=[self.positions, self.velocities, self.forces, self.invmass, sim_dt],
141
+ outputs=[],
142
+ device=self.device,
143
+ )
144
+
145
+ # copy data back to host
146
+ if self.device.is_cuda:
147
+ wp.copy(self.positions_host, self.positions)
148
+ wp.synchronize()
149
+
150
+ return self.positions_host.numpy()
151
+
152
+ else:
153
+ return self.positions.numpy()
@@ -0,0 +1,164 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Compare GEMM performance between Torch and Warp (Tiled).
17
+
18
+ This script can be used to identify optimal tile parameters for a fixed-size
19
+ matrix multiplication.
20
+ """
21
+
22
+ from itertools import product
23
+ from statistics import mean, stdev
24
+ from typing import List
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ import warp as wp
30
+
31
+
32
+ # returns a kernel to compute a GEMM given m,n,k tile sizes
33
+ def create_gemm_kernel(m, n, k):
34
+ TILE_M = m
35
+ TILE_N = n
36
+ TILE_K = k
37
+
38
+ @wp.kernel
39
+ def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
40
+ i, j = wp.tid()
41
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
42
+
43
+ count = A.shape[1] // TILE_K
44
+
45
+ for k in range(count):
46
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
47
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
48
+
49
+ wp.tile_matmul(a, b, sum)
50
+
51
+ wp.tile_store(output, sum, offset=(i * TILE_M, j * TILE_N))
52
+
53
+ return gemm
54
+
55
+
56
+ def benchmark_torch(A: torch.Tensor, B: torch.Tensor, warm_up: int, iterations: int):
57
+ # warm-up
58
+ for _ in range(warm_up):
59
+ torch.matmul(A, B)
60
+
61
+ torch.cuda.synchronize()
62
+
63
+ start_event = torch.cuda.Event(enable_timing=True)
64
+ end_event = torch.cuda.Event(enable_timing=True)
65
+
66
+ timing_results = []
67
+
68
+ for _i in range(iterations):
69
+ start_event.record()
70
+ torch.matmul(A, B)
71
+ end_event.record()
72
+
73
+ torch.cuda.synchronize()
74
+ timing_results.append(start_event.elapsed_time(end_event))
75
+
76
+ return mean(timing_results), stdev(timing_results)
77
+
78
+
79
+ def benchmark_warp(A: wp.array, B: wp.array, config: List[int], warm_up: int, iterations: int):
80
+ TILE_M = config[0]
81
+ TILE_N = config[1]
82
+ TILE_K = config[2]
83
+ BLOCK_DIM = config[3]
84
+
85
+ mlp = create_gemm_kernel(TILE_M, TILE_N, TILE_K)
86
+
87
+ M = A.shape[0]
88
+ N = B.shape[1]
89
+
90
+ output = wp.zeros((M, N), dtype=float)
91
+
92
+ # create launch command
93
+ cmd = wp.launch_tiled(
94
+ kernel=mlp,
95
+ dim=[M // TILE_M, N // TILE_N],
96
+ inputs=[A, B, output],
97
+ block_dim=BLOCK_DIM,
98
+ record_cmd=True,
99
+ )
100
+
101
+ # warm-up
102
+ for _ in range(warm_up):
103
+ cmd.launch()
104
+
105
+ # check output
106
+ if warm_up > 0:
107
+ try:
108
+ np.testing.assert_allclose(output.numpy(), A.numpy() @ B.numpy(), atol=1e-3, rtol=1e-3)
109
+ except AssertionError as e:
110
+ print(f"Failed with {TILE_M=}, {TILE_N=}, {TILE_K=}, {BLOCK_DIM=}")
111
+ raise e
112
+
113
+ # benchmark
114
+ with wp.ScopedTimer("warp", print=False, synchronize=True, cuda_filter=wp.TIMING_KERNEL) as timer:
115
+ for _ in range(iterations):
116
+ cmd.launch()
117
+
118
+ timing_results = [result.elapsed for result in timer.timing_results]
119
+
120
+ return mean(timing_results), stdev(timing_results)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ torch.backends.cuda.matmul.allow_tf32 = False # Disable TF32 for matrix multiplications
125
+ torch.backends.cudnn.allow_tf32 = False # Disable TF32 for cuDNN operations
126
+
127
+ wp.init()
128
+ wp.clear_kernel_cache()
129
+ wp.set_module_options({"fast_math": True, "enable_backward": False})
130
+
131
+ tile_m = [8, 16, 32, 64]
132
+ tile_n = [8, 16, 32, 64]
133
+ tile_k = [8, 16, 64]
134
+ block = [32, 64, 128]
135
+
136
+ M = 1024
137
+ N = 1024
138
+ K = 1024
139
+ print(f"{M=}, {N=}, {K=}")
140
+
141
+ A = torch.randn(M, K).cuda()
142
+ B = torch.randn(K, N).cuda()
143
+
144
+ iterations = 100
145
+ warm_up = 5
146
+
147
+ time_torch_mean, time_torch_std = benchmark_torch(A, B, warm_up, iterations)
148
+ print(f"Torch: {time_torch_mean:.6g}±{time_torch_std:.2g} ms")
149
+
150
+ configs = list(product(tile_m, tile_n, tile_k, block))
151
+
152
+ wp.config.quiet = True
153
+
154
+ # header
155
+ print(
156
+ f"{'TILE_M':<8s} {'TILE_N':<8s} {'TILE_K':<8s} {'BLOCK':<8s} {'Time (ms)':<10s} {'Std dev (ms)':<14s} {'Warp/Torch':<12s}"
157
+ )
158
+ print("-" * 79)
159
+
160
+ for c in configs:
161
+ time_mean, time_std = benchmark_warp(wp.from_torch(A), wp.from_torch(B), c, warm_up, iterations)
162
+ print(
163
+ f"{c[0]:<8d} {c[1]:<8d} {c[2]:<8d} {c[3]:<8d} {time_mean:<10.6g} {time_std:<#14.2g} {time_mean / time_torch_mean:<12.6g}"
164
+ )
@@ -0,0 +1,166 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import time
17
+
18
+ import paddle
19
+
20
+ import warp as wp
21
+
22
+
23
+ def create_simple_kernel(dtype):
24
+ def simple_kernel(
25
+ a: wp.array(dtype=dtype),
26
+ b: wp.array(dtype=dtype),
27
+ c: wp.array(dtype=dtype),
28
+ d: wp.array(dtype=dtype),
29
+ e: wp.array(dtype=dtype),
30
+ ):
31
+ pass
32
+
33
+ return wp.Kernel(simple_kernel)
34
+
35
+
36
+ def test_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
37
+ warp_device = wp.get_device(device)
38
+ paddle_device = wp.device_to_paddle(warp_device)
39
+
40
+ if hasattr(warp_dtype, "_shape_"):
41
+ paddle_shape = (array_size, *warp_dtype._shape_)
42
+ paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
43
+ else:
44
+ paddle_shape = (array_size,)
45
+ paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)
46
+
47
+ _a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
48
+ _b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
49
+ _c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
50
+ _d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
51
+ _e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
52
+
53
+ wp.synchronize()
54
+
55
+ # profiler = Profiler(interval=0.000001)
56
+ # profiler.start()
57
+
58
+ t1 = time.time_ns()
59
+
60
+ for _ in range(num_iters):
61
+ a = wp.from_paddle(_a, dtype=warp_dtype)
62
+ b = wp.from_paddle(_b, dtype=warp_dtype)
63
+ c = wp.from_paddle(_c, dtype=warp_dtype)
64
+ d = wp.from_paddle(_d, dtype=warp_dtype)
65
+ e = wp.from_paddle(_e, dtype=warp_dtype)
66
+ wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])
67
+
68
+ t2 = time.time_ns()
69
+ print(f"{(t2 - t1) / 1_000_000:8.0f} ms from_paddle(...)")
70
+
71
+ # profiler.stop()
72
+ # profiler.print()
73
+
74
+
75
+ def test_array_ctype_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
76
+ warp_device = wp.get_device(device)
77
+ paddle_device = wp.device_to_paddle(warp_device)
78
+
79
+ if hasattr(warp_dtype, "_shape_"):
80
+ paddle_shape = (array_size, *warp_dtype._shape_)
81
+ paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
82
+ else:
83
+ paddle_shape = (array_size,)
84
+ paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)
85
+
86
+ _a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
87
+ _b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
88
+ _c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
89
+ _d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
90
+ _e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
91
+
92
+ wp.synchronize()
93
+
94
+ # profiler = Profiler(interval=0.000001)
95
+ # profiler.start()
96
+
97
+ t1 = time.time_ns()
98
+
99
+ for _ in range(num_iters):
100
+ a = wp.from_paddle(_a, dtype=warp_dtype, return_ctype=True)
101
+ b = wp.from_paddle(_b, dtype=warp_dtype, return_ctype=True)
102
+ c = wp.from_paddle(_c, dtype=warp_dtype, return_ctype=True)
103
+ d = wp.from_paddle(_d, dtype=warp_dtype, return_ctype=True)
104
+ e = wp.from_paddle(_e, dtype=warp_dtype, return_ctype=True)
105
+ wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])
106
+
107
+ t2 = time.time_ns()
108
+ print(f"{(t2 - t1) / 1_000_000:8.0f} ms from_paddle(..., return_ctype=True)")
109
+
110
+ # profiler.stop()
111
+ # profiler.print()
112
+
113
+
114
+ def test_direct_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
115
+ warp_device = wp.get_device(device)
116
+ paddle_device = wp.device_to_paddle(warp_device)
117
+
118
+ if hasattr(warp_dtype, "_shape_"):
119
+ paddle_shape = (array_size, *warp_dtype._shape_)
120
+ paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
121
+ else:
122
+ paddle_shape = (array_size,)
123
+ paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)
124
+
125
+ _a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
126
+ _b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
127
+ _c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
128
+ _d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
129
+ _e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
130
+
131
+ wp.synchronize()
132
+
133
+ # profiler = Profiler(interval=0.000001)
134
+ # profiler.start()
135
+
136
+ t1 = time.time_ns()
137
+
138
+ for _ in range(num_iters):
139
+ wp.launch(kernel, dim=array_size, inputs=[_a, _b, _c, _d, _e])
140
+
141
+ t2 = time.time_ns()
142
+ print(f"{(t2 - t1) / 1_000_000:8.0f} ms direct from paddle")
143
+
144
+ # profiler.stop()
145
+ # profiler.print()
146
+
147
+
148
+ wp.init()
149
+
150
+ params = [
151
+ # (warp_dtype arg, kernel)
152
+ (None, create_simple_kernel(wp.float32)),
153
+ (wp.float32, create_simple_kernel(wp.float32)),
154
+ (wp.vec3f, create_simple_kernel(wp.vec3f)),
155
+ (wp.mat22f, create_simple_kernel(wp.mat22f)),
156
+ ]
157
+
158
+ wp.load_module()
159
+
160
+ num_iters = 100000
161
+
162
+ for warp_dtype, kernel in params:
163
+ print(f"\ndtype={wp.context.type_str(warp_dtype)}")
164
+ test_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
165
+ test_array_ctype_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
166
+ test_direct_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)