warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/optim/linear.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c)
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,1594 +13,23 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
import math
|
|
18
|
-
from typing import Any, Callable, Optional, Tuple, Union
|
|
16
|
+
# isort: skip_file
|
|
19
17
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from warp.
|
|
18
|
+
from warp._src.optim.linear import LinearOperator as LinearOperator
|
|
19
|
+
from warp._src.optim.linear import aslinearoperator as aslinearoperator
|
|
20
|
+
from warp._src.optim.linear import bicgstab as bicgstab
|
|
21
|
+
from warp._src.optim.linear import cg as cg
|
|
22
|
+
from warp._src.optim.linear import cr as cr
|
|
23
|
+
from warp._src.optim.linear import gmres as gmres
|
|
24
|
+
from warp._src.optim.linear import preconditioner as preconditioner
|
|
23
25
|
|
|
24
|
-
__all__ = ["LinearOperator", "aslinearoperator", "bicgstab", "cg", "cr", "gmres", "preconditioner"]
|
|
25
26
|
|
|
26
|
-
#
|
|
27
|
-
wp.set_module_options({"enable_backward": False})
|
|
27
|
+
# TODO: Remove after cleaning up the public API.
|
|
28
28
|
|
|
29
|
+
from warp._src.optim import linear as _linear
|
|
29
30
|
|
|
30
|
-
class LinearOperator:
|
|
31
|
-
"""
|
|
32
|
-
Linear operator to be used as left-hand-side of linear iterative solvers.
|
|
33
31
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
dtype: Type of the operator elements
|
|
37
|
-
device: Device on which computations involving the operator should be performed
|
|
38
|
-
matvec: Matrix-vector multiplication routine
|
|
32
|
+
def __getattr__(name):
|
|
33
|
+
from warp._src.utils import get_deprecated_api
|
|
39
34
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
.. code-block:: python
|
|
43
|
-
|
|
44
|
-
def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
|
|
45
|
-
'''Perform a generalized matrix-vector product.
|
|
46
|
-
|
|
47
|
-
This function computes the operation z = alpha * (A @ x) + beta * y, where 'A'
|
|
48
|
-
is the linear operator represented by this class.
|
|
49
|
-
'''
|
|
50
|
-
...
|
|
51
|
-
|
|
52
|
-
For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
|
|
53
|
-
for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
|
|
54
|
-
cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
|
|
55
|
-
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
def __init__(self, shape: Tuple[int, int], dtype: type, device: wp.context.Device, matvec: Callable):
|
|
59
|
-
self._shape = shape
|
|
60
|
-
self._dtype = dtype
|
|
61
|
-
self._device = device
|
|
62
|
-
self._matvec = matvec
|
|
63
|
-
|
|
64
|
-
@property
|
|
65
|
-
def shape(self) -> Tuple[int, int]:
|
|
66
|
-
return self._shape
|
|
67
|
-
|
|
68
|
-
@property
|
|
69
|
-
def dtype(self) -> type:
|
|
70
|
-
return self._dtype
|
|
71
|
-
|
|
72
|
-
@property
|
|
73
|
-
def device(self) -> wp.context.Device:
|
|
74
|
-
return self._device
|
|
75
|
-
|
|
76
|
-
@property
|
|
77
|
-
def matvec(self) -> Callable:
|
|
78
|
-
return self._matvec
|
|
79
|
-
|
|
80
|
-
@property
|
|
81
|
-
def scalar_type(self):
|
|
82
|
-
return wp.types.type_scalar_type(self.dtype)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
_Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def aslinearoperator(A: _Matrix) -> LinearOperator:
|
|
89
|
-
"""
|
|
90
|
-
Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
|
|
91
|
-
|
|
92
|
-
`A` must be of one of the following types:
|
|
93
|
-
|
|
94
|
-
- :class:`warp.sparse.BsrMatrix`
|
|
95
|
-
- two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
|
|
96
|
-
- one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
|
|
97
|
-
- :class:`warp.sparse.LinearOperator`; no casting necessary
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
if A is None or isinstance(A, LinearOperator):
|
|
101
|
-
return A
|
|
102
|
-
|
|
103
|
-
def bsr_mv(x, y, z, alpha, beta):
|
|
104
|
-
if z.ptr != y.ptr and beta != 0.0:
|
|
105
|
-
wp.copy(src=y, dest=z)
|
|
106
|
-
sparse.bsr_mv(A, x, z, alpha, beta)
|
|
107
|
-
|
|
108
|
-
def dense_mv(x, y, z, alpha, beta):
|
|
109
|
-
alpha = A.dtype(alpha)
|
|
110
|
-
beta = A.dtype(beta)
|
|
111
|
-
if A.device.is_cuda:
|
|
112
|
-
tile_size = 1 << min(10, max(5, math.ceil(math.log2(A.shape[1]))))
|
|
113
|
-
else:
|
|
114
|
-
tile_size = 1
|
|
115
|
-
wp.launch(
|
|
116
|
-
_dense_mv_kernel,
|
|
117
|
-
dim=(A.shape[0], tile_size),
|
|
118
|
-
block_dim=tile_size,
|
|
119
|
-
device=A.device,
|
|
120
|
-
inputs=[A, x, y, z, alpha, beta],
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
def diag_mv_impl(A, x, y, z, alpha, beta):
|
|
124
|
-
scalar_type = type_scalar_type(A.dtype)
|
|
125
|
-
alpha = scalar_type(alpha)
|
|
126
|
-
beta = scalar_type(beta)
|
|
127
|
-
wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
128
|
-
|
|
129
|
-
def diag_mv(x, y, z, alpha, beta):
|
|
130
|
-
return diag_mv_impl(A, x, y, z, alpha, beta)
|
|
131
|
-
|
|
132
|
-
def diag_mv_vec(x, y, z, alpha, beta):
|
|
133
|
-
return diag_mv_impl(
|
|
134
|
-
_as_scalar_array(A), _as_scalar_array(x), _as_scalar_array(y), _as_scalar_array(z), alpha, beta
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
if isinstance(A, wp.array):
|
|
138
|
-
if A.ndim == 2:
|
|
139
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
|
|
140
|
-
if A.ndim == 1:
|
|
141
|
-
if wp.types.type_is_vector(A.dtype):
|
|
142
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
|
|
143
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
|
|
144
|
-
if isinstance(A, sparse.BsrMatrix):
|
|
145
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
|
|
146
|
-
|
|
147
|
-
raise ValueError(f"Unable to create LinearOperator from {A}")
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
|
|
151
|
-
"""Constructs and returns a preconditioner for an input matrix.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
A: The matrix for which to build the preconditioner
|
|
155
|
-
ptype: The type of preconditioner. Currently the following values are supported:
|
|
156
|
-
|
|
157
|
-
- ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
|
|
158
|
-
- ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
|
|
159
|
-
- ``"id"``: Identity (null) preconditioner
|
|
160
|
-
"""
|
|
161
|
-
|
|
162
|
-
if ptype == "id":
|
|
163
|
-
return None
|
|
164
|
-
|
|
165
|
-
if ptype in ("diag", "diag_abs"):
|
|
166
|
-
use_abs = 1 if ptype == "diag_abs" else 0
|
|
167
|
-
if isinstance(A, sparse.BsrMatrix):
|
|
168
|
-
A_diag = sparse.bsr_get_diag(A)
|
|
169
|
-
if wp.types.type_is_matrix(A.dtype):
|
|
170
|
-
inv_diag = wp.empty(
|
|
171
|
-
shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
|
|
172
|
-
)
|
|
173
|
-
wp.launch(
|
|
174
|
-
_extract_inverse_diagonal_blocked,
|
|
175
|
-
dim=inv_diag.shape,
|
|
176
|
-
device=inv_diag.device,
|
|
177
|
-
inputs=[A_diag, inv_diag, use_abs],
|
|
178
|
-
)
|
|
179
|
-
else:
|
|
180
|
-
inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
|
|
181
|
-
wp.launch(
|
|
182
|
-
_extract_inverse_diagonal_scalar,
|
|
183
|
-
dim=inv_diag.shape,
|
|
184
|
-
device=inv_diag.device,
|
|
185
|
-
inputs=[A_diag, inv_diag, use_abs],
|
|
186
|
-
)
|
|
187
|
-
elif isinstance(A, wp.array) and A.ndim == 2:
|
|
188
|
-
inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
|
|
189
|
-
wp.launch(
|
|
190
|
-
_extract_inverse_diagonal_dense,
|
|
191
|
-
dim=inv_diag.shape,
|
|
192
|
-
device=inv_diag.device,
|
|
193
|
-
inputs=[A, inv_diag, use_abs],
|
|
194
|
-
)
|
|
195
|
-
else:
|
|
196
|
-
raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
|
|
197
|
-
|
|
198
|
-
return aslinearoperator(inv_diag)
|
|
199
|
-
|
|
200
|
-
raise ValueError(f"Unsupported preconditioner type '{ptype}'")
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def _as_scalar_array(x: wp.array):
|
|
204
|
-
scalar_type = type_scalar_type(x.dtype)
|
|
205
|
-
if scalar_type == x.dtype:
|
|
206
|
-
return x
|
|
207
|
-
|
|
208
|
-
dlen = type_length(x.dtype)
|
|
209
|
-
arr = wp.array(
|
|
210
|
-
ptr=x.ptr,
|
|
211
|
-
shape=(*x.shape[:-1], x.shape[-1] * dlen),
|
|
212
|
-
strides=(*x.strides[:-1], x.strides[-1] // dlen),
|
|
213
|
-
dtype=scalar_type,
|
|
214
|
-
device=x.device,
|
|
215
|
-
grad=None if x.grad is None else _as_scalar_array(x.grad),
|
|
216
|
-
)
|
|
217
|
-
arr._ref = x
|
|
218
|
-
return arr
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
class TiledDot:
|
|
222
|
-
"""
|
|
223
|
-
Computes the dot product of two arrays in a way that is compatible with CUDA sub-graphs.
|
|
224
|
-
"""
|
|
225
|
-
|
|
226
|
-
def __init__(self, max_length: int, scalar_type: type, tile_size=512, device=None, max_column_count: int = 1):
|
|
227
|
-
self.tile_size = tile_size
|
|
228
|
-
self.device = device
|
|
229
|
-
self.max_column_count = max_column_count
|
|
230
|
-
|
|
231
|
-
num_blocks = (max_length + self.tile_size - 1) // self.tile_size
|
|
232
|
-
scratch = wp.empty(
|
|
233
|
-
shape=(2, max_column_count, num_blocks),
|
|
234
|
-
dtype=scalar_type,
|
|
235
|
-
device=self.device,
|
|
236
|
-
)
|
|
237
|
-
self.partial_sums_a = scratch[0]
|
|
238
|
-
self.partial_sums_b = scratch[1]
|
|
239
|
-
|
|
240
|
-
self.dot_kernel, self.sum_kernel = _create_tiled_dot_kernels(self.tile_size)
|
|
241
|
-
|
|
242
|
-
rounds = 0
|
|
243
|
-
length = num_blocks
|
|
244
|
-
while length > 1:
|
|
245
|
-
length = (length + self.tile_size - 1) // self.tile_size
|
|
246
|
-
rounds += 1
|
|
247
|
-
|
|
248
|
-
self.rounds = rounds
|
|
249
|
-
|
|
250
|
-
self._output = self.partial_sums_a if rounds % 2 == 0 else self.partial_sums_b
|
|
251
|
-
|
|
252
|
-
self.dot_launch: wp.Launch = wp.launch(
|
|
253
|
-
self.dot_kernel,
|
|
254
|
-
dim=(max_column_count, num_blocks, self.tile_size),
|
|
255
|
-
inputs=(self.partial_sums_a, self.partial_sums_b),
|
|
256
|
-
outputs=(self.partial_sums_a,),
|
|
257
|
-
block_dim=self.tile_size,
|
|
258
|
-
record_cmd=True,
|
|
259
|
-
)
|
|
260
|
-
self.sum_launch: wp.Launch = wp.launch(
|
|
261
|
-
self.sum_kernel,
|
|
262
|
-
dim=(max_column_count, num_blocks, self.tile_size),
|
|
263
|
-
inputs=(self.partial_sums_a,),
|
|
264
|
-
outputs=(self.partial_sums_b,),
|
|
265
|
-
block_dim=self.tile_size,
|
|
266
|
-
record_cmd=True,
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
# Result contains a single value, the sum of the array (will get updated by this function)
|
|
270
|
-
def compute(self, a: wp.array, b: wp.array, col_offset: int = 0):
|
|
271
|
-
a = _as_scalar_array(a)
|
|
272
|
-
b = _as_scalar_array(b)
|
|
273
|
-
if a.ndim == 1:
|
|
274
|
-
a = a.reshape((1, -1))
|
|
275
|
-
if b.ndim == 1:
|
|
276
|
-
b = b.reshape((1, -1))
|
|
277
|
-
|
|
278
|
-
column_count = a.shape[0]
|
|
279
|
-
num_blocks = (a.shape[1] + self.tile_size - 1) // self.tile_size
|
|
280
|
-
|
|
281
|
-
data_out = self.partial_sums_a[col_offset : col_offset + column_count]
|
|
282
|
-
data_in = self.partial_sums_b[col_offset : col_offset + column_count]
|
|
283
|
-
|
|
284
|
-
self.dot_launch.set_param_at_index(0, a)
|
|
285
|
-
self.dot_launch.set_param_at_index(1, b)
|
|
286
|
-
self.dot_launch.set_param_at_index(2, data_out)
|
|
287
|
-
self.dot_launch.set_dim((column_count, num_blocks, self.tile_size))
|
|
288
|
-
self.dot_launch.launch()
|
|
289
|
-
|
|
290
|
-
for _r in range(self.rounds):
|
|
291
|
-
array_length = num_blocks
|
|
292
|
-
num_blocks = (array_length + self.tile_size - 1) // self.tile_size
|
|
293
|
-
data_in, data_out = data_out, data_in
|
|
294
|
-
|
|
295
|
-
self.sum_launch.set_param_at_index(0, data_in)
|
|
296
|
-
self.sum_launch.set_param_at_index(1, data_out)
|
|
297
|
-
self.sum_launch.set_dim((column_count, num_blocks, self.tile_size))
|
|
298
|
-
self.sum_launch.launch()
|
|
299
|
-
|
|
300
|
-
return data_out
|
|
301
|
-
|
|
302
|
-
def col(self, col: int = 0):
|
|
303
|
-
return self._output[col][:1]
|
|
304
|
-
|
|
305
|
-
def cols(self, count, start: int = 0):
|
|
306
|
-
return self._output[start : start + count, :1]
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
@functools.lru_cache(maxsize=None)
|
|
310
|
-
def _create_tiled_dot_kernels(tile_size):
|
|
311
|
-
@wp.kernel
|
|
312
|
-
def block_dot_kernel(
|
|
313
|
-
a: wp.array2d(dtype=Any),
|
|
314
|
-
b: wp.array2d(dtype=Any),
|
|
315
|
-
partial_sums: wp.array2d(dtype=Any),
|
|
316
|
-
):
|
|
317
|
-
column, block_id, tid_block = wp.tid()
|
|
318
|
-
|
|
319
|
-
start = block_id * tile_size
|
|
320
|
-
|
|
321
|
-
a_block = wp.tile_load(a[column], shape=tile_size, offset=start)
|
|
322
|
-
b_block = wp.tile_load(b[column], shape=tile_size, offset=start)
|
|
323
|
-
t = wp.tile_map(wp.mul, a_block, b_block)
|
|
324
|
-
|
|
325
|
-
tile_sum = wp.tile_sum(t)
|
|
326
|
-
wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
|
|
327
|
-
|
|
328
|
-
@wp.kernel
|
|
329
|
-
def block_sum_kernel(
|
|
330
|
-
data: wp.array2d(dtype=Any),
|
|
331
|
-
partial_sums: wp.array2d(dtype=Any),
|
|
332
|
-
):
|
|
333
|
-
column, block_id, tid_block = wp.tid()
|
|
334
|
-
start = block_id * tile_size
|
|
335
|
-
|
|
336
|
-
t = wp.tile_load(data[column], shape=tile_size, offset=start)
|
|
337
|
-
|
|
338
|
-
tile_sum = wp.tile_sum(t)
|
|
339
|
-
wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
|
|
340
|
-
|
|
341
|
-
return block_dot_kernel, block_sum_kernel
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
def cg(
|
|
345
|
-
A: _Matrix,
|
|
346
|
-
b: wp.array,
|
|
347
|
-
x: wp.array,
|
|
348
|
-
tol: Optional[float] = None,
|
|
349
|
-
atol: Optional[float] = None,
|
|
350
|
-
maxiter: Optional[float] = 0,
|
|
351
|
-
M: Optional[_Matrix] = None,
|
|
352
|
-
callback: Optional[Callable] = None,
|
|
353
|
-
check_every=10,
|
|
354
|
-
use_cuda_graph=True,
|
|
355
|
-
) -> Union[Tuple[int, float, float], Tuple[wp.array, wp.array, wp.array]]:
|
|
356
|
-
"""Computes an approximate solution to a symmetric, positive-definite linear system
|
|
357
|
-
using the Conjugate Gradient algorithm.
|
|
358
|
-
|
|
359
|
-
Args:
|
|
360
|
-
A: the linear system's left-hand-side
|
|
361
|
-
b: the linear system's right-hand-side
|
|
362
|
-
x: initial guess and solution vector
|
|
363
|
-
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
364
|
-
atol: absolute tolerance for the residual
|
|
365
|
-
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
366
|
-
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
367
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
368
|
-
If `check_every` is 0, the callback should be a Warp kernel.
|
|
369
|
-
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
370
|
-
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
371
|
-
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
372
|
-
to the maximum number of iterations.
|
|
373
|
-
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
374
|
-
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
375
|
-
|
|
376
|
-
Returns:
|
|
377
|
-
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
378
|
-
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
379
|
-
- residual_norm: The final residual norm ||b - Ax||
|
|
380
|
-
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
381
|
-
|
|
382
|
-
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
383
|
-
- final_iteration_array: Device array containing the number of iterations performed
|
|
384
|
-
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
385
|
-
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
386
|
-
|
|
387
|
-
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
388
|
-
"""
|
|
389
|
-
A = aslinearoperator(A)
|
|
390
|
-
M = aslinearoperator(M)
|
|
391
|
-
|
|
392
|
-
if maxiter == 0:
|
|
393
|
-
maxiter = A.shape[0]
|
|
394
|
-
|
|
395
|
-
device = A.device
|
|
396
|
-
scalar_type = A.scalar_type
|
|
397
|
-
|
|
398
|
-
# Temp storage
|
|
399
|
-
r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
|
|
400
|
-
p_and_Ap = wp.empty_like(r_and_z)
|
|
401
|
-
residuals = wp.empty(2, dtype=scalar_type, device=device)
|
|
402
|
-
|
|
403
|
-
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
|
|
404
|
-
|
|
405
|
-
# named views
|
|
406
|
-
|
|
407
|
-
# (r, r) -- so we can compute r.z and r.r at once
|
|
408
|
-
r_repeated = _repeat_first(r_and_z)
|
|
409
|
-
if M is None:
|
|
410
|
-
# without preconditioner r == z
|
|
411
|
-
r_and_z = r_repeated
|
|
412
|
-
rz_new = tiled_dot.col(0)
|
|
413
|
-
else:
|
|
414
|
-
rz_new = tiled_dot.col(1)
|
|
415
|
-
|
|
416
|
-
r, z = r_and_z[0], r_and_z[1]
|
|
417
|
-
r_norm_sq = tiled_dot.col(0)
|
|
418
|
-
|
|
419
|
-
p, Ap = p_and_Ap[0], p_and_Ap[1]
|
|
420
|
-
rz_old, atol_sq = residuals[0:1], residuals[1:2]
|
|
421
|
-
|
|
422
|
-
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
423
|
-
Ap.zero_()
|
|
424
|
-
z.zero_()
|
|
425
|
-
|
|
426
|
-
# Initialize tolerance from right-hand-side norm
|
|
427
|
-
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
428
|
-
# Initialize residual
|
|
429
|
-
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
430
|
-
|
|
431
|
-
def update_rr_rz():
|
|
432
|
-
# z = M r
|
|
433
|
-
if M is None:
|
|
434
|
-
tiled_dot.compute(r, r)
|
|
435
|
-
else:
|
|
436
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
437
|
-
tiled_dot.compute(r_repeated, r_and_z)
|
|
438
|
-
|
|
439
|
-
update_rr_rz()
|
|
440
|
-
p.assign(z)
|
|
441
|
-
|
|
442
|
-
def do_iteration():
|
|
443
|
-
rz_old.assign(rz_new)
|
|
444
|
-
|
|
445
|
-
# Ap = A * p;
|
|
446
|
-
A.matvec(p, Ap, Ap, alpha=1, beta=0)
|
|
447
|
-
tiled_dot.compute(p, Ap, col_offset=1)
|
|
448
|
-
p_Ap = tiled_dot.col(1)
|
|
449
|
-
|
|
450
|
-
wp.launch(
|
|
451
|
-
kernel=_cg_kernel_1,
|
|
452
|
-
dim=x.shape[0],
|
|
453
|
-
device=device,
|
|
454
|
-
inputs=[atol_sq, r_norm_sq, rz_old, p_Ap, x, r, p, Ap],
|
|
455
|
-
)
|
|
456
|
-
|
|
457
|
-
update_rr_rz()
|
|
458
|
-
|
|
459
|
-
wp.launch(
|
|
460
|
-
kernel=_cg_kernel_2,
|
|
461
|
-
dim=z.shape[0],
|
|
462
|
-
device=device,
|
|
463
|
-
inputs=[atol_sq, r_norm_sq, rz_old, rz_new, z, p],
|
|
464
|
-
)
|
|
465
|
-
|
|
466
|
-
return _run_capturable_loop(do_iteration, r_norm_sq, maxiter, atol_sq, callback, check_every, use_cuda_graph)
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
def cr(
|
|
470
|
-
A: _Matrix,
|
|
471
|
-
b: wp.array,
|
|
472
|
-
x: wp.array,
|
|
473
|
-
tol: Optional[float] = None,
|
|
474
|
-
atol: Optional[float] = None,
|
|
475
|
-
maxiter: Optional[float] = 0,
|
|
476
|
-
M: Optional[_Matrix] = None,
|
|
477
|
-
callback: Optional[Callable] = None,
|
|
478
|
-
check_every=10,
|
|
479
|
-
use_cuda_graph=True,
|
|
480
|
-
) -> Tuple[int, float, float]:
|
|
481
|
-
"""Computes an approximate solution to a symmetric, positive-definite linear system
|
|
482
|
-
using the Conjugate Residual algorithm.
|
|
483
|
-
|
|
484
|
-
Args:
|
|
485
|
-
A: the linear system's left-hand-side
|
|
486
|
-
b: the linear system's right-hand-side
|
|
487
|
-
x: initial guess and solution vector
|
|
488
|
-
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
489
|
-
atol: absolute tolerance for the residual
|
|
490
|
-
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
491
|
-
Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
|
|
492
|
-
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
493
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
494
|
-
If `check_every` is 0, the callback should be a Warp kernel.
|
|
495
|
-
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
496
|
-
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
497
|
-
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
498
|
-
to the maximum number of iterations.
|
|
499
|
-
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
500
|
-
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
501
|
-
|
|
502
|
-
Returns:
|
|
503
|
-
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
504
|
-
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
505
|
-
- residual_norm: The final residual norm ||b - Ax||
|
|
506
|
-
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
507
|
-
|
|
508
|
-
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
509
|
-
- final_iteration_array: Device array containing the number of iterations performed
|
|
510
|
-
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
511
|
-
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
512
|
-
|
|
513
|
-
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
514
|
-
"""
|
|
515
|
-
|
|
516
|
-
A = aslinearoperator(A)
|
|
517
|
-
M = aslinearoperator(M)
|
|
518
|
-
|
|
519
|
-
if maxiter == 0:
|
|
520
|
-
maxiter = A.shape[0]
|
|
521
|
-
|
|
522
|
-
device = A.device
|
|
523
|
-
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
524
|
-
|
|
525
|
-
# Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
|
|
526
|
-
# with z := M^-1 r and y := M^-1 Ap
|
|
527
|
-
|
|
528
|
-
# Temp storage
|
|
529
|
-
r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
|
|
530
|
-
r_and_Az = wp.empty_like(r_and_z)
|
|
531
|
-
y_and_Ap = wp.empty_like(r_and_z)
|
|
532
|
-
p = wp.empty_like(b)
|
|
533
|
-
residuals = wp.empty(2, dtype=scalar_type, device=device)
|
|
534
|
-
|
|
535
|
-
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
|
|
536
|
-
|
|
537
|
-
if M is None:
|
|
538
|
-
r_and_z = _repeat_first(r_and_z)
|
|
539
|
-
y_and_Ap = _repeat_first(y_and_Ap)
|
|
540
|
-
|
|
541
|
-
# named views
|
|
542
|
-
r, z = r_and_z[0], r_and_z[1]
|
|
543
|
-
r_copy, Az = r_and_Az[0], r_and_Az[1]
|
|
544
|
-
|
|
545
|
-
y, Ap = y_and_Ap[0], y_and_Ap[1]
|
|
546
|
-
|
|
547
|
-
r_norm_sq = tiled_dot.col(0)
|
|
548
|
-
zAz_new = tiled_dot.col(1)
|
|
549
|
-
zAz_old, atol_sq = residuals[0:1], residuals[1:2]
|
|
550
|
-
|
|
551
|
-
# Initialize tolerance from right-hand-side norm
|
|
552
|
-
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
553
|
-
# Initialize residual
|
|
554
|
-
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
555
|
-
|
|
556
|
-
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
557
|
-
y_and_Ap.zero_()
|
|
558
|
-
|
|
559
|
-
# z = M r
|
|
560
|
-
if M is not None:
|
|
561
|
-
z.zero_()
|
|
562
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
563
|
-
|
|
564
|
-
def update_rr_zAz():
|
|
565
|
-
A.matvec(z, Az, Az, alpha=1, beta=0)
|
|
566
|
-
r_copy.assign(r)
|
|
567
|
-
tiled_dot.compute(r_and_z, r_and_Az)
|
|
568
|
-
|
|
569
|
-
update_rr_zAz()
|
|
570
|
-
|
|
571
|
-
p.assign(z)
|
|
572
|
-
Ap.assign(Az)
|
|
573
|
-
|
|
574
|
-
def do_iteration():
|
|
575
|
-
zAz_old.assign(zAz_new)
|
|
576
|
-
|
|
577
|
-
if M is not None:
|
|
578
|
-
M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
|
|
579
|
-
tiled_dot.compute(Ap, y, col_offset=1)
|
|
580
|
-
y_Ap = tiled_dot.col(1)
|
|
581
|
-
|
|
582
|
-
if M is None:
|
|
583
|
-
# In non-preconditioned case, first kernel is same as CG
|
|
584
|
-
wp.launch(
|
|
585
|
-
kernel=_cg_kernel_1,
|
|
586
|
-
dim=x.shape[0],
|
|
587
|
-
device=device,
|
|
588
|
-
inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, p, Ap],
|
|
589
|
-
)
|
|
590
|
-
else:
|
|
591
|
-
# In preconditioned case, we have one more vector to update
|
|
592
|
-
wp.launch(
|
|
593
|
-
kernel=_cr_kernel_1,
|
|
594
|
-
dim=x.shape[0],
|
|
595
|
-
device=device,
|
|
596
|
-
inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, z, p, Ap, y],
|
|
597
|
-
)
|
|
598
|
-
|
|
599
|
-
update_rr_zAz()
|
|
600
|
-
wp.launch(
|
|
601
|
-
kernel=_cr_kernel_2,
|
|
602
|
-
dim=z.shape[0],
|
|
603
|
-
device=device,
|
|
604
|
-
inputs=[atol_sq, r_norm_sq, zAz_old, zAz_new, z, p, Az, Ap],
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
return _run_capturable_loop(
|
|
608
|
-
do_iteration,
|
|
609
|
-
cycle_size=1,
|
|
610
|
-
r_norm_sq=r_norm_sq,
|
|
611
|
-
maxiter=maxiter,
|
|
612
|
-
atol_sq=atol_sq,
|
|
613
|
-
callback=callback,
|
|
614
|
-
check_every=check_every,
|
|
615
|
-
use_cuda_graph=use_cuda_graph,
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
def bicgstab(
|
|
620
|
-
A: _Matrix,
|
|
621
|
-
b: wp.array,
|
|
622
|
-
x: wp.array,
|
|
623
|
-
tol: Optional[float] = None,
|
|
624
|
-
atol: Optional[float] = None,
|
|
625
|
-
maxiter: Optional[float] = 0,
|
|
626
|
-
M: Optional[_Matrix] = None,
|
|
627
|
-
callback: Optional[Callable] = None,
|
|
628
|
-
check_every=10,
|
|
629
|
-
use_cuda_graph=True,
|
|
630
|
-
is_left_preconditioner=False,
|
|
631
|
-
):
|
|
632
|
-
"""Computes an approximate solution to a linear system using the Biconjugate Gradient Stabilized method (BiCGSTAB).
|
|
633
|
-
|
|
634
|
-
Args:
|
|
635
|
-
A: the linear system's left-hand-side
|
|
636
|
-
b: the linear system's right-hand-side
|
|
637
|
-
x: initial guess and solution vector
|
|
638
|
-
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
639
|
-
atol: absolute tolerance for the residual
|
|
640
|
-
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
641
|
-
M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
|
|
642
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
643
|
-
If `check_every` is 0, the callback should be a Warp kernel.
|
|
644
|
-
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
645
|
-
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
646
|
-
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
647
|
-
to the maximum number of iterations.
|
|
648
|
-
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
649
|
-
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
650
|
-
is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
|
|
651
|
-
|
|
652
|
-
Returns:
|
|
653
|
-
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
654
|
-
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
655
|
-
- residual_norm: The final residual norm ||b - Ax||
|
|
656
|
-
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
657
|
-
|
|
658
|
-
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
659
|
-
- final_iteration_array: Device array containing the number of iterations performed
|
|
660
|
-
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
661
|
-
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
662
|
-
|
|
663
|
-
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
664
|
-
"""
|
|
665
|
-
A = aslinearoperator(A)
|
|
666
|
-
M = aslinearoperator(M)
|
|
667
|
-
|
|
668
|
-
if maxiter == 0:
|
|
669
|
-
maxiter = A.shape[0]
|
|
670
|
-
|
|
671
|
-
device = A.device
|
|
672
|
-
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
673
|
-
|
|
674
|
-
# Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
|
|
675
|
-
|
|
676
|
-
# Temp storage
|
|
677
|
-
r_and_r0 = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
|
|
678
|
-
p = wp.empty_like(b)
|
|
679
|
-
v = wp.empty_like(b)
|
|
680
|
-
t = wp.empty_like(b)
|
|
681
|
-
|
|
682
|
-
r, r0 = r_and_r0[0], r_and_r0[1]
|
|
683
|
-
r_repeated = _repeat_first(r_and_r0)
|
|
684
|
-
|
|
685
|
-
if M is not None:
|
|
686
|
-
y = wp.zeros_like(p)
|
|
687
|
-
z = wp.zeros_like(r)
|
|
688
|
-
if is_left_preconditioner:
|
|
689
|
-
Mt = wp.zeros_like(t)
|
|
690
|
-
else:
|
|
691
|
-
y = p
|
|
692
|
-
z = r
|
|
693
|
-
Mt = t
|
|
694
|
-
|
|
695
|
-
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=5)
|
|
696
|
-
r_norm_sq = tiled_dot.col(0)
|
|
697
|
-
rho = tiled_dot.col(1)
|
|
698
|
-
|
|
699
|
-
atol_sq = wp.empty(1, dtype=scalar_type, device=device)
|
|
700
|
-
|
|
701
|
-
# Initialize tolerance from right-hand-side norm
|
|
702
|
-
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
703
|
-
# Initialize residual
|
|
704
|
-
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
705
|
-
tiled_dot.compute(r, r, col_offset=0)
|
|
706
|
-
|
|
707
|
-
p.assign(r)
|
|
708
|
-
r0.assign(r)
|
|
709
|
-
rho.assign(r_norm_sq)
|
|
710
|
-
|
|
711
|
-
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
712
|
-
v.zero_()
|
|
713
|
-
t.zero_()
|
|
714
|
-
|
|
715
|
-
def do_iteration():
|
|
716
|
-
# y = M p
|
|
717
|
-
if M is not None:
|
|
718
|
-
M.matvec(p, y, y, alpha=1.0, beta=0.0)
|
|
719
|
-
|
|
720
|
-
# v = A * y;
|
|
721
|
-
A.matvec(y, v, v, alpha=1, beta=0)
|
|
722
|
-
|
|
723
|
-
# alpha = rho / <r0 . v>
|
|
724
|
-
tiled_dot.compute(r0, v, col_offset=2)
|
|
725
|
-
r0v = tiled_dot.col(2)
|
|
726
|
-
|
|
727
|
-
# x += alpha y
|
|
728
|
-
# r -= alpha v
|
|
729
|
-
wp.launch(
|
|
730
|
-
kernel=_bicgstab_kernel_1,
|
|
731
|
-
dim=x.shape[0],
|
|
732
|
-
device=device,
|
|
733
|
-
inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
|
|
734
|
-
)
|
|
735
|
-
tiled_dot.compute(r, r, col_offset=0)
|
|
736
|
-
|
|
737
|
-
# z = M r
|
|
738
|
-
if M is not None:
|
|
739
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
740
|
-
|
|
741
|
-
# t = A z
|
|
742
|
-
A.matvec(z, t, t, alpha=1, beta=0)
|
|
743
|
-
|
|
744
|
-
if M is not None and is_left_preconditioner:
|
|
745
|
-
# Mt = M t
|
|
746
|
-
M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
|
|
747
|
-
|
|
748
|
-
# omega = <Mt, Ms> / <Mt, Mt>
|
|
749
|
-
tiled_dot.compute(z, Mt, col_offset=3)
|
|
750
|
-
tiled_dot.compute(Mt, Mt, col_offset=4)
|
|
751
|
-
else:
|
|
752
|
-
tiled_dot.compute(r, t, col_offset=3)
|
|
753
|
-
tiled_dot.compute(t, t, col_offset=4)
|
|
754
|
-
st = tiled_dot.col(3)
|
|
755
|
-
tt = tiled_dot.col(4)
|
|
756
|
-
|
|
757
|
-
# x += omega z
|
|
758
|
-
# r -= omega t
|
|
759
|
-
wp.launch(
|
|
760
|
-
kernel=_bicgstab_kernel_2,
|
|
761
|
-
dim=z.shape[0],
|
|
762
|
-
device=device,
|
|
763
|
-
inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
|
|
764
|
-
)
|
|
765
|
-
|
|
766
|
-
# r = <r,r>, rho = <r0, r>
|
|
767
|
-
tiled_dot.compute(r_and_r0, r_repeated, col_offset=0)
|
|
768
|
-
|
|
769
|
-
# beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
|
|
770
|
-
# p = r + beta (p - omega v)
|
|
771
|
-
wp.launch(
|
|
772
|
-
kernel=_bicgstab_kernel_3,
|
|
773
|
-
dim=z.shape[0],
|
|
774
|
-
device=device,
|
|
775
|
-
inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
|
|
776
|
-
)
|
|
777
|
-
|
|
778
|
-
return _run_capturable_loop(
|
|
779
|
-
do_iteration,
|
|
780
|
-
r_norm_sq=r_norm_sq,
|
|
781
|
-
maxiter=maxiter,
|
|
782
|
-
atol_sq=atol_sq,
|
|
783
|
-
callback=callback,
|
|
784
|
-
check_every=check_every,
|
|
785
|
-
use_cuda_graph=use_cuda_graph,
|
|
786
|
-
)
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
def gmres(
|
|
790
|
-
A: _Matrix,
|
|
791
|
-
b: wp.array,
|
|
792
|
-
x: wp.array,
|
|
793
|
-
tol: Optional[float] = None,
|
|
794
|
-
atol: Optional[float] = None,
|
|
795
|
-
restart=31,
|
|
796
|
-
maxiter: Optional[float] = 0,
|
|
797
|
-
M: Optional[_Matrix] = None,
|
|
798
|
-
callback: Optional[Callable] = None,
|
|
799
|
-
check_every=31,
|
|
800
|
-
use_cuda_graph=True,
|
|
801
|
-
is_left_preconditioner=False,
|
|
802
|
-
):
|
|
803
|
-
"""Computes an approximate solution to a linear system using the restarted Generalized Minimum Residual method (GMRES[k]).
|
|
804
|
-
|
|
805
|
-
Args:
|
|
806
|
-
A: the linear system's left-hand-side
|
|
807
|
-
b: the linear system's right-hand-side
|
|
808
|
-
x: initial guess and solution vector
|
|
809
|
-
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
810
|
-
atol: absolute tolerance for the residual
|
|
811
|
-
restart: The restart parameter, i.e, the `k` in `GMRES[k]`. In general, increasing this parameter reduces the number of iterations but increases memory consumption.
|
|
812
|
-
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
813
|
-
Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
|
|
814
|
-
M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
|
|
815
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
816
|
-
If `check_every` is 0, the callback should be a Warp kernel.
|
|
817
|
-
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
818
|
-
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
819
|
-
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
820
|
-
to the maximum number of iterations.
|
|
821
|
-
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
822
|
-
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
823
|
-
is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
|
|
824
|
-
|
|
825
|
-
Returns:
|
|
826
|
-
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
827
|
-
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
828
|
-
- residual_norm: The final residual norm ||b - Ax||
|
|
829
|
-
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
830
|
-
|
|
831
|
-
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
832
|
-
- final_iteration_array: Device array containing the number of iterations performed
|
|
833
|
-
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
834
|
-
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
835
|
-
|
|
836
|
-
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
837
|
-
"""
|
|
838
|
-
|
|
839
|
-
A = aslinearoperator(A)
|
|
840
|
-
M = aslinearoperator(M)
|
|
841
|
-
|
|
842
|
-
if maxiter == 0:
|
|
843
|
-
maxiter = A.shape[0]
|
|
844
|
-
|
|
845
|
-
restart = min(restart, maxiter)
|
|
846
|
-
|
|
847
|
-
if check_every > 0:
|
|
848
|
-
check_every = max(restart, check_every)
|
|
849
|
-
|
|
850
|
-
device = A.device
|
|
851
|
-
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
852
|
-
|
|
853
|
-
pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
|
|
854
|
-
|
|
855
|
-
r = wp.empty_like(b)
|
|
856
|
-
w = wp.empty_like(r)
|
|
857
|
-
|
|
858
|
-
H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
|
|
859
|
-
y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
|
|
860
|
-
|
|
861
|
-
V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
|
|
862
|
-
|
|
863
|
-
residuals = wp.empty(2, dtype=scalar_dtype, device=device)
|
|
864
|
-
beta, atol_sq = residuals[0:1], residuals[1:2]
|
|
865
|
-
|
|
866
|
-
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_dtype, max_column_count=restart + 1)
|
|
867
|
-
r_norm_sq = tiled_dot.col(0)
|
|
868
|
-
|
|
869
|
-
w_repeated = wp.array(
|
|
870
|
-
ptr=w.ptr, shape=(restart + 1, w.shape[0]), strides=(0, w.strides[0]), dtype=w.dtype, device=w.device
|
|
871
|
-
)
|
|
872
|
-
|
|
873
|
-
# tile size for least square solve
|
|
874
|
-
# (need to fit in a CUDA block, so 1024 max)
|
|
875
|
-
if device.is_cuda and 4 < restart <= 1024:
|
|
876
|
-
tile_size = 1 << math.ceil(math.log2(restart))
|
|
877
|
-
least_squares_kernel = make_gmres_solve_least_squares_kernel_tiled(tile_size)
|
|
878
|
-
else:
|
|
879
|
-
tile_size = 1
|
|
880
|
-
least_squares_kernel = _gmres_solve_least_squares
|
|
881
|
-
|
|
882
|
-
# recorded launches
|
|
883
|
-
least_squares_solve = wp.launch(
|
|
884
|
-
least_squares_kernel,
|
|
885
|
-
dim=(1, tile_size),
|
|
886
|
-
block_dim=tile_size if tile_size > 1 else 256,
|
|
887
|
-
device=device,
|
|
888
|
-
inputs=[restart, pivot_tolerance, beta, H, y],
|
|
889
|
-
record_cmd=True,
|
|
890
|
-
)
|
|
891
|
-
|
|
892
|
-
normalize_anorldi_vec = wp.launch(
|
|
893
|
-
_gmres_arnoldi_normalize_kernel,
|
|
894
|
-
dim=r.shape,
|
|
895
|
-
device=r.device,
|
|
896
|
-
inputs=[r, w, tiled_dot.col(0), beta],
|
|
897
|
-
record_cmd=True,
|
|
898
|
-
)
|
|
899
|
-
|
|
900
|
-
arnoldi_axpy = wp.launch(
|
|
901
|
-
_gmres_arnoldi_axpy_kernel,
|
|
902
|
-
dim=(w.shape[0], tile_size),
|
|
903
|
-
block_dim=tile_size,
|
|
904
|
-
device=w.device,
|
|
905
|
-
inputs=[V, w, H],
|
|
906
|
-
record_cmd=True,
|
|
907
|
-
)
|
|
908
|
-
|
|
909
|
-
# Initialize tolerance from right-hand-side norm
|
|
910
|
-
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
911
|
-
# Initialize residual
|
|
912
|
-
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
913
|
-
tiled_dot.compute(r, r, col_offset=0)
|
|
914
|
-
|
|
915
|
-
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
916
|
-
w.zero_()
|
|
917
|
-
|
|
918
|
-
def array_coeff(H, i, j):
|
|
919
|
-
return H[i][j : j + 1]
|
|
920
|
-
|
|
921
|
-
def array_col(H, j):
|
|
922
|
-
return H[: j + 1, j : j + 1]
|
|
923
|
-
|
|
924
|
-
def do_arnoldi_iteration(j: int):
|
|
925
|
-
# w = A * v[j];
|
|
926
|
-
if M is not None:
|
|
927
|
-
tmp = V[j + 1]
|
|
928
|
-
|
|
929
|
-
if is_left_preconditioner:
|
|
930
|
-
A.matvec(V[j], tmp, tmp, alpha=1, beta=0)
|
|
931
|
-
M.matvec(tmp, w, w, alpha=1, beta=0)
|
|
932
|
-
else:
|
|
933
|
-
M.matvec(V[j], tmp, tmp, alpha=1, beta=0)
|
|
934
|
-
A.matvec(tmp, w, w, alpha=1, beta=0)
|
|
935
|
-
else:
|
|
936
|
-
A.matvec(V[j], w, w, alpha=1, beta=0)
|
|
937
|
-
|
|
938
|
-
# compute and apply dot products in rappel,
|
|
939
|
-
# since Hj columns are orthogonal
|
|
940
|
-
Hj = array_col(H, j)
|
|
941
|
-
tiled_dot.compute(w_repeated, V[: j + 1])
|
|
942
|
-
wp.copy(src=tiled_dot.cols(j + 1), dest=Hj)
|
|
943
|
-
|
|
944
|
-
# w -= w.vi vi
|
|
945
|
-
arnoldi_axpy.set_params([V[: j + 1], w, Hj])
|
|
946
|
-
arnoldi_axpy.launch()
|
|
947
|
-
|
|
948
|
-
# H[j+1, j] = |w.w|
|
|
949
|
-
tiled_dot.compute(w, w)
|
|
950
|
-
normalize_anorldi_vec.set_params([w, V[j + 1], tiled_dot.col(0), array_coeff(H, j + 1, j)])
|
|
951
|
-
|
|
952
|
-
normalize_anorldi_vec.launch()
|
|
953
|
-
|
|
954
|
-
def do_restart_cycle():
|
|
955
|
-
if M is not None and is_left_preconditioner:
|
|
956
|
-
M.matvec(r, w, w, alpha=1, beta=0)
|
|
957
|
-
rh = w
|
|
958
|
-
else:
|
|
959
|
-
rh = r
|
|
960
|
-
|
|
961
|
-
# beta^2 = rh.rh
|
|
962
|
-
tiled_dot.compute(rh, rh)
|
|
963
|
-
|
|
964
|
-
# v[0] = r / beta
|
|
965
|
-
normalize_anorldi_vec.set_params([rh, V[0], tiled_dot.col(0), beta])
|
|
966
|
-
normalize_anorldi_vec.launch()
|
|
967
|
-
|
|
968
|
-
for j in range(restart):
|
|
969
|
-
do_arnoldi_iteration(j)
|
|
970
|
-
|
|
971
|
-
least_squares_solve.launch()
|
|
972
|
-
|
|
973
|
-
# update x
|
|
974
|
-
if M is None or is_left_preconditioner:
|
|
975
|
-
wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(1.0), y, V, x])
|
|
976
|
-
else:
|
|
977
|
-
wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(0.0), y, V, w])
|
|
978
|
-
M.matvec(w, x, x, alpha=1, beta=1)
|
|
979
|
-
|
|
980
|
-
# update r and residual
|
|
981
|
-
wp.copy(src=b, dest=r)
|
|
982
|
-
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
983
|
-
tiled_dot.compute(r, r)
|
|
984
|
-
|
|
985
|
-
return _run_capturable_loop(
|
|
986
|
-
do_restart_cycle,
|
|
987
|
-
cycle_size=restart,
|
|
988
|
-
r_norm_sq=r_norm_sq,
|
|
989
|
-
maxiter=maxiter,
|
|
990
|
-
atol_sq=atol_sq,
|
|
991
|
-
callback=callback,
|
|
992
|
-
check_every=check_every,
|
|
993
|
-
use_cuda_graph=use_cuda_graph,
|
|
994
|
-
)
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
def _repeat_first(arr: wp.array):
|
|
998
|
-
# returns a view of the first element repeated arr.shape[0] times
|
|
999
|
-
view = wp.array(
|
|
1000
|
-
ptr=arr.ptr,
|
|
1001
|
-
shape=arr.shape,
|
|
1002
|
-
dtype=arr.dtype,
|
|
1003
|
-
strides=(0, *arr.strides[1:]),
|
|
1004
|
-
device=arr.device,
|
|
1005
|
-
)
|
|
1006
|
-
view._ref = arr
|
|
1007
|
-
return view
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
def _get_dtype_epsilon(dtype):
|
|
1011
|
-
if dtype == wp.float64:
|
|
1012
|
-
return 1.0e-16
|
|
1013
|
-
elif dtype == wp.float16:
|
|
1014
|
-
return 1.0e-4
|
|
1015
|
-
|
|
1016
|
-
return 1.0e-8
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
def _get_tolerances(dtype, tol, atol):
|
|
1020
|
-
eps_tol = _get_dtype_epsilon(dtype)
|
|
1021
|
-
default_tol = eps_tol ** (3 / 4)
|
|
1022
|
-
min_tol = eps_tol ** (9 / 4)
|
|
1023
|
-
|
|
1024
|
-
if tol is None and atol is None:
|
|
1025
|
-
tol = atol = default_tol
|
|
1026
|
-
elif tol is None:
|
|
1027
|
-
tol = atol
|
|
1028
|
-
elif atol is None:
|
|
1029
|
-
atol = tol
|
|
1030
|
-
|
|
1031
|
-
atol = max(atol, min_tol)
|
|
1032
|
-
return tol, atol
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
@wp.kernel
|
|
1036
|
-
def _initialize_tolerance(
|
|
1037
|
-
rtol: Any,
|
|
1038
|
-
atol: Any,
|
|
1039
|
-
r_norm_sq: wp.array(dtype=Any),
|
|
1040
|
-
atol_sq: wp.array(dtype=Any),
|
|
1041
|
-
):
|
|
1042
|
-
atol = wp.max(rtol * wp.sqrt(r_norm_sq[0]), atol)
|
|
1043
|
-
atol_sq[0] = atol * atol
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
def _initialize_absolute_tolerance(
|
|
1047
|
-
b: wp.array,
|
|
1048
|
-
tol: float,
|
|
1049
|
-
atol: float,
|
|
1050
|
-
tiled_dot: TiledDot,
|
|
1051
|
-
atol_sq: wp.array,
|
|
1052
|
-
):
|
|
1053
|
-
scalar_type = atol_sq.dtype
|
|
1054
|
-
|
|
1055
|
-
# Compute b norm to define absolute tolerance
|
|
1056
|
-
tiled_dot.compute(b, b)
|
|
1057
|
-
b_norm_sq = tiled_dot.col(0)
|
|
1058
|
-
|
|
1059
|
-
rtol, atol = _get_tolerances(scalar_type, tol, atol)
|
|
1060
|
-
wp.launch(
|
|
1061
|
-
kernel=_initialize_tolerance,
|
|
1062
|
-
dim=1,
|
|
1063
|
-
device=b.device,
|
|
1064
|
-
inputs=[scalar_type(rtol), scalar_type(atol), b_norm_sq, atol_sq],
|
|
1065
|
-
)
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
@wp.kernel
|
|
1069
|
-
def _update_condition(
|
|
1070
|
-
maxiter: int,
|
|
1071
|
-
cycle_size: int,
|
|
1072
|
-
cur_iter: wp.array(dtype=int),
|
|
1073
|
-
r_norm_sq: wp.array(dtype=Any),
|
|
1074
|
-
atol_sq: wp.array(dtype=Any),
|
|
1075
|
-
condition: wp.array(dtype=int),
|
|
1076
|
-
):
|
|
1077
|
-
cur_iter[0] += cycle_size
|
|
1078
|
-
condition[0] = wp.where(r_norm_sq[0] <= atol_sq[0] or cur_iter[0] >= maxiter, 0, 1)
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
def _run_capturable_loop(
|
|
1082
|
-
do_cycle: Callable,
|
|
1083
|
-
r_norm_sq: wp.array,
|
|
1084
|
-
maxiter: int,
|
|
1085
|
-
atol_sq: wp.array,
|
|
1086
|
-
callback: Optional[Callable],
|
|
1087
|
-
check_every: int,
|
|
1088
|
-
use_cuda_graph: bool,
|
|
1089
|
-
cycle_size: int = 1,
|
|
1090
|
-
):
|
|
1091
|
-
device = atol_sq.device
|
|
1092
|
-
|
|
1093
|
-
if check_every > 0:
|
|
1094
|
-
atol = math.sqrt(atol_sq.numpy()[0])
|
|
1095
|
-
return _run_solver_loop(
|
|
1096
|
-
do_cycle, cycle_size, r_norm_sq, maxiter, atol, callback, check_every, use_cuda_graph, device
|
|
1097
|
-
)
|
|
1098
|
-
|
|
1099
|
-
cur_iter_and_condition = wp.full((2,), value=-1, dtype=int, device=device)
|
|
1100
|
-
cur_iter = cur_iter_and_condition[0:1]
|
|
1101
|
-
condition = cur_iter_and_condition[1:2]
|
|
1102
|
-
|
|
1103
|
-
update_condition_launch = wp.launch(
|
|
1104
|
-
_update_condition,
|
|
1105
|
-
dim=1,
|
|
1106
|
-
device=device,
|
|
1107
|
-
inputs=[int(maxiter), cycle_size, cur_iter, r_norm_sq, atol_sq, condition],
|
|
1108
|
-
record_cmd=True,
|
|
1109
|
-
)
|
|
1110
|
-
|
|
1111
|
-
if isinstance(callback, wp.Kernel):
|
|
1112
|
-
callback_launch = wp.launch(
|
|
1113
|
-
callback, dim=1, device=device, inputs=[cur_iter, r_norm_sq, atol_sq], record_cmd=True
|
|
1114
|
-
)
|
|
1115
|
-
else:
|
|
1116
|
-
callback_launch = None
|
|
1117
|
-
|
|
1118
|
-
update_condition_launch.launch()
|
|
1119
|
-
if callback_launch is not None:
|
|
1120
|
-
callback_launch.launch()
|
|
1121
|
-
|
|
1122
|
-
def do_cycle_with_condition():
|
|
1123
|
-
do_cycle()
|
|
1124
|
-
update_condition_launch.launch()
|
|
1125
|
-
if callback_launch is not None:
|
|
1126
|
-
callback_launch.launch()
|
|
1127
|
-
|
|
1128
|
-
if use_cuda_graph and device.is_cuda:
|
|
1129
|
-
if device.is_capturing:
|
|
1130
|
-
wp.capture_while(condition, do_cycle_with_condition)
|
|
1131
|
-
else:
|
|
1132
|
-
with wp.ScopedCapture() as capture:
|
|
1133
|
-
wp.capture_while(condition, do_cycle_with_condition)
|
|
1134
|
-
wp.capture_launch(capture.graph)
|
|
1135
|
-
else:
|
|
1136
|
-
for _ in range(0, maxiter, cycle_size):
|
|
1137
|
-
do_cycle_with_condition()
|
|
1138
|
-
|
|
1139
|
-
return cur_iter, r_norm_sq, atol_sq
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
def _run_solver_loop(
|
|
1143
|
-
do_cycle: Callable[[float], None],
|
|
1144
|
-
cycle_size: int,
|
|
1145
|
-
r_norm_sq: wp.array,
|
|
1146
|
-
maxiter: int,
|
|
1147
|
-
atol: float,
|
|
1148
|
-
callback: Callable,
|
|
1149
|
-
check_every: int,
|
|
1150
|
-
use_cuda_graph: bool,
|
|
1151
|
-
device,
|
|
1152
|
-
):
|
|
1153
|
-
atol_sq = atol * atol
|
|
1154
|
-
check_every = max(check_every, cycle_size)
|
|
1155
|
-
|
|
1156
|
-
cur_iter = 0
|
|
1157
|
-
|
|
1158
|
-
err_sq = r_norm_sq.numpy()[0]
|
|
1159
|
-
err = math.sqrt(err_sq)
|
|
1160
|
-
if callback is not None:
|
|
1161
|
-
callback(cur_iter, err, atol)
|
|
1162
|
-
|
|
1163
|
-
if err_sq <= atol_sq:
|
|
1164
|
-
return cur_iter, err, atol
|
|
1165
|
-
|
|
1166
|
-
graph = None
|
|
1167
|
-
|
|
1168
|
-
while True:
|
|
1169
|
-
# Do not do graph capture at first iteration -- modules may not be loaded yet
|
|
1170
|
-
if device.is_cuda and use_cuda_graph and cur_iter > 0:
|
|
1171
|
-
if graph is None:
|
|
1172
|
-
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
1173
|
-
do_cycle()
|
|
1174
|
-
graph = capture.graph
|
|
1175
|
-
wp.capture_launch(graph)
|
|
1176
|
-
else:
|
|
1177
|
-
do_cycle()
|
|
1178
|
-
|
|
1179
|
-
cur_iter += cycle_size
|
|
1180
|
-
|
|
1181
|
-
if cur_iter >= maxiter:
|
|
1182
|
-
break
|
|
1183
|
-
|
|
1184
|
-
if (cur_iter % check_every) < cycle_size:
|
|
1185
|
-
err_sq = r_norm_sq.numpy()[0]
|
|
1186
|
-
|
|
1187
|
-
if err_sq <= atol_sq:
|
|
1188
|
-
break
|
|
1189
|
-
|
|
1190
|
-
if callback is not None:
|
|
1191
|
-
callback(cur_iter, math.sqrt(err_sq), atol)
|
|
1192
|
-
|
|
1193
|
-
err_sq = r_norm_sq.numpy()[0]
|
|
1194
|
-
err = math.sqrt(err_sq)
|
|
1195
|
-
if callback is not None:
|
|
1196
|
-
callback(cur_iter, err, atol)
|
|
1197
|
-
|
|
1198
|
-
return cur_iter, err, atol
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
@wp.kernel
|
|
1202
|
-
def _dense_mv_kernel(
|
|
1203
|
-
A: wp.array2d(dtype=Any),
|
|
1204
|
-
x: wp.array1d(dtype=Any),
|
|
1205
|
-
y: wp.array1d(dtype=Any),
|
|
1206
|
-
z: wp.array1d(dtype=Any),
|
|
1207
|
-
alpha: Any,
|
|
1208
|
-
beta: Any,
|
|
1209
|
-
):
|
|
1210
|
-
row, lane = wp.tid()
|
|
1211
|
-
|
|
1212
|
-
zero = type(alpha)(0)
|
|
1213
|
-
s = zero
|
|
1214
|
-
if alpha != zero:
|
|
1215
|
-
for col in range(lane, A.shape[1], wp.block_dim()):
|
|
1216
|
-
s += A[row, col] * x[col]
|
|
1217
|
-
|
|
1218
|
-
row_tile = wp.tile_sum(wp.tile(s * alpha))
|
|
1219
|
-
|
|
1220
|
-
if beta != zero:
|
|
1221
|
-
row_tile += wp.tile_load(y, shape=1, offset=row) * beta
|
|
1222
|
-
|
|
1223
|
-
wp.tile_store(z, row_tile, offset=row)
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
@wp.kernel
|
|
1227
|
-
def _diag_mv_kernel(
|
|
1228
|
-
A: wp.array(dtype=Any),
|
|
1229
|
-
x: wp.array(dtype=Any),
|
|
1230
|
-
y: wp.array(dtype=Any),
|
|
1231
|
-
z: wp.array(dtype=Any),
|
|
1232
|
-
alpha: Any,
|
|
1233
|
-
beta: Any,
|
|
1234
|
-
):
|
|
1235
|
-
i = wp.tid()
|
|
1236
|
-
zero = type(alpha)(0)
|
|
1237
|
-
s = z.dtype(zero)
|
|
1238
|
-
if alpha != zero:
|
|
1239
|
-
s += alpha * (A[i] * x[i])
|
|
1240
|
-
if beta != zero:
|
|
1241
|
-
s += beta * y[i]
|
|
1242
|
-
z[i] = s
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
@wp.func
|
|
1246
|
-
def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
|
|
1247
|
-
zero = type(coeff)(0.0)
|
|
1248
|
-
one = type(coeff)(1.0)
|
|
1249
|
-
return wp.where(coeff == zero, one, one / wp.where(use_abs, wp.abs(coeff), coeff))
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
@wp.kernel
|
|
1253
|
-
def _extract_inverse_diagonal_blocked(
|
|
1254
|
-
diag_block: wp.array(dtype=Any),
|
|
1255
|
-
inv_diag: wp.array(dtype=Any),
|
|
1256
|
-
use_abs: int,
|
|
1257
|
-
):
|
|
1258
|
-
i = wp.tid()
|
|
1259
|
-
|
|
1260
|
-
d = wp.get_diag(diag_block[i])
|
|
1261
|
-
for k in range(d.length):
|
|
1262
|
-
d[k] = _inverse_diag_coefficient(d[k], use_abs != 0)
|
|
1263
|
-
|
|
1264
|
-
inv_diag[i] = d
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
@wp.kernel
|
|
1268
|
-
def _extract_inverse_diagonal_scalar(
|
|
1269
|
-
diag_array: wp.array(dtype=Any),
|
|
1270
|
-
inv_diag: wp.array(dtype=Any),
|
|
1271
|
-
use_abs: int,
|
|
1272
|
-
):
|
|
1273
|
-
i = wp.tid()
|
|
1274
|
-
inv_diag[i] = _inverse_diag_coefficient(diag_array[i], use_abs != 0)
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
@wp.kernel
|
|
1278
|
-
def _extract_inverse_diagonal_dense(
|
|
1279
|
-
dense_matrix: wp.array2d(dtype=Any),
|
|
1280
|
-
inv_diag: wp.array(dtype=Any),
|
|
1281
|
-
use_abs: int,
|
|
1282
|
-
):
|
|
1283
|
-
i = wp.tid()
|
|
1284
|
-
inv_diag[i] = _inverse_diag_coefficient(dense_matrix[i, i], use_abs != 0)
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
@wp.kernel
|
|
1288
|
-
def _cg_kernel_1(
|
|
1289
|
-
tol: wp.array(dtype=Any),
|
|
1290
|
-
resid: wp.array(dtype=Any),
|
|
1291
|
-
rz_old: wp.array(dtype=Any),
|
|
1292
|
-
p_Ap: wp.array(dtype=Any),
|
|
1293
|
-
x: wp.array(dtype=Any),
|
|
1294
|
-
r: wp.array(dtype=Any),
|
|
1295
|
-
p: wp.array(dtype=Any),
|
|
1296
|
-
Ap: wp.array(dtype=Any),
|
|
1297
|
-
):
|
|
1298
|
-
i = wp.tid()
|
|
1299
|
-
|
|
1300
|
-
alpha = wp.where(resid[0] > tol[0], rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
|
|
1301
|
-
|
|
1302
|
-
x[i] = x[i] + alpha * p[i]
|
|
1303
|
-
r[i] = r[i] - alpha * Ap[i]
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
@wp.kernel
|
|
1307
|
-
def _cg_kernel_2(
|
|
1308
|
-
tol: wp.array(dtype=Any),
|
|
1309
|
-
resid_new: wp.array(dtype=Any),
|
|
1310
|
-
rz_old: wp.array(dtype=Any),
|
|
1311
|
-
rz_new: wp.array(dtype=Any),
|
|
1312
|
-
z: wp.array(dtype=Any),
|
|
1313
|
-
p: wp.array(dtype=Any),
|
|
1314
|
-
):
|
|
1315
|
-
# p = r + (rz_new / rz_old) * p;
|
|
1316
|
-
i = wp.tid()
|
|
1317
|
-
|
|
1318
|
-
cond = resid_new[0] > tol[0]
|
|
1319
|
-
beta = wp.where(cond, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
|
|
1320
|
-
|
|
1321
|
-
p[i] = z[i] + beta * p[i]
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
@wp.kernel
|
|
1325
|
-
def _cr_kernel_1(
|
|
1326
|
-
tol: wp.array(dtype=Any),
|
|
1327
|
-
resid: wp.array(dtype=Any),
|
|
1328
|
-
zAz_old: wp.array(dtype=Any),
|
|
1329
|
-
y_Ap: wp.array(dtype=Any),
|
|
1330
|
-
x: wp.array(dtype=Any),
|
|
1331
|
-
r: wp.array(dtype=Any),
|
|
1332
|
-
z: wp.array(dtype=Any),
|
|
1333
|
-
p: wp.array(dtype=Any),
|
|
1334
|
-
Ap: wp.array(dtype=Any),
|
|
1335
|
-
y: wp.array(dtype=Any),
|
|
1336
|
-
):
|
|
1337
|
-
i = wp.tid()
|
|
1338
|
-
|
|
1339
|
-
alpha = wp.where(resid[0] > tol[0] and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
|
|
1340
|
-
|
|
1341
|
-
x[i] = x[i] + alpha * p[i]
|
|
1342
|
-
r[i] = r[i] - alpha * Ap[i]
|
|
1343
|
-
z[i] = z[i] - alpha * y[i]
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
@wp.kernel
|
|
1347
|
-
def _cr_kernel_2(
|
|
1348
|
-
tol: wp.array(dtype=Any),
|
|
1349
|
-
resid: wp.array(dtype=Any),
|
|
1350
|
-
zAz_old: wp.array(dtype=Any),
|
|
1351
|
-
zAz_new: wp.array(dtype=Any),
|
|
1352
|
-
z: wp.array(dtype=Any),
|
|
1353
|
-
p: wp.array(dtype=Any),
|
|
1354
|
-
Az: wp.array(dtype=Any),
|
|
1355
|
-
Ap: wp.array(dtype=Any),
|
|
1356
|
-
):
|
|
1357
|
-
# p = r + (rz_new / rz_old) * p;
|
|
1358
|
-
i = wp.tid()
|
|
1359
|
-
|
|
1360
|
-
beta = wp.where(resid[0] > tol[0] and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
|
|
1361
|
-
|
|
1362
|
-
p[i] = z[i] + beta * p[i]
|
|
1363
|
-
Ap[i] = Az[i] + beta * Ap[i]
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
@wp.kernel
|
|
1367
|
-
def _bicgstab_kernel_1(
|
|
1368
|
-
tol: wp.array(dtype=Any),
|
|
1369
|
-
resid: wp.array(dtype=Any),
|
|
1370
|
-
rho_old: wp.array(dtype=Any),
|
|
1371
|
-
r0v: wp.array(dtype=Any),
|
|
1372
|
-
x: wp.array(dtype=Any),
|
|
1373
|
-
r: wp.array(dtype=Any),
|
|
1374
|
-
y: wp.array(dtype=Any),
|
|
1375
|
-
v: wp.array(dtype=Any),
|
|
1376
|
-
):
|
|
1377
|
-
i = wp.tid()
|
|
1378
|
-
|
|
1379
|
-
alpha = wp.where(resid[0] > tol[0], rho_old[0] / r0v[0], rho_old.dtype(0.0))
|
|
1380
|
-
|
|
1381
|
-
x[i] += alpha * y[i]
|
|
1382
|
-
r[i] -= alpha * v[i]
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
@wp.kernel
|
|
1386
|
-
def _bicgstab_kernel_2(
|
|
1387
|
-
tol: wp.array(dtype=Any),
|
|
1388
|
-
resid: wp.array(dtype=Any),
|
|
1389
|
-
st: wp.array(dtype=Any),
|
|
1390
|
-
tt: wp.array(dtype=Any),
|
|
1391
|
-
z: wp.array(dtype=Any),
|
|
1392
|
-
t: wp.array(dtype=Any),
|
|
1393
|
-
x: wp.array(dtype=Any),
|
|
1394
|
-
r: wp.array(dtype=Any),
|
|
1395
|
-
):
|
|
1396
|
-
i = wp.tid()
|
|
1397
|
-
|
|
1398
|
-
omega = wp.where(resid[0] > tol[0], st[0] / tt[0], st.dtype(0.0))
|
|
1399
|
-
|
|
1400
|
-
x[i] += omega * z[i]
|
|
1401
|
-
r[i] -= omega * t[i]
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
@wp.kernel
|
|
1405
|
-
def _bicgstab_kernel_3(
|
|
1406
|
-
tol: wp.array(dtype=Any),
|
|
1407
|
-
resid: wp.array(dtype=Any),
|
|
1408
|
-
rho_new: wp.array(dtype=Any),
|
|
1409
|
-
r0v: wp.array(dtype=Any),
|
|
1410
|
-
st: wp.array(dtype=Any),
|
|
1411
|
-
tt: wp.array(dtype=Any),
|
|
1412
|
-
p: wp.array(dtype=Any),
|
|
1413
|
-
r: wp.array(dtype=Any),
|
|
1414
|
-
v: wp.array(dtype=Any),
|
|
1415
|
-
):
|
|
1416
|
-
i = wp.tid()
|
|
1417
|
-
|
|
1418
|
-
beta = wp.where(resid[0] > tol[0], rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
|
|
1419
|
-
beta_omega = wp.where(resid[0] > tol[0], rho_new[0] / r0v[0], st.dtype(0.0))
|
|
1420
|
-
|
|
1421
|
-
p[i] = r[i] + beta * p[i] - beta_omega * v[i]
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
@wp.kernel
|
|
1425
|
-
def _gmres_solve_least_squares(
|
|
1426
|
-
k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
|
|
1427
|
-
):
|
|
1428
|
-
# Solve H y = (beta, 0, ..., 0)
|
|
1429
|
-
# H Hessenberg matrix of shape (k+1, k)
|
|
1430
|
-
# so would not fit in registers
|
|
1431
|
-
|
|
1432
|
-
rhs = beta[0]
|
|
1433
|
-
|
|
1434
|
-
# Apply 2x2 rotations to H so as to remove lower diagonal,
|
|
1435
|
-
# and apply similar rotations to right-hand-side
|
|
1436
|
-
max_k = int(k)
|
|
1437
|
-
for i in range(k):
|
|
1438
|
-
Ha = H[i]
|
|
1439
|
-
Hb = H[i + 1]
|
|
1440
|
-
|
|
1441
|
-
# Givens rotation [[c s], [-s c]]
|
|
1442
|
-
a = Ha[i]
|
|
1443
|
-
b = Hb[i]
|
|
1444
|
-
abn_sq = a * a + b * b
|
|
1445
|
-
|
|
1446
|
-
if abn_sq < type(abn_sq)(pivot_tolerance):
|
|
1447
|
-
# Arnoldi iteration finished early
|
|
1448
|
-
max_k = i
|
|
1449
|
-
break
|
|
1450
|
-
|
|
1451
|
-
abn = wp.sqrt(abn_sq)
|
|
1452
|
-
c = a / abn
|
|
1453
|
-
s = b / abn
|
|
1454
|
-
|
|
1455
|
-
# Rotate H
|
|
1456
|
-
for j in range(i, k):
|
|
1457
|
-
a = Ha[j]
|
|
1458
|
-
b = Hb[j]
|
|
1459
|
-
Ha[j] = c * a + s * b
|
|
1460
|
-
Hb[j] = c * b - s * a
|
|
1461
|
-
|
|
1462
|
-
# Rotate rhs
|
|
1463
|
-
y[i] = c * rhs
|
|
1464
|
-
rhs = -s * rhs
|
|
1465
|
-
|
|
1466
|
-
for i in range(max_k, k):
|
|
1467
|
-
y[i] = y.dtype(0.0)
|
|
1468
|
-
|
|
1469
|
-
# Triangular back-solve for y
|
|
1470
|
-
for ii in range(max_k, 0, -1):
|
|
1471
|
-
i = ii - 1
|
|
1472
|
-
Hi = H[i]
|
|
1473
|
-
yi = y[i]
|
|
1474
|
-
for j in range(ii, max_k):
|
|
1475
|
-
yi -= Hi[j] * y[j]
|
|
1476
|
-
y[i] = yi / Hi[i]
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
@functools.lru_cache(maxsize=None)
|
|
1480
|
-
def make_gmres_solve_least_squares_kernel_tiled(K: int):
|
|
1481
|
-
@wp.kernel(module="unique")
|
|
1482
|
-
def gmres_solve_least_squares_tiled(
|
|
1483
|
-
k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
|
|
1484
|
-
):
|
|
1485
|
-
# Assumes tiles of size K, and K at least as large as highest number of columns
|
|
1486
|
-
# Limits the max restart cycle length to the max block size of 1024, but using
|
|
1487
|
-
# larger restarts would be very inefficient anyway (default is ~30)
|
|
1488
|
-
|
|
1489
|
-
# Solve H y = (beta, 0, ..., 0)
|
|
1490
|
-
# H Hessenberg matrix of shape (k+1, k)
|
|
1491
|
-
|
|
1492
|
-
i, lane = wp.tid()
|
|
1493
|
-
|
|
1494
|
-
rhs = beta[0]
|
|
1495
|
-
|
|
1496
|
-
zero = H.dtype(0.0)
|
|
1497
|
-
one = H.dtype(1.0)
|
|
1498
|
-
yi = zero
|
|
1499
|
-
|
|
1500
|
-
Ha = wp.tile_load(H[0], shape=(K))
|
|
1501
|
-
|
|
1502
|
-
# Apply 2x2 rotations to H so as to remove lower diagonal,
|
|
1503
|
-
# and apply similar rotations to right-hand-side
|
|
1504
|
-
max_k = int(k)
|
|
1505
|
-
for i in range(k):
|
|
1506
|
-
# Ha = H[i]
|
|
1507
|
-
# Hb = H[i + 1]
|
|
1508
|
-
Hb = wp.tile_load(H[i + 1], shape=(K))
|
|
1509
|
-
|
|
1510
|
-
# Givens rotation [[c s], [-s c]]
|
|
1511
|
-
a = Ha[i]
|
|
1512
|
-
b = Hb[i]
|
|
1513
|
-
abn_sq = a * a + b * b
|
|
1514
|
-
|
|
1515
|
-
if abn_sq < type(abn_sq)(pivot_tolerance):
|
|
1516
|
-
# Arnoldi iteration finished early
|
|
1517
|
-
max_k = i
|
|
1518
|
-
break
|
|
1519
|
-
|
|
1520
|
-
abn = wp.sqrt(abn_sq)
|
|
1521
|
-
c = a / abn
|
|
1522
|
-
s = b / abn
|
|
1523
|
-
|
|
1524
|
-
# Rotate H
|
|
1525
|
-
a = wp.untile(Ha)
|
|
1526
|
-
b = wp.untile(Hb)
|
|
1527
|
-
a_rot = c * a + s * b
|
|
1528
|
-
b_rot = c * b - s * a
|
|
1529
|
-
|
|
1530
|
-
# Rotate rhs
|
|
1531
|
-
if lane == i:
|
|
1532
|
-
yi = c * rhs
|
|
1533
|
-
rhs = -s * rhs
|
|
1534
|
-
|
|
1535
|
-
wp.tile_store(H[i], wp.tile(a_rot))
|
|
1536
|
-
Ha[lane] = b_rot
|
|
1537
|
-
|
|
1538
|
-
y_tile = wp.tile(yi)
|
|
1539
|
-
|
|
1540
|
-
# Triangular back-solve for y
|
|
1541
|
-
for ii in range(max_k, 0, -1):
|
|
1542
|
-
i = ii - 1
|
|
1543
|
-
|
|
1544
|
-
Hi = wp.tile_load(H[i], shape=(K))
|
|
1545
|
-
|
|
1546
|
-
il = lane + i
|
|
1547
|
-
if lane == 0:
|
|
1548
|
-
yl = y_tile[i]
|
|
1549
|
-
elif il < max_k:
|
|
1550
|
-
yl = -y_tile[il] * Hi[il]
|
|
1551
|
-
else:
|
|
1552
|
-
yl = zero
|
|
1553
|
-
|
|
1554
|
-
yit = wp.tile_sum(wp.tile(yl)) * (one / Hi[i])
|
|
1555
|
-
yit[0] # no-op, movs yit to shared
|
|
1556
|
-
wp.tile_assign(y_tile, yit, offset=(i,))
|
|
1557
|
-
|
|
1558
|
-
wp.tile_store(y, y_tile)
|
|
1559
|
-
|
|
1560
|
-
return gmres_solve_least_squares_tiled
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
@wp.kernel
|
|
1564
|
-
def _gmres_arnoldi_axpy_kernel(
|
|
1565
|
-
V: wp.array2d(dtype=Any),
|
|
1566
|
-
w: wp.array(dtype=Any),
|
|
1567
|
-
Vw: wp.array2d(dtype=Any),
|
|
1568
|
-
):
|
|
1569
|
-
tid, lane = wp.tid()
|
|
1570
|
-
|
|
1571
|
-
s = w.dtype(Vw.dtype(0))
|
|
1572
|
-
|
|
1573
|
-
tile_size = wp.block_dim()
|
|
1574
|
-
for k in range(lane, Vw.shape[0], tile_size):
|
|
1575
|
-
s += Vw[k, 0] * V[k, tid]
|
|
1576
|
-
|
|
1577
|
-
wi = wp.tile_load(w, shape=1, offset=tid)
|
|
1578
|
-
wi -= wp.tile_sum(wp.tile(s, preserve_type=True))
|
|
1579
|
-
|
|
1580
|
-
wp.tile_store(w, wi, offset=tid)
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
@wp.kernel
|
|
1584
|
-
def _gmres_arnoldi_normalize_kernel(
|
|
1585
|
-
x: wp.array(dtype=Any),
|
|
1586
|
-
y: wp.array(dtype=Any),
|
|
1587
|
-
alpha: wp.array(dtype=Any),
|
|
1588
|
-
alpha_copy: wp.array(dtype=Any),
|
|
1589
|
-
):
|
|
1590
|
-
tid = wp.tid()
|
|
1591
|
-
norm = wp.sqrt(alpha[0])
|
|
1592
|
-
y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / norm)
|
|
1593
|
-
|
|
1594
|
-
if tid == 0:
|
|
1595
|
-
alpha_copy[0] = norm
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
@wp.kernel
|
|
1599
|
-
def _gmres_update_x_kernel(k: int, beta: Any, y: wp.array(dtype=Any), V: wp.array2d(dtype=Any), x: wp.array(dtype=Any)):
|
|
1600
|
-
tid = wp.tid()
|
|
1601
|
-
|
|
1602
|
-
xi = beta * x[tid]
|
|
1603
|
-
for j in range(k):
|
|
1604
|
-
xi += V[j, tid] * y[j]
|
|
1605
|
-
|
|
1606
|
-
x[tid] = xi
|
|
35
|
+
return get_deprecated_api(_linear, "wp.optim", name)
|