warp-lang 1.0.1__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -279
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -28
- warp/examples/core/example_dem.py +234 -221
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -129
- warp/examples/core/example_marching_cubes.py +188 -176
- warp/examples/core/example_mesh.py +174 -154
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -169
- warp/examples/core/example_raycast.py +105 -89
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -389
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -249
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -391
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -248
- warp/examples/optim/example_cloth_throw.py +222 -210
- warp/examples/optim/example_diffray.py +566 -535
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -169
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
- warp/examples/optim/example_spring_cage.py +239 -234
- warp/examples/optim/example_trajectory.py +223 -201
- warp/examples/optim/example_walker.py +306 -292
- warp/examples/sim/example_cartpole.py +139 -128
- warp/examples/sim/example_cloth.py +196 -184
- warp/examples/sim/example_granular.py +124 -113
- warp/examples/sim/example_granular_collision_sdf.py +197 -185
- warp/examples/sim/example_jacobian_ik.py +236 -213
- warp/examples/sim/example_particle_chain.py +118 -106
- warp/examples/sim/example_quadruped.py +193 -179
- warp/examples/sim/example_rigid_chain.py +197 -189
- warp/examples/sim/example_rigid_contact.py +189 -176
- warp/examples/sim/example_rigid_force.py +127 -126
- warp/examples/sim/example_rigid_gyroscopic.py +109 -97
- warp/examples/sim/example_rigid_soft_contact.py +134 -124
- warp/examples/sim/example_soft_body.py +190 -178
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.1.dist-info/RECORD +0 -352
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/optim/linear.py
CHANGED
|
@@ -1,939 +1,1104 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
import warp as wp
|
|
5
|
-
import warp.sparse as sparse
|
|
6
|
-
from warp.utils import array_inner
|
|
7
|
-
|
|
8
|
-
# No need to auto-generate adjoint code for linear solvers
|
|
9
|
-
wp.set_module_options({"enable_backward": False})
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class LinearOperator:
|
|
13
|
-
"""
|
|
14
|
-
Linear operator to be used as left-hand-side of linear iterative solvers.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
shape: Tuple containing the number of rows and columns of the operator
|
|
18
|
-
dtype: Type of the operator elements
|
|
19
|
-
device: Device on which computations involving the operator should be performed
|
|
20
|
-
matvec: Matrix-vector multiplication routine
|
|
21
|
-
|
|
22
|
-
The matrix-vector multiplication routine should have the following signature:
|
|
23
|
-
|
|
24
|
-
.. code-block:: python
|
|
25
|
-
|
|
26
|
-
def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
|
|
27
|
-
'''Performs the operation z = alpha * x + beta * y'''
|
|
28
|
-
...
|
|
29
|
-
|
|
30
|
-
For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
|
|
31
|
-
for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
|
|
32
|
-
cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
|
|
33
|
-
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(self, shape: Tuple[int, int], dtype: type, device: wp.context.Device, matvec: Callable):
|
|
37
|
-
self._shape = shape
|
|
38
|
-
self._dtype = dtype
|
|
39
|
-
self._device = device
|
|
40
|
-
self._matvec = matvec
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def shape(self) -> Tuple[int, int]:
|
|
44
|
-
return self._shape
|
|
45
|
-
|
|
46
|
-
@property
|
|
47
|
-
def dtype(self) -> type:
|
|
48
|
-
return self._dtype
|
|
49
|
-
|
|
50
|
-
@property
|
|
51
|
-
def device(self) -> wp.context.Device:
|
|
52
|
-
return self._device
|
|
53
|
-
|
|
54
|
-
@property
|
|
55
|
-
def matvec(self) -> Callable:
|
|
56
|
-
return self._matvec
|
|
57
|
-
|
|
58
|
-
@property
|
|
59
|
-
def scalar_type(self):
|
|
60
|
-
return wp.types.type_scalar_type(self.dtype)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
_Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def aslinearoperator(A: _Matrix) -> LinearOperator:
|
|
67
|
-
"""
|
|
68
|
-
Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
|
|
69
|
-
|
|
70
|
-
`A` must be of one of the following types:
|
|
71
|
-
|
|
72
|
-
- :class:`warp.sparse.BsrMatrix`
|
|
73
|
-
- two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
|
|
74
|
-
- one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
|
|
75
|
-
- :class:`warp.sparse.LinearOperator`; no casting necessary
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
if A is None or isinstance(A, LinearOperator):
|
|
79
|
-
return A
|
|
80
|
-
|
|
81
|
-
def bsr_mv(x, y, z, alpha, beta):
|
|
82
|
-
if z.ptr != y.ptr and beta != 0.0:
|
|
83
|
-
wp.copy(src=y, dest=z)
|
|
84
|
-
sparse.bsr_mv(A, x, z, alpha, beta)
|
|
85
|
-
|
|
86
|
-
def dense_mv(x, y, z, alpha, beta):
|
|
87
|
-
x = x.reshape((x.shape[0], 1))
|
|
88
|
-
y = y.reshape((y.shape[0], 1))
|
|
89
|
-
z = z.reshape((y.shape[0], 1))
|
|
90
|
-
wp.matmul(A, x, y, z, alpha, beta)
|
|
91
|
-
|
|
92
|
-
def diag_mv(x, y, z, alpha, beta):
|
|
93
|
-
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
94
|
-
alpha = scalar_type(alpha)
|
|
95
|
-
beta = scalar_type(beta)
|
|
96
|
-
wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
97
|
-
|
|
98
|
-
def diag_mv_vec(x, y, z, alpha, beta):
|
|
99
|
-
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
100
|
-
alpha = scalar_type(alpha)
|
|
101
|
-
beta = scalar_type(beta)
|
|
102
|
-
wp.launch(_diag_mv_vec_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
103
|
-
|
|
104
|
-
if isinstance(A, wp.array):
|
|
105
|
-
if A.ndim == 2:
|
|
106
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
|
|
107
|
-
if A.ndim == 1:
|
|
108
|
-
if wp.types.type_is_vector(A.dtype):
|
|
109
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
|
|
110
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
|
|
111
|
-
if isinstance(A, sparse.BsrMatrix):
|
|
112
|
-
return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
|
|
113
|
-
|
|
114
|
-
raise ValueError(f"Unable to create LinearOperator from {A}")
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
|
|
118
|
-
"""Constructs and returns a preconditioner for an input matrix.
|
|
119
|
-
|
|
120
|
-
Args:
|
|
121
|
-
A: The matrix for which to build the preconditioner
|
|
122
|
-
ptype: The type of preconditioner. Currently the following values are supported:
|
|
123
|
-
|
|
124
|
-
- ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
|
|
125
|
-
- ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
|
|
126
|
-
- ``"id"``: Identity (null) preconditioner
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
if ptype == "id":
|
|
130
|
-
return None
|
|
131
|
-
|
|
132
|
-
if ptype in ("diag", "diag_abs"):
|
|
133
|
-
use_abs = 1 if ptype == "diag_abs" else 0
|
|
134
|
-
if isinstance(A, sparse.BsrMatrix):
|
|
135
|
-
A_diag = sparse.bsr_get_diag(A)
|
|
136
|
-
if wp.types.type_is_matrix(A.dtype):
|
|
137
|
-
inv_diag = wp.empty(
|
|
138
|
-
shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
|
|
139
|
-
)
|
|
140
|
-
wp.launch(
|
|
141
|
-
_extract_inverse_diagonal_blocked,
|
|
142
|
-
dim=inv_diag.shape,
|
|
143
|
-
device=inv_diag.device,
|
|
144
|
-
inputs=[A_diag, inv_diag, use_abs],
|
|
145
|
-
)
|
|
146
|
-
else:
|
|
147
|
-
inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
|
|
148
|
-
wp.launch(
|
|
149
|
-
_extract_inverse_diagonal_scalar,
|
|
150
|
-
dim=inv_diag.shape,
|
|
151
|
-
device=inv_diag.device,
|
|
152
|
-
inputs=[A_diag, inv_diag, use_abs],
|
|
153
|
-
)
|
|
154
|
-
elif isinstance(A, wp.array) and A.ndim == 2:
|
|
155
|
-
inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
|
|
156
|
-
wp.launch(
|
|
157
|
-
_extract_inverse_diagonal_dense,
|
|
158
|
-
dim=inv_diag.shape,
|
|
159
|
-
device=inv_diag.device,
|
|
160
|
-
inputs=[A, inv_diag, use_abs],
|
|
161
|
-
)
|
|
162
|
-
else:
|
|
163
|
-
raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
|
|
164
|
-
|
|
165
|
-
return aslinearoperator(inv_diag)
|
|
166
|
-
|
|
167
|
-
raise ValueError(f"Unsupported preconditioner type '{ptype}'")
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
def cg(
|
|
171
|
-
A: _Matrix,
|
|
172
|
-
b: wp.array,
|
|
173
|
-
x: wp.array,
|
|
174
|
-
tol: Optional[float] = None,
|
|
175
|
-
atol: Optional[float] = None,
|
|
176
|
-
maxiter: Optional[float] = 0,
|
|
177
|
-
M: Optional[_Matrix] = None,
|
|
178
|
-
callback: Optional[Callable] = None,
|
|
179
|
-
check_every=10,
|
|
180
|
-
use_cuda_graph=True,
|
|
181
|
-
) -> Tuple[int, float, float]:
|
|
182
|
-
"""Computes an approximate solution to a symmetric, positive-definite linear system
|
|
183
|
-
using the Conjugate Gradient algorithm.
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
A: the linear system's left-hand-side
|
|
187
|
-
b: the linear system's right-hand-side
|
|
188
|
-
x: initial guess and solution vector
|
|
189
|
-
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
190
|
-
atol: absolute tolerance for the residual
|
|
191
|
-
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
192
|
-
Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
|
|
193
|
-
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
194
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
195
|
-
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
196
|
-
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
197
|
-
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
Tuple (final iteration number, residual norm, absolute tolerance)
|
|
201
|
-
|
|
202
|
-
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
203
|
-
"""
|
|
204
|
-
|
|
205
|
-
A = aslinearoperator(A)
|
|
206
|
-
M = aslinearoperator(M)
|
|
207
|
-
|
|
208
|
-
if maxiter == 0:
|
|
209
|
-
maxiter = A.shape[0]
|
|
210
|
-
|
|
211
|
-
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
212
|
-
|
|
213
|
-
device = A.device
|
|
214
|
-
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
215
|
-
|
|
216
|
-
# Notations below follow pseudo-code from https://en.wikipedia.org/wiki/Conjugate_gradient_method
|
|
217
|
-
|
|
218
|
-
# z = M r
|
|
219
|
-
if M is not None:
|
|
220
|
-
z = wp.zeros_like(b)
|
|
221
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
222
|
-
|
|
223
|
-
# rz = r' z;
|
|
224
|
-
rz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
225
|
-
array_inner(r, z, out=rz_new)
|
|
226
|
-
else:
|
|
227
|
-
z = r
|
|
228
|
-
|
|
229
|
-
rz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
230
|
-
p_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
231
|
-
Ap = wp.zeros_like(b)
|
|
232
|
-
|
|
233
|
-
p = wp.clone(z)
|
|
234
|
-
|
|
235
|
-
def do_iteration(atol_sq, rr_old, rr_new, rz_old, rz_new):
|
|
236
|
-
# Ap = A * p;
|
|
237
|
-
A.matvec(p, Ap, Ap, alpha=1, beta=0)
|
|
238
|
-
|
|
239
|
-
array_inner(p, Ap, out=p_Ap)
|
|
240
|
-
|
|
241
|
-
wp.launch(
|
|
242
|
-
kernel=_cg_kernel_1,
|
|
243
|
-
dim=x.shape[0],
|
|
244
|
-
device=device,
|
|
245
|
-
inputs=[atol_sq, rr_old, rz_old, p_Ap, x, r, p, Ap],
|
|
246
|
-
)
|
|
247
|
-
array_inner(r, r, out=rr_new)
|
|
248
|
-
|
|
249
|
-
# z = M r
|
|
250
|
-
if M is not None:
|
|
251
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
252
|
-
# rz = r' z;
|
|
253
|
-
array_inner(r, z, out=rz_new)
|
|
254
|
-
|
|
255
|
-
wp.launch(kernel=_cg_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr_new, rz_old, rz_new, z, p])
|
|
256
|
-
|
|
257
|
-
# We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
|
|
258
|
-
# In the non-preconditioned case we reuse the error norm buffer for the new <r,z> computation
|
|
259
|
-
|
|
260
|
-
def do_odd_even_cycle(atol_sq: float):
|
|
261
|
-
# A pair of iterations, so that we're swapping the residual buffers twice
|
|
262
|
-
if M is None:
|
|
263
|
-
do_iteration(atol_sq, r_norm_sq, rz_old, r_norm_sq, rz_old)
|
|
264
|
-
do_iteration(atol_sq, rz_old, r_norm_sq, rz_old, r_norm_sq)
|
|
265
|
-
else:
|
|
266
|
-
do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_new, rz_old)
|
|
267
|
-
do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_old, rz_new)
|
|
268
|
-
|
|
269
|
-
return _run_solver_loop(
|
|
270
|
-
do_odd_even_cycle,
|
|
271
|
-
cycle_size=2,
|
|
272
|
-
r_norm_sq=r_norm_sq,
|
|
273
|
-
maxiter=maxiter,
|
|
274
|
-
atol=atol,
|
|
275
|
-
callback=callback,
|
|
276
|
-
check_every=check_every,
|
|
277
|
-
use_cuda_graph=use_cuda_graph,
|
|
278
|
-
device=device,
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def
|
|
283
|
-
A: _Matrix,
|
|
284
|
-
b: wp.array,
|
|
285
|
-
x: wp.array,
|
|
286
|
-
tol: Optional[float] = None,
|
|
287
|
-
atol: Optional[float] = None,
|
|
288
|
-
maxiter: Optional[float] = 0,
|
|
289
|
-
M: Optional[_Matrix] = None,
|
|
290
|
-
callback: Optional[Callable] = None,
|
|
291
|
-
check_every=10,
|
|
292
|
-
use_cuda_graph=True,
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
Args:
|
|
298
|
-
A: the linear system's left-hand-side
|
|
299
|
-
b: the linear system's right-hand-side
|
|
300
|
-
x: initial guess and solution vector
|
|
301
|
-
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
302
|
-
atol: absolute tolerance for the residual
|
|
303
|
-
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
Returns:
|
|
312
|
-
Tuple (final iteration number, residual norm, absolute tolerance)
|
|
313
|
-
|
|
314
|
-
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
315
|
-
"""
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
)
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
@wp.kernel
|
|
807
|
-
def
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
)
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
)
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
)
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
alpha
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
1
|
+
from math import sqrt
|
|
2
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import warp as wp
|
|
5
|
+
import warp.sparse as sparse
|
|
6
|
+
from warp.utils import array_inner
|
|
7
|
+
|
|
8
|
+
# No need to auto-generate adjoint code for linear solvers
|
|
9
|
+
wp.set_module_options({"enable_backward": False})
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LinearOperator:
|
|
13
|
+
"""
|
|
14
|
+
Linear operator to be used as left-hand-side of linear iterative solvers.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
shape: Tuple containing the number of rows and columns of the operator
|
|
18
|
+
dtype: Type of the operator elements
|
|
19
|
+
device: Device on which computations involving the operator should be performed
|
|
20
|
+
matvec: Matrix-vector multiplication routine
|
|
21
|
+
|
|
22
|
+
The matrix-vector multiplication routine should have the following signature:
|
|
23
|
+
|
|
24
|
+
.. code-block:: python
|
|
25
|
+
|
|
26
|
+
def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
|
|
27
|
+
'''Performs the operation z = alpha * x + beta * y'''
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
|
|
31
|
+
for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
|
|
32
|
+
cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, shape: Tuple[int, int], dtype: type, device: wp.context.Device, matvec: Callable):
|
|
37
|
+
self._shape = shape
|
|
38
|
+
self._dtype = dtype
|
|
39
|
+
self._device = device
|
|
40
|
+
self._matvec = matvec
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def shape(self) -> Tuple[int, int]:
|
|
44
|
+
return self._shape
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def dtype(self) -> type:
|
|
48
|
+
return self._dtype
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def device(self) -> wp.context.Device:
|
|
52
|
+
return self._device
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def matvec(self) -> Callable:
|
|
56
|
+
return self._matvec
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def scalar_type(self):
|
|
60
|
+
return wp.types.type_scalar_type(self.dtype)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
_Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def aslinearoperator(A: _Matrix) -> LinearOperator:
|
|
67
|
+
"""
|
|
68
|
+
Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
|
|
69
|
+
|
|
70
|
+
`A` must be of one of the following types:
|
|
71
|
+
|
|
72
|
+
- :class:`warp.sparse.BsrMatrix`
|
|
73
|
+
- two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
|
|
74
|
+
- one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
|
|
75
|
+
- :class:`warp.sparse.LinearOperator`; no casting necessary
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
if A is None or isinstance(A, LinearOperator):
|
|
79
|
+
return A
|
|
80
|
+
|
|
81
|
+
def bsr_mv(x, y, z, alpha, beta):
|
|
82
|
+
if z.ptr != y.ptr and beta != 0.0:
|
|
83
|
+
wp.copy(src=y, dest=z)
|
|
84
|
+
sparse.bsr_mv(A, x, z, alpha, beta)
|
|
85
|
+
|
|
86
|
+
def dense_mv(x, y, z, alpha, beta):
|
|
87
|
+
x = x.reshape((x.shape[0], 1))
|
|
88
|
+
y = y.reshape((y.shape[0], 1))
|
|
89
|
+
z = z.reshape((y.shape[0], 1))
|
|
90
|
+
wp.matmul(A, x, y, z, alpha, beta)
|
|
91
|
+
|
|
92
|
+
def diag_mv(x, y, z, alpha, beta):
|
|
93
|
+
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
94
|
+
alpha = scalar_type(alpha)
|
|
95
|
+
beta = scalar_type(beta)
|
|
96
|
+
wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
97
|
+
|
|
98
|
+
def diag_mv_vec(x, y, z, alpha, beta):
|
|
99
|
+
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
100
|
+
alpha = scalar_type(alpha)
|
|
101
|
+
beta = scalar_type(beta)
|
|
102
|
+
wp.launch(_diag_mv_vec_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
103
|
+
|
|
104
|
+
if isinstance(A, wp.array):
|
|
105
|
+
if A.ndim == 2:
|
|
106
|
+
return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
|
|
107
|
+
if A.ndim == 1:
|
|
108
|
+
if wp.types.type_is_vector(A.dtype):
|
|
109
|
+
return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
|
|
110
|
+
return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
|
|
111
|
+
if isinstance(A, sparse.BsrMatrix):
|
|
112
|
+
return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
|
|
113
|
+
|
|
114
|
+
raise ValueError(f"Unable to create LinearOperator from {A}")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
|
|
118
|
+
"""Constructs and returns a preconditioner for an input matrix.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
A: The matrix for which to build the preconditioner
|
|
122
|
+
ptype: The type of preconditioner. Currently the following values are supported:
|
|
123
|
+
|
|
124
|
+
- ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
|
|
125
|
+
- ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
|
|
126
|
+
- ``"id"``: Identity (null) preconditioner
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
if ptype == "id":
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
if ptype in ("diag", "diag_abs"):
|
|
133
|
+
use_abs = 1 if ptype == "diag_abs" else 0
|
|
134
|
+
if isinstance(A, sparse.BsrMatrix):
|
|
135
|
+
A_diag = sparse.bsr_get_diag(A)
|
|
136
|
+
if wp.types.type_is_matrix(A.dtype):
|
|
137
|
+
inv_diag = wp.empty(
|
|
138
|
+
shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
|
|
139
|
+
)
|
|
140
|
+
wp.launch(
|
|
141
|
+
_extract_inverse_diagonal_blocked,
|
|
142
|
+
dim=inv_diag.shape,
|
|
143
|
+
device=inv_diag.device,
|
|
144
|
+
inputs=[A_diag, inv_diag, use_abs],
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
|
|
148
|
+
wp.launch(
|
|
149
|
+
_extract_inverse_diagonal_scalar,
|
|
150
|
+
dim=inv_diag.shape,
|
|
151
|
+
device=inv_diag.device,
|
|
152
|
+
inputs=[A_diag, inv_diag, use_abs],
|
|
153
|
+
)
|
|
154
|
+
elif isinstance(A, wp.array) and A.ndim == 2:
|
|
155
|
+
inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
|
|
156
|
+
wp.launch(
|
|
157
|
+
_extract_inverse_diagonal_dense,
|
|
158
|
+
dim=inv_diag.shape,
|
|
159
|
+
device=inv_diag.device,
|
|
160
|
+
inputs=[A, inv_diag, use_abs],
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
|
|
164
|
+
|
|
165
|
+
return aslinearoperator(inv_diag)
|
|
166
|
+
|
|
167
|
+
raise ValueError(f"Unsupported preconditioner type '{ptype}'")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def cg(
|
|
171
|
+
A: _Matrix,
|
|
172
|
+
b: wp.array,
|
|
173
|
+
x: wp.array,
|
|
174
|
+
tol: Optional[float] = None,
|
|
175
|
+
atol: Optional[float] = None,
|
|
176
|
+
maxiter: Optional[float] = 0,
|
|
177
|
+
M: Optional[_Matrix] = None,
|
|
178
|
+
callback: Optional[Callable] = None,
|
|
179
|
+
check_every=10,
|
|
180
|
+
use_cuda_graph=True,
|
|
181
|
+
) -> Tuple[int, float, float]:
|
|
182
|
+
"""Computes an approximate solution to a symmetric, positive-definite linear system
|
|
183
|
+
using the Conjugate Gradient algorithm.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
A: the linear system's left-hand-side
|
|
187
|
+
b: the linear system's right-hand-side
|
|
188
|
+
x: initial guess and solution vector
|
|
189
|
+
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
190
|
+
atol: absolute tolerance for the residual
|
|
191
|
+
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
192
|
+
Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
|
|
193
|
+
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
194
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
195
|
+
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
196
|
+
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
197
|
+
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Tuple (final iteration number, residual norm, absolute tolerance)
|
|
201
|
+
|
|
202
|
+
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
A = aslinearoperator(A)
|
|
206
|
+
M = aslinearoperator(M)
|
|
207
|
+
|
|
208
|
+
if maxiter == 0:
|
|
209
|
+
maxiter = A.shape[0]
|
|
210
|
+
|
|
211
|
+
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
212
|
+
|
|
213
|
+
device = A.device
|
|
214
|
+
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
215
|
+
|
|
216
|
+
# Notations below follow pseudo-code from https://en.wikipedia.org/wiki/Conjugate_gradient_method
|
|
217
|
+
|
|
218
|
+
# z = M r
|
|
219
|
+
if M is not None:
|
|
220
|
+
z = wp.zeros_like(b)
|
|
221
|
+
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
222
|
+
|
|
223
|
+
# rz = r' z;
|
|
224
|
+
rz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
225
|
+
array_inner(r, z, out=rz_new)
|
|
226
|
+
else:
|
|
227
|
+
z = r
|
|
228
|
+
|
|
229
|
+
rz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
230
|
+
p_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
231
|
+
Ap = wp.zeros_like(b)
|
|
232
|
+
|
|
233
|
+
p = wp.clone(z)
|
|
234
|
+
|
|
235
|
+
def do_iteration(atol_sq, rr_old, rr_new, rz_old, rz_new):
|
|
236
|
+
# Ap = A * p;
|
|
237
|
+
A.matvec(p, Ap, Ap, alpha=1, beta=0)
|
|
238
|
+
|
|
239
|
+
array_inner(p, Ap, out=p_Ap)
|
|
240
|
+
|
|
241
|
+
wp.launch(
|
|
242
|
+
kernel=_cg_kernel_1,
|
|
243
|
+
dim=x.shape[0],
|
|
244
|
+
device=device,
|
|
245
|
+
inputs=[atol_sq, rr_old, rz_old, p_Ap, x, r, p, Ap],
|
|
246
|
+
)
|
|
247
|
+
array_inner(r, r, out=rr_new)
|
|
248
|
+
|
|
249
|
+
# z = M r
|
|
250
|
+
if M is not None:
|
|
251
|
+
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
252
|
+
# rz = r' z;
|
|
253
|
+
array_inner(r, z, out=rz_new)
|
|
254
|
+
|
|
255
|
+
wp.launch(kernel=_cg_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr_new, rz_old, rz_new, z, p])
|
|
256
|
+
|
|
257
|
+
# We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
|
|
258
|
+
# In the non-preconditioned case we reuse the error norm buffer for the new <r,z> computation
|
|
259
|
+
|
|
260
|
+
def do_odd_even_cycle(atol_sq: float):
|
|
261
|
+
# A pair of iterations, so that we're swapping the residual buffers twice
|
|
262
|
+
if M is None:
|
|
263
|
+
do_iteration(atol_sq, r_norm_sq, rz_old, r_norm_sq, rz_old)
|
|
264
|
+
do_iteration(atol_sq, rz_old, r_norm_sq, rz_old, r_norm_sq)
|
|
265
|
+
else:
|
|
266
|
+
do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_new, rz_old)
|
|
267
|
+
do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_old, rz_new)
|
|
268
|
+
|
|
269
|
+
return _run_solver_loop(
|
|
270
|
+
do_odd_even_cycle,
|
|
271
|
+
cycle_size=2,
|
|
272
|
+
r_norm_sq=r_norm_sq,
|
|
273
|
+
maxiter=maxiter,
|
|
274
|
+
atol=atol,
|
|
275
|
+
callback=callback,
|
|
276
|
+
check_every=check_every,
|
|
277
|
+
use_cuda_graph=use_cuda_graph,
|
|
278
|
+
device=device,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def cr(
|
|
283
|
+
A: _Matrix,
|
|
284
|
+
b: wp.array,
|
|
285
|
+
x: wp.array,
|
|
286
|
+
tol: Optional[float] = None,
|
|
287
|
+
atol: Optional[float] = None,
|
|
288
|
+
maxiter: Optional[float] = 0,
|
|
289
|
+
M: Optional[_Matrix] = None,
|
|
290
|
+
callback: Optional[Callable] = None,
|
|
291
|
+
check_every=10,
|
|
292
|
+
use_cuda_graph=True,
|
|
293
|
+
) -> Tuple[int, float, float]:
|
|
294
|
+
"""Computes an approximate solution to a symmetric, positive-definite linear system
|
|
295
|
+
using the Conjugate Residual algorithm.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
A: the linear system's left-hand-side
|
|
299
|
+
b: the linear system's right-hand-side
|
|
300
|
+
x: initial guess and solution vector
|
|
301
|
+
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
302
|
+
atol: absolute tolerance for the residual
|
|
303
|
+
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
304
|
+
Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
|
|
305
|
+
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
306
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
307
|
+
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
308
|
+
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
309
|
+
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Tuple (final iteration number, residual norm, absolute tolerance)
|
|
313
|
+
|
|
314
|
+
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
A = aslinearoperator(A)
|
|
318
|
+
M = aslinearoperator(M)
|
|
319
|
+
|
|
320
|
+
if maxiter == 0:
|
|
321
|
+
maxiter = A.shape[0]
|
|
322
|
+
|
|
323
|
+
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
324
|
+
|
|
325
|
+
device = A.device
|
|
326
|
+
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
327
|
+
|
|
328
|
+
# Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
|
|
329
|
+
# with z := M^-1 r and y := M^-1 Ap
|
|
330
|
+
|
|
331
|
+
# z = M r
|
|
332
|
+
if M is None:
|
|
333
|
+
z = r
|
|
334
|
+
else:
|
|
335
|
+
z = wp.zeros_like(r)
|
|
336
|
+
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
337
|
+
|
|
338
|
+
Az = wp.zeros_like(b)
|
|
339
|
+
A.matvec(z, Az, Az, alpha=1, beta=0)
|
|
340
|
+
|
|
341
|
+
p = wp.clone(z)
|
|
342
|
+
Ap = wp.clone(Az)
|
|
343
|
+
|
|
344
|
+
if M is None:
|
|
345
|
+
y = Ap
|
|
346
|
+
else:
|
|
347
|
+
y = wp.zeros_like(Ap)
|
|
348
|
+
|
|
349
|
+
zAz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
350
|
+
zAz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
351
|
+
y_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
352
|
+
|
|
353
|
+
array_inner(z, Az, out=zAz_new)
|
|
354
|
+
|
|
355
|
+
def do_iteration(atol_sq, rr, zAz_old, zAz_new):
|
|
356
|
+
if M is not None:
|
|
357
|
+
M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
|
|
358
|
+
array_inner(Ap, y, out=y_Ap)
|
|
359
|
+
|
|
360
|
+
if M is None:
|
|
361
|
+
# In non-preconditioned case, first kernel is same as CG
|
|
362
|
+
wp.launch(
|
|
363
|
+
kernel=_cg_kernel_1,
|
|
364
|
+
dim=x.shape[0],
|
|
365
|
+
device=device,
|
|
366
|
+
inputs=[atol_sq, rr, zAz_old, y_Ap, x, r, p, Ap],
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
# In preconditioned case, we have one more vector to update
|
|
370
|
+
wp.launch(
|
|
371
|
+
kernel=_cr_kernel_1,
|
|
372
|
+
dim=x.shape[0],
|
|
373
|
+
device=device,
|
|
374
|
+
inputs=[atol_sq, rr, zAz_old, y_Ap, x, r, z, p, Ap, y],
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
array_inner(r, r, out=rr)
|
|
378
|
+
|
|
379
|
+
A.matvec(z, Az, Az, alpha=1, beta=0)
|
|
380
|
+
array_inner(z, Az, out=zAz_new)
|
|
381
|
+
|
|
382
|
+
# beta = rz_new / rz_old
|
|
383
|
+
wp.launch(
|
|
384
|
+
kernel=_cr_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr, zAz_old, zAz_new, z, p, Az, Ap]
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
|
|
388
|
+
def do_odd_even_cycle(atol_sq: float):
|
|
389
|
+
do_iteration(atol_sq, r_norm_sq, zAz_new, zAz_old)
|
|
390
|
+
do_iteration(atol_sq, r_norm_sq, zAz_old, zAz_new)
|
|
391
|
+
|
|
392
|
+
return _run_solver_loop(
|
|
393
|
+
do_odd_even_cycle,
|
|
394
|
+
cycle_size=2,
|
|
395
|
+
r_norm_sq=r_norm_sq,
|
|
396
|
+
maxiter=maxiter,
|
|
397
|
+
atol=atol,
|
|
398
|
+
callback=callback,
|
|
399
|
+
check_every=check_every,
|
|
400
|
+
use_cuda_graph=use_cuda_graph,
|
|
401
|
+
device=device,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def bicgstab(
|
|
406
|
+
A: _Matrix,
|
|
407
|
+
b: wp.array,
|
|
408
|
+
x: wp.array,
|
|
409
|
+
tol: Optional[float] = None,
|
|
410
|
+
atol: Optional[float] = None,
|
|
411
|
+
maxiter: Optional[float] = 0,
|
|
412
|
+
M: Optional[_Matrix] = None,
|
|
413
|
+
callback: Optional[Callable] = None,
|
|
414
|
+
check_every=10,
|
|
415
|
+
use_cuda_graph=True,
|
|
416
|
+
is_left_preconditioner=False,
|
|
417
|
+
):
|
|
418
|
+
"""Computes an approximate solution to a linear system using the Biconjugate Gradient Stabilized method (BiCGSTAB).
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
A: the linear system's left-hand-side
|
|
422
|
+
b: the linear system's right-hand-side
|
|
423
|
+
x: initial guess and solution vector
|
|
424
|
+
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
425
|
+
atol: absolute tolerance for the residual
|
|
426
|
+
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
427
|
+
M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
|
|
428
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
429
|
+
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
430
|
+
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
431
|
+
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
432
|
+
is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Tuple (final iteration number, residual norm, absolute tolerance)
|
|
436
|
+
|
|
437
|
+
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
438
|
+
"""
|
|
439
|
+
A = aslinearoperator(A)
|
|
440
|
+
M = aslinearoperator(M)
|
|
441
|
+
|
|
442
|
+
if maxiter == 0:
|
|
443
|
+
maxiter = A.shape[0]
|
|
444
|
+
|
|
445
|
+
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
446
|
+
|
|
447
|
+
device = A.device
|
|
448
|
+
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
449
|
+
|
|
450
|
+
# Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
|
|
451
|
+
|
|
452
|
+
rho = wp.clone(r_norm_sq, pinned=False)
|
|
453
|
+
r0v = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
454
|
+
st = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
455
|
+
tt = wp.empty(n=1, dtype=scalar_dtype, device=device)
|
|
456
|
+
|
|
457
|
+
# work arrays
|
|
458
|
+
r0 = wp.clone(r)
|
|
459
|
+
v = wp.zeros_like(r)
|
|
460
|
+
t = wp.zeros_like(r)
|
|
461
|
+
p = wp.clone(r)
|
|
462
|
+
|
|
463
|
+
if M is not None:
|
|
464
|
+
y = wp.zeros_like(p)
|
|
465
|
+
z = wp.zeros_like(r)
|
|
466
|
+
if is_left_preconditioner:
|
|
467
|
+
Mt = wp.zeros_like(t)
|
|
468
|
+
else:
|
|
469
|
+
y = p
|
|
470
|
+
z = r
|
|
471
|
+
Mt = t
|
|
472
|
+
|
|
473
|
+
def do_iteration(atol_sq: float):
|
|
474
|
+
# y = M p
|
|
475
|
+
if M is not None:
|
|
476
|
+
M.matvec(p, y, y, alpha=1.0, beta=0.0)
|
|
477
|
+
|
|
478
|
+
# v = A * y;
|
|
479
|
+
A.matvec(y, v, v, alpha=1, beta=0)
|
|
480
|
+
|
|
481
|
+
# alpha = rho / <r0 . v>
|
|
482
|
+
array_inner(r0, v, out=r0v)
|
|
483
|
+
|
|
484
|
+
# x += alpha y
|
|
485
|
+
# r -= alpha v
|
|
486
|
+
wp.launch(
|
|
487
|
+
kernel=_bicgstab_kernel_1,
|
|
488
|
+
dim=x.shape[0],
|
|
489
|
+
device=device,
|
|
490
|
+
inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
|
|
491
|
+
)
|
|
492
|
+
array_inner(r, r, out=r_norm_sq)
|
|
493
|
+
|
|
494
|
+
# z = M r
|
|
495
|
+
if M is not None:
|
|
496
|
+
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
497
|
+
|
|
498
|
+
# t = A z
|
|
499
|
+
A.matvec(z, t, t, alpha=1, beta=0)
|
|
500
|
+
|
|
501
|
+
if is_left_preconditioner:
|
|
502
|
+
# Mt = M t
|
|
503
|
+
if M is not None:
|
|
504
|
+
M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
|
|
505
|
+
|
|
506
|
+
# omega = <Mt, Ms> / <Mt, Mt>
|
|
507
|
+
array_inner(z, Mt, out=st)
|
|
508
|
+
array_inner(Mt, Mt, out=tt)
|
|
509
|
+
else:
|
|
510
|
+
array_inner(r, t, out=st)
|
|
511
|
+
array_inner(t, t, out=tt)
|
|
512
|
+
|
|
513
|
+
# x += omega z
|
|
514
|
+
# r -= omega t
|
|
515
|
+
wp.launch(
|
|
516
|
+
kernel=_bicgstab_kernel_2,
|
|
517
|
+
dim=z.shape[0],
|
|
518
|
+
device=device,
|
|
519
|
+
inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
|
|
520
|
+
)
|
|
521
|
+
array_inner(r, r, out=r_norm_sq)
|
|
522
|
+
|
|
523
|
+
# rho = <r0, r>
|
|
524
|
+
array_inner(r0, r, out=rho)
|
|
525
|
+
|
|
526
|
+
# beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
|
|
527
|
+
# p = r + beta (p - omega v)
|
|
528
|
+
wp.launch(
|
|
529
|
+
kernel=_bicgstab_kernel_3,
|
|
530
|
+
dim=z.shape[0],
|
|
531
|
+
device=device,
|
|
532
|
+
inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
return _run_solver_loop(
|
|
536
|
+
do_iteration,
|
|
537
|
+
cycle_size=1,
|
|
538
|
+
r_norm_sq=r_norm_sq,
|
|
539
|
+
maxiter=maxiter,
|
|
540
|
+
atol=atol,
|
|
541
|
+
callback=callback,
|
|
542
|
+
check_every=check_every,
|
|
543
|
+
use_cuda_graph=use_cuda_graph,
|
|
544
|
+
device=device,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def gmres(
|
|
549
|
+
A: _Matrix,
|
|
550
|
+
b: wp.array,
|
|
551
|
+
x: wp.array,
|
|
552
|
+
tol: Optional[float] = None,
|
|
553
|
+
atol: Optional[float] = None,
|
|
554
|
+
restart=31,
|
|
555
|
+
maxiter: Optional[float] = 0,
|
|
556
|
+
M: Optional[_Matrix] = None,
|
|
557
|
+
callback: Optional[Callable] = None,
|
|
558
|
+
check_every=31,
|
|
559
|
+
use_cuda_graph=True,
|
|
560
|
+
is_left_preconditioner=False,
|
|
561
|
+
):
|
|
562
|
+
"""Computes an approximate solution to a linear system using the restarted Generalized Minimum Residual method (GMRES[k]).
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
A: the linear system's left-hand-side
|
|
566
|
+
b: the linear system's right-hand-side
|
|
567
|
+
x: initial guess and solution vector
|
|
568
|
+
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
569
|
+
atol: absolute tolerance for the residual
|
|
570
|
+
restart: The restart parameter, i.e, the `k` in `GMRES[k]`. In general, increasing this parameter reduces the number of iterations but increases memory consumption.
|
|
571
|
+
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
572
|
+
Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
|
|
573
|
+
M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
|
|
574
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
575
|
+
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
576
|
+
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
577
|
+
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
578
|
+
is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
|
|
579
|
+
|
|
580
|
+
Returns:
|
|
581
|
+
Tuple (final iteration number, residual norm, absolute tolerance)
|
|
582
|
+
|
|
583
|
+
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
584
|
+
"""
|
|
585
|
+
|
|
586
|
+
A = aslinearoperator(A)
|
|
587
|
+
M = aslinearoperator(M)
|
|
588
|
+
|
|
589
|
+
if maxiter == 0:
|
|
590
|
+
maxiter = A.shape[0]
|
|
591
|
+
|
|
592
|
+
restart = min(restart, maxiter)
|
|
593
|
+
check_every = max(restart, check_every)
|
|
594
|
+
|
|
595
|
+
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
596
|
+
|
|
597
|
+
device = A.device
|
|
598
|
+
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
599
|
+
|
|
600
|
+
pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
|
|
601
|
+
|
|
602
|
+
beta_sq = wp.empty_like(r_norm_sq, pinned=False)
|
|
603
|
+
H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
|
|
604
|
+
|
|
605
|
+
y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
|
|
606
|
+
|
|
607
|
+
w = wp.zeros_like(r)
|
|
608
|
+
V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
|
|
609
|
+
|
|
610
|
+
def array_coeff(H, i, j):
|
|
611
|
+
return wp.array(
|
|
612
|
+
ptr=H.ptr + i * H.strides[0] + j * H.strides[1],
|
|
613
|
+
dtype=H.dtype,
|
|
614
|
+
shape=(1,),
|
|
615
|
+
device=H.device,
|
|
616
|
+
copy=False,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
def array_row(V, i):
|
|
620
|
+
return wp.array(
|
|
621
|
+
ptr=V.ptr + i * V.strides[0],
|
|
622
|
+
dtype=V.dtype,
|
|
623
|
+
shape=V.shape[1],
|
|
624
|
+
device=V.device,
|
|
625
|
+
copy=False,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
def do_arnoldi_iteration(j: int):
|
|
629
|
+
# w = A * v;
|
|
630
|
+
|
|
631
|
+
vj = array_row(V, j)
|
|
632
|
+
|
|
633
|
+
if M is not None:
|
|
634
|
+
tmp = array_row(V, j + 1)
|
|
635
|
+
|
|
636
|
+
if is_left_preconditioner:
|
|
637
|
+
A.matvec(vj, tmp, tmp, alpha=1, beta=0)
|
|
638
|
+
M.matvec(tmp, w, w, alpha=1, beta=0)
|
|
639
|
+
else:
|
|
640
|
+
M.matvec(vj, tmp, tmp, alpha=1, beta=0)
|
|
641
|
+
A.matvec(tmp, w, w, alpha=1, beta=0)
|
|
642
|
+
else:
|
|
643
|
+
A.matvec(vj, w, w, alpha=1, beta=0)
|
|
644
|
+
|
|
645
|
+
for i in range(j + 1):
|
|
646
|
+
vi = array_row(V, i)
|
|
647
|
+
hij = array_coeff(H, i, j)
|
|
648
|
+
array_inner(w, vi, out=hij)
|
|
649
|
+
|
|
650
|
+
wp.launch(_gmres_arnoldi_axpy_kernel, dim=w.shape, device=w.device, inputs=[vi, w, hij])
|
|
651
|
+
|
|
652
|
+
hjnj = array_coeff(H, j + 1, j)
|
|
653
|
+
array_inner(w, w, out=hjnj)
|
|
654
|
+
|
|
655
|
+
vjn = array_row(V, j + 1)
|
|
656
|
+
wp.launch(_gmres_arnoldi_normalize_kernel, dim=w.shape, device=w.device, inputs=[w, vjn, hjnj])
|
|
657
|
+
|
|
658
|
+
def do_restart_cycle(atol_sq: float):
|
|
659
|
+
if M is not None and is_left_preconditioner:
|
|
660
|
+
M.matvec(r, w, w, alpha=1, beta=0)
|
|
661
|
+
rh = w
|
|
662
|
+
else:
|
|
663
|
+
rh = r
|
|
664
|
+
|
|
665
|
+
array_inner(rh, rh, out=beta_sq)
|
|
666
|
+
|
|
667
|
+
v0 = array_row(V, 0)
|
|
668
|
+
# v0 = r / beta
|
|
669
|
+
wp.launch(_gmres_arnoldi_normalize_kernel, dim=r.shape, device=r.device, inputs=[rh, v0, beta_sq])
|
|
670
|
+
|
|
671
|
+
for j in range(restart):
|
|
672
|
+
do_arnoldi_iteration(j)
|
|
673
|
+
|
|
674
|
+
wp.launch(_gmres_normalize_lower_diagonal, dim=restart, device=device, inputs=[H])
|
|
675
|
+
wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, pivot_tolerance, beta_sq, H, y])
|
|
676
|
+
|
|
677
|
+
# update x
|
|
678
|
+
if M is None or is_left_preconditioner:
|
|
679
|
+
wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(1.0), y, V, x])
|
|
680
|
+
else:
|
|
681
|
+
wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(0.0), y, V, w])
|
|
682
|
+
M.matvec(w, x, x, alpha=1, beta=1)
|
|
683
|
+
|
|
684
|
+
# update r and residual
|
|
685
|
+
wp.copy(src=b, dest=r)
|
|
686
|
+
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
687
|
+
array_inner(r, r, out=r_norm_sq)
|
|
688
|
+
|
|
689
|
+
return _run_solver_loop(
|
|
690
|
+
do_restart_cycle,
|
|
691
|
+
cycle_size=restart,
|
|
692
|
+
r_norm_sq=r_norm_sq,
|
|
693
|
+
maxiter=maxiter,
|
|
694
|
+
atol=atol,
|
|
695
|
+
callback=callback,
|
|
696
|
+
check_every=check_every,
|
|
697
|
+
use_cuda_graph=use_cuda_graph,
|
|
698
|
+
device=device,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _get_dtype_epsilon(dtype):
|
|
703
|
+
if dtype == wp.float64:
|
|
704
|
+
return 1.0e-16
|
|
705
|
+
elif dtype == wp.float16:
|
|
706
|
+
return 1.0e-4
|
|
707
|
+
|
|
708
|
+
return 1.0e-8
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
|
|
712
|
+
eps_tol = _get_dtype_epsilon(dtype)
|
|
713
|
+
default_tol = eps_tol ** (3 / 4)
|
|
714
|
+
min_tol = eps_tol ** (9 / 4)
|
|
715
|
+
|
|
716
|
+
if tol is None and atol is None:
|
|
717
|
+
tol = atol = default_tol
|
|
718
|
+
elif tol is None:
|
|
719
|
+
tol = atol
|
|
720
|
+
elif atol is None:
|
|
721
|
+
atol = tol
|
|
722
|
+
|
|
723
|
+
return max(tol * lhs_norm, atol, min_tol)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _initialize_residual_and_tolerance(A: LinearOperator, b: wp.array, x: wp.array, tol: float, atol: float):
|
|
727
|
+
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
728
|
+
device = A.device
|
|
729
|
+
|
|
730
|
+
# Buffer for storing square norm or residual
|
|
731
|
+
r_norm_sq = wp.empty(n=1, dtype=scalar_dtype, device=device, pinned=device.is_cuda)
|
|
732
|
+
|
|
733
|
+
# Compute b norm to define absolute tolerance
|
|
734
|
+
array_inner(b, b, out=r_norm_sq)
|
|
735
|
+
atol = _get_absolute_tolerance(scalar_dtype, tol, atol, sqrt(r_norm_sq.numpy()[0]))
|
|
736
|
+
|
|
737
|
+
# Residual r = b - Ax
|
|
738
|
+
r = wp.empty_like(b)
|
|
739
|
+
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
740
|
+
|
|
741
|
+
array_inner(r, r, out=r_norm_sq)
|
|
742
|
+
|
|
743
|
+
return r, r_norm_sq, atol
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def _run_solver_loop(
|
|
747
|
+
do_cycle: Callable[[float], None],
|
|
748
|
+
cycle_size: int,
|
|
749
|
+
r_norm_sq: wp.array,
|
|
750
|
+
maxiter: int,
|
|
751
|
+
atol: float,
|
|
752
|
+
callback: Callable,
|
|
753
|
+
check_every: int,
|
|
754
|
+
use_cuda_graph: bool,
|
|
755
|
+
device,
|
|
756
|
+
):
|
|
757
|
+
atol_sq = atol * atol
|
|
758
|
+
|
|
759
|
+
cur_iter = 0
|
|
760
|
+
|
|
761
|
+
err_sq = r_norm_sq.numpy()[0]
|
|
762
|
+
err = sqrt(err_sq)
|
|
763
|
+
if callback is not None:
|
|
764
|
+
callback(cur_iter, err, atol)
|
|
765
|
+
|
|
766
|
+
if err_sq <= atol_sq:
|
|
767
|
+
return cur_iter, err, atol
|
|
768
|
+
|
|
769
|
+
graph = None
|
|
770
|
+
|
|
771
|
+
while True:
|
|
772
|
+
# Do not do graph capture at first iteration -- modules may not be loaded yet
|
|
773
|
+
if device.is_cuda and use_cuda_graph and cur_iter > 0:
|
|
774
|
+
if graph is None:
|
|
775
|
+
wp.capture_begin(device, force_module_load=False)
|
|
776
|
+
try:
|
|
777
|
+
do_cycle(atol_sq)
|
|
778
|
+
finally:
|
|
779
|
+
graph = wp.capture_end(device)
|
|
780
|
+
wp.capture_launch(graph)
|
|
781
|
+
else:
|
|
782
|
+
do_cycle(atol_sq)
|
|
783
|
+
|
|
784
|
+
cur_iter += cycle_size
|
|
785
|
+
|
|
786
|
+
if cur_iter >= maxiter:
|
|
787
|
+
break
|
|
788
|
+
|
|
789
|
+
if (cur_iter % check_every) < cycle_size:
|
|
790
|
+
err_sq = r_norm_sq.numpy()[0]
|
|
791
|
+
|
|
792
|
+
if err_sq <= atol_sq:
|
|
793
|
+
break
|
|
794
|
+
|
|
795
|
+
if callback is not None:
|
|
796
|
+
callback(cur_iter, sqrt(err_sq), atol)
|
|
797
|
+
|
|
798
|
+
err_sq = r_norm_sq.numpy()[0]
|
|
799
|
+
err = sqrt(err_sq)
|
|
800
|
+
if callback is not None:
|
|
801
|
+
callback(cur_iter, err, atol)
|
|
802
|
+
|
|
803
|
+
return cur_iter, err, atol
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
@wp.kernel
|
|
807
|
+
def _diag_mv_kernel(
|
|
808
|
+
A: wp.array(dtype=Any),
|
|
809
|
+
x: wp.array(dtype=Any),
|
|
810
|
+
y: wp.array(dtype=Any),
|
|
811
|
+
z: wp.array(dtype=Any),
|
|
812
|
+
alpha: Any,
|
|
813
|
+
beta: Any,
|
|
814
|
+
):
|
|
815
|
+
i = wp.tid()
|
|
816
|
+
z[i] = beta * y[i] + alpha * (A[i] * x[i])
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
@wp.kernel
|
|
820
|
+
def _diag_mv_vec_kernel(
|
|
821
|
+
A: wp.array(dtype=Any),
|
|
822
|
+
x: wp.array(dtype=Any),
|
|
823
|
+
y: wp.array(dtype=Any),
|
|
824
|
+
z: wp.array(dtype=Any),
|
|
825
|
+
alpha: Any,
|
|
826
|
+
beta: Any,
|
|
827
|
+
):
|
|
828
|
+
i = wp.tid()
|
|
829
|
+
z[i] = beta * y[i] + alpha * wp.cw_mul(A[i], x[i])
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
@wp.func
|
|
833
|
+
def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
|
|
834
|
+
zero = type(coeff)(0.0)
|
|
835
|
+
one = type(coeff)(1.0)
|
|
836
|
+
return wp.select(coeff == zero, one / wp.select(use_abs, coeff, wp.abs(coeff)), one)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
@wp.kernel
|
|
840
|
+
def _extract_inverse_diagonal_blocked(
|
|
841
|
+
diag_block: wp.array(dtype=Any),
|
|
842
|
+
inv_diag: wp.array(dtype=Any),
|
|
843
|
+
use_abs: int,
|
|
844
|
+
):
|
|
845
|
+
i = wp.tid()
|
|
846
|
+
|
|
847
|
+
d = wp.get_diag(diag_block[i])
|
|
848
|
+
for k in range(d.length):
|
|
849
|
+
d[k] = _inverse_diag_coefficient(d[k], use_abs != 0)
|
|
850
|
+
|
|
851
|
+
inv_diag[i] = d
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
@wp.kernel
|
|
855
|
+
def _extract_inverse_diagonal_scalar(
|
|
856
|
+
diag_array: wp.array(dtype=Any),
|
|
857
|
+
inv_diag: wp.array(dtype=Any),
|
|
858
|
+
use_abs: int,
|
|
859
|
+
):
|
|
860
|
+
i = wp.tid()
|
|
861
|
+
inv_diag[i] = _inverse_diag_coefficient(diag_array[i], use_abs != 0)
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
@wp.kernel
|
|
865
|
+
def _extract_inverse_diagonal_dense(
|
|
866
|
+
dense_matrix: wp.array2d(dtype=Any),
|
|
867
|
+
inv_diag: wp.array(dtype=Any),
|
|
868
|
+
use_abs: int,
|
|
869
|
+
):
|
|
870
|
+
i = wp.tid()
|
|
871
|
+
inv_diag[i] = _inverse_diag_coefficient(dense_matrix[i, i], use_abs != 0)
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
@wp.kernel
|
|
875
|
+
def _cg_kernel_1(
|
|
876
|
+
tol: Any,
|
|
877
|
+
resid: wp.array(dtype=Any),
|
|
878
|
+
rz_old: wp.array(dtype=Any),
|
|
879
|
+
p_Ap: wp.array(dtype=Any),
|
|
880
|
+
x: wp.array(dtype=Any),
|
|
881
|
+
r: wp.array(dtype=Any),
|
|
882
|
+
p: wp.array(dtype=Any),
|
|
883
|
+
Ap: wp.array(dtype=Any),
|
|
884
|
+
):
|
|
885
|
+
i = wp.tid()
|
|
886
|
+
|
|
887
|
+
alpha = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_old[0] / p_Ap[0])
|
|
888
|
+
|
|
889
|
+
x[i] = x[i] + alpha * p[i]
|
|
890
|
+
r[i] = r[i] - alpha * Ap[i]
|
|
891
|
+
|
|
892
|
+
|
|
893
|
+
@wp.kernel
|
|
894
|
+
def _cg_kernel_2(
|
|
895
|
+
tol: Any,
|
|
896
|
+
resid: wp.array(dtype=Any),
|
|
897
|
+
rz_old: wp.array(dtype=Any),
|
|
898
|
+
rz_new: wp.array(dtype=Any),
|
|
899
|
+
z: wp.array(dtype=Any),
|
|
900
|
+
p: wp.array(dtype=Any),
|
|
901
|
+
):
|
|
902
|
+
# p = r + (rz_new / rz_old) * p;
|
|
903
|
+
i = wp.tid()
|
|
904
|
+
|
|
905
|
+
beta = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_new[0] / rz_old[0])
|
|
906
|
+
|
|
907
|
+
p[i] = z[i] + beta * p[i]
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
@wp.kernel
|
|
911
|
+
def _cr_kernel_1(
|
|
912
|
+
tol: Any,
|
|
913
|
+
resid: wp.array(dtype=Any),
|
|
914
|
+
zAz_old: wp.array(dtype=Any),
|
|
915
|
+
y_Ap: wp.array(dtype=Any),
|
|
916
|
+
x: wp.array(dtype=Any),
|
|
917
|
+
r: wp.array(dtype=Any),
|
|
918
|
+
z: wp.array(dtype=Any),
|
|
919
|
+
p: wp.array(dtype=Any),
|
|
920
|
+
Ap: wp.array(dtype=Any),
|
|
921
|
+
y: wp.array(dtype=Any),
|
|
922
|
+
):
|
|
923
|
+
i = wp.tid()
|
|
924
|
+
|
|
925
|
+
alpha = wp.select(resid[0] > tol and y_Ap[0] > 0.0, zAz_old.dtype(0.0), zAz_old[0] / y_Ap[0])
|
|
926
|
+
|
|
927
|
+
x[i] = x[i] + alpha * p[i]
|
|
928
|
+
r[i] = r[i] - alpha * Ap[i]
|
|
929
|
+
z[i] = z[i] - alpha * y[i]
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
@wp.kernel
|
|
933
|
+
def _cr_kernel_2(
|
|
934
|
+
tol: Any,
|
|
935
|
+
resid: wp.array(dtype=Any),
|
|
936
|
+
zAz_old: wp.array(dtype=Any),
|
|
937
|
+
zAz_new: wp.array(dtype=Any),
|
|
938
|
+
z: wp.array(dtype=Any),
|
|
939
|
+
p: wp.array(dtype=Any),
|
|
940
|
+
Az: wp.array(dtype=Any),
|
|
941
|
+
Ap: wp.array(dtype=Any),
|
|
942
|
+
):
|
|
943
|
+
# p = r + (rz_new / rz_old) * p;
|
|
944
|
+
i = wp.tid()
|
|
945
|
+
|
|
946
|
+
beta = wp.select(resid[0] > tol and zAz_old[0] > 0.0, zAz_old.dtype(0.0), zAz_new[0] / zAz_old[0])
|
|
947
|
+
|
|
948
|
+
p[i] = z[i] + beta * p[i]
|
|
949
|
+
Ap[i] = Az[i] + beta * Ap[i]
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
@wp.kernel
|
|
953
|
+
def _bicgstab_kernel_1(
|
|
954
|
+
tol: Any,
|
|
955
|
+
resid: wp.array(dtype=Any),
|
|
956
|
+
rho_old: wp.array(dtype=Any),
|
|
957
|
+
r0v: wp.array(dtype=Any),
|
|
958
|
+
x: wp.array(dtype=Any),
|
|
959
|
+
r: wp.array(dtype=Any),
|
|
960
|
+
y: wp.array(dtype=Any),
|
|
961
|
+
v: wp.array(dtype=Any),
|
|
962
|
+
):
|
|
963
|
+
i = wp.tid()
|
|
964
|
+
|
|
965
|
+
alpha = wp.select(resid[0] > tol, rho_old.dtype(0.0), rho_old[0] / r0v[0])
|
|
966
|
+
|
|
967
|
+
x[i] += alpha * y[i]
|
|
968
|
+
r[i] -= alpha * v[i]
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
@wp.kernel
|
|
972
|
+
def _bicgstab_kernel_2(
|
|
973
|
+
tol: Any,
|
|
974
|
+
resid: wp.array(dtype=Any),
|
|
975
|
+
st: wp.array(dtype=Any),
|
|
976
|
+
tt: wp.array(dtype=Any),
|
|
977
|
+
z: wp.array(dtype=Any),
|
|
978
|
+
t: wp.array(dtype=Any),
|
|
979
|
+
x: wp.array(dtype=Any),
|
|
980
|
+
r: wp.array(dtype=Any),
|
|
981
|
+
):
|
|
982
|
+
i = wp.tid()
|
|
983
|
+
|
|
984
|
+
omega = wp.select(resid[0] > tol, st.dtype(0.0), st[0] / tt[0])
|
|
985
|
+
|
|
986
|
+
x[i] += omega * z[i]
|
|
987
|
+
r[i] -= omega * t[i]
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
@wp.kernel
|
|
991
|
+
def _bicgstab_kernel_3(
|
|
992
|
+
tol: Any,
|
|
993
|
+
resid: wp.array(dtype=Any),
|
|
994
|
+
rho_new: wp.array(dtype=Any),
|
|
995
|
+
r0v: wp.array(dtype=Any),
|
|
996
|
+
st: wp.array(dtype=Any),
|
|
997
|
+
tt: wp.array(dtype=Any),
|
|
998
|
+
p: wp.array(dtype=Any),
|
|
999
|
+
r: wp.array(dtype=Any),
|
|
1000
|
+
v: wp.array(dtype=Any),
|
|
1001
|
+
):
|
|
1002
|
+
i = wp.tid()
|
|
1003
|
+
|
|
1004
|
+
beta = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] * tt[0] / (r0v[0] * st[0]))
|
|
1005
|
+
beta_omega = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] / r0v[0])
|
|
1006
|
+
|
|
1007
|
+
p[i] = r[i] + beta * p[i] - beta_omega * v[i]
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
@wp.kernel
|
|
1011
|
+
def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
|
|
1012
|
+
# normalize lower-diagonal values of Hessenberg matrix
|
|
1013
|
+
i = wp.tid()
|
|
1014
|
+
H[i + 1, i] = wp.sqrt(H[i + 1, i])
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
@wp.kernel
|
|
1018
|
+
def _gmres_solve_least_squares(
|
|
1019
|
+
k: int, pivot_tolerance: float, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
|
|
1020
|
+
):
|
|
1021
|
+
# Solve H y = (beta, 0, ..., 0)
|
|
1022
|
+
# H Hessenberg matrix of shape (k+1, k)
|
|
1023
|
+
|
|
1024
|
+
# Keeping H in global mem; warp kernels are launched with fixed block size,
|
|
1025
|
+
# so would not fit in registers
|
|
1026
|
+
|
|
1027
|
+
# TODO: switch to native code with thread synchronization
|
|
1028
|
+
|
|
1029
|
+
rhs = wp.sqrt(beta_sq[0])
|
|
1030
|
+
|
|
1031
|
+
# Apply 2x2 rotations to H so as to remove lower diagonal,
|
|
1032
|
+
# and apply similar rotations to right-hand-side
|
|
1033
|
+
max_k = int(k)
|
|
1034
|
+
for i in range(k):
|
|
1035
|
+
Ha = H[i]
|
|
1036
|
+
Hb = H[i + 1]
|
|
1037
|
+
|
|
1038
|
+
# Givens rotation [[c s], [-s c]]
|
|
1039
|
+
a = Ha[i]
|
|
1040
|
+
b = Hb[i]
|
|
1041
|
+
abn_sq = a * a + b * b
|
|
1042
|
+
|
|
1043
|
+
if abn_sq < type(abn_sq)(pivot_tolerance):
|
|
1044
|
+
# Arnoldi iteration finished early
|
|
1045
|
+
max_k = i
|
|
1046
|
+
break
|
|
1047
|
+
|
|
1048
|
+
abn = wp.sqrt(abn_sq)
|
|
1049
|
+
c = a / abn
|
|
1050
|
+
s = b / abn
|
|
1051
|
+
|
|
1052
|
+
# Rotate H
|
|
1053
|
+
for j in range(i, k):
|
|
1054
|
+
a = Ha[j]
|
|
1055
|
+
b = Hb[j]
|
|
1056
|
+
Ha[j] = c * a + s * b
|
|
1057
|
+
Hb[j] = c * b - s * a
|
|
1058
|
+
|
|
1059
|
+
# Rotate rhs
|
|
1060
|
+
y[i] = c * rhs
|
|
1061
|
+
rhs = -s * rhs
|
|
1062
|
+
|
|
1063
|
+
for i in range(max_k, k):
|
|
1064
|
+
y[i] = y.dtype(0.0)
|
|
1065
|
+
|
|
1066
|
+
# Triangular back-solve for y
|
|
1067
|
+
for ii in range(max_k, 0, -1):
|
|
1068
|
+
i = ii - 1
|
|
1069
|
+
Hi = H[i]
|
|
1070
|
+
yi = y[i]
|
|
1071
|
+
for j in range(ii, max_k):
|
|
1072
|
+
yi -= Hi[j] * y[j]
|
|
1073
|
+
y[i] = yi / Hi[i]
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
@wp.kernel
|
|
1077
|
+
def _gmres_arnoldi_axpy_kernel(
|
|
1078
|
+
x: wp.array(dtype=Any),
|
|
1079
|
+
y: wp.array(dtype=Any),
|
|
1080
|
+
alpha: wp.array(dtype=Any),
|
|
1081
|
+
):
|
|
1082
|
+
tid = wp.tid()
|
|
1083
|
+
y[tid] -= x[tid] * alpha[0]
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
@wp.kernel
|
|
1087
|
+
def _gmres_arnoldi_normalize_kernel(
|
|
1088
|
+
x: wp.array(dtype=Any),
|
|
1089
|
+
y: wp.array(dtype=Any),
|
|
1090
|
+
alpha: wp.array(dtype=Any),
|
|
1091
|
+
):
|
|
1092
|
+
tid = wp.tid()
|
|
1093
|
+
y[tid] = wp.select(alpha[0] == alpha.dtype(0.0), x[tid] / wp.sqrt(alpha[0]), x[tid])
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
@wp.kernel
|
|
1097
|
+
def _gmres_update_x_kernel(k: int, beta: Any, y: wp.array(dtype=Any), V: wp.array2d(dtype=Any), x: wp.array(dtype=Any)):
|
|
1098
|
+
tid = wp.tid()
|
|
1099
|
+
|
|
1100
|
+
xi = beta * x[tid]
|
|
1101
|
+
for j in range(k):
|
|
1102
|
+
xi += V[j, tid] * y[j]
|
|
1103
|
+
|
|
1104
|
+
x[tid] = xi
|