warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1047 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import gc
18
+ from typing import Any, Dict, Optional, Tuple
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ import warp.fem as fem
24
+ from warp.context import assert_conditional_graph_support
25
+ from warp.optim.linear import LinearOperator, aslinearoperator, preconditioner
26
+ from warp.sparse import BsrMatrix, bsr_get_diag, bsr_mv, bsr_transposed
27
+
28
+ __all__ = [
29
+ "Plot",
30
+ "SaddleSystem",
31
+ "bsr_cg",
32
+ "bsr_solve_saddle",
33
+ "gen_hexmesh",
34
+ "gen_quadmesh",
35
+ "gen_tetmesh",
36
+ "gen_trimesh",
37
+ "invert_diagonal_bsr_matrix",
38
+ ]
39
+
40
+ # matrix inversion routines contain nested loops,
41
+ # default unrolling leads to code explosion
42
+ wp.set_module_options({"max_unroll": 6})
43
+
44
+ #
45
+ # Mesh utilities
46
+ #
47
+
48
+
49
+ def gen_trimesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
50
+ """Constructs a triangular mesh by diving each cell of a dense 2D grid into two triangles
51
+
52
+ Args:
53
+ res: Resolution of the grid along each dimension
54
+ bounds_lo: Position of the lower bound of the axis-aligned grid
55
+ bounds_hi: Position of the upper bound of the axis-aligned grid
56
+
57
+ Returns:
58
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
59
+ """
60
+
61
+ if bounds_lo is None:
62
+ bounds_lo = wp.vec2(0.0)
63
+
64
+ if bounds_hi is None:
65
+ bounds_hi = wp.vec2(1.0)
66
+
67
+ Nx = res[0]
68
+ Ny = res[1]
69
+
70
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
71
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
72
+
73
+ positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
74
+
75
+ vidx = fem.utils.grid_to_tris(Nx, Ny)
76
+
77
+ return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
78
+
79
+
80
+ def gen_tetmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
81
+ """Constructs a tetrahedral mesh by diving each cell of a dense 3D grid into five tetrahedrons
82
+
83
+ Args:
84
+ res: Resolution of the grid along each dimension
85
+ bounds_lo: Position of the lower bound of the axis-aligned grid
86
+ bounds_hi: Position of the upper bound of the axis-aligned grid
87
+
88
+ Returns:
89
+ Tuple of ndarrays: (Vertex positions, Tetrahedron vertex indices)
90
+ """
91
+
92
+ if bounds_lo is None:
93
+ bounds_lo = wp.vec3(0.0)
94
+
95
+ if bounds_hi is None:
96
+ bounds_hi = wp.vec3(1.0)
97
+
98
+ Nx = res[0]
99
+ Ny = res[1]
100
+ Nz = res[2]
101
+
102
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
103
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
104
+ z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
105
+
106
+ positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
107
+
108
+ vidx = fem.utils.grid_to_tets(Nx, Ny, Nz)
109
+
110
+ return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
111
+
112
+
113
+ def gen_quadmesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
114
+ """Constructs a quadrilateral mesh from a dense 2D grid
115
+
116
+ Args:
117
+ res: Resolution of the grid along each dimension
118
+ bounds_lo: Position of the lower bound of the axis-aligned grid
119
+ bounds_hi: Position of the upper bound of the axis-aligned grid
120
+
121
+ Returns:
122
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
123
+ """
124
+ if bounds_lo is None:
125
+ bounds_lo = wp.vec2(0.0)
126
+
127
+ if bounds_hi is None:
128
+ bounds_hi = wp.vec2(1.0)
129
+
130
+ Nx = res[0]
131
+ Ny = res[1]
132
+
133
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
134
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
135
+
136
+ positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
137
+
138
+ vidx = fem.utils.grid_to_quads(Nx, Ny)
139
+
140
+ return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
141
+
142
+
143
+ def gen_hexmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
144
+ """Constructs a quadrilateral mesh from a dense 2D grid
145
+
146
+ Args:
147
+ res: Resolution of the grid along each dimension
148
+ bounds_lo: Position of the lower bound of the axis-aligned grid
149
+ bounds_hi: Position of the upper bound of the axis-aligned grid
150
+
151
+ Returns:
152
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
153
+ """
154
+
155
+ if bounds_lo is None:
156
+ bounds_lo = wp.vec3(0.0)
157
+
158
+ if bounds_hi is None:
159
+ bounds_hi = wp.vec3(1.0)
160
+
161
+ Nx = res[0]
162
+ Ny = res[1]
163
+ Nz = res[2]
164
+
165
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
166
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
167
+ z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
168
+
169
+ positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
170
+
171
+ vidx = fem.utils.grid_to_hexes(Nx, Ny, Nz)
172
+
173
+ return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
174
+
175
+
176
+ def gen_volume(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None, device=None) -> wp.Volume:
177
+ """Constructs a wp.Volume from a dense 3D grid
178
+
179
+ Args:
180
+ res: Resolution of the grid along each dimension
181
+ bounds_lo: Position of the lower bound of the axis-aligned grid
182
+ bounds_hi: Position of the upper bound of the axis-aligned grid
183
+ device: Cuda device on which to allocate the grid
184
+ """
185
+
186
+ if bounds_lo is None:
187
+ bounds_lo = wp.vec3(0.0)
188
+
189
+ if bounds_hi is None:
190
+ bounds_hi = wp.vec3(1.0)
191
+
192
+ extents = bounds_hi - bounds_lo
193
+ voxel_size = wp.cw_div(extents, wp.vec3(res))
194
+
195
+ x = np.arange(res[0], dtype=int)
196
+ y = np.arange(res[1], dtype=int)
197
+ z = np.arange(res[2], dtype=int)
198
+
199
+ ijk = np.transpose(np.meshgrid(x, y, z), axes=(1, 2, 3, 0)).reshape(-1, 3)
200
+ ijk = wp.array(ijk, dtype=wp.vec3i, device=device)
201
+ return wp.Volume.allocate_by_voxels(
202
+ ijk, voxel_size=voxel_size, translation=bounds_lo + 0.5 * voxel_size, device=device
203
+ )
204
+
205
+
206
+ #
207
+ # Bsr matrix utilities
208
+ #
209
+
210
+
211
+ def _get_linear_solver_func(method_name: str):
212
+ from warp.optim.linear import bicgstab, cg, cr, gmres
213
+
214
+ if method_name == "bicgstab":
215
+ return bicgstab
216
+ if method_name == "gmres":
217
+ return gmres
218
+ if method_name == "cr":
219
+ return cr
220
+ return cg
221
+
222
+
223
+ def bsr_cg(
224
+ A: BsrMatrix,
225
+ x: wp.array,
226
+ b: wp.array,
227
+ max_iters: int = 0,
228
+ tol: float = 0.0001,
229
+ check_every=10,
230
+ use_diag_precond=True,
231
+ mv_routine=None,
232
+ quiet=False,
233
+ method: str = "cg",
234
+ M: BsrMatrix = None,
235
+ mv_routine_uses_multiple_cuda_contexts: bool = False,
236
+ ) -> Tuple[float, int]:
237
+ """Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
238
+
239
+ Args:
240
+ A: system left-hand side
241
+ x: result vector and initial guess
242
+ b: system right-hand-side
243
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
244
+ tol: relative tolerance under which to stop the solve
245
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
246
+ use_diag_precond: Whether to use diagonal preconditioning
247
+ mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
248
+ quiet: if True, do not print iteration residuals
249
+ method: Iterative solver method to use, defaults to Conjugate Gradient
250
+ mv_routine_uses_multiple_cuda_contexts: Whether the matrix-vector multiplication routine uses multiple CUDA contexts,
251
+ which prevents the use of conditional CUDA graphs.
252
+
253
+ Returns:
254
+ Tuple (residual norm, iteration count)
255
+
256
+ """
257
+
258
+ if M is not None:
259
+ M = aslinearoperator(M)
260
+ elif mv_routine is None:
261
+ M = preconditioner(A, "diag") if use_diag_precond else None
262
+ else:
263
+ A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
264
+ M = None
265
+
266
+ func = _get_linear_solver_func(method_name=method)
267
+
268
+ callback = None
269
+
270
+ use_cuda_graph = A.device.is_cuda and not wp.config.verify_cuda
271
+ capturable = use_cuda_graph and not mv_routine_uses_multiple_cuda_contexts
272
+
273
+ if capturable:
274
+ try:
275
+ assert_conditional_graph_support()
276
+ except RuntimeError:
277
+ capturable = False
278
+
279
+ if not quiet:
280
+ if capturable:
281
+
282
+ @wp.func_native(snippet=f'printf("%s: ", "{func.__name__}");')
283
+ def print_method_name():
284
+ pass
285
+
286
+ @fem.cache.dynamic_kernel(suffix=f"{check_every}{func.__name__}")
287
+ def device_cg_callback(
288
+ cur_iter: wp.array(dtype=int),
289
+ err_sq: wp.array(dtype=Any),
290
+ atol_sq: wp.array(dtype=Any),
291
+ ):
292
+ if cur_iter[0] % check_every == 0:
293
+ print_method_name()
294
+ wp.printf(
295
+ "at iteration %d error = \t %f \t tol: %f\n",
296
+ cur_iter[0],
297
+ wp.sqrt(err_sq[0]),
298
+ wp.sqrt(atol_sq[0]),
299
+ )
300
+
301
+ if check_every > 0:
302
+ callback = device_cg_callback
303
+ else:
304
+
305
+ def print_callback(i, err, tol):
306
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
307
+
308
+ callback = print_callback
309
+
310
+ if use_cuda_graph:
311
+ # Temporarily disable garbage collection
312
+ # Garbage collection of externally-allocated objects during graph capture may lead to
313
+ # invalid operations or memory access errors.
314
+ gc.disable()
315
+
316
+ end_iter, err, atol = func(
317
+ A=A,
318
+ b=b,
319
+ x=x,
320
+ maxiter=max_iters,
321
+ tol=tol,
322
+ check_every=0 if capturable else check_every,
323
+ M=M,
324
+ callback=callback,
325
+ use_cuda_graph=use_cuda_graph,
326
+ )
327
+
328
+ if use_cuda_graph:
329
+ gc.enable()
330
+
331
+ if isinstance(end_iter, wp.array):
332
+ end_iter = end_iter.numpy()[0]
333
+ err = np.sqrt(err.numpy()[0])
334
+ atol = np.sqrt(atol.numpy()[0])
335
+
336
+ if not quiet:
337
+ res_str = "OK" if err <= atol else "TRUNCATED"
338
+ print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
339
+
340
+ return err, end_iter
341
+
342
+
343
+ class SaddleSystem(LinearOperator):
344
+ """Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
345
+
346
+ If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ A: BsrMatrix,
352
+ B: BsrMatrix,
353
+ Bt: Optional[BsrMatrix] = None,
354
+ use_diag_precond: bool = True,
355
+ ):
356
+ if Bt is None:
357
+ Bt = bsr_transposed(B)
358
+
359
+ self._A = A
360
+ self._B = B
361
+ self._Bt = Bt
362
+
363
+ self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
364
+ self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
365
+ self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
366
+
367
+ saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
368
+
369
+ super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
370
+
371
+ if use_diag_precond:
372
+ self._preconditioner = self._diag_preconditioner()
373
+ else:
374
+ self._preconditioner = None
375
+
376
+ def _diag_preconditioner(self):
377
+ A = self._A
378
+ B = self._B
379
+
380
+ M_u = preconditioner(A, "diag")
381
+
382
+ A_diag = bsr_get_diag(A)
383
+
384
+ schur_block_shape = (B.block_shape[0], B.block_shape[0])
385
+ schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
386
+ schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
387
+ wp.launch(
388
+ _compute_schur_inverse_diagonal,
389
+ dim=B.nrow,
390
+ device=A.device,
391
+ inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
392
+ )
393
+
394
+ if schur_block_shape == (1, 1):
395
+ # Downcast 1x1 mats to scalars
396
+ schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
397
+
398
+ M_p = aslinearoperator(schur_inv_diag)
399
+
400
+ def precond_mv(x, y, z, alpha, beta):
401
+ x_u = self.u_slice(x)
402
+ x_p = self.p_slice(x)
403
+ y_u = self.u_slice(y)
404
+ y_p = self.p_slice(y)
405
+ z_u = self.u_slice(z)
406
+ z_p = self.p_slice(z)
407
+
408
+ M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
409
+ M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
410
+
411
+ return LinearOperator(
412
+ shape=self.shape,
413
+ dtype=self.dtype,
414
+ device=self.device,
415
+ matvec=precond_mv,
416
+ )
417
+
418
+ @property
419
+ def preconditioner(self):
420
+ return self._preconditioner
421
+
422
+ def u_slice(self, a: wp.array):
423
+ return wp.array(
424
+ ptr=a.ptr,
425
+ dtype=self._u_dtype,
426
+ shape=self._A.nrow,
427
+ strides=None,
428
+ device=a.device,
429
+ pinned=a.pinned,
430
+ copy=False,
431
+ )
432
+
433
+ def p_slice(self, a: wp.array):
434
+ return wp.array(
435
+ ptr=a.ptr + self._p_byte_offset,
436
+ dtype=self._p_dtype,
437
+ shape=self._B.nrow,
438
+ strides=None,
439
+ device=a.device,
440
+ pinned=a.pinned,
441
+ copy=False,
442
+ )
443
+
444
+ def _saddle_mv(self, x, y, z, alpha, beta):
445
+ x_u = self.u_slice(x)
446
+ x_p = self.p_slice(x)
447
+ z_u = self.u_slice(z)
448
+ z_p = self.p_slice(z)
449
+
450
+ if y.ptr != z.ptr and beta != 0.0:
451
+ wp.copy(src=y, dest=z)
452
+
453
+ bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
454
+ bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
455
+ bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
456
+
457
+
458
+ def bsr_solve_saddle(
459
+ saddle_system: SaddleSystem,
460
+ x_u: wp.array,
461
+ x_p: wp.array,
462
+ b_u: wp.array,
463
+ b_p: wp.array,
464
+ max_iters: int = 0,
465
+ tol: float = 0.0001,
466
+ check_every=10,
467
+ quiet=False,
468
+ method: str = "cg",
469
+ ) -> Tuple[float, int]:
470
+ """Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
471
+
472
+ Args:
473
+ saddle_system: Saddle point system
474
+ x_u: primal part of the result vector and initial guess
475
+ x_p: Lagrange multiplier part of the result vector and initial guess
476
+ b_u: primal left-hand-side
477
+ b_p: constraint left-hand-side
478
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
479
+ tol: relative tolerance under which to stop the solve
480
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
481
+ quiet: if True, do not print iteration residuals
482
+ method: Iterative solver method to use, defaults to BiCGSTAB
483
+
484
+ Returns:
485
+ Tuple (residual norm, iteration count)
486
+
487
+ """
488
+ x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
489
+ b = wp.empty_like(x)
490
+
491
+ wp.copy(src=x_u, dest=saddle_system.u_slice(x))
492
+ wp.copy(src=x_p, dest=saddle_system.p_slice(x))
493
+ wp.copy(src=b_u, dest=saddle_system.u_slice(b))
494
+ wp.copy(src=b_p, dest=saddle_system.p_slice(b))
495
+
496
+ err, end_iter = bsr_cg(
497
+ saddle_system,
498
+ x,
499
+ b,
500
+ max_iters=max_iters,
501
+ tol=tol,
502
+ check_every=check_every,
503
+ quiet=quiet,
504
+ method=method,
505
+ M=saddle_system.preconditioner,
506
+ )
507
+
508
+ wp.copy(dest=x_u, src=saddle_system.u_slice(x))
509
+ wp.copy(dest=x_p, src=saddle_system.p_slice(x))
510
+
511
+ return err, end_iter
512
+
513
+
514
+ @wp.kernel(enable_backward=False)
515
+ def _compute_schur_inverse_diagonal(
516
+ B_offsets: wp.array(dtype=int),
517
+ B_indices: wp.array(dtype=int),
518
+ B_values: wp.array(dtype=Any),
519
+ A_diag: wp.array(dtype=Any),
520
+ P_diag: wp.array(dtype=Any),
521
+ ):
522
+ row = wp.tid()
523
+
524
+ zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
525
+
526
+ schur = zero
527
+
528
+ beg = B_offsets[row]
529
+ end = B_offsets[row + 1]
530
+
531
+ for b in range(beg, end):
532
+ B = B_values[b]
533
+ col = B_indices[b]
534
+ Ai = wp.inverse(A_diag[col])
535
+ S = B * Ai * wp.transpose(B)
536
+ schur += S
537
+
538
+ P_diag[row] = fem.linalg.inverse_qr(schur)
539
+
540
+
541
+ def invert_diagonal_bsr_matrix(A: BsrMatrix):
542
+ """Inverts each block of a block-diagonal mass matrix"""
543
+
544
+ values = A.values
545
+ if not wp.types.type_is_matrix(values.dtype):
546
+ values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
547
+
548
+ wp.launch(
549
+ kernel=_block_diagonal_invert,
550
+ dim=A.nrow,
551
+ inputs=[values],
552
+ device=values.device,
553
+ )
554
+
555
+
556
+ @wp.kernel(enable_backward=False)
557
+ def _block_diagonal_invert(values: wp.array(dtype=Any)):
558
+ i = wp.tid()
559
+ values[i] = fem.linalg.inverse_qr(values[i])
560
+
561
+
562
+ #
563
+ # Plot utilities
564
+ #
565
+
566
+
567
+ class Plot:
568
+ def __init__(self, stage=None, default_point_radius=0.01):
569
+ self.default_point_radius = default_point_radius
570
+
571
+ self._fields = {}
572
+
573
+ self._usd_renderer = None
574
+ if stage is not None:
575
+ try:
576
+ from warp.render import UsdRenderer
577
+
578
+ self._usd_renderer = UsdRenderer(stage)
579
+ except Exception as err:
580
+ print(f"Could not initialize UsdRenderer for stage '{stage}': {err}.")
581
+
582
+ def begin_frame(self, time):
583
+ if self._usd_renderer is not None:
584
+ self._usd_renderer.begin_frame(time=time)
585
+
586
+ def end_frame(self):
587
+ if self._usd_renderer is not None:
588
+ self._usd_renderer.end_frame()
589
+
590
+ def add_field(self, name: str, field: fem.DiscreteField):
591
+ if self._usd_renderer is not None:
592
+ self._render_to_usd(field)
593
+
594
+ if name not in self._fields:
595
+ field_clone = field.space.make_field(space_partition=field.space_partition)
596
+ self._fields[name] = (field_clone, [])
597
+
598
+ self._fields[name][1].append(field.dof_values.numpy())
599
+
600
+ def _render_to_usd(self, name: str, field: fem.DiscreteField):
601
+ points = field.space.node_positions().numpy()
602
+ values = field.dof_values.numpy()
603
+
604
+ if values.ndim == 2:
605
+ if values.shape[1] == field.space.dimension:
606
+ # use values as displacement
607
+ points += values
608
+ else:
609
+ # use magnitude
610
+ values = np.linalg.norm(values, axis=1)
611
+
612
+ if field.space.dimension == 2:
613
+ z = values if values.ndim == 1 else np.zeros((points.shape[0], 1))
614
+ points = np.hstack((points, z))
615
+
616
+ if hasattr(field.space, "node_triangulation"):
617
+ indices = field.space.node_triangulation()
618
+ self._usd_renderer.render_mesh(name, points=points, indices=indices)
619
+ else:
620
+ self._usd_renderer.render_points(name, points=points, radius=self.default_point_radius)
621
+ elif values.ndim == 1:
622
+ self._usd_renderer.render_points(name, points, radius=values)
623
+ else:
624
+ self._usd_renderer.render_points(name, points, radius=self.default_point_radius)
625
+
626
+ def plot(self, options: Optional[Dict[str, Any]] = None, backend: str = "auto"):
627
+ if options is None:
628
+ options = {}
629
+
630
+ if backend == "pyvista":
631
+ return self._plot_pyvista(options)
632
+ if backend == "matplotlib":
633
+ return self._plot_matplotlib(options)
634
+
635
+ # try both
636
+ try:
637
+ return self._plot_pyvista(options)
638
+ except ModuleNotFoundError:
639
+ try:
640
+ return self._plot_matplotlib(options)
641
+ except ModuleNotFoundError:
642
+ wp.utils.warn("pyvista or matplotlib must be installed to visualize solution results")
643
+
644
+ def _plot_pyvista(self, options: Dict[str, Any]):
645
+ import pyvista
646
+ import pyvista.themes
647
+
648
+ grids = {}
649
+ scales = {}
650
+ markers = {}
651
+
652
+ animate = False
653
+
654
+ ref_geom = options.get("ref_geom", None)
655
+ if ref_geom is not None:
656
+ if isinstance(ref_geom, tuple):
657
+ vertices, counts, indices = ref_geom
658
+ offsets = np.cumsum(counts)
659
+ ranges = np.array([offsets - counts, offsets]).T
660
+ faces = np.concatenate(
661
+ [[count, *list(indices[beg:end])] for (count, (beg, end)) in zip(counts, ranges)]
662
+ )
663
+ ref_geom = pyvista.PolyData(vertices, faces)
664
+ else:
665
+ ref_geom = pyvista.PolyData(ref_geom)
666
+
667
+ for name, (field, values) in self._fields.items():
668
+ cells, types = field.space.vtk_cells()
669
+ node_pos = field.space.node_positions().numpy()
670
+
671
+ args = options.get(name, {})
672
+
673
+ grid_scale = np.max(np.max(node_pos, axis=0) - np.min(node_pos, axis=0))
674
+ value_range = self._get_field_value_range(values, args)
675
+ scales[name] = (grid_scale, value_range)
676
+
677
+ if node_pos.shape[1] == 2:
678
+ node_pos = np.hstack((node_pos, np.zeros((node_pos.shape[0], 1))))
679
+
680
+ grid = pyvista.UnstructuredGrid(cells, types, node_pos)
681
+ grids[name] = grid
682
+
683
+ if len(values) > 1:
684
+ animate = True
685
+
686
+ def set_frame_data(frame):
687
+ for name, (field, values) in self._fields.items():
688
+ if frame > 0 and len(values) == 1:
689
+ continue
690
+
691
+ v = values[frame % len(values)]
692
+ grid = grids[name]
693
+ grid_scale, value_range = scales[name]
694
+ field_args = options.get(name, {})
695
+
696
+ marker = None
697
+
698
+ if field.space.dimension == 2 and v.ndim == 2 and v.shape[1] == 2:
699
+ grid.point_data[name] = np.hstack((v, np.zeros((v.shape[0], 1))))
700
+ else:
701
+ grid.point_data[name] = v
702
+
703
+ if v.ndim == 2:
704
+ grid.point_data[name + "_mag"] = np.linalg.norm(v, axis=1)
705
+
706
+ if "arrows" in field_args:
707
+ glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
708
+ glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
709
+ marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
710
+ elif "contours" in field_args:
711
+ levels = field_args["contours"].get("levels", 10)
712
+ if type(levels) == int:
713
+ levels = np.linspace(*value_range, levels)
714
+ marker = grid.contour(isosurfaces=levels, scalars=name + "_mag" if v.ndim == 2 else name)
715
+ elif field.space.dimension == 2:
716
+ z_scale = grid_scale / max(1.0e-8, value_range[1] - value_range[0])
717
+
718
+ if "streamlines" in field_args:
719
+ center = np.mean(grid.points, axis=0)
720
+ density = field_args["streamlines"].get("density", 1.0)
721
+ cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
722
+
723
+ separating_distance = 0.5 / (30.0 * density * cell_size)
724
+ # Try with various sep distance until we get at least one line
725
+ while separating_distance * cell_size < 1.0:
726
+ lines = grid.streamlines_evenly_spaced_2D(
727
+ vectors=name,
728
+ start_position=center,
729
+ separating_distance=separating_distance,
730
+ separating_distance_ratio=0.5,
731
+ step_length=0.25,
732
+ compute_vorticity=False,
733
+ )
734
+ if lines.n_lines > 0:
735
+ break
736
+ separating_distance *= 1.25
737
+ marker = lines.tube(radius=0.0025 * grid_scale / density)
738
+ elif "arrows" in field_args:
739
+ glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
740
+ glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
741
+ marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
742
+ elif "displacement" in field_args:
743
+ grid.points[:, 0:2] = field.space.node_positions().numpy() + v
744
+ else:
745
+ # Extrude surface
746
+ z = v if v.ndim == 1 else grid.point_data[name + "_mag"]
747
+ grid.points[:, 2] = z * z_scale
748
+
749
+ elif field.space.dimension == 3:
750
+ if "streamlines" in field_args:
751
+ center = np.mean(grid.points, axis=0)
752
+ density = field_args["streamlines"].get("density", 1.0)
753
+ cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
754
+ lines = grid.streamlines(vectors=name, n_points=int(100 * density))
755
+ marker = lines.tube(radius=0.0025 * grid_scale / np.sqrt(density))
756
+ elif "displacement" in field_args:
757
+ grid.points = field.space.node_positions().numpy() + v
758
+
759
+ if frame == 0:
760
+ if v.ndim == 1:
761
+ grid.set_active_scalars(name)
762
+ else:
763
+ grid.set_active_vectors(name)
764
+ grid.set_active_scalars(name + "_mag")
765
+ markers[name] = marker
766
+ elif marker:
767
+ markers[name].copy_from(marker)
768
+
769
+ set_frame_data(0)
770
+
771
+ subplot_rows = options.get("rows", 1)
772
+ subplot_shape = (subplot_rows, (len(grids) + subplot_rows - 1) // subplot_rows)
773
+
774
+ plotter = pyvista.Plotter(shape=subplot_shape, theme=pyvista.themes.DocumentProTheme())
775
+ plotter.link_views()
776
+ plotter.add_camera_orientation_widget()
777
+ for index, (name, grid) in enumerate(grids.items()):
778
+ plotter.subplot(index // subplot_shape[1], index % subplot_shape[1])
779
+ grid_scale, value_range = scales[name]
780
+ field = self._fields[name][0]
781
+ marker = markers[name]
782
+ if marker:
783
+ if field.space.dimension == 2:
784
+ plotter.add_mesh(marker, show_scalar_bar=False)
785
+ plotter.add_mesh(grid, opacity=0.25, clim=value_range)
786
+ plotter.view_xy()
787
+ else:
788
+ plotter.add_mesh(marker)
789
+ elif field.space.geometry.cell_dimension == 3:
790
+ plotter.add_mesh_clip_plane(grid, show_edges=True, clim=value_range, assign_to_axis="z")
791
+ else:
792
+ plotter.add_mesh(grid, show_edges=True, clim=value_range)
793
+
794
+ if ref_geom:
795
+ plotter.add_mesh(ref_geom)
796
+
797
+ plotter.show(interactive_update=animate)
798
+
799
+ frame = 0
800
+ while animate and not plotter.iren.interactor.GetDone():
801
+ frame += 1
802
+ set_frame_data(frame)
803
+ plotter.update()
804
+
805
+ def _plot_matplotlib(self, options: Dict[str, Any]):
806
+ import matplotlib.animation as animation
807
+ import matplotlib.pyplot as plt
808
+ from matplotlib import cm
809
+
810
+ def make_animation(fig, ax, cax, values, draw_func):
811
+ def animate(i):
812
+ cs = draw_func(ax, values[i])
813
+
814
+ cax.cla()
815
+ fig.colorbar(cs, cax)
816
+
817
+ return cs
818
+
819
+ return animation.FuncAnimation(
820
+ ax.figure,
821
+ animate,
822
+ interval=30,
823
+ blit=False,
824
+ frames=len(values),
825
+ )
826
+
827
+ def make_draw_func(field, args, plot_func, plot_opts):
828
+ def draw_fn(axes, values):
829
+ axes.clear()
830
+
831
+ field.dof_values = values
832
+ cs = plot_func(field, axes=axes, **plot_opts)
833
+
834
+ if "xlim" in args:
835
+ axes.set_xlim(*args["xlim"])
836
+ if "ylim" in args:
837
+ axes.set_ylim(*args["ylim"])
838
+
839
+ return cs
840
+
841
+ return draw_fn
842
+
843
+ anims = []
844
+
845
+ field_count = len(self._fields)
846
+ subplot_rows = options.get("rows", 1)
847
+ subplot_shape = (subplot_rows, (field_count + subplot_rows - 1) // subplot_rows)
848
+
849
+ for index, (name, (field, values)) in enumerate(self._fields.items()):
850
+ args = options.get(name, {})
851
+ v = values[0]
852
+
853
+ plot_fn = None
854
+ plot_3d = False
855
+ plot_opts = {"cmap": cm.viridis}
856
+
857
+ plot_opts["clim"] = self._get_field_value_range(values, args)
858
+
859
+ if field.space.dimension == 2:
860
+ if "contours" in args:
861
+ plot_opts["levels"] = args["contours"].get("levels", None)
862
+ plot_fn = _plot_contours
863
+ elif v.ndim == 2 and v.shape[1] == 2:
864
+ if "displacement" in args:
865
+ plot_fn = _plot_displaced_tri_mesh
866
+ elif "streamlines" in args:
867
+ plot_opts["density"] = args["streamlines"].get("density", 1.0)
868
+ plot_fn = _plot_streamlines
869
+ elif "arrows" in args:
870
+ plot_opts["glyph_scale"] = args["arrows"].get("glyph_scale", 1.0)
871
+ plot_fn = _plot_quivers
872
+
873
+ if plot_fn is None:
874
+ plot_fn = _plot_surface
875
+ plot_3d = True
876
+
877
+ elif field.space.dimension == 3:
878
+ if "arrows" in args or "streamlines" in args:
879
+ plot_opts["glyph_scale"] = args.get("arrows", {}).get("glyph_scale", 1.0)
880
+ plot_fn = _plot_quivers_3d
881
+ elif field.space.geometry.cell_dimension == 2:
882
+ plot_fn = _plot_surface
883
+ else:
884
+ plot_fn = _plot_3d_scatter
885
+ plot_3d = True
886
+
887
+ subplot_kw = {"projection": "3d"} if plot_3d else {}
888
+ axes = plt.subplot(*subplot_shape, index + 1, **subplot_kw)
889
+
890
+ if not plot_3d:
891
+ axes.set_aspect("equal")
892
+
893
+ draw_fn = make_draw_func(field, args, plot_func=plot_fn, plot_opts=plot_opts)
894
+ cs = draw_fn(axes, values[0])
895
+
896
+ fig = plt.gcf()
897
+ cax = fig.colorbar(cs).ax
898
+
899
+ if len(values) > 1:
900
+ anims.append(make_animation(fig, axes, cax, values, draw_func=draw_fn))
901
+
902
+ plt.show()
903
+
904
+ @staticmethod
905
+ def _get_field_value_range(values, field_options: Dict[str, Any]):
906
+ value_range = field_options.get("clim", None)
907
+ if value_range is None:
908
+ value_range = (
909
+ min(np.min(_value_or_magnitude(v)) for v in values),
910
+ max(np.max(_value_or_magnitude(v)) for v in values),
911
+ )
912
+
913
+ return value_range
914
+
915
+
916
+ def _value_or_magnitude(values: np.ndarray):
917
+ if values.ndim == 1:
918
+ return values
919
+ return np.linalg.norm(values, axis=-1)
920
+
921
+
922
+ def _field_triangulation(field):
923
+ from matplotlib.tri import Triangulation
924
+
925
+ node_positions = field.space.node_positions().numpy()
926
+ return Triangulation(x=node_positions[:, 0], y=node_positions[:, 1], triangles=field.space.node_triangulation())
927
+
928
+
929
+ def _plot_surface(field, axes, **kwargs):
930
+ from matplotlib.cm import get_cmap
931
+ from matplotlib.colors import Normalize
932
+
933
+ C = _value_or_magnitude(field.dof_values.numpy())
934
+
935
+ positions = field.space.node_positions().numpy().T
936
+ if field.space.dimension == 3:
937
+ X, Y, Z = positions
938
+ else:
939
+ X, Y = positions
940
+ Z = C
941
+ axes.set_zlim(kwargs["clim"])
942
+
943
+ if hasattr(field.space, "node_grid"):
944
+ X, Y = field.space.node_grid()
945
+ C = C.reshape(X.shape)
946
+ return axes.plot_surface(X, Y, C, linewidth=0.1, antialiased=False, **kwargs)
947
+
948
+ if hasattr(field.space, "node_triangulation"):
949
+ triangulation = _field_triangulation(field)
950
+
951
+ if field.space.dimension == 3:
952
+ plot = axes.plot_trisurf(triangulation, Z, linewidth=0.1, antialiased=False)
953
+ # change colors -- recompute color map manually
954
+ vmin, vmax = kwargs["clim"]
955
+ norm = Normalize(vmin=vmin, vmax=vmax)
956
+ values = np.mean(C[triangulation.triangles], axis=1)
957
+ colors = get_cmap(kwargs["cmap"])(norm(values))
958
+ plot.set_norm(norm)
959
+ plot.set_fc(colors)
960
+ else:
961
+ plot = axes.plot_trisurf(triangulation, C, linewidth=0.1, antialiased=False, **kwargs)
962
+
963
+ return plot
964
+
965
+ # scatter
966
+ return axes.scatter(X, Y, Z, c=C, **kwargs)
967
+
968
+
969
+ def _plot_displaced_tri_mesh(field, axes, **kwargs):
970
+ triangulation = _field_triangulation(field)
971
+
972
+ displacement = field.dof_values.numpy()
973
+ triangulation.x += displacement[:, 0]
974
+ triangulation.y += displacement[:, 1]
975
+
976
+ Z = _value_or_magnitude(displacement)
977
+
978
+ # Plot the surface.
979
+ cs = axes.tripcolor(triangulation, Z, **kwargs)
980
+ axes.triplot(triangulation, lw=0.1)
981
+
982
+ return cs
983
+
984
+
985
+ def _plot_quivers(field, axes, clim=None, glyph_scale=1.0, **kwargs):
986
+ X, Y = field.space.node_positions().numpy().T
987
+
988
+ vel = field.dof_values.numpy()
989
+ u = vel[:, 0].reshape(X.shape)
990
+ v = vel[:, 1].reshape(X.shape)
991
+
992
+ return axes.quiver(X, Y, u, v, _value_or_magnitude(vel), scale=1.0 / glyph_scale, **kwargs)
993
+
994
+
995
+ def _plot_quivers_3d(field, axes, clim=None, cmap=None, glyph_scale=1.0, **kwargs):
996
+ X, Y, Z = field.space.node_positions().numpy().T
997
+
998
+ vel = field.dof_values.numpy()
999
+
1000
+ colors = cmap((_value_or_magnitude(vel) - clim[0]) / (clim[1] - clim[0]))
1001
+
1002
+ u = vel[:, 0].reshape(X.shape) / (clim[1] - clim[0])
1003
+ v = vel[:, 1].reshape(X.shape) / (clim[1] - clim[0])
1004
+ w = vel[:, 2].reshape(X.shape) / (clim[1] - clim[0])
1005
+
1006
+ return axes.quiver(X, Y, Z, u, v, w, colors=colors, length=glyph_scale, clim=clim, cmap=cmap, **kwargs)
1007
+
1008
+
1009
+ def _plot_streamlines(field, axes, clim=None, **kwargs):
1010
+ import matplotlib.tri as tr
1011
+
1012
+ triangulation = _field_triangulation(field)
1013
+
1014
+ vel = field.dof_values.numpy()
1015
+
1016
+ itp_vx = tr.CubicTriInterpolator(triangulation, vel[:, 0])
1017
+ itp_vy = tr.CubicTriInterpolator(triangulation, vel[:, 1])
1018
+
1019
+ X, Y = np.meshgrid(
1020
+ np.linspace(np.min(triangulation.x), np.max(triangulation.x), 100),
1021
+ np.linspace(np.min(triangulation.y), np.max(triangulation.y), 100),
1022
+ )
1023
+
1024
+ u = itp_vx(X, Y)
1025
+ v = itp_vy(X, Y)
1026
+ C = np.sqrt(u * u + v * v)
1027
+
1028
+ plot = axes.streamplot(X, Y, u, v, color=C, **kwargs)
1029
+ return plot.lines
1030
+
1031
+
1032
+ def _plot_contours(field, axes, clim=None, **kwargs):
1033
+ triangulation = _field_triangulation(field)
1034
+
1035
+ Z = _value_or_magnitude(field.dof_values.numpy())
1036
+
1037
+ tc = axes.tricontourf(triangulation, Z, **kwargs)
1038
+ axes.tricontour(triangulation, Z, **kwargs)
1039
+ return tc
1040
+
1041
+
1042
+ def _plot_3d_scatter(field, axes, **kwargs):
1043
+ X, Y, Z = field.space.node_positions().numpy().T
1044
+
1045
+ f = _value_or_magnitude(field.dof_values.numpy()).reshape(X.shape)
1046
+
1047
+ return axes.scatter(X, Y, Z, c=f, **kwargs)