warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/native/warp.h
CHANGED
|
@@ -11,12 +11,21 @@
|
|
|
11
11
|
// defines all crt + builtin types
|
|
12
12
|
#include "builtin.h"
|
|
13
13
|
|
|
14
|
+
#define WP_CURRENT_STREAM ((void*)0xffffffffffffffff)
|
|
15
|
+
|
|
14
16
|
// this is the core runtime API exposed on the DLL level
|
|
15
17
|
extern "C"
|
|
16
18
|
{
|
|
17
19
|
WP_API int init();
|
|
18
20
|
//WP_API void shutdown();
|
|
19
21
|
|
|
22
|
+
// get error message from C++
|
|
23
|
+
WP_API const char* get_error_string();
|
|
24
|
+
|
|
25
|
+
// allow disabling error output, which is handy during tests that expect failure
|
|
26
|
+
WP_API void set_error_output_enabled(int enable);
|
|
27
|
+
WP_API int is_error_output_enabled();
|
|
28
|
+
|
|
20
29
|
// whether Warp was compiled with CUDA support
|
|
21
30
|
WP_API int is_cuda_enabled();
|
|
22
31
|
// whether Warp was compiled with enhanced CUDA compatibility
|
|
@@ -31,22 +40,22 @@ extern "C"
|
|
|
31
40
|
|
|
32
41
|
WP_API void* alloc_host(size_t s);
|
|
33
42
|
WP_API void* alloc_pinned(size_t s);
|
|
34
|
-
WP_API void* alloc_device(void* context, size_t s);
|
|
35
|
-
WP_API void*
|
|
43
|
+
WP_API void* alloc_device(void* context, size_t s); // uses cudaMallocAsync() if supported, cudaMalloc() otherwise
|
|
44
|
+
WP_API void* alloc_device_default(void* context, size_t s); // uses cudaMalloc()
|
|
45
|
+
WP_API void* alloc_device_async(void* context, size_t s); // uses cudaMallocAsync()
|
|
36
46
|
|
|
37
47
|
WP_API void free_host(void* ptr);
|
|
38
48
|
WP_API void free_pinned(void* ptr);
|
|
39
|
-
WP_API void free_device(void* context, void* ptr);
|
|
40
|
-
WP_API void
|
|
49
|
+
WP_API void free_device(void* context, void* ptr); // uses cudaFreeAsync() if supported, cudaFree() otherwise
|
|
50
|
+
WP_API void free_device_default(void* context, void* ptr); // uses cudaFree()
|
|
51
|
+
WP_API void free_device_async(void* context, void* ptr); // uses cudaFreeAsync()
|
|
41
52
|
|
|
42
|
-
|
|
43
|
-
WP_API
|
|
44
|
-
WP_API
|
|
45
|
-
WP_API
|
|
46
|
-
WP_API
|
|
47
|
-
WP_API void memcpy_peer(void* context, void* dest, void* src, size_t n);
|
|
53
|
+
WP_API bool memcpy_h2h(void* dest, void* src, size_t n);
|
|
54
|
+
WP_API bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
|
|
55
|
+
WP_API bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
|
|
56
|
+
WP_API bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
|
|
57
|
+
WP_API bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream=WP_CURRENT_STREAM);
|
|
48
58
|
|
|
49
|
-
// all memsets are performed asynchronously
|
|
50
59
|
WP_API void memset_host(void* dest, int value, size_t n);
|
|
51
60
|
WP_API void memset_device(void* context, void* dest, int value, size_t n);
|
|
52
61
|
|
|
@@ -82,7 +91,7 @@ extern "C"
|
|
|
82
91
|
WP_API void hash_grid_destroy_device(uint64_t id);
|
|
83
92
|
WP_API void hash_grid_update_device(uint64_t id, float cell_width, const wp::vec3* positions, int num_points);
|
|
84
93
|
|
|
85
|
-
WP_API bool cutlass_gemm(int compute_capability, int m, int n, int k, const char* datatype,
|
|
94
|
+
WP_API bool cutlass_gemm(void* context, int compute_capability, int m, int n, int k, const char* datatype,
|
|
86
95
|
const void* a, const void* b, const void* c, void* d, float alpha, float beta,
|
|
87
96
|
bool row_major_a, bool row_major_b, bool allow_tf32x3_arith, int batch_count);
|
|
88
97
|
|
|
@@ -106,8 +115,8 @@ extern "C"
|
|
|
106
115
|
WP_API int marching_cubes_surface_device(uint64_t id, const float* field, int nx, int ny, int nz, float threshold, wp::vec3* verts, int* triangles, int max_verts, int max_tris, int* out_num_verts, int* out_num_tris);
|
|
107
116
|
|
|
108
117
|
// generic copy supporting non-contiguous arrays
|
|
109
|
-
WP_API
|
|
110
|
-
WP_API
|
|
118
|
+
WP_API bool array_copy_host(void* dst, void* src, int dst_type, int src_type, int elem_size);
|
|
119
|
+
WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size);
|
|
111
120
|
|
|
112
121
|
// generic fill for non-contiguous arrays
|
|
113
122
|
WP_API void array_fill_host(void* arr, int arr_type, const void* value, int value_size);
|
|
@@ -220,8 +229,7 @@ extern "C"
|
|
|
220
229
|
WP_API void nvrtc_supported_archs(int* archs);
|
|
221
230
|
|
|
222
231
|
WP_API int cuda_device_get_count();
|
|
223
|
-
WP_API void*
|
|
224
|
-
WP_API void cuda_device_primary_context_release(int ordinal);
|
|
232
|
+
WP_API void* cuda_device_get_primary_context(int ordinal);
|
|
225
233
|
WP_API const char* cuda_device_get_name(int ordinal);
|
|
226
234
|
WP_API int cuda_device_get_arch(int ordinal);
|
|
227
235
|
WP_API void cuda_device_get_uuid(int ordinal, char uuid[16]);
|
|
@@ -229,7 +237,10 @@ extern "C"
|
|
|
229
237
|
WP_API int cuda_device_get_pci_bus_id(int ordinal);
|
|
230
238
|
WP_API int cuda_device_get_pci_device_id(int ordinal);
|
|
231
239
|
WP_API int cuda_device_is_uva(int ordinal);
|
|
232
|
-
WP_API int
|
|
240
|
+
WP_API int cuda_device_is_mempool_supported(int ordinal);
|
|
241
|
+
WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold);
|
|
242
|
+
WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal);
|
|
243
|
+
WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem);
|
|
233
244
|
|
|
234
245
|
WP_API void* cuda_context_get_current();
|
|
235
246
|
WP_API void cuda_context_set_current(void* context);
|
|
@@ -239,11 +250,8 @@ extern "C"
|
|
|
239
250
|
WP_API void cuda_context_destroy(void* context);
|
|
240
251
|
WP_API int cuda_context_get_device_ordinal(void* context);
|
|
241
252
|
WP_API int cuda_context_is_primary(void* context);
|
|
242
|
-
WP_API int cuda_context_is_memory_pool_supported(void* context);
|
|
243
253
|
WP_API void* cuda_context_get_stream(void* context);
|
|
244
|
-
WP_API void cuda_context_set_stream(void* context, void* stream);
|
|
245
|
-
WP_API int cuda_context_can_access_peer(void* context, void* peer_context);
|
|
246
|
-
WP_API int cuda_context_enable_peer_access(void* context, void* peer_context);
|
|
254
|
+
WP_API void cuda_context_set_stream(void* context, void* stream, int sync);
|
|
247
255
|
|
|
248
256
|
// ensures all device side operations have completed in the current context
|
|
249
257
|
WP_API void cuda_context_synchronize(void* context);
|
|
@@ -251,28 +259,38 @@ extern "C"
|
|
|
251
259
|
// return cudaError_t code
|
|
252
260
|
WP_API uint64_t cuda_context_check(void* context);
|
|
253
261
|
|
|
262
|
+
// peer access
|
|
263
|
+
WP_API int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal);
|
|
264
|
+
WP_API int cuda_is_peer_access_enabled(void* target_context, void* peer_context);
|
|
265
|
+
WP_API int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable);
|
|
266
|
+
WP_API int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal);
|
|
267
|
+
WP_API int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable);
|
|
268
|
+
|
|
254
269
|
WP_API void* cuda_stream_create(void* context);
|
|
255
270
|
WP_API void cuda_stream_destroy(void* context, void* stream);
|
|
256
|
-
WP_API void
|
|
271
|
+
WP_API void cuda_stream_register(void* context, void* stream);
|
|
272
|
+
WP_API void cuda_stream_unregister(void* context, void* stream);
|
|
257
273
|
WP_API void* cuda_stream_get_current();
|
|
258
|
-
WP_API void
|
|
259
|
-
WP_API void
|
|
274
|
+
WP_API void cuda_stream_synchronize(void* stream);
|
|
275
|
+
WP_API void cuda_stream_wait_event(void* stream, void* event);
|
|
276
|
+
WP_API void cuda_stream_wait_stream(void* stream, void* other_stream, void* event);
|
|
277
|
+
WP_API int cuda_stream_is_capturing(void* stream);
|
|
260
278
|
|
|
261
279
|
WP_API void* cuda_event_create(void* context, unsigned flags);
|
|
262
|
-
WP_API void cuda_event_destroy(void*
|
|
263
|
-
WP_API void cuda_event_record(void*
|
|
280
|
+
WP_API void cuda_event_destroy(void* event);
|
|
281
|
+
WP_API void cuda_event_record(void* event, void* stream);
|
|
264
282
|
|
|
265
|
-
WP_API
|
|
266
|
-
WP_API
|
|
267
|
-
WP_API
|
|
268
|
-
WP_API
|
|
283
|
+
WP_API bool cuda_graph_begin_capture(void* context, void* stream, int external);
|
|
284
|
+
WP_API bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret);
|
|
285
|
+
WP_API bool cuda_graph_launch(void* graph, void* stream);
|
|
286
|
+
WP_API bool cuda_graph_destroy(void* context, void* graph);
|
|
269
287
|
|
|
270
288
|
WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_file);
|
|
271
289
|
|
|
272
290
|
WP_API void* cuda_load_module(void* context, const char* ptx);
|
|
273
291
|
WP_API void cuda_unload_module(void* context, void* module);
|
|
274
292
|
WP_API void* cuda_get_kernel(void* context, void* module, const char* name);
|
|
275
|
-
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args);
|
|
293
|
+
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream);
|
|
276
294
|
|
|
277
295
|
WP_API void cuda_set_context_restore_policy(bool always_restore);
|
|
278
296
|
WP_API int cuda_get_context_restore_policy();
|
|
@@ -284,4 +302,3 @@ extern "C"
|
|
|
284
302
|
WP_API void cuda_graphics_unregister_resource(void* context, void* resource);
|
|
285
303
|
|
|
286
304
|
} // extern "C"
|
|
287
|
-
|
warp/optim/linear.py
CHANGED
|
@@ -366,6 +366,7 @@ def bicgstab(
|
|
|
366
366
|
device=device,
|
|
367
367
|
inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
|
|
368
368
|
)
|
|
369
|
+
array_inner(r, r, out=r_norm_sq)
|
|
369
370
|
|
|
370
371
|
# z = M r
|
|
371
372
|
if M is not None:
|
|
@@ -473,6 +474,8 @@ def gmres(
|
|
|
473
474
|
device = A.device
|
|
474
475
|
scalar_dtype = wp.types.type_scalar_type(A.dtype)
|
|
475
476
|
|
|
477
|
+
pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
|
|
478
|
+
|
|
476
479
|
beta_sq = wp.empty_like(r_norm_sq, pinned=False)
|
|
477
480
|
H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
|
|
478
481
|
|
|
@@ -488,7 +491,6 @@ def gmres(
|
|
|
488
491
|
shape=(1,),
|
|
489
492
|
device=H.device,
|
|
490
493
|
copy=False,
|
|
491
|
-
owner=False,
|
|
492
494
|
)
|
|
493
495
|
|
|
494
496
|
def array_row(V, i):
|
|
@@ -498,7 +500,6 @@ def gmres(
|
|
|
498
500
|
shape=V.shape[1],
|
|
499
501
|
device=V.device,
|
|
500
502
|
copy=False,
|
|
501
|
-
owner=False,
|
|
502
503
|
)
|
|
503
504
|
|
|
504
505
|
def do_arnoldi_iteration(j: int):
|
|
@@ -548,7 +549,7 @@ def gmres(
|
|
|
548
549
|
do_arnoldi_iteration(j)
|
|
549
550
|
|
|
550
551
|
wp.launch(_gmres_normalize_lower_diagonal, dim=restart, device=device, inputs=[H])
|
|
551
|
-
wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, beta_sq, H, y])
|
|
552
|
+
wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, pivot_tolerance, beta_sq, H, y])
|
|
552
553
|
|
|
553
554
|
# update x
|
|
554
555
|
if M is None or is_left_preconditioner:
|
|
@@ -575,16 +576,19 @@ def gmres(
|
|
|
575
576
|
)
|
|
576
577
|
|
|
577
578
|
|
|
578
|
-
def
|
|
579
|
+
def _get_dtype_epsilon(dtype):
|
|
579
580
|
if dtype == wp.float64:
|
|
580
|
-
|
|
581
|
-
min_tol = 1.0e-36
|
|
581
|
+
return 1.0e-16
|
|
582
582
|
elif dtype == wp.float16:
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
583
|
+
return 1.0e-4
|
|
584
|
+
|
|
585
|
+
return 1.0e-8
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
|
|
589
|
+
eps_tol = _get_dtype_epsilon(dtype)
|
|
590
|
+
default_tol = eps_tol ** (3 / 4)
|
|
591
|
+
min_tol = eps_tol ** (9 / 4)
|
|
588
592
|
|
|
589
593
|
if tol is None and atol is None:
|
|
590
594
|
tol = atol = default_tol
|
|
@@ -846,7 +850,9 @@ def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
|
|
|
846
850
|
|
|
847
851
|
|
|
848
852
|
@wp.kernel
|
|
849
|
-
def _gmres_solve_least_squares(
|
|
853
|
+
def _gmres_solve_least_squares(
|
|
854
|
+
k: int, pivot_tolerance: float, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
|
|
855
|
+
):
|
|
850
856
|
# Solve H y = (beta, 0, ..., 0)
|
|
851
857
|
# H Hessenberg matrix of shape (k+1, k)
|
|
852
858
|
|
|
@@ -859,6 +865,7 @@ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array
|
|
|
859
865
|
|
|
860
866
|
# Apply 2x2 rotations to H so as to remove lower diagonal,
|
|
861
867
|
# and apply similar rotations to right-hand-side
|
|
868
|
+
max_k = int(k)
|
|
862
869
|
for i in range(k):
|
|
863
870
|
Ha = H[i]
|
|
864
871
|
Hb = H[i + 1]
|
|
@@ -866,7 +873,14 @@ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array
|
|
|
866
873
|
# Givens rotation [[c s], [-s c]]
|
|
867
874
|
a = Ha[i]
|
|
868
875
|
b = Hb[i]
|
|
869
|
-
|
|
876
|
+
abn_sq = a * a + b * b
|
|
877
|
+
|
|
878
|
+
if abn_sq < type(abn_sq)(pivot_tolerance):
|
|
879
|
+
# Arnoldi iteration finished early
|
|
880
|
+
max_k = i
|
|
881
|
+
break
|
|
882
|
+
|
|
883
|
+
abn = wp.sqrt(abn_sq)
|
|
870
884
|
c = a / abn
|
|
871
885
|
s = b / abn
|
|
872
886
|
|
|
@@ -881,12 +895,15 @@ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array
|
|
|
881
895
|
y[i] = c * rhs
|
|
882
896
|
rhs = -s * rhs
|
|
883
897
|
|
|
898
|
+
for i in range(max_k, k):
|
|
899
|
+
y[i] = y.dtype(0.0)
|
|
900
|
+
|
|
884
901
|
# Triangular back-solve for y
|
|
885
|
-
for ii in range(
|
|
902
|
+
for ii in range(max_k, 0, -1):
|
|
886
903
|
i = ii - 1
|
|
887
904
|
Hi = H[i]
|
|
888
905
|
yi = y[i]
|
|
889
|
-
for j in range(ii,
|
|
906
|
+
for j in range(ii, max_k):
|
|
890
907
|
yi -= Hi[j] * y[j]
|
|
891
908
|
y[i] = yi / Hi[i]
|
|
892
909
|
|
|
@@ -908,7 +925,7 @@ def _gmres_arnoldi_normalize_kernel(
|
|
|
908
925
|
alpha: wp.array(dtype=Any),
|
|
909
926
|
):
|
|
910
927
|
tid = wp.tid()
|
|
911
|
-
y[tid] = x[tid] / wp.sqrt(alpha[0])
|
|
928
|
+
y[tid] = wp.select(alpha[0] == alpha.dtype(0.0), x[tid] / wp.sqrt(alpha[0]), x[tid])
|
|
912
929
|
|
|
913
930
|
|
|
914
931
|
@wp.kernel
|