warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/native/warp.h CHANGED
@@ -11,12 +11,21 @@
11
11
  // defines all crt + builtin types
12
12
  #include "builtin.h"
13
13
 
14
+ #define WP_CURRENT_STREAM ((void*)0xffffffffffffffff)
15
+
14
16
  // this is the core runtime API exposed on the DLL level
15
17
  extern "C"
16
18
  {
17
19
  WP_API int init();
18
20
  //WP_API void shutdown();
19
21
 
22
+ // get error message from C++
23
+ WP_API const char* get_error_string();
24
+
25
+ // allow disabling error output, which is handy during tests that expect failure
26
+ WP_API void set_error_output_enabled(int enable);
27
+ WP_API int is_error_output_enabled();
28
+
20
29
  // whether Warp was compiled with CUDA support
21
30
  WP_API int is_cuda_enabled();
22
31
  // whether Warp was compiled with enhanced CUDA compatibility
@@ -31,22 +40,22 @@ extern "C"
31
40
 
32
41
  WP_API void* alloc_host(size_t s);
33
42
  WP_API void* alloc_pinned(size_t s);
34
- WP_API void* alloc_device(void* context, size_t s);
35
- WP_API void* alloc_temp_device(void* context, size_t s);
43
+ WP_API void* alloc_device(void* context, size_t s); // uses cudaMallocAsync() if supported, cudaMalloc() otherwise
44
+ WP_API void* alloc_device_default(void* context, size_t s); // uses cudaMalloc()
45
+ WP_API void* alloc_device_async(void* context, size_t s); // uses cudaMallocAsync()
36
46
 
37
47
  WP_API void free_host(void* ptr);
38
48
  WP_API void free_pinned(void* ptr);
39
- WP_API void free_device(void* context, void* ptr);
40
- WP_API void free_temp_device(void* context, void* ptr);
49
+ WP_API void free_device(void* context, void* ptr); // uses cudaFreeAsync() if supported, cudaFree() otherwise
50
+ WP_API void free_device_default(void* context, void* ptr); // uses cudaFree()
51
+ WP_API void free_device_async(void* context, void* ptr); // uses cudaFreeAsync()
41
52
 
42
- // all memcpys are performed asynchronously
43
- WP_API void memcpy_h2h(void* dest, void* src, size_t n);
44
- WP_API void memcpy_h2d(void* context, void* dest, void* src, size_t n);
45
- WP_API void memcpy_d2h(void* context, void* dest, void* src, size_t n);
46
- WP_API void memcpy_d2d(void* context, void* dest, void* src, size_t n);
47
- WP_API void memcpy_peer(void* context, void* dest, void* src, size_t n);
53
+ WP_API bool memcpy_h2h(void* dest, void* src, size_t n);
54
+ WP_API bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
55
+ WP_API bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
56
+ WP_API bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
57
+ WP_API bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
48
58
 
49
- // all memsets are performed asynchronously
50
59
  WP_API void memset_host(void* dest, int value, size_t n);
51
60
  WP_API void memset_device(void* context, void* dest, int value, size_t n);
52
61
 
@@ -82,7 +91,7 @@ extern "C"
82
91
  WP_API void hash_grid_destroy_device(uint64_t id);
83
92
  WP_API void hash_grid_update_device(uint64_t id, float cell_width, const wp::vec3* positions, int num_points);
84
93
 
85
- WP_API bool cutlass_gemm(int compute_capability, int m, int n, int k, const char* datatype,
94
+ WP_API bool cutlass_gemm(void* context, int compute_capability, int m, int n, int k, const char* datatype,
86
95
  const void* a, const void* b, const void* c, void* d, float alpha, float beta,
87
96
  bool row_major_a, bool row_major_b, bool allow_tf32x3_arith, int batch_count);
88
97
 
@@ -106,8 +115,8 @@ extern "C"
106
115
  WP_API int marching_cubes_surface_device(uint64_t id, const float* field, int nx, int ny, int nz, float threshold, wp::vec3* verts, int* triangles, int max_verts, int max_tris, int* out_num_verts, int* out_num_tris);
107
116
 
108
117
  // generic copy supporting non-contiguous arrays
109
- WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type, int elem_size);
110
- WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size);
118
+ WP_API bool array_copy_host(void* dst, void* src, int dst_type, int src_type, int elem_size);
119
+ WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size);
111
120
 
112
121
  // generic fill for non-contiguous arrays
113
122
  WP_API void array_fill_host(void* arr, int arr_type, const void* value, int value_size);
@@ -220,8 +229,7 @@ extern "C"
220
229
  WP_API void nvrtc_supported_archs(int* archs);
221
230
 
222
231
  WP_API int cuda_device_get_count();
223
- WP_API void* cuda_device_primary_context_retain(int ordinal);
224
- WP_API void cuda_device_primary_context_release(int ordinal);
232
+ WP_API void* cuda_device_get_primary_context(int ordinal);
225
233
  WP_API const char* cuda_device_get_name(int ordinal);
226
234
  WP_API int cuda_device_get_arch(int ordinal);
227
235
  WP_API void cuda_device_get_uuid(int ordinal, char uuid[16]);
@@ -229,7 +237,10 @@ extern "C"
229
237
  WP_API int cuda_device_get_pci_bus_id(int ordinal);
230
238
  WP_API int cuda_device_get_pci_device_id(int ordinal);
231
239
  WP_API int cuda_device_is_uva(int ordinal);
232
- WP_API int cuda_device_is_memory_pool_supported(int ordinal);
240
+ WP_API int cuda_device_is_mempool_supported(int ordinal);
241
+ WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold);
242
+ WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal);
243
+ WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem);
233
244
 
234
245
  WP_API void* cuda_context_get_current();
235
246
  WP_API void cuda_context_set_current(void* context);
@@ -239,11 +250,8 @@ extern "C"
239
250
  WP_API void cuda_context_destroy(void* context);
240
251
  WP_API int cuda_context_get_device_ordinal(void* context);
241
252
  WP_API int cuda_context_is_primary(void* context);
242
- WP_API int cuda_context_is_memory_pool_supported(void* context);
243
253
  WP_API void* cuda_context_get_stream(void* context);
244
- WP_API void cuda_context_set_stream(void* context, void* stream);
245
- WP_API int cuda_context_can_access_peer(void* context, void* peer_context);
246
- WP_API int cuda_context_enable_peer_access(void* context, void* peer_context);
254
+ WP_API void cuda_context_set_stream(void* context, void* stream, int sync);
247
255
 
248
256
  // ensures all device side operations have completed in the current context
249
257
  WP_API void cuda_context_synchronize(void* context);
@@ -251,28 +259,38 @@ extern "C"
251
259
  // return cudaError_t code
252
260
  WP_API uint64_t cuda_context_check(void* context);
253
261
 
262
+ // peer access
263
+ WP_API int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal);
264
+ WP_API int cuda_is_peer_access_enabled(void* target_context, void* peer_context);
265
+ WP_API int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable);
266
+ WP_API int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal);
267
+ WP_API int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable);
268
+
254
269
  WP_API void* cuda_stream_create(void* context);
255
270
  WP_API void cuda_stream_destroy(void* context, void* stream);
256
- WP_API void cuda_stream_synchronize(void* context, void* stream);
271
+ WP_API void cuda_stream_register(void* context, void* stream);
272
+ WP_API void cuda_stream_unregister(void* context, void* stream);
257
273
  WP_API void* cuda_stream_get_current();
258
- WP_API void cuda_stream_wait_event(void* context, void* stream, void* event);
259
- WP_API void cuda_stream_wait_stream(void* context, void* stream, void* other_stream, void* event);
274
+ WP_API void cuda_stream_synchronize(void* stream);
275
+ WP_API void cuda_stream_wait_event(void* stream, void* event);
276
+ WP_API void cuda_stream_wait_stream(void* stream, void* other_stream, void* event);
277
+ WP_API int cuda_stream_is_capturing(void* stream);
260
278
 
261
279
  WP_API void* cuda_event_create(void* context, unsigned flags);
262
- WP_API void cuda_event_destroy(void* context, void* event);
263
- WP_API void cuda_event_record(void* context, void* event, void* stream);
280
+ WP_API void cuda_event_destroy(void* event);
281
+ WP_API void cuda_event_record(void* event, void* stream);
264
282
 
265
- WP_API void cuda_graph_begin_capture(void* context);
266
- WP_API void* cuda_graph_end_capture(void* context);
267
- WP_API void cuda_graph_launch(void* context, void* graph);
268
- WP_API void cuda_graph_destroy(void* context, void* graph);
283
+ WP_API bool cuda_graph_begin_capture(void* context, void* stream, int external);
284
+ WP_API bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret);
285
+ WP_API bool cuda_graph_launch(void* graph, void* stream);
286
+ WP_API bool cuda_graph_destroy(void* context, void* graph);
269
287
 
270
288
  WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_file);
271
289
 
272
290
  WP_API void* cuda_load_module(void* context, const char* ptx);
273
291
  WP_API void cuda_unload_module(void* context, void* module);
274
292
  WP_API void* cuda_get_kernel(void* context, void* module, const char* name);
275
- WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args);
293
+ WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream);
276
294
 
277
295
  WP_API void cuda_set_context_restore_policy(bool always_restore);
278
296
  WP_API int cuda_get_context_restore_policy();
@@ -284,4 +302,3 @@ extern "C"
284
302
  WP_API void cuda_graphics_unregister_resource(void* context, void* resource);
285
303
 
286
304
  } // extern "C"
287
-
warp/optim/linear.py CHANGED
@@ -366,6 +366,7 @@ def bicgstab(
366
366
  device=device,
367
367
  inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
368
368
  )
369
+ array_inner(r, r, out=r_norm_sq)
369
370
 
370
371
  # z = M r
371
372
  if M is not None:
@@ -473,6 +474,8 @@ def gmres(
473
474
  device = A.device
474
475
  scalar_dtype = wp.types.type_scalar_type(A.dtype)
475
476
 
477
+ pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
478
+
476
479
  beta_sq = wp.empty_like(r_norm_sq, pinned=False)
477
480
  H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
478
481
 
@@ -488,7 +491,6 @@ def gmres(
488
491
  shape=(1,),
489
492
  device=H.device,
490
493
  copy=False,
491
- owner=False,
492
494
  )
493
495
 
494
496
  def array_row(V, i):
@@ -498,7 +500,6 @@ def gmres(
498
500
  shape=V.shape[1],
499
501
  device=V.device,
500
502
  copy=False,
501
- owner=False,
502
503
  )
503
504
 
504
505
  def do_arnoldi_iteration(j: int):
@@ -548,7 +549,7 @@ def gmres(
548
549
  do_arnoldi_iteration(j)
549
550
 
550
551
  wp.launch(_gmres_normalize_lower_diagonal, dim=restart, device=device, inputs=[H])
551
- wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, beta_sq, H, y])
552
+ wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, pivot_tolerance, beta_sq, H, y])
552
553
 
553
554
  # update x
554
555
  if M is None or is_left_preconditioner:
@@ -575,16 +576,19 @@ def gmres(
575
576
  )
576
577
 
577
578
 
578
- def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
579
+ def _get_dtype_epsilon(dtype):
579
580
  if dtype == wp.float64:
580
- default_tol = 1.0e-12
581
- min_tol = 1.0e-36
581
+ return 1.0e-16
582
582
  elif dtype == wp.float16:
583
- default_tol = 1.0e-3
584
- min_tol = 1.0e-9
585
- else:
586
- default_tol = 1.0e-6
587
- min_tol = 1.0e-18
583
+ return 1.0e-4
584
+
585
+ return 1.0e-8
586
+
587
+
588
+ def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
589
+ eps_tol = _get_dtype_epsilon(dtype)
590
+ default_tol = eps_tol ** (3 / 4)
591
+ min_tol = eps_tol ** (9 / 4)
588
592
 
589
593
  if tol is None and atol is None:
590
594
  tol = atol = default_tol
@@ -846,7 +850,9 @@ def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
846
850
 
847
851
 
848
852
  @wp.kernel
849
- def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)):
853
+ def _gmres_solve_least_squares(
854
+ k: int, pivot_tolerance: float, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
855
+ ):
850
856
  # Solve H y = (beta, 0, ..., 0)
851
857
  # H Hessenberg matrix of shape (k+1, k)
852
858
 
@@ -859,6 +865,7 @@ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array
859
865
 
860
866
  # Apply 2x2 rotations to H so as to remove lower diagonal,
861
867
  # and apply similar rotations to right-hand-side
868
+ max_k = int(k)
862
869
  for i in range(k):
863
870
  Ha = H[i]
864
871
  Hb = H[i + 1]
@@ -866,7 +873,14 @@ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array
866
873
  # Givens rotation [[c s], [-s c]]
867
874
  a = Ha[i]
868
875
  b = Hb[i]
869
- abn = wp.sqrt(a * a + b * b)
876
+ abn_sq = a * a + b * b
877
+
878
+ if abn_sq < type(abn_sq)(pivot_tolerance):
879
+ # Arnoldi iteration finished early
880
+ max_k = i
881
+ break
882
+
883
+ abn = wp.sqrt(abn_sq)
870
884
  c = a / abn
871
885
  s = b / abn
872
886
 
@@ -881,12 +895,15 @@ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array
881
895
  y[i] = c * rhs
882
896
  rhs = -s * rhs
883
897
 
898
+ for i in range(max_k, k):
899
+ y[i] = y.dtype(0.0)
900
+
884
901
  # Triangular back-solve for y
885
- for ii in range(k, 0, -1):
902
+ for ii in range(max_k, 0, -1):
886
903
  i = ii - 1
887
904
  Hi = H[i]
888
905
  yi = y[i]
889
- for j in range(ii, k):
906
+ for j in range(ii, max_k):
890
907
  yi -= Hi[j] * y[j]
891
908
  y[i] = yi / Hi[i]
892
909
 
@@ -908,7 +925,7 @@ def _gmres_arnoldi_normalize_kernel(
908
925
  alpha: wp.array(dtype=Any),
909
926
  ):
910
927
  tid = wp.tid()
911
- y[tid] = x[tid] / wp.sqrt(alpha[0])
928
+ y[tid] = wp.select(alpha[0] == alpha.dtype(0.0), x[tid] / wp.sqrt(alpha[0]), x[tid])
912
929
 
913
930
 
914
931
  @wp.kernel