warp-lang 1.0.1__py3-none-manylinux2014_x86_64.whl → 1.1.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -279
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -28
- warp/examples/core/example_dem.py +234 -221
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -129
- warp/examples/core/example_marching_cubes.py +188 -176
- warp/examples/core/example_mesh.py +174 -154
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -169
- warp/examples/core/example_raycast.py +105 -89
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -389
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -249
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -391
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -248
- warp/examples/optim/example_cloth_throw.py +222 -210
- warp/examples/optim/example_diffray.py +566 -535
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -169
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
- warp/examples/optim/example_spring_cage.py +239 -234
- warp/examples/optim/example_trajectory.py +223 -201
- warp/examples/optim/example_walker.py +306 -292
- warp/examples/sim/example_cartpole.py +139 -128
- warp/examples/sim/example_cloth.py +196 -184
- warp/examples/sim/example_granular.py +124 -113
- warp/examples/sim/example_granular_collision_sdf.py +197 -185
- warp/examples/sim/example_jacobian_ik.py +236 -213
- warp/examples/sim/example_particle_chain.py +118 -106
- warp/examples/sim/example_quadruped.py +193 -179
- warp/examples/sim/example_rigid_chain.py +197 -189
- warp/examples/sim/example_rigid_contact.py +189 -176
- warp/examples/sim/example_rigid_force.py +127 -126
- warp/examples/sim/example_rigid_gyroscopic.py +109 -97
- warp/examples/sim/example_rigid_soft_contact.py +134 -124
- warp/examples/sim/example_soft_body.py +190 -178
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.1.dist-info/RECORD +0 -352
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/examples/fem/bsr_utils.py
CHANGED
|
@@ -1,380 +1,378 @@
|
|
|
1
|
-
from typing import
|
|
2
|
-
|
|
3
|
-
import warp as wp
|
|
4
|
-
import warp.types
|
|
5
|
-
|
|
6
|
-
from warp.sparse import BsrMatrix,
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from scipy.sparse import csr_matrix as csr_array
|
|
16
|
-
|
|
17
|
-
if matrix.block_shape == (1, 1):
|
|
18
|
-
return csr_array(
|
|
19
|
-
(
|
|
20
|
-
matrix.values.numpy().flatten()[: matrix.nnz],
|
|
21
|
-
matrix.columns.numpy()[: matrix.nnz],
|
|
22
|
-
matrix.offsets.numpy(),
|
|
23
|
-
),
|
|
24
|
-
shape=matrix.shape,
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
return bsr_array(
|
|
28
|
-
(
|
|
29
|
-
matrix.values.numpy().reshape((matrix.values.shape[0], *matrix.block_shape))[: matrix.nnz],
|
|
30
|
-
matrix.columns.numpy()[: matrix.nnz],
|
|
31
|
-
matrix.offsets.numpy(),
|
|
32
|
-
),
|
|
33
|
-
shape=matrix.shape,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def scipy_to_bsr(
|
|
38
|
-
sp: Union["scipy.sparse.bsr_array", "scipy.sparse.csr_array"],
|
|
39
|
-
device=None,
|
|
40
|
-
dtype=None,
|
|
41
|
-
) -> BsrMatrix:
|
|
42
|
-
try:
|
|
43
|
-
from scipy.sparse import csr_array
|
|
44
|
-
except ImportError:
|
|
45
|
-
# WAR for older scipy
|
|
46
|
-
from scipy.sparse import csr_matrix as csr_array
|
|
47
|
-
|
|
48
|
-
if dtype is None:
|
|
49
|
-
dtype = warp.types.np_dtype_to_warp_type[sp.dtype]
|
|
50
|
-
|
|
51
|
-
sp.sort_indices()
|
|
52
|
-
|
|
53
|
-
if isinstance(sp, csr_array):
|
|
54
|
-
matrix = bsr_zeros(sp.shape[0], sp.shape[1], dtype, device=device)
|
|
55
|
-
else:
|
|
56
|
-
block_shape = sp.blocksize
|
|
57
|
-
block_type = wp.types.matrix(shape=block_shape, dtype=dtype)
|
|
58
|
-
matrix = bsr_zeros(
|
|
59
|
-
sp.shape[0] // block_shape[0],
|
|
60
|
-
sp.shape[1] // block_shape[1],
|
|
61
|
-
block_type,
|
|
62
|
-
device=device,
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
matrix.nnz = sp.nnz
|
|
66
|
-
matrix.values = wp.array(sp.data.flatten(), dtype=matrix.values.dtype, device=device)
|
|
67
|
-
matrix.columns = wp.array(sp.indices, dtype=matrix.columns.dtype, device=device)
|
|
68
|
-
matrix.offsets = wp.array(sp.indptr, dtype=matrix.offsets.dtype, device=device)
|
|
69
|
-
|
|
70
|
-
return matrix
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def get_linear_solver_func(method_name: str):
|
|
74
|
-
from warp.optim.linear import cg,
|
|
75
|
-
|
|
76
|
-
if method_name == "bicgstab":
|
|
77
|
-
return bicgstab
|
|
78
|
-
if method_name == "gmres":
|
|
79
|
-
return gmres
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
if mv_routine is None:
|
|
117
|
-
M = preconditioner(A, "diag") if use_diag_precond else None
|
|
118
|
-
else:
|
|
119
|
-
A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
|
|
120
|
-
M = None
|
|
121
|
-
|
|
122
|
-
func = get_linear_solver_func(method_name=method)
|
|
123
|
-
|
|
124
|
-
def print_callback(i, err, tol):
|
|
125
|
-
print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
|
|
126
|
-
|
|
127
|
-
callback = None if quiet else print_callback
|
|
128
|
-
|
|
129
|
-
end_iter, err, atol = func(
|
|
130
|
-
A=A,
|
|
131
|
-
b=b,
|
|
132
|
-
x=x,
|
|
133
|
-
maxiter=max_iters,
|
|
134
|
-
tol=tol,
|
|
135
|
-
check_every=check_every,
|
|
136
|
-
M=M,
|
|
137
|
-
callback=callback,
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
if not quiet:
|
|
141
|
-
res_str = "OK" if err <= atol else "TRUNCATED"
|
|
142
|
-
print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
|
|
143
|
-
|
|
144
|
-
return err, end_iter
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
class SaddleSystem(LinearOperator):
|
|
148
|
-
"""Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
|
|
149
|
-
|
|
150
|
-
If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
|
|
151
|
-
"""
|
|
152
|
-
|
|
153
|
-
def __init__(
|
|
154
|
-
self,
|
|
155
|
-
A: BsrMatrix,
|
|
156
|
-
B: BsrMatrix,
|
|
157
|
-
Bt: Optional[BsrMatrix] = None,
|
|
158
|
-
use_diag_precond: bool = True,
|
|
159
|
-
):
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
self.
|
|
166
|
-
|
|
167
|
-
self.
|
|
168
|
-
|
|
169
|
-
self.
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
saddle_shape =
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
self._preconditioner =
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
bsr_mv(self.
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
wp.copy(src=
|
|
298
|
-
wp.copy(src=
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
i = wp.tid()
|
|
380
|
-
values[i] = scale * values[i] / wp.ddot(values[i], values[i])
|
|
1
|
+
from typing import Any, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import warp as wp
|
|
4
|
+
import warp.types
|
|
5
|
+
from warp.optim.linear import LinearOperator, aslinearoperator, preconditioner
|
|
6
|
+
from warp.sparse import BsrMatrix, bsr_get_diag, bsr_mv, bsr_transposed, bsr_zeros
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def bsr_to_scipy(matrix: BsrMatrix) -> "scipy.sparse.bsr_array": # noqa: F821
|
|
10
|
+
try:
|
|
11
|
+
from scipy.sparse import bsr_array, csr_array
|
|
12
|
+
except ImportError:
|
|
13
|
+
# WAR for older scipy
|
|
14
|
+
from scipy.sparse import bsr_matrix as bsr_array
|
|
15
|
+
from scipy.sparse import csr_matrix as csr_array
|
|
16
|
+
|
|
17
|
+
if matrix.block_shape == (1, 1):
|
|
18
|
+
return csr_array(
|
|
19
|
+
(
|
|
20
|
+
matrix.values.numpy().flatten()[: matrix.nnz],
|
|
21
|
+
matrix.columns.numpy()[: matrix.nnz],
|
|
22
|
+
matrix.offsets.numpy(),
|
|
23
|
+
),
|
|
24
|
+
shape=matrix.shape,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
return bsr_array(
|
|
28
|
+
(
|
|
29
|
+
matrix.values.numpy().reshape((matrix.values.shape[0], *matrix.block_shape))[: matrix.nnz],
|
|
30
|
+
matrix.columns.numpy()[: matrix.nnz],
|
|
31
|
+
matrix.offsets.numpy(),
|
|
32
|
+
),
|
|
33
|
+
shape=matrix.shape,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def scipy_to_bsr(
|
|
38
|
+
sp: Union["scipy.sparse.bsr_array", "scipy.sparse.csr_array"], # noqa: F821
|
|
39
|
+
device=None,
|
|
40
|
+
dtype=None,
|
|
41
|
+
) -> BsrMatrix:
|
|
42
|
+
try:
|
|
43
|
+
from scipy.sparse import csr_array
|
|
44
|
+
except ImportError:
|
|
45
|
+
# WAR for older scipy
|
|
46
|
+
from scipy.sparse import csr_matrix as csr_array
|
|
47
|
+
|
|
48
|
+
if dtype is None:
|
|
49
|
+
dtype = warp.types.np_dtype_to_warp_type[sp.dtype]
|
|
50
|
+
|
|
51
|
+
sp.sort_indices()
|
|
52
|
+
|
|
53
|
+
if isinstance(sp, csr_array):
|
|
54
|
+
matrix = bsr_zeros(sp.shape[0], sp.shape[1], dtype, device=device)
|
|
55
|
+
else:
|
|
56
|
+
block_shape = sp.blocksize
|
|
57
|
+
block_type = wp.types.matrix(shape=block_shape, dtype=dtype)
|
|
58
|
+
matrix = bsr_zeros(
|
|
59
|
+
sp.shape[0] // block_shape[0],
|
|
60
|
+
sp.shape[1] // block_shape[1],
|
|
61
|
+
block_type,
|
|
62
|
+
device=device,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
matrix.nnz = sp.nnz
|
|
66
|
+
matrix.values = wp.array(sp.data.flatten(), dtype=matrix.values.dtype, device=device)
|
|
67
|
+
matrix.columns = wp.array(sp.indices, dtype=matrix.columns.dtype, device=device)
|
|
68
|
+
matrix.offsets = wp.array(sp.indptr, dtype=matrix.offsets.dtype, device=device)
|
|
69
|
+
|
|
70
|
+
return matrix
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_linear_solver_func(method_name: str):
|
|
74
|
+
from warp.optim.linear import bicgstab, cg, cr, gmres
|
|
75
|
+
|
|
76
|
+
if method_name == "bicgstab":
|
|
77
|
+
return bicgstab
|
|
78
|
+
if method_name == "gmres":
|
|
79
|
+
return gmres
|
|
80
|
+
if method_name == "cr":
|
|
81
|
+
return cr
|
|
82
|
+
return cg
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def bsr_cg(
|
|
86
|
+
A: BsrMatrix,
|
|
87
|
+
x: wp.array,
|
|
88
|
+
b: wp.array,
|
|
89
|
+
max_iters: int = 0,
|
|
90
|
+
tol: float = 0.0001,
|
|
91
|
+
check_every=10,
|
|
92
|
+
use_diag_precond=True,
|
|
93
|
+
mv_routine=None,
|
|
94
|
+
quiet=False,
|
|
95
|
+
method: str = "cg",
|
|
96
|
+
) -> Tuple[float, int]:
|
|
97
|
+
"""Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
A: system left-hand side
|
|
101
|
+
x: result vector and initial guess
|
|
102
|
+
b: system right-hand-side
|
|
103
|
+
max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
|
|
104
|
+
tol: relative tolerance under which to stop the solve
|
|
105
|
+
check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
|
|
106
|
+
use_diag_precond: Whether to use diagonal preconditioning
|
|
107
|
+
mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
|
|
108
|
+
quiet: if True, do not print iteration residuals
|
|
109
|
+
method: Iterative solver method to use, defaults to Conjugate Gradient
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Tuple (residual norm, iteration count)
|
|
113
|
+
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
if mv_routine is None:
|
|
117
|
+
M = preconditioner(A, "diag") if use_diag_precond else None
|
|
118
|
+
else:
|
|
119
|
+
A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
|
|
120
|
+
M = None
|
|
121
|
+
|
|
122
|
+
func = get_linear_solver_func(method_name=method)
|
|
123
|
+
|
|
124
|
+
def print_callback(i, err, tol):
|
|
125
|
+
print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
|
|
126
|
+
|
|
127
|
+
callback = None if quiet else print_callback
|
|
128
|
+
|
|
129
|
+
end_iter, err, atol = func(
|
|
130
|
+
A=A,
|
|
131
|
+
b=b,
|
|
132
|
+
x=x,
|
|
133
|
+
maxiter=max_iters,
|
|
134
|
+
tol=tol,
|
|
135
|
+
check_every=check_every,
|
|
136
|
+
M=M,
|
|
137
|
+
callback=callback,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if not quiet:
|
|
141
|
+
res_str = "OK" if err <= atol else "TRUNCATED"
|
|
142
|
+
print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
|
|
143
|
+
|
|
144
|
+
return err, end_iter
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class SaddleSystem(LinearOperator):
|
|
148
|
+
"""Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
|
|
149
|
+
|
|
150
|
+
If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
A: BsrMatrix,
|
|
156
|
+
B: BsrMatrix,
|
|
157
|
+
Bt: Optional[BsrMatrix] = None,
|
|
158
|
+
use_diag_precond: bool = True,
|
|
159
|
+
):
|
|
160
|
+
if Bt is None:
|
|
161
|
+
Bt = bsr_transposed(B)
|
|
162
|
+
|
|
163
|
+
self._A = A
|
|
164
|
+
self._B = B
|
|
165
|
+
self._Bt = Bt
|
|
166
|
+
|
|
167
|
+
self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
|
|
168
|
+
self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
|
|
169
|
+
self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
|
|
170
|
+
|
|
171
|
+
saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
|
|
172
|
+
|
|
173
|
+
super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
|
|
174
|
+
|
|
175
|
+
if use_diag_precond:
|
|
176
|
+
self._preconditioner = self._diag_preconditioner()
|
|
177
|
+
else:
|
|
178
|
+
self._preconditioner = None
|
|
179
|
+
|
|
180
|
+
def _diag_preconditioner(self):
|
|
181
|
+
A = self._A
|
|
182
|
+
B = self._B
|
|
183
|
+
|
|
184
|
+
M_u = preconditioner(A, "diag")
|
|
185
|
+
|
|
186
|
+
A_diag = bsr_get_diag(A)
|
|
187
|
+
|
|
188
|
+
schur_block_shape = (B.block_shape[0], B.block_shape[0])
|
|
189
|
+
schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
|
|
190
|
+
schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
|
|
191
|
+
wp.launch(
|
|
192
|
+
_compute_schur_inverse_diagonal,
|
|
193
|
+
dim=B.nrow,
|
|
194
|
+
device=A.device,
|
|
195
|
+
inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if schur_block_shape == (1, 1):
|
|
199
|
+
# Downcast 1x1 mats to scalars
|
|
200
|
+
schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
|
|
201
|
+
|
|
202
|
+
M_p = aslinearoperator(schur_inv_diag)
|
|
203
|
+
|
|
204
|
+
def precond_mv(x, y, z, alpha, beta):
|
|
205
|
+
x_u = self.u_slice(x)
|
|
206
|
+
x_p = self.p_slice(x)
|
|
207
|
+
y_u = self.u_slice(y)
|
|
208
|
+
y_p = self.p_slice(y)
|
|
209
|
+
z_u = self.u_slice(z)
|
|
210
|
+
z_p = self.p_slice(z)
|
|
211
|
+
|
|
212
|
+
M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
|
|
213
|
+
M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
|
|
214
|
+
|
|
215
|
+
return LinearOperator(
|
|
216
|
+
shape=self.shape,
|
|
217
|
+
dtype=self.dtype,
|
|
218
|
+
device=self.device,
|
|
219
|
+
matvec=precond_mv,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def preconditioner(self):
|
|
224
|
+
return self._preconditioner
|
|
225
|
+
|
|
226
|
+
def u_slice(self, a: wp.array):
|
|
227
|
+
return wp.array(
|
|
228
|
+
ptr=a.ptr,
|
|
229
|
+
dtype=self._u_dtype,
|
|
230
|
+
shape=self._A.nrow,
|
|
231
|
+
strides=None,
|
|
232
|
+
device=a.device,
|
|
233
|
+
pinned=a.pinned,
|
|
234
|
+
copy=False,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def p_slice(self, a: wp.array):
|
|
238
|
+
return wp.array(
|
|
239
|
+
ptr=a.ptr + self._p_byte_offset,
|
|
240
|
+
dtype=self._p_dtype,
|
|
241
|
+
shape=self._B.nrow,
|
|
242
|
+
strides=None,
|
|
243
|
+
device=a.device,
|
|
244
|
+
pinned=a.pinned,
|
|
245
|
+
copy=False,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def _saddle_mv(self, x, y, z, alpha, beta):
|
|
249
|
+
x_u = self.u_slice(x)
|
|
250
|
+
x_p = self.p_slice(x)
|
|
251
|
+
z_u = self.u_slice(z)
|
|
252
|
+
z_p = self.p_slice(z)
|
|
253
|
+
|
|
254
|
+
if y.ptr != z.ptr and beta != 0.0:
|
|
255
|
+
wp.copy(src=y, dest=z)
|
|
256
|
+
|
|
257
|
+
bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
|
|
258
|
+
bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
|
|
259
|
+
bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def bsr_solve_saddle(
|
|
263
|
+
saddle_system: SaddleSystem,
|
|
264
|
+
x_u: wp.array,
|
|
265
|
+
x_p: wp.array,
|
|
266
|
+
b_u: wp.array,
|
|
267
|
+
b_p: wp.array,
|
|
268
|
+
max_iters: int = 0,
|
|
269
|
+
tol: float = 0.0001,
|
|
270
|
+
check_every=10,
|
|
271
|
+
quiet=False,
|
|
272
|
+
method: str = "cg",
|
|
273
|
+
) -> Tuple[float, int]:
|
|
274
|
+
"""Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
saddle_system: Saddle point system
|
|
278
|
+
x_u: primal part of the result vector and initial guess
|
|
279
|
+
x_p: Lagrange multiplier part of the result vector and initial guess
|
|
280
|
+
b_u: primal left-hand-side
|
|
281
|
+
b_p: constraint left-hand-side
|
|
282
|
+
max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
|
|
283
|
+
tol: relative tolerance under which to stop the solve
|
|
284
|
+
check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
|
|
285
|
+
quiet: if True, do not print iteration residuals
|
|
286
|
+
method: Iterative solver method to use, defaults to BiCGSTAB
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Tuple (residual norm, iteration count)
|
|
290
|
+
|
|
291
|
+
"""
|
|
292
|
+
x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
|
|
293
|
+
b = wp.empty_like(x)
|
|
294
|
+
|
|
295
|
+
wp.copy(src=x_u, dest=saddle_system.u_slice(x))
|
|
296
|
+
wp.copy(src=x_p, dest=saddle_system.p_slice(x))
|
|
297
|
+
wp.copy(src=b_u, dest=saddle_system.u_slice(b))
|
|
298
|
+
wp.copy(src=b_p, dest=saddle_system.p_slice(b))
|
|
299
|
+
|
|
300
|
+
func = get_linear_solver_func(method_name=method)
|
|
301
|
+
|
|
302
|
+
def print_callback(i, err, tol):
|
|
303
|
+
print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
|
|
304
|
+
|
|
305
|
+
callback = None if quiet else print_callback
|
|
306
|
+
|
|
307
|
+
end_iter, err, atol = func(
|
|
308
|
+
A=saddle_system,
|
|
309
|
+
b=b,
|
|
310
|
+
x=x,
|
|
311
|
+
maxiter=max_iters,
|
|
312
|
+
tol=tol,
|
|
313
|
+
check_every=check_every,
|
|
314
|
+
M=saddle_system.preconditioner,
|
|
315
|
+
callback=callback,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if not quiet:
|
|
319
|
+
res_str = "OK" if err <= atol else "TRUNCATED"
|
|
320
|
+
print(f"{func.__name__}: terminated after {end_iter} iterations with absolute error = \t {err} ({res_str})")
|
|
321
|
+
|
|
322
|
+
wp.copy(dest=x_u, src=saddle_system.u_slice(x))
|
|
323
|
+
wp.copy(dest=x_p, src=saddle_system.p_slice(x))
|
|
324
|
+
|
|
325
|
+
return err, end_iter
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@wp.kernel
|
|
329
|
+
def _compute_schur_inverse_diagonal(
|
|
330
|
+
B_offsets: wp.array(dtype=int),
|
|
331
|
+
B_indices: wp.array(dtype=int),
|
|
332
|
+
B_values: wp.array(dtype=Any),
|
|
333
|
+
A_diag: wp.array(dtype=Any),
|
|
334
|
+
P_diag: wp.array(dtype=Any),
|
|
335
|
+
):
|
|
336
|
+
row = wp.tid()
|
|
337
|
+
|
|
338
|
+
zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
|
|
339
|
+
|
|
340
|
+
schur = zero
|
|
341
|
+
|
|
342
|
+
beg = B_offsets[row]
|
|
343
|
+
end = B_offsets[row + 1]
|
|
344
|
+
|
|
345
|
+
for b in range(beg, end):
|
|
346
|
+
B = B_values[b]
|
|
347
|
+
col = B_indices[b]
|
|
348
|
+
Ai = wp.inverse(A_diag[col])
|
|
349
|
+
S = B * Ai * wp.transpose(B)
|
|
350
|
+
schur += S
|
|
351
|
+
|
|
352
|
+
schur_diag = wp.get_diag(schur)
|
|
353
|
+
id_diag = type(schur_diag)(schur_diag.dtype(1.0))
|
|
354
|
+
|
|
355
|
+
inv_diag = wp.select(schur == zero, wp.cw_div(id_diag, schur_diag), id_diag)
|
|
356
|
+
P_diag[row] = wp.diag(inv_diag)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def invert_diagonal_bsr_mass_matrix(A: BsrMatrix):
|
|
360
|
+
"""Inverts each block of a block-diagonal mass matrix"""
|
|
361
|
+
|
|
362
|
+
scale = A.scalar_type(A.block_shape[0])
|
|
363
|
+
values = A.values
|
|
364
|
+
if not wp.types.type_is_matrix(values.dtype):
|
|
365
|
+
values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
|
|
366
|
+
|
|
367
|
+
wp.launch(
|
|
368
|
+
kernel=_block_diagonal_mass_invert,
|
|
369
|
+
dim=A.nrow,
|
|
370
|
+
inputs=[values, scale],
|
|
371
|
+
device=values.device,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@wp.kernel
|
|
376
|
+
def _block_diagonal_mass_invert(values: wp.array(dtype=Any), scale: Any):
|
|
377
|
+
i = wp.tid()
|
|
378
|
+
values[i] = scale * values[i] / wp.ddot(values[i], values[i])
|