warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1000 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ import warp.fem as fem
22
+ from warp.optim.linear import LinearOperator, aslinearoperator, preconditioner
23
+ from warp.sparse import BsrMatrix, bsr_get_diag, bsr_mv, bsr_transposed
24
+
25
+ __all__ = [
26
+ "gen_hexmesh",
27
+ "gen_quadmesh",
28
+ "gen_tetmesh",
29
+ "gen_trimesh",
30
+ "bsr_cg",
31
+ "bsr_solve_saddle",
32
+ "SaddleSystem",
33
+ "invert_diagonal_bsr_matrix",
34
+ "Plot",
35
+ ]
36
+
37
+ # matrix inversion routines contain nested loops,
38
+ # default unrolling leads to code explosion
39
+ wp.set_module_options({"max_unroll": 6})
40
+
41
+ #
42
+ # Mesh utilities
43
+ #
44
+
45
+
46
+ def gen_trimesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
47
+ """Constructs a triangular mesh by diving each cell of a dense 2D grid into two triangles
48
+
49
+ Args:
50
+ res: Resolution of the grid along each dimension
51
+ bounds_lo: Position of the lower bound of the axis-aligned grid
52
+ bounds_hi: Position of the upper bound of the axis-aligned grid
53
+
54
+ Returns:
55
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
56
+ """
57
+
58
+ if bounds_lo is None:
59
+ bounds_lo = wp.vec2(0.0)
60
+
61
+ if bounds_hi is None:
62
+ bounds_hi = wp.vec2(1.0)
63
+
64
+ Nx = res[0]
65
+ Ny = res[1]
66
+
67
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
68
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
69
+
70
+ positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
71
+
72
+ vidx = fem.utils.grid_to_tris(Nx, Ny)
73
+
74
+ return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
75
+
76
+
77
+ def gen_tetmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
78
+ """Constructs a tetrahedral mesh by diving each cell of a dense 3D grid into five tetrahedrons
79
+
80
+ Args:
81
+ res: Resolution of the grid along each dimension
82
+ bounds_lo: Position of the lower bound of the axis-aligned grid
83
+ bounds_hi: Position of the upper bound of the axis-aligned grid
84
+
85
+ Returns:
86
+ Tuple of ndarrays: (Vertex positions, Tetrahedron vertex indices)
87
+ """
88
+
89
+ if bounds_lo is None:
90
+ bounds_lo = wp.vec3(0.0)
91
+
92
+ if bounds_hi is None:
93
+ bounds_hi = wp.vec3(1.0)
94
+
95
+ Nx = res[0]
96
+ Ny = res[1]
97
+ Nz = res[2]
98
+
99
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
100
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
101
+ z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
102
+
103
+ positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
104
+
105
+ vidx = fem.utils.grid_to_tets(Nx, Ny, Nz)
106
+
107
+ return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
108
+
109
+
110
+ def gen_quadmesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
111
+ """Constructs a quadrilateral mesh from a dense 2D grid
112
+
113
+ Args:
114
+ res: Resolution of the grid along each dimension
115
+ bounds_lo: Position of the lower bound of the axis-aligned grid
116
+ bounds_hi: Position of the upper bound of the axis-aligned grid
117
+
118
+ Returns:
119
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
120
+ """
121
+ if bounds_lo is None:
122
+ bounds_lo = wp.vec2(0.0)
123
+
124
+ if bounds_hi is None:
125
+ bounds_hi = wp.vec2(1.0)
126
+
127
+ Nx = res[0]
128
+ Ny = res[1]
129
+
130
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
131
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
132
+
133
+ positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
134
+
135
+ vidx = fem.utils.grid_to_quads(Nx, Ny)
136
+
137
+ return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
138
+
139
+
140
+ def gen_hexmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
141
+ """Constructs a quadrilateral mesh from a dense 2D grid
142
+
143
+ Args:
144
+ res: Resolution of the grid along each dimension
145
+ bounds_lo: Position of the lower bound of the axis-aligned grid
146
+ bounds_hi: Position of the upper bound of the axis-aligned grid
147
+
148
+ Returns:
149
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
150
+ """
151
+
152
+ if bounds_lo is None:
153
+ bounds_lo = wp.vec3(0.0)
154
+
155
+ if bounds_hi is None:
156
+ bounds_hi = wp.vec3(1.0)
157
+
158
+ Nx = res[0]
159
+ Ny = res[1]
160
+ Nz = res[2]
161
+
162
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
163
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
164
+ z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
165
+
166
+ positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
167
+
168
+ vidx = fem.utils.grid_to_hexes(Nx, Ny, Nz)
169
+
170
+ return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
171
+
172
+
173
+ def gen_volume(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None, device=None) -> wp.Volume:
174
+ """Constructs a wp.Volume from a dense 3D grid
175
+
176
+ Args:
177
+ res: Resolution of the grid along each dimension
178
+ bounds_lo: Position of the lower bound of the axis-aligned grid
179
+ bounds_hi: Position of the upper bound of the axis-aligned grid
180
+ device: Cuda device on which to allocate the grid
181
+ """
182
+
183
+ if bounds_lo is None:
184
+ bounds_lo = wp.vec3(0.0)
185
+
186
+ if bounds_hi is None:
187
+ bounds_hi = wp.vec3(1.0)
188
+
189
+ extents = bounds_hi - bounds_lo
190
+ voxel_size = wp.cw_div(extents, wp.vec3(res))
191
+
192
+ x = np.arange(res[0], dtype=int)
193
+ y = np.arange(res[1], dtype=int)
194
+ z = np.arange(res[2], dtype=int)
195
+
196
+ ijk = np.transpose(np.meshgrid(x, y, z), axes=(1, 2, 3, 0)).reshape(-1, 3)
197
+ ijk = wp.array(ijk, dtype=wp.vec3i, device=device)
198
+ return wp.Volume.allocate_by_voxels(
199
+ ijk, voxel_size=voxel_size, translation=bounds_lo + 0.5 * voxel_size, device=device
200
+ )
201
+
202
+
203
+ #
204
+ # Bsr matrix utilities
205
+ #
206
+
207
+
208
+ def _get_linear_solver_func(method_name: str):
209
+ from warp.optim.linear import bicgstab, cg, cr, gmres
210
+
211
+ if method_name == "bicgstab":
212
+ return bicgstab
213
+ if method_name == "gmres":
214
+ return gmres
215
+ if method_name == "cr":
216
+ return cr
217
+ return cg
218
+
219
+
220
+ def bsr_cg(
221
+ A: BsrMatrix,
222
+ x: wp.array,
223
+ b: wp.array,
224
+ max_iters: int = 0,
225
+ tol: float = 0.0001,
226
+ check_every=10,
227
+ use_diag_precond=True,
228
+ mv_routine=None,
229
+ quiet=False,
230
+ method: str = "cg",
231
+ M: BsrMatrix = None,
232
+ ) -> Tuple[float, int]:
233
+ """Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
234
+
235
+ Args:
236
+ A: system left-hand side
237
+ x: result vector and initial guess
238
+ b: system right-hand-side
239
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
240
+ tol: relative tolerance under which to stop the solve
241
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
242
+ use_diag_precond: Whether to use diagonal preconditioning
243
+ mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
244
+ quiet: if True, do not print iteration residuals
245
+ method: Iterative solver method to use, defaults to Conjugate Gradient
246
+
247
+ Returns:
248
+ Tuple (residual norm, iteration count)
249
+
250
+ """
251
+
252
+ if M is not None:
253
+ M = aslinearoperator(M)
254
+ elif mv_routine is None:
255
+ M = preconditioner(A, "diag") if use_diag_precond else None
256
+ else:
257
+ A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
258
+ M = None
259
+
260
+ func = _get_linear_solver_func(method_name=method)
261
+
262
+ def print_callback(i, err, tol):
263
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
264
+
265
+ callback = None if quiet else print_callback
266
+
267
+ end_iter, err, atol = func(
268
+ A=A,
269
+ b=b,
270
+ x=x,
271
+ maxiter=max_iters,
272
+ tol=tol,
273
+ check_every=check_every,
274
+ M=M,
275
+ callback=callback,
276
+ use_cuda_graph=not wp.config.verify_cuda,
277
+ )
278
+
279
+ if not quiet:
280
+ res_str = "OK" if err <= atol else "TRUNCATED"
281
+ print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
282
+
283
+ return err, end_iter
284
+
285
+
286
+ class SaddleSystem(LinearOperator):
287
+ """Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
288
+
289
+ If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ A: BsrMatrix,
295
+ B: BsrMatrix,
296
+ Bt: Optional[BsrMatrix] = None,
297
+ use_diag_precond: bool = True,
298
+ ):
299
+ if Bt is None:
300
+ Bt = bsr_transposed(B)
301
+
302
+ self._A = A
303
+ self._B = B
304
+ self._Bt = Bt
305
+
306
+ self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
307
+ self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
308
+ self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
309
+
310
+ saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
311
+
312
+ super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
313
+
314
+ if use_diag_precond:
315
+ self._preconditioner = self._diag_preconditioner()
316
+ else:
317
+ self._preconditioner = None
318
+
319
+ def _diag_preconditioner(self):
320
+ A = self._A
321
+ B = self._B
322
+
323
+ M_u = preconditioner(A, "diag")
324
+
325
+ A_diag = bsr_get_diag(A)
326
+
327
+ schur_block_shape = (B.block_shape[0], B.block_shape[0])
328
+ schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
329
+ schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
330
+ wp.launch(
331
+ _compute_schur_inverse_diagonal,
332
+ dim=B.nrow,
333
+ device=A.device,
334
+ inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
335
+ )
336
+
337
+ if schur_block_shape == (1, 1):
338
+ # Downcast 1x1 mats to scalars
339
+ schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
340
+
341
+ M_p = aslinearoperator(schur_inv_diag)
342
+
343
+ def precond_mv(x, y, z, alpha, beta):
344
+ x_u = self.u_slice(x)
345
+ x_p = self.p_slice(x)
346
+ y_u = self.u_slice(y)
347
+ y_p = self.p_slice(y)
348
+ z_u = self.u_slice(z)
349
+ z_p = self.p_slice(z)
350
+
351
+ M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
352
+ M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
353
+
354
+ return LinearOperator(
355
+ shape=self.shape,
356
+ dtype=self.dtype,
357
+ device=self.device,
358
+ matvec=precond_mv,
359
+ )
360
+
361
+ @property
362
+ def preconditioner(self):
363
+ return self._preconditioner
364
+
365
+ def u_slice(self, a: wp.array):
366
+ return wp.array(
367
+ ptr=a.ptr,
368
+ dtype=self._u_dtype,
369
+ shape=self._A.nrow,
370
+ strides=None,
371
+ device=a.device,
372
+ pinned=a.pinned,
373
+ copy=False,
374
+ )
375
+
376
+ def p_slice(self, a: wp.array):
377
+ return wp.array(
378
+ ptr=a.ptr + self._p_byte_offset,
379
+ dtype=self._p_dtype,
380
+ shape=self._B.nrow,
381
+ strides=None,
382
+ device=a.device,
383
+ pinned=a.pinned,
384
+ copy=False,
385
+ )
386
+
387
+ def _saddle_mv(self, x, y, z, alpha, beta):
388
+ x_u = self.u_slice(x)
389
+ x_p = self.p_slice(x)
390
+ z_u = self.u_slice(z)
391
+ z_p = self.p_slice(z)
392
+
393
+ if y.ptr != z.ptr and beta != 0.0:
394
+ wp.copy(src=y, dest=z)
395
+
396
+ bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
397
+ bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
398
+ bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
399
+
400
+
401
+ def bsr_solve_saddle(
402
+ saddle_system: SaddleSystem,
403
+ x_u: wp.array,
404
+ x_p: wp.array,
405
+ b_u: wp.array,
406
+ b_p: wp.array,
407
+ max_iters: int = 0,
408
+ tol: float = 0.0001,
409
+ check_every=10,
410
+ quiet=False,
411
+ method: str = "cg",
412
+ ) -> Tuple[float, int]:
413
+ """Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
414
+
415
+ Args:
416
+ saddle_system: Saddle point system
417
+ x_u: primal part of the result vector and initial guess
418
+ x_p: Lagrange multiplier part of the result vector and initial guess
419
+ b_u: primal left-hand-side
420
+ b_p: constraint left-hand-side
421
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
422
+ tol: relative tolerance under which to stop the solve
423
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
424
+ quiet: if True, do not print iteration residuals
425
+ method: Iterative solver method to use, defaults to BiCGSTAB
426
+
427
+ Returns:
428
+ Tuple (residual norm, iteration count)
429
+
430
+ """
431
+ x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
432
+ b = wp.empty_like(x)
433
+
434
+ wp.copy(src=x_u, dest=saddle_system.u_slice(x))
435
+ wp.copy(src=x_p, dest=saddle_system.p_slice(x))
436
+ wp.copy(src=b_u, dest=saddle_system.u_slice(b))
437
+ wp.copy(src=b_p, dest=saddle_system.p_slice(b))
438
+
439
+ func = _get_linear_solver_func(method_name=method)
440
+
441
+ def print_callback(i, err, tol):
442
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
443
+
444
+ callback = None if quiet else print_callback
445
+
446
+ end_iter, err, atol = func(
447
+ A=saddle_system,
448
+ b=b,
449
+ x=x,
450
+ maxiter=max_iters,
451
+ tol=tol,
452
+ check_every=check_every,
453
+ M=saddle_system.preconditioner,
454
+ callback=callback,
455
+ )
456
+
457
+ if not quiet:
458
+ res_str = "OK" if err <= atol else "TRUNCATED"
459
+ print(f"{func.__name__}: terminated after {end_iter} iterations with absolute error = \t {err} ({res_str})")
460
+
461
+ wp.copy(dest=x_u, src=saddle_system.u_slice(x))
462
+ wp.copy(dest=x_p, src=saddle_system.p_slice(x))
463
+
464
+ return err, end_iter
465
+
466
+
467
+ @wp.kernel(enable_backward=False)
468
+ def _compute_schur_inverse_diagonal(
469
+ B_offsets: wp.array(dtype=int),
470
+ B_indices: wp.array(dtype=int),
471
+ B_values: wp.array(dtype=Any),
472
+ A_diag: wp.array(dtype=Any),
473
+ P_diag: wp.array(dtype=Any),
474
+ ):
475
+ row = wp.tid()
476
+
477
+ zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
478
+
479
+ schur = zero
480
+
481
+ beg = B_offsets[row]
482
+ end = B_offsets[row + 1]
483
+
484
+ for b in range(beg, end):
485
+ B = B_values[b]
486
+ col = B_indices[b]
487
+ Ai = wp.inverse(A_diag[col])
488
+ S = B * Ai * wp.transpose(B)
489
+ schur += S
490
+
491
+ P_diag[row] = fem.utils.inverse_qr(schur)
492
+
493
+
494
+ def invert_diagonal_bsr_matrix(A: BsrMatrix):
495
+ """Inverts each block of a block-diagonal mass matrix"""
496
+
497
+ values = A.values
498
+ if not wp.types.type_is_matrix(values.dtype):
499
+ values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
500
+
501
+ wp.launch(
502
+ kernel=_block_diagonal_invert,
503
+ dim=A.nrow,
504
+ inputs=[values],
505
+ device=values.device,
506
+ )
507
+
508
+
509
+ @wp.kernel(enable_backward=False)
510
+ def _block_diagonal_invert(values: wp.array(dtype=Any)):
511
+ i = wp.tid()
512
+ values[i] = fem.utils.inverse_qr(values[i])
513
+
514
+
515
+ #
516
+ # Plot utilities
517
+ #
518
+
519
+
520
+ class Plot:
521
+ def __init__(self, stage=None, default_point_radius=0.01):
522
+ self.default_point_radius = default_point_radius
523
+
524
+ self._fields = {}
525
+
526
+ self._usd_renderer = None
527
+ if stage is not None:
528
+ try:
529
+ from warp.render import UsdRenderer
530
+
531
+ self._usd_renderer = UsdRenderer(stage)
532
+ except Exception as err:
533
+ print(f"Could not initialize UsdRenderer for stage '{stage}': {err}.")
534
+
535
+ def begin_frame(self, time):
536
+ if self._usd_renderer is not None:
537
+ self._usd_renderer.begin_frame(time=time)
538
+
539
+ def end_frame(self):
540
+ if self._usd_renderer is not None:
541
+ self._usd_renderer.end_frame()
542
+
543
+ def add_field(self, name: str, field: fem.DiscreteField):
544
+ if self._usd_renderer is not None:
545
+ self._render_to_usd(field)
546
+
547
+ if name not in self._fields:
548
+ field_clone = field.space.make_field(space_partition=field.space_partition)
549
+ self._fields[name] = (field_clone, [])
550
+
551
+ self._fields[name][1].append(field.dof_values.numpy())
552
+
553
+ def _render_to_usd(self, name: str, field: fem.DiscreteField):
554
+ points = field.space.node_positions().numpy()
555
+ values = field.dof_values.numpy()
556
+
557
+ if values.ndim == 2:
558
+ if values.shape[1] == field.space.dimension:
559
+ # use values as displacement
560
+ points += values
561
+ else:
562
+ # use magnitude
563
+ values = np.linalg.norm(values, axis=1)
564
+
565
+ if field.space.dimension == 2:
566
+ z = values if values.ndim == 1 else np.zeros((points.shape[0], 1))
567
+ points = np.hstack((points, z))
568
+
569
+ if hasattr(field.space, "node_triangulation"):
570
+ indices = field.space.node_triangulation()
571
+ self._usd_renderer.render_mesh(name, points=points, indices=indices)
572
+ else:
573
+ self._usd_renderer.render_points(name, points=points, radius=self.default_point_radius)
574
+ elif values.ndim == 1:
575
+ self._usd_renderer.render_points(name, points, radius=values)
576
+ else:
577
+ self._usd_renderer.render_points(name, points, radius=self.default_point_radius)
578
+
579
+ def plot(self, options: Dict[str, Any] = None, backend: str = "auto"):
580
+ if options is None:
581
+ options = {}
582
+
583
+ if backend == "pyvista":
584
+ return self._plot_pyvista(options)
585
+ if backend == "matplotlib":
586
+ return self._plot_matplotlib(options)
587
+
588
+ # try both
589
+ try:
590
+ return self._plot_pyvista(options)
591
+ except ModuleNotFoundError:
592
+ try:
593
+ return self._plot_matplotlib(options)
594
+ except ModuleNotFoundError:
595
+ wp.utils.warn("pyvista or matplotlib must be installed to visualize solution results")
596
+
597
+ def _plot_pyvista(self, options: Dict[str, Any]):
598
+ import pyvista
599
+ import pyvista.themes
600
+
601
+ grids = {}
602
+ scales = {}
603
+ markers = {}
604
+
605
+ animate = False
606
+
607
+ ref_geom = options.get("ref_geom", None)
608
+ if ref_geom is not None:
609
+ if isinstance(ref_geom, tuple):
610
+ vertices, counts, indices = ref_geom
611
+ offsets = np.cumsum(counts)
612
+ ranges = np.array([offsets - counts, offsets]).T
613
+ faces = np.concatenate(
614
+ [[count] + list(indices[beg:end]) for (count, (beg, end)) in zip(counts, ranges)]
615
+ )
616
+ ref_geom = pyvista.PolyData(vertices, faces)
617
+ else:
618
+ ref_geom = pyvista.PolyData(ref_geom)
619
+
620
+ for name, (field, values) in self._fields.items():
621
+ cells, types = field.space.vtk_cells()
622
+ node_pos = field.space.node_positions().numpy()
623
+
624
+ args = options.get(name, {})
625
+
626
+ grid_scale = np.max(np.max(node_pos, axis=0) - np.min(node_pos, axis=0))
627
+ value_range = self._get_field_value_range(values, args)
628
+ scales[name] = (grid_scale, value_range)
629
+
630
+ if node_pos.shape[1] == 2:
631
+ node_pos = np.hstack((node_pos, np.zeros((node_pos.shape[0], 1))))
632
+
633
+ grid = pyvista.UnstructuredGrid(cells, types, node_pos)
634
+ grids[name] = grid
635
+
636
+ if len(values) > 1:
637
+ animate = True
638
+
639
+ def set_frame_data(frame):
640
+ for name, (field, values) in self._fields.items():
641
+ if frame > 0 and len(values) == 1:
642
+ continue
643
+
644
+ v = values[frame % len(values)]
645
+ grid = grids[name]
646
+ grid_scale, value_range = scales[name]
647
+ field_args = options.get(name, {})
648
+
649
+ marker = None
650
+
651
+ if field.space.dimension == 2 and v.ndim == 2 and v.shape[1] == 2:
652
+ grid.point_data[name] = np.hstack((v, np.zeros((v.shape[0], 1))))
653
+ else:
654
+ grid.point_data[name] = v
655
+
656
+ if v.ndim == 2:
657
+ grid.point_data[name + "_mag"] = np.linalg.norm(v, axis=1)
658
+
659
+ if "arrows" in field_args:
660
+ glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
661
+ glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
662
+ marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
663
+ elif "contours" in field_args:
664
+ levels = field_args["contours"].get("levels", 10)
665
+ if type(levels) == int:
666
+ levels = np.linspace(*value_range, levels)
667
+ marker = grid.contour(isosurfaces=levels, scalars=name + "_mag" if v.ndim == 2 else name)
668
+ elif field.space.dimension == 2:
669
+ z_scale = grid_scale / max(1.0e-8, value_range[1] - value_range[0])
670
+
671
+ if "streamlines" in field_args:
672
+ center = np.mean(grid.points, axis=0)
673
+ density = field_args["streamlines"].get("density", 1.0)
674
+ cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
675
+
676
+ separating_distance = 0.5 / (30.0 * density * cell_size)
677
+ # Try with various sep distance until we get at least one line
678
+ while separating_distance * cell_size < 1.0:
679
+ lines = grid.streamlines_evenly_spaced_2D(
680
+ vectors=name,
681
+ start_position=center,
682
+ separating_distance=separating_distance,
683
+ separating_distance_ratio=0.5,
684
+ step_length=0.25,
685
+ compute_vorticity=False,
686
+ )
687
+ if lines.n_lines > 0:
688
+ break
689
+ separating_distance *= 1.25
690
+ marker = lines.tube(radius=0.0025 * grid_scale / density)
691
+ elif "arrows" in field_args:
692
+ glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
693
+ glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
694
+ marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
695
+ elif "displacement" in field_args:
696
+ grid.points[:, 0:2] = field.space.node_positions().numpy() + v
697
+ else:
698
+ # Extrude surface
699
+ z = v if v.ndim == 1 else grid.point_data[name + "_mag"]
700
+ grid.points[:, 2] = z * z_scale
701
+
702
+ elif field.space.dimension == 3:
703
+ if "streamlines" in field_args:
704
+ center = np.mean(grid.points, axis=0)
705
+ density = field_args["streamlines"].get("density", 1.0)
706
+ cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
707
+ lines = grid.streamlines(vectors=name, n_points=int(100 * density))
708
+ marker = lines.tube(radius=0.0025 * grid_scale / np.sqrt(density))
709
+ elif "displacement" in field_args:
710
+ grid.points = field.space.node_positions().numpy() + v
711
+
712
+ if frame == 0:
713
+ if v.ndim == 1:
714
+ grid.set_active_scalars(name)
715
+ else:
716
+ grid.set_active_vectors(name)
717
+ grid.set_active_scalars(name + "_mag")
718
+ markers[name] = marker
719
+ elif marker:
720
+ markers[name].copy_from(marker)
721
+
722
+ set_frame_data(0)
723
+
724
+ subplot_rows = options.get("rows", 1)
725
+ subplot_shape = (subplot_rows, (len(grids) + subplot_rows - 1) // subplot_rows)
726
+
727
+ plotter = pyvista.Plotter(shape=subplot_shape, theme=pyvista.themes.DocumentProTheme())
728
+ plotter.link_views()
729
+ plotter.add_camera_orientation_widget()
730
+ for index, (name, grid) in enumerate(grids.items()):
731
+ plotter.subplot(index // subplot_shape[1], index % subplot_shape[1])
732
+ grid_scale, value_range = scales[name]
733
+ field = self._fields[name][0]
734
+ marker = markers[name]
735
+ if marker:
736
+ if field.space.dimension == 2:
737
+ plotter.add_mesh(marker, show_scalar_bar=False)
738
+ plotter.add_mesh(grid, opacity=0.25, clim=value_range)
739
+ plotter.view_xy()
740
+ else:
741
+ plotter.add_mesh(marker)
742
+ elif field.space.geometry.cell_dimension == 3:
743
+ plotter.add_mesh_clip_plane(grid, show_edges=True, clim=value_range, assign_to_axis="z")
744
+ else:
745
+ plotter.add_mesh(grid, show_edges=True, clim=value_range)
746
+
747
+ if ref_geom:
748
+ plotter.add_mesh(ref_geom)
749
+
750
+ plotter.show(interactive_update=animate)
751
+
752
+ frame = 0
753
+ while animate and not plotter.iren.interactor.GetDone():
754
+ frame += 1
755
+ set_frame_data(frame)
756
+ plotter.update()
757
+
758
+ def _plot_matplotlib(self, options: Dict[str, Any]):
759
+ import matplotlib.animation as animation
760
+ import matplotlib.pyplot as plt
761
+ from matplotlib import cm
762
+
763
+ def make_animation(fig, ax, cax, values, draw_func):
764
+ def animate(i):
765
+ cs = draw_func(ax, values[i])
766
+
767
+ cax.cla()
768
+ fig.colorbar(cs, cax)
769
+
770
+ return cs
771
+
772
+ return animation.FuncAnimation(
773
+ ax.figure,
774
+ animate,
775
+ interval=30,
776
+ blit=False,
777
+ frames=len(values),
778
+ )
779
+
780
+ def make_draw_func(field, args, plot_func, plot_opts):
781
+ def draw_fn(axes, values):
782
+ axes.clear()
783
+
784
+ field.dof_values = values
785
+ cs = plot_func(field, axes=axes, **plot_opts)
786
+
787
+ if "xlim" in args:
788
+ axes.set_xlim(*args["xlim"])
789
+ if "ylim" in args:
790
+ axes.set_ylim(*args["ylim"])
791
+
792
+ return cs
793
+
794
+ return draw_fn
795
+
796
+ anims = []
797
+
798
+ field_count = len(self._fields)
799
+ subplot_rows = options.get("rows", 1)
800
+ subplot_shape = (subplot_rows, (field_count + subplot_rows - 1) // subplot_rows)
801
+
802
+ for index, (name, (field, values)) in enumerate(self._fields.items()):
803
+ args = options.get(name, {})
804
+ v = values[0]
805
+
806
+ plot_fn = None
807
+ plot_3d = False
808
+ plot_opts = {"cmap": cm.viridis}
809
+
810
+ plot_opts["clim"] = self._get_field_value_range(values, args)
811
+
812
+ if field.space.dimension == 2:
813
+ if "contours" in args:
814
+ plot_opts["levels"] = args["contours"].get("levels", None)
815
+ plot_fn = _plot_contours
816
+ elif v.ndim == 2 and v.shape[1] == 2:
817
+ if "displacement" in args:
818
+ plot_fn = _plot_displaced_tri_mesh
819
+ elif "streamlines" in args:
820
+ plot_opts["density"] = args["streamlines"].get("density", 1.0)
821
+ plot_fn = _plot_streamlines
822
+ elif "arrows" in args:
823
+ plot_opts["glyph_scale"] = args["arrows"].get("glyph_scale", 1.0)
824
+ plot_fn = _plot_quivers
825
+
826
+ if plot_fn is None:
827
+ plot_fn = _plot_surface
828
+ plot_3d = True
829
+
830
+ elif field.space.dimension == 3:
831
+ if "arrows" in args or "streamlines" in args:
832
+ plot_opts["glyph_scale"] = args.get("arrows", {}).get("glyph_scale", 1.0)
833
+ plot_fn = _plot_quivers_3d
834
+ elif field.space.geometry.cell_dimension == 2:
835
+ plot_fn = _plot_surface
836
+ else:
837
+ plot_fn = _plot_3d_scatter
838
+ plot_3d = True
839
+
840
+ subplot_kw = {"projection": "3d"} if plot_3d else {}
841
+ axes = plt.subplot(*subplot_shape, index + 1, **subplot_kw)
842
+
843
+ if not plot_3d:
844
+ axes.set_aspect("equal")
845
+
846
+ draw_fn = make_draw_func(field, args, plot_func=plot_fn, plot_opts=plot_opts)
847
+ cs = draw_fn(axes, values[0])
848
+
849
+ fig = plt.gcf()
850
+ cax = fig.colorbar(cs).ax
851
+
852
+ if len(values) > 1:
853
+ anims.append(make_animation(fig, axes, cax, values, draw_func=draw_fn))
854
+
855
+ plt.show()
856
+
857
+ @staticmethod
858
+ def _get_field_value_range(values, field_options: Dict[str, Any]):
859
+ value_range = field_options.get("clim", None)
860
+ if value_range is None:
861
+ value_range = (
862
+ min((np.min(_value_or_magnitude(v)) for v in values)),
863
+ max((np.max(_value_or_magnitude(v)) for v in values)),
864
+ )
865
+
866
+ return value_range
867
+
868
+
869
+ def _value_or_magnitude(values: np.ndarray):
870
+ if values.ndim == 1:
871
+ return values
872
+ return np.linalg.norm(values, axis=-1)
873
+
874
+
875
+ def _field_triangulation(field):
876
+ from matplotlib.tri import Triangulation
877
+
878
+ node_positions = field.space.node_positions().numpy()
879
+ return Triangulation(x=node_positions[:, 0], y=node_positions[:, 1], triangles=field.space.node_triangulation())
880
+
881
+
882
+ def _plot_surface(field, axes, **kwargs):
883
+ from matplotlib.cm import get_cmap
884
+ from matplotlib.colors import Normalize
885
+
886
+ C = _value_or_magnitude(field.dof_values.numpy())
887
+
888
+ positions = field.space.node_positions().numpy().T
889
+ if field.space.dimension == 3:
890
+ X, Y, Z = positions
891
+ else:
892
+ X, Y = positions
893
+ Z = C
894
+ axes.set_zlim(kwargs["clim"])
895
+
896
+ if hasattr(field.space, "node_grid"):
897
+ X, Y = field.space.node_grid()
898
+ C = C.reshape(X.shape)
899
+ return axes.plot_surface(X, Y, C, linewidth=0.1, antialiased=False, **kwargs)
900
+
901
+ if hasattr(field.space, "node_triangulation"):
902
+ triangulation = _field_triangulation(field)
903
+
904
+ if field.space.dimension == 3:
905
+ plot = axes.plot_trisurf(triangulation, Z, linewidth=0.1, antialiased=False)
906
+ # change colors -- recompute color map manually
907
+ vmin, vmax = kwargs["clim"]
908
+ norm = Normalize(vmin=vmin, vmax=vmax)
909
+ values = np.mean(C[triangulation.triangles], axis=1)
910
+ colors = get_cmap(kwargs["cmap"])(norm(values))
911
+ plot.set_norm(norm)
912
+ plot.set_fc(colors)
913
+ else:
914
+ plot = axes.plot_trisurf(triangulation, C, linewidth=0.1, antialiased=False, **kwargs)
915
+
916
+ return plot
917
+
918
+ # scatter
919
+ return axes.scatter(X, Y, Z, c=C, **kwargs)
920
+
921
+
922
+ def _plot_displaced_tri_mesh(field, axes, **kwargs):
923
+ triangulation = _field_triangulation(field)
924
+
925
+ displacement = field.dof_values.numpy()
926
+ triangulation.x += displacement[:, 0]
927
+ triangulation.y += displacement[:, 1]
928
+
929
+ Z = _value_or_magnitude(displacement)
930
+
931
+ # Plot the surface.
932
+ cs = axes.tripcolor(triangulation, Z, **kwargs)
933
+ axes.triplot(triangulation, lw=0.1)
934
+
935
+ return cs
936
+
937
+
938
+ def _plot_quivers(field, axes, clim=None, glyph_scale=1.0, **kwargs):
939
+ X, Y = field.space.node_positions().numpy().T
940
+
941
+ vel = field.dof_values.numpy()
942
+ u = vel[:, 0].reshape(X.shape)
943
+ v = vel[:, 1].reshape(X.shape)
944
+
945
+ return axes.quiver(X, Y, u, v, _value_or_magnitude(vel), scale=1.0 / glyph_scale, **kwargs)
946
+
947
+
948
+ def _plot_quivers_3d(field, axes, clim=None, cmap=None, glyph_scale=1.0, **kwargs):
949
+ X, Y, Z = field.space.node_positions().numpy().T
950
+
951
+ vel = field.dof_values.numpy()
952
+
953
+ colors = cmap((_value_or_magnitude(vel) - clim[0]) / (clim[1] - clim[0]))
954
+
955
+ u = vel[:, 0].reshape(X.shape) / (clim[1] - clim[0])
956
+ v = vel[:, 1].reshape(X.shape) / (clim[1] - clim[0])
957
+ w = vel[:, 2].reshape(X.shape) / (clim[1] - clim[0])
958
+
959
+ return axes.quiver(X, Y, Z, u, v, w, colors=colors, length=glyph_scale, clim=clim, cmap=cmap, **kwargs)
960
+
961
+
962
+ def _plot_streamlines(field, axes, clim=None, **kwargs):
963
+ import matplotlib.tri as tr
964
+
965
+ triangulation = _field_triangulation(field)
966
+
967
+ vel = field.dof_values.numpy()
968
+
969
+ itp_vx = tr.CubicTriInterpolator(triangulation, vel[:, 0])
970
+ itp_vy = tr.CubicTriInterpolator(triangulation, vel[:, 1])
971
+
972
+ X, Y = np.meshgrid(
973
+ np.linspace(np.min(triangulation.x), np.max(triangulation.x), 100),
974
+ np.linspace(np.min(triangulation.y), np.max(triangulation.y), 100),
975
+ )
976
+
977
+ u = itp_vx(X, Y)
978
+ v = itp_vy(X, Y)
979
+ C = np.sqrt(u * u + v * v)
980
+
981
+ plot = axes.streamplot(X, Y, u, v, color=C, **kwargs)
982
+ return plot.lines
983
+
984
+
985
+ def _plot_contours(field, axes, clim=None, **kwargs):
986
+ triangulation = _field_triangulation(field)
987
+
988
+ Z = _value_or_magnitude(field.dof_values.numpy())
989
+
990
+ tc = axes.tricontourf(triangulation, Z, **kwargs)
991
+ axes.tricontour(triangulation, Z, **kwargs)
992
+ return tc
993
+
994
+
995
+ def _plot_3d_scatter(field, axes, **kwargs):
996
+ X, Y, Z = field.space.node_positions().numpy().T
997
+
998
+ f = _value_or_magnitude(field.dof_values.numpy()).reshape(X.shape)
999
+
1000
+ return axes.scatter(X, Y, Z, c=f, **kwargs)