warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/optim/linear.py ADDED
@@ -0,0 +1,922 @@
1
+ from typing import Optional, Union, Callable, Tuple, Any
2
+ from math import sqrt
3
+
4
+ import warp as wp
5
+ import warp.sparse as sparse
6
+ from warp.utils import array_inner
7
+
8
+ # No need to auto-generate adjoint code for linear solvers
9
+ wp.set_module_options({"enable_backward": False})
10
+
11
+
12
+ class LinearOperator:
13
+ """
14
+ Linear operator to be used as left-hand-side of linear iterative solvers.
15
+
16
+ Args:
17
+ shape: Tuple containing the number of rows and columns of the operator
18
+ dtype: Type of the operator elements
19
+ device: Device on which computations involving the operator should be performed
20
+ matvec: Matrix-vector multiplication routine
21
+
22
+ The matrix-vector multiplication routine should have the following signature:
23
+
24
+ .. code-block:: python
25
+
26
+ def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
27
+ '''Perfoms the operation z = alpha * x + beta * y'''
28
+ ...
29
+
30
+ For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
31
+ for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
32
+ cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
33
+
34
+ """
35
+
36
+ def __init__(self, shape: Tuple[int, int], dtype: type, device: wp.context.Device, matvec: Callable):
37
+ self._shape = shape
38
+ self._dtype = dtype
39
+ self._device = device
40
+ self._matvec = matvec
41
+
42
+ @property
43
+ def shape(self) -> Tuple[int, int]:
44
+ return self._shape
45
+
46
+ @property
47
+ def dtype(self) -> type:
48
+ return self._dtype
49
+
50
+ @property
51
+ def device(self) -> wp.context.Device:
52
+ return self._device
53
+
54
+ @property
55
+ def matvec(self) -> Callable:
56
+ return self._matvec
57
+
58
+ @property
59
+ def scalar_type(self):
60
+ return wp.types.type_scalar_type(self.dtype)
61
+
62
+
63
+ _Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
64
+
65
+
66
+ def aslinearoperator(A: _Matrix) -> LinearOperator:
67
+ """
68
+ Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
69
+
70
+ `A` must be of one of the following types:
71
+
72
+ - :class:`warp.sparse.BsrMatrix`
73
+ - two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
74
+ - one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
75
+ - :class:`warp.sparse.LinearOperator`; no casting necessary
76
+ """
77
+
78
+ if A is None or isinstance(A, LinearOperator):
79
+ return A
80
+
81
+ def bsr_mv(x, y, z, alpha, beta):
82
+ if z.ptr != y.ptr and beta != 0.0:
83
+ wp.copy(src=y, dest=z)
84
+ sparse.bsr_mv(A, x, z, alpha, beta)
85
+
86
+ def dense_mv(x, y, z, alpha, beta):
87
+ x = x.reshape((x.shape[0], 1))
88
+ y = y.reshape((y.shape[0], 1))
89
+ z = z.reshape((y.shape[0], 1))
90
+ wp.matmul(A, x, y, z, alpha, beta)
91
+
92
+ def diag_mv(x, y, z, alpha, beta):
93
+ scalar_type = wp.types.type_scalar_type(A.dtype)
94
+ alpha = scalar_type(alpha)
95
+ beta = scalar_type(beta)
96
+ wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
97
+
98
+ def diag_mv_vec(x, y, z, alpha, beta):
99
+ scalar_type = wp.types.type_scalar_type(A.dtype)
100
+ alpha = scalar_type(alpha)
101
+ beta = scalar_type(beta)
102
+ wp.launch(_diag_mv_vec_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
103
+
104
+ if isinstance(A, wp.array):
105
+ if A.ndim == 2:
106
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
107
+ if A.ndim == 1:
108
+ if wp.types.type_is_vector(A.dtype):
109
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
110
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
111
+ if isinstance(A, sparse.BsrMatrix):
112
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
113
+
114
+ raise ValueError(f"Unable to create LinearOperator from {A}")
115
+
116
+
117
+ def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
118
+ """Constructs and returns a preconditioner for an input matrix.
119
+
120
+ Args:
121
+ A: The matrix for which to build the preconditioner
122
+ ptype: The type of preconditioner. Currently the following values are supported:
123
+
124
+ - ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
125
+ - ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
126
+ - ``"id"``: Identity (null) preconditioner
127
+ """
128
+
129
+ if ptype == "id":
130
+ return None
131
+
132
+ if ptype in ("diag", "diag_abs"):
133
+ use_abs = 1 if ptype == "diag_abs" else 0
134
+ if isinstance(A, sparse.BsrMatrix):
135
+ A_diag = sparse.bsr_get_diag(A)
136
+ if wp.types.type_is_matrix(A.dtype):
137
+ inv_diag = wp.empty(
138
+ shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
139
+ )
140
+ wp.launch(
141
+ _extract_inverse_diagonal_blocked,
142
+ dim=inv_diag.shape,
143
+ device=inv_diag.device,
144
+ inputs=[A_diag, inv_diag, use_abs],
145
+ )
146
+ else:
147
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
148
+ wp.launch(
149
+ _extract_inverse_diagonal_scalar,
150
+ dim=inv_diag.shape,
151
+ device=inv_diag.device,
152
+ inputs=[A_diag, inv_diag, use_abs],
153
+ )
154
+ elif isinstance(A, wp.array) and A.ndim == 2:
155
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
156
+ wp.launch(
157
+ _extract_inverse_diagonal_dense,
158
+ dim=inv_diag.shape,
159
+ device=inv_diag.device,
160
+ inputs=[A, inv_diag, use_abs],
161
+ )
162
+ else:
163
+ raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
164
+
165
+ return aslinearoperator(inv_diag)
166
+
167
+ raise ValueError(f"Unsupported preconditioner type '{ptype}'")
168
+
169
+
170
+ def cg(
171
+ A: _Matrix,
172
+ b: wp.array,
173
+ x: wp.array,
174
+ tol: Optional[float] = None,
175
+ atol: Optional[float] = None,
176
+ maxiter: Optional[float] = 0,
177
+ M: Optional[_Matrix] = None,
178
+ callback: Optional[Callable] = None,
179
+ check_every=10,
180
+ use_cuda_graph=True,
181
+ ) -> Tuple[int, float, float]:
182
+ """Computes an approximate solution to a symmetric, positive-definite linear system
183
+ using the Conjugate Gradient algorithm.
184
+
185
+ Args:
186
+ A: the linear system's left-hand-side
187
+ b: the linear system's right-hand-side
188
+ x: initial guess and solution vector
189
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
190
+ atol: absolute tolerance for the residual
191
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
192
+ Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
193
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
194
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
195
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
196
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
197
+ The linear operator and preconditioner must only perform graph-friendly operations.
198
+
199
+ Returns:
200
+ Tuple (final iteration number, residual norm, absolute tolerance)
201
+
202
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
203
+ """
204
+
205
+ A = aslinearoperator(A)
206
+ M = aslinearoperator(M)
207
+
208
+ if maxiter == 0:
209
+ maxiter = A.shape[0]
210
+
211
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
212
+
213
+ device = A.device
214
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
215
+
216
+ # Notations below follow pseudo-code from https://en.wikipedia.org/wiki/Conjugate_gradient_method
217
+
218
+ # z = M r
219
+ if M is not None:
220
+ z = wp.zeros_like(b)
221
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
222
+
223
+ # rz = r' z;
224
+ rz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
225
+ array_inner(r, z, out=rz_new)
226
+ else:
227
+ z = r
228
+
229
+ rz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
230
+ p_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
231
+ Ap = wp.zeros_like(b)
232
+
233
+ p = wp.clone(z)
234
+
235
+ def do_iteration(atol_sq, rr_old, rr_new, rz_old, rz_new):
236
+ # Ap = A * p;
237
+ A.matvec(p, Ap, Ap, alpha=1, beta=0)
238
+
239
+ array_inner(p, Ap, out=p_Ap)
240
+
241
+ wp.launch(
242
+ kernel=_cg_kernel_1,
243
+ dim=x.shape[0],
244
+ device=device,
245
+ inputs=[atol_sq, rr_old, rz_old, p_Ap, x, r, p, Ap],
246
+ )
247
+ array_inner(r, r, out=rr_new)
248
+
249
+ # z = M r
250
+ if M is not None:
251
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
252
+ # rz = r' z;
253
+ array_inner(r, z, out=rz_new)
254
+
255
+ wp.launch(kernel=_cg_kernel_2, dim=z.shape[0], device=device, inputs=[atol_sq, rr_new, rz_old, rz_new, z, p])
256
+
257
+ # We do iterations by pairs, switching old and new residual norm buffers for each odd-even couple
258
+ # In the non-preconditioned case we reuse the error norm buffer for the new <r,z> computation
259
+
260
+ def do_odd_even_cycle(atol_sq: float):
261
+ # A pair of iterations, so that we're swapping the residual buffers twice
262
+ if M is None:
263
+ do_iteration(atol_sq, r_norm_sq, rz_old, r_norm_sq, rz_old)
264
+ do_iteration(atol_sq, rz_old, r_norm_sq, rz_old, r_norm_sq)
265
+ else:
266
+ do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_new, rz_old)
267
+ do_iteration(atol_sq, r_norm_sq, r_norm_sq, rz_old, rz_new)
268
+
269
+ return _run_solver_loop(
270
+ do_odd_even_cycle,
271
+ cycle_size=2,
272
+ r_norm_sq=r_norm_sq,
273
+ maxiter=maxiter,
274
+ atol=atol,
275
+ callback=callback,
276
+ check_every=check_every,
277
+ use_cuda_graph=use_cuda_graph,
278
+ device=device,
279
+ )
280
+
281
+
282
+ def bicgstab(
283
+ A: _Matrix,
284
+ b: wp.array,
285
+ x: wp.array,
286
+ tol: Optional[float] = None,
287
+ atol: Optional[float] = None,
288
+ maxiter: Optional[float] = 0,
289
+ M: Optional[_Matrix] = None,
290
+ callback: Optional[Callable] = None,
291
+ check_every=10,
292
+ use_cuda_graph=True,
293
+ is_left_preconditioner=False,
294
+ ):
295
+ """Computes an approximate solution to a linear system using the Biconjugate Gradient Stabilized method (BiCGSTAB).
296
+
297
+ Args:
298
+ A: the linear system's left-hand-side
299
+ b: the linear system's right-hand-side
300
+ x: initial guess and solution vector
301
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
302
+ atol: absolute tolerance for the residual
303
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
304
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
305
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
306
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
307
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
308
+ The linear operator and preconditioner must only perform graph-friendly operations.
309
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
310
+
311
+ Returns:
312
+ Tuple (final iteration number, residual norm, absolute tolerance)
313
+
314
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
315
+ """
316
+ A = aslinearoperator(A)
317
+ M = aslinearoperator(M)
318
+
319
+ if maxiter == 0:
320
+ maxiter = A.shape[0]
321
+
322
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
323
+
324
+ device = A.device
325
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
326
+
327
+ # Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
328
+
329
+ rho = wp.clone(r_norm_sq, pinned=False)
330
+ r0v = wp.empty(n=1, dtype=scalar_dtype, device=device)
331
+ st = wp.empty(n=1, dtype=scalar_dtype, device=device)
332
+ tt = wp.empty(n=1, dtype=scalar_dtype, device=device)
333
+
334
+ # work arrays
335
+ r0 = wp.clone(r)
336
+ v = wp.zeros_like(r)
337
+ t = wp.zeros_like(r)
338
+ p = wp.clone(r)
339
+
340
+ if M is not None:
341
+ y = wp.zeros_like(p)
342
+ z = wp.zeros_like(r)
343
+ if is_left_preconditioner:
344
+ Mt = wp.zeros_like(t)
345
+ else:
346
+ y = p
347
+ z = r
348
+ Mt = t
349
+
350
+ def do_iteration(atol_sq: float):
351
+ # y = M p
352
+ if M is not None:
353
+ M.matvec(p, y, y, alpha=1.0, beta=0.0)
354
+
355
+ # v = A * y;
356
+ A.matvec(y, v, v, alpha=1, beta=0)
357
+
358
+ # alpha = rho / <r0 . v>
359
+ array_inner(r0, v, out=r0v)
360
+
361
+ # x += alpha y
362
+ # r -= alpha v
363
+ wp.launch(
364
+ kernel=_bicgstab_kernel_1,
365
+ dim=x.shape[0],
366
+ device=device,
367
+ inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
368
+ )
369
+
370
+ # z = M r
371
+ if M is not None:
372
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
373
+
374
+ # t = A z
375
+ A.matvec(z, t, t, alpha=1, beta=0)
376
+
377
+ if is_left_preconditioner:
378
+ # Mt = M t
379
+ if M is not None:
380
+ M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
381
+
382
+ # omega = <Mt, Ms> / <Mt, Mt>
383
+ array_inner(z, Mt, out=st)
384
+ array_inner(Mt, Mt, out=tt)
385
+ else:
386
+ array_inner(r, t, out=st)
387
+ array_inner(t, t, out=tt)
388
+
389
+ # x += omega z
390
+ # r -= omega t
391
+ wp.launch(
392
+ kernel=_bicgstab_kernel_2,
393
+ dim=z.shape[0],
394
+ device=device,
395
+ inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
396
+ )
397
+ array_inner(r, r, out=r_norm_sq)
398
+
399
+ # rho = <r0, r>
400
+ array_inner(r0, r, out=rho)
401
+
402
+ # beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
403
+ # p = r + beta (p - omega v)
404
+ wp.launch(
405
+ kernel=_bicgstab_kernel_3,
406
+ dim=z.shape[0],
407
+ device=device,
408
+ inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
409
+ )
410
+
411
+ return _run_solver_loop(
412
+ do_iteration,
413
+ cycle_size=1,
414
+ r_norm_sq=r_norm_sq,
415
+ maxiter=maxiter,
416
+ atol=atol,
417
+ callback=callback,
418
+ check_every=check_every,
419
+ use_cuda_graph=use_cuda_graph,
420
+ device=device,
421
+ )
422
+
423
+
424
+ def gmres(
425
+ A: _Matrix,
426
+ b: wp.array,
427
+ x: wp.array,
428
+ tol: Optional[float] = None,
429
+ atol: Optional[float] = None,
430
+ restart=31,
431
+ maxiter: Optional[float] = 0,
432
+ M: Optional[_Matrix] = None,
433
+ callback: Optional[Callable] = None,
434
+ check_every=31,
435
+ use_cuda_graph=True,
436
+ is_left_preconditioner=False,
437
+ ):
438
+ """Computes an approximate solution to a linear system using the restarted Generalized Minimum Residual method (GMRES[k]).
439
+
440
+ Args:
441
+ A: the linear system's left-hand-side
442
+ b: the linear system's right-hand-side
443
+ x: initial guess and solution vector
444
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
445
+ atol: absolute tolerance for the residual
446
+ restart: The restart parameter, i.e, the `k` in `GMRES[k]`. In general, increasing this parameter reduces the number of iterations but increases memory consumption.
447
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
448
+ Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
449
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
450
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance
451
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
452
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
453
+ The linear operator and preconditioner must only perform graph-friendly operations.
454
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
455
+
456
+ Returns:
457
+ Tuple (final iteration number, residual norm, absolute tolerance)
458
+
459
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
460
+ """
461
+
462
+ A = aslinearoperator(A)
463
+ M = aslinearoperator(M)
464
+
465
+ if maxiter == 0:
466
+ maxiter = A.shape[0]
467
+
468
+ restart = min(restart, maxiter)
469
+ check_every = max(restart, check_every)
470
+
471
+ r, r_norm_sq, atol = _initialize_residual_and_tolerance(A, b, x, tol=tol, atol=atol)
472
+
473
+ device = A.device
474
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
475
+
476
+ beta_sq = wp.empty_like(r_norm_sq, pinned=False)
477
+ H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
478
+
479
+ y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
480
+
481
+ w = wp.zeros_like(r)
482
+ V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
483
+
484
+ def array_coeff(H, i, j):
485
+ return wp.array(
486
+ ptr=H.ptr + i * H.strides[0] + j * H.strides[1],
487
+ dtype=H.dtype,
488
+ shape=(1,),
489
+ device=H.device,
490
+ copy=False,
491
+ owner=False,
492
+ )
493
+
494
+ def array_row(V, i):
495
+ return wp.array(
496
+ ptr=V.ptr + i * V.strides[0],
497
+ dtype=V.dtype,
498
+ shape=V.shape[1],
499
+ device=V.device,
500
+ copy=False,
501
+ owner=False,
502
+ )
503
+
504
+ def do_arnoldi_iteration(j: int):
505
+ # w = A * v;
506
+
507
+ vj = array_row(V, j)
508
+
509
+ if M is not None:
510
+ tmp = array_row(V, j + 1)
511
+
512
+ if is_left_preconditioner:
513
+ A.matvec(vj, tmp, tmp, alpha=1, beta=0)
514
+ M.matvec(tmp, w, w, alpha=1, beta=0)
515
+ else:
516
+ M.matvec(vj, tmp, tmp, alpha=1, beta=0)
517
+ A.matvec(tmp, w, w, alpha=1, beta=0)
518
+ else:
519
+ A.matvec(vj, w, w, alpha=1, beta=0)
520
+
521
+ for i in range(j + 1):
522
+ vi = array_row(V, i)
523
+ hij = array_coeff(H, i, j)
524
+ array_inner(w, vi, out=hij)
525
+
526
+ wp.launch(_gmres_arnoldi_axpy_kernel, dim=w.shape, device=w.device, inputs=[vi, w, hij])
527
+
528
+ hjnj = array_coeff(H, j + 1, j)
529
+ array_inner(w, w, out=hjnj)
530
+
531
+ vjn = array_row(V, j + 1)
532
+ wp.launch(_gmres_arnoldi_normalize_kernel, dim=w.shape, device=w.device, inputs=[w, vjn, hjnj])
533
+
534
+ def do_restart_cycle(atol_sq: float):
535
+ if M is not None and is_left_preconditioner:
536
+ M.matvec(r, w, w, alpha=1, beta=0)
537
+ rh = w
538
+ else:
539
+ rh = r
540
+
541
+ array_inner(rh, rh, out=beta_sq)
542
+
543
+ v0 = array_row(V, 0)
544
+ # v0 = r / beta
545
+ wp.launch(_gmres_arnoldi_normalize_kernel, dim=r.shape, device=r.device, inputs=[rh, v0, beta_sq])
546
+
547
+ for j in range(restart):
548
+ do_arnoldi_iteration(j)
549
+
550
+ wp.launch(_gmres_normalize_lower_diagonal, dim=restart, device=device, inputs=[H])
551
+ wp.launch(_gmres_solve_least_squares, dim=1, device=device, inputs=[restart, beta_sq, H, y])
552
+
553
+ # update x
554
+ if M is None or is_left_preconditioner:
555
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(1.0), y, V, x])
556
+ else:
557
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(0.0), y, V, w])
558
+ M.matvec(w, x, x, alpha=1, beta=1)
559
+
560
+ # update r and residual
561
+ wp.copy(src=b, dest=r)
562
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
563
+ array_inner(r, r, out=r_norm_sq)
564
+
565
+ return _run_solver_loop(
566
+ do_restart_cycle,
567
+ cycle_size=restart,
568
+ r_norm_sq=r_norm_sq,
569
+ maxiter=maxiter,
570
+ atol=atol,
571
+ callback=callback,
572
+ check_every=check_every,
573
+ use_cuda_graph=use_cuda_graph,
574
+ device=device,
575
+ )
576
+
577
+
578
+ def _get_absolute_tolerance(dtype, tol, atol, lhs_norm):
579
+ if dtype == wp.float64:
580
+ default_tol = 1.0e-12
581
+ min_tol = 1.0e-36
582
+ elif dtype == wp.float16:
583
+ default_tol = 1.0e-3
584
+ min_tol = 1.0e-9
585
+ else:
586
+ default_tol = 1.0e-6
587
+ min_tol = 1.0e-18
588
+
589
+ if tol is None and atol is None:
590
+ tol = atol = default_tol
591
+ elif tol is None:
592
+ tol = atol
593
+ elif atol is None:
594
+ atol = tol
595
+
596
+ return max(tol * lhs_norm, atol, min_tol)
597
+
598
+
599
+ def _initialize_residual_and_tolerance(A: LinearOperator, b: wp.array, x: wp.array, tol: float, atol: float):
600
+ scalar_dtype = wp.types.type_scalar_type(A.dtype)
601
+ device = A.device
602
+
603
+ # Buffer for storing square norm or residual
604
+ r_norm_sq = wp.empty(n=1, dtype=scalar_dtype, device=device, pinned=device.is_cuda)
605
+
606
+ # Compute b norm to define absolute tolerance
607
+ array_inner(b, b, out=r_norm_sq)
608
+ atol = _get_absolute_tolerance(scalar_dtype, tol, atol, sqrt(r_norm_sq.numpy()[0]))
609
+
610
+ # Residual r = b - Ax
611
+ r = wp.empty_like(b)
612
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
613
+
614
+ array_inner(r, r, out=r_norm_sq)
615
+
616
+ return r, r_norm_sq, atol
617
+
618
+
619
+ def _run_solver_loop(
620
+ do_cycle: Callable[[float], None],
621
+ cycle_size: int,
622
+ r_norm_sq: wp.array,
623
+ maxiter: int,
624
+ atol: float,
625
+ callback: Callable,
626
+ check_every: int,
627
+ use_cuda_graph: bool,
628
+ device,
629
+ ):
630
+ atol_sq = atol * atol
631
+
632
+ cur_iter = 0
633
+
634
+ err_sq = r_norm_sq.numpy()[0]
635
+ err = sqrt(err_sq)
636
+ if callback is not None:
637
+ callback(cur_iter, err, atol)
638
+
639
+ if err_sq <= atol_sq:
640
+ return cur_iter, err, atol
641
+
642
+ graph = None
643
+
644
+ while True:
645
+ # Do not do graph capture at first iteration -- modules may not be loaded yet
646
+ if device.is_cuda and use_cuda_graph and cur_iter > 0:
647
+ if graph is None:
648
+ wp.capture_begin(device, force_module_load=False)
649
+ try:
650
+ do_cycle(atol_sq)
651
+ finally:
652
+ graph = wp.capture_end(device)
653
+ wp.capture_launch(graph)
654
+ else:
655
+ do_cycle(atol_sq)
656
+
657
+ cur_iter += cycle_size
658
+
659
+ if cur_iter >= maxiter:
660
+ break
661
+
662
+ if (cur_iter % check_every) < cycle_size:
663
+ err_sq = r_norm_sq.numpy()[0]
664
+
665
+ if err_sq <= atol_sq:
666
+ break
667
+
668
+ if callback is not None:
669
+ callback(cur_iter, sqrt(err_sq), atol)
670
+
671
+ err_sq = r_norm_sq.numpy()[0]
672
+ err = sqrt(err_sq)
673
+ if callback is not None:
674
+ callback(cur_iter, err, atol)
675
+
676
+ return cur_iter, err, atol
677
+
678
+
679
+ @wp.kernel
680
+ def _diag_mv_kernel(
681
+ A: wp.array(dtype=Any),
682
+ x: wp.array(dtype=Any),
683
+ y: wp.array(dtype=Any),
684
+ z: wp.array(dtype=Any),
685
+ alpha: Any,
686
+ beta: Any,
687
+ ):
688
+ i = wp.tid()
689
+ z[i] = beta * y[i] + alpha * (A[i] * x[i])
690
+
691
+
692
+ @wp.kernel
693
+ def _diag_mv_vec_kernel(
694
+ A: wp.array(dtype=Any),
695
+ x: wp.array(dtype=Any),
696
+ y: wp.array(dtype=Any),
697
+ z: wp.array(dtype=Any),
698
+ alpha: Any,
699
+ beta: Any,
700
+ ):
701
+ i = wp.tid()
702
+ z[i] = beta * y[i] + alpha * wp.cw_mul(A[i], x[i])
703
+
704
+
705
+ @wp.func
706
+ def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
707
+ zero = type(coeff)(0.0)
708
+ one = type(coeff)(1.0)
709
+ return wp.select(coeff == zero, one / wp.select(use_abs, coeff, wp.abs(coeff)), one)
710
+
711
+
712
+ @wp.kernel
713
+ def _extract_inverse_diagonal_blocked(
714
+ diag_block: wp.array(dtype=Any),
715
+ inv_diag: wp.array(dtype=Any),
716
+ use_abs: int,
717
+ ):
718
+ i = wp.tid()
719
+
720
+ d = wp.get_diag(diag_block[i])
721
+ for k in range(d.length):
722
+ d[k] = _inverse_diag_coefficient(d[k], use_abs != 0)
723
+
724
+ inv_diag[i] = d
725
+
726
+
727
+ @wp.kernel
728
+ def _extract_inverse_diagonal_scalar(
729
+ diag_array: wp.array(dtype=Any),
730
+ inv_diag: wp.array(dtype=Any),
731
+ use_abs: int,
732
+ ):
733
+ i = wp.tid()
734
+ inv_diag[i] = _inverse_diag_coefficient(diag_array[i], use_abs != 0)
735
+
736
+
737
+ @wp.kernel
738
+ def _extract_inverse_diagonal_dense(
739
+ dense_matrix: wp.array2d(dtype=Any),
740
+ inv_diag: wp.array(dtype=Any),
741
+ use_abs: int,
742
+ ):
743
+ i = wp.tid()
744
+ inv_diag[i] = _inverse_diag_coefficient(dense_matrix[i, i], use_abs != 0)
745
+
746
+
747
+ @wp.kernel
748
+ def _cg_kernel_1(
749
+ tol: Any,
750
+ resid: wp.array(dtype=Any),
751
+ rz_old: wp.array(dtype=Any),
752
+ p_Ap: wp.array(dtype=Any),
753
+ x: wp.array(dtype=Any),
754
+ r: wp.array(dtype=Any),
755
+ p: wp.array(dtype=Any),
756
+ Ap: wp.array(dtype=Any),
757
+ ):
758
+ i = wp.tid()
759
+
760
+ alpha = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_old[0] / p_Ap[0])
761
+
762
+ x[i] = x[i] + alpha * p[i]
763
+ r[i] = r[i] - alpha * Ap[i]
764
+
765
+
766
+ @wp.kernel
767
+ def _cg_kernel_2(
768
+ tol: Any,
769
+ resid: wp.array(dtype=Any),
770
+ rz_old: wp.array(dtype=Any),
771
+ rz_new: wp.array(dtype=Any),
772
+ z: wp.array(dtype=Any),
773
+ p: wp.array(dtype=Any),
774
+ ):
775
+ # p = r + (rz_new / rz_old) * p;
776
+ i = wp.tid()
777
+
778
+ beta = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_new[0] / rz_old[0])
779
+
780
+ p[i] = z[i] + beta * p[i]
781
+
782
+
783
+ @wp.kernel
784
+ def _bicgstab_kernel_1(
785
+ tol: Any,
786
+ resid: wp.array(dtype=Any),
787
+ rho_old: wp.array(dtype=Any),
788
+ r0v: wp.array(dtype=Any),
789
+ x: wp.array(dtype=Any),
790
+ r: wp.array(dtype=Any),
791
+ y: wp.array(dtype=Any),
792
+ v: wp.array(dtype=Any),
793
+ ):
794
+ i = wp.tid()
795
+
796
+ alpha = wp.select(resid[0] > tol, rho_old.dtype(0.0), rho_old[0] / r0v[0])
797
+
798
+ x[i] += alpha * y[i]
799
+ r[i] -= alpha * v[i]
800
+
801
+
802
+ @wp.kernel
803
+ def _bicgstab_kernel_2(
804
+ tol: Any,
805
+ resid: wp.array(dtype=Any),
806
+ st: wp.array(dtype=Any),
807
+ tt: wp.array(dtype=Any),
808
+ z: wp.array(dtype=Any),
809
+ t: wp.array(dtype=Any),
810
+ x: wp.array(dtype=Any),
811
+ r: wp.array(dtype=Any),
812
+ ):
813
+ i = wp.tid()
814
+
815
+ omega = wp.select(resid[0] > tol, st.dtype(0.0), st[0] / tt[0])
816
+
817
+ x[i] += omega * z[i]
818
+ r[i] -= omega * t[i]
819
+
820
+
821
+ @wp.kernel
822
+ def _bicgstab_kernel_3(
823
+ tol: Any,
824
+ resid: wp.array(dtype=Any),
825
+ rho_new: wp.array(dtype=Any),
826
+ r0v: wp.array(dtype=Any),
827
+ st: wp.array(dtype=Any),
828
+ tt: wp.array(dtype=Any),
829
+ p: wp.array(dtype=Any),
830
+ r: wp.array(dtype=Any),
831
+ v: wp.array(dtype=Any),
832
+ ):
833
+ i = wp.tid()
834
+
835
+ beta = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] * tt[0] / (r0v[0] * st[0]))
836
+ beta_omega = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] / r0v[0])
837
+
838
+ p[i] = r[i] + beta * p[i] - beta_omega * v[i]
839
+
840
+
841
+ @wp.kernel
842
+ def _gmres_normalize_lower_diagonal(H: wp.array2d(dtype=Any)):
843
+ # normalize lower-diagonal values of Hessenberg matrix
844
+ i = wp.tid()
845
+ H[i + 1, i] = wp.sqrt(H[i + 1, i])
846
+
847
+
848
+ @wp.kernel
849
+ def _gmres_solve_least_squares(k: int, beta_sq: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)):
850
+ # Solve H y = (beta, 0, ..., 0)
851
+ # H Hessenberg matrix of shape (k+1, k)
852
+
853
+ # Keeping H in global mem; warp kernels are launched with fixed block size,
854
+ # so would not fit in registers
855
+
856
+ # TODO: switch to native code with thread synchronization
857
+
858
+ rhs = wp.sqrt(beta_sq[0])
859
+
860
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
861
+ # and apply similar rotations to right-hand-side
862
+ for i in range(k):
863
+ Ha = H[i]
864
+ Hb = H[i + 1]
865
+
866
+ # Givens rotation [[c s], [-s c]]
867
+ a = Ha[i]
868
+ b = Hb[i]
869
+ abn = wp.sqrt(a * a + b * b)
870
+ c = a / abn
871
+ s = b / abn
872
+
873
+ # Rotate H
874
+ for j in range(i, k):
875
+ a = Ha[j]
876
+ b = Hb[j]
877
+ Ha[j] = c * a + s * b
878
+ Hb[j] = c * b - s * a
879
+
880
+ # Rotate rhs
881
+ y[i] = c * rhs
882
+ rhs = -s * rhs
883
+
884
+ # Triangular back-solve for y
885
+ for ii in range(k, 0, -1):
886
+ i = ii - 1
887
+ Hi = H[i]
888
+ yi = y[i]
889
+ for j in range(ii, k):
890
+ yi -= Hi[j] * y[j]
891
+ y[i] = yi / Hi[i]
892
+
893
+
894
+ @wp.kernel
895
+ def _gmres_arnoldi_axpy_kernel(
896
+ x: wp.array(dtype=Any),
897
+ y: wp.array(dtype=Any),
898
+ alpha: wp.array(dtype=Any),
899
+ ):
900
+ tid = wp.tid()
901
+ y[tid] -= x[tid] * alpha[0]
902
+
903
+
904
+ @wp.kernel
905
+ def _gmres_arnoldi_normalize_kernel(
906
+ x: wp.array(dtype=Any),
907
+ y: wp.array(dtype=Any),
908
+ alpha: wp.array(dtype=Any),
909
+ ):
910
+ tid = wp.tid()
911
+ y[tid] = x[tid] / wp.sqrt(alpha[0])
912
+
913
+
914
+ @wp.kernel
915
+ def _gmres_update_x_kernel(k: int, beta: Any, y: wp.array(dtype=Any), V: wp.array2d(dtype=Any), x: wp.array(dtype=Any)):
916
+ tid = wp.tid()
917
+
918
+ xi = beta * x[tid]
919
+ for j in range(k):
920
+ xi += V[j, tid] * y[j]
921
+
922
+ x[tid] = xi