warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1608 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import math
18
+ from typing import Any, Callable, Optional, Tuple, Union
19
+
20
+ import warp as wp
21
+ import warp._src.sparse as sparse
22
+ from warp._src.types import type_length, type_scalar_type
23
+
24
+ _wp_module_name_ = "warp.optim.linear"
25
+
26
+ __all__ = ["LinearOperator", "aslinearoperator", "bicgstab", "cg", "cr", "gmres", "preconditioner"]
27
+
28
+ # No need to auto-generate adjoint code for linear solvers
29
+ wp.set_module_options({"enable_backward": False})
30
+
31
+
32
+ class LinearOperator:
33
+ """
34
+ Linear operator to be used as left-hand-side of linear iterative solvers.
35
+
36
+ Args:
37
+ shape: Tuple containing the number of rows and columns of the operator
38
+ dtype: Type of the operator elements
39
+ device: Device on which computations involving the operator should be performed
40
+ matvec: Matrix-vector multiplication routine
41
+
42
+ The matrix-vector multiplication routine should have the following signature:
43
+
44
+ .. code-block:: python
45
+
46
+ def matvec(x: wp.array, y: wp.array, z: wp.array, alpha: Scalar, beta: Scalar):
47
+ '''Perform a generalized matrix-vector product.
48
+
49
+ This function computes the operation z = alpha * (A @ x) + beta * y, where 'A'
50
+ is the linear operator represented by this class.
51
+ '''
52
+ ...
53
+
54
+ For performance reasons, by default the iterative linear solvers in this module will try to capture the calls
55
+ for one or more iterations in CUDA graphs. If the `matvec` routine of a custom :class:`LinearOperator`
56
+ cannot be graph-captured, the ``use_cuda_graph=False`` parameter should be passed to the solver function.
57
+
58
+ """
59
+
60
+ def __init__(self, shape: Tuple[int, int], dtype: type, device: wp._src.context.Device, matvec: Callable):
61
+ self._shape = shape
62
+ self._dtype = dtype
63
+ self._device = device
64
+ self._matvec = matvec
65
+
66
+ @property
67
+ def shape(self) -> Tuple[int, int]:
68
+ return self._shape
69
+
70
+ @property
71
+ def dtype(self) -> type:
72
+ return self._dtype
73
+
74
+ @property
75
+ def device(self) -> wp._src.context.Device:
76
+ return self._device
77
+
78
+ @property
79
+ def matvec(self) -> Callable:
80
+ return self._matvec
81
+
82
+ @property
83
+ def scalar_type(self):
84
+ return wp._src.types.type_scalar_type(self.dtype)
85
+
86
+
87
+ _Matrix = Union[wp.array, sparse.BsrMatrix, LinearOperator]
88
+
89
+
90
+ def aslinearoperator(A: _Matrix) -> LinearOperator:
91
+ """
92
+ Casts the dense or sparse matrix `A` as a :class:`LinearOperator`
93
+
94
+ `A` must be of one of the following types:
95
+
96
+ - :class:`warp.sparse.BsrMatrix`
97
+ - two-dimensional `warp.array`; then `A` is assumed to be a dense matrix
98
+ - one-dimensional `warp.array`; then `A` is assumed to be a diagonal matrix
99
+ - :class:`warp.sparse.LinearOperator`; no casting necessary
100
+ """
101
+
102
+ if A is None or isinstance(A, LinearOperator):
103
+ return A
104
+
105
+ def bsr_mv(x, y, z, alpha, beta):
106
+ if z.ptr != y.ptr and beta != 0.0:
107
+ wp.copy(src=y, dest=z)
108
+ sparse.bsr_mv(A, x, z, alpha, beta)
109
+
110
+ def dense_mv(x, y, z, alpha, beta):
111
+ alpha = A.dtype(alpha)
112
+ beta = A.dtype(beta)
113
+ if A.device.is_cuda:
114
+ tile_size = 1 << min(10, max(5, math.ceil(math.log2(A.shape[1]))))
115
+ else:
116
+ tile_size = 1
117
+ wp.launch(
118
+ _dense_mv_kernel,
119
+ dim=(A.shape[0], tile_size),
120
+ block_dim=tile_size,
121
+ device=A.device,
122
+ inputs=[A, x, y, z, alpha, beta],
123
+ )
124
+
125
+ def diag_mv_impl(A, x, y, z, alpha, beta):
126
+ scalar_type = type_scalar_type(A.dtype)
127
+ alpha = scalar_type(alpha)
128
+ beta = scalar_type(beta)
129
+ wp.launch(_diag_mv_kernel, dim=A.shape, device=A.device, inputs=[A, x, y, z, alpha, beta])
130
+
131
+ def diag_mv(x, y, z, alpha, beta):
132
+ return diag_mv_impl(A, x, y, z, alpha, beta)
133
+
134
+ def diag_mv_vec(x, y, z, alpha, beta):
135
+ return diag_mv_impl(
136
+ _as_scalar_array(A), _as_scalar_array(x), _as_scalar_array(y), _as_scalar_array(z), alpha, beta
137
+ )
138
+
139
+ if isinstance(A, wp.array):
140
+ if A.ndim == 2:
141
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=dense_mv)
142
+ if A.ndim == 1:
143
+ if wp._src.types.type_is_vector(A.dtype):
144
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv_vec)
145
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=diag_mv)
146
+ if isinstance(A, sparse.BsrMatrix):
147
+ return LinearOperator(A.shape, A.dtype, A.device, matvec=bsr_mv)
148
+
149
+ raise ValueError(f"Unable to create LinearOperator from {A}")
150
+
151
+
152
+ def preconditioner(A: _Matrix, ptype: str = "diag") -> LinearOperator:
153
+ """Constructs and returns a preconditioner for an input matrix.
154
+
155
+ Args:
156
+ A: The matrix for which to build the preconditioner
157
+ ptype: The type of preconditioner. Currently the following values are supported:
158
+
159
+ - ``"diag"``: Diagonal (a.k.a. Jacobi) preconditioner
160
+ - ``"diag_abs"``: Similar to Jacobi, but using the absolute value of diagonal coefficients
161
+ - ``"id"``: Identity (null) preconditioner
162
+ """
163
+
164
+ if ptype == "id":
165
+ return None
166
+
167
+ if ptype in ("diag", "diag_abs"):
168
+ use_abs = 1 if ptype == "diag_abs" else 0
169
+ if isinstance(A, sparse.BsrMatrix):
170
+ A_diag = sparse.bsr_get_diag(A)
171
+ if wp._src.types.type_is_matrix(A.dtype):
172
+ inv_diag = wp.empty(
173
+ shape=A.nrow, dtype=wp.vec(length=A.block_shape[0], dtype=A.scalar_type), device=A.device
174
+ )
175
+ wp.launch(
176
+ _extract_inverse_diagonal_blocked,
177
+ dim=inv_diag.shape,
178
+ device=inv_diag.device,
179
+ inputs=[A_diag, inv_diag, use_abs],
180
+ )
181
+ else:
182
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.scalar_type, device=A.device)
183
+ wp.launch(
184
+ _extract_inverse_diagonal_scalar,
185
+ dim=inv_diag.shape,
186
+ device=inv_diag.device,
187
+ inputs=[A_diag, inv_diag, use_abs],
188
+ )
189
+ elif isinstance(A, wp.array) and A.ndim == 2:
190
+ inv_diag = wp.empty(shape=A.shape[0], dtype=A.dtype, device=A.device)
191
+ wp.launch(
192
+ _extract_inverse_diagonal_dense,
193
+ dim=inv_diag.shape,
194
+ device=inv_diag.device,
195
+ inputs=[A, inv_diag, use_abs],
196
+ )
197
+ else:
198
+ raise ValueError("Unsupported source matrix type for building diagonal preconditioner")
199
+
200
+ return aslinearoperator(inv_diag)
201
+
202
+ raise ValueError(f"Unsupported preconditioner type '{ptype}'")
203
+
204
+
205
+ def _as_scalar_array(x: wp.array):
206
+ scalar_type = type_scalar_type(x.dtype)
207
+ if scalar_type == x.dtype:
208
+ return x
209
+
210
+ dlen = type_length(x.dtype)
211
+ arr = wp.array(
212
+ ptr=x.ptr,
213
+ shape=(*x.shape[:-1], x.shape[-1] * dlen),
214
+ strides=(*x.strides[:-1], x.strides[-1] // dlen),
215
+ dtype=scalar_type,
216
+ device=x.device,
217
+ grad=None if x.grad is None else _as_scalar_array(x.grad),
218
+ )
219
+ arr._ref = x
220
+ return arr
221
+
222
+
223
+ class TiledDot:
224
+ """
225
+ Computes the dot product of two arrays in a way that is compatible with CUDA sub-graphs.
226
+ """
227
+
228
+ def __init__(self, max_length: int, scalar_type: type, tile_size=512, device=None, max_column_count: int = 1):
229
+ self.tile_size = tile_size
230
+ self.device = device
231
+ self.max_column_count = max_column_count
232
+
233
+ num_blocks = (max_length + self.tile_size - 1) // self.tile_size
234
+ scratch = wp.empty(
235
+ shape=(2, max_column_count, num_blocks),
236
+ dtype=scalar_type,
237
+ device=self.device,
238
+ )
239
+ self.partial_sums_a = scratch[0]
240
+ self.partial_sums_b = scratch[1]
241
+
242
+ self.dot_kernel, self.sum_kernel = _create_tiled_dot_kernels(self.tile_size)
243
+
244
+ rounds = 0
245
+ length = num_blocks
246
+ while length > 1:
247
+ length = (length + self.tile_size - 1) // self.tile_size
248
+ rounds += 1
249
+
250
+ self.rounds = rounds
251
+
252
+ self._output = self.partial_sums_a if rounds % 2 == 0 else self.partial_sums_b
253
+
254
+ self.dot_launch: wp.Launch = wp.launch(
255
+ self.dot_kernel,
256
+ dim=(max_column_count, num_blocks, self.tile_size),
257
+ inputs=(self.partial_sums_a, self.partial_sums_b),
258
+ outputs=(self.partial_sums_a,),
259
+ block_dim=self.tile_size,
260
+ record_cmd=True,
261
+ )
262
+ self.sum_launch: wp.Launch = wp.launch(
263
+ self.sum_kernel,
264
+ dim=(max_column_count, num_blocks, self.tile_size),
265
+ inputs=(self.partial_sums_a,),
266
+ outputs=(self.partial_sums_b,),
267
+ block_dim=self.tile_size,
268
+ record_cmd=True,
269
+ )
270
+
271
+ # Result contains a single value, the sum of the array (will get updated by this function)
272
+ def compute(self, a: wp.array, b: wp.array, col_offset: int = 0):
273
+ a = _as_scalar_array(a)
274
+ b = _as_scalar_array(b)
275
+ if a.ndim == 1:
276
+ a = a.reshape((1, -1))
277
+ if b.ndim == 1:
278
+ b = b.reshape((1, -1))
279
+
280
+ column_count = a.shape[0]
281
+ num_blocks = (a.shape[1] + self.tile_size - 1) // self.tile_size
282
+
283
+ data_out = self.partial_sums_a[col_offset : col_offset + column_count]
284
+ data_in = self.partial_sums_b[col_offset : col_offset + column_count]
285
+
286
+ self.dot_launch.set_param_at_index(0, a)
287
+ self.dot_launch.set_param_at_index(1, b)
288
+ self.dot_launch.set_param_at_index(2, data_out)
289
+ self.dot_launch.set_dim((column_count, num_blocks, self.tile_size))
290
+ self.dot_launch.launch()
291
+
292
+ for _r in range(self.rounds):
293
+ array_length = num_blocks
294
+ num_blocks = (array_length + self.tile_size - 1) // self.tile_size
295
+ data_in, data_out = data_out, data_in
296
+
297
+ self.sum_launch.set_param_at_index(0, data_in)
298
+ self.sum_launch.set_param_at_index(1, data_out)
299
+ self.sum_launch.set_dim((column_count, num_blocks, self.tile_size))
300
+ self.sum_launch.launch()
301
+
302
+ return data_out
303
+
304
+ def col(self, col: int = 0):
305
+ return self._output[col][:1]
306
+
307
+ def cols(self, count, start: int = 0):
308
+ return self._output[start : start + count, :1]
309
+
310
+
311
+ @functools.lru_cache(maxsize=None)
312
+ def _create_tiled_dot_kernels(tile_size):
313
+ @wp.kernel
314
+ def block_dot_kernel(
315
+ a: wp.array2d(dtype=Any),
316
+ b: wp.array2d(dtype=Any),
317
+ partial_sums: wp.array2d(dtype=Any),
318
+ ):
319
+ column, block_id, tid_block = wp.tid()
320
+
321
+ start = block_id * tile_size
322
+
323
+ a_block = wp.tile_load(a[column], shape=tile_size, offset=start)
324
+ b_block = wp.tile_load(b[column], shape=tile_size, offset=start)
325
+ t = wp.tile_map(wp.mul, a_block, b_block)
326
+
327
+ tile_sum = wp.tile_sum(t)
328
+ wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
329
+
330
+ @wp.kernel
331
+ def block_sum_kernel(
332
+ data: wp.array2d(dtype=Any),
333
+ partial_sums: wp.array2d(dtype=Any),
334
+ ):
335
+ column, block_id, tid_block = wp.tid()
336
+ start = block_id * tile_size
337
+
338
+ t = wp.tile_load(data[column], shape=tile_size, offset=start)
339
+
340
+ tile_sum = wp.tile_sum(t)
341
+ wp.tile_store(partial_sums[column], tile_sum, offset=block_id)
342
+
343
+ return block_dot_kernel, block_sum_kernel
344
+
345
+
346
+ def cg(
347
+ A: _Matrix,
348
+ b: wp.array,
349
+ x: wp.array,
350
+ tol: Optional[float] = None,
351
+ atol: Optional[float] = None,
352
+ maxiter: Optional[float] = 0,
353
+ M: Optional[_Matrix] = None,
354
+ callback: Optional[Callable] = None,
355
+ check_every=10,
356
+ use_cuda_graph=True,
357
+ ) -> Union[Tuple[int, float, float], Tuple[wp.array, wp.array, wp.array]]:
358
+ """Computes an approximate solution to a symmetric, positive-definite linear system
359
+ using the Conjugate Gradient algorithm.
360
+
361
+ Args:
362
+ A: the linear system's left-hand-side
363
+ b: the linear system's right-hand-side
364
+ x: initial guess and solution vector
365
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
366
+ atol: absolute tolerance for the residual
367
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
368
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
369
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
370
+ If `check_every` is 0, the callback should be a Warp kernel.
371
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
372
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
373
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
374
+ to the maximum number of iterations.
375
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
376
+ The linear operator and preconditioner must only perform graph-friendly operations.
377
+
378
+ Returns:
379
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
380
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
381
+ - residual_norm: The final residual norm ||b - Ax||
382
+ - absolute_tolerance: The absolute tolerance used for convergence checking
383
+
384
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
385
+ - final_iteration_array: Device array containing the number of iterations performed
386
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
387
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
388
+
389
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
390
+ """
391
+ A = aslinearoperator(A)
392
+ M = aslinearoperator(M)
393
+
394
+ if maxiter == 0:
395
+ maxiter = A.shape[0]
396
+
397
+ device = A.device
398
+ scalar_type = A.scalar_type
399
+
400
+ # Temp storage
401
+ r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
402
+ p_and_Ap = wp.empty_like(r_and_z)
403
+ residuals = wp.empty(2, dtype=scalar_type, device=device)
404
+
405
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
406
+
407
+ # named views
408
+
409
+ # (r, r) -- so we can compute r.z and r.r at once
410
+ r_repeated = _repeat_first(r_and_z)
411
+ if M is None:
412
+ # without preconditioner r == z
413
+ r_and_z = r_repeated
414
+ rz_new = tiled_dot.col(0)
415
+ else:
416
+ rz_new = tiled_dot.col(1)
417
+
418
+ r, z = r_and_z[0], r_and_z[1]
419
+ r_norm_sq = tiled_dot.col(0)
420
+
421
+ p, Ap = p_and_Ap[0], p_and_Ap[1]
422
+ rz_old, atol_sq = residuals[0:1], residuals[1:2]
423
+
424
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
425
+ Ap.zero_()
426
+ z.zero_()
427
+
428
+ # Initialize tolerance from right-hand-side norm
429
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
430
+ # Initialize residual
431
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
432
+
433
+ def update_rr_rz():
434
+ # z = M r
435
+ if M is None:
436
+ tiled_dot.compute(r, r)
437
+ else:
438
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
439
+ tiled_dot.compute(r_repeated, r_and_z)
440
+
441
+ update_rr_rz()
442
+ p.assign(z)
443
+
444
+ def do_iteration():
445
+ rz_old.assign(rz_new)
446
+
447
+ # Ap = A * p;
448
+ A.matvec(p, Ap, Ap, alpha=1, beta=0)
449
+ tiled_dot.compute(p, Ap, col_offset=1)
450
+ p_Ap = tiled_dot.col(1)
451
+
452
+ wp.launch(
453
+ kernel=_cg_kernel_1,
454
+ dim=x.shape[0],
455
+ device=device,
456
+ inputs=[atol_sq, r_norm_sq, rz_old, p_Ap, x, r, p, Ap],
457
+ )
458
+
459
+ update_rr_rz()
460
+
461
+ wp.launch(
462
+ kernel=_cg_kernel_2,
463
+ dim=z.shape[0],
464
+ device=device,
465
+ inputs=[atol_sq, r_norm_sq, rz_old, rz_new, z, p],
466
+ )
467
+
468
+ return _run_capturable_loop(do_iteration, r_norm_sq, maxiter, atol_sq, callback, check_every, use_cuda_graph)
469
+
470
+
471
+ def cr(
472
+ A: _Matrix,
473
+ b: wp.array,
474
+ x: wp.array,
475
+ tol: Optional[float] = None,
476
+ atol: Optional[float] = None,
477
+ maxiter: Optional[float] = 0,
478
+ M: Optional[_Matrix] = None,
479
+ callback: Optional[Callable] = None,
480
+ check_every=10,
481
+ use_cuda_graph=True,
482
+ ) -> Tuple[int, float, float]:
483
+ """Computes an approximate solution to a symmetric, positive-definite linear system
484
+ using the Conjugate Residual algorithm.
485
+
486
+ Args:
487
+ A: the linear system's left-hand-side
488
+ b: the linear system's right-hand-side
489
+ x: initial guess and solution vector
490
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
491
+ atol: absolute tolerance for the residual
492
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
493
+ Note that the current implementation always performs iterations in pairs, and as a result may exceed the specified maximum number of iterations by one.
494
+ M: optional left-preconditioner, ideally chosen such that ``M A`` is close to identity.
495
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
496
+ If `check_every` is 0, the callback should be a Warp kernel.
497
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
498
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
499
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
500
+ to the maximum number of iterations.
501
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
502
+ The linear operator and preconditioner must only perform graph-friendly operations.
503
+
504
+ Returns:
505
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
506
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
507
+ - residual_norm: The final residual norm ||b - Ax||
508
+ - absolute_tolerance: The absolute tolerance used for convergence checking
509
+
510
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
511
+ - final_iteration_array: Device array containing the number of iterations performed
512
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
513
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
514
+
515
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
516
+ """
517
+
518
+ A = aslinearoperator(A)
519
+ M = aslinearoperator(M)
520
+
521
+ if maxiter == 0:
522
+ maxiter = A.shape[0]
523
+
524
+ device = A.device
525
+ scalar_type = wp._src.types.type_scalar_type(A.dtype)
526
+
527
+ # Notations below follow roughly pseudo-code from https://en.wikipedia.org/wiki/Conjugate_residual_method
528
+ # with z := M^-1 r and y := M^-1 Ap
529
+
530
+ # Temp storage
531
+ r_and_z = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
532
+ r_and_Az = wp.empty_like(r_and_z)
533
+ y_and_Ap = wp.empty_like(r_and_z)
534
+ p = wp.empty_like(b)
535
+ residuals = wp.empty(2, dtype=scalar_type, device=device)
536
+
537
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=2)
538
+
539
+ if M is None:
540
+ r_and_z = _repeat_first(r_and_z)
541
+ y_and_Ap = _repeat_first(y_and_Ap)
542
+
543
+ # named views
544
+ r, z = r_and_z[0], r_and_z[1]
545
+ r_copy, Az = r_and_Az[0], r_and_Az[1]
546
+
547
+ y, Ap = y_and_Ap[0], y_and_Ap[1]
548
+
549
+ r_norm_sq = tiled_dot.col(0)
550
+ zAz_new = tiled_dot.col(1)
551
+ zAz_old, atol_sq = residuals[0:1], residuals[1:2]
552
+
553
+ # Initialize tolerance from right-hand-side norm
554
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
555
+ # Initialize residual
556
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
557
+
558
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
559
+ y_and_Ap.zero_()
560
+
561
+ # z = M r
562
+ if M is not None:
563
+ z.zero_()
564
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
565
+
566
+ def update_rr_zAz():
567
+ A.matvec(z, Az, Az, alpha=1, beta=0)
568
+ r_copy.assign(r)
569
+ tiled_dot.compute(r_and_z, r_and_Az)
570
+
571
+ update_rr_zAz()
572
+
573
+ p.assign(z)
574
+ Ap.assign(Az)
575
+
576
+ def do_iteration():
577
+ zAz_old.assign(zAz_new)
578
+
579
+ if M is not None:
580
+ M.matvec(Ap, y, y, alpha=1.0, beta=0.0)
581
+ tiled_dot.compute(Ap, y, col_offset=1)
582
+ y_Ap = tiled_dot.col(1)
583
+
584
+ if M is None:
585
+ # In non-preconditioned case, first kernel is same as CG
586
+ wp.launch(
587
+ kernel=_cg_kernel_1,
588
+ dim=x.shape[0],
589
+ device=device,
590
+ inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, p, Ap],
591
+ )
592
+ else:
593
+ # In preconditioned case, we have one more vector to update
594
+ wp.launch(
595
+ kernel=_cr_kernel_1,
596
+ dim=x.shape[0],
597
+ device=device,
598
+ inputs=[atol_sq, r_norm_sq, zAz_old, y_Ap, x, r, z, p, Ap, y],
599
+ )
600
+
601
+ update_rr_zAz()
602
+ wp.launch(
603
+ kernel=_cr_kernel_2,
604
+ dim=z.shape[0],
605
+ device=device,
606
+ inputs=[atol_sq, r_norm_sq, zAz_old, zAz_new, z, p, Az, Ap],
607
+ )
608
+
609
+ return _run_capturable_loop(
610
+ do_iteration,
611
+ cycle_size=1,
612
+ r_norm_sq=r_norm_sq,
613
+ maxiter=maxiter,
614
+ atol_sq=atol_sq,
615
+ callback=callback,
616
+ check_every=check_every,
617
+ use_cuda_graph=use_cuda_graph,
618
+ )
619
+
620
+
621
+ def bicgstab(
622
+ A: _Matrix,
623
+ b: wp.array,
624
+ x: wp.array,
625
+ tol: Optional[float] = None,
626
+ atol: Optional[float] = None,
627
+ maxiter: Optional[float] = 0,
628
+ M: Optional[_Matrix] = None,
629
+ callback: Optional[Callable] = None,
630
+ check_every=10,
631
+ use_cuda_graph=True,
632
+ is_left_preconditioner=False,
633
+ ):
634
+ """Computes an approximate solution to a linear system using the Biconjugate Gradient Stabilized method (BiCGSTAB).
635
+
636
+ Args:
637
+ A: the linear system's left-hand-side
638
+ b: the linear system's right-hand-side
639
+ x: initial guess and solution vector
640
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
641
+ atol: absolute tolerance for the residual
642
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
643
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
644
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
645
+ If `check_every` is 0, the callback should be a Warp kernel.
646
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
647
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
648
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
649
+ to the maximum number of iterations.
650
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
651
+ The linear operator and preconditioner must only perform graph-friendly operations.
652
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
653
+
654
+ Returns:
655
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
656
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
657
+ - residual_norm: The final residual norm ||b - Ax||
658
+ - absolute_tolerance: The absolute tolerance used for convergence checking
659
+
660
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
661
+ - final_iteration_array: Device array containing the number of iterations performed
662
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
663
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
664
+
665
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
666
+ """
667
+ A = aslinearoperator(A)
668
+ M = aslinearoperator(M)
669
+
670
+ if maxiter == 0:
671
+ maxiter = A.shape[0]
672
+
673
+ device = A.device
674
+ scalar_type = wp._src.types.type_scalar_type(A.dtype)
675
+
676
+ # Notations below follow pseudo-code from biconjugate https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method
677
+
678
+ # Temp storage
679
+ r_and_r0 = wp.empty((2, b.shape[0]), dtype=b.dtype, device=device)
680
+ p = wp.empty_like(b)
681
+ v = wp.empty_like(b)
682
+ t = wp.empty_like(b)
683
+
684
+ r, r0 = r_and_r0[0], r_and_r0[1]
685
+ r_repeated = _repeat_first(r_and_r0)
686
+
687
+ if M is not None:
688
+ y = wp.zeros_like(p)
689
+ z = wp.zeros_like(r)
690
+ if is_left_preconditioner:
691
+ Mt = wp.zeros_like(t)
692
+ else:
693
+ y = p
694
+ z = r
695
+ Mt = t
696
+
697
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_type, max_column_count=5)
698
+ r_norm_sq = tiled_dot.col(0)
699
+ rho = tiled_dot.col(1)
700
+
701
+ atol_sq = wp.empty(1, dtype=scalar_type, device=device)
702
+
703
+ # Initialize tolerance from right-hand-side norm
704
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
705
+ # Initialize residual
706
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
707
+ tiled_dot.compute(r, r, col_offset=0)
708
+
709
+ p.assign(r)
710
+ r0.assign(r)
711
+ rho.assign(r_norm_sq)
712
+
713
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
714
+ v.zero_()
715
+ t.zero_()
716
+
717
+ def do_iteration():
718
+ # y = M p
719
+ if M is not None:
720
+ M.matvec(p, y, y, alpha=1.0, beta=0.0)
721
+
722
+ # v = A * y;
723
+ A.matvec(y, v, v, alpha=1, beta=0)
724
+
725
+ # alpha = rho / <r0 . v>
726
+ tiled_dot.compute(r0, v, col_offset=2)
727
+ r0v = tiled_dot.col(2)
728
+
729
+ # x += alpha y
730
+ # r -= alpha v
731
+ wp.launch(
732
+ kernel=_bicgstab_kernel_1,
733
+ dim=x.shape[0],
734
+ device=device,
735
+ inputs=[atol_sq, r_norm_sq, rho, r0v, x, r, y, v],
736
+ )
737
+ tiled_dot.compute(r, r, col_offset=0)
738
+
739
+ # z = M r
740
+ if M is not None:
741
+ M.matvec(r, z, z, alpha=1.0, beta=0.0)
742
+
743
+ # t = A z
744
+ A.matvec(z, t, t, alpha=1, beta=0)
745
+
746
+ if M is not None and is_left_preconditioner:
747
+ # Mt = M t
748
+ M.matvec(t, Mt, Mt, alpha=1.0, beta=0.0)
749
+
750
+ # omega = <Mt, Ms> / <Mt, Mt>
751
+ tiled_dot.compute(z, Mt, col_offset=3)
752
+ tiled_dot.compute(Mt, Mt, col_offset=4)
753
+ else:
754
+ tiled_dot.compute(r, t, col_offset=3)
755
+ tiled_dot.compute(t, t, col_offset=4)
756
+ st = tiled_dot.col(3)
757
+ tt = tiled_dot.col(4)
758
+
759
+ # x += omega z
760
+ # r -= omega t
761
+ wp.launch(
762
+ kernel=_bicgstab_kernel_2,
763
+ dim=z.shape[0],
764
+ device=device,
765
+ inputs=[atol_sq, r_norm_sq, st, tt, z, t, x, r],
766
+ )
767
+
768
+ # r = <r,r>, rho = <r0, r>
769
+ tiled_dot.compute(r_and_r0, r_repeated, col_offset=0)
770
+
771
+ # beta = (rho / rho_old) * alpha / omega = (rho / r0v) / omega
772
+ # p = r + beta (p - omega v)
773
+ wp.launch(
774
+ kernel=_bicgstab_kernel_3,
775
+ dim=z.shape[0],
776
+ device=device,
777
+ inputs=[atol_sq, r_norm_sq, rho, r0v, st, tt, p, r, v],
778
+ )
779
+
780
+ return _run_capturable_loop(
781
+ do_iteration,
782
+ r_norm_sq=r_norm_sq,
783
+ maxiter=maxiter,
784
+ atol_sq=atol_sq,
785
+ callback=callback,
786
+ check_every=check_every,
787
+ use_cuda_graph=use_cuda_graph,
788
+ )
789
+
790
+
791
+ def gmres(
792
+ A: _Matrix,
793
+ b: wp.array,
794
+ x: wp.array,
795
+ tol: Optional[float] = None,
796
+ atol: Optional[float] = None,
797
+ restart=31,
798
+ maxiter: Optional[float] = 0,
799
+ M: Optional[_Matrix] = None,
800
+ callback: Optional[Callable] = None,
801
+ check_every=31,
802
+ use_cuda_graph=True,
803
+ is_left_preconditioner=False,
804
+ ):
805
+ """Computes an approximate solution to a linear system using the restarted Generalized Minimum Residual method (GMRES[k]).
806
+
807
+ Args:
808
+ A: the linear system's left-hand-side
809
+ b: the linear system's right-hand-side
810
+ x: initial guess and solution vector
811
+ tol: relative tolerance for the residual, as a ratio of the right-hand-side norm
812
+ atol: absolute tolerance for the residual
813
+ restart: The restart parameter, i.e, the `k` in `GMRES[k]`. In general, increasing this parameter reduces the number of iterations but increases memory consumption.
814
+ maxiter: maximum number of iterations to perform before aborting. Defaults to the system size.
815
+ Note that the current implementation always perform `restart` iterations at a time, and as a result may exceed the specified maximum number of iterations by ``restart-1``.
816
+ M: optional left- or right-preconditioner, ideally chosen such that ``M A`` (resp ``A M``) is close to identity.
817
+ callback: function to be called every `check_every` iteration with the current iteration number, residual and tolerance.
818
+ If `check_every` is 0, the callback should be a Warp kernel.
819
+ check_every: number of iterations every which to call `callback`, check the residual against the tolerance and possibility terminate the algorithm.
820
+ Setting `check_every` to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable.
821
+ If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run
822
+ to the maximum number of iterations.
823
+ use_cuda_graph: If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead.
824
+ The linear operator and preconditioner must only perform graph-friendly operations.
825
+ is_left_preconditioner: whether `M` should be used as a left- or right- preconditioner.
826
+
827
+ Returns:
828
+ If `check_every` > 0: Tuple (final_iteration, residual_norm, absolute_tolerance)
829
+ - final_iteration: The number of iterations performed before convergence or reaching maxiter
830
+ - residual_norm: The final residual norm ||b - Ax||
831
+ - absolute_tolerance: The absolute tolerance used for convergence checking
832
+
833
+ If `check_every` is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
834
+ - final_iteration_array: Device array containing the number of iterations performed
835
+ - residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
836
+ - absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
837
+
838
+ If both `tol` and `atol` are provided, the absolute tolerance used as the termination criterion for the residual norm is ``max(atol, tol * norm(b))``.
839
+ """
840
+
841
+ A = aslinearoperator(A)
842
+ M = aslinearoperator(M)
843
+
844
+ if maxiter == 0:
845
+ maxiter = A.shape[0]
846
+
847
+ restart = min(restart, maxiter)
848
+
849
+ if check_every > 0:
850
+ check_every = max(restart, check_every)
851
+
852
+ device = A.device
853
+ scalar_dtype = wp._src.types.type_scalar_type(A.dtype)
854
+
855
+ pivot_tolerance = _get_dtype_epsilon(scalar_dtype) ** 2
856
+
857
+ r = wp.empty_like(b)
858
+ w = wp.empty_like(r)
859
+
860
+ H = wp.empty(shape=(restart + 1, restart), dtype=scalar_dtype, device=device)
861
+ y = wp.empty(shape=restart + 1, dtype=scalar_dtype, device=device)
862
+
863
+ V = wp.zeros(shape=(restart + 1, r.shape[0]), dtype=r.dtype, device=device)
864
+
865
+ residuals = wp.empty(2, dtype=scalar_dtype, device=device)
866
+ beta, atol_sq = residuals[0:1], residuals[1:2]
867
+
868
+ tiled_dot = TiledDot(max_length=A.shape[0], device=device, scalar_type=scalar_dtype, max_column_count=restart + 1)
869
+ r_norm_sq = tiled_dot.col(0)
870
+
871
+ w_repeated = wp.array(
872
+ ptr=w.ptr, shape=(restart + 1, w.shape[0]), strides=(0, w.strides[0]), dtype=w.dtype, device=w.device
873
+ )
874
+
875
+ # tile size for least square solve
876
+ # (need to fit in a CUDA block, so 1024 max)
877
+ if device.is_cuda and 4 < restart <= 1024:
878
+ tile_size = 1 << math.ceil(math.log2(restart))
879
+ least_squares_kernel = make_gmres_solve_least_squares_kernel_tiled(tile_size)
880
+ else:
881
+ tile_size = 1
882
+ least_squares_kernel = _gmres_solve_least_squares
883
+
884
+ # recorded launches
885
+ least_squares_solve = wp.launch(
886
+ least_squares_kernel,
887
+ dim=(1, tile_size),
888
+ block_dim=tile_size if tile_size > 1 else 256,
889
+ device=device,
890
+ inputs=[restart, pivot_tolerance, beta, H, y],
891
+ record_cmd=True,
892
+ )
893
+
894
+ normalize_anorldi_vec = wp.launch(
895
+ _gmres_arnoldi_normalize_kernel,
896
+ dim=r.shape,
897
+ device=r.device,
898
+ inputs=[r, w, tiled_dot.col(0), beta],
899
+ record_cmd=True,
900
+ )
901
+
902
+ arnoldi_axpy = wp.launch(
903
+ _gmres_arnoldi_axpy_kernel,
904
+ dim=(w.shape[0], tile_size),
905
+ block_dim=tile_size,
906
+ device=w.device,
907
+ inputs=[V, w, H],
908
+ record_cmd=True,
909
+ )
910
+
911
+ # Initialize tolerance from right-hand-side norm
912
+ _initialize_absolute_tolerance(b, tol, atol, tiled_dot, atol_sq)
913
+ # Initialize residual
914
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
915
+ tiled_dot.compute(r, r, col_offset=0)
916
+
917
+ # Not strictly necessary, but makes it more robust to user-provided LinearOperators
918
+ w.zero_()
919
+
920
+ def array_coeff(H, i, j):
921
+ return H[i][j : j + 1]
922
+
923
+ def array_col(H, j):
924
+ return H[: j + 1, j : j + 1]
925
+
926
+ def do_arnoldi_iteration(j: int):
927
+ # w = A * v[j];
928
+ if M is not None:
929
+ tmp = V[j + 1]
930
+
931
+ if is_left_preconditioner:
932
+ A.matvec(V[j], tmp, tmp, alpha=1, beta=0)
933
+ M.matvec(tmp, w, w, alpha=1, beta=0)
934
+ else:
935
+ M.matvec(V[j], tmp, tmp, alpha=1, beta=0)
936
+ A.matvec(tmp, w, w, alpha=1, beta=0)
937
+ else:
938
+ A.matvec(V[j], w, w, alpha=1, beta=0)
939
+
940
+ # compute and apply dot products in rappel,
941
+ # since Hj columns are orthogonal
942
+ Hj = array_col(H, j)
943
+ tiled_dot.compute(w_repeated, V[: j + 1])
944
+ wp.copy(src=tiled_dot.cols(j + 1), dest=Hj)
945
+
946
+ # w -= w.vi vi
947
+ arnoldi_axpy.set_params([V[: j + 1], w, Hj])
948
+ arnoldi_axpy.launch()
949
+
950
+ # H[j+1, j] = |w.w|
951
+ tiled_dot.compute(w, w)
952
+ normalize_anorldi_vec.set_params([w, V[j + 1], tiled_dot.col(0), array_coeff(H, j + 1, j)])
953
+
954
+ normalize_anorldi_vec.launch()
955
+
956
+ def do_restart_cycle():
957
+ if M is not None and is_left_preconditioner:
958
+ M.matvec(r, w, w, alpha=1, beta=0)
959
+ rh = w
960
+ else:
961
+ rh = r
962
+
963
+ # beta^2 = rh.rh
964
+ tiled_dot.compute(rh, rh)
965
+
966
+ # v[0] = r / beta
967
+ normalize_anorldi_vec.set_params([rh, V[0], tiled_dot.col(0), beta])
968
+ normalize_anorldi_vec.launch()
969
+
970
+ for j in range(restart):
971
+ do_arnoldi_iteration(j)
972
+
973
+ least_squares_solve.launch()
974
+
975
+ # update x
976
+ if M is None or is_left_preconditioner:
977
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(1.0), y, V, x])
978
+ else:
979
+ wp.launch(_gmres_update_x_kernel, dim=x.shape, device=device, inputs=[restart, scalar_dtype(0.0), y, V, w])
980
+ M.matvec(w, x, x, alpha=1, beta=1)
981
+
982
+ # update r and residual
983
+ wp.copy(src=b, dest=r)
984
+ A.matvec(x, b, r, alpha=-1.0, beta=1.0)
985
+ tiled_dot.compute(r, r)
986
+
987
+ return _run_capturable_loop(
988
+ do_restart_cycle,
989
+ cycle_size=restart,
990
+ r_norm_sq=r_norm_sq,
991
+ maxiter=maxiter,
992
+ atol_sq=atol_sq,
993
+ callback=callback,
994
+ check_every=check_every,
995
+ use_cuda_graph=use_cuda_graph,
996
+ )
997
+
998
+
999
+ def _repeat_first(arr: wp.array):
1000
+ # returns a view of the first element repeated arr.shape[0] times
1001
+ view = wp.array(
1002
+ ptr=arr.ptr,
1003
+ shape=arr.shape,
1004
+ dtype=arr.dtype,
1005
+ strides=(0, *arr.strides[1:]),
1006
+ device=arr.device,
1007
+ )
1008
+ view._ref = arr
1009
+ return view
1010
+
1011
+
1012
+ def _get_dtype_epsilon(dtype):
1013
+ if dtype == wp.float64:
1014
+ return 1.0e-16
1015
+ elif dtype == wp.float16:
1016
+ return 1.0e-4
1017
+
1018
+ return 1.0e-8
1019
+
1020
+
1021
+ def _get_tolerances(dtype, tol, atol):
1022
+ eps_tol = _get_dtype_epsilon(dtype)
1023
+ default_tol = eps_tol ** (3 / 4)
1024
+ min_tol = eps_tol ** (9 / 4)
1025
+
1026
+ if tol is None and atol is None:
1027
+ tol = atol = default_tol
1028
+ elif tol is None:
1029
+ tol = atol
1030
+ elif atol is None:
1031
+ atol = tol
1032
+
1033
+ atol = max(atol, min_tol)
1034
+ return tol, atol
1035
+
1036
+
1037
+ @wp.kernel
1038
+ def _initialize_tolerance(
1039
+ rtol: Any,
1040
+ atol: Any,
1041
+ r_norm_sq: wp.array(dtype=Any),
1042
+ atol_sq: wp.array(dtype=Any),
1043
+ ):
1044
+ atol = wp.max(rtol * wp.sqrt(r_norm_sq[0]), atol)
1045
+ atol_sq[0] = atol * atol
1046
+
1047
+
1048
+ def _initialize_absolute_tolerance(
1049
+ b: wp.array,
1050
+ tol: float,
1051
+ atol: float,
1052
+ tiled_dot: TiledDot,
1053
+ atol_sq: wp.array,
1054
+ ):
1055
+ scalar_type = atol_sq.dtype
1056
+
1057
+ # Compute b norm to define absolute tolerance
1058
+ tiled_dot.compute(b, b)
1059
+ b_norm_sq = tiled_dot.col(0)
1060
+
1061
+ rtol, atol = _get_tolerances(scalar_type, tol, atol)
1062
+ wp.launch(
1063
+ kernel=_initialize_tolerance,
1064
+ dim=1,
1065
+ device=b.device,
1066
+ inputs=[scalar_type(rtol), scalar_type(atol), b_norm_sq, atol_sq],
1067
+ )
1068
+
1069
+
1070
+ @wp.kernel
1071
+ def _update_condition(
1072
+ maxiter: int,
1073
+ cycle_size: int,
1074
+ cur_iter: wp.array(dtype=int),
1075
+ r_norm_sq: wp.array(dtype=Any),
1076
+ atol_sq: wp.array(dtype=Any),
1077
+ condition: wp.array(dtype=int),
1078
+ ):
1079
+ cur_iter[0] += cycle_size
1080
+ condition[0] = wp.where(r_norm_sq[0] <= atol_sq[0] or cur_iter[0] >= maxiter, 0, 1)
1081
+
1082
+
1083
+ def _run_capturable_loop(
1084
+ do_cycle: Callable,
1085
+ r_norm_sq: wp.array,
1086
+ maxiter: int,
1087
+ atol_sq: wp.array,
1088
+ callback: Optional[Callable],
1089
+ check_every: int,
1090
+ use_cuda_graph: bool,
1091
+ cycle_size: int = 1,
1092
+ ):
1093
+ device = atol_sq.device
1094
+
1095
+ if check_every > 0:
1096
+ atol = math.sqrt(atol_sq.numpy()[0])
1097
+ return _run_solver_loop(
1098
+ do_cycle, cycle_size, r_norm_sq, maxiter, atol, callback, check_every, use_cuda_graph, device
1099
+ )
1100
+
1101
+ cur_iter_and_condition = wp.full((2,), value=-1, dtype=int, device=device)
1102
+ cur_iter = cur_iter_and_condition[0:1]
1103
+ condition = cur_iter_and_condition[1:2]
1104
+
1105
+ update_condition_launch = wp.launch(
1106
+ _update_condition,
1107
+ dim=1,
1108
+ device=device,
1109
+ inputs=[int(maxiter), cycle_size, cur_iter, r_norm_sq, atol_sq, condition],
1110
+ record_cmd=True,
1111
+ )
1112
+
1113
+ if isinstance(callback, wp.Kernel):
1114
+ callback_launch = wp.launch(
1115
+ callback, dim=1, device=device, inputs=[cur_iter, r_norm_sq, atol_sq], record_cmd=True
1116
+ )
1117
+ else:
1118
+ callback_launch = None
1119
+
1120
+ update_condition_launch.launch()
1121
+ if callback_launch is not None:
1122
+ callback_launch.launch()
1123
+
1124
+ def do_cycle_with_condition():
1125
+ do_cycle()
1126
+ update_condition_launch.launch()
1127
+ if callback_launch is not None:
1128
+ callback_launch.launch()
1129
+
1130
+ if use_cuda_graph and device.is_cuda:
1131
+ if device.is_capturing:
1132
+ wp.capture_while(condition, do_cycle_with_condition)
1133
+ else:
1134
+ with wp.ScopedCapture() as capture:
1135
+ wp.capture_while(condition, do_cycle_with_condition)
1136
+ wp.capture_launch(capture.graph)
1137
+ else:
1138
+ for _ in range(0, maxiter, cycle_size):
1139
+ do_cycle_with_condition()
1140
+
1141
+ return cur_iter, r_norm_sq, atol_sq
1142
+
1143
+
1144
+ def _run_solver_loop(
1145
+ do_cycle: Callable[[float], None],
1146
+ cycle_size: int,
1147
+ r_norm_sq: wp.array,
1148
+ maxiter: int,
1149
+ atol: float,
1150
+ callback: Callable,
1151
+ check_every: int,
1152
+ use_cuda_graph: bool,
1153
+ device,
1154
+ ):
1155
+ atol_sq = atol * atol
1156
+ check_every = max(check_every, cycle_size)
1157
+
1158
+ cur_iter = 0
1159
+
1160
+ err_sq = r_norm_sq.numpy()[0]
1161
+ err = math.sqrt(err_sq)
1162
+ if callback is not None:
1163
+ callback(cur_iter, err, atol)
1164
+
1165
+ if err_sq <= atol_sq:
1166
+ return cur_iter, err, atol
1167
+
1168
+ graph = None
1169
+
1170
+ while True:
1171
+ # Do not do graph capture at first iteration -- modules may not be loaded yet
1172
+ if device.is_cuda and use_cuda_graph and cur_iter > 0:
1173
+ if graph is None:
1174
+ with wp.ScopedCapture(force_module_load=False) as capture:
1175
+ do_cycle()
1176
+ graph = capture.graph
1177
+ wp.capture_launch(graph)
1178
+ else:
1179
+ do_cycle()
1180
+
1181
+ cur_iter += cycle_size
1182
+
1183
+ if cur_iter >= maxiter:
1184
+ break
1185
+
1186
+ if (cur_iter % check_every) < cycle_size:
1187
+ err_sq = r_norm_sq.numpy()[0]
1188
+
1189
+ if err_sq <= atol_sq:
1190
+ break
1191
+
1192
+ if callback is not None:
1193
+ callback(cur_iter, math.sqrt(err_sq), atol)
1194
+
1195
+ err_sq = r_norm_sq.numpy()[0]
1196
+ err = math.sqrt(err_sq)
1197
+ if callback is not None:
1198
+ callback(cur_iter, err, atol)
1199
+
1200
+ return cur_iter, err, atol
1201
+
1202
+
1203
+ @wp.kernel
1204
+ def _dense_mv_kernel(
1205
+ A: wp.array2d(dtype=Any),
1206
+ x: wp.array1d(dtype=Any),
1207
+ y: wp.array1d(dtype=Any),
1208
+ z: wp.array1d(dtype=Any),
1209
+ alpha: Any,
1210
+ beta: Any,
1211
+ ):
1212
+ row, lane = wp.tid()
1213
+
1214
+ zero = type(alpha)(0)
1215
+ s = zero
1216
+ if alpha != zero:
1217
+ for col in range(lane, A.shape[1], wp.block_dim()):
1218
+ s += A[row, col] * x[col]
1219
+
1220
+ row_tile = wp.tile_sum(wp.tile(s * alpha))
1221
+
1222
+ if beta != zero:
1223
+ row_tile += wp.tile_load(y, shape=1, offset=row) * beta
1224
+
1225
+ wp.tile_store(z, row_tile, offset=row)
1226
+
1227
+
1228
+ @wp.kernel
1229
+ def _diag_mv_kernel(
1230
+ A: wp.array(dtype=Any),
1231
+ x: wp.array(dtype=Any),
1232
+ y: wp.array(dtype=Any),
1233
+ z: wp.array(dtype=Any),
1234
+ alpha: Any,
1235
+ beta: Any,
1236
+ ):
1237
+ i = wp.tid()
1238
+ zero = type(alpha)(0)
1239
+ s = z.dtype(zero)
1240
+ if alpha != zero:
1241
+ s += alpha * (A[i] * x[i])
1242
+ if beta != zero:
1243
+ s += beta * y[i]
1244
+ z[i] = s
1245
+
1246
+
1247
+ @wp.func
1248
+ def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
1249
+ zero = type(coeff)(0.0)
1250
+ one = type(coeff)(1.0)
1251
+ return wp.where(coeff == zero, one, one / wp.where(use_abs, wp.abs(coeff), coeff))
1252
+
1253
+
1254
+ @wp.kernel
1255
+ def _extract_inverse_diagonal_blocked(
1256
+ diag_block: wp.array(dtype=Any),
1257
+ inv_diag: wp.array(dtype=Any),
1258
+ use_abs: int,
1259
+ ):
1260
+ i = wp.tid()
1261
+
1262
+ d = wp.get_diag(diag_block[i])
1263
+ for k in range(d.length):
1264
+ d[k] = _inverse_diag_coefficient(d[k], use_abs != 0)
1265
+
1266
+ inv_diag[i] = d
1267
+
1268
+
1269
+ @wp.kernel
1270
+ def _extract_inverse_diagonal_scalar(
1271
+ diag_array: wp.array(dtype=Any),
1272
+ inv_diag: wp.array(dtype=Any),
1273
+ use_abs: int,
1274
+ ):
1275
+ i = wp.tid()
1276
+ inv_diag[i] = _inverse_diag_coefficient(diag_array[i], use_abs != 0)
1277
+
1278
+
1279
+ @wp.kernel
1280
+ def _extract_inverse_diagonal_dense(
1281
+ dense_matrix: wp.array2d(dtype=Any),
1282
+ inv_diag: wp.array(dtype=Any),
1283
+ use_abs: int,
1284
+ ):
1285
+ i = wp.tid()
1286
+ inv_diag[i] = _inverse_diag_coefficient(dense_matrix[i, i], use_abs != 0)
1287
+
1288
+
1289
+ @wp.kernel
1290
+ def _cg_kernel_1(
1291
+ tol: wp.array(dtype=Any),
1292
+ resid: wp.array(dtype=Any),
1293
+ rz_old: wp.array(dtype=Any),
1294
+ p_Ap: wp.array(dtype=Any),
1295
+ x: wp.array(dtype=Any),
1296
+ r: wp.array(dtype=Any),
1297
+ p: wp.array(dtype=Any),
1298
+ Ap: wp.array(dtype=Any),
1299
+ ):
1300
+ i = wp.tid()
1301
+
1302
+ alpha = wp.where(resid[0] > tol[0], rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
1303
+
1304
+ x[i] = x[i] + alpha * p[i]
1305
+ r[i] = r[i] - alpha * Ap[i]
1306
+
1307
+
1308
+ @wp.kernel
1309
+ def _cg_kernel_2(
1310
+ tol: wp.array(dtype=Any),
1311
+ resid_new: wp.array(dtype=Any),
1312
+ rz_old: wp.array(dtype=Any),
1313
+ rz_new: wp.array(dtype=Any),
1314
+ z: wp.array(dtype=Any),
1315
+ p: wp.array(dtype=Any),
1316
+ ):
1317
+ # p = r + (rz_new / rz_old) * p;
1318
+ i = wp.tid()
1319
+
1320
+ cond = resid_new[0] > tol[0]
1321
+ beta = wp.where(cond, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
1322
+
1323
+ p[i] = z[i] + beta * p[i]
1324
+
1325
+
1326
+ @wp.kernel
1327
+ def _cr_kernel_1(
1328
+ tol: wp.array(dtype=Any),
1329
+ resid: wp.array(dtype=Any),
1330
+ zAz_old: wp.array(dtype=Any),
1331
+ y_Ap: wp.array(dtype=Any),
1332
+ x: wp.array(dtype=Any),
1333
+ r: wp.array(dtype=Any),
1334
+ z: wp.array(dtype=Any),
1335
+ p: wp.array(dtype=Any),
1336
+ Ap: wp.array(dtype=Any),
1337
+ y: wp.array(dtype=Any),
1338
+ ):
1339
+ i = wp.tid()
1340
+
1341
+ alpha = wp.where(resid[0] > tol[0] and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
1342
+
1343
+ x[i] = x[i] + alpha * p[i]
1344
+ r[i] = r[i] - alpha * Ap[i]
1345
+ z[i] = z[i] - alpha * y[i]
1346
+
1347
+
1348
+ @wp.kernel
1349
+ def _cr_kernel_2(
1350
+ tol: wp.array(dtype=Any),
1351
+ resid: wp.array(dtype=Any),
1352
+ zAz_old: wp.array(dtype=Any),
1353
+ zAz_new: wp.array(dtype=Any),
1354
+ z: wp.array(dtype=Any),
1355
+ p: wp.array(dtype=Any),
1356
+ Az: wp.array(dtype=Any),
1357
+ Ap: wp.array(dtype=Any),
1358
+ ):
1359
+ # p = r + (rz_new / rz_old) * p;
1360
+ i = wp.tid()
1361
+
1362
+ beta = wp.where(resid[0] > tol[0] and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
1363
+
1364
+ p[i] = z[i] + beta * p[i]
1365
+ Ap[i] = Az[i] + beta * Ap[i]
1366
+
1367
+
1368
+ @wp.kernel
1369
+ def _bicgstab_kernel_1(
1370
+ tol: wp.array(dtype=Any),
1371
+ resid: wp.array(dtype=Any),
1372
+ rho_old: wp.array(dtype=Any),
1373
+ r0v: wp.array(dtype=Any),
1374
+ x: wp.array(dtype=Any),
1375
+ r: wp.array(dtype=Any),
1376
+ y: wp.array(dtype=Any),
1377
+ v: wp.array(dtype=Any),
1378
+ ):
1379
+ i = wp.tid()
1380
+
1381
+ alpha = wp.where(resid[0] > tol[0], rho_old[0] / r0v[0], rho_old.dtype(0.0))
1382
+
1383
+ x[i] += alpha * y[i]
1384
+ r[i] -= alpha * v[i]
1385
+
1386
+
1387
+ @wp.kernel
1388
+ def _bicgstab_kernel_2(
1389
+ tol: wp.array(dtype=Any),
1390
+ resid: wp.array(dtype=Any),
1391
+ st: wp.array(dtype=Any),
1392
+ tt: wp.array(dtype=Any),
1393
+ z: wp.array(dtype=Any),
1394
+ t: wp.array(dtype=Any),
1395
+ x: wp.array(dtype=Any),
1396
+ r: wp.array(dtype=Any),
1397
+ ):
1398
+ i = wp.tid()
1399
+
1400
+ omega = wp.where(resid[0] > tol[0], st[0] / tt[0], st.dtype(0.0))
1401
+
1402
+ x[i] += omega * z[i]
1403
+ r[i] -= omega * t[i]
1404
+
1405
+
1406
+ @wp.kernel
1407
+ def _bicgstab_kernel_3(
1408
+ tol: wp.array(dtype=Any),
1409
+ resid: wp.array(dtype=Any),
1410
+ rho_new: wp.array(dtype=Any),
1411
+ r0v: wp.array(dtype=Any),
1412
+ st: wp.array(dtype=Any),
1413
+ tt: wp.array(dtype=Any),
1414
+ p: wp.array(dtype=Any),
1415
+ r: wp.array(dtype=Any),
1416
+ v: wp.array(dtype=Any),
1417
+ ):
1418
+ i = wp.tid()
1419
+
1420
+ beta = wp.where(resid[0] > tol[0], rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1421
+ beta_omega = wp.where(resid[0] > tol[0], rho_new[0] / r0v[0], st.dtype(0.0))
1422
+
1423
+ p[i] = r[i] + beta * p[i] - beta_omega * v[i]
1424
+
1425
+
1426
+ @wp.kernel
1427
+ def _gmres_solve_least_squares(
1428
+ k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1429
+ ):
1430
+ # Solve H y = (beta, 0, ..., 0)
1431
+ # H Hessenberg matrix of shape (k+1, k)
1432
+ # so would not fit in registers
1433
+
1434
+ rhs = beta[0]
1435
+
1436
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
1437
+ # and apply similar rotations to right-hand-side
1438
+ max_k = int(k)
1439
+ for i in range(k):
1440
+ Ha = H[i]
1441
+ Hb = H[i + 1]
1442
+
1443
+ # Givens rotation [[c s], [-s c]]
1444
+ a = Ha[i]
1445
+ b = Hb[i]
1446
+ abn_sq = a * a + b * b
1447
+
1448
+ if abn_sq < type(abn_sq)(pivot_tolerance):
1449
+ # Arnoldi iteration finished early
1450
+ max_k = i
1451
+ break
1452
+
1453
+ abn = wp.sqrt(abn_sq)
1454
+ c = a / abn
1455
+ s = b / abn
1456
+
1457
+ # Rotate H
1458
+ for j in range(i, k):
1459
+ a = Ha[j]
1460
+ b = Hb[j]
1461
+ Ha[j] = c * a + s * b
1462
+ Hb[j] = c * b - s * a
1463
+
1464
+ # Rotate rhs
1465
+ y[i] = c * rhs
1466
+ rhs = -s * rhs
1467
+
1468
+ for i in range(max_k, k):
1469
+ y[i] = y.dtype(0.0)
1470
+
1471
+ # Triangular back-solve for y
1472
+ for ii in range(max_k, 0, -1):
1473
+ i = ii - 1
1474
+ Hi = H[i]
1475
+ yi = y[i]
1476
+ for j in range(ii, max_k):
1477
+ yi -= Hi[j] * y[j]
1478
+ y[i] = yi / Hi[i]
1479
+
1480
+
1481
+ @functools.lru_cache(maxsize=None)
1482
+ def make_gmres_solve_least_squares_kernel_tiled(K: int):
1483
+ @wp.kernel(module="unique")
1484
+ def gmres_solve_least_squares_tiled(
1485
+ k: int, pivot_tolerance: float, beta: wp.array(dtype=Any), H: wp.array2d(dtype=Any), y: wp.array(dtype=Any)
1486
+ ):
1487
+ # Assumes tiles of size K, and K at least as large as highest number of columns
1488
+ # Limits the max restart cycle length to the max block size of 1024, but using
1489
+ # larger restarts would be very inefficient anyway (default is ~30)
1490
+
1491
+ # Solve H y = (beta, 0, ..., 0)
1492
+ # H Hessenberg matrix of shape (k+1, k)
1493
+
1494
+ i, lane = wp.tid()
1495
+
1496
+ rhs = beta[0]
1497
+
1498
+ zero = H.dtype(0.0)
1499
+ one = H.dtype(1.0)
1500
+ yi = zero
1501
+
1502
+ Ha = wp.tile_load(H[0], shape=(K))
1503
+
1504
+ # Apply 2x2 rotations to H so as to remove lower diagonal,
1505
+ # and apply similar rotations to right-hand-side
1506
+ max_k = int(k)
1507
+ for i in range(k):
1508
+ # Ha = H[i]
1509
+ # Hb = H[i + 1]
1510
+ Hb = wp.tile_load(H[i + 1], shape=(K))
1511
+
1512
+ # Givens rotation [[c s], [-s c]]
1513
+ a = Ha[i]
1514
+ b = Hb[i]
1515
+ abn_sq = a * a + b * b
1516
+
1517
+ if abn_sq < type(abn_sq)(pivot_tolerance):
1518
+ # Arnoldi iteration finished early
1519
+ max_k = i
1520
+ break
1521
+
1522
+ abn = wp.sqrt(abn_sq)
1523
+ c = a / abn
1524
+ s = b / abn
1525
+
1526
+ # Rotate H
1527
+ a = wp.untile(Ha)
1528
+ b = wp.untile(Hb)
1529
+ a_rot = c * a + s * b
1530
+ b_rot = c * b - s * a
1531
+
1532
+ # Rotate rhs
1533
+ if lane == i:
1534
+ yi = c * rhs
1535
+ rhs = -s * rhs
1536
+
1537
+ wp.tile_store(H[i], wp.tile(a_rot))
1538
+ Ha[lane] = b_rot
1539
+
1540
+ y_tile = wp.tile(yi)
1541
+
1542
+ # Triangular back-solve for y
1543
+ for ii in range(max_k, 0, -1):
1544
+ i = ii - 1
1545
+
1546
+ Hi = wp.tile_load(H[i], shape=(K))
1547
+
1548
+ il = lane + i
1549
+ if lane == 0:
1550
+ yl = y_tile[i]
1551
+ elif il < max_k:
1552
+ yl = -y_tile[il] * Hi[il]
1553
+ else:
1554
+ yl = zero
1555
+
1556
+ yit = wp.tile_sum(wp.tile(yl)) * (one / Hi[i])
1557
+ yit[0] # no-op, movs yit to shared
1558
+ wp.tile_assign(y_tile, yit, offset=(i,))
1559
+
1560
+ wp.tile_store(y, y_tile)
1561
+
1562
+ return gmres_solve_least_squares_tiled
1563
+
1564
+
1565
+ @wp.kernel
1566
+ def _gmres_arnoldi_axpy_kernel(
1567
+ V: wp.array2d(dtype=Any),
1568
+ w: wp.array(dtype=Any),
1569
+ Vw: wp.array2d(dtype=Any),
1570
+ ):
1571
+ tid, lane = wp.tid()
1572
+
1573
+ s = w.dtype(Vw.dtype(0))
1574
+
1575
+ tile_size = wp.block_dim()
1576
+ for k in range(lane, Vw.shape[0], tile_size):
1577
+ s += Vw[k, 0] * V[k, tid]
1578
+
1579
+ wi = wp.tile_load(w, shape=1, offset=tid)
1580
+ wi -= wp.tile_sum(wp.tile(s, preserve_type=True))
1581
+
1582
+ wp.tile_store(w, wi, offset=tid)
1583
+
1584
+
1585
+ @wp.kernel
1586
+ def _gmres_arnoldi_normalize_kernel(
1587
+ x: wp.array(dtype=Any),
1588
+ y: wp.array(dtype=Any),
1589
+ alpha: wp.array(dtype=Any),
1590
+ alpha_copy: wp.array(dtype=Any),
1591
+ ):
1592
+ tid = wp.tid()
1593
+ norm = wp.sqrt(alpha[0])
1594
+ y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / norm)
1595
+
1596
+ if tid == 0:
1597
+ alpha_copy[0] = norm
1598
+
1599
+
1600
+ @wp.kernel
1601
+ def _gmres_update_x_kernel(k: int, beta: Any, y: wp.array(dtype=Any), V: wp.array2d(dtype=Any), x: wp.array(dtype=Any)):
1602
+ tid = wp.tid()
1603
+
1604
+ xi = beta * x[tid]
1605
+ for j in range(k):
1606
+ xi += V[j, tid] * y[j]
1607
+
1608
+ x[tid] = xi