warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,506 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """An example implementation of a distributed Jacobi solver using MPI.
16
+
17
+ This example shows how to solve the Laplace equation using Jacobi iteration on
18
+ multiple GPUs using Warp and mpi4py. This example is based on the basic "mpi"
19
+ example from the Multi GPU Programming Models repository.
20
+
21
+ This example requires mpi4py and a CUDA-aware MPI implementation. We suggest
22
+ downloading and installing NVIDIA HPC-X, followed by installing mpi4py from its
23
+ source distribution: python -m pip install mpi4py
24
+
25
+ Usage:
26
+ mpirun -n 2 python example_jacobi_mpi.py
27
+
28
+ References:
29
+ https://github.com/NVIDIA/multi-gpu-programming-models
30
+ https://developer.nvidia.com/networking/hpc-x
31
+ https://github.com/mpi4py/mpi4py
32
+ """
33
+
34
+ import math
35
+ import sys
36
+ from typing import Tuple
37
+
38
+ import numpy as np
39
+ from mpi4py import MPI
40
+
41
+ import warp as wp
42
+ from warp.types import warp_type_to_np_dtype
43
+
44
+ wp.config.quiet = True # Suppress wp.init() output
45
+
46
+
47
+ tol = 1e-8
48
+ wptype = wp.float32 # Global precision setting, can set wp.float64 here for double precision
49
+ pi = wptype(math.pi) # GitHub #485
50
+
51
+
52
+ def calc_default_device(mpi_comm: "MPI.Comm") -> wp.context.Device:
53
+ """Return the device that should be used for the current rank.
54
+
55
+ This function is used to ensure that multiple MPI ranks running on the same
56
+ node are assigned to different GPUs.
57
+
58
+ Args:
59
+ mpi_comm: The MPI communicator.
60
+
61
+ Returns:
62
+ The Warp device that should be used for the current rank.
63
+
64
+ Raises:
65
+ RuntimeError: If the number of visible devices is less than the number of ranks on the node.
66
+ """
67
+
68
+ # Find the local rank and size
69
+ local_mpi_comm = mpi_comm.Split_type(MPI.COMM_TYPE_SHARED)
70
+
71
+ local_size = local_mpi_comm.Get_size()
72
+ local_rank = local_mpi_comm.Get_rank()
73
+
74
+ num_cuda_devices = wp.get_cuda_device_count()
75
+
76
+ if 1 < num_cuda_devices < local_size:
77
+ raise RuntimeError(
78
+ f"Number of visible devices ({num_cuda_devices}) is less than number of ranks on the node ({local_size})"
79
+ )
80
+
81
+ if 1 < num_cuda_devices:
82
+ # Get the device based on local_rank
83
+ return wp.get_cuda_device(local_rank)
84
+ else:
85
+ return wp.get_device()
86
+
87
+
88
+ def calc_decomp_1d(total_points: int, rank: int, total_ranks: int) -> Tuple[int, int]:
89
+ """Calculate a 1-D decomposition to divide ``total_points`` among ``total_ranks`` domains.
90
+
91
+ Returns a tuple containing the starting index of the decomposition followed
92
+ by number of points in the domain.
93
+
94
+ If ``total_points`` can not be evenly divided among ``total_ranks``,
95
+ the first ``total_points % total_ranks`` domains will contain one additional
96
+ point.
97
+ """
98
+
99
+ if rank < total_points % total_ranks:
100
+ num_domain_points = total_points // total_ranks + 1
101
+ start_index = rank * num_domain_points
102
+ else:
103
+ num_domain_points = total_points // total_ranks
104
+ start_index = total_points - (total_ranks - rank) * num_domain_points
105
+
106
+ return (start_index, num_domain_points)
107
+
108
+
109
+ @wp.kernel
110
+ def jacobi_update(
111
+ a: wp.array2d(dtype=wptype),
112
+ iy_start: int,
113
+ iy_end: int,
114
+ nx: int,
115
+ calculate_norm: bool,
116
+ a_new: wp.array2d(dtype=wptype),
117
+ l2_norm: wp.array(dtype=wptype),
118
+ ):
119
+ i, j = wp.tid()
120
+
121
+ # Convert from local thread indices to the indices used to access the arrays
122
+
123
+ iy = i + iy_start
124
+ ix = j + 1
125
+
126
+ local_l2_norm = wptype(0.0)
127
+
128
+ if iy < iy_end and ix < nx - 1:
129
+ new_val = wptype(0.25) * (a[iy - 1, ix] + a[iy + 1, ix] + a[iy, ix - 1] + a[iy, ix + 1])
130
+ a_new[iy, ix] = new_val
131
+
132
+ if calculate_norm:
133
+ residue = new_val - a[iy, ix]
134
+ local_l2_norm = residue * residue
135
+
136
+ if calculate_norm:
137
+ t = wp.tile(local_l2_norm)
138
+ s = wp.tile_sum(t)
139
+ wp.tile_atomic_add(l2_norm, s)
140
+
141
+
142
+ @wp.kernel
143
+ def initialize_boundaries(
144
+ nx: int,
145
+ ny: int,
146
+ offset: int,
147
+ a: wp.array2d(dtype=wptype),
148
+ a_new: wp.array2d(dtype=wptype),
149
+ ):
150
+ i = wp.tid()
151
+
152
+ boundary_val = wp.sin(wptype(2.0) * pi * wptype(i + offset) / wptype(ny - 1))
153
+
154
+ a[i, 0] = boundary_val
155
+ a[i, nx - 1] = boundary_val
156
+ a_new[i, 0] = boundary_val
157
+ a_new[i, nx - 1] = boundary_val
158
+
159
+
160
+ def benchmark_single_gpu(nx: int, ny: int, iter_max: int, nccheck: int = 1, verbose: bool = False):
161
+ """Compute the solution on a single GPU for performance and correctness comparisons.
162
+
163
+ Args:
164
+ nx: The number of points in the x-direction.
165
+ ny: The number of points in the y-direction.
166
+ iter_max: The maximum number of Jacobi iterations.
167
+ nccheck: The number of iterations between norm checks. Defaults to 1.
168
+ verbose: Whether to print verbose output. Defaults to False.
169
+
170
+ Returns:
171
+ tuple: A tuple containing:
172
+ - runtime (float): The execution time of the solution in seconds.
173
+ - solution (warp.array2d): The solution as a Warp array on the host
174
+ with dimensions ``(ny, nx)``.
175
+ """
176
+
177
+ a = wp.zeros((ny, nx), dtype=wptype)
178
+ a_new = wp.zeros_like(a)
179
+
180
+ l2_norm_d = wp.zeros((1,), dtype=wptype)
181
+ l2_norm_h = wp.ones_like(l2_norm_d, device="cpu", pinned=True)
182
+
183
+ compute_stream = wp.Stream()
184
+ push_top_stream = wp.Stream()
185
+ push_bottom_stream = wp.Stream()
186
+
187
+ compute_done = wp.Event()
188
+ push_top_done = wp.Event()
189
+ push_bottom_done = wp.Event()
190
+
191
+ iy_start = 1
192
+ iy_end = ny - 1
193
+ update_shape = (iy_end - iy_start, nx - 2)
194
+
195
+ wp.launch(initialize_boundaries, dim=(ny,), inputs=[nx, ny, 0], outputs=[a, a_new])
196
+
197
+ if verbose:
198
+ print(
199
+ f"Single GPU jacobi relaxation: {iter_max} iterations on {ny} x {nx} mesh with norm check every {nccheck}"
200
+ " iterations"
201
+ )
202
+
203
+ iter = 0
204
+ l2_norm = 1.0
205
+
206
+ start_time = MPI.Wtime()
207
+
208
+ while l2_norm > tol and iter < iter_max:
209
+ calculate_norm = (iter % nccheck == 0) or (iter % 100 == 0)
210
+
211
+ with wp.ScopedStream(compute_stream):
212
+ l2_norm_d.zero_()
213
+
214
+ compute_stream.wait_event(push_top_done)
215
+ compute_stream.wait_event(push_bottom_done)
216
+
217
+ wp.launch(
218
+ jacobi_update,
219
+ update_shape,
220
+ inputs=[a, iy_start, iy_end, nx, calculate_norm],
221
+ outputs=[a_new, l2_norm_d],
222
+ )
223
+ wp.record_event(compute_done)
224
+
225
+ if calculate_norm:
226
+ wp.copy(l2_norm_h, l2_norm_d, stream=compute_stream)
227
+
228
+ # Apply periodic boundary conditions
229
+ push_top_stream.wait_event(compute_done)
230
+ wp.copy(a_new[0], a_new[iy_end - 1], stream=push_top_stream)
231
+ push_top_stream.record_event(push_top_done)
232
+
233
+ push_bottom_stream.wait_event(compute_done)
234
+ wp.copy(a_new[iy_end], a_new[iy_start], stream=push_bottom_stream)
235
+ push_bottom_stream.record_event(push_bottom_done)
236
+
237
+ if calculate_norm:
238
+ wp.synchronize_stream(compute_stream)
239
+
240
+ l2_norm = math.sqrt(l2_norm_h.numpy()[0])
241
+
242
+ if verbose and iter % 100 == 0:
243
+ print(f"{iter:5d}, {l2_norm:.6f}")
244
+
245
+ # Swap arrays
246
+ a, a_new = a_new, a
247
+
248
+ iter += 1
249
+
250
+ wp.synchronize_device()
251
+ stop_time = MPI.Wtime()
252
+
253
+ a_ref_h = wp.empty((ny, nx), dtype=wptype, device="cpu")
254
+ wp.copy(a_ref_h, a)
255
+
256
+ return stop_time - start_time, a_ref_h
257
+
258
+
259
+ class Example:
260
+ def __init__(
261
+ self,
262
+ nx: int = 16384,
263
+ ny: int = 16384,
264
+ iter_max: int = 1000,
265
+ nccheck: int = 1,
266
+ csv: bool = False,
267
+ ):
268
+ self.iter_max = iter_max
269
+ self.nx = nx # Global resolution
270
+ self.ny = ny # Global resolution
271
+ self.nccheck = nccheck
272
+ self.csv = csv
273
+
274
+ self.mpi_comm = MPI.COMM_WORLD
275
+ self.mpi_rank = self.mpi_comm.Get_rank()
276
+ self.mpi_size = self.mpi_comm.Get_size()
277
+
278
+ # Set the default device on the current rank
279
+ self.device = calc_default_device(self.mpi_comm)
280
+ wp.set_device(self.device)
281
+
282
+ # We need to disable memory pools for peer-to-peer transfers using MPI
283
+ # wp.set_mempool_enabled(wp.get_cuda_device(), False)
284
+ self.compute_stream = wp.Stream()
285
+ self.compute_done = wp.Event()
286
+
287
+ # Compute the solution on a single GPU for comparisons
288
+ self.runtime_serial, self.a_ref_h = benchmark_single_gpu(
289
+ self.nx, self.ny, self.iter_max, self.nccheck, not self.csv and self.mpi_rank == 0
290
+ )
291
+
292
+ # num_local_rows: Number of rows from the full (self.ny, self.nx) solution that
293
+ # this rank will calculate (excludes halo regions)
294
+ # iy_start_global: Allows us to go from a local index to a global index.
295
+
296
+ # self.ny-2 rows are distributed among the ranks for comparison with single-GPU case,
297
+ # which reserves the first and last rows for the boundary conditions
298
+ iy_decomp_start, self.num_local_rows = calc_decomp_1d(self.ny - 2, self.mpi_rank, self.mpi_size)
299
+
300
+ # Add 1 to get the global start index since the 1-D decomposition excludes the boundaries
301
+ self.iy_start_global = iy_decomp_start + 1
302
+
303
+ self.mpi_comm.Barrier()
304
+ if not self.csv:
305
+ print(
306
+ f"Rank {self.mpi_rank} on device {wp.get_cuda_device().pci_bus_id}: "
307
+ f"{self.num_local_rows} rows from y = {self.iy_start_global} to y = {self.iy_start_global + self.num_local_rows - 1}"
308
+ )
309
+ self.mpi_comm.Barrier()
310
+
311
+ # Allocate local array (the +2 is for the halo layer on each side)
312
+ self.a = wp.zeros((self.num_local_rows + 2, self.nx), dtype=wptype)
313
+ self.a_new = wp.zeros_like(self.a)
314
+
315
+ # Allocate host array for the final result
316
+ self.a_h = wp.empty((self.ny, self.nx), dtype=wptype, device="cpu")
317
+
318
+ self.l2_norm_d = wp.zeros((1,), dtype=wptype)
319
+ self.l2_norm_h = wp.ones_like(self.l2_norm_d, device="cpu", pinned=True)
320
+
321
+ # Boundary Conditions
322
+ # - y-boundaries (iy=0 and iy=self.ny-1): Periodic
323
+ # - x-boundaries (ix=0 and ix=self.nx-1): Dirichlet
324
+
325
+ # Local Indices
326
+ self.iy_start = 1
327
+ self.iy_end = self.iy_start + self.num_local_rows # Last owned row begins at [iy_end-1, 0]
328
+
329
+ # Don't need to loop over the Dirichlet boundaries in the Jacobi iteration
330
+ self.update_shape = (self.num_local_rows, self.nx - 2)
331
+
332
+ # Used for inter-rank communication
333
+ self.lower_neighbor = (self.mpi_rank + 1) % self.mpi_size
334
+ self.upper_neighbor = self.mpi_rank - 1 if self.mpi_rank > 0 else self.mpi_size - 1
335
+
336
+ # Apply Dirichlet boundary conditions to both a and a_new
337
+ wp.launch(
338
+ initialize_boundaries,
339
+ dim=(self.num_local_rows + 2,),
340
+ inputs=[self.nx, self.ny, self.iy_start_global - 1],
341
+ outputs=[self.a, self.a_new],
342
+ )
343
+
344
+ # MPI Warmup
345
+ wp.synchronize_device()
346
+
347
+ for _mpi_warmup in range(10):
348
+ self.apply_periodic_bc()
349
+ self.a, self.a_new = self.a_new, self.a
350
+
351
+ wp.synchronize_device()
352
+
353
+ if not self.csv and self.mpi_rank == 0:
354
+ print(
355
+ f"Jacobi relaxation: {self.iter_max} iterations on {self.ny} x {self.nx} mesh with norm check "
356
+ f"every {self.nccheck} iterations"
357
+ )
358
+
359
+ def apply_periodic_bc(self) -> None:
360
+ """Apply periodic boundary conditions to the array.
361
+
362
+ This function sends the first row of owned data to the lower neighbor
363
+ and the last row of owned data to the upper neighbor.
364
+ """
365
+ # Send the first row of owned data to the lower neighbor
366
+ self.mpi_comm.Sendrecv(
367
+ self.a_new[self.iy_start], self.lower_neighbor, 0, self.a_new[self.iy_end], self.upper_neighbor, 0
368
+ )
369
+ # Send the last row of owned data to the upper neighbor
370
+ self.mpi_comm.Sendrecv(
371
+ self.a_new[self.iy_end - 1], self.upper_neighbor, 0, self.a_new[0], self.lower_neighbor, 0
372
+ )
373
+
374
+ def step(self, calculate_norm: bool) -> None:
375
+ """Perform a single Jacobi iteration step."""
376
+ with wp.ScopedStream(self.compute_stream):
377
+ self.l2_norm_d.zero_()
378
+ wp.launch(
379
+ jacobi_update,
380
+ self.update_shape,
381
+ inputs=[self.a, self.iy_start, self.iy_end, self.nx, calculate_norm],
382
+ outputs=[self.a_new, self.l2_norm_d],
383
+ )
384
+ wp.record_event(self.compute_done)
385
+
386
+ def run(self) -> None:
387
+ """Run the Jacobi relaxation on multiple GPUs using MPI and compare with single-GPU results."""
388
+ iter = 0
389
+ l2_norm = np.array([1.0], dtype=warp_type_to_np_dtype[wptype])
390
+
391
+ start_time = MPI.Wtime()
392
+
393
+ while l2_norm > tol and iter < self.iter_max:
394
+ calculate_norm = (iter % self.nccheck == 0) or (not self.csv and iter % 100 == 0)
395
+
396
+ self.step(calculate_norm)
397
+
398
+ if calculate_norm:
399
+ wp.copy(self.l2_norm_h, self.l2_norm_d, stream=self.compute_stream)
400
+
401
+ wp.synchronize_event(self.compute_done)
402
+
403
+ self.apply_periodic_bc()
404
+
405
+ if calculate_norm:
406
+ wp.synchronize_stream(self.compute_stream)
407
+
408
+ self.mpi_comm.Allreduce(self.l2_norm_h.numpy(), l2_norm)
409
+ l2_norm = np.sqrt(l2_norm)
410
+
411
+ if not self.csv and self.mpi_rank == 0 and iter % 100 == 0:
412
+ print(f"{iter:5d}, {l2_norm[0]:.6f}")
413
+
414
+ # Swap arrays
415
+ self.a, self.a_new = self.a_new, self.a
416
+
417
+ iter += 1
418
+
419
+ wp.synchronize_device()
420
+ stop_time = MPI.Wtime()
421
+
422
+ result_correct = self.check_results(tol)
423
+ global_result_correct = self.mpi_comm.allreduce(result_correct, op=MPI.MIN)
424
+
425
+ if not global_result_correct:
426
+ sys.exit(1)
427
+ elif global_result_correct and self.mpi_rank == 0:
428
+ if self.csv:
429
+ print(
430
+ f"mpi, {self.nx}, {self.ny}, {self.iter_max}, {self.nccheck}, {self.mpi_size}, 1, "
431
+ f"{stop_time - start_time}, {self.runtime_serial}"
432
+ )
433
+ else:
434
+ print(f"Num GPUs: {self.mpi_size}")
435
+ print(
436
+ f"{self.ny}x{self.nx}: 1 GPU: {self.runtime_serial:8.4f} s, "
437
+ f"{self.mpi_size} GPUs {stop_time - start_time:8.4f} s, "
438
+ f"speedup: {self.runtime_serial / (stop_time - start_time):8.2f}, "
439
+ f"efficiency: {self.runtime_serial / (stop_time - start_time) / self.mpi_size * 100:8.2f}"
440
+ )
441
+
442
+ def check_results(self, tol: float = 1e-8) -> bool:
443
+ """Returns ``True`` if multi-GPU result is within ``tol`` of the single-GPU result.
444
+
445
+ Comparison is performed on the host in a serial manner.
446
+ """
447
+ result_correct = True
448
+
449
+ wp.copy(
450
+ self.a_h,
451
+ self.a,
452
+ dest_offset=self.iy_start_global * self.nx,
453
+ src_offset=self.nx,
454
+ count=self.num_local_rows * self.nx,
455
+ )
456
+
457
+ a_ref_np = self.a_ref_h.numpy()
458
+ a_np = self.a_h.numpy()
459
+
460
+ for iy in range(self.iy_start_global, self.iy_start_global + self.num_local_rows):
461
+ if not result_correct:
462
+ break
463
+ for ix in range(1, self.nx - 1):
464
+ if math.fabs(a_ref_np[iy, ix] - a_np[iy, ix]) > tol:
465
+ result_correct = False
466
+ print(
467
+ f"ERROR on rank {self.mpi_rank}: a[{iy},{ix}] = {a_np[iy, ix]} does not match "
468
+ f"{a_ref_np[iy, ix]} (reference)"
469
+ )
470
+ break
471
+
472
+ return result_correct
473
+
474
+
475
+ if __name__ == "__main__":
476
+ import argparse
477
+
478
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
479
+
480
+ parser.add_argument("--itermax", type=int, default=1000, help="Maximum number of Jacobi iterations.")
481
+ parser.add_argument("--nccheck", type=int, default=1, help="Check convergence every nccheck iterations.")
482
+ parser.add_argument("--nx", type=int, default=16384, help="Total resolution in x.")
483
+ parser.add_argument("--ny", type=int, default=16384, help="Total resolution in y.")
484
+ parser.add_argument("-csv", action="store_true", help="Print results as CSV values.")
485
+ parser.add_argument(
486
+ "--visualize",
487
+ action="store_true",
488
+ help="Display the final solution in a graphical window using matplotlib.",
489
+ )
490
+
491
+ args = parser.parse_known_args()[0]
492
+
493
+ example = Example(args.nx, args.ny, args.itermax, args.nccheck, args.csv)
494
+
495
+ example.run()
496
+
497
+ if args.visualize:
498
+ import matplotlib.pyplot as plt
499
+
500
+ # Plot the final result
501
+ plt.imshow(example.a.numpy(), cmap="viridis", origin="lower", vmin=-1, vmax=1)
502
+ plt.colorbar(label="Value")
503
+ plt.title(f"Rank {example.mpi_rank} Jacobi Iteration Result")
504
+ plt.xlabel("X-axis")
505
+ plt.ylabel("Y-axis")
506
+ plt.show()