warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/optim/linear.py ADDED
@@ -0,0 +1,1137 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from math import sqrt
17
+ from typing import Any, Callable, Optional, Tuple, Union
18
+
19
+ import warp as wp
20
+ import warp.sparse as sparse
21
+ from warp.utils import array_inner
22
+
23
+ # No need to auto-generate adjoint code for linear solvers
24
+ wp.set_module_options({"enable_backward": False})
25
+
26
+
27
+ class LinearOperator:
28
+ """
29
+ Linear operator to be used as left-hand-side of linear iterative solvers.
30
+
31
+ Args:
32
+ shape: Tuple containing the number of rows and columns of the operator
33
+ dtype: Type of the operator elements
34
+ device: Device on which computations involving the operator should be performed
35
+ matvec: Matrix-vector multiplication routine
36
+
37
+ The matrix-vector multiplication routine should have the following signature:
38
+
39
+ .. code-block:: python
40
+
41
+ def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
42
+ '''Performs the operation z = alpha * x + beta * y'''
43
+ ...
44
+
45
+ For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
46
+ for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
47
+ cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
48
+
49
+ """
50
+
51
+ def __init__(self, shape: Tuple[int, int], dtype: type, device: wp.context.Device, matvec: Callable):
52
+ self._shape = shape
53
+ self._dtype = dtype
54
+ self._device = device
55
+ self._matvec = matvec
56
+
57
+ @property
58
+ def shape(self) -> Tuple[int, int]:
59
+ return self._shape
60
+
61
+ @property
62
+ def dtype(self) -> type:
63
+ return self._dtype
64
+
65
+ @property
66
+ def device(self) -> wp.context.Device:
67
+ return self._device
68
+
69
+ @property
70
+ def matvec(self) -> Callable:
71
+ return self._matvec
72
+
73
+ @property
74
+ def scalar_type(self):
75
+ return wp.types.type_scalar_type(self.dtype)
76
+
77
+
78
+ _Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
79
+
80
+
81
+ def aslinearoperator(A: _Matrix) -> LinearOperator:
82
+ """
83
+ Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
84
+
85
+ `A` must be of one of the following types:
86
+
87
+ - :class:`warp.sparse.BsrMatrix`
88
+ - two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
89
+ - one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
90
+ - :class:`warp.sparse.LinearOperator`; no casting necessary
91
+ """
92
+
93
+ if A is None or isinstance(A, LinearOperator):
94
+ return A
95
+
96
+ def bsr_mv(x, y, z, alpha, beta):
97
+ if z.ptr != y.ptr and beta != 0.0:
98
+ wp.copy(src=y, dest=z)
99
+ sparse.bsr_mv(A, x, z, alpha, beta)
100
+
101
+ def dense_mv(x, y, z, alpha, beta):
102
+ wp.launch(_dense_mv_kernel, dim=A.shape[1], device=A.device, inputs=[A, x, y, z, alpha, beta])
103
+
104
+ def diag_mv(x, y, z, alpha, beta):
105
+ scalar_type = wp.types.type_scalar_type(A.dtype)
106
+ alpha = scalar_type(alpha)
107
+ beta = scalar_type(beta)
108
+ wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
109
+
110
+ def diag_mv_vec(x, y, z, alpha, beta):
111
+ scalar_type = wp.types.type_scalar_type(A.dtype)
112
+ alpha = scalar_type(alpha)
113
+ beta = scalar_type(beta)
114
+ wp.launch(_diag_mv_vec_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
115
+
116
+ if isinstance(A, wp.array):
117
+ if A.ndim == 2:
118
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
119
+ if A.ndim == 1:
120
+ if wp.types.type_is_vector(A.dtype):
121
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
122
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
123
+ if isinstance(A, sparse.BsrMatrix):
124
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
125
+
126
+ raise ValueError(f"Unable to create LinearOperator from {A}")
127
+
128
+
129
+ def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
130
+ """Constructs and returns a preconditioner for an input matrix.
131
+
132
+ Args:
133
+ A: The matrix for which to build the preconditioner
134
+ ptype: The type of preconditioner. Currently the following values are supported:
135
+
136
+ - ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
137
+ - ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
138
+ - ``"id"``: Identity (null) preconditioner
139
+ """
140
+
141
+ if ptype == "id":
142
+ return None
143
+
144
+ if ptype in ("diag", "diag_abs"):
145
+ use_abs = 1 if ptype == "diag_abs" else 0
146
+ if isinstance(A, sparse.BsrMatrix):
147
+ A_diag = sparse.bsr_get_diag(A)
148
+ if wp.types.type_is_matrix(A.dtype):
149
+ inv_diag = wp.empty(
150
+ shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
151
+ )
152
+ wp.launch(
153
+ _extract_inverse_diagonal_blocked,
154
+ dim=inv_diag.shape,
155
+ device=inv_diag.device,
156
+ inputs=[A_diag, inv_diag, use_abs],
157
+ )
158
+ else:
159
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
160
+ wp.launch(
161
+ _extract_inverse_diagonal_scalar,
162
+ dim=inv_diag.shape,
163
+ device=inv_diag.device,
164
+ inputs=[A_diag, inv_diag, use_abs],
165
+ )
166
+ elif isinstance(A, wp.array) and A.ndim == 2:
167
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
168
+ wp.launch(
169
+ _extract_inverse_diagonal_dense,
170
+ dim=inv_diag.shape,
171
+ device=inv_diag.device,
172
+ inputs=[A, inv_diag, use_abs],
173
+ )
174
+ else:
175
+ raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
176
+
177
+ return aslinearoperator(inv_diag)
178
+
179
+ raise ValueError(f"Unsupported preconditioner type '{ptype}'")
180
+
181
+
182
+ def cg(
183
+ A: _Matrix,
184
+ b: wp.array,
185
+ x: wp.array,
186
+ tol: Optional[float] = None,
187
+ atol: Optional[float] = None,
188
+ maxiter: Optional[float] = 0,
189
+ M: Optional[_Matrix] = None,
190
+ callback: Optional[Callable] = None,
191
+ check_every=10,
192
+ use_cuda_graph=True,
193
+ ) -> Tuple[int, float, float]:
194
+ """Computes an approximate solution to a symmetric, positive-definite linear system
195
+ using the Conjugate Gradient algorithm.
196
+
197
+ Args:
198
+ A: the linear system's left-hand-side
199
+ b: the linear system's right-hand-side
200
+ x: initial guess and solution vector
201
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
202
+ atol: absolute tolerance for the residual
203
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
204
+ Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
205
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
206
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
207
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
208
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
209
+ The linear operator and preconditioner must only perform graph-friendly operations.
210
+
211
+ Returns:
212
+ Tuple (final iteration number, residual norm, absolute tolerance)
213
+
214
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
215
+ """
216
+
217
+ A = aslinearoperator(A)
218
+ M = aslinearoperator(M)
219
+
220
+ if maxiter == 0:
221
+ maxiter = A.shape[0]
222
+
223
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
224
+
225
+ device = A.device
226
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
227
+
228
+ # Notations below follow pseudo-code from https://en.wikipedia.org/wiki/Conjugate_gradient_method
229
+
230
+ # z = M r
231
+ if M is not None:
232
+ z = wp.zeros_like(b)
233
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
234
+
235
+ # rz = r' z;
236
+ rz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
237
+ array_inner(r, z, out=rz_new)
238
+ else:
239
+ z = r
240
+
241
+ rz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
242
+ p_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
243
+ Ap = wp.zeros_like(b)
244
+
245
+ p = wp.clone(z)
246
+
247
+ def do_iteration(atol_sq, rr_old, rr_new, rz_old, rz_new):
248
+ # Ap = A * p;
249
+ A.matvec(p, Ap, Ap, alpha=1, beta=0)
250
+
251
+ array_inner(p, Ap, out=p_Ap)
252
+
253
+ wp.launch(
254
+ kernel=_cg_kernel_1,
255
+ dim=x.shape[0],
256
+ device=device,
257
+ inputs=[atol_sq, rr_old, rz_old, p_Ap, x, r, p, Ap],
258
+ )
259
+ array_inner(r, r, out=rr_new)
260
+
261
+ # z = M r
262
+ if M is not None:
263
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
264
+ # rz = r' z;
265
+ array_inner(r, z, out=rz_new)
266
+
267
+ wp.launch(kernel=_cg_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr_new, rz_old, rz_new, z, p])
268
+
269
+ # We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
270
+ # In the non-preconditioned case we reuse the error norm buffer for the new <r,z> computation
271
+
272
+ def do_odd_even_cycle(atol_sq: float):
273
+ # A pair of iterations, so that we're swapping the residual buffers twice
274
+ if M is None:
275
+ do_iteration(atol_sq, r_norm_sq, rz_old, r_norm_sq, rz_old)
276
+ do_iteration(atol_sq, rz_old, r_norm_sq, rz_old, r_norm_sq)
277
+ else:
278
+ do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_new, rz_old)
279
+ do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_old, rz_new)
280
+
281
+ return _run_solver_loop(
282
+ do_odd_even_cycle,
283
+ cycle_size=2,
284
+ r_norm_sq=r_norm_sq,
285
+ maxiter=maxiter,
286
+ atol=atol,
287
+ callback=callback,
288
+ check_every=check_every,
289
+ use_cuda_graph=use_cuda_graph,
290
+ device=device,
291
+ )
292
+
293
+
294
+ def cr(
295
+ A: _Matrix,
296
+ b: wp.array,
297
+ x: wp.array,
298
+ tol: Optional[float] = None,
299
+ atol: Optional[float] = None,
300
+ maxiter: Optional[float] = 0,
301
+ M: Optional[_Matrix] = None,
302
+ callback: Optional[Callable] = None,
303
+ check_every=10,
304
+ use_cuda_graph=True,
305
+ ) -> Tuple[int, float, float]:
306
+ """Computes an approximate solution to a symmetric, positive-definite linear system
307
+ using the Conjugate Residual algorithm.
308
+
309
+ Args:
310
+ A: the linear system's left-hand-side
311
+ b: the linear system's right-hand-side
312
+ x: initial guess and solution vector
313
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
314
+ atol: absolute tolerance for the residual
315
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
316
+ Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
317
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
318
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
319
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
320
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
321
+ The linear operator and preconditioner must only perform graph-friendly operations.
322
+
323
+ Returns:
324
+ Tuple (final iteration number, residual norm, absolute tolerance)
325
+
326
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
327
+ """
328
+
329
+ A = aslinearoperator(A)
330
+ M = aslinearoperator(M)
331
+
332
+ if maxiter == 0:
333
+ maxiter = A.shape[0]
334
+
335
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
336
+
337
+ device = A.device
338
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
339
+
340
+ # Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
341
+ # with z := M^-1 r and y := M^-1 Ap
342
+
343
+ # z = M r
344
+ if M is None:
345
+ z = r
346
+ else:
347
+ z = wp.zeros_like(r)
348
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
349
+
350
+ Az = wp.zeros_like(b)
351
+ A.matvec(z, Az, Az, alpha=1, beta=0)
352
+
353
+ p = wp.clone(z)
354
+ Ap = wp.clone(Az)
355
+
356
+ if M is None:
357
+ y = Ap
358
+ else:
359
+ y = wp.zeros_like(Ap)
360
+
361
+ zAz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
362
+ zAz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
363
+ y_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
364
+
365
+ array_inner(z, Az, out=zAz_new)
366
+
367
+ def do_iteration(atol_sq, rr, zAz_old, zAz_new):
368
+ if M is not None:
369
+ M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
370
+ array_inner(Ap, y, out=y_Ap)
371
+
372
+ if M is None:
373
+ # In non-preconditioned case, first kernel is same as CG
374
+ wp.launch(
375
+ kernel=_cg_kernel_1,
376
+ dim=x.shape[0],
377
+ device=device,
378
+ inputs=[atol_sq, rr, zAz_old, y_Ap, x, r, p, Ap],
379
+ )
380
+ else:
381
+ # In preconditioned case, we have one more vector to update
382
+ wp.launch(
383
+ kernel=_cr_kernel_1,
384
+ dim=x.shape[0],
385
+ device=device,
386
+ inputs=[atol_sq, rr, zAz_old, y_Ap, x, r, z, p, Ap, y],
387
+ )
388
+
389
+ array_inner(r, r, out=rr)
390
+
391
+ A.matvec(z, Az, Az, alpha=1, beta=0)
392
+ array_inner(z, Az, out=zAz_new)
393
+
394
+ # beta = rz_new / rz_old
395
+ wp.launch(
396
+ kernel=_cr_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr, zAz_old, zAz_new, z, p, Az, Ap]
397
+ )
398
+
399
+ # We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
400
+ def do_odd_even_cycle(atol_sq: float):
401
+ do_iteration(atol_sq, r_norm_sq, zAz_new, zAz_old)
402
+ do_iteration(atol_sq, r_norm_sq, zAz_old, zAz_new)
403
+
404
+ return _run_solver_loop(
405
+ do_odd_even_cycle,
406
+ cycle_size=2,
407
+ r_norm_sq=r_norm_sq,
408
+ maxiter=maxiter,
409
+ atol=atol,
410
+ callback=callback,
411
+ check_every=check_every,
412
+ use_cuda_graph=use_cuda_graph,
413
+ device=device,
414
+ )
415
+
416
+
417
+ def bicgstab(
418
+ A: _Matrix,
419
+ b: wp.array,
420
+ x: wp.array,
421
+ tol: Optional[float] = None,
422
+ atol: Optional[float] = None,
423
+ maxiter: Optional[float] = 0,
424
+ M: Optional[_Matrix] = None,
425
+ callback: Optional[Callable] = None,
426
+ check_every=10,
427
+ use_cuda_graph=True,
428
+ is_left_preconditioner=False,
429
+ ):
430
+ """Computes an approximate solution to a linear system using the Biconjugate Gradient Stabilized method (BiCGSTAB).
431
+
432
+ Args:
433
+ A: the linear system's left-hand-side
434
+ b: the linear system's right-hand-side
435
+ x: initial guess and solution vector
436
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
437
+ atol: absolute tolerance for the residual
438
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
439
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
440
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
441
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
442
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
443
+ The linear operator and preconditioner must only perform graph-friendly operations.
444
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
445
+
446
+ Returns:
447
+ Tuple (final iteration number, residual norm, absolute tolerance)
448
+
449
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
450
+ """
451
+ A = aslinearoperator(A)
452
+ M = aslinearoperator(M)
453
+
454
+ if maxiter == 0:
455
+ maxiter = A.shape[0]
456
+
457
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
458
+
459
+ device = A.device
460
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
461
+
462
+ # Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
463
+
464
+ rho = wp.clone(r_norm_sq, pinned=False)
465
+ r0v = wp.empty(n=1, dtype=scalar_dtype, device=device)
466
+ st = wp.empty(n=1, dtype=scalar_dtype, device=device)
467
+ tt = wp.empty(n=1, dtype=scalar_dtype, device=device)
468
+
469
+ # work arrays
470
+ r0 = wp.clone(r)
471
+ v = wp.zeros_like(r)
472
+ t = wp.zeros_like(r)
473
+ p = wp.clone(r)
474
+
475
+ if M is not None:
476
+ y = wp.zeros_like(p)
477
+ z = wp.zeros_like(r)
478
+ if is_left_preconditioner:
479
+ Mt = wp.zeros_like(t)
480
+ else:
481
+ y = p
482
+ z = r
483
+ Mt = t
484
+
485
+ def do_iteration(atol_sq: float):
486
+ # y = M p
487
+ if M is not None:
488
+ M.matvec(p, y, y, alpha=1.0, beta=0.0)
489
+
490
+ # v = A * y;
491
+ A.matvec(y, v, v, alpha=1, beta=0)
492
+
493
+ # alpha = rho / <r0 . v>
494
+ array_inner(r0, v, out=r0v)
495
+
496
+ # x += alpha y
497
+ # r -= alpha v
498
+ wp.launch(
499
+ kernel=_bicgstab_kernel_1,
500
+ dim=x.shape[0],
501
+ device=device,
502
+ inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
503
+ )
504
+ array_inner(r, r, out=r_norm_sq)
505
+
506
+ # z = M r
507
+ if M is not None:
508
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
509
+
510
+ # t = A z
511
+ A.matvec(z, t, t, alpha=1, beta=0)
512
+
513
+ if is_left_preconditioner:
514
+ # Mt = M t
515
+ if M is not None:
516
+ M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
517
+
518
+ # omega = <Mt, Ms> / <Mt, Mt>
519
+ array_inner(z, Mt, out=st)
520
+ array_inner(Mt, Mt, out=tt)
521
+ else:
522
+ array_inner(r, t, out=st)
523
+ array_inner(t, t, out=tt)
524
+
525
+ # x += omega z
526
+ # r -= omega t
527
+ wp.launch(
528
+ kernel=_bicgstab_kernel_2,
529
+ dim=z.shape[0],
530
+ device=device,
531
+ inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
532
+ )
533
+ array_inner(r, r, out=r_norm_sq)
534
+
535
+ # rho = <r0, r>
536
+ array_inner(r0, r, out=rho)
537
+
538
+ # beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
539
+ # p = r + beta (p - omega v)
540
+ wp.launch(
541
+ kernel=_bicgstab_kernel_3,
542
+ dim=z.shape[0],
543
+ device=device,
544
+ inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
545
+ )
546
+
547
+ return _run_solver_loop(
548
+ do_iteration,
549
+ cycle_size=1,
550
+ r_norm_sq=r_norm_sq,
551
+ maxiter=maxiter,
552
+ atol=atol,
553
+ callback=callback,
554
+ check_every=check_every,
555
+ use_cuda_graph=use_cuda_graph,
556
+ device=device,
557
+ )
558
+
559
+
560
+ def gmres(
561
+ A: _Matrix,
562
+ b: wp.array,
563
+ x: wp.array,
564
+ tol: Optional[float] = None,
565
+ atol: Optional[float] = None,
566
+ restart=31,
567
+ maxiter: Optional[float] = 0,
568
+ M: Optional[_Matrix] = None,
569
+ callback: Optional[Callable] = None,
570
+ check_every=31,
571
+ use_cuda_graph=True,
572
+ is_left_preconditioner=False,
573
+ ):
574
+ """Computes an approximate solution to a linear system using the restarted Generalized Minimum Residual method (GMRES[k]).
575
+
576
+ Args:
577
+ A: the linear system's left-hand-side
578
+ b: the linear system's right-hand-side
579
+ x: initial guess and solution vector
580
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
581
+ atol: absolute tolerance for the residual
582
+ restart: The restart parameter, i.e, the `k` in `GMRES[k]`. In general, increasing this parameter reduces the number of iterations but increases memory consumption.
583
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
584
+ Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
585
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
586
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
587
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
588
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
589
+ The linear operator and preconditioner must only perform graph-friendly operations.
590
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
591
+
592
+ Returns:
593
+ Tuple (final iteration number, residual norm, absolute tolerance)
594
+
595
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
596
+ """
597
+
598
+ A = aslinearoperator(A)
599
+ M = aslinearoperator(M)
600
+
601
+ if maxiter == 0:
602
+ maxiter = A.shape[0]
603
+
604
+ restart = min(restart, maxiter)
605
+ check_every = max(restart, check_every)
606
+
607
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
608
+
609
+ device = A.device
610
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
611
+
612
+ pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
613
+
614
+ beta_sq = wp.empty_like(r_norm_sq, pinned=False)
615
+ H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
616
+
617
+ y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
618
+
619
+ w = wp.zeros_like(r)
620
+ V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
621
+
622
+ def array_coeff(H, i, j):
623
+ return wp.array(
624
+ ptr=H.ptr + i * H.strides[0] + j * H.strides[1],
625
+ dtype=H.dtype,
626
+ shape=(1,),
627
+ device=H.device,
628
+ copy=False,
629
+ )
630
+
631
+ def array_row(V, i):
632
+ return wp.array(
633
+ ptr=V.ptr + i * V.strides[0],
634
+ dtype=V.dtype,
635
+ shape=V.shape[1],
636
+ device=V.device,
637
+ copy=False,
638
+ )
639
+
640
+ def do_arnoldi_iteration(j: int):
641
+ # w = A * v;
642
+
643
+ vj = array_row(V, j)
644
+
645
+ if M is not None:
646
+ tmp = array_row(V, j + 1)
647
+
648
+ if is_left_preconditioner:
649
+ A.matvec(vj, tmp, tmp, alpha=1, beta=0)
650
+ M.matvec(tmp, w, w, alpha=1, beta=0)
651
+ else:
652
+ M.matvec(vj, tmp, tmp, alpha=1, beta=0)
653
+ A.matvec(tmp, w, w, alpha=1, beta=0)
654
+ else:
655
+ A.matvec(vj, w, w, alpha=1, beta=0)
656
+
657
+ for i in range(j + 1):
658
+ vi = array_row(V, i)
659
+ hij = array_coeff(H, i, j)
660
+ array_inner(w, vi, out=hij)
661
+
662
+ wp.launch(_gmres_arnoldi_axpy_kernel, dim=w.shape, device=w.device, inputs=[vi, w, hij])
663
+
664
+ hjnj = array_coeff(H, j + 1, j)
665
+ array_inner(w, w, out=hjnj)
666
+
667
+ vjn = array_row(V, j + 1)
668
+ wp.launch(_gmres_arnoldi_normalize_kernel, dim=w.shape, device=w.device, inputs=[w, vjn, hjnj])
669
+
670
+ def do_restart_cycle(atol_sq: float):
671
+ if M is not None and is_left_preconditioner:
672
+ M.matvec(r, w, w, alpha=1, beta=0)
673
+ rh = w
674
+ else:
675
+ rh = r
676
+
677
+ array_inner(rh, rh, out=beta_sq)
678
+
679
+ v0 = array_row(V, 0)
680
+ # v0 = r / beta
681
+ wp.launch(_gmres_arnoldi_normalize_kernel, dim=r.shape, device=r.device, inputs=[rh, v0, beta_sq])
682
+
683
+ for j in range(restart):
684
+ do_arnoldi_iteration(j)
685
+
686
+ wp.launch(_gmres_normalize_lower_diagonal, dim=restart, device=device, inputs=[H])
687
+ wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, pivot_tolerance, beta_sq, H, y])
688
+
689
+ # update x
690
+ if M is None or is_left_preconditioner:
691
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(1.0), y, V, x])
692
+ else:
693
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(0.0), y, V, w])
694
+ M.matvec(w, x, x, alpha=1, beta=1)
695
+
696
+ # update r and residual
697
+ wp.copy(src=b, dest=r)
698
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
699
+ array_inner(r, r, out=r_norm_sq)
700
+
701
+ return _run_solver_loop(
702
+ do_restart_cycle,
703
+ cycle_size=restart,
704
+ r_norm_sq=r_norm_sq,
705
+ maxiter=maxiter,
706
+ atol=atol,
707
+ callback=callback,
708
+ check_every=check_every,
709
+ use_cuda_graph=use_cuda_graph,
710
+ device=device,
711
+ )
712
+
713
+
714
+ def _get_dtype_epsilon(dtype):
715
+ if dtype == wp.float64:
716
+ return 1.0e-16
717
+ elif dtype == wp.float16:
718
+ return 1.0e-4
719
+
720
+ return 1.0e-8
721
+
722
+
723
+ def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
724
+ eps_tol = _get_dtype_epsilon(dtype)
725
+ default_tol = eps_tol ** (3 / 4)
726
+ min_tol = eps_tol ** (9 / 4)
727
+
728
+ if tol is None and atol is None:
729
+ tol = atol = default_tol
730
+ elif tol is None:
731
+ tol = atol
732
+ elif atol is None:
733
+ atol = tol
734
+
735
+ return max(tol * lhs_norm, atol, min_tol)
736
+
737
+
738
+ def _initialize_residual_and_tolerance(A: LinearOperator, b: wp.array, x: wp.array, tol: float, atol: float):
739
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
740
+ device = A.device
741
+
742
+ # Buffer for storing square norm or residual
743
+ r_norm_sq = wp.empty(n=1, dtype=scalar_dtype, device=device, pinned=device.is_cuda)
744
+
745
+ # Compute b norm to define absolute tolerance
746
+ array_inner(b, b, out=r_norm_sq)
747
+ atol = _get_absolute_tolerance(scalar_dtype, tol, atol, sqrt(r_norm_sq.numpy()[0]))
748
+
749
+ # Residual r = b - Ax
750
+ r = wp.empty_like(b)
751
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
752
+
753
+ array_inner(r, r, out=r_norm_sq)
754
+
755
+ return r, r_norm_sq, atol
756
+
757
+
758
+ def _run_solver_loop(
759
+ do_cycle: Callable[[float], None],
760
+ cycle_size: int,
761
+ r_norm_sq: wp.array,
762
+ maxiter: int,
763
+ atol: float,
764
+ callback: Callable,
765
+ check_every: int,
766
+ use_cuda_graph: bool,
767
+ device,
768
+ ):
769
+ atol_sq = atol * atol
770
+
771
+ cur_iter = 0
772
+
773
+ err_sq = r_norm_sq.numpy()[0]
774
+ err = sqrt(err_sq)
775
+ if callback is not None:
776
+ callback(cur_iter, err, atol)
777
+
778
+ if err_sq <= atol_sq:
779
+ return cur_iter, err, atol
780
+
781
+ graph = None
782
+
783
+ while True:
784
+ # Do not do graph capture at first iteration -- modules may not be loaded yet
785
+ if device.is_cuda and use_cuda_graph and cur_iter > 0:
786
+ if graph is None:
787
+ wp.capture_begin(device, force_module_load=False)
788
+ try:
789
+ do_cycle(atol_sq)
790
+ finally:
791
+ graph = wp.capture_end(device)
792
+ wp.capture_launch(graph)
793
+ else:
794
+ do_cycle(atol_sq)
795
+
796
+ cur_iter += cycle_size
797
+
798
+ if cur_iter >= maxiter:
799
+ break
800
+
801
+ if (cur_iter % check_every) < cycle_size:
802
+ err_sq = r_norm_sq.numpy()[0]
803
+
804
+ if err_sq <= atol_sq:
805
+ break
806
+
807
+ if callback is not None:
808
+ callback(cur_iter, sqrt(err_sq), atol)
809
+
810
+ err_sq = r_norm_sq.numpy()[0]
811
+ err = sqrt(err_sq)
812
+ if callback is not None:
813
+ callback(cur_iter, err, atol)
814
+
815
+ return cur_iter, err, atol
816
+
817
+
818
+ @wp.func
819
+ def _calc_mv_product(i: wp.int32, A: wp.array2d(dtype=Any), x: wp.array1d(dtype=Any)):
820
+ sum = A.dtype(0)
821
+ for j in range(A.shape[1]):
822
+ sum += A[i, j] * x[j]
823
+ return sum
824
+
825
+
826
+ @wp.kernel
827
+ def _dense_mv_kernel(
828
+ A: wp.array2d(dtype=Any),
829
+ x: wp.array1d(dtype=Any),
830
+ y: wp.array1d(dtype=Any),
831
+ z: wp.array1d(dtype=Any),
832
+ alpha: Any,
833
+ beta: Any,
834
+ ):
835
+ i = wp.tid()
836
+ z[i] = z.dtype(beta) * y[i] + z.dtype(alpha) * _calc_mv_product(i, A, x)
837
+
838
+
839
+ @wp.kernel
840
+ def _diag_mv_kernel(
841
+ A: wp.array(dtype=Any),
842
+ x: wp.array(dtype=Any),
843
+ y: wp.array(dtype=Any),
844
+ z: wp.array(dtype=Any),
845
+ alpha: Any,
846
+ beta: Any,
847
+ ):
848
+ i = wp.tid()
849
+ z[i] = beta * y[i] + alpha * (A[i] * x[i])
850
+
851
+
852
+ @wp.kernel
853
+ def _diag_mv_vec_kernel(
854
+ A: wp.array(dtype=Any),
855
+ x: wp.array(dtype=Any),
856
+ y: wp.array(dtype=Any),
857
+ z: wp.array(dtype=Any),
858
+ alpha: Any,
859
+ beta: Any,
860
+ ):
861
+ i = wp.tid()
862
+ z[i] = beta * y[i] + alpha * wp.cw_mul(A[i], x[i])
863
+
864
+
865
+ @wp.func
866
+ def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
867
+ zero = type(coeff)(0.0)
868
+ one = type(coeff)(1.0)
869
+ return wp.where(coeff == zero, one, one / wp.where(use_abs, wp.abs(coeff), coeff))
870
+
871
+
872
+ @wp.kernel
873
+ def _extract_inverse_diagonal_blocked(
874
+ diag_block: wp.array(dtype=Any),
875
+ inv_diag: wp.array(dtype=Any),
876
+ use_abs: int,
877
+ ):
878
+ i = wp.tid()
879
+
880
+ d = wp.get_diag(diag_block[i])
881
+ for k in range(d.length):
882
+ d[k] = _inverse_diag_coefficient(d[k], use_abs != 0)
883
+
884
+ inv_diag[i] = d
885
+
886
+
887
+ @wp.kernel
888
+ def _extract_inverse_diagonal_scalar(
889
+ diag_array: wp.array(dtype=Any),
890
+ inv_diag: wp.array(dtype=Any),
891
+ use_abs: int,
892
+ ):
893
+ i = wp.tid()
894
+ inv_diag[i] = _inverse_diag_coefficient(diag_array[i], use_abs != 0)
895
+
896
+
897
+ @wp.kernel
898
+ def _extract_inverse_diagonal_dense(
899
+ dense_matrix: wp.array2d(dtype=Any),
900
+ inv_diag: wp.array(dtype=Any),
901
+ use_abs: int,
902
+ ):
903
+ i = wp.tid()
904
+ inv_diag[i] = _inverse_diag_coefficient(dense_matrix[i, i], use_abs != 0)
905
+
906
+
907
+ @wp.kernel
908
+ def _cg_kernel_1(
909
+ tol: Any,
910
+ resid: wp.array(dtype=Any),
911
+ rz_old: wp.array(dtype=Any),
912
+ p_Ap: wp.array(dtype=Any),
913
+ x: wp.array(dtype=Any),
914
+ r: wp.array(dtype=Any),
915
+ p: wp.array(dtype=Any),
916
+ Ap: wp.array(dtype=Any),
917
+ ):
918
+ i = wp.tid()
919
+
920
+ alpha = wp.where(resid[0] > tol, rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
921
+
922
+ x[i] = x[i] + alpha * p[i]
923
+ r[i] = r[i] - alpha * Ap[i]
924
+
925
+
926
+ @wp.kernel
927
+ def _cg_kernel_2(
928
+ tol: Any,
929
+ resid: wp.array(dtype=Any),
930
+ rz_old: wp.array(dtype=Any),
931
+ rz_new: wp.array(dtype=Any),
932
+ z: wp.array(dtype=Any),
933
+ p: wp.array(dtype=Any),
934
+ ):
935
+ # p = r + (rz_new / rz_old) * p;
936
+ i = wp.tid()
937
+
938
+ beta = wp.where(resid[0] > tol, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
939
+
940
+ p[i] = z[i] + beta * p[i]
941
+
942
+
943
+ @wp.kernel
944
+ def _cr_kernel_1(
945
+ tol: Any,
946
+ resid: wp.array(dtype=Any),
947
+ zAz_old: wp.array(dtype=Any),
948
+ y_Ap: wp.array(dtype=Any),
949
+ x: wp.array(dtype=Any),
950
+ r: wp.array(dtype=Any),
951
+ z: wp.array(dtype=Any),
952
+ p: wp.array(dtype=Any),
953
+ Ap: wp.array(dtype=Any),
954
+ y: wp.array(dtype=Any),
955
+ ):
956
+ i = wp.tid()
957
+
958
+ alpha = wp.where(resid[0] > tol and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
959
+
960
+ x[i] = x[i] + alpha * p[i]
961
+ r[i] = r[i] - alpha * Ap[i]
962
+ z[i] = z[i] - alpha * y[i]
963
+
964
+
965
+ @wp.kernel
966
+ def _cr_kernel_2(
967
+ tol: Any,
968
+ resid: wp.array(dtype=Any),
969
+ zAz_old: wp.array(dtype=Any),
970
+ zAz_new: wp.array(dtype=Any),
971
+ z: wp.array(dtype=Any),
972
+ p: wp.array(dtype=Any),
973
+ Az: wp.array(dtype=Any),
974
+ Ap: wp.array(dtype=Any),
975
+ ):
976
+ # p = r + (rz_new / rz_old) * p;
977
+ i = wp.tid()
978
+
979
+ beta = wp.where(resid[0] > tol and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
980
+
981
+ p[i] = z[i] + beta * p[i]
982
+ Ap[i] = Az[i] + beta * Ap[i]
983
+
984
+
985
+ @wp.kernel
986
+ def _bicgstab_kernel_1(
987
+ tol: Any,
988
+ resid: wp.array(dtype=Any),
989
+ rho_old: wp.array(dtype=Any),
990
+ r0v: wp.array(dtype=Any),
991
+ x: wp.array(dtype=Any),
992
+ r: wp.array(dtype=Any),
993
+ y: wp.array(dtype=Any),
994
+ v: wp.array(dtype=Any),
995
+ ):
996
+ i = wp.tid()
997
+
998
+ alpha = wp.where(resid[0] > tol, rho_old[0] / r0v[0], rho_old.dtype(0.0))
999
+
1000
+ x[i] += alpha * y[i]
1001
+ r[i] -= alpha * v[i]
1002
+
1003
+
1004
+ @wp.kernel
1005
+ def _bicgstab_kernel_2(
1006
+ tol: Any,
1007
+ resid: wp.array(dtype=Any),
1008
+ st: wp.array(dtype=Any),
1009
+ tt: wp.array(dtype=Any),
1010
+ z: wp.array(dtype=Any),
1011
+ t: wp.array(dtype=Any),
1012
+ x: wp.array(dtype=Any),
1013
+ r: wp.array(dtype=Any),
1014
+ ):
1015
+ i = wp.tid()
1016
+
1017
+ omega = wp.where(resid[0] > tol, st[0] / tt[0], st.dtype(0.0))
1018
+
1019
+ x[i] += omega * z[i]
1020
+ r[i] -= omega * t[i]
1021
+
1022
+
1023
+ @wp.kernel
1024
+ def _bicgstab_kernel_3(
1025
+ tol: Any,
1026
+ resid: wp.array(dtype=Any),
1027
+ rho_new: wp.array(dtype=Any),
1028
+ r0v: wp.array(dtype=Any),
1029
+ st: wp.array(dtype=Any),
1030
+ tt: wp.array(dtype=Any),
1031
+ p: wp.array(dtype=Any),
1032
+ r: wp.array(dtype=Any),
1033
+ v: wp.array(dtype=Any),
1034
+ ):
1035
+ i = wp.tid()
1036
+
1037
+ beta = wp.where(resid[0] > tol, rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1038
+ beta_omega = wp.where(resid[0] > tol, rho_new[0] / r0v[0], st.dtype(0.0))
1039
+
1040
+ p[i] = r[i] + beta * p[i] - beta_omega * v[i]
1041
+
1042
+
1043
+ @wp.kernel
1044
+ def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
1045
+ # normalize lower-diagonal values of Hessenberg matrix
1046
+ i = wp.tid()
1047
+ H[i + 1, i] = wp.sqrt(H[i + 1, i])
1048
+
1049
+
1050
+ @wp.kernel
1051
+ def _gmres_solve_least_squares(
1052
+ k: int, pivot_tolerance: float, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1053
+ ):
1054
+ # Solve H y = (beta, 0, ..., 0)
1055
+ # H Hessenberg matrix of shape (k+1, k)
1056
+
1057
+ # Keeping H in global mem; warp kernels are launched with fixed block size,
1058
+ # so would not fit in registers
1059
+
1060
+ # TODO: switch to native code with thread synchronization
1061
+
1062
+ rhs = wp.sqrt(beta_sq[0])
1063
+
1064
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
1065
+ # and apply similar rotations to right-hand-side
1066
+ max_k = int(k)
1067
+ for i in range(k):
1068
+ Ha = H[i]
1069
+ Hb = H[i + 1]
1070
+
1071
+ # Givens rotation [[c s], [-s c]]
1072
+ a = Ha[i]
1073
+ b = Hb[i]
1074
+ abn_sq = a * a + b * b
1075
+
1076
+ if abn_sq < type(abn_sq)(pivot_tolerance):
1077
+ # Arnoldi iteration finished early
1078
+ max_k = i
1079
+ break
1080
+
1081
+ abn = wp.sqrt(abn_sq)
1082
+ c = a / abn
1083
+ s = b / abn
1084
+
1085
+ # Rotate H
1086
+ for j in range(i, k):
1087
+ a = Ha[j]
1088
+ b = Hb[j]
1089
+ Ha[j] = c * a + s * b
1090
+ Hb[j] = c * b - s * a
1091
+
1092
+ # Rotate rhs
1093
+ y[i] = c * rhs
1094
+ rhs = -s * rhs
1095
+
1096
+ for i in range(max_k, k):
1097
+ y[i] = y.dtype(0.0)
1098
+
1099
+ # Triangular back-solve for y
1100
+ for ii in range(max_k, 0, -1):
1101
+ i = ii - 1
1102
+ Hi = H[i]
1103
+ yi = y[i]
1104
+ for j in range(ii, max_k):
1105
+ yi -= Hi[j] * y[j]
1106
+ y[i] = yi / Hi[i]
1107
+
1108
+
1109
+ @wp.kernel
1110
+ def _gmres_arnoldi_axpy_kernel(
1111
+ x: wp.array(dtype=Any),
1112
+ y: wp.array(dtype=Any),
1113
+ alpha: wp.array(dtype=Any),
1114
+ ):
1115
+ tid = wp.tid()
1116
+ y[tid] -= x[tid] * alpha[0]
1117
+
1118
+
1119
+ @wp.kernel
1120
+ def _gmres_arnoldi_normalize_kernel(
1121
+ x: wp.array(dtype=Any),
1122
+ y: wp.array(dtype=Any),
1123
+ alpha: wp.array(dtype=Any),
1124
+ ):
1125
+ tid = wp.tid()
1126
+ y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / wp.sqrt(alpha[0]))
1127
+
1128
+
1129
+ @wp.kernel
1130
+ def _gmres_update_x_kernel(k: int, beta: Any, y: wp.array(dtype=Any), V: wp.array2d(dtype=Any), x: wp.array(dtype=Any)):
1131
+ tid = wp.tid()
1132
+
1133
+ xi = beta * x[tid]
1134
+ for j in range(k):
1135
+ xi += V[j, tid] * y[j]
1136
+
1137
+ x[tid] = xi