warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1000 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any, Dict, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
import warp.fem as fem
|
|
22
|
+
from warp.optim.linear import LinearOperator, aslinearoperator, preconditioner
|
|
23
|
+
from warp.sparse import BsrMatrix, bsr_get_diag, bsr_mv, bsr_transposed
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"gen_hexmesh",
|
|
27
|
+
"gen_quadmesh",
|
|
28
|
+
"gen_tetmesh",
|
|
29
|
+
"gen_trimesh",
|
|
30
|
+
"bsr_cg",
|
|
31
|
+
"bsr_solve_saddle",
|
|
32
|
+
"SaddleSystem",
|
|
33
|
+
"invert_diagonal_bsr_matrix",
|
|
34
|
+
"Plot",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
# matrix inversion routines contain nested loops,
|
|
38
|
+
# default unrolling leads to code explosion
|
|
39
|
+
wp.set_module_options({"max_unroll": 6})
|
|
40
|
+
|
|
41
|
+
#
|
|
42
|
+
# Mesh utilities
|
|
43
|
+
#
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def gen_trimesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
|
|
47
|
+
"""Constructs a triangular mesh by diving each cell of a dense 2D grid into two triangles
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
res: Resolution of the grid along each dimension
|
|
51
|
+
bounds_lo: Position of the lower bound of the axis-aligned grid
|
|
52
|
+
bounds_hi: Position of the upper bound of the axis-aligned grid
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
if bounds_lo is None:
|
|
59
|
+
bounds_lo = wp.vec2(0.0)
|
|
60
|
+
|
|
61
|
+
if bounds_hi is None:
|
|
62
|
+
bounds_hi = wp.vec2(1.0)
|
|
63
|
+
|
|
64
|
+
Nx = res[0]
|
|
65
|
+
Ny = res[1]
|
|
66
|
+
|
|
67
|
+
x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
|
|
68
|
+
y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
|
|
69
|
+
|
|
70
|
+
positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
|
|
71
|
+
|
|
72
|
+
vidx = fem.utils.grid_to_tris(Nx, Ny)
|
|
73
|
+
|
|
74
|
+
return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def gen_tetmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
|
|
78
|
+
"""Constructs a tetrahedral mesh by diving each cell of a dense 3D grid into five tetrahedrons
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
res: Resolution of the grid along each dimension
|
|
82
|
+
bounds_lo: Position of the lower bound of the axis-aligned grid
|
|
83
|
+
bounds_hi: Position of the upper bound of the axis-aligned grid
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Tuple of ndarrays: (Vertex positions, Tetrahedron vertex indices)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
if bounds_lo is None:
|
|
90
|
+
bounds_lo = wp.vec3(0.0)
|
|
91
|
+
|
|
92
|
+
if bounds_hi is None:
|
|
93
|
+
bounds_hi = wp.vec3(1.0)
|
|
94
|
+
|
|
95
|
+
Nx = res[0]
|
|
96
|
+
Ny = res[1]
|
|
97
|
+
Nz = res[2]
|
|
98
|
+
|
|
99
|
+
x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
|
|
100
|
+
y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
|
|
101
|
+
z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
|
|
102
|
+
|
|
103
|
+
positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
|
|
104
|
+
|
|
105
|
+
vidx = fem.utils.grid_to_tets(Nx, Ny, Nz)
|
|
106
|
+
|
|
107
|
+
return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def gen_quadmesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
|
|
111
|
+
"""Constructs a quadrilateral mesh from a dense 2D grid
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
res: Resolution of the grid along each dimension
|
|
115
|
+
bounds_lo: Position of the lower bound of the axis-aligned grid
|
|
116
|
+
bounds_hi: Position of the upper bound of the axis-aligned grid
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
|
|
120
|
+
"""
|
|
121
|
+
if bounds_lo is None:
|
|
122
|
+
bounds_lo = wp.vec2(0.0)
|
|
123
|
+
|
|
124
|
+
if bounds_hi is None:
|
|
125
|
+
bounds_hi = wp.vec2(1.0)
|
|
126
|
+
|
|
127
|
+
Nx = res[0]
|
|
128
|
+
Ny = res[1]
|
|
129
|
+
|
|
130
|
+
x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
|
|
131
|
+
y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
|
|
132
|
+
|
|
133
|
+
positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
|
|
134
|
+
|
|
135
|
+
vidx = fem.utils.grid_to_quads(Nx, Ny)
|
|
136
|
+
|
|
137
|
+
return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def gen_hexmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
|
|
141
|
+
"""Constructs a quadrilateral mesh from a dense 2D grid
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
res: Resolution of the grid along each dimension
|
|
145
|
+
bounds_lo: Position of the lower bound of the axis-aligned grid
|
|
146
|
+
bounds_hi: Position of the upper bound of the axis-aligned grid
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
if bounds_lo is None:
|
|
153
|
+
bounds_lo = wp.vec3(0.0)
|
|
154
|
+
|
|
155
|
+
if bounds_hi is None:
|
|
156
|
+
bounds_hi = wp.vec3(1.0)
|
|
157
|
+
|
|
158
|
+
Nx = res[0]
|
|
159
|
+
Ny = res[1]
|
|
160
|
+
Nz = res[2]
|
|
161
|
+
|
|
162
|
+
x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
|
|
163
|
+
y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
|
|
164
|
+
z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
|
|
165
|
+
|
|
166
|
+
positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
|
|
167
|
+
|
|
168
|
+
vidx = fem.utils.grid_to_hexes(Nx, Ny, Nz)
|
|
169
|
+
|
|
170
|
+
return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def gen_volume(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None, device=None) -> wp.Volume:
|
|
174
|
+
"""Constructs a wp.Volume from a dense 3D grid
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
res: Resolution of the grid along each dimension
|
|
178
|
+
bounds_lo: Position of the lower bound of the axis-aligned grid
|
|
179
|
+
bounds_hi: Position of the upper bound of the axis-aligned grid
|
|
180
|
+
device: Cuda device on which to allocate the grid
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
if bounds_lo is None:
|
|
184
|
+
bounds_lo = wp.vec3(0.0)
|
|
185
|
+
|
|
186
|
+
if bounds_hi is None:
|
|
187
|
+
bounds_hi = wp.vec3(1.0)
|
|
188
|
+
|
|
189
|
+
extents = bounds_hi - bounds_lo
|
|
190
|
+
voxel_size = wp.cw_div(extents, wp.vec3(res))
|
|
191
|
+
|
|
192
|
+
x = np.arange(res[0], dtype=int)
|
|
193
|
+
y = np.arange(res[1], dtype=int)
|
|
194
|
+
z = np.arange(res[2], dtype=int)
|
|
195
|
+
|
|
196
|
+
ijk = np.transpose(np.meshgrid(x, y, z), axes=(1, 2, 3, 0)).reshape(-1, 3)
|
|
197
|
+
ijk = wp.array(ijk, dtype=wp.vec3i, device=device)
|
|
198
|
+
return wp.Volume.allocate_by_voxels(
|
|
199
|
+
ijk, voxel_size=voxel_size, translation=bounds_lo + 0.5 * voxel_size, device=device
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
#
|
|
204
|
+
# Bsr matrix utilities
|
|
205
|
+
#
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _get_linear_solver_func(method_name: str):
|
|
209
|
+
from warp.optim.linear import bicgstab, cg, cr, gmres
|
|
210
|
+
|
|
211
|
+
if method_name == "bicgstab":
|
|
212
|
+
return bicgstab
|
|
213
|
+
if method_name == "gmres":
|
|
214
|
+
return gmres
|
|
215
|
+
if method_name == "cr":
|
|
216
|
+
return cr
|
|
217
|
+
return cg
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def bsr_cg(
|
|
221
|
+
A: BsrMatrix,
|
|
222
|
+
x: wp.array,
|
|
223
|
+
b: wp.array,
|
|
224
|
+
max_iters: int = 0,
|
|
225
|
+
tol: float = 0.0001,
|
|
226
|
+
check_every=10,
|
|
227
|
+
use_diag_precond=True,
|
|
228
|
+
mv_routine=None,
|
|
229
|
+
quiet=False,
|
|
230
|
+
method: str = "cg",
|
|
231
|
+
M: BsrMatrix = None,
|
|
232
|
+
) -> Tuple[float, int]:
|
|
233
|
+
"""Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
A: system left-hand side
|
|
237
|
+
x: result vector and initial guess
|
|
238
|
+
b: system right-hand-side
|
|
239
|
+
max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
|
|
240
|
+
tol: relative tolerance under which to stop the solve
|
|
241
|
+
check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
|
|
242
|
+
use_diag_precond: Whether to use diagonal preconditioning
|
|
243
|
+
mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
|
|
244
|
+
quiet: if True, do not print iteration residuals
|
|
245
|
+
method: Iterative solver method to use, defaults to Conjugate Gradient
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Tuple (residual norm, iteration count)
|
|
249
|
+
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
if M is not None:
|
|
253
|
+
M = aslinearoperator(M)
|
|
254
|
+
elif mv_routine is None:
|
|
255
|
+
M = preconditioner(A, "diag") if use_diag_precond else None
|
|
256
|
+
else:
|
|
257
|
+
A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
|
|
258
|
+
M = None
|
|
259
|
+
|
|
260
|
+
func = _get_linear_solver_func(method_name=method)
|
|
261
|
+
|
|
262
|
+
def print_callback(i, err, tol):
|
|
263
|
+
print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
|
|
264
|
+
|
|
265
|
+
callback = None if quiet else print_callback
|
|
266
|
+
|
|
267
|
+
end_iter, err, atol = func(
|
|
268
|
+
A=A,
|
|
269
|
+
b=b,
|
|
270
|
+
x=x,
|
|
271
|
+
maxiter=max_iters,
|
|
272
|
+
tol=tol,
|
|
273
|
+
check_every=check_every,
|
|
274
|
+
M=M,
|
|
275
|
+
callback=callback,
|
|
276
|
+
use_cuda_graph=not wp.config.verify_cuda,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if not quiet:
|
|
280
|
+
res_str = "OK" if err <= atol else "TRUNCATED"
|
|
281
|
+
print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
|
|
282
|
+
|
|
283
|
+
return err, end_iter
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class SaddleSystem(LinearOperator):
|
|
287
|
+
"""Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
|
|
288
|
+
|
|
289
|
+
If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(
|
|
293
|
+
self,
|
|
294
|
+
A: BsrMatrix,
|
|
295
|
+
B: BsrMatrix,
|
|
296
|
+
Bt: Optional[BsrMatrix] = None,
|
|
297
|
+
use_diag_precond: bool = True,
|
|
298
|
+
):
|
|
299
|
+
if Bt is None:
|
|
300
|
+
Bt = bsr_transposed(B)
|
|
301
|
+
|
|
302
|
+
self._A = A
|
|
303
|
+
self._B = B
|
|
304
|
+
self._Bt = Bt
|
|
305
|
+
|
|
306
|
+
self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
|
|
307
|
+
self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
|
|
308
|
+
self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
|
|
309
|
+
|
|
310
|
+
saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
|
|
311
|
+
|
|
312
|
+
super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
|
|
313
|
+
|
|
314
|
+
if use_diag_precond:
|
|
315
|
+
self._preconditioner = self._diag_preconditioner()
|
|
316
|
+
else:
|
|
317
|
+
self._preconditioner = None
|
|
318
|
+
|
|
319
|
+
def _diag_preconditioner(self):
|
|
320
|
+
A = self._A
|
|
321
|
+
B = self._B
|
|
322
|
+
|
|
323
|
+
M_u = preconditioner(A, "diag")
|
|
324
|
+
|
|
325
|
+
A_diag = bsr_get_diag(A)
|
|
326
|
+
|
|
327
|
+
schur_block_shape = (B.block_shape[0], B.block_shape[0])
|
|
328
|
+
schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
|
|
329
|
+
schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
|
|
330
|
+
wp.launch(
|
|
331
|
+
_compute_schur_inverse_diagonal,
|
|
332
|
+
dim=B.nrow,
|
|
333
|
+
device=A.device,
|
|
334
|
+
inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
if schur_block_shape == (1, 1):
|
|
338
|
+
# Downcast 1x1 mats to scalars
|
|
339
|
+
schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
|
|
340
|
+
|
|
341
|
+
M_p = aslinearoperator(schur_inv_diag)
|
|
342
|
+
|
|
343
|
+
def precond_mv(x, y, z, alpha, beta):
|
|
344
|
+
x_u = self.u_slice(x)
|
|
345
|
+
x_p = self.p_slice(x)
|
|
346
|
+
y_u = self.u_slice(y)
|
|
347
|
+
y_p = self.p_slice(y)
|
|
348
|
+
z_u = self.u_slice(z)
|
|
349
|
+
z_p = self.p_slice(z)
|
|
350
|
+
|
|
351
|
+
M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
|
|
352
|
+
M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
|
|
353
|
+
|
|
354
|
+
return LinearOperator(
|
|
355
|
+
shape=self.shape,
|
|
356
|
+
dtype=self.dtype,
|
|
357
|
+
device=self.device,
|
|
358
|
+
matvec=precond_mv,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def preconditioner(self):
|
|
363
|
+
return self._preconditioner
|
|
364
|
+
|
|
365
|
+
def u_slice(self, a: wp.array):
|
|
366
|
+
return wp.array(
|
|
367
|
+
ptr=a.ptr,
|
|
368
|
+
dtype=self._u_dtype,
|
|
369
|
+
shape=self._A.nrow,
|
|
370
|
+
strides=None,
|
|
371
|
+
device=a.device,
|
|
372
|
+
pinned=a.pinned,
|
|
373
|
+
copy=False,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def p_slice(self, a: wp.array):
|
|
377
|
+
return wp.array(
|
|
378
|
+
ptr=a.ptr + self._p_byte_offset,
|
|
379
|
+
dtype=self._p_dtype,
|
|
380
|
+
shape=self._B.nrow,
|
|
381
|
+
strides=None,
|
|
382
|
+
device=a.device,
|
|
383
|
+
pinned=a.pinned,
|
|
384
|
+
copy=False,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def _saddle_mv(self, x, y, z, alpha, beta):
|
|
388
|
+
x_u = self.u_slice(x)
|
|
389
|
+
x_p = self.p_slice(x)
|
|
390
|
+
z_u = self.u_slice(z)
|
|
391
|
+
z_p = self.p_slice(z)
|
|
392
|
+
|
|
393
|
+
if y.ptr != z.ptr and beta != 0.0:
|
|
394
|
+
wp.copy(src=y, dest=z)
|
|
395
|
+
|
|
396
|
+
bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
|
|
397
|
+
bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
|
|
398
|
+
bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def bsr_solve_saddle(
|
|
402
|
+
saddle_system: SaddleSystem,
|
|
403
|
+
x_u: wp.array,
|
|
404
|
+
x_p: wp.array,
|
|
405
|
+
b_u: wp.array,
|
|
406
|
+
b_p: wp.array,
|
|
407
|
+
max_iters: int = 0,
|
|
408
|
+
tol: float = 0.0001,
|
|
409
|
+
check_every=10,
|
|
410
|
+
quiet=False,
|
|
411
|
+
method: str = "cg",
|
|
412
|
+
) -> Tuple[float, int]:
|
|
413
|
+
"""Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
saddle_system: Saddle point system
|
|
417
|
+
x_u: primal part of the result vector and initial guess
|
|
418
|
+
x_p: Lagrange multiplier part of the result vector and initial guess
|
|
419
|
+
b_u: primal left-hand-side
|
|
420
|
+
b_p: constraint left-hand-side
|
|
421
|
+
max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
|
|
422
|
+
tol: relative tolerance under which to stop the solve
|
|
423
|
+
check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
|
|
424
|
+
quiet: if True, do not print iteration residuals
|
|
425
|
+
method: Iterative solver method to use, defaults to BiCGSTAB
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
Tuple (residual norm, iteration count)
|
|
429
|
+
|
|
430
|
+
"""
|
|
431
|
+
x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
|
|
432
|
+
b = wp.empty_like(x)
|
|
433
|
+
|
|
434
|
+
wp.copy(src=x_u, dest=saddle_system.u_slice(x))
|
|
435
|
+
wp.copy(src=x_p, dest=saddle_system.p_slice(x))
|
|
436
|
+
wp.copy(src=b_u, dest=saddle_system.u_slice(b))
|
|
437
|
+
wp.copy(src=b_p, dest=saddle_system.p_slice(b))
|
|
438
|
+
|
|
439
|
+
func = _get_linear_solver_func(method_name=method)
|
|
440
|
+
|
|
441
|
+
def print_callback(i, err, tol):
|
|
442
|
+
print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
|
|
443
|
+
|
|
444
|
+
callback = None if quiet else print_callback
|
|
445
|
+
|
|
446
|
+
end_iter, err, atol = func(
|
|
447
|
+
A=saddle_system,
|
|
448
|
+
b=b,
|
|
449
|
+
x=x,
|
|
450
|
+
maxiter=max_iters,
|
|
451
|
+
tol=tol,
|
|
452
|
+
check_every=check_every,
|
|
453
|
+
M=saddle_system.preconditioner,
|
|
454
|
+
callback=callback,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if not quiet:
|
|
458
|
+
res_str = "OK" if err <= atol else "TRUNCATED"
|
|
459
|
+
print(f"{func.__name__}: terminated after {end_iter} iterations with absolute error = \t {err} ({res_str})")
|
|
460
|
+
|
|
461
|
+
wp.copy(dest=x_u, src=saddle_system.u_slice(x))
|
|
462
|
+
wp.copy(dest=x_p, src=saddle_system.p_slice(x))
|
|
463
|
+
|
|
464
|
+
return err, end_iter
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@wp.kernel(enable_backward=False)
|
|
468
|
+
def _compute_schur_inverse_diagonal(
|
|
469
|
+
B_offsets: wp.array(dtype=int),
|
|
470
|
+
B_indices: wp.array(dtype=int),
|
|
471
|
+
B_values: wp.array(dtype=Any),
|
|
472
|
+
A_diag: wp.array(dtype=Any),
|
|
473
|
+
P_diag: wp.array(dtype=Any),
|
|
474
|
+
):
|
|
475
|
+
row = wp.tid()
|
|
476
|
+
|
|
477
|
+
zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
|
|
478
|
+
|
|
479
|
+
schur = zero
|
|
480
|
+
|
|
481
|
+
beg = B_offsets[row]
|
|
482
|
+
end = B_offsets[row + 1]
|
|
483
|
+
|
|
484
|
+
for b in range(beg, end):
|
|
485
|
+
B = B_values[b]
|
|
486
|
+
col = B_indices[b]
|
|
487
|
+
Ai = wp.inverse(A_diag[col])
|
|
488
|
+
S = B * Ai * wp.transpose(B)
|
|
489
|
+
schur += S
|
|
490
|
+
|
|
491
|
+
P_diag[row] = fem.utils.inverse_qr(schur)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def invert_diagonal_bsr_matrix(A: BsrMatrix):
|
|
495
|
+
"""Inverts each block of a block-diagonal mass matrix"""
|
|
496
|
+
|
|
497
|
+
values = A.values
|
|
498
|
+
if not wp.types.type_is_matrix(values.dtype):
|
|
499
|
+
values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
|
|
500
|
+
|
|
501
|
+
wp.launch(
|
|
502
|
+
kernel=_block_diagonal_invert,
|
|
503
|
+
dim=A.nrow,
|
|
504
|
+
inputs=[values],
|
|
505
|
+
device=values.device,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
@wp.kernel(enable_backward=False)
|
|
510
|
+
def _block_diagonal_invert(values: wp.array(dtype=Any)):
|
|
511
|
+
i = wp.tid()
|
|
512
|
+
values[i] = fem.utils.inverse_qr(values[i])
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
#
|
|
516
|
+
# Plot utilities
|
|
517
|
+
#
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
class Plot:
|
|
521
|
+
def __init__(self, stage=None, default_point_radius=0.01):
|
|
522
|
+
self.default_point_radius = default_point_radius
|
|
523
|
+
|
|
524
|
+
self._fields = {}
|
|
525
|
+
|
|
526
|
+
self._usd_renderer = None
|
|
527
|
+
if stage is not None:
|
|
528
|
+
try:
|
|
529
|
+
from warp.render import UsdRenderer
|
|
530
|
+
|
|
531
|
+
self._usd_renderer = UsdRenderer(stage)
|
|
532
|
+
except Exception as err:
|
|
533
|
+
print(f"Could not initialize UsdRenderer for stage '{stage}': {err}.")
|
|
534
|
+
|
|
535
|
+
def begin_frame(self, time):
|
|
536
|
+
if self._usd_renderer is not None:
|
|
537
|
+
self._usd_renderer.begin_frame(time=time)
|
|
538
|
+
|
|
539
|
+
def end_frame(self):
|
|
540
|
+
if self._usd_renderer is not None:
|
|
541
|
+
self._usd_renderer.end_frame()
|
|
542
|
+
|
|
543
|
+
def add_field(self, name: str, field: fem.DiscreteField):
|
|
544
|
+
if self._usd_renderer is not None:
|
|
545
|
+
self._render_to_usd(field)
|
|
546
|
+
|
|
547
|
+
if name not in self._fields:
|
|
548
|
+
field_clone = field.space.make_field(space_partition=field.space_partition)
|
|
549
|
+
self._fields[name] = (field_clone, [])
|
|
550
|
+
|
|
551
|
+
self._fields[name][1].append(field.dof_values.numpy())
|
|
552
|
+
|
|
553
|
+
def _render_to_usd(self, name: str, field: fem.DiscreteField):
|
|
554
|
+
points = field.space.node_positions().numpy()
|
|
555
|
+
values = field.dof_values.numpy()
|
|
556
|
+
|
|
557
|
+
if values.ndim == 2:
|
|
558
|
+
if values.shape[1] == field.space.dimension:
|
|
559
|
+
# use values as displacement
|
|
560
|
+
points += values
|
|
561
|
+
else:
|
|
562
|
+
# use magnitude
|
|
563
|
+
values = np.linalg.norm(values, axis=1)
|
|
564
|
+
|
|
565
|
+
if field.space.dimension == 2:
|
|
566
|
+
z = values if values.ndim == 1 else np.zeros((points.shape[0], 1))
|
|
567
|
+
points = np.hstack((points, z))
|
|
568
|
+
|
|
569
|
+
if hasattr(field.space, "node_triangulation"):
|
|
570
|
+
indices = field.space.node_triangulation()
|
|
571
|
+
self._usd_renderer.render_mesh(name, points=points, indices=indices)
|
|
572
|
+
else:
|
|
573
|
+
self._usd_renderer.render_points(name, points=points, radius=self.default_point_radius)
|
|
574
|
+
elif values.ndim == 1:
|
|
575
|
+
self._usd_renderer.render_points(name, points, radius=values)
|
|
576
|
+
else:
|
|
577
|
+
self._usd_renderer.render_points(name, points, radius=self.default_point_radius)
|
|
578
|
+
|
|
579
|
+
def plot(self, options: Dict[str, Any] = None, backend: str = "auto"):
|
|
580
|
+
if options is None:
|
|
581
|
+
options = {}
|
|
582
|
+
|
|
583
|
+
if backend == "pyvista":
|
|
584
|
+
return self._plot_pyvista(options)
|
|
585
|
+
if backend == "matplotlib":
|
|
586
|
+
return self._plot_matplotlib(options)
|
|
587
|
+
|
|
588
|
+
# try both
|
|
589
|
+
try:
|
|
590
|
+
return self._plot_pyvista(options)
|
|
591
|
+
except ModuleNotFoundError:
|
|
592
|
+
try:
|
|
593
|
+
return self._plot_matplotlib(options)
|
|
594
|
+
except ModuleNotFoundError:
|
|
595
|
+
wp.utils.warn("pyvista or matplotlib must be installed to visualize solution results")
|
|
596
|
+
|
|
597
|
+
def _plot_pyvista(self, options: Dict[str, Any]):
|
|
598
|
+
import pyvista
|
|
599
|
+
import pyvista.themes
|
|
600
|
+
|
|
601
|
+
grids = {}
|
|
602
|
+
scales = {}
|
|
603
|
+
markers = {}
|
|
604
|
+
|
|
605
|
+
animate = False
|
|
606
|
+
|
|
607
|
+
ref_geom = options.get("ref_geom", None)
|
|
608
|
+
if ref_geom is not None:
|
|
609
|
+
if isinstance(ref_geom, tuple):
|
|
610
|
+
vertices, counts, indices = ref_geom
|
|
611
|
+
offsets = np.cumsum(counts)
|
|
612
|
+
ranges = np.array([offsets - counts, offsets]).T
|
|
613
|
+
faces = np.concatenate(
|
|
614
|
+
[[count] + list(indices[beg:end]) for (count, (beg, end)) in zip(counts, ranges)]
|
|
615
|
+
)
|
|
616
|
+
ref_geom = pyvista.PolyData(vertices, faces)
|
|
617
|
+
else:
|
|
618
|
+
ref_geom = pyvista.PolyData(ref_geom)
|
|
619
|
+
|
|
620
|
+
for name, (field, values) in self._fields.items():
|
|
621
|
+
cells, types = field.space.vtk_cells()
|
|
622
|
+
node_pos = field.space.node_positions().numpy()
|
|
623
|
+
|
|
624
|
+
args = options.get(name, {})
|
|
625
|
+
|
|
626
|
+
grid_scale = np.max(np.max(node_pos, axis=0) - np.min(node_pos, axis=0))
|
|
627
|
+
value_range = self._get_field_value_range(values, args)
|
|
628
|
+
scales[name] = (grid_scale, value_range)
|
|
629
|
+
|
|
630
|
+
if node_pos.shape[1] == 2:
|
|
631
|
+
node_pos = np.hstack((node_pos, np.zeros((node_pos.shape[0], 1))))
|
|
632
|
+
|
|
633
|
+
grid = pyvista.UnstructuredGrid(cells, types, node_pos)
|
|
634
|
+
grids[name] = grid
|
|
635
|
+
|
|
636
|
+
if len(values) > 1:
|
|
637
|
+
animate = True
|
|
638
|
+
|
|
639
|
+
def set_frame_data(frame):
|
|
640
|
+
for name, (field, values) in self._fields.items():
|
|
641
|
+
if frame > 0 and len(values) == 1:
|
|
642
|
+
continue
|
|
643
|
+
|
|
644
|
+
v = values[frame % len(values)]
|
|
645
|
+
grid = grids[name]
|
|
646
|
+
grid_scale, value_range = scales[name]
|
|
647
|
+
field_args = options.get(name, {})
|
|
648
|
+
|
|
649
|
+
marker = None
|
|
650
|
+
|
|
651
|
+
if field.space.dimension == 2 and v.ndim == 2 and v.shape[1] == 2:
|
|
652
|
+
grid.point_data[name] = np.hstack((v, np.zeros((v.shape[0], 1))))
|
|
653
|
+
else:
|
|
654
|
+
grid.point_data[name] = v
|
|
655
|
+
|
|
656
|
+
if v.ndim == 2:
|
|
657
|
+
grid.point_data[name + "_mag"] = np.linalg.norm(v, axis=1)
|
|
658
|
+
|
|
659
|
+
if "arrows" in field_args:
|
|
660
|
+
glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
|
|
661
|
+
glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
|
|
662
|
+
marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
|
|
663
|
+
elif "contours" in field_args:
|
|
664
|
+
levels = field_args["contours"].get("levels", 10)
|
|
665
|
+
if type(levels) == int:
|
|
666
|
+
levels = np.linspace(*value_range, levels)
|
|
667
|
+
marker = grid.contour(isosurfaces=levels, scalars=name + "_mag" if v.ndim == 2 else name)
|
|
668
|
+
elif field.space.dimension == 2:
|
|
669
|
+
z_scale = grid_scale / max(1.0e-8, value_range[1] - value_range[0])
|
|
670
|
+
|
|
671
|
+
if "streamlines" in field_args:
|
|
672
|
+
center = np.mean(grid.points, axis=0)
|
|
673
|
+
density = field_args["streamlines"].get("density", 1.0)
|
|
674
|
+
cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
|
|
675
|
+
|
|
676
|
+
separating_distance = 0.5 / (30.0 * density * cell_size)
|
|
677
|
+
# Try with various sep distance until we get at least one line
|
|
678
|
+
while separating_distance * cell_size < 1.0:
|
|
679
|
+
lines = grid.streamlines_evenly_spaced_2D(
|
|
680
|
+
vectors=name,
|
|
681
|
+
start_position=center,
|
|
682
|
+
separating_distance=separating_distance,
|
|
683
|
+
separating_distance_ratio=0.5,
|
|
684
|
+
step_length=0.25,
|
|
685
|
+
compute_vorticity=False,
|
|
686
|
+
)
|
|
687
|
+
if lines.n_lines > 0:
|
|
688
|
+
break
|
|
689
|
+
separating_distance *= 1.25
|
|
690
|
+
marker = lines.tube(radius=0.0025 * grid_scale / density)
|
|
691
|
+
elif "arrows" in field_args:
|
|
692
|
+
glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
|
|
693
|
+
glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
|
|
694
|
+
marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
|
|
695
|
+
elif "displacement" in field_args:
|
|
696
|
+
grid.points[:, 0:2] = field.space.node_positions().numpy() + v
|
|
697
|
+
else:
|
|
698
|
+
# Extrude surface
|
|
699
|
+
z = v if v.ndim == 1 else grid.point_data[name + "_mag"]
|
|
700
|
+
grid.points[:, 2] = z * z_scale
|
|
701
|
+
|
|
702
|
+
elif field.space.dimension == 3:
|
|
703
|
+
if "streamlines" in field_args:
|
|
704
|
+
center = np.mean(grid.points, axis=0)
|
|
705
|
+
density = field_args["streamlines"].get("density", 1.0)
|
|
706
|
+
cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
|
|
707
|
+
lines = grid.streamlines(vectors=name, n_points=int(100 * density))
|
|
708
|
+
marker = lines.tube(radius=0.0025 * grid_scale / np.sqrt(density))
|
|
709
|
+
elif "displacement" in field_args:
|
|
710
|
+
grid.points = field.space.node_positions().numpy() + v
|
|
711
|
+
|
|
712
|
+
if frame == 0:
|
|
713
|
+
if v.ndim == 1:
|
|
714
|
+
grid.set_active_scalars(name)
|
|
715
|
+
else:
|
|
716
|
+
grid.set_active_vectors(name)
|
|
717
|
+
grid.set_active_scalars(name + "_mag")
|
|
718
|
+
markers[name] = marker
|
|
719
|
+
elif marker:
|
|
720
|
+
markers[name].copy_from(marker)
|
|
721
|
+
|
|
722
|
+
set_frame_data(0)
|
|
723
|
+
|
|
724
|
+
subplot_rows = options.get("rows", 1)
|
|
725
|
+
subplot_shape = (subplot_rows, (len(grids) + subplot_rows - 1) // subplot_rows)
|
|
726
|
+
|
|
727
|
+
plotter = pyvista.Plotter(shape=subplot_shape, theme=pyvista.themes.DocumentProTheme())
|
|
728
|
+
plotter.link_views()
|
|
729
|
+
plotter.add_camera_orientation_widget()
|
|
730
|
+
for index, (name, grid) in enumerate(grids.items()):
|
|
731
|
+
plotter.subplot(index // subplot_shape[1], index % subplot_shape[1])
|
|
732
|
+
grid_scale, value_range = scales[name]
|
|
733
|
+
field = self._fields[name][0]
|
|
734
|
+
marker = markers[name]
|
|
735
|
+
if marker:
|
|
736
|
+
if field.space.dimension == 2:
|
|
737
|
+
plotter.add_mesh(marker, show_scalar_bar=False)
|
|
738
|
+
plotter.add_mesh(grid, opacity=0.25, clim=value_range)
|
|
739
|
+
plotter.view_xy()
|
|
740
|
+
else:
|
|
741
|
+
plotter.add_mesh(marker)
|
|
742
|
+
elif field.space.geometry.cell_dimension == 3:
|
|
743
|
+
plotter.add_mesh_clip_plane(grid, show_edges=True, clim=value_range, assign_to_axis="z")
|
|
744
|
+
else:
|
|
745
|
+
plotter.add_mesh(grid, show_edges=True, clim=value_range)
|
|
746
|
+
|
|
747
|
+
if ref_geom:
|
|
748
|
+
plotter.add_mesh(ref_geom)
|
|
749
|
+
|
|
750
|
+
plotter.show(interactive_update=animate)
|
|
751
|
+
|
|
752
|
+
frame = 0
|
|
753
|
+
while animate and not plotter.iren.interactor.GetDone():
|
|
754
|
+
frame += 1
|
|
755
|
+
set_frame_data(frame)
|
|
756
|
+
plotter.update()
|
|
757
|
+
|
|
758
|
+
def _plot_matplotlib(self, options: Dict[str, Any]):
|
|
759
|
+
import matplotlib.animation as animation
|
|
760
|
+
import matplotlib.pyplot as plt
|
|
761
|
+
from matplotlib import cm
|
|
762
|
+
|
|
763
|
+
def make_animation(fig, ax, cax, values, draw_func):
|
|
764
|
+
def animate(i):
|
|
765
|
+
cs = draw_func(ax, values[i])
|
|
766
|
+
|
|
767
|
+
cax.cla()
|
|
768
|
+
fig.colorbar(cs, cax)
|
|
769
|
+
|
|
770
|
+
return cs
|
|
771
|
+
|
|
772
|
+
return animation.FuncAnimation(
|
|
773
|
+
ax.figure,
|
|
774
|
+
animate,
|
|
775
|
+
interval=30,
|
|
776
|
+
blit=False,
|
|
777
|
+
frames=len(values),
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
def make_draw_func(field, args, plot_func, plot_opts):
|
|
781
|
+
def draw_fn(axes, values):
|
|
782
|
+
axes.clear()
|
|
783
|
+
|
|
784
|
+
field.dof_values = values
|
|
785
|
+
cs = plot_func(field, axes=axes, **plot_opts)
|
|
786
|
+
|
|
787
|
+
if "xlim" in args:
|
|
788
|
+
axes.set_xlim(*args["xlim"])
|
|
789
|
+
if "ylim" in args:
|
|
790
|
+
axes.set_ylim(*args["ylim"])
|
|
791
|
+
|
|
792
|
+
return cs
|
|
793
|
+
|
|
794
|
+
return draw_fn
|
|
795
|
+
|
|
796
|
+
anims = []
|
|
797
|
+
|
|
798
|
+
field_count = len(self._fields)
|
|
799
|
+
subplot_rows = options.get("rows", 1)
|
|
800
|
+
subplot_shape = (subplot_rows, (field_count + subplot_rows - 1) // subplot_rows)
|
|
801
|
+
|
|
802
|
+
for index, (name, (field, values)) in enumerate(self._fields.items()):
|
|
803
|
+
args = options.get(name, {})
|
|
804
|
+
v = values[0]
|
|
805
|
+
|
|
806
|
+
plot_fn = None
|
|
807
|
+
plot_3d = False
|
|
808
|
+
plot_opts = {"cmap": cm.viridis}
|
|
809
|
+
|
|
810
|
+
plot_opts["clim"] = self._get_field_value_range(values, args)
|
|
811
|
+
|
|
812
|
+
if field.space.dimension == 2:
|
|
813
|
+
if "contours" in args:
|
|
814
|
+
plot_opts["levels"] = args["contours"].get("levels", None)
|
|
815
|
+
plot_fn = _plot_contours
|
|
816
|
+
elif v.ndim == 2 and v.shape[1] == 2:
|
|
817
|
+
if "displacement" in args:
|
|
818
|
+
plot_fn = _plot_displaced_tri_mesh
|
|
819
|
+
elif "streamlines" in args:
|
|
820
|
+
plot_opts["density"] = args["streamlines"].get("density", 1.0)
|
|
821
|
+
plot_fn = _plot_streamlines
|
|
822
|
+
elif "arrows" in args:
|
|
823
|
+
plot_opts["glyph_scale"] = args["arrows"].get("glyph_scale", 1.0)
|
|
824
|
+
plot_fn = _plot_quivers
|
|
825
|
+
|
|
826
|
+
if plot_fn is None:
|
|
827
|
+
plot_fn = _plot_surface
|
|
828
|
+
plot_3d = True
|
|
829
|
+
|
|
830
|
+
elif field.space.dimension == 3:
|
|
831
|
+
if "arrows" in args or "streamlines" in args:
|
|
832
|
+
plot_opts["glyph_scale"] = args.get("arrows", {}).get("glyph_scale", 1.0)
|
|
833
|
+
plot_fn = _plot_quivers_3d
|
|
834
|
+
elif field.space.geometry.cell_dimension == 2:
|
|
835
|
+
plot_fn = _plot_surface
|
|
836
|
+
else:
|
|
837
|
+
plot_fn = _plot_3d_scatter
|
|
838
|
+
plot_3d = True
|
|
839
|
+
|
|
840
|
+
subplot_kw = {"projection": "3d"} if plot_3d else {}
|
|
841
|
+
axes = plt.subplot(*subplot_shape, index + 1, **subplot_kw)
|
|
842
|
+
|
|
843
|
+
if not plot_3d:
|
|
844
|
+
axes.set_aspect("equal")
|
|
845
|
+
|
|
846
|
+
draw_fn = make_draw_func(field, args, plot_func=plot_fn, plot_opts=plot_opts)
|
|
847
|
+
cs = draw_fn(axes, values[0])
|
|
848
|
+
|
|
849
|
+
fig = plt.gcf()
|
|
850
|
+
cax = fig.colorbar(cs).ax
|
|
851
|
+
|
|
852
|
+
if len(values) > 1:
|
|
853
|
+
anims.append(make_animation(fig, axes, cax, values, draw_func=draw_fn))
|
|
854
|
+
|
|
855
|
+
plt.show()
|
|
856
|
+
|
|
857
|
+
@staticmethod
|
|
858
|
+
def _get_field_value_range(values, field_options: Dict[str, Any]):
|
|
859
|
+
value_range = field_options.get("clim", None)
|
|
860
|
+
if value_range is None:
|
|
861
|
+
value_range = (
|
|
862
|
+
min((np.min(_value_or_magnitude(v)) for v in values)),
|
|
863
|
+
max((np.max(_value_or_magnitude(v)) for v in values)),
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
return value_range
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def _value_or_magnitude(values: np.ndarray):
|
|
870
|
+
if values.ndim == 1:
|
|
871
|
+
return values
|
|
872
|
+
return np.linalg.norm(values, axis=-1)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def _field_triangulation(field):
|
|
876
|
+
from matplotlib.tri import Triangulation
|
|
877
|
+
|
|
878
|
+
node_positions = field.space.node_positions().numpy()
|
|
879
|
+
return Triangulation(x=node_positions[:, 0], y=node_positions[:, 1], triangles=field.space.node_triangulation())
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def _plot_surface(field, axes, **kwargs):
|
|
883
|
+
from matplotlib.cm import get_cmap
|
|
884
|
+
from matplotlib.colors import Normalize
|
|
885
|
+
|
|
886
|
+
C = _value_or_magnitude(field.dof_values.numpy())
|
|
887
|
+
|
|
888
|
+
positions = field.space.node_positions().numpy().T
|
|
889
|
+
if field.space.dimension == 3:
|
|
890
|
+
X, Y, Z = positions
|
|
891
|
+
else:
|
|
892
|
+
X, Y = positions
|
|
893
|
+
Z = C
|
|
894
|
+
axes.set_zlim(kwargs["clim"])
|
|
895
|
+
|
|
896
|
+
if hasattr(field.space, "node_grid"):
|
|
897
|
+
X, Y = field.space.node_grid()
|
|
898
|
+
C = C.reshape(X.shape)
|
|
899
|
+
return axes.plot_surface(X, Y, C, linewidth=0.1, antialiased=False, **kwargs)
|
|
900
|
+
|
|
901
|
+
if hasattr(field.space, "node_triangulation"):
|
|
902
|
+
triangulation = _field_triangulation(field)
|
|
903
|
+
|
|
904
|
+
if field.space.dimension == 3:
|
|
905
|
+
plot = axes.plot_trisurf(triangulation, Z, linewidth=0.1, antialiased=False)
|
|
906
|
+
# change colors -- recompute color map manually
|
|
907
|
+
vmin, vmax = kwargs["clim"]
|
|
908
|
+
norm = Normalize(vmin=vmin, vmax=vmax)
|
|
909
|
+
values = np.mean(C[triangulation.triangles], axis=1)
|
|
910
|
+
colors = get_cmap(kwargs["cmap"])(norm(values))
|
|
911
|
+
plot.set_norm(norm)
|
|
912
|
+
plot.set_fc(colors)
|
|
913
|
+
else:
|
|
914
|
+
plot = axes.plot_trisurf(triangulation, C, linewidth=0.1, antialiased=False, **kwargs)
|
|
915
|
+
|
|
916
|
+
return plot
|
|
917
|
+
|
|
918
|
+
# scatter
|
|
919
|
+
return axes.scatter(X, Y, Z, c=C, **kwargs)
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def _plot_displaced_tri_mesh(field, axes, **kwargs):
|
|
923
|
+
triangulation = _field_triangulation(field)
|
|
924
|
+
|
|
925
|
+
displacement = field.dof_values.numpy()
|
|
926
|
+
triangulation.x += displacement[:, 0]
|
|
927
|
+
triangulation.y += displacement[:, 1]
|
|
928
|
+
|
|
929
|
+
Z = _value_or_magnitude(displacement)
|
|
930
|
+
|
|
931
|
+
# Plot the surface.
|
|
932
|
+
cs = axes.tripcolor(triangulation, Z, **kwargs)
|
|
933
|
+
axes.triplot(triangulation, lw=0.1)
|
|
934
|
+
|
|
935
|
+
return cs
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def _plot_quivers(field, axes, clim=None, glyph_scale=1.0, **kwargs):
|
|
939
|
+
X, Y = field.space.node_positions().numpy().T
|
|
940
|
+
|
|
941
|
+
vel = field.dof_values.numpy()
|
|
942
|
+
u = vel[:, 0].reshape(X.shape)
|
|
943
|
+
v = vel[:, 1].reshape(X.shape)
|
|
944
|
+
|
|
945
|
+
return axes.quiver(X, Y, u, v, _value_or_magnitude(vel), scale=1.0 / glyph_scale, **kwargs)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
def _plot_quivers_3d(field, axes, clim=None, cmap=None, glyph_scale=1.0, **kwargs):
|
|
949
|
+
X, Y, Z = field.space.node_positions().numpy().T
|
|
950
|
+
|
|
951
|
+
vel = field.dof_values.numpy()
|
|
952
|
+
|
|
953
|
+
colors = cmap((_value_or_magnitude(vel) - clim[0]) / (clim[1] - clim[0]))
|
|
954
|
+
|
|
955
|
+
u = vel[:, 0].reshape(X.shape) / (clim[1] - clim[0])
|
|
956
|
+
v = vel[:, 1].reshape(X.shape) / (clim[1] - clim[0])
|
|
957
|
+
w = vel[:, 2].reshape(X.shape) / (clim[1] - clim[0])
|
|
958
|
+
|
|
959
|
+
return axes.quiver(X, Y, Z, u, v, w, colors=colors, length=glyph_scale, clim=clim, cmap=cmap, **kwargs)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
def _plot_streamlines(field, axes, clim=None, **kwargs):
|
|
963
|
+
import matplotlib.tri as tr
|
|
964
|
+
|
|
965
|
+
triangulation = _field_triangulation(field)
|
|
966
|
+
|
|
967
|
+
vel = field.dof_values.numpy()
|
|
968
|
+
|
|
969
|
+
itp_vx = tr.CubicTriInterpolator(triangulation, vel[:, 0])
|
|
970
|
+
itp_vy = tr.CubicTriInterpolator(triangulation, vel[:, 1])
|
|
971
|
+
|
|
972
|
+
X, Y = np.meshgrid(
|
|
973
|
+
np.linspace(np.min(triangulation.x), np.max(triangulation.x), 100),
|
|
974
|
+
np.linspace(np.min(triangulation.y), np.max(triangulation.y), 100),
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
u = itp_vx(X, Y)
|
|
978
|
+
v = itp_vy(X, Y)
|
|
979
|
+
C = np.sqrt(u * u + v * v)
|
|
980
|
+
|
|
981
|
+
plot = axes.streamplot(X, Y, u, v, color=C, **kwargs)
|
|
982
|
+
return plot.lines
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def _plot_contours(field, axes, clim=None, **kwargs):
|
|
986
|
+
triangulation = _field_triangulation(field)
|
|
987
|
+
|
|
988
|
+
Z = _value_or_magnitude(field.dof_values.numpy())
|
|
989
|
+
|
|
990
|
+
tc = axes.tricontourf(triangulation, Z, **kwargs)
|
|
991
|
+
axes.tricontour(triangulation, Z, **kwargs)
|
|
992
|
+
return tc
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
def _plot_3d_scatter(field, axes, **kwargs):
|
|
996
|
+
X, Y, Z = field.space.node_positions().numpy().T
|
|
997
|
+
|
|
998
|
+
f = _value_or_magnitude(field.dof_values.numpy()).reshape(X.shape)
|
|
999
|
+
|
|
1000
|
+
return axes.scatter(X, Y, Z, c=f, **kwargs)
|