warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,196 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Diffusion
18
+ #
19
+ # This example solves a 2d diffusion problem:
20
+ #
21
+ # nu Div u = 1
22
+ #
23
+ # with Dirichlet boundary conditions on vertical edges and
24
+ # homogeneous Neumann on horizontal edges.
25
+ ###########################################################################
26
+
27
+ import warp as wp
28
+ import warp.examples.fem.utils as fem_example_utils
29
+ import warp.fem as fem
30
+ from warp.fem.linalg import array_axpy
31
+
32
+
33
+ @fem.integrand
34
+ def linear_form(
35
+ s: fem.Sample,
36
+ v: fem.Field,
37
+ ):
38
+ """Linear form with constant slope 1 -- forcing term of our problem"""
39
+ return v(s)
40
+
41
+
42
+ @fem.integrand
43
+ def diffusion_form(s: fem.Sample, u: fem.Field, v: fem.Field, nu: float):
44
+ """Diffusion bilinear form with constant coefficient ``nu``"""
45
+ return nu * wp.dot(
46
+ fem.grad(u, s),
47
+ fem.grad(v, s),
48
+ )
49
+
50
+
51
+ @fem.integrand
52
+ def y_boundary_value_form(s: fem.Sample, domain: fem.Domain, v: fem.Field, val: float):
53
+ """Linear form with coefficient val on vertical edges, zero elsewhere"""
54
+ nor = fem.normal(domain, s)
55
+ return val * v(s) * wp.abs(nor[0])
56
+
57
+
58
+ @fem.integrand
59
+ def y_boundary_projector_form(
60
+ s: fem.Sample,
61
+ domain: fem.Domain,
62
+ u: fem.Field,
63
+ v: fem.Field,
64
+ ):
65
+ """
66
+ Bilinear boundary condition projector form, non-zero on vertical edges only.
67
+ """
68
+ # Reuse the above linear form implementation by evaluating one of the participating field and passing it as a normal scalar argument.
69
+ return y_boundary_value_form(s, domain, v, u(s))
70
+
71
+
72
+ class Example:
73
+ def __init__(
74
+ self,
75
+ quiet=False,
76
+ degree=2,
77
+ resolution=50,
78
+ mesh="grid",
79
+ serendipity=False,
80
+ viscosity=2.0,
81
+ boundary_value=5.0,
82
+ boundary_compliance=0.0,
83
+ ):
84
+ self._quiet = quiet
85
+
86
+ self._viscosity = viscosity
87
+ self._boundary_value = boundary_value
88
+ self._boundary_compliance = boundary_compliance
89
+
90
+ # Grid or triangle mesh geometry
91
+ if mesh == "tri":
92
+ positions, tri_vidx = fem_example_utils.gen_trimesh(res=wp.vec2i(resolution))
93
+ self._geo = fem.Trimesh2D(tri_vertex_indices=tri_vidx, positions=positions)
94
+ elif mesh == "quad":
95
+ positions, quad_vidx = fem_example_utils.gen_quadmesh(res=wp.vec2i(resolution))
96
+ self._geo = fem.Quadmesh2D(quad_vertex_indices=quad_vidx, positions=positions)
97
+ else:
98
+ self._geo = fem.Grid2D(res=wp.vec2i(resolution))
99
+
100
+ # Scalar function space
101
+ element_basis = fem.ElementBasis.SERENDIPITY if serendipity else None
102
+ self._scalar_space = fem.make_polynomial_space(self._geo, degree=degree, element_basis=element_basis)
103
+
104
+ # Scalar field over our function space
105
+ self._scalar_field = self._scalar_space.make_field()
106
+
107
+ self.renderer = fem_example_utils.Plot()
108
+
109
+ def step(self):
110
+ geo = self._geo
111
+
112
+ domain = fem.Cells(geometry=geo)
113
+
114
+ # Right-hand-side (forcing term)
115
+ test = fem.make_test(space=self._scalar_space, domain=domain)
116
+ rhs = fem.integrate(linear_form, fields={"v": test})
117
+
118
+ # Diffusion form
119
+ trial = fem.make_trial(space=self._scalar_space, domain=domain)
120
+ matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": self._viscosity})
121
+
122
+ # Boundary conditions on Y sides
123
+ # Use nodal integration so that boundary conditions are specified on each node independently
124
+ boundary = fem.BoundarySides(geo)
125
+ bd_test = fem.make_test(space=self._scalar_space, domain=boundary)
126
+ bd_trial = fem.make_trial(space=self._scalar_space, domain=boundary)
127
+
128
+ bd_matrix = fem.integrate(y_boundary_projector_form, fields={"u": bd_trial, "v": bd_test}, assembly="nodal")
129
+ bd_rhs = fem.integrate(
130
+ y_boundary_value_form, fields={"v": bd_test}, values={"val": self._boundary_value}, assembly="nodal"
131
+ )
132
+
133
+ # Assemble linear system
134
+ if self._boundary_compliance == 0.0:
135
+ # Hard BC: project linear system
136
+ fem.project_linear_system(matrix, rhs, bd_matrix, bd_rhs)
137
+ else:
138
+ # Weak BC: add together diffusion and boundary condition matrices
139
+ boundary_strength = 1.0 / self._boundary_compliance
140
+ matrix += bd_matrix * boundary_strength
141
+ array_axpy(x=bd_rhs, y=rhs, alpha=boundary_strength, beta=1)
142
+
143
+ # Solve linear system using Conjugate Gradient
144
+ x = wp.zeros_like(rhs)
145
+ fem_example_utils.bsr_cg(matrix, b=rhs, x=x, quiet=self._quiet)
146
+
147
+ # Assign system result to our discrete field
148
+ self._scalar_field.dof_values = x
149
+
150
+ def render(self):
151
+ self.renderer.add_field("solution", self._scalar_field)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ import argparse
156
+
157
+ wp.set_module_options({"enable_backward": False})
158
+
159
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
160
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
161
+ parser.add_argument("--resolution", type=int, default=50, help="Grid resolution.")
162
+ parser.add_argument("--degree", type=int, default=2, help="Polynomial degree of shape functions.")
163
+ parser.add_argument("--serendipity", action="store_true", default=False, help="Use Serendipity basis functions.")
164
+ parser.add_argument("--viscosity", type=float, default=2.0, help="Fluid viscosity parameter.")
165
+ parser.add_argument(
166
+ "--boundary_value", type=float, default=5.0, help="Value of Dirichlet boundary condition on vertical edges."
167
+ )
168
+ parser.add_argument(
169
+ "--boundary_compliance", type=float, default=0.0, help="Dirichlet boundary condition compliance."
170
+ )
171
+ parser.add_argument("--mesh", choices=("grid", "tri", "quad"), default="grid", help="Mesh type.")
172
+ parser.add_argument(
173
+ "--headless",
174
+ action="store_true",
175
+ help="Run in headless mode, suppressing the opening of any graphical windows.",
176
+ )
177
+ parser.add_argument("--quiet", action="store_true", help="Suppresses the printing out of iteration residuals.")
178
+
179
+ args = parser.parse_known_args()[0]
180
+
181
+ with wp.ScopedDevice(args.device):
182
+ example = Example(
183
+ quiet=args.quiet,
184
+ degree=args.degree,
185
+ resolution=args.resolution,
186
+ mesh=args.mesh,
187
+ serendipity=args.serendipity,
188
+ viscosity=args.viscosity,
189
+ boundary_value=args.boundary_value,
190
+ boundary_compliance=args.boundary_compliance,
191
+ )
192
+ example.step()
193
+ example.render()
194
+
195
+ if not args.headless:
196
+ example.renderer.plot()
@@ -0,0 +1,225 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Diffusion 3D
18
+ #
19
+ # This example solves a 3d diffusion problem:
20
+ #
21
+ # nu Div u = 1
22
+ #
23
+ # with homogeneous Neumann conditions on horizontal sides
24
+ # and homogeneous Dirichlet boundary conditions other sides.
25
+ ###########################################################################
26
+
27
+ import numpy as np
28
+
29
+ import warp as wp
30
+ import warp.examples.fem.utils as fem_example_utils
31
+ import warp.fem as fem
32
+ from warp.examples.fem.example_diffusion import diffusion_form, linear_form
33
+ from warp.sparse import bsr_axpy
34
+
35
+
36
+ @fem.integrand
37
+ def vertical_boundary_projector_form(
38
+ s: fem.Sample,
39
+ domain: fem.Domain,
40
+ u: fem.Field,
41
+ v: fem.Field,
42
+ ):
43
+ # Constrain XY and YZ faces
44
+ nor = fem.normal(domain, s)
45
+ w = 1.0 - wp.abs(nor[1])
46
+ return w * u(s) * v(s)
47
+
48
+
49
+ @fem.integrand
50
+ def y_boundary_projector_form(
51
+ s: fem.Sample,
52
+ domain: fem.Domain,
53
+ u: fem.Field,
54
+ v: fem.Field,
55
+ ):
56
+ # Constrain Y edges
57
+ tangent = fem.deformation_gradient(domain, s)
58
+ return wp.abs(tangent[1]) * u(s) * v(s)
59
+
60
+
61
+ class Example:
62
+ def __init__(
63
+ self,
64
+ quiet=False,
65
+ degree=2,
66
+ resolution=10,
67
+ mesh="grid",
68
+ serendipity=False,
69
+ viscosity=2.0,
70
+ boundary_compliance=0.0,
71
+ ):
72
+ self._quiet = quiet
73
+
74
+ self._viscosity = viscosity
75
+ self._boundary_compliance = boundary_compliance
76
+
77
+ res = wp.vec3i(resolution, max(1, resolution // 2), resolution * 2)
78
+ bounds_lo = wp.vec3(0.0, 0.0, 0.0)
79
+ bounds_hi = wp.vec3(1.0, 0.5, 2.0)
80
+
81
+ if mesh == "tet":
82
+ pos, tet_vtx_indices = fem_example_utils.gen_tetmesh(
83
+ res=res,
84
+ bounds_lo=bounds_lo,
85
+ bounds_hi=bounds_hi,
86
+ )
87
+ self._geo = fem.Tetmesh(tet_vtx_indices, pos)
88
+ elif mesh == "hex":
89
+ pos, hex_vtx_indices = fem_example_utils.gen_hexmesh(
90
+ res=res,
91
+ bounds_lo=bounds_lo,
92
+ bounds_hi=bounds_hi,
93
+ )
94
+ self._geo = fem.Hexmesh(hex_vtx_indices, pos)
95
+ elif mesh == "nano":
96
+ volume = fem_example_utils.gen_volume(
97
+ res=res,
98
+ bounds_lo=bounds_lo,
99
+ bounds_hi=bounds_hi,
100
+ )
101
+ self._geo = fem.Nanogrid(volume)
102
+ elif mesh == "tri":
103
+ pos, quad_vtx_indices = fem_example_utils.gen_trimesh(
104
+ res=res,
105
+ bounds_lo=bounds_lo,
106
+ bounds_hi=bounds_hi,
107
+ )
108
+ pos = pos.numpy()
109
+ pos_z = np.cos(3.0 * pos[:, 0]) * np.sin(4.0 * pos[:, 1])
110
+ pos = np.hstack((pos, np.expand_dims(pos_z, axis=1)))
111
+ pos = wp.array(pos, dtype=wp.vec3)
112
+ self._geo = fem.Trimesh3D(quad_vtx_indices, pos)
113
+ elif mesh == "quad":
114
+ pos, quad_vtx_indices = fem_example_utils.gen_quadmesh(
115
+ res=res,
116
+ bounds_lo=bounds_lo,
117
+ bounds_hi=bounds_hi,
118
+ )
119
+ pos = pos.numpy()
120
+ pos_z = np.cos(3.0 * pos[:, 0]) * np.sin(4.0 * pos[:, 1])
121
+ pos = np.hstack((pos, np.expand_dims(pos_z, axis=1)))
122
+ pos = wp.array(pos, dtype=wp.vec3)
123
+ self._geo = fem.Quadmesh3D(quad_vtx_indices, pos)
124
+ else:
125
+ self._geo = fem.Grid3D(
126
+ res=res,
127
+ bounds_lo=bounds_lo,
128
+ bounds_hi=bounds_hi,
129
+ )
130
+
131
+ # Domain and function spaces
132
+ element_basis = fem.ElementBasis.SERENDIPITY if serendipity else None
133
+ self._scalar_space = fem.make_polynomial_space(self._geo, degree=degree, element_basis=element_basis)
134
+
135
+ # Scalar field over our function space
136
+ self._scalar_field: fem.DiscreteField = self._scalar_space.make_field()
137
+
138
+ self.renderer = fem_example_utils.Plot()
139
+
140
+ def step(self):
141
+ geo = self._geo
142
+
143
+ domain = fem.Cells(geometry=geo)
144
+
145
+ # Right-hand-side
146
+ test = fem.make_test(space=self._scalar_space, domain=domain)
147
+ rhs = fem.integrate(linear_form, fields={"v": test})
148
+
149
+ # Weakly-imposed boundary conditions on Y sides
150
+ with wp.ScopedTimer("Integrate"):
151
+ boundary = fem.BoundarySides(geo)
152
+
153
+ bd_test = fem.make_test(space=self._scalar_space, domain=boundary)
154
+ bd_trial = fem.make_trial(space=self._scalar_space, domain=boundary)
155
+
156
+ # Pick boundary conditions depending on whether our geometry is a 3d surface or a volume
157
+ boundary_projector_form = (
158
+ vertical_boundary_projector_form if self._geo.cell_dimension == 3 else y_boundary_projector_form
159
+ )
160
+ bd_matrix = fem.integrate(boundary_projector_form, fields={"u": bd_trial, "v": bd_test}, assembly="nodal")
161
+
162
+ # Diffusion form
163
+ trial = fem.make_trial(space=self._scalar_space, domain=domain)
164
+ matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": self._viscosity})
165
+
166
+ if self._boundary_compliance == 0.0:
167
+ # Hard BC: project linear system
168
+ bd_rhs = wp.zeros_like(rhs)
169
+ fem.project_linear_system(matrix, rhs, bd_matrix, bd_rhs)
170
+ else:
171
+ # Weak BC: add together diffusion and boundary condition matrices
172
+ boundary_strength = 1.0 / self._boundary_compliance
173
+ bsr_axpy(x=bd_matrix, y=matrix, alpha=boundary_strength, beta=1)
174
+
175
+ with wp.ScopedTimer("CG solve"):
176
+ x = wp.zeros_like(rhs)
177
+ fem_example_utils.bsr_cg(matrix, b=rhs, x=x, quiet=self._quiet)
178
+ self._scalar_field.dof_values = x
179
+
180
+ def render(self):
181
+ self.renderer.add_field("solution", self._scalar_field)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ import argparse
186
+
187
+ wp.set_module_options({"enable_backward": False})
188
+
189
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
190
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
191
+ parser.add_argument("--resolution", type=int, default=10, help="Grid resolution.")
192
+ parser.add_argument("--degree", type=int, default=2, help="Polynomial degree of shape functions.")
193
+ parser.add_argument("--serendipity", action="store_true", default=False, help="Use Serendipity basis functions.")
194
+ parser.add_argument("--viscosity", type=float, default=2.0, help="Fluid viscosity parameter.")
195
+ parser.add_argument(
196
+ "--boundary_compliance", type=float, default=0.0, help="Dirichlet boundary condition compliance."
197
+ )
198
+ parser.add_argument(
199
+ "--mesh", choices=("grid", "tet", "hex", "nano", "anano", "tri", "quad"), default="grid", help="Mesh type."
200
+ )
201
+ parser.add_argument(
202
+ "--headless",
203
+ action="store_true",
204
+ help="Run in headless mode, suppressing the opening of any graphical windows.",
205
+ )
206
+ parser.add_argument("--quiet", action="store_true", help="Suppresses the printing out of iteration residuals.")
207
+
208
+ args = parser.parse_known_args()[0]
209
+
210
+ with wp.ScopedDevice(args.device):
211
+ example = Example(
212
+ quiet=args.quiet,
213
+ degree=args.degree,
214
+ resolution=args.resolution,
215
+ mesh=args.mesh,
216
+ serendipity=args.serendipity,
217
+ viscosity=args.viscosity,
218
+ boundary_compliance=args.boundary_compliance,
219
+ )
220
+
221
+ example.step()
222
+ example.render()
223
+
224
+ if not args.headless:
225
+ example.renderer.plot()
@@ -0,0 +1,225 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Diffusion MGPU
18
+ #
19
+ # This example illustrates using domain decomposition to
20
+ # solve a diffusion PDE over multiple devices
21
+ ###########################################################################
22
+
23
+ from typing import Tuple
24
+
25
+ import warp as wp
26
+ import warp.examples.fem.utils as fem_example_utils
27
+ import warp.fem as fem
28
+ from warp.examples.fem.example_diffusion import diffusion_form, linear_form
29
+ from warp.sparse import bsr_axpy, bsr_mv
30
+ from warp.utils import array_cast
31
+
32
+
33
+ @fem.integrand
34
+ def mass_form(
35
+ s: fem.Sample,
36
+ u: fem.Field,
37
+ v: fem.Field,
38
+ ):
39
+ return u(s) * v(s)
40
+
41
+
42
+ @wp.kernel
43
+ def scal_kernel(a: wp.array(dtype=wp.float64), res: wp.array(dtype=wp.float64), alpha: wp.float64):
44
+ res[wp.tid()] = a[wp.tid()] * alpha
45
+
46
+
47
+ @wp.kernel
48
+ def sum_kernel(a: wp.indexedarray(dtype=wp.float64), b: wp.array(dtype=wp.float64)):
49
+ a[wp.tid()] = a[wp.tid()] + b[wp.tid()]
50
+
51
+
52
+ def sum_vecs(vecs, indices, sum: wp.array, tmp: wp.array):
53
+ for v, idx in zip(vecs, indices):
54
+ wp.copy(dest=tmp, src=v)
55
+ idx_sum = wp.indexedarray(sum, idx)
56
+ wp.launch(kernel=sum_kernel, dim=idx.shape, device=sum.device, inputs=[idx_sum, tmp])
57
+
58
+ return sum
59
+
60
+
61
+ class DistributedSystem:
62
+ device = None
63
+ scalar_type: type
64
+ tmp_buf: wp.array
65
+
66
+ nrow: int
67
+ shape = Tuple[int, int]
68
+ rank_data = None
69
+
70
+ def mv_routine(self, x: wp.array, y: wp.array, z: wp.array, alpha=1.0, beta=0.0):
71
+ """Distributed matrix-vector multiplication routine, for example purposes"""
72
+
73
+ tmp = self.tmp_buf
74
+
75
+ wp.launch(kernel=scal_kernel, dim=y.shape, device=y.device, inputs=[y, z, wp.float64(beta)])
76
+
77
+ stream = wp.get_stream()
78
+
79
+ for mat_i, x_i, y_i, idx in zip(*self.rank_data):
80
+ tmp_i = tmp[: idx.size]
81
+
82
+ # Compress rhs on rank 0
83
+ x_idx = wp.indexedarray(x, idx)
84
+ wp.copy(dest=tmp_i, src=x_idx, count=idx.size, stream=stream)
85
+
86
+ # Send to rank i
87
+ wp.copy(dest=x_i, src=tmp_i, count=idx.size, stream=stream)
88
+
89
+ with wp.ScopedDevice(x_i.device):
90
+ wp.wait_stream(stream)
91
+ bsr_mv(A=mat_i, x=x_i, y=y_i, alpha=alpha, beta=0.0)
92
+
93
+ wp.wait_stream(wp.get_stream(x_i.device))
94
+
95
+ # Back to rank 0 for sum
96
+ wp.copy(dest=tmp_i, src=y_i, count=idx.size, stream=stream)
97
+ z_idx = wp.indexedarray(z, idx)
98
+ wp.launch(kernel=sum_kernel, dim=idx.shape, device=z_idx.device, inputs=[z_idx, tmp_i], stream=stream)
99
+
100
+ wp.wait_stream(stream)
101
+
102
+
103
+ class Example:
104
+ def __init__(self, quiet=False, device=None):
105
+ self._bd_weight = 100.0
106
+ self._quiet = quiet
107
+
108
+ self._geo = fem.Grid2D(res=wp.vec2i(25))
109
+
110
+ self._main_device = wp.get_device(device)
111
+
112
+ with wp.ScopedDevice(self._main_device):
113
+ self._scalar_space = fem.make_polynomial_space(self._geo, degree=3)
114
+ self._scalar_field = self._scalar_space.make_field()
115
+
116
+ self.renderer = fem_example_utils.Plot()
117
+
118
+ def step(self):
119
+ devices = wp.get_cuda_devices()
120
+ main_device = self._main_device
121
+
122
+ rhs_vecs = []
123
+ res_vecs = []
124
+ matrices = []
125
+ indices = []
126
+
127
+ # Build local system for each device
128
+ for k, device in enumerate(devices):
129
+ with wp.ScopedDevice(device):
130
+ # Construct the partition corresponding to the k'th device
131
+ geo_partition = fem.LinearGeometryPartition(self._geo, k, len(devices))
132
+ matrix, rhs, partition_node_indices = self._assemble_local_system(geo_partition)
133
+
134
+ rhs_vecs.append(rhs)
135
+ res_vecs.append(wp.empty_like(rhs))
136
+ matrices.append(matrix)
137
+ indices.append(partition_node_indices.to(main_device))
138
+
139
+ # Global rhs as sum of all local rhs
140
+ glob_rhs = wp.zeros(n=self._scalar_space.node_count(), dtype=wp.float64, device=main_device)
141
+
142
+ # This temporary buffer will be used for peer-to-peer copying during graph capture,
143
+ # so we allocate it using the default CUDA allocator. This ensures that the copying
144
+ # will succeed without enabling mempool access between devices, which is not supported
145
+ # on all systems.
146
+ with wp.ScopedMempool(main_device, False):
147
+ tmp = wp.empty_like(glob_rhs)
148
+
149
+ sum_vecs(rhs_vecs, indices, glob_rhs, tmp)
150
+
151
+ # Distributed CG
152
+ global_res = wp.zeros_like(glob_rhs)
153
+ A = DistributedSystem()
154
+ A.device = main_device
155
+ A.dtype = glob_rhs.dtype
156
+ A.nrow = self._scalar_space.node_count()
157
+ A.shape = (A.nrow, A.nrow)
158
+ A.tmp_buf = tmp
159
+ A.rank_data = (matrices, rhs_vecs, res_vecs, indices)
160
+
161
+ with wp.ScopedDevice(main_device):
162
+ fem_example_utils.bsr_cg(
163
+ A,
164
+ x=global_res,
165
+ b=glob_rhs,
166
+ use_diag_precond=False,
167
+ quiet=self._quiet,
168
+ mv_routine=A.mv_routine,
169
+ mv_routine_uses_multiple_cuda_contexts=True,
170
+ )
171
+
172
+ array_cast(in_array=global_res, out_array=self._scalar_field.dof_values)
173
+
174
+ def render(self):
175
+ self.renderer.add_field("solution", self._scalar_field)
176
+
177
+ def _assemble_local_system(self, geo_partition: fem.GeometryPartition):
178
+ scalar_space = self._scalar_space
179
+ space_partition = fem.make_space_partition(scalar_space, geo_partition)
180
+
181
+ domain = fem.Cells(geometry=geo_partition)
182
+
183
+ # Right-hand-side
184
+ test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=domain)
185
+ rhs = fem.integrate(linear_form, fields={"v": test})
186
+
187
+ # Weakly-imposed boundary conditions on all sides
188
+ boundary = fem.BoundarySides(geometry=geo_partition)
189
+ bd_test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=boundary)
190
+ bd_trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=boundary)
191
+ bd_matrix = fem.integrate(mass_form, fields={"u": bd_trial, "v": bd_test})
192
+
193
+ # Diffusion form
194
+ trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=domain)
195
+ matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": 1.0})
196
+
197
+ bsr_axpy(y=matrix, x=bd_matrix, alpha=self._bd_weight)
198
+
199
+ return matrix, rhs, space_partition.space_node_indices()
200
+
201
+
202
+ if __name__ == "__main__":
203
+ import argparse
204
+
205
+ wp.set_module_options({"enable_backward": False})
206
+
207
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
208
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
209
+ parser.add_argument("--quiet", action="store_true", help="Suppresses the printing out of iteration residuals.")
210
+ parser.add_argument(
211
+ "--headless",
212
+ action="store_true",
213
+ help="Run in headless mode, suppressing the opening of any graphical windows.",
214
+ )
215
+
216
+ args = parser.parse_known_args()[0]
217
+
218
+ with wp.ScopedTimer(__file__):
219
+ example = Example(quiet=args.quiet, device=args.device)
220
+
221
+ example.step()
222
+ example.render()
223
+
224
+ if not args.headless:
225
+ example.renderer.plot()