warp-lang 1.8.0__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/optim/linear.py CHANGED
@@ -13,12 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from math import sqrt
16
+ import functools
17
+ import math
17
18
  from typing import Any, Callable, Optional, Tuple, Union
18
19
 
19
20
  import warp as wp
20
21
  import warp.sparse as sparse
21
- from warp.utils import array_inner
22
+ from warp.types import type_length, type_scalar_type
23
+
24
+ __all__ = ["LinearOperator", "aslinearoperator", "bicgstab", "cg", "cr", "gmres", "preconditioner"]
22
25
 
23
26
  # No need to auto-generate adjoint code for linear solvers
24
27
  wp.set_module_options({"enable_backward": False})
@@ -103,19 +106,33 @@ def aslinearoperator(A: _Matrix) -> LinearOperator:
103
106
  sparse.bsr_mv(A, x, z, alpha, beta)
104
107
 
105
108
  def dense_mv(x, y, z, alpha, beta):
106
- wp.launch(_dense_mv_kernel, dim=A.shape[1], device=A.device, inputs=[A, x, y, z, alpha, beta])
109
+ alpha = A.dtype(alpha)
110
+ beta = A.dtype(beta)
111
+ if A.device.is_cuda:
112
+ tile_size = 1 << min(10, max(5, math.ceil(math.log2(A.shape[1]))))
113
+ else:
114
+ tile_size = 1
115
+ wp.launch(
116
+ _dense_mv_kernel,
117
+ dim=(A.shape[0], tile_size),
118
+ block_dim=tile_size,
119
+ device=A.device,
120
+ inputs=[A, x, y, z, alpha, beta],
121
+ )
107
122
 
108
- def diag_mv(x, y, z, alpha, beta):
109
- scalar_type = wp.types.type_scalar_type(A.dtype)
123
+ def diag_mv_impl(A, x, y, z, alpha, beta):
124
+ scalar_type = type_scalar_type(A.dtype)
110
125
  alpha = scalar_type(alpha)
111
126
  beta = scalar_type(beta)
112
127
  wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
113
128
 
129
+ def diag_mv(x, y, z, alpha, beta):
130
+ return diag_mv_impl(A, x, y, z, alpha, beta)
131
+
114
132
  def diag_mv_vec(x, y, z, alpha, beta):
115
- scalar_type = wp.types.type_scalar_type(A.dtype)
116
- alpha = scalar_type(alpha)
117
- beta = scalar_type(beta)
118
- wp.launch(_diag_mv_vec_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
133
+ return diag_mv_impl(
134
+ _as_scalar_array(A), _as_scalar_array(x), _as_scalar_array(y), _as_scalar_array(z), alpha, beta
135
+ )
119
136
 
120
137
  if isinstance(A, wp.array):
121
138
  if A.ndim == 2:
@@ -183,6 +200,147 @@ def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
183
200
  raise ValueError(f"Unsupported preconditioner type '{ptype}'")
184
201
 
185
202
 
203
+ def _as_scalar_array(x: wp.array):
204
+ scalar_type = type_scalar_type(x.dtype)
205
+ if scalar_type == x.dtype:
206
+ return x
207
+
208
+ dlen = type_length(x.dtype)
209
+ arr = wp.array(
210
+ ptr=x.ptr,
211
+ shape=(*x.shape[:-1], x.shape[-1] * dlen),
212
+ strides=(*x.strides[:-1], x.strides[-1] // dlen),
213
+ dtype=scalar_type,
214
+ device=x.device,
215
+ grad=None if x.grad is None else _as_scalar_array(x.grad),
216
+ )
217
+ arr._ref = x
218
+ return arr
219
+
220
+
221
+ class TiledDot:
222
+ """
223
+ Computes the dot product of two arrays in a way that is compatible with CUDA sub-graphs.
224
+ """
225
+
226
+ def __init__(self, max_length: int, scalar_type: type, tile_size=512, device=None, max_column_count: int = 1):
227
+ self.tile_size = tile_size
228
+ self.device = device
229
+ self.max_column_count = max_column_count
230
+
231
+ num_blocks = (max_length + self.tile_size - 1) // self.tile_size
232
+ scratch = wp.empty(
233
+ shape=(2, max_column_count, num_blocks),
234
+ dtype=scalar_type,
235
+ device=self.device,
236
+ )
237
+ self.partial_sums_a = scratch[0]
238
+ self.partial_sums_b = scratch[1]
239
+
240
+ self.dot_kernel, self.sum_kernel = _create_tiled_dot_kernels(self.tile_size)
241
+
242
+ rounds = 0
243
+ length = num_blocks
244
+ while length > 1:
245
+ length = (length + self.tile_size - 1) // self.tile_size
246
+ rounds += 1
247
+
248
+ self.rounds = rounds
249
+
250
+ self._output = self.partial_sums_a if rounds % 2 == 0 else self.partial_sums_b
251
+
252
+ self.dot_launch: wp.Launch = wp.launch(
253
+ self.dot_kernel,
254
+ dim=(max_column_count, num_blocks, self.tile_size),
255
+ inputs=(self.partial_sums_a, self.partial_sums_b),
256
+ outputs=(self.partial_sums_a,),
257
+ block_dim=self.tile_size,
258
+ record_cmd=True,
259
+ )
260
+ self.sum_launch: wp.Launch = wp.launch(
261
+ self.sum_kernel,
262
+ dim=(max_column_count, num_blocks, self.tile_size),
263
+ inputs=(self.partial_sums_a,),
264
+ outputs=(self.partial_sums_b,),
265
+ block_dim=self.tile_size,
266
+ record_cmd=True,
267
+ )
268
+
269
+ # Result contains a single value, the sum of the array (will get updated by this function)
270
+ def compute(self, a: wp.array, b: wp.array, col_offset: int = 0):
271
+ a = _as_scalar_array(a)
272
+ b = _as_scalar_array(b)
273
+ if a.ndim == 1:
274
+ a = a.reshape((1, -1))
275
+ if b.ndim == 1:
276
+ b = b.reshape((1, -1))
277
+
278
+ column_count = a.shape[0]
279
+ num_blocks = (a.shape[1] + self.tile_size - 1) // self.tile_size
280
+
281
+ data_out = self.partial_sums_a[col_offset : col_offset + column_count]
282
+ data_in = self.partial_sums_b[col_offset : col_offset + column_count]
283
+
284
+ self.dot_launch.set_param_at_index(0, a)
285
+ self.dot_launch.set_param_at_index(1, b)
286
+ self.dot_launch.set_param_at_index(2, data_out)
287
+ self.dot_launch.set_dim((column_count, num_blocks, self.tile_size))
288
+ self.dot_launch.launch()
289
+
290
+ for _r in range(self.rounds):
291
+ array_length = num_blocks
292
+ num_blocks = (array_length + self.tile_size - 1) // self.tile_size
293
+ data_in, data_out = data_out, data_in
294
+
295
+ self.sum_launch.set_param_at_index(0, data_in)
296
+ self.sum_launch.set_param_at_index(1, data_out)
297
+ self.sum_launch.set_dim((column_count, num_blocks, self.tile_size))
298
+ self.sum_launch.launch()
299
+
300
+ return data_out
301
+
302
+ def col(self, col: int = 0):
303
+ return self._output[col][:1]
304
+
305
+ def cols(self, count, start: int = 0):
306
+ return self._output[start : start + count, :1]
307
+
308
+
309
+ @functools.lru_cache(maxsize=None)
310
+ def _create_tiled_dot_kernels(tile_size):
311
+ @wp.kernel
312
+ def block_dot_kernel(
313
+ a: wp.array2d(dtype=Any),
314
+ b: wp.array2d(dtype=Any),
315
+ partial_sums: wp.array2d(dtype=Any),
316
+ ):
317
+ column, block_id, tid_block = wp.tid()
318
+
319
+ start = block_id * tile_size
320
+
321
+ a_block = wp.tile_load(a[column], shape=tile_size, offset=start)
322
+ b_block = wp.tile_load(b[column], shape=tile_size, offset=start)
323
+ t = wp.tile_map(wp.mul, a_block, b_block)
324
+
325
+ tile_sum = wp.tile_sum(t)
326
+ wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
327
+
328
+ @wp.kernel
329
+ def block_sum_kernel(
330
+ data: wp.array2d(dtype=Any),
331
+ partial_sums: wp.array2d(dtype=Any),
332
+ ):
333
+ column, block_id, tid_block = wp.tid()
334
+ start = block_id * tile_size
335
+
336
+ t = wp.tile_load(data[column], shape=tile_size, offset=start)
337
+
338
+ tile_sum = wp.tile_sum(t)
339
+ wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
340
+
341
+ return block_dot_kernel, block_sum_kernel
342
+
343
+
186
344
  def cg(
187
345
  A: _Matrix,
188
346
  b: wp.array,
@@ -194,7 +352,7 @@ def cg(
194
352
  callback: Optional[Callable] = None,
195
353
  check_every=10,
196
354
  use_cuda_graph=True,
197
- ) -> Tuple[int, float, float]:
355
+ ) -> Union[Tuple[int, float, float], Tuple[wp.array, wp.array, wp.array]]:
198
356
  """Computes an approximate solution to a symmetric, positive-definite linear system
199
357
  using the Conjugate Gradient algorithm.
200
358
 
@@ -205,94 +363,107 @@ def cg(
205
363
  tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
206
364
  atol: absolute tolerance for the residual
207
365
  maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
208
- Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
209
366
  M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
210
- callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
367
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
368
+ If `check_every` is 0, the callback should be a Warp kernel.
211
369
  check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
370
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
371
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
372
+ to the maximum number of iterations.
212
373
  use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
213
- The linear operator and preconditioner must only perform graph-friendly operations.
374
+ The linear operator and preconditioner must only perform graph-friendly operations.
214
375
 
215
376
  Returns:
216
- Tuple (final iteration number, residual norm, absolute tolerance)
377
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
378
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
379
+ - residual_norm: The final residual norm ||b - Ax||
380
+ - absolute_tolerance: The absolute tolerance used for convergence checking
381
+
382
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
383
+ - final_iteration_array: Device array containing the number of iterations performed
384
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
385
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
217
386
 
218
387
  If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
219
388
  """
220
-
221
389
  A = aslinearoperator(A)
222
390
  M = aslinearoperator(M)
223
391
 
224
392
  if maxiter == 0:
225
393
  maxiter = A.shape[0]
226
394
 
227
- r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
228
-
229
395
  device = A.device
230
- scalar_dtype = wp.types.type_scalar_type(A.dtype)
396
+ scalar_type = A.scalar_type
231
397
 
232
- # Notations below follow pseudo-code from https://en.wikipedia.org/wiki/Conjugate_gradient_method
398
+ # Temp storage
399
+ r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
400
+ p_and_Ap = wp.empty_like(r_and_z)
401
+ residuals = wp.empty(2, dtype=scalar_type, device=device)
233
402
 
234
- # z = M r
235
- if M is not None:
236
- z = wp.zeros_like(b)
237
- M.matvec(r, z, z, alpha=1.0, beta=0.0)
403
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
238
404
 
239
- # rz = r' z;
240
- rz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
241
- array_inner(r, z, out=rz_new)
405
+ # named views
406
+
407
+ # (r, r) -- so we can compute r.z and r.r at once
408
+ r_repeated = _repeat_first(r_and_z)
409
+ if M is None:
410
+ # without preconditioner r == z
411
+ r_and_z = r_repeated
412
+ rz_new = tiled_dot.col(0)
242
413
  else:
243
- z = r
414
+ rz_new = tiled_dot.col(1)
415
+
416
+ r, z = r_and_z[0], r_and_z[1]
417
+ r_norm_sq = tiled_dot.col(0)
418
+
419
+ p, Ap = p_and_Ap[0], p_and_Ap[1]
420
+ rz_old, atol_sq = residuals[0:1], residuals[1:2]
421
+
422
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
423
+ Ap.zero_()
424
+ z.zero_()
425
+
426
+ # Initialize tolerance from right-hand-side norm
427
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
428
+ # Initialize residual
429
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
430
+
431
+ def update_rr_rz():
432
+ # z = M r
433
+ if M is None:
434
+ tiled_dot.compute(r, r)
435
+ else:
436
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
437
+ tiled_dot.compute(r_repeated, r_and_z)
244
438
 
245
- rz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
246
- p_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
247
- Ap = wp.zeros_like(b)
439
+ update_rr_rz()
440
+ p.assign(z)
248
441
 
249
- p = wp.clone(z)
442
+ def do_iteration():
443
+ rz_old.assign(rz_new)
250
444
 
251
- def do_iteration(atol_sq, rr_old, rr_new, rz_old, rz_new):
252
445
  # Ap = A * p;
253
446
  A.matvec(p, Ap, Ap, alpha=1, beta=0)
254
-
255
- array_inner(p, Ap, out=p_Ap)
447
+ tiled_dot.compute(p, Ap, col_offset=1)
448
+ p_Ap = tiled_dot.col(1)
256
449
 
257
450
  wp.launch(
258
451
  kernel=_cg_kernel_1,
259
452
  dim=x.shape[0],
260
453
  device=device,
261
- inputs=[atol_sq, rr_old, rz_old, p_Ap, x, r, p, Ap],
454
+ inputs=[atol_sq, r_norm_sq, rz_old, p_Ap, x, r, p, Ap],
262
455
  )
263
- array_inner(r, r, out=rr_new)
264
-
265
- # z = M r
266
- if M is not None:
267
- M.matvec(r, z, z, alpha=1.0, beta=0.0)
268
- # rz = r' z;
269
- array_inner(r, z, out=rz_new)
270
-
271
- wp.launch(kernel=_cg_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr_new, rz_old, rz_new, z, p])
272
456
 
273
- # We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
274
- # In the non-preconditioned case we reuse the error norm buffer for the new <r,z> computation
457
+ update_rr_rz()
275
458
 
276
- def do_odd_even_cycle(atol_sq: float):
277
- # A pair of iterations, so that we're swapping the residual buffers twice
278
- if M is None:
279
- do_iteration(atol_sq, r_norm_sq, rz_old, r_norm_sq, rz_old)
280
- do_iteration(atol_sq, rz_old, r_norm_sq, rz_old, r_norm_sq)
281
- else:
282
- do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_new, rz_old)
283
- do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_old, rz_new)
459
+ wp.launch(
460
+ kernel=_cg_kernel_2,
461
+ dim=z.shape[0],
462
+ device=device,
463
+ inputs=[atol_sq, r_norm_sq, rz_old, rz_new, z, p],
464
+ )
284
465
 
285
- return _run_solver_loop(
286
- do_odd_even_cycle,
287
- cycle_size=2,
288
- r_norm_sq=r_norm_sq,
289
- maxiter=maxiter,
290
- atol=atol,
291
- callback=callback,
292
- check_every=check_every,
293
- use_cuda_graph=use_cuda_graph,
294
- device=device,
295
- )
466
+ return _run_capturable_loop(do_iteration, r_norm_sq, maxiter, atol_sq, callback, check_every, use_cuda_graph)
296
467
 
297
468
 
298
469
  def cr(
@@ -319,13 +490,25 @@ def cr(
319
490
  maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
320
491
  Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
321
492
  M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
322
- callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
493
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
494
+ If `check_every` is 0, the callback should be a Warp kernel.
323
495
  check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
496
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
497
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
498
+ to the maximum number of iterations.
324
499
  use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
325
500
  The linear operator and preconditioner must only perform graph-friendly operations.
326
501
 
327
502
  Returns:
328
- Tuple (final iteration number, residual norm, absolute tolerance)
503
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
504
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
505
+ - residual_norm: The final residual norm ||b - Ax||
506
+ - absolute_tolerance: The absolute tolerance used for convergence checking
507
+
508
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
509
+ - final_iteration_array: Device array containing the number of iterations performed
510
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
511
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
329
512
 
330
513
  If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
331
514
  """
@@ -336,42 +519,65 @@ def cr(
336
519
  if maxiter == 0:
337
520
  maxiter = A.shape[0]
338
521
 
339
- r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
340
-
341
522
  device = A.device
342
- scalar_dtype = wp.types.type_scalar_type(A.dtype)
523
+ scalar_type = wp.types.type_scalar_type(A.dtype)
343
524
 
344
525
  # Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
345
526
  # with z := M^-1 r and y := M^-1 Ap
346
527
 
347
- # z = M r
528
+ # Temp storage
529
+ r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
530
+ r_and_Az = wp.empty_like(r_and_z)
531
+ y_and_Ap = wp.empty_like(r_and_z)
532
+ p = wp.empty_like(b)
533
+ residuals = wp.empty(2, dtype=scalar_type, device=device)
534
+
535
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
536
+
348
537
  if M is None:
349
- z = r
350
- else:
351
- z = wp.zeros_like(r)
352
- M.matvec(r, z, z, alpha=1.0, beta=0.0)
538
+ r_and_z = _repeat_first(r_and_z)
539
+ y_and_Ap = _repeat_first(y_and_Ap)
353
540
 
354
- Az = wp.zeros_like(b)
355
- A.matvec(z, Az, Az, alpha=1, beta=0)
541
+ # named views
542
+ r, z = r_and_z[0], r_and_z[1]
543
+ r_copy, Az = r_and_Az[0], r_and_Az[1]
356
544
 
357
- p = wp.clone(z)
358
- Ap = wp.clone(Az)
545
+ y, Ap = y_and_Ap[0], y_and_Ap[1]
359
546
 
360
- if M is None:
361
- y = Ap
362
- else:
363
- y = wp.zeros_like(Ap)
547
+ r_norm_sq = tiled_dot.col(0)
548
+ zAz_new = tiled_dot.col(1)
549
+ zAz_old, atol_sq = residuals[0:1], residuals[1:2]
550
+
551
+ # Initialize tolerance from right-hand-side norm
552
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
553
+ # Initialize residual
554
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
555
+
556
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
557
+ y_and_Ap.zero_()
364
558
 
365
- zAz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
366
- zAz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
367
- y_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
559
+ # z = M r
560
+ if M is not None:
561
+ z.zero_()
562
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
563
+
564
+ def update_rr_zAz():
565
+ A.matvec(z, Az, Az, alpha=1, beta=0)
566
+ r_copy.assign(r)
567
+ tiled_dot.compute(r_and_z, r_and_Az)
568
+
569
+ update_rr_zAz()
368
570
 
369
- array_inner(z, Az, out=zAz_new)
571
+ p.assign(z)
572
+ Ap.assign(Az)
573
+
574
+ def do_iteration():
575
+ zAz_old.assign(zAz_new)
370
576
 
371
- def do_iteration(atol_sq, rr, zAz_old, zAz_new):
372
577
  if M is not None:
373
578
  M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
374
- array_inner(Ap, y, out=y_Ap)
579
+ tiled_dot.compute(Ap, y, col_offset=1)
580
+ y_Ap = tiled_dot.col(1)
375
581
 
376
582
  if M is None:
377
583
  # In non-preconditioned case, first kernel is same as CG
@@ -379,7 +585,7 @@ def cr(
379
585
  kernel=_cg_kernel_1,
380
586
  dim=x.shape[0],
381
587
  device=device,
382
- inputs=[atol_sq, rr, zAz_old, y_Ap, x, r, p, Ap],
588
+ inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, p, Ap],
383
589
  )
384
590
  else:
385
591
  # In preconditioned case, we have one more vector to update
@@ -387,34 +593,26 @@ def cr(
387
593
  kernel=_cr_kernel_1,
388
594
  dim=x.shape[0],
389
595
  device=device,
390
- inputs=[atol_sq, rr, zAz_old, y_Ap, x, r, z, p, Ap, y],
596
+ inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, z, p, Ap, y],
391
597
  )
392
598
 
393
- array_inner(r, r, out=rr)
394
-
395
- A.matvec(z, Az, Az, alpha=1, beta=0)
396
- array_inner(z, Az, out=zAz_new)
397
-
398
- # beta = rz_new / rz_old
599
+ update_rr_zAz()
399
600
  wp.launch(
400
- kernel=_cr_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr, zAz_old, zAz_new, z, p, Az, Ap]
601
+ kernel=_cr_kernel_2,
602
+ dim=z.shape[0],
603
+ device=device,
604
+ inputs=[atol_sq, r_norm_sq, zAz_old, zAz_new, z, p, Az, Ap],
401
605
  )
402
606
 
403
- # We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
404
- def do_odd_even_cycle(atol_sq: float):
405
- do_iteration(atol_sq, r_norm_sq, zAz_new, zAz_old)
406
- do_iteration(atol_sq, r_norm_sq, zAz_old, zAz_new)
407
-
408
- return _run_solver_loop(
409
- do_odd_even_cycle,
410
- cycle_size=2,
607
+ return _run_capturable_loop(
608
+ do_iteration,
609
+ cycle_size=1,
411
610
  r_norm_sq=r_norm_sq,
412
611
  maxiter=maxiter,
413
- atol=atol,
612
+ atol_sq=atol_sq,
414
613
  callback=callback,
415
614
  check_every=check_every,
416
615
  use_cuda_graph=use_cuda_graph,
417
- device=device,
418
616
  )
419
617
 
420
618
 
@@ -441,14 +639,26 @@ def bicgstab(
441
639
  atol: absolute tolerance for the residual
442
640
  maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
443
641
  M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
444
- callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
642
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
643
+ If `check_every` is 0, the callback should be a Warp kernel.
445
644
  check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
645
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
646
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
647
+ to the maximum number of iterations.
446
648
  use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
447
- The linear operator and preconditioner must only perform graph-friendly operations.
649
+ The linear operator and preconditioner must only perform graph-friendly operations.
448
650
  is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
449
651
 
450
652
  Returns:
451
- Tuple (final iteration number, residual norm, absolute tolerance)
653
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
654
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
655
+ - residual_norm: The final residual norm ||b - Ax||
656
+ - absolute_tolerance: The absolute tolerance used for convergence checking
657
+
658
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
659
+ - final_iteration_array: Device array containing the number of iterations performed
660
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
661
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
452
662
 
453
663
  If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
454
664
  """
@@ -458,23 +668,19 @@ def bicgstab(
458
668
  if maxiter == 0:
459
669
  maxiter = A.shape[0]
460
670
 
461
- r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
462
-
463
671
  device = A.device
464
- scalar_dtype = wp.types.type_scalar_type(A.dtype)
672
+ scalar_type = wp.types.type_scalar_type(A.dtype)
465
673
 
466
674
  # Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
467
675
 
468
- rho = wp.clone(r_norm_sq, pinned=False)
469
- r0v = wp.empty(n=1, dtype=scalar_dtype, device=device)
470
- st = wp.empty(n=1, dtype=scalar_dtype, device=device)
471
- tt = wp.empty(n=1, dtype=scalar_dtype, device=device)
676
+ # Temp storage
677
+ r_and_r0 = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
678
+ p = wp.empty_like(b)
679
+ v = wp.empty_like(b)
680
+ t = wp.empty_like(b)
472
681
 
473
- # work arrays
474
- r0 = wp.clone(r)
475
- v = wp.zeros_like(r)
476
- t = wp.zeros_like(r)
477
- p = wp.clone(r)
682
+ r, r0 = r_and_r0[0], r_and_r0[1]
683
+ r_repeated = _repeat_first(r_and_r0)
478
684
 
479
685
  if M is not None:
480
686
  y = wp.zeros_like(p)
@@ -486,7 +692,27 @@ def bicgstab(
486
692
  z = r
487
693
  Mt = t
488
694
 
489
- def do_iteration(atol_sq: float):
695
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=5)
696
+ r_norm_sq = tiled_dot.col(0)
697
+ rho = tiled_dot.col(1)
698
+
699
+ atol_sq = wp.empty(1, dtype=scalar_type, device=device)
700
+
701
+ # Initialize tolerance from right-hand-side norm
702
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
703
+ # Initialize residual
704
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
705
+ tiled_dot.compute(r, r, col_offset=0)
706
+
707
+ p.assign(r)
708
+ r0.assign(r)
709
+ rho.assign(r_norm_sq)
710
+
711
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
712
+ v.zero_()
713
+ t.zero_()
714
+
715
+ def do_iteration():
490
716
  # y = M p
491
717
  if M is not None:
492
718
  M.matvec(p, y, y, alpha=1.0, beta=0.0)
@@ -495,7 +721,8 @@ def bicgstab(
495
721
  A.matvec(y, v, v, alpha=1, beta=0)
496
722
 
497
723
  # alpha = rho / <r0 . v>
498
- array_inner(r0, v, out=r0v)
724
+ tiled_dot.compute(r0, v, col_offset=2)
725
+ r0v = tiled_dot.col(2)
499
726
 
500
727
  # x += alpha y
501
728
  # r -= alpha v
@@ -505,7 +732,7 @@ def bicgstab(
505
732
  device=device,
506
733
  inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
507
734
  )
508
- array_inner(r, r, out=r_norm_sq)
735
+ tiled_dot.compute(r, r, col_offset=0)
509
736
 
510
737
  # z = M r
511
738
  if M is not None:
@@ -514,17 +741,18 @@ def bicgstab(
514
741
  # t = A z
515
742
  A.matvec(z, t, t, alpha=1, beta=0)
516
743
 
517
- if is_left_preconditioner:
744
+ if M is not None and is_left_preconditioner:
518
745
  # Mt = M t
519
- if M is not None:
520
- M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
746
+ M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
521
747
 
522
748
  # omega = <Mt, Ms> / <Mt, Mt>
523
- array_inner(z, Mt, out=st)
524
- array_inner(Mt, Mt, out=tt)
749
+ tiled_dot.compute(z, Mt, col_offset=3)
750
+ tiled_dot.compute(Mt, Mt, col_offset=4)
525
751
  else:
526
- array_inner(r, t, out=st)
527
- array_inner(t, t, out=tt)
752
+ tiled_dot.compute(r, t, col_offset=3)
753
+ tiled_dot.compute(t, t, col_offset=4)
754
+ st = tiled_dot.col(3)
755
+ tt = tiled_dot.col(4)
528
756
 
529
757
  # x += omega z
530
758
  # r -= omega t
@@ -534,10 +762,9 @@ def bicgstab(
534
762
  device=device,
535
763
  inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
536
764
  )
537
- array_inner(r, r, out=r_norm_sq)
538
765
 
539
- # rho = <r0, r>
540
- array_inner(r0, r, out=rho)
766
+ # r = <r,r>, rho = <r0, r>
767
+ tiled_dot.compute(r_and_r0, r_repeated, col_offset=0)
541
768
 
542
769
  # beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
543
770
  # p = r + beta (p - omega v)
@@ -548,16 +775,14 @@ def bicgstab(
548
775
  inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
549
776
  )
550
777
 
551
- return _run_solver_loop(
778
+ return _run_capturable_loop(
552
779
  do_iteration,
553
- cycle_size=1,
554
780
  r_norm_sq=r_norm_sq,
555
781
  maxiter=maxiter,
556
- atol=atol,
782
+ atol_sq=atol_sq,
557
783
  callback=callback,
558
784
  check_every=check_every,
559
785
  use_cuda_graph=use_cuda_graph,
560
- device=device,
561
786
  )
562
787
 
563
788
 
@@ -587,14 +812,26 @@ def gmres(
587
812
  maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
588
813
  Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
589
814
  M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
590
- callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
815
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
816
+ If `check_every` is 0, the callback should be a Warp kernel.
591
817
  check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
818
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
819
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
820
+ to the maximum number of iterations.
592
821
  use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
593
822
  The linear operator and preconditioner must only perform graph-friendly operations.
594
823
  is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
595
824
 
596
825
  Returns:
597
- Tuple (final iteration number, residual norm, absolute tolerance)
826
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
827
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
828
+ - residual_norm: The final residual norm ||b - Ax||
829
+ - absolute_tolerance: The absolute tolerance used for convergence checking
830
+
831
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
832
+ - final_iteration_array: Device array containing the number of iterations performed
833
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
834
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
598
835
 
599
836
  If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
600
837
  """
@@ -606,89 +843,132 @@ def gmres(
606
843
  maxiter = A.shape[0]
607
844
 
608
845
  restart = min(restart, maxiter)
609
- check_every = max(restart, check_every)
610
846
 
611
- r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
847
+ if check_every > 0:
848
+ check_every = max(restart, check_every)
612
849
 
613
850
  device = A.device
614
851
  scalar_dtype = wp.types.type_scalar_type(A.dtype)
615
852
 
616
853
  pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
617
854
 
618
- beta_sq = wp.empty_like(r_norm_sq, pinned=False)
619
- H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
855
+ r = wp.empty_like(b)
856
+ w = wp.empty_like(r)
620
857
 
858
+ H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
621
859
  y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
622
860
 
623
- w = wp.zeros_like(r)
624
861
  V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
625
862
 
626
- def array_coeff(H, i, j):
627
- return wp.array(
628
- ptr=H.ptr + i * H.strides[0] + j * H.strides[1],
629
- dtype=H.dtype,
630
- shape=(1,),
631
- device=H.device,
632
- copy=False,
633
- )
863
+ residuals = wp.empty(2, dtype=scalar_dtype, device=device)
864
+ beta, atol_sq = residuals[0:1], residuals[1:2]
634
865
 
635
- def array_row(V, i):
636
- return wp.array(
637
- ptr=V.ptr + i * V.strides[0],
638
- dtype=V.dtype,
639
- shape=V.shape[1],
640
- device=V.device,
641
- copy=False,
642
- )
866
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_dtype, max_column_count=restart + 1)
867
+ r_norm_sq = tiled_dot.col(0)
643
868
 
644
- def do_arnoldi_iteration(j: int):
645
- # w = A * v;
869
+ w_repeated = wp.array(
870
+ ptr=w.ptr, shape=(restart + 1, w.shape[0]), strides=(0, w.strides[0]), dtype=w.dtype, device=w.device
871
+ )
646
872
 
647
- vj = array_row(V, j)
873
+ # tile size for least square solve
874
+ # (need to fit in a CUDA block, so 1024 max)
875
+ if device.is_cuda and 4 < restart <= 1024:
876
+ tile_size = 1 << math.ceil(math.log2(restart))
877
+ least_squares_kernel = make_gmres_solve_least_squares_kernel_tiled(tile_size)
878
+ else:
879
+ tile_size = 1
880
+ least_squares_kernel = _gmres_solve_least_squares
881
+
882
+ # recorded launches
883
+ least_squares_solve = wp.launch(
884
+ least_squares_kernel,
885
+ dim=(1, tile_size),
886
+ block_dim=tile_size if tile_size > 1 else 256,
887
+ device=device,
888
+ inputs=[restart, pivot_tolerance, beta, H, y],
889
+ record_cmd=True,
890
+ )
891
+
892
+ normalize_anorldi_vec = wp.launch(
893
+ _gmres_arnoldi_normalize_kernel,
894
+ dim=r.shape,
895
+ device=r.device,
896
+ inputs=[r, w, tiled_dot.col(0), beta],
897
+ record_cmd=True,
898
+ )
899
+
900
+ arnoldi_axpy = wp.launch(
901
+ _gmres_arnoldi_axpy_kernel,
902
+ dim=(w.shape[0], tile_size),
903
+ block_dim=tile_size,
904
+ device=w.device,
905
+ inputs=[V, w, H],
906
+ record_cmd=True,
907
+ )
908
+
909
+ # Initialize tolerance from right-hand-side norm
910
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
911
+ # Initialize residual
912
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
913
+ tiled_dot.compute(r, r, col_offset=0)
914
+
915
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
916
+ w.zero_()
917
+
918
+ def array_coeff(H, i, j):
919
+ return H[i][j : j + 1]
648
920
 
921
+ def array_col(H, j):
922
+ return H[: j + 1, j : j + 1]
923
+
924
+ def do_arnoldi_iteration(j: int):
925
+ # w = A * v[j];
649
926
  if M is not None:
650
- tmp = array_row(V, j + 1)
927
+ tmp = V[j + 1]
651
928
 
652
929
  if is_left_preconditioner:
653
- A.matvec(vj, tmp, tmp, alpha=1, beta=0)
930
+ A.matvec(V[j], tmp, tmp, alpha=1, beta=0)
654
931
  M.matvec(tmp, w, w, alpha=1, beta=0)
655
932
  else:
656
- M.matvec(vj, tmp, tmp, alpha=1, beta=0)
933
+ M.matvec(V[j], tmp, tmp, alpha=1, beta=0)
657
934
  A.matvec(tmp, w, w, alpha=1, beta=0)
658
935
  else:
659
- A.matvec(vj, w, w, alpha=1, beta=0)
936
+ A.matvec(V[j], w, w, alpha=1, beta=0)
660
937
 
661
- for i in range(j + 1):
662
- vi = array_row(V, i)
663
- hij = array_coeff(H, i, j)
664
- array_inner(w, vi, out=hij)
938
+ # compute and apply dot products in rappel,
939
+ # since Hj columns are orthogonal
940
+ Hj = array_col(H, j)
941
+ tiled_dot.compute(w_repeated, V[: j + 1])
942
+ wp.copy(src=tiled_dot.cols(j + 1), dest=Hj)
665
943
 
666
- wp.launch(_gmres_arnoldi_axpy_kernel, dim=w.shape, device=w.device, inputs=[vi, w, hij])
944
+ # w -= w.vi vi
945
+ arnoldi_axpy.set_params([V[: j + 1], w, Hj])
946
+ arnoldi_axpy.launch()
667
947
 
668
- hjnj = array_coeff(H, j + 1, j)
669
- array_inner(w, w, out=hjnj)
948
+ # H[j+1, j] = |w.w|
949
+ tiled_dot.compute(w, w)
950
+ normalize_anorldi_vec.set_params([w, V[j + 1], tiled_dot.col(0), array_coeff(H, j + 1, j)])
670
951
 
671
- vjn = array_row(V, j + 1)
672
- wp.launch(_gmres_arnoldi_normalize_kernel, dim=w.shape, device=w.device, inputs=[w, vjn, hjnj])
952
+ normalize_anorldi_vec.launch()
673
953
 
674
- def do_restart_cycle(atol_sq: float):
954
+ def do_restart_cycle():
675
955
  if M is not None and is_left_preconditioner:
676
956
  M.matvec(r, w, w, alpha=1, beta=0)
677
957
  rh = w
678
958
  else:
679
959
  rh = r
680
960
 
681
- array_inner(rh, rh, out=beta_sq)
961
+ # beta^2 = rh.rh
962
+ tiled_dot.compute(rh, rh)
682
963
 
683
- v0 = array_row(V, 0)
684
- # v0 = r / beta
685
- wp.launch(_gmres_arnoldi_normalize_kernel, dim=r.shape, device=r.device, inputs=[rh, v0, beta_sq])
964
+ # v[0] = r / beta
965
+ normalize_anorldi_vec.set_params([rh, V[0], tiled_dot.col(0), beta])
966
+ normalize_anorldi_vec.launch()
686
967
 
687
968
  for j in range(restart):
688
969
  do_arnoldi_iteration(j)
689
970
 
690
- wp.launch(_gmres_normalize_lower_diagonal, dim=restart, device=device, inputs=[H])
691
- wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, pivot_tolerance, beta_sq, H, y])
971
+ least_squares_solve.launch()
692
972
 
693
973
  # update x
694
974
  if M is None or is_left_preconditioner:
@@ -700,21 +980,33 @@ def gmres(
700
980
  # update r and residual
701
981
  wp.copy(src=b, dest=r)
702
982
  A.matvec(x, b, r, alpha=-1.0, beta=1.0)
703
- array_inner(r, r, out=r_norm_sq)
983
+ tiled_dot.compute(r, r)
704
984
 
705
- return _run_solver_loop(
985
+ return _run_capturable_loop(
706
986
  do_restart_cycle,
707
987
  cycle_size=restart,
708
988
  r_norm_sq=r_norm_sq,
709
989
  maxiter=maxiter,
710
- atol=atol,
990
+ atol_sq=atol_sq,
711
991
  callback=callback,
712
992
  check_every=check_every,
713
993
  use_cuda_graph=use_cuda_graph,
714
- device=device,
715
994
  )
716
995
 
717
996
 
997
+ def _repeat_first(arr: wp.array):
998
+ # returns a view of the first element repeated arr.shape[0] times
999
+ view = wp.array(
1000
+ ptr=arr.ptr,
1001
+ shape=arr.shape,
1002
+ dtype=arr.dtype,
1003
+ strides=(0, *arr.strides[1:]),
1004
+ device=arr.device,
1005
+ )
1006
+ view._ref = arr
1007
+ return view
1008
+
1009
+
718
1010
  def _get_dtype_epsilon(dtype):
719
1011
  if dtype == wp.float64:
720
1012
  return 1.0e-16
@@ -724,7 +1016,7 @@ def _get_dtype_epsilon(dtype):
724
1016
  return 1.0e-8
725
1017
 
726
1018
 
727
- def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
1019
+ def _get_tolerances(dtype, tol, atol):
728
1020
  eps_tol = _get_dtype_epsilon(dtype)
729
1021
  default_tol = eps_tol ** (3 / 4)
730
1022
  min_tol = eps_tol ** (9 / 4)
@@ -736,27 +1028,115 @@ def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
736
1028
  elif atol is None:
737
1029
  atol = tol
738
1030
 
739
- return max(tol * lhs_norm, atol, min_tol)
1031
+ atol = max(atol, min_tol)
1032
+ return tol, atol
740
1033
 
741
1034
 
742
- def _initialize_residual_and_tolerance(A: LinearOperator, b: wp.array, x: wp.array, tol: float, atol: float):
743
- scalar_dtype = wp.types.type_scalar_type(A.dtype)
744
- device = A.device
1035
+ @wp.kernel
1036
+ def _initialize_tolerance(
1037
+ rtol: Any,
1038
+ atol: Any,
1039
+ r_norm_sq: wp.array(dtype=Any),
1040
+ atol_sq: wp.array(dtype=Any),
1041
+ ):
1042
+ atol = wp.max(rtol * wp.sqrt(r_norm_sq[0]), atol)
1043
+ atol_sq[0] = atol * atol
745
1044
 
746
- # Buffer for storing square norm or residual
747
- r_norm_sq = wp.empty(n=1, dtype=scalar_dtype, device=device, pinned=device.is_cuda)
1045
+
1046
+ def _initialize_absolute_tolerance(
1047
+ b: wp.array,
1048
+ tol: float,
1049
+ atol: float,
1050
+ tiled_dot: TiledDot,
1051
+ atol_sq: wp.array,
1052
+ ):
1053
+ scalar_type = atol_sq.dtype
748
1054
 
749
1055
  # Compute b norm to define absolute tolerance
750
- array_inner(b, b, out=r_norm_sq)
751
- atol = _get_absolute_tolerance(scalar_dtype, tol, atol, sqrt(r_norm_sq.numpy()[0]))
1056
+ tiled_dot.compute(b, b)
1057
+ b_norm_sq = tiled_dot.col(0)
1058
+
1059
+ rtol, atol = _get_tolerances(scalar_type, tol, atol)
1060
+ wp.launch(
1061
+ kernel=_initialize_tolerance,
1062
+ dim=1,
1063
+ device=b.device,
1064
+ inputs=[scalar_type(rtol), scalar_type(atol), b_norm_sq, atol_sq],
1065
+ )
752
1066
 
753
- # Residual r = b - Ax
754
- r = wp.empty_like(b)
755
- A.matvec(x, b, r, alpha=-1.0, beta=1.0)
756
1067
 
757
- array_inner(r, r, out=r_norm_sq)
1068
+ @wp.kernel
1069
+ def _update_condition(
1070
+ maxiter: int,
1071
+ cycle_size: int,
1072
+ cur_iter: wp.array(dtype=int),
1073
+ r_norm_sq: wp.array(dtype=Any),
1074
+ atol_sq: wp.array(dtype=Any),
1075
+ condition: wp.array(dtype=int),
1076
+ ):
1077
+ cur_iter[0] += cycle_size
1078
+ condition[0] = wp.where(r_norm_sq[0] <= atol_sq[0] or cur_iter[0] >= maxiter, 0, 1)
1079
+
1080
+
1081
+ def _run_capturable_loop(
1082
+ do_cycle: Callable,
1083
+ r_norm_sq: wp.array,
1084
+ maxiter: int,
1085
+ atol_sq: wp.array,
1086
+ callback: Optional[Callable],
1087
+ check_every: int,
1088
+ use_cuda_graph: bool,
1089
+ cycle_size: int = 1,
1090
+ ):
1091
+ device = atol_sq.device
1092
+
1093
+ if check_every > 0:
1094
+ atol = math.sqrt(atol_sq.numpy()[0])
1095
+ return _run_solver_loop(
1096
+ do_cycle, cycle_size, r_norm_sq, maxiter, atol, callback, check_every, use_cuda_graph, device
1097
+ )
1098
+
1099
+ cur_iter_and_condition = wp.full((2,), value=-1, dtype=int, device=device)
1100
+ cur_iter = cur_iter_and_condition[0:1]
1101
+ condition = cur_iter_and_condition[1:2]
1102
+
1103
+ update_condition_launch = wp.launch(
1104
+ _update_condition,
1105
+ dim=1,
1106
+ device=device,
1107
+ inputs=[int(maxiter), cycle_size, cur_iter, r_norm_sq, atol_sq, condition],
1108
+ record_cmd=True,
1109
+ )
1110
+
1111
+ if isinstance(callback, wp.Kernel):
1112
+ callback_launch = wp.launch(
1113
+ callback, dim=1, device=device, inputs=[cur_iter, r_norm_sq, atol_sq], record_cmd=True
1114
+ )
1115
+ else:
1116
+ callback_launch = None
1117
+
1118
+ update_condition_launch.launch()
1119
+ if callback_launch is not None:
1120
+ callback_launch.launch()
1121
+
1122
+ def do_cycle_with_condition():
1123
+ do_cycle()
1124
+ update_condition_launch.launch()
1125
+ if callback_launch is not None:
1126
+ callback_launch.launch()
758
1127
 
759
- return r, r_norm_sq, atol
1128
+ if use_cuda_graph and device.is_cuda:
1129
+ if device.is_capturing:
1130
+ wp.capture_while(condition, do_cycle_with_condition)
1131
+ else:
1132
+ with wp.ScopedCapture() as capture:
1133
+ wp.capture_while(condition, do_cycle_with_condition)
1134
+ wp.capture_launch(capture.graph)
1135
+ else:
1136
+ for _ in range(0, maxiter, cycle_size):
1137
+ do_cycle_with_condition()
1138
+
1139
+ return cur_iter, r_norm_sq, atol_sq
760
1140
 
761
1141
 
762
1142
  def _run_solver_loop(
@@ -771,11 +1151,12 @@ def _run_solver_loop(
771
1151
  device,
772
1152
  ):
773
1153
  atol_sq = atol * atol
1154
+ check_every = max(check_every, cycle_size)
774
1155
 
775
1156
  cur_iter = 0
776
1157
 
777
1158
  err_sq = r_norm_sq.numpy()[0]
778
- err = sqrt(err_sq)
1159
+ err = math.sqrt(err_sq)
779
1160
  if callback is not None:
780
1161
  callback(cur_iter, err, atol)
781
1162
 
@@ -788,14 +1169,12 @@ def _run_solver_loop(
788
1169
  # Do not do graph capture at first iteration -- modules may not be loaded yet
789
1170
  if device.is_cuda and use_cuda_graph and cur_iter > 0:
790
1171
  if graph is None:
791
- wp.capture_begin(device, force_module_load=False)
792
- try:
793
- do_cycle(atol_sq)
794
- finally:
795
- graph = wp.capture_end(device)
1172
+ with wp.ScopedCapture(force_module_load=False) as capture:
1173
+ do_cycle()
1174
+ graph = capture.graph
796
1175
  wp.capture_launch(graph)
797
1176
  else:
798
- do_cycle(atol_sq)
1177
+ do_cycle()
799
1178
 
800
1179
  cur_iter += cycle_size
801
1180
 
@@ -809,24 +1188,16 @@ def _run_solver_loop(
809
1188
  break
810
1189
 
811
1190
  if callback is not None:
812
- callback(cur_iter, sqrt(err_sq), atol)
1191
+ callback(cur_iter, math.sqrt(err_sq), atol)
813
1192
 
814
1193
  err_sq = r_norm_sq.numpy()[0]
815
- err = sqrt(err_sq)
1194
+ err = math.sqrt(err_sq)
816
1195
  if callback is not None:
817
1196
  callback(cur_iter, err, atol)
818
1197
 
819
1198
  return cur_iter, err, atol
820
1199
 
821
1200
 
822
- @wp.func
823
- def _calc_mv_product(i: wp.int32, A: wp.array2d(dtype=Any), x: wp.array1d(dtype=Any)):
824
- sum = A.dtype(0)
825
- for j in range(A.shape[1]):
826
- sum += A[i, j] * x[j]
827
- return sum
828
-
829
-
830
1201
  @wp.kernel
831
1202
  def _dense_mv_kernel(
832
1203
  A: wp.array2d(dtype=Any),
@@ -836,25 +1207,24 @@ def _dense_mv_kernel(
836
1207
  alpha: Any,
837
1208
  beta: Any,
838
1209
  ):
839
- i = wp.tid()
840
- z[i] = z.dtype(beta) * y[i] + z.dtype(alpha) * _calc_mv_product(i, A, x)
1210
+ row, lane = wp.tid()
841
1211
 
1212
+ zero = type(alpha)(0)
1213
+ s = zero
1214
+ if alpha != zero:
1215
+ for col in range(lane, A.shape[1], wp.block_dim()):
1216
+ s += A[row, col] * x[col]
842
1217
 
843
- @wp.kernel
844
- def _diag_mv_kernel(
845
- A: wp.array(dtype=Any),
846
- x: wp.array(dtype=Any),
847
- y: wp.array(dtype=Any),
848
- z: wp.array(dtype=Any),
849
- alpha: Any,
850
- beta: Any,
851
- ):
852
- i = wp.tid()
853
- z[i] = beta * y[i] + alpha * (A[i] * x[i])
1218
+ row_tile = wp.tile_sum(wp.tile(s * alpha))
1219
+
1220
+ if beta != zero:
1221
+ row_tile += wp.tile_load(y, shape=1, offset=row) * beta
1222
+
1223
+ wp.tile_store(z, row_tile, offset=row)
854
1224
 
855
1225
 
856
1226
  @wp.kernel
857
- def _diag_mv_vec_kernel(
1227
+ def _diag_mv_kernel(
858
1228
  A: wp.array(dtype=Any),
859
1229
  x: wp.array(dtype=Any),
860
1230
  y: wp.array(dtype=Any),
@@ -863,7 +1233,13 @@ def _diag_mv_vec_kernel(
863
1233
  beta: Any,
864
1234
  ):
865
1235
  i = wp.tid()
866
- z[i] = beta * y[i] + alpha * wp.cw_mul(A[i], x[i])
1236
+ zero = type(alpha)(0)
1237
+ s = z.dtype(zero)
1238
+ if alpha != zero:
1239
+ s += alpha * (A[i] * x[i])
1240
+ if beta != zero:
1241
+ s += beta * y[i]
1242
+ z[i] = s
867
1243
 
868
1244
 
869
1245
  @wp.func
@@ -910,7 +1286,7 @@ def _extract_inverse_diagonal_dense(
910
1286
 
911
1287
  @wp.kernel
912
1288
  def _cg_kernel_1(
913
- tol: Any,
1289
+ tol: wp.array(dtype=Any),
914
1290
  resid: wp.array(dtype=Any),
915
1291
  rz_old: wp.array(dtype=Any),
916
1292
  p_Ap: wp.array(dtype=Any),
@@ -921,7 +1297,7 @@ def _cg_kernel_1(
921
1297
  ):
922
1298
  i = wp.tid()
923
1299
 
924
- alpha = wp.where(resid[0] > tol, rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
1300
+ alpha = wp.where(resid[0] > tol[0], rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
925
1301
 
926
1302
  x[i] = x[i] + alpha * p[i]
927
1303
  r[i] = r[i] - alpha * Ap[i]
@@ -929,8 +1305,8 @@ def _cg_kernel_1(
929
1305
 
930
1306
  @wp.kernel
931
1307
  def _cg_kernel_2(
932
- tol: Any,
933
- resid: wp.array(dtype=Any),
1308
+ tol: wp.array(dtype=Any),
1309
+ resid_new: wp.array(dtype=Any),
934
1310
  rz_old: wp.array(dtype=Any),
935
1311
  rz_new: wp.array(dtype=Any),
936
1312
  z: wp.array(dtype=Any),
@@ -939,14 +1315,15 @@ def _cg_kernel_2(
939
1315
  # p = r + (rz_new / rz_old) * p;
940
1316
  i = wp.tid()
941
1317
 
942
- beta = wp.where(resid[0] > tol, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
1318
+ cond = resid_new[0] > tol[0]
1319
+ beta = wp.where(cond, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
943
1320
 
944
1321
  p[i] = z[i] + beta * p[i]
945
1322
 
946
1323
 
947
1324
  @wp.kernel
948
1325
  def _cr_kernel_1(
949
- tol: Any,
1326
+ tol: wp.array(dtype=Any),
950
1327
  resid: wp.array(dtype=Any),
951
1328
  zAz_old: wp.array(dtype=Any),
952
1329
  y_Ap: wp.array(dtype=Any),
@@ -959,7 +1336,7 @@ def _cr_kernel_1(
959
1336
  ):
960
1337
  i = wp.tid()
961
1338
 
962
- alpha = wp.where(resid[0] > tol and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
1339
+ alpha = wp.where(resid[0] > tol[0] and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
963
1340
 
964
1341
  x[i] = x[i] + alpha * p[i]
965
1342
  r[i] = r[i] - alpha * Ap[i]
@@ -968,7 +1345,7 @@ def _cr_kernel_1(
968
1345
 
969
1346
  @wp.kernel
970
1347
  def _cr_kernel_2(
971
- tol: Any,
1348
+ tol: wp.array(dtype=Any),
972
1349
  resid: wp.array(dtype=Any),
973
1350
  zAz_old: wp.array(dtype=Any),
974
1351
  zAz_new: wp.array(dtype=Any),
@@ -980,7 +1357,7 @@ def _cr_kernel_2(
980
1357
  # p = r + (rz_new / rz_old) * p;
981
1358
  i = wp.tid()
982
1359
 
983
- beta = wp.where(resid[0] > tol and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
1360
+ beta = wp.where(resid[0] > tol[0] and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
984
1361
 
985
1362
  p[i] = z[i] + beta * p[i]
986
1363
  Ap[i] = Az[i] + beta * Ap[i]
@@ -988,7 +1365,7 @@ def _cr_kernel_2(
988
1365
 
989
1366
  @wp.kernel
990
1367
  def _bicgstab_kernel_1(
991
- tol: Any,
1368
+ tol: wp.array(dtype=Any),
992
1369
  resid: wp.array(dtype=Any),
993
1370
  rho_old: wp.array(dtype=Any),
994
1371
  r0v: wp.array(dtype=Any),
@@ -999,7 +1376,7 @@ def _bicgstab_kernel_1(
999
1376
  ):
1000
1377
  i = wp.tid()
1001
1378
 
1002
- alpha = wp.where(resid[0] > tol, rho_old[0] / r0v[0], rho_old.dtype(0.0))
1379
+ alpha = wp.where(resid[0] > tol[0], rho_old[0] / r0v[0], rho_old.dtype(0.0))
1003
1380
 
1004
1381
  x[i] += alpha * y[i]
1005
1382
  r[i] -= alpha * v[i]
@@ -1007,7 +1384,7 @@ def _bicgstab_kernel_1(
1007
1384
 
1008
1385
  @wp.kernel
1009
1386
  def _bicgstab_kernel_2(
1010
- tol: Any,
1387
+ tol: wp.array(dtype=Any),
1011
1388
  resid: wp.array(dtype=Any),
1012
1389
  st: wp.array(dtype=Any),
1013
1390
  tt: wp.array(dtype=Any),
@@ -1018,7 +1395,7 @@ def _bicgstab_kernel_2(
1018
1395
  ):
1019
1396
  i = wp.tid()
1020
1397
 
1021
- omega = wp.where(resid[0] > tol, st[0] / tt[0], st.dtype(0.0))
1398
+ omega = wp.where(resid[0] > tol[0], st[0] / tt[0], st.dtype(0.0))
1022
1399
 
1023
1400
  x[i] += omega * z[i]
1024
1401
  r[i] -= omega * t[i]
@@ -1026,7 +1403,7 @@ def _bicgstab_kernel_2(
1026
1403
 
1027
1404
  @wp.kernel
1028
1405
  def _bicgstab_kernel_3(
1029
- tol: Any,
1406
+ tol: wp.array(dtype=Any),
1030
1407
  resid: wp.array(dtype=Any),
1031
1408
  rho_new: wp.array(dtype=Any),
1032
1409
  r0v: wp.array(dtype=Any),
@@ -1038,32 +1415,21 @@ def _bicgstab_kernel_3(
1038
1415
  ):
1039
1416
  i = wp.tid()
1040
1417
 
1041
- beta = wp.where(resid[0] > tol, rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1042
- beta_omega = wp.where(resid[0] > tol, rho_new[0] / r0v[0], st.dtype(0.0))
1418
+ beta = wp.where(resid[0] > tol[0], rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1419
+ beta_omega = wp.where(resid[0] > tol[0], rho_new[0] / r0v[0], st.dtype(0.0))
1043
1420
 
1044
1421
  p[i] = r[i] + beta * p[i] - beta_omega * v[i]
1045
1422
 
1046
1423
 
1047
- @wp.kernel
1048
- def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
1049
- # normalize lower-diagonal values of Hessenberg matrix
1050
- i = wp.tid()
1051
- H[i + 1, i] = wp.sqrt(H[i + 1, i])
1052
-
1053
-
1054
1424
  @wp.kernel
1055
1425
  def _gmres_solve_least_squares(
1056
- k: int, pivot_tolerance: float, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1426
+ k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1057
1427
  ):
1058
1428
  # Solve H y = (beta, 0, ..., 0)
1059
1429
  # H Hessenberg matrix of shape (k+1, k)
1060
-
1061
- # Keeping H in global mem; warp kernels are launched with fixed block size,
1062
1430
  # so would not fit in registers
1063
1431
 
1064
- # TODO: switch to native code with thread synchronization
1065
-
1066
- rhs = wp.sqrt(beta_sq[0])
1432
+ rhs = beta[0]
1067
1433
 
1068
1434
  # Apply 2x2 rotations to H so as to remove lower diagonal,
1069
1435
  # and apply similar rotations to right-hand-side
@@ -1110,14 +1476,108 @@ def _gmres_solve_least_squares(
1110
1476
  y[i] = yi / Hi[i]
1111
1477
 
1112
1478
 
1479
+ @functools.lru_cache(maxsize=None)
1480
+ def make_gmres_solve_least_squares_kernel_tiled(K: int):
1481
+ @wp.kernel(module="unique")
1482
+ def gmres_solve_least_squares_tiled(
1483
+ k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1484
+ ):
1485
+ # Assumes tiles of size K, and K at least as large as highest number of columns
1486
+ # Limits the max restart cycle length to the max block size of 1024, but using
1487
+ # larger restarts would be very inefficient anyway (default is ~30)
1488
+
1489
+ # Solve H y = (beta, 0, ..., 0)
1490
+ # H Hessenberg matrix of shape (k+1, k)
1491
+
1492
+ i, lane = wp.tid()
1493
+
1494
+ rhs = beta[0]
1495
+
1496
+ zero = H.dtype(0.0)
1497
+ one = H.dtype(1.0)
1498
+ yi = zero
1499
+
1500
+ Ha = wp.tile_load(H[0], shape=(K))
1501
+
1502
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
1503
+ # and apply similar rotations to right-hand-side
1504
+ max_k = int(k)
1505
+ for i in range(k):
1506
+ # Ha = H[i]
1507
+ # Hb = H[i + 1]
1508
+ Hb = wp.tile_load(H[i + 1], shape=(K))
1509
+
1510
+ # Givens rotation [[c s], [-s c]]
1511
+ a = Ha[i]
1512
+ b = Hb[i]
1513
+ abn_sq = a * a + b * b
1514
+
1515
+ if abn_sq < type(abn_sq)(pivot_tolerance):
1516
+ # Arnoldi iteration finished early
1517
+ max_k = i
1518
+ break
1519
+
1520
+ abn = wp.sqrt(abn_sq)
1521
+ c = a / abn
1522
+ s = b / abn
1523
+
1524
+ # Rotate H
1525
+ a = wp.untile(Ha)
1526
+ b = wp.untile(Hb)
1527
+ a_rot = c * a + s * b
1528
+ b_rot = c * b - s * a
1529
+
1530
+ # Rotate rhs
1531
+ if lane == i:
1532
+ yi = c * rhs
1533
+ rhs = -s * rhs
1534
+
1535
+ wp.tile_store(H[i], wp.tile(a_rot))
1536
+ Ha[lane] = b_rot
1537
+
1538
+ y_tile = wp.tile(yi)
1539
+
1540
+ # Triangular back-solve for y
1541
+ for ii in range(max_k, 0, -1):
1542
+ i = ii - 1
1543
+
1544
+ Hi = wp.tile_load(H[i], shape=(K))
1545
+
1546
+ il = lane + i
1547
+ if lane == 0:
1548
+ yl = y_tile[i]
1549
+ elif il < max_k:
1550
+ yl = -y_tile[il] * Hi[il]
1551
+ else:
1552
+ yl = zero
1553
+
1554
+ yit = wp.tile_sum(wp.tile(yl)) * (one / Hi[i])
1555
+ yit[0] # no-op, movs yit to shared
1556
+ wp.tile_assign(y_tile, yit, offset=(i,))
1557
+
1558
+ wp.tile_store(y, y_tile)
1559
+
1560
+ return gmres_solve_least_squares_tiled
1561
+
1562
+
1113
1563
  @wp.kernel
1114
1564
  def _gmres_arnoldi_axpy_kernel(
1115
- x: wp.array(dtype=Any),
1116
- y: wp.array(dtype=Any),
1117
- alpha: wp.array(dtype=Any),
1565
+ V: wp.array2d(dtype=Any),
1566
+ w: wp.array(dtype=Any),
1567
+ Vw: wp.array2d(dtype=Any),
1118
1568
  ):
1119
- tid = wp.tid()
1120
- y[tid] -= x[tid] * alpha[0]
1569
+ tid, lane = wp.tid()
1570
+
1571
+ s = w.dtype(Vw.dtype(0))
1572
+
1573
+ tile_size = wp.block_dim()
1574
+ for k in range(lane, Vw.shape[0], tile_size):
1575
+ s += Vw[k, 0] * V[k, tid]
1576
+
1577
+ wi = wp.tile_load(w, shape=1, offset=tid)
1578
+ wi -= wp.tile_sum(wp.tile(s, preserve_type=True))
1579
+
1580
+ wp.tile_store(w, wi, offset=tid)
1121
1581
 
1122
1582
 
1123
1583
  @wp.kernel
@@ -1125,9 +1585,14 @@ def _gmres_arnoldi_normalize_kernel(
1125
1585
  x: wp.array(dtype=Any),
1126
1586
  y: wp.array(dtype=Any),
1127
1587
  alpha: wp.array(dtype=Any),
1588
+ alpha_copy: wp.array(dtype=Any),
1128
1589
  ):
1129
1590
  tid = wp.tid()
1130
- y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / wp.sqrt(alpha[0]))
1591
+ norm = wp.sqrt(alpha[0])
1592
+ y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / norm)
1593
+
1594
+ if tid == 0:
1595
+ alpha_copy[0] = norm
1131
1596
 
1132
1597
 
1133
1598
  @wp.kernel