warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Sim Rigid Force
18
+ #
19
+ # Shows how to apply an external force (torque) to a rigid body causing
20
+ # it to roll.
21
+ #
22
+ ###########################################################################
23
+
24
+ import warp as wp
25
+ import warp.sim
26
+ import warp.sim.render
27
+
28
+
29
+ class Example:
30
+ def __init__(self, stage_path="example_rigid_force.usd", use_opengl=False):
31
+ fps = 60
32
+ self.frame_dt = 1.0 / fps
33
+ self.sim_substeps = 5
34
+ self.sim_dt = self.frame_dt / self.sim_substeps
35
+ self.sim_time = 0.0
36
+
37
+ builder = wp.sim.ModelBuilder()
38
+
39
+ b = builder.add_body(origin=wp.transform((0.0, 10.0, 0.0), wp.quat_identity()))
40
+ builder.add_shape_box(body=b, hx=1.0, hy=1.0, hz=1.0, density=100.0)
41
+
42
+ self.model = builder.finalize()
43
+ self.model.ground = True
44
+
45
+ self.integrator = wp.sim.XPBDIntegrator()
46
+
47
+ self.state_0 = self.model.state()
48
+ self.state_1 = self.model.state()
49
+
50
+ if use_opengl:
51
+ self.renderer = wp.sim.render.SimRendererOpenGL(self.model, stage_path)
52
+ elif stage_path:
53
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path)
54
+ else:
55
+ self.renderer = None
56
+
57
+ # simulate() allocates memory via a clone, so we can't use graph capture if the device does not support mempools
58
+ self.use_cuda_graph = wp.get_device().is_cuda and wp.is_mempool_enabled(wp.get_device())
59
+ if self.use_cuda_graph:
60
+ with wp.ScopedCapture() as capture:
61
+ self.simulate()
62
+ self.graph = capture.graph
63
+
64
+ def simulate(self):
65
+ for _ in range(self.sim_substeps):
66
+ wp.sim.collide(self.model, self.state_0)
67
+
68
+ self.state_0.clear_forces()
69
+ self.state_1.clear_forces()
70
+
71
+ self.state_0.body_f.assign(
72
+ [
73
+ [0.0, 0.0, -7000.0, 0.0, 0.0, 0.0],
74
+ ]
75
+ )
76
+
77
+ self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
78
+
79
+ # swap states
80
+ (self.state_0, self.state_1) = (self.state_1, self.state_0)
81
+
82
+ def step(self):
83
+ with wp.ScopedTimer("step"):
84
+ if self.use_cuda_graph:
85
+ wp.capture_launch(self.graph)
86
+ else:
87
+ self.simulate()
88
+ self.sim_time += self.frame_dt
89
+
90
+ def render(self):
91
+ if self.renderer is None:
92
+ return
93
+
94
+ with wp.ScopedTimer("render"):
95
+ self.renderer.begin_frame(self.sim_time)
96
+ self.renderer.render(self.state_0)
97
+ self.renderer.end_frame()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ import argparse
102
+
103
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
104
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
105
+ parser.add_argument(
106
+ "--stage_path",
107
+ type=lambda x: None if x == "None" else str(x),
108
+ default="example_rigid_force.usd",
109
+ help="Path to the output USD file.",
110
+ )
111
+ parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames.")
112
+ parser.add_argument(
113
+ "--opengl",
114
+ action="store_true",
115
+ help="Open an interactive window to play back animations in real time. Ignores --num_frames if used.",
116
+ )
117
+
118
+ args = parser.parse_known_args()[0]
119
+
120
+ with wp.ScopedDevice(args.device):
121
+ example = Example(stage_path=args.stage_path, use_opengl=args.opengl)
122
+
123
+ if args.opengl:
124
+ while example.renderer.is_running():
125
+ example.step()
126
+ example.render()
127
+ else:
128
+ for _ in range(args.num_frames):
129
+ example.step()
130
+ example.render()
131
+
132
+ if example.renderer:
133
+ example.renderer.save()
@@ -0,0 +1,115 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Sim Rigid Gyroscopic
18
+ #
19
+ # Demonstrates the Dzhanibekov effect where rigid bodies will tumble in
20
+ # free space due to unstable axes of rotation.
21
+ #
22
+ ###########################################################################
23
+
24
+ import warp as wp
25
+ import warp.sim
26
+ import warp.sim.render
27
+
28
+
29
+ class Example:
30
+ def __init__(self, stage_path="example_rigid_gyroscopic.usd"):
31
+ fps = 120
32
+ self.sim_dt = 1.0 / fps
33
+ self.sim_time = 0.0
34
+
35
+ self.scale = 0.5
36
+
37
+ builder = wp.sim.ModelBuilder()
38
+
39
+ b = builder.add_body()
40
+
41
+ # axis shape
42
+ builder.add_shape_box(
43
+ pos=wp.vec3(0.3 * self.scale, 0.0, 0.0),
44
+ hx=0.25 * self.scale,
45
+ hy=0.1 * self.scale,
46
+ hz=0.1 * self.scale,
47
+ density=100.0,
48
+ body=b,
49
+ )
50
+
51
+ # tip shape
52
+ builder.add_shape_box(
53
+ pos=wp.vec3(0.0, 0.0, 0.0),
54
+ hx=0.05 * self.scale,
55
+ hy=0.2 * self.scale,
56
+ hz=1.0 * self.scale,
57
+ density=100.0,
58
+ body=b,
59
+ )
60
+
61
+ # initial spin
62
+ builder.body_qd[0] = (25.0, 0.01, 0.01, 0.0, 0.0, 0.0)
63
+
64
+ builder.gravity = 0.0
65
+ self.model = builder.finalize()
66
+ self.model.ground = False
67
+
68
+ self.integrator = wp.sim.SemiImplicitIntegrator()
69
+ self.state = self.model.state()
70
+
71
+ if stage_path:
72
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=100.0)
73
+ else:
74
+ self.renderer = None
75
+
76
+ def step(self):
77
+ with wp.ScopedTimer("step"):
78
+ self.state.clear_forces()
79
+ self.state = self.integrator.simulate(self.model, self.state, self.state, self.sim_dt)
80
+ self.sim_time += self.sim_dt
81
+
82
+ def render(self):
83
+ if self.renderer is None:
84
+ return
85
+
86
+ with wp.ScopedTimer("render"):
87
+ self.renderer.begin_frame(self.sim_time)
88
+ self.renderer.render(self.state)
89
+ self.renderer.end_frame()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ import argparse
94
+
95
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
96
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
97
+ parser.add_argument(
98
+ "--stage_path",
99
+ type=lambda x: None if x == "None" else str(x),
100
+ default="example_rigid_gyroscopic.usd",
101
+ help="Path to the output USD file.",
102
+ )
103
+ parser.add_argument("--num_frames", type=int, default=2000, help="Total number of frames.")
104
+
105
+ args = parser.parse_known_args()[0]
106
+
107
+ with wp.ScopedDevice(args.device):
108
+ example = Example(stage_path=args.stage_path)
109
+
110
+ for _ in range(args.num_frames):
111
+ example.step()
112
+ example.render()
113
+
114
+ if example.renderer:
115
+ example.renderer.save()
@@ -0,0 +1,140 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Sim Rigid FEM
18
+ #
19
+ # Shows how to set up a rigid sphere colliding with an FEM beam
20
+ # using wp.sim.ModelBuilder().
21
+ #
22
+ ###########################################################################
23
+
24
+ import warp as wp
25
+ import warp.sim
26
+ import warp.sim.render
27
+
28
+
29
+ class Example:
30
+ def __init__(self, stage_path="example_rigid_soft_contact.usd"):
31
+ self.sim_width = 8
32
+ self.sim_height = 8
33
+
34
+ fps = 60
35
+ self.frame_dt = 1.0 / fps
36
+ self.sim_substeps = 32
37
+ self.sim_dt = self.frame_dt / self.sim_substeps
38
+ self.sim_time = 0.0
39
+ self.sim_iterations = 1
40
+ self.sim_relaxation = 1.0
41
+ self.profiler = {}
42
+
43
+ builder = wp.sim.ModelBuilder()
44
+ builder.default_particle_radius = 0.01
45
+
46
+ builder.add_soft_grid(
47
+ pos=wp.vec3(0.0, 0.0, 0.0),
48
+ rot=wp.quat_identity(),
49
+ vel=wp.vec3(0.0, 0.0, 0.0),
50
+ dim_x=20,
51
+ dim_y=10,
52
+ dim_z=10,
53
+ cell_x=0.1,
54
+ cell_y=0.1,
55
+ cell_z=0.1,
56
+ density=100.0,
57
+ k_mu=50000.0,
58
+ k_lambda=20000.0,
59
+ k_damp=0.0,
60
+ )
61
+
62
+ b = builder.add_body(origin=wp.transform((0.5, 2.5, 0.5), wp.quat_identity()))
63
+ builder.add_shape_sphere(body=b, radius=0.75, density=100.0)
64
+
65
+ self.model = builder.finalize()
66
+ self.model.ground = True
67
+ self.model.soft_contact_ke = 1.0e3
68
+ self.model.soft_contact_kd = 0.0
69
+ self.model.soft_contact_kf = 1.0e3
70
+
71
+ self.integrator = wp.sim.SemiImplicitIntegrator()
72
+
73
+ self.state_0 = self.model.state()
74
+ self.state_1 = self.model.state()
75
+
76
+ if stage_path:
77
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=1.0)
78
+ else:
79
+ self.renderer = None
80
+
81
+ self.use_cuda_graph = wp.get_device().is_cuda
82
+ if self.use_cuda_graph:
83
+ with wp.ScopedCapture() as capture:
84
+ self.simulate()
85
+ self.graph = capture.graph
86
+
87
+ def simulate(self):
88
+ for _s in range(self.sim_substeps):
89
+ wp.sim.collide(self.model, self.state_0)
90
+
91
+ self.state_0.clear_forces()
92
+ self.state_1.clear_forces()
93
+
94
+ self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
95
+
96
+ # swap states
97
+ (self.state_0, self.state_1) = (self.state_1, self.state_0)
98
+
99
+ def step(self):
100
+ with wp.ScopedTimer("step", dict=self.profiler):
101
+ if self.use_cuda_graph:
102
+ wp.capture_launch(self.graph)
103
+ else:
104
+ self.simulate()
105
+ self.sim_time += self.frame_dt
106
+
107
+ def render(self):
108
+ if self.renderer is None:
109
+ return
110
+
111
+ with wp.ScopedTimer("render"):
112
+ self.renderer.begin_frame(self.sim_time)
113
+ self.renderer.render(self.state_0)
114
+ self.renderer.end_frame()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ import argparse
119
+
120
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
121
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
122
+ parser.add_argument(
123
+ "--stage_path",
124
+ type=lambda x: None if x == "None" else str(x),
125
+ default="example_rigid_soft_contact.usd",
126
+ help="Path to the output USD file.",
127
+ )
128
+ parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames.")
129
+
130
+ args = parser.parse_known_args()[0]
131
+
132
+ with wp.ScopedDevice(args.device):
133
+ example = Example(stage_path=args.stage_path)
134
+
135
+ for _ in range(args.num_frames):
136
+ example.step()
137
+ example.render()
138
+
139
+ if example.renderer:
140
+ example.renderer.save()
@@ -0,0 +1,196 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Sim Neo-Hookean
18
+ #
19
+ # Shows a simulation of an Neo-Hookean FEM beam being twisted through a
20
+ # 180 degree rotation.
21
+ #
22
+ ###########################################################################
23
+ import math
24
+
25
+ import warp as wp
26
+ import warp.sim
27
+ import warp.sim.render
28
+
29
+
30
+ @wp.kernel
31
+ def twist_points(
32
+ rest: wp.array(dtype=wp.vec3), points: wp.array(dtype=wp.vec3), mass: wp.array(dtype=float), xform: wp.transform
33
+ ):
34
+ tid = wp.tid()
35
+
36
+ r = rest[tid]
37
+ p = points[tid]
38
+ m = mass[tid]
39
+
40
+ # twist the top layer of particles in the beam
41
+ if m == 0 and p[1] != 0.0:
42
+ points[tid] = wp.transform_point(xform, r)
43
+
44
+
45
+ @wp.kernel
46
+ def compute_volume(points: wp.array(dtype=wp.vec3), indices: wp.array2d(dtype=int), volume: wp.array(dtype=float)):
47
+ tid = wp.tid()
48
+
49
+ i = indices[tid, 0]
50
+ j = indices[tid, 1]
51
+ k = indices[tid, 2]
52
+ l = indices[tid, 3]
53
+
54
+ x0 = points[i]
55
+ x1 = points[j]
56
+ x2 = points[k]
57
+ x3 = points[l]
58
+
59
+ x10 = x1 - x0
60
+ x20 = x2 - x0
61
+ x30 = x3 - x0
62
+
63
+ v = wp.dot(x10, wp.cross(x20, x30)) / 6.0
64
+
65
+ wp.atomic_add(volume, 0, v)
66
+
67
+
68
+ class Example:
69
+ def __init__(self, stage_path="example_soft_body.usd", num_frames=300):
70
+ self.sim_substeps = 64
71
+ self.num_frames = num_frames
72
+ fps = 60
73
+ sim_duration = self.num_frames / fps
74
+ self.frame_dt = 1.0 / fps
75
+ self.sim_dt = self.frame_dt / self.sim_substeps
76
+ self.sim_time = 0.0
77
+ self.lift_speed = 2.5 / sim_duration * 2.0 # from Smith et al.
78
+ self.rot_speed = math.pi / sim_duration
79
+
80
+ builder = wp.sim.ModelBuilder()
81
+
82
+ cell_dim = 15
83
+ cell_size = 2.0 / cell_dim
84
+
85
+ center = cell_size * cell_dim * 0.5
86
+
87
+ builder.add_soft_grid(
88
+ pos=wp.vec3(-center, 0.0, -center),
89
+ rot=wp.quat_identity(),
90
+ vel=wp.vec3(0.0, 0.0, 0.0),
91
+ dim_x=cell_dim,
92
+ dim_y=cell_dim,
93
+ dim_z=cell_dim,
94
+ cell_x=cell_size,
95
+ cell_y=cell_size,
96
+ cell_z=cell_size,
97
+ density=100.0,
98
+ fix_bottom=True,
99
+ fix_top=True,
100
+ k_mu=1000.0,
101
+ k_lambda=5000.0,
102
+ k_damp=0.0,
103
+ )
104
+
105
+ self.model = builder.finalize()
106
+ self.model.ground = False
107
+ self.model.gravity[1] = 0.0
108
+
109
+ self.integrator = wp.sim.SemiImplicitIntegrator()
110
+
111
+ self.rest = self.model.state()
112
+ self.rest_vol = (cell_size * cell_dim) ** 3
113
+
114
+ self.state_0 = self.model.state()
115
+ self.state_1 = self.model.state()
116
+
117
+ self.volume = wp.zeros(1, dtype=wp.float32)
118
+
119
+ if stage_path:
120
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=20.0)
121
+ else:
122
+ self.renderer = None
123
+
124
+ self.use_cuda_graph = wp.get_device().is_cuda
125
+ if self.use_cuda_graph:
126
+ with wp.ScopedCapture() as capture:
127
+ self.simulate()
128
+ self.graph = capture.graph
129
+
130
+ def simulate(self):
131
+ for _ in range(self.sim_substeps):
132
+ self.state_0.clear_forces()
133
+ self.state_1.clear_forces()
134
+
135
+ self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
136
+
137
+ # swap states
138
+ (self.state_0, self.state_1) = (self.state_1, self.state_0)
139
+
140
+ def step(self):
141
+ with wp.ScopedTimer("step"):
142
+ xform = wp.transform(
143
+ (0.0, self.lift_speed * self.sim_time, 0.0),
144
+ wp.quat_from_axis_angle(wp.vec3(0.0, 1.0, 0.0), self.rot_speed * self.sim_time),
145
+ )
146
+ wp.launch(
147
+ kernel=twist_points,
148
+ dim=len(self.state_0.particle_q),
149
+ inputs=[self.rest.particle_q, self.state_0.particle_q, self.model.particle_mass, xform],
150
+ )
151
+ if self.use_cuda_graph:
152
+ wp.capture_launch(self.graph)
153
+ else:
154
+ self.simulate()
155
+ self.volume.zero_()
156
+ wp.launch(
157
+ kernel=compute_volume,
158
+ dim=self.model.tet_count,
159
+ inputs=[self.state_0.particle_q, self.model.tet_indices, self.volume],
160
+ )
161
+ self.sim_time += self.frame_dt
162
+
163
+ def render(self):
164
+ if self.renderer is None:
165
+ return
166
+
167
+ with wp.ScopedTimer("render"):
168
+ self.renderer.begin_frame(self.sim_time)
169
+ self.renderer.render(self.state_0)
170
+ self.renderer.end_frame()
171
+
172
+
173
+ if __name__ == "__main__":
174
+ import argparse
175
+
176
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
177
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
178
+ parser.add_argument(
179
+ "--stage_path",
180
+ type=lambda x: None if x == "None" else str(x),
181
+ default="example_soft_body.usd",
182
+ help="Path to the output USD file.",
183
+ )
184
+ parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames.")
185
+
186
+ args = parser.parse_known_args()[0]
187
+
188
+ with wp.ScopedDevice(args.device):
189
+ example = Example(stage_path=args.stage_path, num_frames=args.num_frames)
190
+
191
+ for _ in range(args.num_frames):
192
+ example.step()
193
+ example.render()
194
+
195
+ if example.renderer:
196
+ example.renderer.save()
@@ -0,0 +1,87 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Tile Cholesky
18
+ #
19
+ # Shows how to write a simple kernel computing a Cholesky factorize and
20
+ # triangular solve using Warp Cholesky Tile APIs.
21
+ #
22
+ ###########################################################################
23
+
24
+ import numpy as np
25
+
26
+ import warp as wp
27
+
28
+ wp.init()
29
+ wp.set_module_options({"enable_backward": False})
30
+
31
+ BLOCK_DIM = 128
32
+ TILE = 32
33
+
34
+ # Both should work
35
+ np_type, wp_type = np.float64, wp.float64
36
+ # np_type, wp_type = np.float32, wp.float32
37
+
38
+
39
+ @wp.kernel
40
+ def cholesky(
41
+ A: wp.array2d(dtype=wp_type),
42
+ L: wp.array2d(dtype=wp_type),
43
+ X: wp.array1d(dtype=wp_type),
44
+ Y: wp.array1d(dtype=wp_type),
45
+ ):
46
+ i, j, _ = wp.tid()
47
+
48
+ a = wp.tile_load(A, shape=(TILE, TILE))
49
+ l = wp.tile_cholesky(a)
50
+ wp.tile_store(L, l)
51
+
52
+ x = wp.tile_load(X, shape=TILE)
53
+ y = wp.tile_cholesky_solve(l, x)
54
+ wp.tile_store(Y, y)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ wp.set_device("cuda:0")
59
+
60
+ A_h = np.ones((TILE, TILE), dtype=np_type) + 5 * np.diag(np.ones(TILE), 0)
61
+ L_h = np.zeros_like(A_h)
62
+
63
+ A_wp = wp.array2d(A_h, dtype=wp_type)
64
+ L_wp = wp.array2d(L_h, dtype=wp_type)
65
+
66
+ X_h = np.arange(TILE, dtype=np_type)
67
+ Y_h = np.zeros_like(X_h)
68
+
69
+ X_wp = wp.array1d(X_h, dtype=wp_type)
70
+ Y_wp = wp.array1d(Y_h, dtype=wp_type)
71
+
72
+ wp.launch_tiled(cholesky, dim=[1, 1], inputs=[A_wp, L_wp, X_wp, Y_wp], block_dim=BLOCK_DIM)
73
+
74
+ L_np = np.linalg.cholesky(A_h)
75
+ Y_np = np.linalg.solve(A_h, X_h)
76
+
77
+ print("A:\n", A_h)
78
+ print("L (Warp):\n", L_wp)
79
+ print("L (Numpy):\n", L_np)
80
+
81
+ print("x:\n", X_h)
82
+ print("A\\n (Warp):\n", Y_wp.numpy())
83
+ print("A\\x (Numpy):\n", Y_np)
84
+
85
+ assert np.allclose(Y_wp.numpy(), Y_np) and np.allclose(L_wp.numpy(), L_np)
86
+
87
+ print("Example Tile Cholesky passed")