warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +47 -67
- warp/builtins.py +955 -137
- warp/codegen.py +312 -206
- warp/config.py +1 -1
- warp/context.py +1249 -784
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +2 -1
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +82 -5
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +283 -69
- warp/native/vec.h +381 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +323 -192
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +85 -6
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +56 -5
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +184 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/optim/linear.py
CHANGED
|
@@ -13,12 +13,15 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
import functools
|
|
17
|
+
import math
|
|
17
18
|
from typing import Any, Callable, Optional, Tuple, Union
|
|
18
19
|
|
|
19
20
|
import warp as wp
|
|
20
21
|
import warp.sparse as sparse
|
|
21
|
-
from warp.
|
|
22
|
+
from warp.types import type_length, type_scalar_type
|
|
23
|
+
|
|
24
|
+
__all__ = ["LinearOperator", "aslinearoperator", "bicgstab", "cg", "cr", "gmres", "preconditioner"]
|
|
22
25
|
|
|
23
26
|
# No need to auto-generate adjoint code for linear solvers
|
|
24
27
|
wp.set_module_options({"enable_backward": False})
|
|
@@ -103,19 +106,33 @@ def aslinearoperator(A: _Matrix) -> LinearOperator:
|
|
|
103
106
|
sparse.bsr_mv(A, x, z, alpha, beta)
|
|
104
107
|
|
|
105
108
|
def dense_mv(x, y, z, alpha, beta):
|
|
106
|
-
|
|
109
|
+
alpha = A.dtype(alpha)
|
|
110
|
+
beta = A.dtype(beta)
|
|
111
|
+
if A.device.is_cuda:
|
|
112
|
+
tile_size = 1 << min(10, max(5, math.ceil(math.log2(A.shape[1]))))
|
|
113
|
+
else:
|
|
114
|
+
tile_size = 1
|
|
115
|
+
wp.launch(
|
|
116
|
+
_dense_mv_kernel,
|
|
117
|
+
dim=(A.shape[0], tile_size),
|
|
118
|
+
block_dim=tile_size,
|
|
119
|
+
device=A.device,
|
|
120
|
+
inputs=[A, x, y, z, alpha, beta],
|
|
121
|
+
)
|
|
107
122
|
|
|
108
|
-
def
|
|
109
|
-
scalar_type =
|
|
123
|
+
def diag_mv_impl(A, x, y, z, alpha, beta):
|
|
124
|
+
scalar_type = type_scalar_type(A.dtype)
|
|
110
125
|
alpha = scalar_type(alpha)
|
|
111
126
|
beta = scalar_type(beta)
|
|
112
127
|
wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
113
128
|
|
|
129
|
+
def diag_mv(x, y, z, alpha, beta):
|
|
130
|
+
return diag_mv_impl(A, x, y, z, alpha, beta)
|
|
131
|
+
|
|
114
132
|
def diag_mv_vec(x, y, z, alpha, beta):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
wp.launch(_diag_mv_vec_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
|
|
133
|
+
return diag_mv_impl(
|
|
134
|
+
_as_scalar_array(A), _as_scalar_array(x), _as_scalar_array(y), _as_scalar_array(z), alpha, beta
|
|
135
|
+
)
|
|
119
136
|
|
|
120
137
|
if isinstance(A, wp.array):
|
|
121
138
|
if A.ndim == 2:
|
|
@@ -183,6 +200,147 @@ def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
|
|
|
183
200
|
raise ValueError(f"Unsupported preconditioner type '{ptype}'")
|
|
184
201
|
|
|
185
202
|
|
|
203
|
+
def _as_scalar_array(x: wp.array):
|
|
204
|
+
scalar_type = type_scalar_type(x.dtype)
|
|
205
|
+
if scalar_type == x.dtype:
|
|
206
|
+
return x
|
|
207
|
+
|
|
208
|
+
dlen = type_length(x.dtype)
|
|
209
|
+
arr = wp.array(
|
|
210
|
+
ptr=x.ptr,
|
|
211
|
+
shape=(*x.shape[:-1], x.shape[-1] * dlen),
|
|
212
|
+
strides=(*x.strides[:-1], x.strides[-1] // dlen),
|
|
213
|
+
dtype=scalar_type,
|
|
214
|
+
device=x.device,
|
|
215
|
+
grad=None if x.grad is None else _as_scalar_array(x.grad),
|
|
216
|
+
)
|
|
217
|
+
arr._ref = x
|
|
218
|
+
return arr
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class TiledDot:
|
|
222
|
+
"""
|
|
223
|
+
Computes the dot product of two arrays in a way that is compatible with CUDA sub-graphs.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
def __init__(self, max_length: int, scalar_type: type, tile_size=512, device=None, max_column_count: int = 1):
|
|
227
|
+
self.tile_size = tile_size
|
|
228
|
+
self.device = device
|
|
229
|
+
self.max_column_count = max_column_count
|
|
230
|
+
|
|
231
|
+
num_blocks = (max_length + self.tile_size - 1) // self.tile_size
|
|
232
|
+
scratch = wp.empty(
|
|
233
|
+
shape=(2, max_column_count, num_blocks),
|
|
234
|
+
dtype=scalar_type,
|
|
235
|
+
device=self.device,
|
|
236
|
+
)
|
|
237
|
+
self.partial_sums_a = scratch[0]
|
|
238
|
+
self.partial_sums_b = scratch[1]
|
|
239
|
+
|
|
240
|
+
self.dot_kernel, self.sum_kernel = _create_tiled_dot_kernels(self.tile_size)
|
|
241
|
+
|
|
242
|
+
rounds = 0
|
|
243
|
+
length = num_blocks
|
|
244
|
+
while length > 1:
|
|
245
|
+
length = (length + self.tile_size - 1) // self.tile_size
|
|
246
|
+
rounds += 1
|
|
247
|
+
|
|
248
|
+
self.rounds = rounds
|
|
249
|
+
|
|
250
|
+
self._output = self.partial_sums_a if rounds % 2 == 0 else self.partial_sums_b
|
|
251
|
+
|
|
252
|
+
self.dot_launch: wp.Launch = wp.launch(
|
|
253
|
+
self.dot_kernel,
|
|
254
|
+
dim=(max_column_count, num_blocks, self.tile_size),
|
|
255
|
+
inputs=(self.partial_sums_a, self.partial_sums_b),
|
|
256
|
+
outputs=(self.partial_sums_a,),
|
|
257
|
+
block_dim=self.tile_size,
|
|
258
|
+
record_cmd=True,
|
|
259
|
+
)
|
|
260
|
+
self.sum_launch: wp.Launch = wp.launch(
|
|
261
|
+
self.sum_kernel,
|
|
262
|
+
dim=(max_column_count, num_blocks, self.tile_size),
|
|
263
|
+
inputs=(self.partial_sums_a,),
|
|
264
|
+
outputs=(self.partial_sums_b,),
|
|
265
|
+
block_dim=self.tile_size,
|
|
266
|
+
record_cmd=True,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Result contains a single value, the sum of the array (will get updated by this function)
|
|
270
|
+
def compute(self, a: wp.array, b: wp.array, col_offset: int = 0):
|
|
271
|
+
a = _as_scalar_array(a)
|
|
272
|
+
b = _as_scalar_array(b)
|
|
273
|
+
if a.ndim == 1:
|
|
274
|
+
a = a.reshape((1, -1))
|
|
275
|
+
if b.ndim == 1:
|
|
276
|
+
b = b.reshape((1, -1))
|
|
277
|
+
|
|
278
|
+
column_count = a.shape[0]
|
|
279
|
+
num_blocks = (a.shape[1] + self.tile_size - 1) // self.tile_size
|
|
280
|
+
|
|
281
|
+
data_out = self.partial_sums_a[col_offset : col_offset + column_count]
|
|
282
|
+
data_in = self.partial_sums_b[col_offset : col_offset + column_count]
|
|
283
|
+
|
|
284
|
+
self.dot_launch.set_param_at_index(0, a)
|
|
285
|
+
self.dot_launch.set_param_at_index(1, b)
|
|
286
|
+
self.dot_launch.set_param_at_index(2, data_out)
|
|
287
|
+
self.dot_launch.set_dim((column_count, num_blocks, self.tile_size))
|
|
288
|
+
self.dot_launch.launch()
|
|
289
|
+
|
|
290
|
+
for _r in range(self.rounds):
|
|
291
|
+
array_length = num_blocks
|
|
292
|
+
num_blocks = (array_length + self.tile_size - 1) // self.tile_size
|
|
293
|
+
data_in, data_out = data_out, data_in
|
|
294
|
+
|
|
295
|
+
self.sum_launch.set_param_at_index(0, data_in)
|
|
296
|
+
self.sum_launch.set_param_at_index(1, data_out)
|
|
297
|
+
self.sum_launch.set_dim((column_count, num_blocks, self.tile_size))
|
|
298
|
+
self.sum_launch.launch()
|
|
299
|
+
|
|
300
|
+
return data_out
|
|
301
|
+
|
|
302
|
+
def col(self, col: int = 0):
|
|
303
|
+
return self._output[col][:1]
|
|
304
|
+
|
|
305
|
+
def cols(self, count, start: int = 0):
|
|
306
|
+
return self._output[start : start + count, :1]
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@functools.lru_cache(maxsize=None)
|
|
310
|
+
def _create_tiled_dot_kernels(tile_size):
|
|
311
|
+
@wp.kernel
|
|
312
|
+
def block_dot_kernel(
|
|
313
|
+
a: wp.array2d(dtype=Any),
|
|
314
|
+
b: wp.array2d(dtype=Any),
|
|
315
|
+
partial_sums: wp.array2d(dtype=Any),
|
|
316
|
+
):
|
|
317
|
+
column, block_id, tid_block = wp.tid()
|
|
318
|
+
|
|
319
|
+
start = block_id * tile_size
|
|
320
|
+
|
|
321
|
+
a_block = wp.tile_load(a[column], shape=tile_size, offset=start)
|
|
322
|
+
b_block = wp.tile_load(b[column], shape=tile_size, offset=start)
|
|
323
|
+
t = wp.tile_map(wp.mul, a_block, b_block)
|
|
324
|
+
|
|
325
|
+
tile_sum = wp.tile_sum(t)
|
|
326
|
+
wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
|
|
327
|
+
|
|
328
|
+
@wp.kernel
|
|
329
|
+
def block_sum_kernel(
|
|
330
|
+
data: wp.array2d(dtype=Any),
|
|
331
|
+
partial_sums: wp.array2d(dtype=Any),
|
|
332
|
+
):
|
|
333
|
+
column, block_id, tid_block = wp.tid()
|
|
334
|
+
start = block_id * tile_size
|
|
335
|
+
|
|
336
|
+
t = wp.tile_load(data[column], shape=tile_size, offset=start)
|
|
337
|
+
|
|
338
|
+
tile_sum = wp.tile_sum(t)
|
|
339
|
+
wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
|
|
340
|
+
|
|
341
|
+
return block_dot_kernel, block_sum_kernel
|
|
342
|
+
|
|
343
|
+
|
|
186
344
|
def cg(
|
|
187
345
|
A: _Matrix,
|
|
188
346
|
b: wp.array,
|
|
@@ -194,7 +352,7 @@ def cg(
|
|
|
194
352
|
callback: Optional[Callable] = None,
|
|
195
353
|
check_every=10,
|
|
196
354
|
use_cuda_graph=True,
|
|
197
|
-
) -> Tuple[int, float, float]:
|
|
355
|
+
) -> Union[Tuple[int, float, float], Tuple[wp.array, wp.array, wp.array]]:
|
|
198
356
|
"""Computes an approximate solution to a symmetric, positive-definite linear system
|
|
199
357
|
using the Conjugate Gradient algorithm.
|
|
200
358
|
|
|
@@ -205,94 +363,107 @@ def cg(
|
|
|
205
363
|
tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
|
|
206
364
|
atol: absolute tolerance for the residual
|
|
207
365
|
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
208
|
-
Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
|
|
209
366
|
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
210
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
367
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
368
|
+
If `check_every` is 0, the callback should be a Warp kernel.
|
|
211
369
|
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
370
|
+
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
371
|
+
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
372
|
+
to the maximum number of iterations.
|
|
212
373
|
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
213
|
-
|
|
374
|
+
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
214
375
|
|
|
215
376
|
Returns:
|
|
216
|
-
|
|
377
|
+
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
378
|
+
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
379
|
+
- residual_norm: The final residual norm ||b - Ax||
|
|
380
|
+
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
381
|
+
|
|
382
|
+
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
383
|
+
- final_iteration_array: Device array containing the number of iterations performed
|
|
384
|
+
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
385
|
+
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
217
386
|
|
|
218
387
|
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
219
388
|
"""
|
|
220
|
-
|
|
221
389
|
A = aslinearoperator(A)
|
|
222
390
|
M = aslinearoperator(M)
|
|
223
391
|
|
|
224
392
|
if maxiter == 0:
|
|
225
393
|
maxiter = A.shape[0]
|
|
226
394
|
|
|
227
|
-
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
228
|
-
|
|
229
395
|
device = A.device
|
|
230
|
-
|
|
396
|
+
scalar_type = A.scalar_type
|
|
231
397
|
|
|
232
|
-
#
|
|
398
|
+
# Temp storage
|
|
399
|
+
r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
|
|
400
|
+
p_and_Ap = wp.empty_like(r_and_z)
|
|
401
|
+
residuals = wp.empty(2, dtype=scalar_type, device=device)
|
|
233
402
|
|
|
234
|
-
|
|
235
|
-
if M is not None:
|
|
236
|
-
z = wp.zeros_like(b)
|
|
237
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
403
|
+
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
|
|
238
404
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
405
|
+
# named views
|
|
406
|
+
|
|
407
|
+
# (r, r) -- so we can compute r.z and r.r at once
|
|
408
|
+
r_repeated = _repeat_first(r_and_z)
|
|
409
|
+
if M is None:
|
|
410
|
+
# without preconditioner r == z
|
|
411
|
+
r_and_z = r_repeated
|
|
412
|
+
rz_new = tiled_dot.col(0)
|
|
242
413
|
else:
|
|
243
|
-
|
|
414
|
+
rz_new = tiled_dot.col(1)
|
|
415
|
+
|
|
416
|
+
r, z = r_and_z[0], r_and_z[1]
|
|
417
|
+
r_norm_sq = tiled_dot.col(0)
|
|
418
|
+
|
|
419
|
+
p, Ap = p_and_Ap[0], p_and_Ap[1]
|
|
420
|
+
rz_old, atol_sq = residuals[0:1], residuals[1:2]
|
|
421
|
+
|
|
422
|
+
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
423
|
+
Ap.zero_()
|
|
424
|
+
z.zero_()
|
|
425
|
+
|
|
426
|
+
# Initialize tolerance from right-hand-side norm
|
|
427
|
+
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
428
|
+
# Initialize residual
|
|
429
|
+
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
430
|
+
|
|
431
|
+
def update_rr_rz():
|
|
432
|
+
# z = M r
|
|
433
|
+
if M is None:
|
|
434
|
+
tiled_dot.compute(r, r)
|
|
435
|
+
else:
|
|
436
|
+
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
437
|
+
tiled_dot.compute(r_repeated, r_and_z)
|
|
244
438
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
Ap = wp.zeros_like(b)
|
|
439
|
+
update_rr_rz()
|
|
440
|
+
p.assign(z)
|
|
248
441
|
|
|
249
|
-
|
|
442
|
+
def do_iteration():
|
|
443
|
+
rz_old.assign(rz_new)
|
|
250
444
|
|
|
251
|
-
def do_iteration(atol_sq, rr_old, rr_new, rz_old, rz_new):
|
|
252
445
|
# Ap = A * p;
|
|
253
446
|
A.matvec(p, Ap, Ap, alpha=1, beta=0)
|
|
254
|
-
|
|
255
|
-
|
|
447
|
+
tiled_dot.compute(p, Ap, col_offset=1)
|
|
448
|
+
p_Ap = tiled_dot.col(1)
|
|
256
449
|
|
|
257
450
|
wp.launch(
|
|
258
451
|
kernel=_cg_kernel_1,
|
|
259
452
|
dim=x.shape[0],
|
|
260
453
|
device=device,
|
|
261
|
-
inputs=[atol_sq,
|
|
454
|
+
inputs=[atol_sq, r_norm_sq, rz_old, p_Ap, x, r, p, Ap],
|
|
262
455
|
)
|
|
263
|
-
array_inner(r, r, out=rr_new)
|
|
264
|
-
|
|
265
|
-
# z = M r
|
|
266
|
-
if M is not None:
|
|
267
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
268
|
-
# rz = r' z;
|
|
269
|
-
array_inner(r, z, out=rz_new)
|
|
270
|
-
|
|
271
|
-
wp.launch(kernel=_cg_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr_new, rz_old, rz_new, z, p])
|
|
272
456
|
|
|
273
|
-
|
|
274
|
-
# In the non-preconditioned case we reuse the error norm buffer for the new <r,z> computation
|
|
457
|
+
update_rr_rz()
|
|
275
458
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_new, rz_old)
|
|
283
|
-
do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_old, rz_new)
|
|
459
|
+
wp.launch(
|
|
460
|
+
kernel=_cg_kernel_2,
|
|
461
|
+
dim=z.shape[0],
|
|
462
|
+
device=device,
|
|
463
|
+
inputs=[atol_sq, r_norm_sq, rz_old, rz_new, z, p],
|
|
464
|
+
)
|
|
284
465
|
|
|
285
|
-
return
|
|
286
|
-
do_odd_even_cycle,
|
|
287
|
-
cycle_size=2,
|
|
288
|
-
r_norm_sq=r_norm_sq,
|
|
289
|
-
maxiter=maxiter,
|
|
290
|
-
atol=atol,
|
|
291
|
-
callback=callback,
|
|
292
|
-
check_every=check_every,
|
|
293
|
-
use_cuda_graph=use_cuda_graph,
|
|
294
|
-
device=device,
|
|
295
|
-
)
|
|
466
|
+
return _run_capturable_loop(do_iteration, r_norm_sq, maxiter, atol_sq, callback, check_every, use_cuda_graph)
|
|
296
467
|
|
|
297
468
|
|
|
298
469
|
def cr(
|
|
@@ -319,13 +490,25 @@ def cr(
|
|
|
319
490
|
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
320
491
|
Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
|
|
321
492
|
M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
|
|
322
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
493
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
494
|
+
If `check_every` is 0, the callback should be a Warp kernel.
|
|
323
495
|
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
496
|
+
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
497
|
+
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
498
|
+
to the maximum number of iterations.
|
|
324
499
|
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
325
500
|
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
326
501
|
|
|
327
502
|
Returns:
|
|
328
|
-
|
|
503
|
+
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
504
|
+
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
505
|
+
- residual_norm: The final residual norm ||b - Ax||
|
|
506
|
+
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
507
|
+
|
|
508
|
+
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
509
|
+
- final_iteration_array: Device array containing the number of iterations performed
|
|
510
|
+
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
511
|
+
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
329
512
|
|
|
330
513
|
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
331
514
|
"""
|
|
@@ -336,42 +519,65 @@ def cr(
|
|
|
336
519
|
if maxiter == 0:
|
|
337
520
|
maxiter = A.shape[0]
|
|
338
521
|
|
|
339
|
-
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
340
|
-
|
|
341
522
|
device = A.device
|
|
342
|
-
|
|
523
|
+
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
343
524
|
|
|
344
525
|
# Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
|
|
345
526
|
# with z := M^-1 r and y := M^-1 Ap
|
|
346
527
|
|
|
347
|
-
#
|
|
528
|
+
# Temp storage
|
|
529
|
+
r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
|
|
530
|
+
r_and_Az = wp.empty_like(r_and_z)
|
|
531
|
+
y_and_Ap = wp.empty_like(r_and_z)
|
|
532
|
+
p = wp.empty_like(b)
|
|
533
|
+
residuals = wp.empty(2, dtype=scalar_type, device=device)
|
|
534
|
+
|
|
535
|
+
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
|
|
536
|
+
|
|
348
537
|
if M is None:
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
z = wp.zeros_like(r)
|
|
352
|
-
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
538
|
+
r_and_z = _repeat_first(r_and_z)
|
|
539
|
+
y_and_Ap = _repeat_first(y_and_Ap)
|
|
353
540
|
|
|
354
|
-
|
|
355
|
-
|
|
541
|
+
# named views
|
|
542
|
+
r, z = r_and_z[0], r_and_z[1]
|
|
543
|
+
r_copy, Az = r_and_Az[0], r_and_Az[1]
|
|
356
544
|
|
|
357
|
-
|
|
358
|
-
Ap = wp.clone(Az)
|
|
545
|
+
y, Ap = y_and_Ap[0], y_and_Ap[1]
|
|
359
546
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
547
|
+
r_norm_sq = tiled_dot.col(0)
|
|
548
|
+
zAz_new = tiled_dot.col(1)
|
|
549
|
+
zAz_old, atol_sq = residuals[0:1], residuals[1:2]
|
|
550
|
+
|
|
551
|
+
# Initialize tolerance from right-hand-side norm
|
|
552
|
+
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
553
|
+
# Initialize residual
|
|
554
|
+
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
555
|
+
|
|
556
|
+
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
557
|
+
y_and_Ap.zero_()
|
|
364
558
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
559
|
+
# z = M r
|
|
560
|
+
if M is not None:
|
|
561
|
+
z.zero_()
|
|
562
|
+
M.matvec(r, z, z, alpha=1.0, beta=0.0)
|
|
563
|
+
|
|
564
|
+
def update_rr_zAz():
|
|
565
|
+
A.matvec(z, Az, Az, alpha=1, beta=0)
|
|
566
|
+
r_copy.assign(r)
|
|
567
|
+
tiled_dot.compute(r_and_z, r_and_Az)
|
|
568
|
+
|
|
569
|
+
update_rr_zAz()
|
|
368
570
|
|
|
369
|
-
|
|
571
|
+
p.assign(z)
|
|
572
|
+
Ap.assign(Az)
|
|
573
|
+
|
|
574
|
+
def do_iteration():
|
|
575
|
+
zAz_old.assign(zAz_new)
|
|
370
576
|
|
|
371
|
-
def do_iteration(atol_sq, rr, zAz_old, zAz_new):
|
|
372
577
|
if M is not None:
|
|
373
578
|
M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
|
|
374
|
-
|
|
579
|
+
tiled_dot.compute(Ap, y, col_offset=1)
|
|
580
|
+
y_Ap = tiled_dot.col(1)
|
|
375
581
|
|
|
376
582
|
if M is None:
|
|
377
583
|
# In non-preconditioned case, first kernel is same as CG
|
|
@@ -379,7 +585,7 @@ def cr(
|
|
|
379
585
|
kernel=_cg_kernel_1,
|
|
380
586
|
dim=x.shape[0],
|
|
381
587
|
device=device,
|
|
382
|
-
inputs=[atol_sq,
|
|
588
|
+
inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, p, Ap],
|
|
383
589
|
)
|
|
384
590
|
else:
|
|
385
591
|
# In preconditioned case, we have one more vector to update
|
|
@@ -387,34 +593,26 @@ def cr(
|
|
|
387
593
|
kernel=_cr_kernel_1,
|
|
388
594
|
dim=x.shape[0],
|
|
389
595
|
device=device,
|
|
390
|
-
inputs=[atol_sq,
|
|
596
|
+
inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, z, p, Ap, y],
|
|
391
597
|
)
|
|
392
598
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
A.matvec(z, Az, Az, alpha=1, beta=0)
|
|
396
|
-
array_inner(z, Az, out=zAz_new)
|
|
397
|
-
|
|
398
|
-
# beta = rz_new / rz_old
|
|
599
|
+
update_rr_zAz()
|
|
399
600
|
wp.launch(
|
|
400
|
-
kernel=_cr_kernel_2,
|
|
601
|
+
kernel=_cr_kernel_2,
|
|
602
|
+
dim=z.shape[0],
|
|
603
|
+
device=device,
|
|
604
|
+
inputs=[atol_sq, r_norm_sq, zAz_old, zAz_new, z, p, Az, Ap],
|
|
401
605
|
)
|
|
402
606
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
do_iteration(atol_sq, r_norm_sq, zAz_old, zAz_new)
|
|
407
|
-
|
|
408
|
-
return _run_solver_loop(
|
|
409
|
-
do_odd_even_cycle,
|
|
410
|
-
cycle_size=2,
|
|
607
|
+
return _run_capturable_loop(
|
|
608
|
+
do_iteration,
|
|
609
|
+
cycle_size=1,
|
|
411
610
|
r_norm_sq=r_norm_sq,
|
|
412
611
|
maxiter=maxiter,
|
|
413
|
-
|
|
612
|
+
atol_sq=atol_sq,
|
|
414
613
|
callback=callback,
|
|
415
614
|
check_every=check_every,
|
|
416
615
|
use_cuda_graph=use_cuda_graph,
|
|
417
|
-
device=device,
|
|
418
616
|
)
|
|
419
617
|
|
|
420
618
|
|
|
@@ -441,14 +639,26 @@ def bicgstab(
|
|
|
441
639
|
atol: absolute tolerance for the residual
|
|
442
640
|
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
443
641
|
M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
|
|
444
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
642
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
643
|
+
If `check_every` is 0, the callback should be a Warp kernel.
|
|
445
644
|
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
645
|
+
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
646
|
+
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
647
|
+
to the maximum number of iterations.
|
|
446
648
|
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
447
|
-
|
|
649
|
+
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
448
650
|
is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
|
|
449
651
|
|
|
450
652
|
Returns:
|
|
451
|
-
|
|
653
|
+
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
654
|
+
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
655
|
+
- residual_norm: The final residual norm ||b - Ax||
|
|
656
|
+
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
657
|
+
|
|
658
|
+
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
659
|
+
- final_iteration_array: Device array containing the number of iterations performed
|
|
660
|
+
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
661
|
+
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
452
662
|
|
|
453
663
|
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
454
664
|
"""
|
|
@@ -458,23 +668,19 @@ def bicgstab(
|
|
|
458
668
|
if maxiter == 0:
|
|
459
669
|
maxiter = A.shape[0]
|
|
460
670
|
|
|
461
|
-
r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
|
|
462
|
-
|
|
463
671
|
device = A.device
|
|
464
|
-
|
|
672
|
+
scalar_type = wp.types.type_scalar_type(A.dtype)
|
|
465
673
|
|
|
466
674
|
# Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
|
|
467
675
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
676
|
+
# Temp storage
|
|
677
|
+
r_and_r0 = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
|
|
678
|
+
p = wp.empty_like(b)
|
|
679
|
+
v = wp.empty_like(b)
|
|
680
|
+
t = wp.empty_like(b)
|
|
472
681
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
v = wp.zeros_like(r)
|
|
476
|
-
t = wp.zeros_like(r)
|
|
477
|
-
p = wp.clone(r)
|
|
682
|
+
r, r0 = r_and_r0[0], r_and_r0[1]
|
|
683
|
+
r_repeated = _repeat_first(r_and_r0)
|
|
478
684
|
|
|
479
685
|
if M is not None:
|
|
480
686
|
y = wp.zeros_like(p)
|
|
@@ -486,7 +692,27 @@ def bicgstab(
|
|
|
486
692
|
z = r
|
|
487
693
|
Mt = t
|
|
488
694
|
|
|
489
|
-
|
|
695
|
+
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=5)
|
|
696
|
+
r_norm_sq = tiled_dot.col(0)
|
|
697
|
+
rho = tiled_dot.col(1)
|
|
698
|
+
|
|
699
|
+
atol_sq = wp.empty(1, dtype=scalar_type, device=device)
|
|
700
|
+
|
|
701
|
+
# Initialize tolerance from right-hand-side norm
|
|
702
|
+
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
703
|
+
# Initialize residual
|
|
704
|
+
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
705
|
+
tiled_dot.compute(r, r, col_offset=0)
|
|
706
|
+
|
|
707
|
+
p.assign(r)
|
|
708
|
+
r0.assign(r)
|
|
709
|
+
rho.assign(r_norm_sq)
|
|
710
|
+
|
|
711
|
+
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
712
|
+
v.zero_()
|
|
713
|
+
t.zero_()
|
|
714
|
+
|
|
715
|
+
def do_iteration():
|
|
490
716
|
# y = M p
|
|
491
717
|
if M is not None:
|
|
492
718
|
M.matvec(p, y, y, alpha=1.0, beta=0.0)
|
|
@@ -495,7 +721,8 @@ def bicgstab(
|
|
|
495
721
|
A.matvec(y, v, v, alpha=1, beta=0)
|
|
496
722
|
|
|
497
723
|
# alpha = rho / <r0 . v>
|
|
498
|
-
|
|
724
|
+
tiled_dot.compute(r0, v, col_offset=2)
|
|
725
|
+
r0v = tiled_dot.col(2)
|
|
499
726
|
|
|
500
727
|
# x += alpha y
|
|
501
728
|
# r -= alpha v
|
|
@@ -505,7 +732,7 @@ def bicgstab(
|
|
|
505
732
|
device=device,
|
|
506
733
|
inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
|
|
507
734
|
)
|
|
508
|
-
|
|
735
|
+
tiled_dot.compute(r, r, col_offset=0)
|
|
509
736
|
|
|
510
737
|
# z = M r
|
|
511
738
|
if M is not None:
|
|
@@ -514,17 +741,18 @@ def bicgstab(
|
|
|
514
741
|
# t = A z
|
|
515
742
|
A.matvec(z, t, t, alpha=1, beta=0)
|
|
516
743
|
|
|
517
|
-
if is_left_preconditioner:
|
|
744
|
+
if M is not None and is_left_preconditioner:
|
|
518
745
|
# Mt = M t
|
|
519
|
-
|
|
520
|
-
M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
|
|
746
|
+
M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
|
|
521
747
|
|
|
522
748
|
# omega = <Mt, Ms> / <Mt, Mt>
|
|
523
|
-
|
|
524
|
-
|
|
749
|
+
tiled_dot.compute(z, Mt, col_offset=3)
|
|
750
|
+
tiled_dot.compute(Mt, Mt, col_offset=4)
|
|
525
751
|
else:
|
|
526
|
-
|
|
527
|
-
|
|
752
|
+
tiled_dot.compute(r, t, col_offset=3)
|
|
753
|
+
tiled_dot.compute(t, t, col_offset=4)
|
|
754
|
+
st = tiled_dot.col(3)
|
|
755
|
+
tt = tiled_dot.col(4)
|
|
528
756
|
|
|
529
757
|
# x += omega z
|
|
530
758
|
# r -= omega t
|
|
@@ -534,10 +762,9 @@ def bicgstab(
|
|
|
534
762
|
device=device,
|
|
535
763
|
inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
|
|
536
764
|
)
|
|
537
|
-
array_inner(r, r, out=r_norm_sq)
|
|
538
765
|
|
|
539
|
-
# rho = <r0, r>
|
|
540
|
-
|
|
766
|
+
# r = <r,r>, rho = <r0, r>
|
|
767
|
+
tiled_dot.compute(r_and_r0, r_repeated, col_offset=0)
|
|
541
768
|
|
|
542
769
|
# beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
|
|
543
770
|
# p = r + beta (p - omega v)
|
|
@@ -548,16 +775,14 @@ def bicgstab(
|
|
|
548
775
|
inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
|
|
549
776
|
)
|
|
550
777
|
|
|
551
|
-
return
|
|
778
|
+
return _run_capturable_loop(
|
|
552
779
|
do_iteration,
|
|
553
|
-
cycle_size=1,
|
|
554
780
|
r_norm_sq=r_norm_sq,
|
|
555
781
|
maxiter=maxiter,
|
|
556
|
-
|
|
782
|
+
atol_sq=atol_sq,
|
|
557
783
|
callback=callback,
|
|
558
784
|
check_every=check_every,
|
|
559
785
|
use_cuda_graph=use_cuda_graph,
|
|
560
|
-
device=device,
|
|
561
786
|
)
|
|
562
787
|
|
|
563
788
|
|
|
@@ -587,14 +812,26 @@ def gmres(
|
|
|
587
812
|
maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
|
|
588
813
|
Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
|
|
589
814
|
M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
|
|
590
|
-
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
|
|
815
|
+
callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
|
|
816
|
+
If `check_every` is 0, the callback should be a Warp kernel.
|
|
591
817
|
check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
|
|
818
|
+
Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
|
|
819
|
+
If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
|
|
820
|
+
to the maximum number of iterations.
|
|
592
821
|
use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
|
|
593
822
|
The linear operator and preconditioner must only perform graph-friendly operations.
|
|
594
823
|
is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
|
|
595
824
|
|
|
596
825
|
Returns:
|
|
597
|
-
|
|
826
|
+
If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
|
|
827
|
+
- final_iteration: The number of iterations performed before convergence or reaching maxiter
|
|
828
|
+
- residual_norm: The final residual norm ||b - Ax||
|
|
829
|
+
- absolute_tolerance: The absolute tolerance used for convergence checking
|
|
830
|
+
|
|
831
|
+
If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
|
|
832
|
+
- final_iteration_array: Device array containing the number of iterations performed
|
|
833
|
+
- residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
|
|
834
|
+
- absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
|
|
598
835
|
|
|
599
836
|
If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
|
|
600
837
|
"""
|
|
@@ -606,89 +843,132 @@ def gmres(
|
|
|
606
843
|
maxiter = A.shape[0]
|
|
607
844
|
|
|
608
845
|
restart = min(restart, maxiter)
|
|
609
|
-
check_every = max(restart, check_every)
|
|
610
846
|
|
|
611
|
-
|
|
847
|
+
if check_every > 0:
|
|
848
|
+
check_every = max(restart, check_every)
|
|
612
849
|
|
|
613
850
|
device = A.device
|
|
614
851
|
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
615
852
|
|
|
616
853
|
pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
|
|
617
854
|
|
|
618
|
-
|
|
619
|
-
|
|
855
|
+
r = wp.empty_like(b)
|
|
856
|
+
w = wp.empty_like(r)
|
|
620
857
|
|
|
858
|
+
H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
|
|
621
859
|
y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
|
|
622
860
|
|
|
623
|
-
w = wp.zeros_like(r)
|
|
624
861
|
V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
|
|
625
862
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
ptr=H.ptr + i * H.strides[0] + j * H.strides[1],
|
|
629
|
-
dtype=H.dtype,
|
|
630
|
-
shape=(1,),
|
|
631
|
-
device=H.device,
|
|
632
|
-
copy=False,
|
|
633
|
-
)
|
|
863
|
+
residuals = wp.empty(2, dtype=scalar_dtype, device=device)
|
|
864
|
+
beta, atol_sq = residuals[0:1], residuals[1:2]
|
|
634
865
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
ptr=V.ptr + i * V.strides[0],
|
|
638
|
-
dtype=V.dtype,
|
|
639
|
-
shape=V.shape[1],
|
|
640
|
-
device=V.device,
|
|
641
|
-
copy=False,
|
|
642
|
-
)
|
|
866
|
+
tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_dtype, max_column_count=restart + 1)
|
|
867
|
+
r_norm_sq = tiled_dot.col(0)
|
|
643
868
|
|
|
644
|
-
|
|
645
|
-
|
|
869
|
+
w_repeated = wp.array(
|
|
870
|
+
ptr=w.ptr, shape=(restart + 1, w.shape[0]), strides=(0, w.strides[0]), dtype=w.dtype, device=w.device
|
|
871
|
+
)
|
|
646
872
|
|
|
647
|
-
|
|
873
|
+
# tile size for least square solve
|
|
874
|
+
# (need to fit in a CUDA block, so 1024 max)
|
|
875
|
+
if device.is_cuda and 4 < restart <= 1024:
|
|
876
|
+
tile_size = 1 << math.ceil(math.log2(restart))
|
|
877
|
+
least_squares_kernel = make_gmres_solve_least_squares_kernel_tiled(tile_size)
|
|
878
|
+
else:
|
|
879
|
+
tile_size = 1
|
|
880
|
+
least_squares_kernel = _gmres_solve_least_squares
|
|
881
|
+
|
|
882
|
+
# recorded launches
|
|
883
|
+
least_squares_solve = wp.launch(
|
|
884
|
+
least_squares_kernel,
|
|
885
|
+
dim=(1, tile_size),
|
|
886
|
+
block_dim=tile_size if tile_size > 1 else 256,
|
|
887
|
+
device=device,
|
|
888
|
+
inputs=[restart, pivot_tolerance, beta, H, y],
|
|
889
|
+
record_cmd=True,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
normalize_anorldi_vec = wp.launch(
|
|
893
|
+
_gmres_arnoldi_normalize_kernel,
|
|
894
|
+
dim=r.shape,
|
|
895
|
+
device=r.device,
|
|
896
|
+
inputs=[r, w, tiled_dot.col(0), beta],
|
|
897
|
+
record_cmd=True,
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
arnoldi_axpy = wp.launch(
|
|
901
|
+
_gmres_arnoldi_axpy_kernel,
|
|
902
|
+
dim=(w.shape[0], tile_size),
|
|
903
|
+
block_dim=tile_size,
|
|
904
|
+
device=w.device,
|
|
905
|
+
inputs=[V, w, H],
|
|
906
|
+
record_cmd=True,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
# Initialize tolerance from right-hand-side norm
|
|
910
|
+
_initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
|
|
911
|
+
# Initialize residual
|
|
912
|
+
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
913
|
+
tiled_dot.compute(r, r, col_offset=0)
|
|
914
|
+
|
|
915
|
+
# Not strictly necessary, but makes it more robust to user-provided LinearOperators
|
|
916
|
+
w.zero_()
|
|
917
|
+
|
|
918
|
+
def array_coeff(H, i, j):
|
|
919
|
+
return H[i][j : j + 1]
|
|
648
920
|
|
|
921
|
+
def array_col(H, j):
|
|
922
|
+
return H[: j + 1, j : j + 1]
|
|
923
|
+
|
|
924
|
+
def do_arnoldi_iteration(j: int):
|
|
925
|
+
# w = A * v[j];
|
|
649
926
|
if M is not None:
|
|
650
|
-
tmp =
|
|
927
|
+
tmp = V[j + 1]
|
|
651
928
|
|
|
652
929
|
if is_left_preconditioner:
|
|
653
|
-
A.matvec(
|
|
930
|
+
A.matvec(V[j], tmp, tmp, alpha=1, beta=0)
|
|
654
931
|
M.matvec(tmp, w, w, alpha=1, beta=0)
|
|
655
932
|
else:
|
|
656
|
-
M.matvec(
|
|
933
|
+
M.matvec(V[j], tmp, tmp, alpha=1, beta=0)
|
|
657
934
|
A.matvec(tmp, w, w, alpha=1, beta=0)
|
|
658
935
|
else:
|
|
659
|
-
A.matvec(
|
|
936
|
+
A.matvec(V[j], w, w, alpha=1, beta=0)
|
|
660
937
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
938
|
+
# compute and apply dot products in rappel,
|
|
939
|
+
# since Hj columns are orthogonal
|
|
940
|
+
Hj = array_col(H, j)
|
|
941
|
+
tiled_dot.compute(w_repeated, V[: j + 1])
|
|
942
|
+
wp.copy(src=tiled_dot.cols(j + 1), dest=Hj)
|
|
665
943
|
|
|
666
|
-
|
|
944
|
+
# w -= w.vi vi
|
|
945
|
+
arnoldi_axpy.set_params([V[: j + 1], w, Hj])
|
|
946
|
+
arnoldi_axpy.launch()
|
|
667
947
|
|
|
668
|
-
|
|
669
|
-
|
|
948
|
+
# H[j+1, j] = |w.w|
|
|
949
|
+
tiled_dot.compute(w, w)
|
|
950
|
+
normalize_anorldi_vec.set_params([w, V[j + 1], tiled_dot.col(0), array_coeff(H, j + 1, j)])
|
|
670
951
|
|
|
671
|
-
|
|
672
|
-
wp.launch(_gmres_arnoldi_normalize_kernel, dim=w.shape, device=w.device, inputs=[w, vjn, hjnj])
|
|
952
|
+
normalize_anorldi_vec.launch()
|
|
673
953
|
|
|
674
|
-
def do_restart_cycle(
|
|
954
|
+
def do_restart_cycle():
|
|
675
955
|
if M is not None and is_left_preconditioner:
|
|
676
956
|
M.matvec(r, w, w, alpha=1, beta=0)
|
|
677
957
|
rh = w
|
|
678
958
|
else:
|
|
679
959
|
rh = r
|
|
680
960
|
|
|
681
|
-
|
|
961
|
+
# beta^2 = rh.rh
|
|
962
|
+
tiled_dot.compute(rh, rh)
|
|
682
963
|
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
964
|
+
# v[0] = r / beta
|
|
965
|
+
normalize_anorldi_vec.set_params([rh, V[0], tiled_dot.col(0), beta])
|
|
966
|
+
normalize_anorldi_vec.launch()
|
|
686
967
|
|
|
687
968
|
for j in range(restart):
|
|
688
969
|
do_arnoldi_iteration(j)
|
|
689
970
|
|
|
690
|
-
|
|
691
|
-
wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, pivot_tolerance, beta_sq, H, y])
|
|
971
|
+
least_squares_solve.launch()
|
|
692
972
|
|
|
693
973
|
# update x
|
|
694
974
|
if M is None or is_left_preconditioner:
|
|
@@ -700,21 +980,33 @@ def gmres(
|
|
|
700
980
|
# update r and residual
|
|
701
981
|
wp.copy(src=b, dest=r)
|
|
702
982
|
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
703
|
-
|
|
983
|
+
tiled_dot.compute(r, r)
|
|
704
984
|
|
|
705
|
-
return
|
|
985
|
+
return _run_capturable_loop(
|
|
706
986
|
do_restart_cycle,
|
|
707
987
|
cycle_size=restart,
|
|
708
988
|
r_norm_sq=r_norm_sq,
|
|
709
989
|
maxiter=maxiter,
|
|
710
|
-
|
|
990
|
+
atol_sq=atol_sq,
|
|
711
991
|
callback=callback,
|
|
712
992
|
check_every=check_every,
|
|
713
993
|
use_cuda_graph=use_cuda_graph,
|
|
714
|
-
device=device,
|
|
715
994
|
)
|
|
716
995
|
|
|
717
996
|
|
|
997
|
+
def _repeat_first(arr: wp.array):
|
|
998
|
+
# returns a view of the first element repeated arr.shape[0] times
|
|
999
|
+
view = wp.array(
|
|
1000
|
+
ptr=arr.ptr,
|
|
1001
|
+
shape=arr.shape,
|
|
1002
|
+
dtype=arr.dtype,
|
|
1003
|
+
strides=(0, *arr.strides[1:]),
|
|
1004
|
+
device=arr.device,
|
|
1005
|
+
)
|
|
1006
|
+
view._ref = arr
|
|
1007
|
+
return view
|
|
1008
|
+
|
|
1009
|
+
|
|
718
1010
|
def _get_dtype_epsilon(dtype):
|
|
719
1011
|
if dtype == wp.float64:
|
|
720
1012
|
return 1.0e-16
|
|
@@ -724,7 +1016,7 @@ def _get_dtype_epsilon(dtype):
|
|
|
724
1016
|
return 1.0e-8
|
|
725
1017
|
|
|
726
1018
|
|
|
727
|
-
def
|
|
1019
|
+
def _get_tolerances(dtype, tol, atol):
|
|
728
1020
|
eps_tol = _get_dtype_epsilon(dtype)
|
|
729
1021
|
default_tol = eps_tol ** (3 / 4)
|
|
730
1022
|
min_tol = eps_tol ** (9 / 4)
|
|
@@ -736,27 +1028,115 @@ def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
|
|
|
736
1028
|
elif atol is None:
|
|
737
1029
|
atol = tol
|
|
738
1030
|
|
|
739
|
-
|
|
1031
|
+
atol = max(atol, min_tol)
|
|
1032
|
+
return tol, atol
|
|
740
1033
|
|
|
741
1034
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
1035
|
+
@wp.kernel
|
|
1036
|
+
def _initialize_tolerance(
|
|
1037
|
+
rtol: Any,
|
|
1038
|
+
atol: Any,
|
|
1039
|
+
r_norm_sq: wp.array(dtype=Any),
|
|
1040
|
+
atol_sq: wp.array(dtype=Any),
|
|
1041
|
+
):
|
|
1042
|
+
atol = wp.max(rtol * wp.sqrt(r_norm_sq[0]), atol)
|
|
1043
|
+
atol_sq[0] = atol * atol
|
|
745
1044
|
|
|
746
|
-
|
|
747
|
-
|
|
1045
|
+
|
|
1046
|
+
def _initialize_absolute_tolerance(
|
|
1047
|
+
b: wp.array,
|
|
1048
|
+
tol: float,
|
|
1049
|
+
atol: float,
|
|
1050
|
+
tiled_dot: TiledDot,
|
|
1051
|
+
atol_sq: wp.array,
|
|
1052
|
+
):
|
|
1053
|
+
scalar_type = atol_sq.dtype
|
|
748
1054
|
|
|
749
1055
|
# Compute b norm to define absolute tolerance
|
|
750
|
-
|
|
751
|
-
|
|
1056
|
+
tiled_dot.compute(b, b)
|
|
1057
|
+
b_norm_sq = tiled_dot.col(0)
|
|
1058
|
+
|
|
1059
|
+
rtol, atol = _get_tolerances(scalar_type, tol, atol)
|
|
1060
|
+
wp.launch(
|
|
1061
|
+
kernel=_initialize_tolerance,
|
|
1062
|
+
dim=1,
|
|
1063
|
+
device=b.device,
|
|
1064
|
+
inputs=[scalar_type(rtol), scalar_type(atol), b_norm_sq, atol_sq],
|
|
1065
|
+
)
|
|
752
1066
|
|
|
753
|
-
# Residual r = b - Ax
|
|
754
|
-
r = wp.empty_like(b)
|
|
755
|
-
A.matvec(x, b, r, alpha=-1.0, beta=1.0)
|
|
756
1067
|
|
|
757
|
-
|
|
1068
|
+
@wp.kernel
|
|
1069
|
+
def _update_condition(
|
|
1070
|
+
maxiter: int,
|
|
1071
|
+
cycle_size: int,
|
|
1072
|
+
cur_iter: wp.array(dtype=int),
|
|
1073
|
+
r_norm_sq: wp.array(dtype=Any),
|
|
1074
|
+
atol_sq: wp.array(dtype=Any),
|
|
1075
|
+
condition: wp.array(dtype=int),
|
|
1076
|
+
):
|
|
1077
|
+
cur_iter[0] += cycle_size
|
|
1078
|
+
condition[0] = wp.where(r_norm_sq[0] <= atol_sq[0] or cur_iter[0] >= maxiter, 0, 1)
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
def _run_capturable_loop(
|
|
1082
|
+
do_cycle: Callable,
|
|
1083
|
+
r_norm_sq: wp.array,
|
|
1084
|
+
maxiter: int,
|
|
1085
|
+
atol_sq: wp.array,
|
|
1086
|
+
callback: Optional[Callable],
|
|
1087
|
+
check_every: int,
|
|
1088
|
+
use_cuda_graph: bool,
|
|
1089
|
+
cycle_size: int = 1,
|
|
1090
|
+
):
|
|
1091
|
+
device = atol_sq.device
|
|
1092
|
+
|
|
1093
|
+
if check_every > 0:
|
|
1094
|
+
atol = math.sqrt(atol_sq.numpy()[0])
|
|
1095
|
+
return _run_solver_loop(
|
|
1096
|
+
do_cycle, cycle_size, r_norm_sq, maxiter, atol, callback, check_every, use_cuda_graph, device
|
|
1097
|
+
)
|
|
1098
|
+
|
|
1099
|
+
cur_iter_and_condition = wp.full((2,), value=-1, dtype=int, device=device)
|
|
1100
|
+
cur_iter = cur_iter_and_condition[0:1]
|
|
1101
|
+
condition = cur_iter_and_condition[1:2]
|
|
1102
|
+
|
|
1103
|
+
update_condition_launch = wp.launch(
|
|
1104
|
+
_update_condition,
|
|
1105
|
+
dim=1,
|
|
1106
|
+
device=device,
|
|
1107
|
+
inputs=[int(maxiter), cycle_size, cur_iter, r_norm_sq, atol_sq, condition],
|
|
1108
|
+
record_cmd=True,
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
if isinstance(callback, wp.Kernel):
|
|
1112
|
+
callback_launch = wp.launch(
|
|
1113
|
+
callback, dim=1, device=device, inputs=[cur_iter, r_norm_sq, atol_sq], record_cmd=True
|
|
1114
|
+
)
|
|
1115
|
+
else:
|
|
1116
|
+
callback_launch = None
|
|
1117
|
+
|
|
1118
|
+
update_condition_launch.launch()
|
|
1119
|
+
if callback_launch is not None:
|
|
1120
|
+
callback_launch.launch()
|
|
1121
|
+
|
|
1122
|
+
def do_cycle_with_condition():
|
|
1123
|
+
do_cycle()
|
|
1124
|
+
update_condition_launch.launch()
|
|
1125
|
+
if callback_launch is not None:
|
|
1126
|
+
callback_launch.launch()
|
|
758
1127
|
|
|
759
|
-
|
|
1128
|
+
if use_cuda_graph and device.is_cuda:
|
|
1129
|
+
if device.is_capturing:
|
|
1130
|
+
wp.capture_while(condition, do_cycle_with_condition)
|
|
1131
|
+
else:
|
|
1132
|
+
with wp.ScopedCapture() as capture:
|
|
1133
|
+
wp.capture_while(condition, do_cycle_with_condition)
|
|
1134
|
+
wp.capture_launch(capture.graph)
|
|
1135
|
+
else:
|
|
1136
|
+
for _ in range(0, maxiter, cycle_size):
|
|
1137
|
+
do_cycle_with_condition()
|
|
1138
|
+
|
|
1139
|
+
return cur_iter, r_norm_sq, atol_sq
|
|
760
1140
|
|
|
761
1141
|
|
|
762
1142
|
def _run_solver_loop(
|
|
@@ -771,11 +1151,12 @@ def _run_solver_loop(
|
|
|
771
1151
|
device,
|
|
772
1152
|
):
|
|
773
1153
|
atol_sq = atol * atol
|
|
1154
|
+
check_every = max(check_every, cycle_size)
|
|
774
1155
|
|
|
775
1156
|
cur_iter = 0
|
|
776
1157
|
|
|
777
1158
|
err_sq = r_norm_sq.numpy()[0]
|
|
778
|
-
err = sqrt(err_sq)
|
|
1159
|
+
err = math.sqrt(err_sq)
|
|
779
1160
|
if callback is not None:
|
|
780
1161
|
callback(cur_iter, err, atol)
|
|
781
1162
|
|
|
@@ -788,14 +1169,12 @@ def _run_solver_loop(
|
|
|
788
1169
|
# Do not do graph capture at first iteration -- modules may not be loaded yet
|
|
789
1170
|
if device.is_cuda and use_cuda_graph and cur_iter > 0:
|
|
790
1171
|
if graph is None:
|
|
791
|
-
wp.
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
finally:
|
|
795
|
-
graph = wp.capture_end(device)
|
|
1172
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
1173
|
+
do_cycle()
|
|
1174
|
+
graph = capture.graph
|
|
796
1175
|
wp.capture_launch(graph)
|
|
797
1176
|
else:
|
|
798
|
-
do_cycle(
|
|
1177
|
+
do_cycle()
|
|
799
1178
|
|
|
800
1179
|
cur_iter += cycle_size
|
|
801
1180
|
|
|
@@ -809,24 +1188,16 @@ def _run_solver_loop(
|
|
|
809
1188
|
break
|
|
810
1189
|
|
|
811
1190
|
if callback is not None:
|
|
812
|
-
callback(cur_iter, sqrt(err_sq), atol)
|
|
1191
|
+
callback(cur_iter, math.sqrt(err_sq), atol)
|
|
813
1192
|
|
|
814
1193
|
err_sq = r_norm_sq.numpy()[0]
|
|
815
|
-
err = sqrt(err_sq)
|
|
1194
|
+
err = math.sqrt(err_sq)
|
|
816
1195
|
if callback is not None:
|
|
817
1196
|
callback(cur_iter, err, atol)
|
|
818
1197
|
|
|
819
1198
|
return cur_iter, err, atol
|
|
820
1199
|
|
|
821
1200
|
|
|
822
|
-
@wp.func
|
|
823
|
-
def _calc_mv_product(i: wp.int32, A: wp.array2d(dtype=Any), x: wp.array1d(dtype=Any)):
|
|
824
|
-
sum = A.dtype(0)
|
|
825
|
-
for j in range(A.shape[1]):
|
|
826
|
-
sum += A[i, j] * x[j]
|
|
827
|
-
return sum
|
|
828
|
-
|
|
829
|
-
|
|
830
1201
|
@wp.kernel
|
|
831
1202
|
def _dense_mv_kernel(
|
|
832
1203
|
A: wp.array2d(dtype=Any),
|
|
@@ -836,25 +1207,24 @@ def _dense_mv_kernel(
|
|
|
836
1207
|
alpha: Any,
|
|
837
1208
|
beta: Any,
|
|
838
1209
|
):
|
|
839
|
-
|
|
840
|
-
z[i] = z.dtype(beta) * y[i] + z.dtype(alpha) * _calc_mv_product(i, A, x)
|
|
1210
|
+
row, lane = wp.tid()
|
|
841
1211
|
|
|
1212
|
+
zero = type(alpha)(0)
|
|
1213
|
+
s = zero
|
|
1214
|
+
if alpha != zero:
|
|
1215
|
+
for col in range(lane, A.shape[1], wp.block_dim()):
|
|
1216
|
+
s += A[row, col] * x[col]
|
|
842
1217
|
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
alpha: Any,
|
|
850
|
-
beta: Any,
|
|
851
|
-
):
|
|
852
|
-
i = wp.tid()
|
|
853
|
-
z[i] = beta * y[i] + alpha * (A[i] * x[i])
|
|
1218
|
+
row_tile = wp.tile_sum(wp.tile(s * alpha))
|
|
1219
|
+
|
|
1220
|
+
if beta != zero:
|
|
1221
|
+
row_tile += wp.tile_load(y, shape=1, offset=row) * beta
|
|
1222
|
+
|
|
1223
|
+
wp.tile_store(z, row_tile, offset=row)
|
|
854
1224
|
|
|
855
1225
|
|
|
856
1226
|
@wp.kernel
|
|
857
|
-
def
|
|
1227
|
+
def _diag_mv_kernel(
|
|
858
1228
|
A: wp.array(dtype=Any),
|
|
859
1229
|
x: wp.array(dtype=Any),
|
|
860
1230
|
y: wp.array(dtype=Any),
|
|
@@ -863,7 +1233,13 @@ def _diag_mv_vec_kernel(
|
|
|
863
1233
|
beta: Any,
|
|
864
1234
|
):
|
|
865
1235
|
i = wp.tid()
|
|
866
|
-
|
|
1236
|
+
zero = type(alpha)(0)
|
|
1237
|
+
s = z.dtype(zero)
|
|
1238
|
+
if alpha != zero:
|
|
1239
|
+
s += alpha * (A[i] * x[i])
|
|
1240
|
+
if beta != zero:
|
|
1241
|
+
s += beta * y[i]
|
|
1242
|
+
z[i] = s
|
|
867
1243
|
|
|
868
1244
|
|
|
869
1245
|
@wp.func
|
|
@@ -910,7 +1286,7 @@ def _extract_inverse_diagonal_dense(
|
|
|
910
1286
|
|
|
911
1287
|
@wp.kernel
|
|
912
1288
|
def _cg_kernel_1(
|
|
913
|
-
tol: Any,
|
|
1289
|
+
tol: wp.array(dtype=Any),
|
|
914
1290
|
resid: wp.array(dtype=Any),
|
|
915
1291
|
rz_old: wp.array(dtype=Any),
|
|
916
1292
|
p_Ap: wp.array(dtype=Any),
|
|
@@ -921,7 +1297,7 @@ def _cg_kernel_1(
|
|
|
921
1297
|
):
|
|
922
1298
|
i = wp.tid()
|
|
923
1299
|
|
|
924
|
-
alpha = wp.where(resid[0] > tol, rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
|
|
1300
|
+
alpha = wp.where(resid[0] > tol[0], rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
|
|
925
1301
|
|
|
926
1302
|
x[i] = x[i] + alpha * p[i]
|
|
927
1303
|
r[i] = r[i] - alpha * Ap[i]
|
|
@@ -929,8 +1305,8 @@ def _cg_kernel_1(
|
|
|
929
1305
|
|
|
930
1306
|
@wp.kernel
|
|
931
1307
|
def _cg_kernel_2(
|
|
932
|
-
tol: Any,
|
|
933
|
-
|
|
1308
|
+
tol: wp.array(dtype=Any),
|
|
1309
|
+
resid_new: wp.array(dtype=Any),
|
|
934
1310
|
rz_old: wp.array(dtype=Any),
|
|
935
1311
|
rz_new: wp.array(dtype=Any),
|
|
936
1312
|
z: wp.array(dtype=Any),
|
|
@@ -939,14 +1315,15 @@ def _cg_kernel_2(
|
|
|
939
1315
|
# p = r + (rz_new / rz_old) * p;
|
|
940
1316
|
i = wp.tid()
|
|
941
1317
|
|
|
942
|
-
|
|
1318
|
+
cond = resid_new[0] > tol[0]
|
|
1319
|
+
beta = wp.where(cond, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
|
|
943
1320
|
|
|
944
1321
|
p[i] = z[i] + beta * p[i]
|
|
945
1322
|
|
|
946
1323
|
|
|
947
1324
|
@wp.kernel
|
|
948
1325
|
def _cr_kernel_1(
|
|
949
|
-
tol: Any,
|
|
1326
|
+
tol: wp.array(dtype=Any),
|
|
950
1327
|
resid: wp.array(dtype=Any),
|
|
951
1328
|
zAz_old: wp.array(dtype=Any),
|
|
952
1329
|
y_Ap: wp.array(dtype=Any),
|
|
@@ -959,7 +1336,7 @@ def _cr_kernel_1(
|
|
|
959
1336
|
):
|
|
960
1337
|
i = wp.tid()
|
|
961
1338
|
|
|
962
|
-
alpha = wp.where(resid[0] > tol and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
|
|
1339
|
+
alpha = wp.where(resid[0] > tol[0] and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
|
|
963
1340
|
|
|
964
1341
|
x[i] = x[i] + alpha * p[i]
|
|
965
1342
|
r[i] = r[i] - alpha * Ap[i]
|
|
@@ -968,7 +1345,7 @@ def _cr_kernel_1(
|
|
|
968
1345
|
|
|
969
1346
|
@wp.kernel
|
|
970
1347
|
def _cr_kernel_2(
|
|
971
|
-
tol: Any,
|
|
1348
|
+
tol: wp.array(dtype=Any),
|
|
972
1349
|
resid: wp.array(dtype=Any),
|
|
973
1350
|
zAz_old: wp.array(dtype=Any),
|
|
974
1351
|
zAz_new: wp.array(dtype=Any),
|
|
@@ -980,7 +1357,7 @@ def _cr_kernel_2(
|
|
|
980
1357
|
# p = r + (rz_new / rz_old) * p;
|
|
981
1358
|
i = wp.tid()
|
|
982
1359
|
|
|
983
|
-
beta = wp.where(resid[0] > tol and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
|
|
1360
|
+
beta = wp.where(resid[0] > tol[0] and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
|
|
984
1361
|
|
|
985
1362
|
p[i] = z[i] + beta * p[i]
|
|
986
1363
|
Ap[i] = Az[i] + beta * Ap[i]
|
|
@@ -988,7 +1365,7 @@ def _cr_kernel_2(
|
|
|
988
1365
|
|
|
989
1366
|
@wp.kernel
|
|
990
1367
|
def _bicgstab_kernel_1(
|
|
991
|
-
tol: Any,
|
|
1368
|
+
tol: wp.array(dtype=Any),
|
|
992
1369
|
resid: wp.array(dtype=Any),
|
|
993
1370
|
rho_old: wp.array(dtype=Any),
|
|
994
1371
|
r0v: wp.array(dtype=Any),
|
|
@@ -999,7 +1376,7 @@ def _bicgstab_kernel_1(
|
|
|
999
1376
|
):
|
|
1000
1377
|
i = wp.tid()
|
|
1001
1378
|
|
|
1002
|
-
alpha = wp.where(resid[0] > tol, rho_old[0] / r0v[0], rho_old.dtype(0.0))
|
|
1379
|
+
alpha = wp.where(resid[0] > tol[0], rho_old[0] / r0v[0], rho_old.dtype(0.0))
|
|
1003
1380
|
|
|
1004
1381
|
x[i] += alpha * y[i]
|
|
1005
1382
|
r[i] -= alpha * v[i]
|
|
@@ -1007,7 +1384,7 @@ def _bicgstab_kernel_1(
|
|
|
1007
1384
|
|
|
1008
1385
|
@wp.kernel
|
|
1009
1386
|
def _bicgstab_kernel_2(
|
|
1010
|
-
tol: Any,
|
|
1387
|
+
tol: wp.array(dtype=Any),
|
|
1011
1388
|
resid: wp.array(dtype=Any),
|
|
1012
1389
|
st: wp.array(dtype=Any),
|
|
1013
1390
|
tt: wp.array(dtype=Any),
|
|
@@ -1018,7 +1395,7 @@ def _bicgstab_kernel_2(
|
|
|
1018
1395
|
):
|
|
1019
1396
|
i = wp.tid()
|
|
1020
1397
|
|
|
1021
|
-
omega = wp.where(resid[0] > tol, st[0] / tt[0], st.dtype(0.0))
|
|
1398
|
+
omega = wp.where(resid[0] > tol[0], st[0] / tt[0], st.dtype(0.0))
|
|
1022
1399
|
|
|
1023
1400
|
x[i] += omega * z[i]
|
|
1024
1401
|
r[i] -= omega * t[i]
|
|
@@ -1026,7 +1403,7 @@ def _bicgstab_kernel_2(
|
|
|
1026
1403
|
|
|
1027
1404
|
@wp.kernel
|
|
1028
1405
|
def _bicgstab_kernel_3(
|
|
1029
|
-
tol: Any,
|
|
1406
|
+
tol: wp.array(dtype=Any),
|
|
1030
1407
|
resid: wp.array(dtype=Any),
|
|
1031
1408
|
rho_new: wp.array(dtype=Any),
|
|
1032
1409
|
r0v: wp.array(dtype=Any),
|
|
@@ -1038,32 +1415,21 @@ def _bicgstab_kernel_3(
|
|
|
1038
1415
|
):
|
|
1039
1416
|
i = wp.tid()
|
|
1040
1417
|
|
|
1041
|
-
beta = wp.where(resid[0] > tol, rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
|
|
1042
|
-
beta_omega = wp.where(resid[0] > tol, rho_new[0] / r0v[0], st.dtype(0.0))
|
|
1418
|
+
beta = wp.where(resid[0] > tol[0], rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
|
|
1419
|
+
beta_omega = wp.where(resid[0] > tol[0], rho_new[0] / r0v[0], st.dtype(0.0))
|
|
1043
1420
|
|
|
1044
1421
|
p[i] = r[i] + beta * p[i] - beta_omega * v[i]
|
|
1045
1422
|
|
|
1046
1423
|
|
|
1047
|
-
@wp.kernel
|
|
1048
|
-
def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
|
|
1049
|
-
# normalize lower-diagonal values of Hessenberg matrix
|
|
1050
|
-
i = wp.tid()
|
|
1051
|
-
H[i + 1, i] = wp.sqrt(H[i + 1, i])
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
1424
|
@wp.kernel
|
|
1055
1425
|
def _gmres_solve_least_squares(
|
|
1056
|
-
k: int, pivot_tolerance: float,
|
|
1426
|
+
k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
|
|
1057
1427
|
):
|
|
1058
1428
|
# Solve H y = (beta, 0, ..., 0)
|
|
1059
1429
|
# H Hessenberg matrix of shape (k+1, k)
|
|
1060
|
-
|
|
1061
|
-
# Keeping H in global mem; warp kernels are launched with fixed block size,
|
|
1062
1430
|
# so would not fit in registers
|
|
1063
1431
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
rhs = wp.sqrt(beta_sq[0])
|
|
1432
|
+
rhs = beta[0]
|
|
1067
1433
|
|
|
1068
1434
|
# Apply 2x2 rotations to H so as to remove lower diagonal,
|
|
1069
1435
|
# and apply similar rotations to right-hand-side
|
|
@@ -1110,14 +1476,108 @@ def _gmres_solve_least_squares(
|
|
|
1110
1476
|
y[i] = yi / Hi[i]
|
|
1111
1477
|
|
|
1112
1478
|
|
|
1479
|
+
@functools.lru_cache(maxsize=None)
|
|
1480
|
+
def make_gmres_solve_least_squares_kernel_tiled(K: int):
|
|
1481
|
+
@wp.kernel(module="unique")
|
|
1482
|
+
def gmres_solve_least_squares_tiled(
|
|
1483
|
+
k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
|
|
1484
|
+
):
|
|
1485
|
+
# Assumes tiles of size K, and K at least as large as highest number of columns
|
|
1486
|
+
# Limits the max restart cycle length to the max block size of 1024, but using
|
|
1487
|
+
# larger restarts would be very inefficient anyway (default is ~30)
|
|
1488
|
+
|
|
1489
|
+
# Solve H y = (beta, 0, ..., 0)
|
|
1490
|
+
# H Hessenberg matrix of shape (k+1, k)
|
|
1491
|
+
|
|
1492
|
+
i, lane = wp.tid()
|
|
1493
|
+
|
|
1494
|
+
rhs = beta[0]
|
|
1495
|
+
|
|
1496
|
+
zero = H.dtype(0.0)
|
|
1497
|
+
one = H.dtype(1.0)
|
|
1498
|
+
yi = zero
|
|
1499
|
+
|
|
1500
|
+
Ha = wp.tile_load(H[0], shape=(K))
|
|
1501
|
+
|
|
1502
|
+
# Apply 2x2 rotations to H so as to remove lower diagonal,
|
|
1503
|
+
# and apply similar rotations to right-hand-side
|
|
1504
|
+
max_k = int(k)
|
|
1505
|
+
for i in range(k):
|
|
1506
|
+
# Ha = H[i]
|
|
1507
|
+
# Hb = H[i + 1]
|
|
1508
|
+
Hb = wp.tile_load(H[i + 1], shape=(K))
|
|
1509
|
+
|
|
1510
|
+
# Givens rotation [[c s], [-s c]]
|
|
1511
|
+
a = Ha[i]
|
|
1512
|
+
b = Hb[i]
|
|
1513
|
+
abn_sq = a * a + b * b
|
|
1514
|
+
|
|
1515
|
+
if abn_sq < type(abn_sq)(pivot_tolerance):
|
|
1516
|
+
# Arnoldi iteration finished early
|
|
1517
|
+
max_k = i
|
|
1518
|
+
break
|
|
1519
|
+
|
|
1520
|
+
abn = wp.sqrt(abn_sq)
|
|
1521
|
+
c = a / abn
|
|
1522
|
+
s = b / abn
|
|
1523
|
+
|
|
1524
|
+
# Rotate H
|
|
1525
|
+
a = wp.untile(Ha)
|
|
1526
|
+
b = wp.untile(Hb)
|
|
1527
|
+
a_rot = c * a + s * b
|
|
1528
|
+
b_rot = c * b - s * a
|
|
1529
|
+
|
|
1530
|
+
# Rotate rhs
|
|
1531
|
+
if lane == i:
|
|
1532
|
+
yi = c * rhs
|
|
1533
|
+
rhs = -s * rhs
|
|
1534
|
+
|
|
1535
|
+
wp.tile_store(H[i], wp.tile(a_rot))
|
|
1536
|
+
Ha[lane] = b_rot
|
|
1537
|
+
|
|
1538
|
+
y_tile = wp.tile(yi)
|
|
1539
|
+
|
|
1540
|
+
# Triangular back-solve for y
|
|
1541
|
+
for ii in range(max_k, 0, -1):
|
|
1542
|
+
i = ii - 1
|
|
1543
|
+
|
|
1544
|
+
Hi = wp.tile_load(H[i], shape=(K))
|
|
1545
|
+
|
|
1546
|
+
il = lane + i
|
|
1547
|
+
if lane == 0:
|
|
1548
|
+
yl = y_tile[i]
|
|
1549
|
+
elif il < max_k:
|
|
1550
|
+
yl = -y_tile[il] * Hi[il]
|
|
1551
|
+
else:
|
|
1552
|
+
yl = zero
|
|
1553
|
+
|
|
1554
|
+
yit = wp.tile_sum(wp.tile(yl)) * (one / Hi[i])
|
|
1555
|
+
yit[0] # no-op, movs yit to shared
|
|
1556
|
+
wp.tile_assign(y_tile, yit, offset=(i,))
|
|
1557
|
+
|
|
1558
|
+
wp.tile_store(y, y_tile)
|
|
1559
|
+
|
|
1560
|
+
return gmres_solve_least_squares_tiled
|
|
1561
|
+
|
|
1562
|
+
|
|
1113
1563
|
@wp.kernel
|
|
1114
1564
|
def _gmres_arnoldi_axpy_kernel(
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1565
|
+
V: wp.array2d(dtype=Any),
|
|
1566
|
+
w: wp.array(dtype=Any),
|
|
1567
|
+
Vw: wp.array2d(dtype=Any),
|
|
1118
1568
|
):
|
|
1119
|
-
tid = wp.tid()
|
|
1120
|
-
|
|
1569
|
+
tid, lane = wp.tid()
|
|
1570
|
+
|
|
1571
|
+
s = w.dtype(Vw.dtype(0))
|
|
1572
|
+
|
|
1573
|
+
tile_size = wp.block_dim()
|
|
1574
|
+
for k in range(lane, Vw.shape[0], tile_size):
|
|
1575
|
+
s += Vw[k, 0] * V[k, tid]
|
|
1576
|
+
|
|
1577
|
+
wi = wp.tile_load(w, shape=1, offset=tid)
|
|
1578
|
+
wi -= wp.tile_sum(wp.tile(s, preserve_type=True))
|
|
1579
|
+
|
|
1580
|
+
wp.tile_store(w, wi, offset=tid)
|
|
1121
1581
|
|
|
1122
1582
|
|
|
1123
1583
|
@wp.kernel
|
|
@@ -1125,9 +1585,14 @@ def _gmres_arnoldi_normalize_kernel(
|
|
|
1125
1585
|
x: wp.array(dtype=Any),
|
|
1126
1586
|
y: wp.array(dtype=Any),
|
|
1127
1587
|
alpha: wp.array(dtype=Any),
|
|
1588
|
+
alpha_copy: wp.array(dtype=Any),
|
|
1128
1589
|
):
|
|
1129
1590
|
tid = wp.tid()
|
|
1130
|
-
|
|
1591
|
+
norm = wp.sqrt(alpha[0])
|
|
1592
|
+
y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / norm)
|
|
1593
|
+
|
|
1594
|
+
if tid == 0:
|
|
1595
|
+
alpha_copy[0] = norm
|
|
1131
1596
|
|
|
1132
1597
|
|
|
1133
1598
|
@wp.kernel
|