warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1606 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import math
18
+ from typing import Any, Callable, Optional, Tuple, Union
19
+
20
+ import warp as wp
21
+ import warp._src.sparse as sparse
22
+ from warp._src.types import type_length, type_scalar_type
23
+
24
+ __all__ = ["LinearOperator", "aslinearoperator", "bicgstab", "cg", "cr", "gmres", "preconditioner"]
25
+
26
+ # No need to auto-generate adjoint code for linear solvers
27
+ wp.set_module_options({"enable_backward": False})
28
+
29
+
30
+ class LinearOperator:
31
+ """
32
+ Linear operator to be used as left-hand-side of linear iterative solvers.
33
+
34
+ Args:
35
+ shape: Tuple containing the number of rows and columns of the operator
36
+ dtype: Type of the operator elements
37
+ device: Device on which computations involving the operator should be performed
38
+ matvec: Matrix-vector multiplication routine
39
+
40
+ The matrix-vector multiplication routine should have the following signature:
41
+
42
+ .. code-block:: python
43
+
44
+ def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
45
+ '''Perform a generalized matrix-vector product.
46
+
47
+ This function computes the operation z = alpha * (A @ x) + beta * y, where 'A'
48
+ is the linear operator represented by this class.
49
+ '''
50
+ ...
51
+
52
+ For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
53
+ for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
54
+ cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
55
+
56
+ """
57
+
58
+ def __init__(self, shape: Tuple[int, int], dtype: type, device: wp._src.context.Device, matvec: Callable):
59
+ self._shape = shape
60
+ self._dtype = dtype
61
+ self._device = device
62
+ self._matvec = matvec
63
+
64
+ @property
65
+ def shape(self) -> Tuple[int, int]:
66
+ return self._shape
67
+
68
+ @property
69
+ def dtype(self) -> type:
70
+ return self._dtype
71
+
72
+ @property
73
+ def device(self) -> wp._src.context.Device:
74
+ return self._device
75
+
76
+ @property
77
+ def matvec(self) -> Callable:
78
+ return self._matvec
79
+
80
+ @property
81
+ def scalar_type(self):
82
+ return wp._src.types.type_scalar_type(self.dtype)
83
+
84
+
85
+ _Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
86
+
87
+
88
+ def aslinearoperator(A: _Matrix) -> LinearOperator:
89
+ """
90
+ Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
91
+
92
+ `A` must be of one of the following types:
93
+
94
+ - :class:`warp.sparse.BsrMatrix`
95
+ - two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
96
+ - one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
97
+ - :class:`warp.sparse.LinearOperator`; no casting necessary
98
+ """
99
+
100
+ if A is None or isinstance(A, LinearOperator):
101
+ return A
102
+
103
+ def bsr_mv(x, y, z, alpha, beta):
104
+ if z.ptr != y.ptr and beta != 0.0:
105
+ wp.copy(src=y, dest=z)
106
+ sparse.bsr_mv(A, x, z, alpha, beta)
107
+
108
+ def dense_mv(x, y, z, alpha, beta):
109
+ alpha = A.dtype(alpha)
110
+ beta = A.dtype(beta)
111
+ if A.device.is_cuda:
112
+ tile_size = 1 << min(10, max(5, math.ceil(math.log2(A.shape[1]))))
113
+ else:
114
+ tile_size = 1
115
+ wp.launch(
116
+ _dense_mv_kernel,
117
+ dim=(A.shape[0], tile_size),
118
+ block_dim=tile_size,
119
+ device=A.device,
120
+ inputs=[A, x, y, z, alpha, beta],
121
+ )
122
+
123
+ def diag_mv_impl(A, x, y, z, alpha, beta):
124
+ scalar_type = type_scalar_type(A.dtype)
125
+ alpha = scalar_type(alpha)
126
+ beta = scalar_type(beta)
127
+ wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
128
+
129
+ def diag_mv(x, y, z, alpha, beta):
130
+ return diag_mv_impl(A, x, y, z, alpha, beta)
131
+
132
+ def diag_mv_vec(x, y, z, alpha, beta):
133
+ return diag_mv_impl(
134
+ _as_scalar_array(A), _as_scalar_array(x), _as_scalar_array(y), _as_scalar_array(z), alpha, beta
135
+ )
136
+
137
+ if isinstance(A, wp.array):
138
+ if A.ndim == 2:
139
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
140
+ if A.ndim == 1:
141
+ if wp._src.types.type_is_vector(A.dtype):
142
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
143
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
144
+ if isinstance(A, sparse.BsrMatrix):
145
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
146
+
147
+ raise ValueError(f"Unable to create LinearOperator from {A}")
148
+
149
+
150
+ def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
151
+ """Constructs and returns a preconditioner for an input matrix.
152
+
153
+ Args:
154
+ A: The matrix for which to build the preconditioner
155
+ ptype: The type of preconditioner. Currently the following values are supported:
156
+
157
+ - ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
158
+ - ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
159
+ - ``"id"``: Identity (null) preconditioner
160
+ """
161
+
162
+ if ptype == "id":
163
+ return None
164
+
165
+ if ptype in ("diag", "diag_abs"):
166
+ use_abs = 1 if ptype == "diag_abs" else 0
167
+ if isinstance(A, sparse.BsrMatrix):
168
+ A_diag = sparse.bsr_get_diag(A)
169
+ if wp._src.types.type_is_matrix(A.dtype):
170
+ inv_diag = wp.empty(
171
+ shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
172
+ )
173
+ wp.launch(
174
+ _extract_inverse_diagonal_blocked,
175
+ dim=inv_diag.shape,
176
+ device=inv_diag.device,
177
+ inputs=[A_diag, inv_diag, use_abs],
178
+ )
179
+ else:
180
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
181
+ wp.launch(
182
+ _extract_inverse_diagonal_scalar,
183
+ dim=inv_diag.shape,
184
+ device=inv_diag.device,
185
+ inputs=[A_diag, inv_diag, use_abs],
186
+ )
187
+ elif isinstance(A, wp.array) and A.ndim == 2:
188
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
189
+ wp.launch(
190
+ _extract_inverse_diagonal_dense,
191
+ dim=inv_diag.shape,
192
+ device=inv_diag.device,
193
+ inputs=[A, inv_diag, use_abs],
194
+ )
195
+ else:
196
+ raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
197
+
198
+ return aslinearoperator(inv_diag)
199
+
200
+ raise ValueError(f"Unsupported preconditioner type '{ptype}'")
201
+
202
+
203
+ def _as_scalar_array(x: wp.array):
204
+ scalar_type = type_scalar_type(x.dtype)
205
+ if scalar_type == x.dtype:
206
+ return x
207
+
208
+ dlen = type_length(x.dtype)
209
+ arr = wp.array(
210
+ ptr=x.ptr,
211
+ shape=(*x.shape[:-1], x.shape[-1] * dlen),
212
+ strides=(*x.strides[:-1], x.strides[-1] // dlen),
213
+ dtype=scalar_type,
214
+ device=x.device,
215
+ grad=None if x.grad is None else _as_scalar_array(x.grad),
216
+ )
217
+ arr._ref = x
218
+ return arr
219
+
220
+
221
+ class TiledDot:
222
+ """
223
+ Computes the dot product of two arrays in a way that is compatible with CUDA sub-graphs.
224
+ """
225
+
226
+ def __init__(self, max_length: int, scalar_type: type, tile_size=512, device=None, max_column_count: int = 1):
227
+ self.tile_size = tile_size
228
+ self.device = device
229
+ self.max_column_count = max_column_count
230
+
231
+ num_blocks = (max_length + self.tile_size - 1) // self.tile_size
232
+ scratch = wp.empty(
233
+ shape=(2, max_column_count, num_blocks),
234
+ dtype=scalar_type,
235
+ device=self.device,
236
+ )
237
+ self.partial_sums_a = scratch[0]
238
+ self.partial_sums_b = scratch[1]
239
+
240
+ self.dot_kernel, self.sum_kernel = _create_tiled_dot_kernels(self.tile_size)
241
+
242
+ rounds = 0
243
+ length = num_blocks
244
+ while length > 1:
245
+ length = (length + self.tile_size - 1) // self.tile_size
246
+ rounds += 1
247
+
248
+ self.rounds = rounds
249
+
250
+ self._output = self.partial_sums_a if rounds % 2 == 0 else self.partial_sums_b
251
+
252
+ self.dot_launch: wp.Launch = wp.launch(
253
+ self.dot_kernel,
254
+ dim=(max_column_count, num_blocks, self.tile_size),
255
+ inputs=(self.partial_sums_a, self.partial_sums_b),
256
+ outputs=(self.partial_sums_a,),
257
+ block_dim=self.tile_size,
258
+ record_cmd=True,
259
+ )
260
+ self.sum_launch: wp.Launch = wp.launch(
261
+ self.sum_kernel,
262
+ dim=(max_column_count, num_blocks, self.tile_size),
263
+ inputs=(self.partial_sums_a,),
264
+ outputs=(self.partial_sums_b,),
265
+ block_dim=self.tile_size,
266
+ record_cmd=True,
267
+ )
268
+
269
+ # Result contains a single value, the sum of the array (will get updated by this function)
270
+ def compute(self, a: wp.array, b: wp.array, col_offset: int = 0):
271
+ a = _as_scalar_array(a)
272
+ b = _as_scalar_array(b)
273
+ if a.ndim == 1:
274
+ a = a.reshape((1, -1))
275
+ if b.ndim == 1:
276
+ b = b.reshape((1, -1))
277
+
278
+ column_count = a.shape[0]
279
+ num_blocks = (a.shape[1] + self.tile_size - 1) // self.tile_size
280
+
281
+ data_out = self.partial_sums_a[col_offset : col_offset + column_count]
282
+ data_in = self.partial_sums_b[col_offset : col_offset + column_count]
283
+
284
+ self.dot_launch.set_param_at_index(0, a)
285
+ self.dot_launch.set_param_at_index(1, b)
286
+ self.dot_launch.set_param_at_index(2, data_out)
287
+ self.dot_launch.set_dim((column_count, num_blocks, self.tile_size))
288
+ self.dot_launch.launch()
289
+
290
+ for _r in range(self.rounds):
291
+ array_length = num_blocks
292
+ num_blocks = (array_length + self.tile_size - 1) // self.tile_size
293
+ data_in, data_out = data_out, data_in
294
+
295
+ self.sum_launch.set_param_at_index(0, data_in)
296
+ self.sum_launch.set_param_at_index(1, data_out)
297
+ self.sum_launch.set_dim((column_count, num_blocks, self.tile_size))
298
+ self.sum_launch.launch()
299
+
300
+ return data_out
301
+
302
+ def col(self, col: int = 0):
303
+ return self._output[col][:1]
304
+
305
+ def cols(self, count, start: int = 0):
306
+ return self._output[start : start + count, :1]
307
+
308
+
309
+ @functools.lru_cache(maxsize=None)
310
+ def _create_tiled_dot_kernels(tile_size):
311
+ @wp.kernel
312
+ def block_dot_kernel(
313
+ a: wp.array2d(dtype=Any),
314
+ b: wp.array2d(dtype=Any),
315
+ partial_sums: wp.array2d(dtype=Any),
316
+ ):
317
+ column, block_id, tid_block = wp.tid()
318
+
319
+ start = block_id * tile_size
320
+
321
+ a_block = wp.tile_load(a[column], shape=tile_size, offset=start)
322
+ b_block = wp.tile_load(b[column], shape=tile_size, offset=start)
323
+ t = wp.tile_map(wp.mul, a_block, b_block)
324
+
325
+ tile_sum = wp.tile_sum(t)
326
+ wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
327
+
328
+ @wp.kernel
329
+ def block_sum_kernel(
330
+ data: wp.array2d(dtype=Any),
331
+ partial_sums: wp.array2d(dtype=Any),
332
+ ):
333
+ column, block_id, tid_block = wp.tid()
334
+ start = block_id * tile_size
335
+
336
+ t = wp.tile_load(data[column], shape=tile_size, offset=start)
337
+
338
+ tile_sum = wp.tile_sum(t)
339
+ wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
340
+
341
+ return block_dot_kernel, block_sum_kernel
342
+
343
+
344
+ def cg(
345
+ A: _Matrix,
346
+ b: wp.array,
347
+ x: wp.array,
348
+ tol: Optional[float] = None,
349
+ atol: Optional[float] = None,
350
+ maxiter: Optional[float] = 0,
351
+ M: Optional[_Matrix] = None,
352
+ callback: Optional[Callable] = None,
353
+ check_every=10,
354
+ use_cuda_graph=True,
355
+ ) -> Union[Tuple[int, float, float], Tuple[wp.array, wp.array, wp.array]]:
356
+ """Computes an approximate solution to a symmetric, positive-definite linear system
357
+ using the Conjugate Gradient algorithm.
358
+
359
+ Args:
360
+ A: the linear system's left-hand-side
361
+ b: the linear system's right-hand-side
362
+ x: initial guess and solution vector
363
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
364
+ atol: absolute tolerance for the residual
365
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
366
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
367
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
368
+ If `check_every` is 0, the callback should be a Warp kernel.
369
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
370
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
371
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
372
+ to the maximum number of iterations.
373
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
374
+ The linear operator and preconditioner must only perform graph-friendly operations.
375
+
376
+ Returns:
377
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
378
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
379
+ - residual_norm: The final residual norm ||b - Ax||
380
+ - absolute_tolerance: The absolute tolerance used for convergence checking
381
+
382
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
383
+ - final_iteration_array: Device array containing the number of iterations performed
384
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
385
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
386
+
387
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
388
+ """
389
+ A = aslinearoperator(A)
390
+ M = aslinearoperator(M)
391
+
392
+ if maxiter == 0:
393
+ maxiter = A.shape[0]
394
+
395
+ device = A.device
396
+ scalar_type = A.scalar_type
397
+
398
+ # Temp storage
399
+ r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
400
+ p_and_Ap = wp.empty_like(r_and_z)
401
+ residuals = wp.empty(2, dtype=scalar_type, device=device)
402
+
403
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
404
+
405
+ # named views
406
+
407
+ # (r, r) -- so we can compute r.z and r.r at once
408
+ r_repeated = _repeat_first(r_and_z)
409
+ if M is None:
410
+ # without preconditioner r == z
411
+ r_and_z = r_repeated
412
+ rz_new = tiled_dot.col(0)
413
+ else:
414
+ rz_new = tiled_dot.col(1)
415
+
416
+ r, z = r_and_z[0], r_and_z[1]
417
+ r_norm_sq = tiled_dot.col(0)
418
+
419
+ p, Ap = p_and_Ap[0], p_and_Ap[1]
420
+ rz_old, atol_sq = residuals[0:1], residuals[1:2]
421
+
422
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
423
+ Ap.zero_()
424
+ z.zero_()
425
+
426
+ # Initialize tolerance from right-hand-side norm
427
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
428
+ # Initialize residual
429
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
430
+
431
+ def update_rr_rz():
432
+ # z = M r
433
+ if M is None:
434
+ tiled_dot.compute(r, r)
435
+ else:
436
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
437
+ tiled_dot.compute(r_repeated, r_and_z)
438
+
439
+ update_rr_rz()
440
+ p.assign(z)
441
+
442
+ def do_iteration():
443
+ rz_old.assign(rz_new)
444
+
445
+ # Ap = A * p;
446
+ A.matvec(p, Ap, Ap, alpha=1, beta=0)
447
+ tiled_dot.compute(p, Ap, col_offset=1)
448
+ p_Ap = tiled_dot.col(1)
449
+
450
+ wp.launch(
451
+ kernel=_cg_kernel_1,
452
+ dim=x.shape[0],
453
+ device=device,
454
+ inputs=[atol_sq, r_norm_sq, rz_old, p_Ap, x, r, p, Ap],
455
+ )
456
+
457
+ update_rr_rz()
458
+
459
+ wp.launch(
460
+ kernel=_cg_kernel_2,
461
+ dim=z.shape[0],
462
+ device=device,
463
+ inputs=[atol_sq, r_norm_sq, rz_old, rz_new, z, p],
464
+ )
465
+
466
+ return _run_capturable_loop(do_iteration, r_norm_sq, maxiter, atol_sq, callback, check_every, use_cuda_graph)
467
+
468
+
469
+ def cr(
470
+ A: _Matrix,
471
+ b: wp.array,
472
+ x: wp.array,
473
+ tol: Optional[float] = None,
474
+ atol: Optional[float] = None,
475
+ maxiter: Optional[float] = 0,
476
+ M: Optional[_Matrix] = None,
477
+ callback: Optional[Callable] = None,
478
+ check_every=10,
479
+ use_cuda_graph=True,
480
+ ) -> Tuple[int, float, float]:
481
+ """Computes an approximate solution to a symmetric, positive-definite linear system
482
+ using the Conjugate Residual algorithm.
483
+
484
+ Args:
485
+ A: the linear system's left-hand-side
486
+ b: the linear system's right-hand-side
487
+ x: initial guess and solution vector
488
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
489
+ atol: absolute tolerance for the residual
490
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
491
+ Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
492
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
493
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
494
+ If `check_every` is 0, the callback should be a Warp kernel.
495
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
496
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
497
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
498
+ to the maximum number of iterations.
499
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
500
+ The linear operator and preconditioner must only perform graph-friendly operations.
501
+
502
+ Returns:
503
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
504
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
505
+ - residual_norm: The final residual norm ||b - Ax||
506
+ - absolute_tolerance: The absolute tolerance used for convergence checking
507
+
508
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
509
+ - final_iteration_array: Device array containing the number of iterations performed
510
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
511
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
512
+
513
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
514
+ """
515
+
516
+ A = aslinearoperator(A)
517
+ M = aslinearoperator(M)
518
+
519
+ if maxiter == 0:
520
+ maxiter = A.shape[0]
521
+
522
+ device = A.device
523
+ scalar_type = wp._src.types.type_scalar_type(A.dtype)
524
+
525
+ # Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
526
+ # with z := M^-1 r and y := M^-1 Ap
527
+
528
+ # Temp storage
529
+ r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
530
+ r_and_Az = wp.empty_like(r_and_z)
531
+ y_and_Ap = wp.empty_like(r_and_z)
532
+ p = wp.empty_like(b)
533
+ residuals = wp.empty(2, dtype=scalar_type, device=device)
534
+
535
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
536
+
537
+ if M is None:
538
+ r_and_z = _repeat_first(r_and_z)
539
+ y_and_Ap = _repeat_first(y_and_Ap)
540
+
541
+ # named views
542
+ r, z = r_and_z[0], r_and_z[1]
543
+ r_copy, Az = r_and_Az[0], r_and_Az[1]
544
+
545
+ y, Ap = y_and_Ap[0], y_and_Ap[1]
546
+
547
+ r_norm_sq = tiled_dot.col(0)
548
+ zAz_new = tiled_dot.col(1)
549
+ zAz_old, atol_sq = residuals[0:1], residuals[1:2]
550
+
551
+ # Initialize tolerance from right-hand-side norm
552
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
553
+ # Initialize residual
554
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
555
+
556
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
557
+ y_and_Ap.zero_()
558
+
559
+ # z = M r
560
+ if M is not None:
561
+ z.zero_()
562
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
563
+
564
+ def update_rr_zAz():
565
+ A.matvec(z, Az, Az, alpha=1, beta=0)
566
+ r_copy.assign(r)
567
+ tiled_dot.compute(r_and_z, r_and_Az)
568
+
569
+ update_rr_zAz()
570
+
571
+ p.assign(z)
572
+ Ap.assign(Az)
573
+
574
+ def do_iteration():
575
+ zAz_old.assign(zAz_new)
576
+
577
+ if M is not None:
578
+ M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
579
+ tiled_dot.compute(Ap, y, col_offset=1)
580
+ y_Ap = tiled_dot.col(1)
581
+
582
+ if M is None:
583
+ # In non-preconditioned case, first kernel is same as CG
584
+ wp.launch(
585
+ kernel=_cg_kernel_1,
586
+ dim=x.shape[0],
587
+ device=device,
588
+ inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, p, Ap],
589
+ )
590
+ else:
591
+ # In preconditioned case, we have one more vector to update
592
+ wp.launch(
593
+ kernel=_cr_kernel_1,
594
+ dim=x.shape[0],
595
+ device=device,
596
+ inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, z, p, Ap, y],
597
+ )
598
+
599
+ update_rr_zAz()
600
+ wp.launch(
601
+ kernel=_cr_kernel_2,
602
+ dim=z.shape[0],
603
+ device=device,
604
+ inputs=[atol_sq, r_norm_sq, zAz_old, zAz_new, z, p, Az, Ap],
605
+ )
606
+
607
+ return _run_capturable_loop(
608
+ do_iteration,
609
+ cycle_size=1,
610
+ r_norm_sq=r_norm_sq,
611
+ maxiter=maxiter,
612
+ atol_sq=atol_sq,
613
+ callback=callback,
614
+ check_every=check_every,
615
+ use_cuda_graph=use_cuda_graph,
616
+ )
617
+
618
+
619
+ def bicgstab(
620
+ A: _Matrix,
621
+ b: wp.array,
622
+ x: wp.array,
623
+ tol: Optional[float] = None,
624
+ atol: Optional[float] = None,
625
+ maxiter: Optional[float] = 0,
626
+ M: Optional[_Matrix] = None,
627
+ callback: Optional[Callable] = None,
628
+ check_every=10,
629
+ use_cuda_graph=True,
630
+ is_left_preconditioner=False,
631
+ ):
632
+ """Computes an approximate solution to a linear system using the Biconjugate Gradient Stabilized method (BiCGSTAB).
633
+
634
+ Args:
635
+ A: the linear system's left-hand-side
636
+ b: the linear system's right-hand-side
637
+ x: initial guess and solution vector
638
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
639
+ atol: absolute tolerance for the residual
640
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
641
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
642
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
643
+ If `check_every` is 0, the callback should be a Warp kernel.
644
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
645
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
646
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
647
+ to the maximum number of iterations.
648
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
649
+ The linear operator and preconditioner must only perform graph-friendly operations.
650
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
651
+
652
+ Returns:
653
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
654
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
655
+ - residual_norm: The final residual norm ||b - Ax||
656
+ - absolute_tolerance: The absolute tolerance used for convergence checking
657
+
658
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
659
+ - final_iteration_array: Device array containing the number of iterations performed
660
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
661
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
662
+
663
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
664
+ """
665
+ A = aslinearoperator(A)
666
+ M = aslinearoperator(M)
667
+
668
+ if maxiter == 0:
669
+ maxiter = A.shape[0]
670
+
671
+ device = A.device
672
+ scalar_type = wp._src.types.type_scalar_type(A.dtype)
673
+
674
+ # Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
675
+
676
+ # Temp storage
677
+ r_and_r0 = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
678
+ p = wp.empty_like(b)
679
+ v = wp.empty_like(b)
680
+ t = wp.empty_like(b)
681
+
682
+ r, r0 = r_and_r0[0], r_and_r0[1]
683
+ r_repeated = _repeat_first(r_and_r0)
684
+
685
+ if M is not None:
686
+ y = wp.zeros_like(p)
687
+ z = wp.zeros_like(r)
688
+ if is_left_preconditioner:
689
+ Mt = wp.zeros_like(t)
690
+ else:
691
+ y = p
692
+ z = r
693
+ Mt = t
694
+
695
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=5)
696
+ r_norm_sq = tiled_dot.col(0)
697
+ rho = tiled_dot.col(1)
698
+
699
+ atol_sq = wp.empty(1, dtype=scalar_type, device=device)
700
+
701
+ # Initialize tolerance from right-hand-side norm
702
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
703
+ # Initialize residual
704
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
705
+ tiled_dot.compute(r, r, col_offset=0)
706
+
707
+ p.assign(r)
708
+ r0.assign(r)
709
+ rho.assign(r_norm_sq)
710
+
711
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
712
+ v.zero_()
713
+ t.zero_()
714
+
715
+ def do_iteration():
716
+ # y = M p
717
+ if M is not None:
718
+ M.matvec(p, y, y, alpha=1.0, beta=0.0)
719
+
720
+ # v = A * y;
721
+ A.matvec(y, v, v, alpha=1, beta=0)
722
+
723
+ # alpha = rho / <r0 . v>
724
+ tiled_dot.compute(r0, v, col_offset=2)
725
+ r0v = tiled_dot.col(2)
726
+
727
+ # x += alpha y
728
+ # r -= alpha v
729
+ wp.launch(
730
+ kernel=_bicgstab_kernel_1,
731
+ dim=x.shape[0],
732
+ device=device,
733
+ inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
734
+ )
735
+ tiled_dot.compute(r, r, col_offset=0)
736
+
737
+ # z = M r
738
+ if M is not None:
739
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
740
+
741
+ # t = A z
742
+ A.matvec(z, t, t, alpha=1, beta=0)
743
+
744
+ if M is not None and is_left_preconditioner:
745
+ # Mt = M t
746
+ M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
747
+
748
+ # omega = <Mt, Ms> / <Mt, Mt>
749
+ tiled_dot.compute(z, Mt, col_offset=3)
750
+ tiled_dot.compute(Mt, Mt, col_offset=4)
751
+ else:
752
+ tiled_dot.compute(r, t, col_offset=3)
753
+ tiled_dot.compute(t, t, col_offset=4)
754
+ st = tiled_dot.col(3)
755
+ tt = tiled_dot.col(4)
756
+
757
+ # x += omega z
758
+ # r -= omega t
759
+ wp.launch(
760
+ kernel=_bicgstab_kernel_2,
761
+ dim=z.shape[0],
762
+ device=device,
763
+ inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
764
+ )
765
+
766
+ # r = <r,r>, rho = <r0, r>
767
+ tiled_dot.compute(r_and_r0, r_repeated, col_offset=0)
768
+
769
+ # beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
770
+ # p = r + beta (p - omega v)
771
+ wp.launch(
772
+ kernel=_bicgstab_kernel_3,
773
+ dim=z.shape[0],
774
+ device=device,
775
+ inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
776
+ )
777
+
778
+ return _run_capturable_loop(
779
+ do_iteration,
780
+ r_norm_sq=r_norm_sq,
781
+ maxiter=maxiter,
782
+ atol_sq=atol_sq,
783
+ callback=callback,
784
+ check_every=check_every,
785
+ use_cuda_graph=use_cuda_graph,
786
+ )
787
+
788
+
789
+ def gmres(
790
+ A: _Matrix,
791
+ b: wp.array,
792
+ x: wp.array,
793
+ tol: Optional[float] = None,
794
+ atol: Optional[float] = None,
795
+ restart=31,
796
+ maxiter: Optional[float] = 0,
797
+ M: Optional[_Matrix] = None,
798
+ callback: Optional[Callable] = None,
799
+ check_every=31,
800
+ use_cuda_graph=True,
801
+ is_left_preconditioner=False,
802
+ ):
803
+ """Computes an approximate solution to a linear system using the restarted Generalized Minimum Residual method (GMRES[k]).
804
+
805
+ Args:
806
+ A: the linear system's left-hand-side
807
+ b: the linear system's right-hand-side
808
+ x: initial guess and solution vector
809
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
810
+ atol: absolute tolerance for the residual
811
+ restart: The restart parameter, i.e, the `k` in `GMRES[k]`. In general, increasing this parameter reduces the number of iterations but increases memory consumption.
812
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
813
+ Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
814
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
815
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
816
+ If `check_every` is 0, the callback should be a Warp kernel.
817
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
818
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
819
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
820
+ to the maximum number of iterations.
821
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
822
+ The linear operator and preconditioner must only perform graph-friendly operations.
823
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
824
+
825
+ Returns:
826
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
827
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
828
+ - residual_norm: The final residual norm ||b - Ax||
829
+ - absolute_tolerance: The absolute tolerance used for convergence checking
830
+
831
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
832
+ - final_iteration_array: Device array containing the number of iterations performed
833
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
834
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
835
+
836
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
837
+ """
838
+
839
+ A = aslinearoperator(A)
840
+ M = aslinearoperator(M)
841
+
842
+ if maxiter == 0:
843
+ maxiter = A.shape[0]
844
+
845
+ restart = min(restart, maxiter)
846
+
847
+ if check_every > 0:
848
+ check_every = max(restart, check_every)
849
+
850
+ device = A.device
851
+ scalar_dtype = wp._src.types.type_scalar_type(A.dtype)
852
+
853
+ pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
854
+
855
+ r = wp.empty_like(b)
856
+ w = wp.empty_like(r)
857
+
858
+ H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
859
+ y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
860
+
861
+ V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
862
+
863
+ residuals = wp.empty(2, dtype=scalar_dtype, device=device)
864
+ beta, atol_sq = residuals[0:1], residuals[1:2]
865
+
866
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_dtype, max_column_count=restart + 1)
867
+ r_norm_sq = tiled_dot.col(0)
868
+
869
+ w_repeated = wp.array(
870
+ ptr=w.ptr, shape=(restart + 1, w.shape[0]), strides=(0, w.strides[0]), dtype=w.dtype, device=w.device
871
+ )
872
+
873
+ # tile size for least square solve
874
+ # (need to fit in a CUDA block, so 1024 max)
875
+ if device.is_cuda and 4 < restart <= 1024:
876
+ tile_size = 1 << math.ceil(math.log2(restart))
877
+ least_squares_kernel = make_gmres_solve_least_squares_kernel_tiled(tile_size)
878
+ else:
879
+ tile_size = 1
880
+ least_squares_kernel = _gmres_solve_least_squares
881
+
882
+ # recorded launches
883
+ least_squares_solve = wp.launch(
884
+ least_squares_kernel,
885
+ dim=(1, tile_size),
886
+ block_dim=tile_size if tile_size > 1 else 256,
887
+ device=device,
888
+ inputs=[restart, pivot_tolerance, beta, H, y],
889
+ record_cmd=True,
890
+ )
891
+
892
+ normalize_anorldi_vec = wp.launch(
893
+ _gmres_arnoldi_normalize_kernel,
894
+ dim=r.shape,
895
+ device=r.device,
896
+ inputs=[r, w, tiled_dot.col(0), beta],
897
+ record_cmd=True,
898
+ )
899
+
900
+ arnoldi_axpy = wp.launch(
901
+ _gmres_arnoldi_axpy_kernel,
902
+ dim=(w.shape[0], tile_size),
903
+ block_dim=tile_size,
904
+ device=w.device,
905
+ inputs=[V, w, H],
906
+ record_cmd=True,
907
+ )
908
+
909
+ # Initialize tolerance from right-hand-side norm
910
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
911
+ # Initialize residual
912
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
913
+ tiled_dot.compute(r, r, col_offset=0)
914
+
915
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
916
+ w.zero_()
917
+
918
+ def array_coeff(H, i, j):
919
+ return H[i][j : j + 1]
920
+
921
+ def array_col(H, j):
922
+ return H[: j + 1, j : j + 1]
923
+
924
+ def do_arnoldi_iteration(j: int):
925
+ # w = A * v[j];
926
+ if M is not None:
927
+ tmp = V[j + 1]
928
+
929
+ if is_left_preconditioner:
930
+ A.matvec(V[j], tmp, tmp, alpha=1, beta=0)
931
+ M.matvec(tmp, w, w, alpha=1, beta=0)
932
+ else:
933
+ M.matvec(V[j], tmp, tmp, alpha=1, beta=0)
934
+ A.matvec(tmp, w, w, alpha=1, beta=0)
935
+ else:
936
+ A.matvec(V[j], w, w, alpha=1, beta=0)
937
+
938
+ # compute and apply dot products in rappel,
939
+ # since Hj columns are orthogonal
940
+ Hj = array_col(H, j)
941
+ tiled_dot.compute(w_repeated, V[: j + 1])
942
+ wp.copy(src=tiled_dot.cols(j + 1), dest=Hj)
943
+
944
+ # w -= w.vi vi
945
+ arnoldi_axpy.set_params([V[: j + 1], w, Hj])
946
+ arnoldi_axpy.launch()
947
+
948
+ # H[j+1, j] = |w.w|
949
+ tiled_dot.compute(w, w)
950
+ normalize_anorldi_vec.set_params([w, V[j + 1], tiled_dot.col(0), array_coeff(H, j + 1, j)])
951
+
952
+ normalize_anorldi_vec.launch()
953
+
954
+ def do_restart_cycle():
955
+ if M is not None and is_left_preconditioner:
956
+ M.matvec(r, w, w, alpha=1, beta=0)
957
+ rh = w
958
+ else:
959
+ rh = r
960
+
961
+ # beta^2 = rh.rh
962
+ tiled_dot.compute(rh, rh)
963
+
964
+ # v[0] = r / beta
965
+ normalize_anorldi_vec.set_params([rh, V[0], tiled_dot.col(0), beta])
966
+ normalize_anorldi_vec.launch()
967
+
968
+ for j in range(restart):
969
+ do_arnoldi_iteration(j)
970
+
971
+ least_squares_solve.launch()
972
+
973
+ # update x
974
+ if M is None or is_left_preconditioner:
975
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(1.0), y, V, x])
976
+ else:
977
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(0.0), y, V, w])
978
+ M.matvec(w, x, x, alpha=1, beta=1)
979
+
980
+ # update r and residual
981
+ wp.copy(src=b, dest=r)
982
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
983
+ tiled_dot.compute(r, r)
984
+
985
+ return _run_capturable_loop(
986
+ do_restart_cycle,
987
+ cycle_size=restart,
988
+ r_norm_sq=r_norm_sq,
989
+ maxiter=maxiter,
990
+ atol_sq=atol_sq,
991
+ callback=callback,
992
+ check_every=check_every,
993
+ use_cuda_graph=use_cuda_graph,
994
+ )
995
+
996
+
997
+ def _repeat_first(arr: wp.array):
998
+ # returns a view of the first element repeated arr.shape[0] times
999
+ view = wp.array(
1000
+ ptr=arr.ptr,
1001
+ shape=arr.shape,
1002
+ dtype=arr.dtype,
1003
+ strides=(0, *arr.strides[1:]),
1004
+ device=arr.device,
1005
+ )
1006
+ view._ref = arr
1007
+ return view
1008
+
1009
+
1010
+ def _get_dtype_epsilon(dtype):
1011
+ if dtype == wp.float64:
1012
+ return 1.0e-16
1013
+ elif dtype == wp.float16:
1014
+ return 1.0e-4
1015
+
1016
+ return 1.0e-8
1017
+
1018
+
1019
+ def _get_tolerances(dtype, tol, atol):
1020
+ eps_tol = _get_dtype_epsilon(dtype)
1021
+ default_tol = eps_tol ** (3 / 4)
1022
+ min_tol = eps_tol ** (9 / 4)
1023
+
1024
+ if tol is None and atol is None:
1025
+ tol = atol = default_tol
1026
+ elif tol is None:
1027
+ tol = atol
1028
+ elif atol is None:
1029
+ atol = tol
1030
+
1031
+ atol = max(atol, min_tol)
1032
+ return tol, atol
1033
+
1034
+
1035
+ @wp.kernel
1036
+ def _initialize_tolerance(
1037
+ rtol: Any,
1038
+ atol: Any,
1039
+ r_norm_sq: wp.array(dtype=Any),
1040
+ atol_sq: wp.array(dtype=Any),
1041
+ ):
1042
+ atol = wp.max(rtol * wp.sqrt(r_norm_sq[0]), atol)
1043
+ atol_sq[0] = atol * atol
1044
+
1045
+
1046
+ def _initialize_absolute_tolerance(
1047
+ b: wp.array,
1048
+ tol: float,
1049
+ atol: float,
1050
+ tiled_dot: TiledDot,
1051
+ atol_sq: wp.array,
1052
+ ):
1053
+ scalar_type = atol_sq.dtype
1054
+
1055
+ # Compute b norm to define absolute tolerance
1056
+ tiled_dot.compute(b, b)
1057
+ b_norm_sq = tiled_dot.col(0)
1058
+
1059
+ rtol, atol = _get_tolerances(scalar_type, tol, atol)
1060
+ wp.launch(
1061
+ kernel=_initialize_tolerance,
1062
+ dim=1,
1063
+ device=b.device,
1064
+ inputs=[scalar_type(rtol), scalar_type(atol), b_norm_sq, atol_sq],
1065
+ )
1066
+
1067
+
1068
+ @wp.kernel
1069
+ def _update_condition(
1070
+ maxiter: int,
1071
+ cycle_size: int,
1072
+ cur_iter: wp.array(dtype=int),
1073
+ r_norm_sq: wp.array(dtype=Any),
1074
+ atol_sq: wp.array(dtype=Any),
1075
+ condition: wp.array(dtype=int),
1076
+ ):
1077
+ cur_iter[0] += cycle_size
1078
+ condition[0] = wp.where(r_norm_sq[0] <= atol_sq[0] or cur_iter[0] >= maxiter, 0, 1)
1079
+
1080
+
1081
+ def _run_capturable_loop(
1082
+ do_cycle: Callable,
1083
+ r_norm_sq: wp.array,
1084
+ maxiter: int,
1085
+ atol_sq: wp.array,
1086
+ callback: Optional[Callable],
1087
+ check_every: int,
1088
+ use_cuda_graph: bool,
1089
+ cycle_size: int = 1,
1090
+ ):
1091
+ device = atol_sq.device
1092
+
1093
+ if check_every > 0:
1094
+ atol = math.sqrt(atol_sq.numpy()[0])
1095
+ return _run_solver_loop(
1096
+ do_cycle, cycle_size, r_norm_sq, maxiter, atol, callback, check_every, use_cuda_graph, device
1097
+ )
1098
+
1099
+ cur_iter_and_condition = wp.full((2,), value=-1, dtype=int, device=device)
1100
+ cur_iter = cur_iter_and_condition[0:1]
1101
+ condition = cur_iter_and_condition[1:2]
1102
+
1103
+ update_condition_launch = wp.launch(
1104
+ _update_condition,
1105
+ dim=1,
1106
+ device=device,
1107
+ inputs=[int(maxiter), cycle_size, cur_iter, r_norm_sq, atol_sq, condition],
1108
+ record_cmd=True,
1109
+ )
1110
+
1111
+ if isinstance(callback, wp.Kernel):
1112
+ callback_launch = wp.launch(
1113
+ callback, dim=1, device=device, inputs=[cur_iter, r_norm_sq, atol_sq], record_cmd=True
1114
+ )
1115
+ else:
1116
+ callback_launch = None
1117
+
1118
+ update_condition_launch.launch()
1119
+ if callback_launch is not None:
1120
+ callback_launch.launch()
1121
+
1122
+ def do_cycle_with_condition():
1123
+ do_cycle()
1124
+ update_condition_launch.launch()
1125
+ if callback_launch is not None:
1126
+ callback_launch.launch()
1127
+
1128
+ if use_cuda_graph and device.is_cuda:
1129
+ if device.is_capturing:
1130
+ wp.capture_while(condition, do_cycle_with_condition)
1131
+ else:
1132
+ with wp.ScopedCapture() as capture:
1133
+ wp.capture_while(condition, do_cycle_with_condition)
1134
+ wp.capture_launch(capture.graph)
1135
+ else:
1136
+ for _ in range(0, maxiter, cycle_size):
1137
+ do_cycle_with_condition()
1138
+
1139
+ return cur_iter, r_norm_sq, atol_sq
1140
+
1141
+
1142
+ def _run_solver_loop(
1143
+ do_cycle: Callable[[float], None],
1144
+ cycle_size: int,
1145
+ r_norm_sq: wp.array,
1146
+ maxiter: int,
1147
+ atol: float,
1148
+ callback: Callable,
1149
+ check_every: int,
1150
+ use_cuda_graph: bool,
1151
+ device,
1152
+ ):
1153
+ atol_sq = atol * atol
1154
+ check_every = max(check_every, cycle_size)
1155
+
1156
+ cur_iter = 0
1157
+
1158
+ err_sq = r_norm_sq.numpy()[0]
1159
+ err = math.sqrt(err_sq)
1160
+ if callback is not None:
1161
+ callback(cur_iter, err, atol)
1162
+
1163
+ if err_sq <= atol_sq:
1164
+ return cur_iter, err, atol
1165
+
1166
+ graph = None
1167
+
1168
+ while True:
1169
+ # Do not do graph capture at first iteration -- modules may not be loaded yet
1170
+ if device.is_cuda and use_cuda_graph and cur_iter > 0:
1171
+ if graph is None:
1172
+ with wp.ScopedCapture(force_module_load=False) as capture:
1173
+ do_cycle()
1174
+ graph = capture.graph
1175
+ wp.capture_launch(graph)
1176
+ else:
1177
+ do_cycle()
1178
+
1179
+ cur_iter += cycle_size
1180
+
1181
+ if cur_iter >= maxiter:
1182
+ break
1183
+
1184
+ if (cur_iter % check_every) < cycle_size:
1185
+ err_sq = r_norm_sq.numpy()[0]
1186
+
1187
+ if err_sq <= atol_sq:
1188
+ break
1189
+
1190
+ if callback is not None:
1191
+ callback(cur_iter, math.sqrt(err_sq), atol)
1192
+
1193
+ err_sq = r_norm_sq.numpy()[0]
1194
+ err = math.sqrt(err_sq)
1195
+ if callback is not None:
1196
+ callback(cur_iter, err, atol)
1197
+
1198
+ return cur_iter, err, atol
1199
+
1200
+
1201
+ @wp.kernel
1202
+ def _dense_mv_kernel(
1203
+ A: wp.array2d(dtype=Any),
1204
+ x: wp.array1d(dtype=Any),
1205
+ y: wp.array1d(dtype=Any),
1206
+ z: wp.array1d(dtype=Any),
1207
+ alpha: Any,
1208
+ beta: Any,
1209
+ ):
1210
+ row, lane = wp.tid()
1211
+
1212
+ zero = type(alpha)(0)
1213
+ s = zero
1214
+ if alpha != zero:
1215
+ for col in range(lane, A.shape[1], wp.block_dim()):
1216
+ s += A[row, col] * x[col]
1217
+
1218
+ row_tile = wp.tile_sum(wp.tile(s * alpha))
1219
+
1220
+ if beta != zero:
1221
+ row_tile += wp.tile_load(y, shape=1, offset=row) * beta
1222
+
1223
+ wp.tile_store(z, row_tile, offset=row)
1224
+
1225
+
1226
+ @wp.kernel
1227
+ def _diag_mv_kernel(
1228
+ A: wp.array(dtype=Any),
1229
+ x: wp.array(dtype=Any),
1230
+ y: wp.array(dtype=Any),
1231
+ z: wp.array(dtype=Any),
1232
+ alpha: Any,
1233
+ beta: Any,
1234
+ ):
1235
+ i = wp.tid()
1236
+ zero = type(alpha)(0)
1237
+ s = z.dtype(zero)
1238
+ if alpha != zero:
1239
+ s += alpha * (A[i] * x[i])
1240
+ if beta != zero:
1241
+ s += beta * y[i]
1242
+ z[i] = s
1243
+
1244
+
1245
+ @wp.func
1246
+ def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
1247
+ zero = type(coeff)(0.0)
1248
+ one = type(coeff)(1.0)
1249
+ return wp.where(coeff == zero, one, one / wp.where(use_abs, wp.abs(coeff), coeff))
1250
+
1251
+
1252
+ @wp.kernel
1253
+ def _extract_inverse_diagonal_blocked(
1254
+ diag_block: wp.array(dtype=Any),
1255
+ inv_diag: wp.array(dtype=Any),
1256
+ use_abs: int,
1257
+ ):
1258
+ i = wp.tid()
1259
+
1260
+ d = wp.get_diag(diag_block[i])
1261
+ for k in range(d.length):
1262
+ d[k] = _inverse_diag_coefficient(d[k], use_abs != 0)
1263
+
1264
+ inv_diag[i] = d
1265
+
1266
+
1267
+ @wp.kernel
1268
+ def _extract_inverse_diagonal_scalar(
1269
+ diag_array: wp.array(dtype=Any),
1270
+ inv_diag: wp.array(dtype=Any),
1271
+ use_abs: int,
1272
+ ):
1273
+ i = wp.tid()
1274
+ inv_diag[i] = _inverse_diag_coefficient(diag_array[i], use_abs != 0)
1275
+
1276
+
1277
+ @wp.kernel
1278
+ def _extract_inverse_diagonal_dense(
1279
+ dense_matrix: wp.array2d(dtype=Any),
1280
+ inv_diag: wp.array(dtype=Any),
1281
+ use_abs: int,
1282
+ ):
1283
+ i = wp.tid()
1284
+ inv_diag[i] = _inverse_diag_coefficient(dense_matrix[i, i], use_abs != 0)
1285
+
1286
+
1287
+ @wp.kernel
1288
+ def _cg_kernel_1(
1289
+ tol: wp.array(dtype=Any),
1290
+ resid: wp.array(dtype=Any),
1291
+ rz_old: wp.array(dtype=Any),
1292
+ p_Ap: wp.array(dtype=Any),
1293
+ x: wp.array(dtype=Any),
1294
+ r: wp.array(dtype=Any),
1295
+ p: wp.array(dtype=Any),
1296
+ Ap: wp.array(dtype=Any),
1297
+ ):
1298
+ i = wp.tid()
1299
+
1300
+ alpha = wp.where(resid[0] > tol[0], rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
1301
+
1302
+ x[i] = x[i] + alpha * p[i]
1303
+ r[i] = r[i] - alpha * Ap[i]
1304
+
1305
+
1306
+ @wp.kernel
1307
+ def _cg_kernel_2(
1308
+ tol: wp.array(dtype=Any),
1309
+ resid_new: wp.array(dtype=Any),
1310
+ rz_old: wp.array(dtype=Any),
1311
+ rz_new: wp.array(dtype=Any),
1312
+ z: wp.array(dtype=Any),
1313
+ p: wp.array(dtype=Any),
1314
+ ):
1315
+ # p = r + (rz_new / rz_old) * p;
1316
+ i = wp.tid()
1317
+
1318
+ cond = resid_new[0] > tol[0]
1319
+ beta = wp.where(cond, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
1320
+
1321
+ p[i] = z[i] + beta * p[i]
1322
+
1323
+
1324
+ @wp.kernel
1325
+ def _cr_kernel_1(
1326
+ tol: wp.array(dtype=Any),
1327
+ resid: wp.array(dtype=Any),
1328
+ zAz_old: wp.array(dtype=Any),
1329
+ y_Ap: wp.array(dtype=Any),
1330
+ x: wp.array(dtype=Any),
1331
+ r: wp.array(dtype=Any),
1332
+ z: wp.array(dtype=Any),
1333
+ p: wp.array(dtype=Any),
1334
+ Ap: wp.array(dtype=Any),
1335
+ y: wp.array(dtype=Any),
1336
+ ):
1337
+ i = wp.tid()
1338
+
1339
+ alpha = wp.where(resid[0] > tol[0] and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
1340
+
1341
+ x[i] = x[i] + alpha * p[i]
1342
+ r[i] = r[i] - alpha * Ap[i]
1343
+ z[i] = z[i] - alpha * y[i]
1344
+
1345
+
1346
+ @wp.kernel
1347
+ def _cr_kernel_2(
1348
+ tol: wp.array(dtype=Any),
1349
+ resid: wp.array(dtype=Any),
1350
+ zAz_old: wp.array(dtype=Any),
1351
+ zAz_new: wp.array(dtype=Any),
1352
+ z: wp.array(dtype=Any),
1353
+ p: wp.array(dtype=Any),
1354
+ Az: wp.array(dtype=Any),
1355
+ Ap: wp.array(dtype=Any),
1356
+ ):
1357
+ # p = r + (rz_new / rz_old) * p;
1358
+ i = wp.tid()
1359
+
1360
+ beta = wp.where(resid[0] > tol[0] and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
1361
+
1362
+ p[i] = z[i] + beta * p[i]
1363
+ Ap[i] = Az[i] + beta * Ap[i]
1364
+
1365
+
1366
+ @wp.kernel
1367
+ def _bicgstab_kernel_1(
1368
+ tol: wp.array(dtype=Any),
1369
+ resid: wp.array(dtype=Any),
1370
+ rho_old: wp.array(dtype=Any),
1371
+ r0v: wp.array(dtype=Any),
1372
+ x: wp.array(dtype=Any),
1373
+ r: wp.array(dtype=Any),
1374
+ y: wp.array(dtype=Any),
1375
+ v: wp.array(dtype=Any),
1376
+ ):
1377
+ i = wp.tid()
1378
+
1379
+ alpha = wp.where(resid[0] > tol[0], rho_old[0] / r0v[0], rho_old.dtype(0.0))
1380
+
1381
+ x[i] += alpha * y[i]
1382
+ r[i] -= alpha * v[i]
1383
+
1384
+
1385
+ @wp.kernel
1386
+ def _bicgstab_kernel_2(
1387
+ tol: wp.array(dtype=Any),
1388
+ resid: wp.array(dtype=Any),
1389
+ st: wp.array(dtype=Any),
1390
+ tt: wp.array(dtype=Any),
1391
+ z: wp.array(dtype=Any),
1392
+ t: wp.array(dtype=Any),
1393
+ x: wp.array(dtype=Any),
1394
+ r: wp.array(dtype=Any),
1395
+ ):
1396
+ i = wp.tid()
1397
+
1398
+ omega = wp.where(resid[0] > tol[0], st[0] / tt[0], st.dtype(0.0))
1399
+
1400
+ x[i] += omega * z[i]
1401
+ r[i] -= omega * t[i]
1402
+
1403
+
1404
+ @wp.kernel
1405
+ def _bicgstab_kernel_3(
1406
+ tol: wp.array(dtype=Any),
1407
+ resid: wp.array(dtype=Any),
1408
+ rho_new: wp.array(dtype=Any),
1409
+ r0v: wp.array(dtype=Any),
1410
+ st: wp.array(dtype=Any),
1411
+ tt: wp.array(dtype=Any),
1412
+ p: wp.array(dtype=Any),
1413
+ r: wp.array(dtype=Any),
1414
+ v: wp.array(dtype=Any),
1415
+ ):
1416
+ i = wp.tid()
1417
+
1418
+ beta = wp.where(resid[0] > tol[0], rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1419
+ beta_omega = wp.where(resid[0] > tol[0], rho_new[0] / r0v[0], st.dtype(0.0))
1420
+
1421
+ p[i] = r[i] + beta * p[i] - beta_omega * v[i]
1422
+
1423
+
1424
+ @wp.kernel
1425
+ def _gmres_solve_least_squares(
1426
+ k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1427
+ ):
1428
+ # Solve H y = (beta, 0, ..., 0)
1429
+ # H Hessenberg matrix of shape (k+1, k)
1430
+ # so would not fit in registers
1431
+
1432
+ rhs = beta[0]
1433
+
1434
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
1435
+ # and apply similar rotations to right-hand-side
1436
+ max_k = int(k)
1437
+ for i in range(k):
1438
+ Ha = H[i]
1439
+ Hb = H[i + 1]
1440
+
1441
+ # Givens rotation [[c s], [-s c]]
1442
+ a = Ha[i]
1443
+ b = Hb[i]
1444
+ abn_sq = a * a + b * b
1445
+
1446
+ if abn_sq < type(abn_sq)(pivot_tolerance):
1447
+ # Arnoldi iteration finished early
1448
+ max_k = i
1449
+ break
1450
+
1451
+ abn = wp.sqrt(abn_sq)
1452
+ c = a / abn
1453
+ s = b / abn
1454
+
1455
+ # Rotate H
1456
+ for j in range(i, k):
1457
+ a = Ha[j]
1458
+ b = Hb[j]
1459
+ Ha[j] = c * a + s * b
1460
+ Hb[j] = c * b - s * a
1461
+
1462
+ # Rotate rhs
1463
+ y[i] = c * rhs
1464
+ rhs = -s * rhs
1465
+
1466
+ for i in range(max_k, k):
1467
+ y[i] = y.dtype(0.0)
1468
+
1469
+ # Triangular back-solve for y
1470
+ for ii in range(max_k, 0, -1):
1471
+ i = ii - 1
1472
+ Hi = H[i]
1473
+ yi = y[i]
1474
+ for j in range(ii, max_k):
1475
+ yi -= Hi[j] * y[j]
1476
+ y[i] = yi / Hi[i]
1477
+
1478
+
1479
+ @functools.lru_cache(maxsize=None)
1480
+ def make_gmres_solve_least_squares_kernel_tiled(K: int):
1481
+ @wp.kernel(module="unique")
1482
+ def gmres_solve_least_squares_tiled(
1483
+ k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1484
+ ):
1485
+ # Assumes tiles of size K, and K at least as large as highest number of columns
1486
+ # Limits the max restart cycle length to the max block size of 1024, but using
1487
+ # larger restarts would be very inefficient anyway (default is ~30)
1488
+
1489
+ # Solve H y = (beta, 0, ..., 0)
1490
+ # H Hessenberg matrix of shape (k+1, k)
1491
+
1492
+ i, lane = wp.tid()
1493
+
1494
+ rhs = beta[0]
1495
+
1496
+ zero = H.dtype(0.0)
1497
+ one = H.dtype(1.0)
1498
+ yi = zero
1499
+
1500
+ Ha = wp.tile_load(H[0], shape=(K))
1501
+
1502
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
1503
+ # and apply similar rotations to right-hand-side
1504
+ max_k = int(k)
1505
+ for i in range(k):
1506
+ # Ha = H[i]
1507
+ # Hb = H[i + 1]
1508
+ Hb = wp.tile_load(H[i + 1], shape=(K))
1509
+
1510
+ # Givens rotation [[c s], [-s c]]
1511
+ a = Ha[i]
1512
+ b = Hb[i]
1513
+ abn_sq = a * a + b * b
1514
+
1515
+ if abn_sq < type(abn_sq)(pivot_tolerance):
1516
+ # Arnoldi iteration finished early
1517
+ max_k = i
1518
+ break
1519
+
1520
+ abn = wp.sqrt(abn_sq)
1521
+ c = a / abn
1522
+ s = b / abn
1523
+
1524
+ # Rotate H
1525
+ a = wp.untile(Ha)
1526
+ b = wp.untile(Hb)
1527
+ a_rot = c * a + s * b
1528
+ b_rot = c * b - s * a
1529
+
1530
+ # Rotate rhs
1531
+ if lane == i:
1532
+ yi = c * rhs
1533
+ rhs = -s * rhs
1534
+
1535
+ wp.tile_store(H[i], wp.tile(a_rot))
1536
+ Ha[lane] = b_rot
1537
+
1538
+ y_tile = wp.tile(yi)
1539
+
1540
+ # Triangular back-solve for y
1541
+ for ii in range(max_k, 0, -1):
1542
+ i = ii - 1
1543
+
1544
+ Hi = wp.tile_load(H[i], shape=(K))
1545
+
1546
+ il = lane + i
1547
+ if lane == 0:
1548
+ yl = y_tile[i]
1549
+ elif il < max_k:
1550
+ yl = -y_tile[il] * Hi[il]
1551
+ else:
1552
+ yl = zero
1553
+
1554
+ yit = wp.tile_sum(wp.tile(yl)) * (one / Hi[i])
1555
+ yit[0] # no-op, movs yit to shared
1556
+ wp.tile_assign(y_tile, yit, offset=(i,))
1557
+
1558
+ wp.tile_store(y, y_tile)
1559
+
1560
+ return gmres_solve_least_squares_tiled
1561
+
1562
+
1563
+ @wp.kernel
1564
+ def _gmres_arnoldi_axpy_kernel(
1565
+ V: wp.array2d(dtype=Any),
1566
+ w: wp.array(dtype=Any),
1567
+ Vw: wp.array2d(dtype=Any),
1568
+ ):
1569
+ tid, lane = wp.tid()
1570
+
1571
+ s = w.dtype(Vw.dtype(0))
1572
+
1573
+ tile_size = wp.block_dim()
1574
+ for k in range(lane, Vw.shape[0], tile_size):
1575
+ s += Vw[k, 0] * V[k, tid]
1576
+
1577
+ wi = wp.tile_load(w, shape=1, offset=tid)
1578
+ wi -= wp.tile_sum(wp.tile(s, preserve_type=True))
1579
+
1580
+ wp.tile_store(w, wi, offset=tid)
1581
+
1582
+
1583
+ @wp.kernel
1584
+ def _gmres_arnoldi_normalize_kernel(
1585
+ x: wp.array(dtype=Any),
1586
+ y: wp.array(dtype=Any),
1587
+ alpha: wp.array(dtype=Any),
1588
+ alpha_copy: wp.array(dtype=Any),
1589
+ ):
1590
+ tid = wp.tid()
1591
+ norm = wp.sqrt(alpha[0])
1592
+ y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / norm)
1593
+
1594
+ if tid == 0:
1595
+ alpha_copy[0] = norm
1596
+
1597
+
1598
+ @wp.kernel
1599
+ def _gmres_update_x_kernel(k: int, beta: Any, y: wp.array(dtype=Any), V: wp.array2d(dtype=Any), x: wp.array(dtype=Any)):
1600
+ tid = wp.tid()
1601
+
1602
+ xi = beta * x[tid]
1603
+ for j in range(k):
1604
+ xi += V[j, tid] * y[j]
1605
+
1606
+ x[tid] = xi