warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +882 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/{builtins.py → _src/builtins.py} +1435 -379
- warp/_src/codegen.py +4361 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/{fem → _src/fem}/domain.py +42 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/{fem → _src/fem}/field/nodal_field.py +32 -15
- warp/{fem → _src/fem}/field/restriction.py +3 -1
- warp/{fem → _src/fem}/field/virtual.py +55 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
- warp/{fem → _src/fem}/geometry/element.py +34 -10
- warp/{fem → _src/fem}/geometry/geometry.py +50 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
- warp/{fem → _src/fem}/geometry/partition.py +123 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
- warp/{fem → _src/fem}/integrate.py +166 -158
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
- warp/_src/fem/space/basis_space.py +681 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
- warp/{fem → _src/fem}/space/function_space.py +16 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
- warp/{fem → _src/fem}/space/partition.py +119 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/restriction.py +68 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
- warp/_src/fem/space/topology.py +461 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +3 -3
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +521 -250
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +18 -17
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +578 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -19,11 +19,11 @@ import textwrap
|
|
|
19
19
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import warp as wp
|
|
22
|
-
import warp.fem.operator as operator
|
|
23
|
-
from warp.codegen import get_annotations
|
|
24
|
-
from warp.fem import cache
|
|
25
|
-
from warp.fem.domain import GeometryDomain
|
|
26
|
-
from warp.fem.field import (
|
|
22
|
+
import warp._src.fem.operator as operator
|
|
23
|
+
from warp._src.codegen import Struct, StructInstance, get_annotations
|
|
24
|
+
from warp._src.fem import cache
|
|
25
|
+
from warp._src.fem.domain import GeometryDomain
|
|
26
|
+
from warp._src.fem.field import (
|
|
27
27
|
DiscreteField,
|
|
28
28
|
FieldLike,
|
|
29
29
|
FieldRestriction,
|
|
@@ -34,18 +34,18 @@ from warp.fem.field import (
|
|
|
34
34
|
TrialField,
|
|
35
35
|
make_restriction,
|
|
36
36
|
)
|
|
37
|
-
from warp.fem.field.virtual import (
|
|
37
|
+
from warp._src.fem.field.virtual import (
|
|
38
38
|
make_bilinear_dispatch_kernel,
|
|
39
39
|
make_linear_dispatch_kernel,
|
|
40
40
|
)
|
|
41
|
-
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
42
|
-
from warp.fem.operator import (
|
|
41
|
+
from warp._src.fem.linalg import array_axpy, basis_coefficient
|
|
42
|
+
from warp._src.fem.operator import (
|
|
43
43
|
Integrand,
|
|
44
44
|
Operator,
|
|
45
45
|
integrand,
|
|
46
46
|
)
|
|
47
|
-
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
48
|
-
from warp.fem.types import (
|
|
47
|
+
from warp._src.fem.quadrature import Quadrature, RegularQuadrature
|
|
48
|
+
from warp._src.fem.types import (
|
|
49
49
|
NULL_DOF_INDEX,
|
|
50
50
|
NULL_ELEMENT_INDEX,
|
|
51
51
|
NULL_NODE_INDEX,
|
|
@@ -57,15 +57,17 @@ from warp.fem.types import (
|
|
|
57
57
|
Sample,
|
|
58
58
|
make_free_sample,
|
|
59
59
|
)
|
|
60
|
-
from warp.fem.utils import type_zero_element
|
|
61
|
-
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
62
|
-
from warp.types import is_array, type_size
|
|
63
|
-
from warp.utils import array_cast
|
|
60
|
+
from warp._src.fem.utils import type_zero_element
|
|
61
|
+
from warp._src.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
62
|
+
from warp._src.types import is_array, type_repr, type_scalar_type, type_size, type_to_warp
|
|
63
|
+
from warp._src.utils import array_cast, warn
|
|
64
|
+
|
|
65
|
+
_wp_module_name_ = "warp.fem.integrate"
|
|
64
66
|
|
|
65
67
|
|
|
66
68
|
def _resolve_path(func, node):
|
|
67
69
|
"""
|
|
68
|
-
Resolves variable and path from ast node/attribute (adapted from warp.codegen)
|
|
70
|
+
Resolves variable and path from ast node/attribute (adapted from warp._src.codegen)
|
|
69
71
|
"""
|
|
70
72
|
|
|
71
73
|
modules = []
|
|
@@ -83,20 +85,20 @@ def _resolve_path(func, node):
|
|
|
83
85
|
if len(path) == 0:
|
|
84
86
|
return None, path
|
|
85
87
|
|
|
86
|
-
|
|
88
|
+
name = path[0]
|
|
87
89
|
try:
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
90
|
+
# look up in closure variables
|
|
91
|
+
idx = func.__code__.co_freevars.index(name)
|
|
92
|
+
expr = func.__closure__[idx].cell_contents
|
|
93
|
+
except ValueError:
|
|
94
|
+
# look up in global variables
|
|
95
|
+
expr = func.__globals__.get(name)
|
|
96
|
+
|
|
97
|
+
for name in path[1:]:
|
|
98
|
+
if expr is not None:
|
|
99
|
+
expr = getattr(expr, name, None)
|
|
98
100
|
|
|
99
|
-
return
|
|
101
|
+
return expr, path
|
|
100
102
|
|
|
101
103
|
|
|
102
104
|
class IntegrandVisitor(ast.NodeTransformer):
|
|
@@ -275,7 +277,7 @@ class IntegrandTransformer(IntegrandVisitor):
|
|
|
275
277
|
try:
|
|
276
278
|
# Retrieve the function pointer corresponding to the operator implementation for the field type
|
|
277
279
|
pointer = operator.resolver(field)
|
|
278
|
-
if not isinstance(pointer, wp.
|
|
280
|
+
if not isinstance(pointer, wp.Function):
|
|
279
281
|
raise NotImplementedError(operator.resolver.__name__)
|
|
280
282
|
|
|
281
283
|
except (AttributeError, NotImplementedError) as e:
|
|
@@ -360,15 +362,13 @@ def _parse_integrand_arguments(
|
|
|
360
362
|
trial_name = None
|
|
361
363
|
|
|
362
364
|
argspec = integrand.argspec
|
|
363
|
-
for arg in argspec.
|
|
364
|
-
arg_type = argspec.annotations[arg]
|
|
365
|
+
for arg, arg_type in argspec.annotations.items():
|
|
365
366
|
if arg_type == Field:
|
|
366
367
|
try:
|
|
367
368
|
field = fields[arg]
|
|
368
369
|
except KeyError as err:
|
|
369
370
|
raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
|
|
370
|
-
|
|
371
|
-
raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
|
|
371
|
+
|
|
372
372
|
if isinstance(field, TestField):
|
|
373
373
|
if test_name is not None:
|
|
374
374
|
raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
|
|
@@ -377,28 +377,26 @@ def _parse_integrand_arguments(
|
|
|
377
377
|
if trial_name is not None:
|
|
378
378
|
raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
|
|
379
379
|
trial_name = arg
|
|
380
|
+
elif not isinstance(field, FieldLike):
|
|
381
|
+
raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
|
|
382
|
+
|
|
380
383
|
field_args[arg] = field
|
|
381
|
-
|
|
384
|
+
continue
|
|
385
|
+
|
|
386
|
+
if arg in fields:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"Cannot pass a field argument to '{arg}' of '{integrand.name}' which is not of type 'Field'"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
if arg_type == Domain:
|
|
382
392
|
if domain_name is not None:
|
|
383
393
|
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
|
|
384
|
-
if arg in fields:
|
|
385
|
-
raise ValueError(
|
|
386
|
-
f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
387
|
-
)
|
|
388
394
|
domain_name = arg
|
|
389
395
|
elif arg_type == Sample:
|
|
390
396
|
if sample_name is not None:
|
|
391
397
|
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
|
|
392
|
-
if arg in fields:
|
|
393
|
-
raise ValueError(
|
|
394
|
-
f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
395
|
-
)
|
|
396
398
|
sample_name = arg
|
|
397
399
|
else:
|
|
398
|
-
if arg in fields:
|
|
399
|
-
raise ValueError(
|
|
400
|
-
f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
|
|
401
|
-
)
|
|
402
400
|
value_args[arg] = arg_type
|
|
403
401
|
|
|
404
402
|
return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
|
|
@@ -438,10 +436,8 @@ def _notify_operator_usage(
|
|
|
438
436
|
integrand: Integrand,
|
|
439
437
|
field_args: Dict[str, FieldLike],
|
|
440
438
|
):
|
|
441
|
-
for arg,
|
|
442
|
-
|
|
443
|
-
# print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
|
|
444
|
-
field_args[arg].notify_operator_usage(field_ops)
|
|
439
|
+
for arg, field in field_args.items():
|
|
440
|
+
field.notify_operator_usage(integrand.operators.get(arg, set()))
|
|
445
441
|
|
|
446
442
|
|
|
447
443
|
def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
@@ -615,8 +611,8 @@ def get_integrate_constant_kernel(
|
|
|
615
611
|
integrand_func: wp.Function,
|
|
616
612
|
domain: GeometryDomain,
|
|
617
613
|
quadrature: Quadrature,
|
|
618
|
-
FieldStruct:
|
|
619
|
-
ValueStruct:
|
|
614
|
+
FieldStruct: Struct,
|
|
615
|
+
ValueStruct: Struct,
|
|
620
616
|
accumulate_dtype,
|
|
621
617
|
tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
|
|
622
618
|
):
|
|
@@ -641,10 +637,13 @@ def get_integrate_constant_kernel(
|
|
|
641
637
|
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
642
638
|
|
|
643
639
|
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
644
|
-
|
|
640
|
+
element_index = NULL_ELEMENT_INDEX
|
|
645
641
|
else:
|
|
646
642
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
647
643
|
|
|
644
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
645
|
+
val = zero_element()
|
|
646
|
+
else:
|
|
648
647
|
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
649
648
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
650
649
|
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
@@ -667,8 +666,8 @@ def get_integrate_linear_kernel(
|
|
|
667
666
|
integrand_func: wp.Function,
|
|
668
667
|
domain: GeometryDomain,
|
|
669
668
|
quadrature: Quadrature,
|
|
670
|
-
FieldStruct:
|
|
671
|
-
ValueStruct:
|
|
669
|
+
FieldStruct: Struct,
|
|
670
|
+
ValueStruct: Struct,
|
|
672
671
|
test: TestField,
|
|
673
672
|
output_dtype,
|
|
674
673
|
accumulate_dtype,
|
|
@@ -684,6 +683,9 @@ def get_integrate_linear_kernel(
|
|
|
684
683
|
):
|
|
685
684
|
local_node_index, test_dof = wp.tid()
|
|
686
685
|
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
686
|
+
if node_index == NULL_NODE_INDEX:
|
|
687
|
+
return
|
|
688
|
+
|
|
687
689
|
element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
|
|
688
690
|
|
|
689
691
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -725,8 +727,8 @@ def get_integrate_linear_kernel(
|
|
|
725
727
|
def get_integrate_linear_nodal_kernel(
|
|
726
728
|
integrand_func: wp.Function,
|
|
727
729
|
domain: GeometryDomain,
|
|
728
|
-
FieldStruct:
|
|
729
|
-
ValueStruct:
|
|
730
|
+
FieldStruct: Struct,
|
|
731
|
+
ValueStruct: Struct,
|
|
730
732
|
test: TestField,
|
|
731
733
|
output_dtype,
|
|
732
734
|
accumulate_dtype,
|
|
@@ -743,6 +745,9 @@ def get_integrate_linear_nodal_kernel(
|
|
|
743
745
|
local_node_index, dof = wp.tid()
|
|
744
746
|
|
|
745
747
|
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
748
|
+
if partition_node_index == NULL_NODE_INDEX:
|
|
749
|
+
return
|
|
750
|
+
|
|
746
751
|
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
747
752
|
|
|
748
753
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -797,8 +802,8 @@ def get_integrate_linear_local_kernel(
|
|
|
797
802
|
integrand_func: wp.Function,
|
|
798
803
|
domain: GeometryDomain,
|
|
799
804
|
quadrature: Quadrature,
|
|
800
|
-
FieldStruct:
|
|
801
|
-
ValueStruct:
|
|
805
|
+
FieldStruct: Struct,
|
|
806
|
+
ValueStruct: Struct,
|
|
802
807
|
test: LocalTestField,
|
|
803
808
|
):
|
|
804
809
|
def integrate_kernel_fn(
|
|
@@ -817,6 +822,8 @@ def get_integrate_linear_local_kernel(
|
|
|
817
822
|
return
|
|
818
823
|
|
|
819
824
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
825
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
826
|
+
return
|
|
820
827
|
|
|
821
828
|
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
822
829
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
@@ -838,8 +845,8 @@ def get_integrate_bilinear_kernel(
|
|
|
838
845
|
integrand_func: wp.Function,
|
|
839
846
|
domain: GeometryDomain,
|
|
840
847
|
quadrature: Quadrature,
|
|
841
|
-
FieldStruct:
|
|
842
|
-
ValueStruct:
|
|
848
|
+
FieldStruct: Struct,
|
|
849
|
+
ValueStruct: Struct,
|
|
843
850
|
test: TestField,
|
|
844
851
|
trial: TrialField,
|
|
845
852
|
output_dtype,
|
|
@@ -863,6 +870,9 @@ def get_integrate_bilinear_kernel(
|
|
|
863
870
|
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
864
871
|
|
|
865
872
|
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
873
|
+
if test_node_index == NULL_NODE_INDEX:
|
|
874
|
+
return
|
|
875
|
+
|
|
866
876
|
element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
|
|
867
877
|
|
|
868
878
|
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
@@ -934,8 +944,8 @@ def get_integrate_bilinear_kernel(
|
|
|
934
944
|
def get_integrate_bilinear_nodal_kernel(
|
|
935
945
|
integrand_func: wp.Function,
|
|
936
946
|
domain: GeometryDomain,
|
|
937
|
-
FieldStruct:
|
|
938
|
-
ValueStruct:
|
|
947
|
+
FieldStruct: Struct,
|
|
948
|
+
ValueStruct: Struct,
|
|
939
949
|
test: TestField,
|
|
940
950
|
output_dtype,
|
|
941
951
|
accumulate_dtype,
|
|
@@ -954,6 +964,11 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
954
964
|
local_node_index, test_dof, trial_dof = wp.tid()
|
|
955
965
|
|
|
956
966
|
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
967
|
+
if partition_node_index == NULL_NODE_INDEX:
|
|
968
|
+
triplet_rows[local_node_index] = -1
|
|
969
|
+
triplet_cols[local_node_index] = -1
|
|
970
|
+
return
|
|
971
|
+
|
|
957
972
|
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
958
973
|
|
|
959
974
|
val_sum = accumulate_dtype(0.0)
|
|
@@ -1009,8 +1024,8 @@ def get_integrate_bilinear_local_kernel(
|
|
|
1009
1024
|
integrand_func: wp.Function,
|
|
1010
1025
|
domain: GeometryDomain,
|
|
1011
1026
|
quadrature: Quadrature,
|
|
1012
|
-
FieldStruct:
|
|
1013
|
-
ValueStruct:
|
|
1027
|
+
FieldStruct: Struct,
|
|
1028
|
+
ValueStruct: Struct,
|
|
1014
1029
|
test: LocalTestField,
|
|
1015
1030
|
trial: LocalTrialField,
|
|
1016
1031
|
):
|
|
@@ -1033,6 +1048,8 @@ def get_integrate_bilinear_local_kernel(
|
|
|
1033
1048
|
return
|
|
1034
1049
|
|
|
1035
1050
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1051
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
1052
|
+
return
|
|
1036
1053
|
|
|
1037
1054
|
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1038
1055
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
@@ -1066,23 +1083,27 @@ def _generate_integrate_kernel(
|
|
|
1066
1083
|
accumulate_dtype: type,
|
|
1067
1084
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1068
1085
|
) -> wp.Kernel:
|
|
1069
|
-
output_dtype =
|
|
1070
|
-
|
|
1071
|
-
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
1072
|
-
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
1086
|
+
output_dtype = type_scalar_type(output_dtype)
|
|
1073
1087
|
|
|
1074
1088
|
_notify_operator_usage(integrand, arguments.field_args)
|
|
1075
1089
|
|
|
1076
1090
|
# Check if kernel exist in cache
|
|
1077
|
-
field_names =
|
|
1078
|
-
kernel_suffix =
|
|
1091
|
+
field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
|
|
1092
|
+
kernel_suffix = ("itg", field_names, cache.pod_type_key(output_dtype), cache.pod_type_key(accumulate_dtype))
|
|
1079
1093
|
|
|
1080
1094
|
if quadrature is not None:
|
|
1081
|
-
kernel_suffix
|
|
1095
|
+
kernel_suffix = (quadrature.name, *kernel_suffix)
|
|
1082
1096
|
|
|
1083
|
-
kernel = cache.get_integrand_kernel(
|
|
1097
|
+
kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
|
|
1098
|
+
integrand=integrand,
|
|
1099
|
+
suffix=kernel_suffix,
|
|
1100
|
+
kernel_options=kernel_options,
|
|
1101
|
+
)
|
|
1084
1102
|
if kernel is not None:
|
|
1085
|
-
return kernel,
|
|
1103
|
+
return kernel, field_arg_values, value_struct_values
|
|
1104
|
+
|
|
1105
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
1106
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
1086
1107
|
|
|
1087
1108
|
# Not found in cache, transform integrand and generate kernel
|
|
1088
1109
|
_check_field_compat(integrand, arguments, domain)
|
|
@@ -1165,7 +1186,7 @@ def _generate_integrate_kernel(
|
|
|
1165
1186
|
accumulate_dtype=accumulate_dtype,
|
|
1166
1187
|
)
|
|
1167
1188
|
|
|
1168
|
-
kernel = cache.get_integrand_kernel(
|
|
1189
|
+
kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
|
|
1169
1190
|
integrand=integrand,
|
|
1170
1191
|
kernel_fn=integrate_kernel_fn,
|
|
1171
1192
|
suffix=kernel_suffix,
|
|
@@ -1175,9 +1196,11 @@ def _generate_integrate_kernel(
|
|
|
1175
1196
|
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
1176
1197
|
)
|
|
1177
1198
|
],
|
|
1199
|
+
FieldStruct=FieldStruct,
|
|
1200
|
+
ValueStruct=ValueStruct,
|
|
1178
1201
|
)
|
|
1179
1202
|
|
|
1180
|
-
return kernel, FieldStruct, ValueStruct
|
|
1203
|
+
return kernel, FieldStruct(), ValueStruct()
|
|
1181
1204
|
|
|
1182
1205
|
|
|
1183
1206
|
def _generate_auxiliary_kernels(
|
|
@@ -1220,8 +1243,8 @@ def _launch_integrate_kernel(
|
|
|
1220
1243
|
integrand: Integrand,
|
|
1221
1244
|
kernel: wp.Kernel,
|
|
1222
1245
|
auxiliary_kernels: List[Tuple[wp.Kernel, int]],
|
|
1223
|
-
|
|
1224
|
-
|
|
1246
|
+
field_arg_values: StructInstance,
|
|
1247
|
+
value_struct_values: StructInstance,
|
|
1225
1248
|
domain: GeometryDomain,
|
|
1226
1249
|
quadrature: Quadrature,
|
|
1227
1250
|
test: Optional[TestField],
|
|
@@ -1243,12 +1266,11 @@ def _launch_integrate_kernel(
|
|
|
1243
1266
|
if quadrature is not None:
|
|
1244
1267
|
qp_arg = quadrature.arg_value(device=device)
|
|
1245
1268
|
|
|
1246
|
-
field_arg_values = FieldStruct()
|
|
1247
1269
|
for k, v in fields.items():
|
|
1248
1270
|
if not isinstance(v, GeometryDomain):
|
|
1249
1271
|
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
1250
1272
|
|
|
1251
|
-
|
|
1273
|
+
cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
|
|
1252
1274
|
|
|
1253
1275
|
# Constant form
|
|
1254
1276
|
if test is None and trial is None:
|
|
@@ -1257,14 +1279,13 @@ def _launch_integrate_kernel(
|
|
|
1257
1279
|
raise RuntimeError("Output array must be of size at least 1")
|
|
1258
1280
|
accumulate_array = output
|
|
1259
1281
|
else:
|
|
1260
|
-
|
|
1282
|
+
accumulate_array = cache.borrow_temporary(
|
|
1261
1283
|
shape=(1),
|
|
1262
1284
|
device=device,
|
|
1263
1285
|
dtype=accumulate_dtype,
|
|
1264
1286
|
temporary_store=temporary_store,
|
|
1265
1287
|
requires_grad=output is not None and output.requires_grad,
|
|
1266
1288
|
)
|
|
1267
|
-
accumulate_array = accumulate_temporary.array
|
|
1268
1289
|
|
|
1269
1290
|
if output != accumulate_array or not add_to_output:
|
|
1270
1291
|
accumulate_array.zero_()
|
|
@@ -1315,21 +1336,17 @@ def _launch_integrate_kernel(
|
|
|
1315
1336
|
output_shape = (test.space_partition.node_count(), test.node_dof_count)
|
|
1316
1337
|
else:
|
|
1317
1338
|
raise RuntimeError(
|
|
1318
|
-
f"Incompatible output type {
|
|
1339
|
+
f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1319
1340
|
)
|
|
1320
1341
|
|
|
1321
|
-
|
|
1342
|
+
output = cache.borrow_temporary(
|
|
1322
1343
|
temporary_store=temporary_store,
|
|
1323
1344
|
shape=output_shape,
|
|
1324
1345
|
dtype=output_dtype,
|
|
1325
1346
|
device=device,
|
|
1326
1347
|
)
|
|
1327
1348
|
|
|
1328
|
-
output = output_temporary.array
|
|
1329
|
-
|
|
1330
1349
|
else:
|
|
1331
|
-
output_temporary = None
|
|
1332
|
-
|
|
1333
1350
|
if output.shape[0] < test.space_partition.node_count():
|
|
1334
1351
|
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
1335
1352
|
|
|
@@ -1337,7 +1354,7 @@ def _launch_integrate_kernel(
|
|
|
1337
1354
|
if type_size(output_dtype) != test.node_dof_count:
|
|
1338
1355
|
if type_size(output_dtype) != 1:
|
|
1339
1356
|
raise RuntimeError(
|
|
1340
|
-
f"Incompatible output type {
|
|
1357
|
+
f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1341
1358
|
)
|
|
1342
1359
|
if output.ndim != 2 and output.shape[1] != test.node_dof_count:
|
|
1343
1360
|
raise RuntimeError(
|
|
@@ -1355,7 +1372,7 @@ def _launch_integrate_kernel(
|
|
|
1355
1372
|
capacity=array.capacity,
|
|
1356
1373
|
device=array.device,
|
|
1357
1374
|
shape=(test.space_partition.node_count(), test.node_dof_count),
|
|
1358
|
-
dtype=
|
|
1375
|
+
dtype=type_scalar_type(output_dtype),
|
|
1359
1376
|
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
1360
1377
|
)
|
|
1361
1378
|
|
|
@@ -1387,7 +1404,7 @@ def _launch_integrate_kernel(
|
|
|
1387
1404
|
|
|
1388
1405
|
wp.launch(
|
|
1389
1406
|
kernel=kernel,
|
|
1390
|
-
dim=local_result.
|
|
1407
|
+
dim=local_result.shape,
|
|
1391
1408
|
inputs=[
|
|
1392
1409
|
qp_arg,
|
|
1393
1410
|
quadrature.element_index_arg_value(device),
|
|
@@ -1395,13 +1412,13 @@ def _launch_integrate_kernel(
|
|
|
1395
1412
|
domain_elt_index_arg,
|
|
1396
1413
|
field_arg_values,
|
|
1397
1414
|
value_struct_values,
|
|
1398
|
-
local_result
|
|
1415
|
+
local_result,
|
|
1399
1416
|
],
|
|
1400
1417
|
device=device,
|
|
1401
1418
|
)
|
|
1402
1419
|
|
|
1403
1420
|
if test.TAYLOR_DOF_COUNT == 0:
|
|
1404
|
-
|
|
1421
|
+
warn(
|
|
1405
1422
|
f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1406
1423
|
category=UserWarning,
|
|
1407
1424
|
stacklevel=2,
|
|
@@ -1418,7 +1435,7 @@ def _launch_integrate_kernel(
|
|
|
1418
1435
|
domain_elt_index_arg,
|
|
1419
1436
|
test_arg,
|
|
1420
1437
|
test.space.space_arg_value(device),
|
|
1421
|
-
local_result
|
|
1438
|
+
local_result,
|
|
1422
1439
|
output_view,
|
|
1423
1440
|
],
|
|
1424
1441
|
device=device,
|
|
@@ -1442,9 +1459,6 @@ def _launch_integrate_kernel(
|
|
|
1442
1459
|
device=device,
|
|
1443
1460
|
)
|
|
1444
1461
|
|
|
1445
|
-
if output_temporary is not None:
|
|
1446
|
-
return output_temporary.detach()
|
|
1447
|
-
|
|
1448
1462
|
return output
|
|
1449
1463
|
|
|
1450
1464
|
# Bilinear form
|
|
@@ -1475,8 +1489,6 @@ def _launch_integrate_kernel(
|
|
|
1475
1489
|
triplet_rows = triplet_rows_temp.array
|
|
1476
1490
|
triplet_values = triplet_values_temp.array
|
|
1477
1491
|
|
|
1478
|
-
triplet_values.zero_()
|
|
1479
|
-
|
|
1480
1492
|
if nodal:
|
|
1481
1493
|
wp.launch(
|
|
1482
1494
|
kernel=kernel,
|
|
@@ -1524,13 +1536,13 @@ def _launch_integrate_kernel(
|
|
|
1524
1536
|
domain_elt_index_arg,
|
|
1525
1537
|
field_arg_values,
|
|
1526
1538
|
value_struct_values,
|
|
1527
|
-
local_result
|
|
1539
|
+
local_result,
|
|
1528
1540
|
],
|
|
1529
1541
|
device=device,
|
|
1530
1542
|
)
|
|
1531
1543
|
|
|
1532
1544
|
if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
|
|
1533
|
-
|
|
1545
|
+
warn(
|
|
1534
1546
|
f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1535
1547
|
category=UserWarning,
|
|
1536
1548
|
stacklevel=2,
|
|
@@ -1557,7 +1569,7 @@ def _launch_integrate_kernel(
|
|
|
1557
1569
|
trial_partition_arg,
|
|
1558
1570
|
trial_topology_arg,
|
|
1559
1571
|
trial.space.space_arg_value(device),
|
|
1560
|
-
local_result
|
|
1572
|
+
local_result,
|
|
1561
1573
|
triplet_rows,
|
|
1562
1574
|
triplet_cols,
|
|
1563
1575
|
triplet_values,
|
|
@@ -1626,20 +1638,12 @@ def _launch_integrate_kernel(
|
|
|
1626
1638
|
|
|
1627
1639
|
|
|
1628
1640
|
def _pick_assembly_strategy(
|
|
1629
|
-
assembly: Optional[str],
|
|
1641
|
+
assembly: Optional[str], operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
|
|
1630
1642
|
):
|
|
1631
1643
|
if assembly is not None:
|
|
1632
1644
|
if assembly not in ("generic", "nodal", "dispatch"):
|
|
1633
1645
|
raise ValueError(f"Invalid assembly strategy'{assembly}'")
|
|
1634
1646
|
return assembly
|
|
1635
|
-
elif nodal is not None:
|
|
1636
|
-
wp.utils.warn(
|
|
1637
|
-
"'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
|
|
1638
|
-
category=DeprecationWarning,
|
|
1639
|
-
stacklevel=2,
|
|
1640
|
-
)
|
|
1641
|
-
if nodal:
|
|
1642
|
-
return "nodal"
|
|
1643
1647
|
|
|
1644
1648
|
test_operators = operators.get(arguments.test_name, set())
|
|
1645
1649
|
trial_operators = operators.get(arguments.trial_name, set())
|
|
@@ -1655,7 +1659,6 @@ def integrate(
|
|
|
1655
1659
|
integrand: Integrand,
|
|
1656
1660
|
domain: Optional[GeometryDomain] = None,
|
|
1657
1661
|
quadrature: Optional[Quadrature] = None,
|
|
1658
|
-
nodal: Optional[bool] = None,
|
|
1659
1662
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
1660
1663
|
values: Optional[Dict[str, Any]] = None,
|
|
1661
1664
|
accumulate_dtype: type = wp.float64,
|
|
@@ -1675,7 +1678,6 @@ def integrate(
|
|
|
1675
1678
|
integrand: Form to be integrated, must have :func:`integrand` decorator
|
|
1676
1679
|
domain: Integration domain. If None, deduced from fields
|
|
1677
1680
|
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1678
|
-
nodal: Deprecated. Use the equivalent assembly="nodal" instead.
|
|
1679
1681
|
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1680
1682
|
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1681
1683
|
temporary_store: shared pool from which to allocate temporary arrays
|
|
@@ -1738,9 +1740,9 @@ def integrate(
|
|
|
1738
1740
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
1739
1741
|
|
|
1740
1742
|
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
1741
|
-
|
|
1743
|
+
warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
1742
1744
|
|
|
1743
|
-
assembly = _pick_assembly_strategy(assembly,
|
|
1745
|
+
assembly = _pick_assembly_strategy(assembly, arguments=arguments, operators=integrand.operators)
|
|
1744
1746
|
# print("assembly for ", integrand.name, ":", strategy)
|
|
1745
1747
|
|
|
1746
1748
|
if assembly == "dispatch":
|
|
@@ -1770,7 +1772,7 @@ def integrate(
|
|
|
1770
1772
|
raise ValueError("Incompatible integration and quadrature domain")
|
|
1771
1773
|
|
|
1772
1774
|
# Canonicalize types
|
|
1773
|
-
accumulate_dtype =
|
|
1775
|
+
accumulate_dtype = type_to_warp(accumulate_dtype)
|
|
1774
1776
|
if output is not None:
|
|
1775
1777
|
if isinstance(output, BsrMatrix):
|
|
1776
1778
|
output_dtype = output.scalar_type
|
|
@@ -1779,9 +1781,9 @@ def integrate(
|
|
|
1779
1781
|
elif output_dtype is None:
|
|
1780
1782
|
output_dtype = accumulate_dtype
|
|
1781
1783
|
else:
|
|
1782
|
-
output_dtype =
|
|
1784
|
+
output_dtype = type_to_warp(output_dtype)
|
|
1783
1785
|
|
|
1784
|
-
kernel,
|
|
1786
|
+
kernel, field_arg_values, value_struct_values = _generate_integrate_kernel(
|
|
1785
1787
|
integrand=integrand,
|
|
1786
1788
|
domain=domain,
|
|
1787
1789
|
quadrature=quadrature,
|
|
@@ -1806,8 +1808,8 @@ def integrate(
|
|
|
1806
1808
|
integrand=integrand,
|
|
1807
1809
|
kernel=kernel,
|
|
1808
1810
|
auxiliary_kernels=auxiliary_kernels,
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
+
field_arg_values=field_arg_values,
|
|
1812
|
+
value_struct_values=value_struct_values,
|
|
1811
1813
|
domain=domain,
|
|
1812
1814
|
quadrature=quadrature,
|
|
1813
1815
|
test=test,
|
|
@@ -1827,14 +1829,14 @@ def integrate(
|
|
|
1827
1829
|
def get_interpolate_to_field_function(
|
|
1828
1830
|
integrand_func: wp.Function,
|
|
1829
1831
|
domain: GeometryDomain,
|
|
1830
|
-
FieldStruct:
|
|
1831
|
-
ValueStruct:
|
|
1832
|
+
FieldStruct: Struct,
|
|
1833
|
+
ValueStruct: Struct,
|
|
1832
1834
|
dest: FieldRestriction,
|
|
1833
1835
|
):
|
|
1834
1836
|
zero_value = type_zero_element(dest.space.dtype)
|
|
1835
1837
|
|
|
1836
1838
|
def interpolate_to_field_fn(
|
|
1837
|
-
|
|
1839
|
+
partition_node_index: int,
|
|
1838
1840
|
domain_arg: domain.ElementArg,
|
|
1839
1841
|
domain_index_arg: domain.ElementIndexArg,
|
|
1840
1842
|
dest_node_arg: dest.space_restriction.NodeArg,
|
|
@@ -1842,7 +1844,6 @@ def get_interpolate_to_field_function(
|
|
|
1842
1844
|
fields: FieldStruct,
|
|
1843
1845
|
values: ValueStruct,
|
|
1844
1846
|
):
|
|
1845
|
-
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1846
1847
|
element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
|
|
1847
1848
|
|
|
1848
1849
|
test_dof_index = NULL_DOF_INDEX
|
|
@@ -1894,8 +1895,8 @@ def get_interpolate_to_field_function(
|
|
|
1894
1895
|
def get_interpolate_to_field_kernel(
|
|
1895
1896
|
interpolate_to_field_fn: wp.Function,
|
|
1896
1897
|
domain: GeometryDomain,
|
|
1897
|
-
FieldStruct:
|
|
1898
|
-
ValueStruct:
|
|
1898
|
+
FieldStruct: Struct,
|
|
1899
|
+
ValueStruct: Struct,
|
|
1899
1900
|
dest: FieldRestriction,
|
|
1900
1901
|
):
|
|
1901
1902
|
@wp.func
|
|
@@ -1932,13 +1933,15 @@ def get_interpolate_to_field_kernel(
|
|
|
1932
1933
|
):
|
|
1933
1934
|
local_node_index = wp.tid()
|
|
1934
1935
|
|
|
1936
|
+
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1937
|
+
if partition_node_index == NULL_NODE_INDEX:
|
|
1938
|
+
return
|
|
1939
|
+
|
|
1935
1940
|
val_sum, vol_sum = interpolate_to_field_fn(
|
|
1936
1941
|
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
|
|
1937
1942
|
)
|
|
1938
1943
|
|
|
1939
1944
|
if vol_sum > 0.0:
|
|
1940
|
-
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1941
|
-
|
|
1942
1945
|
# Grab first element containing node; there must be at least one since vol_sum != 0
|
|
1943
1946
|
element_index, node_index_in_element = _find_node_in_element(
|
|
1944
1947
|
domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
|
|
@@ -1959,8 +1962,8 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
1959
1962
|
integrand_func: wp.Function,
|
|
1960
1963
|
domain: GeometryDomain,
|
|
1961
1964
|
quadrature: Quadrature,
|
|
1962
|
-
FieldStruct:
|
|
1963
|
-
ValueStruct:
|
|
1965
|
+
FieldStruct: Struct,
|
|
1966
|
+
ValueStruct: Struct,
|
|
1964
1967
|
value_type: type,
|
|
1965
1968
|
):
|
|
1966
1969
|
def interpolate_at_quadrature_nonvalued_kernel_fn(
|
|
@@ -1978,6 +1981,8 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
1978
1981
|
return
|
|
1979
1982
|
|
|
1980
1983
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1984
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
1985
|
+
return
|
|
1981
1986
|
|
|
1982
1987
|
test_dof_index = NULL_DOF_INDEX
|
|
1983
1988
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -2004,6 +2009,8 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
2004
2009
|
return
|
|
2005
2010
|
|
|
2006
2011
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
2012
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
2013
|
+
return
|
|
2007
2014
|
|
|
2008
2015
|
test_dof_index = NULL_DOF_INDEX
|
|
2009
2016
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -2022,8 +2029,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
|
|
|
2022
2029
|
integrand_func: wp.Function,
|
|
2023
2030
|
domain: GeometryDomain,
|
|
2024
2031
|
quadrature: Quadrature,
|
|
2025
|
-
FieldStruct:
|
|
2026
|
-
ValueStruct:
|
|
2032
|
+
FieldStruct: Struct,
|
|
2033
|
+
ValueStruct: Struct,
|
|
2027
2034
|
trial: TrialField,
|
|
2028
2035
|
value_size: int,
|
|
2029
2036
|
value_type: type,
|
|
@@ -2046,11 +2053,13 @@ def get_interpolate_jacobian_at_quadrature_kernel(
|
|
|
2046
2053
|
):
|
|
2047
2054
|
qp_eval_index, trial_node, trial_dof = wp.tid()
|
|
2048
2055
|
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
2049
|
-
|
|
2050
2056
|
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
2051
2057
|
return
|
|
2052
2058
|
|
|
2053
2059
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
2060
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
2061
|
+
return
|
|
2062
|
+
|
|
2054
2063
|
if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
|
|
2055
2064
|
return
|
|
2056
2065
|
|
|
@@ -2090,8 +2099,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
|
|
|
2090
2099
|
def get_interpolate_free_kernel(
|
|
2091
2100
|
integrand_func: wp.Function,
|
|
2092
2101
|
domain: GeometryDomain,
|
|
2093
|
-
FieldStruct:
|
|
2094
|
-
ValueStruct:
|
|
2102
|
+
FieldStruct: Struct,
|
|
2103
|
+
ValueStruct: Struct,
|
|
2095
2104
|
value_type: type,
|
|
2096
2105
|
):
|
|
2097
2106
|
def interpolate_free_nonvalued_kernel_fn(
|
|
@@ -2144,31 +2153,31 @@ def _generate_interpolate_kernel(
|
|
|
2144
2153
|
arguments: IntegrandArguments,
|
|
2145
2154
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
2146
2155
|
) -> wp.Kernel:
|
|
2147
|
-
# Generate field struct
|
|
2148
|
-
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
2149
|
-
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
2150
|
-
|
|
2151
2156
|
_notify_operator_usage(integrand, arguments.field_args)
|
|
2152
2157
|
|
|
2153
2158
|
# Check if kernel exist in cache
|
|
2154
|
-
field_names =
|
|
2159
|
+
field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
|
|
2155
2160
|
if isinstance(dest, FieldRestriction):
|
|
2156
|
-
kernel_suffix =
|
|
2161
|
+
kernel_suffix = ("itp", *field_names, dest.domain.name, dest.space_restriction.space_partition.name)
|
|
2157
2162
|
else:
|
|
2158
2163
|
dest_dtype = dest.dtype if dest else None
|
|
2159
|
-
type_str =
|
|
2164
|
+
type_str = cache.pod_type_key(dest_dtype) if dest_dtype else ""
|
|
2160
2165
|
if quadrature is None:
|
|
2161
|
-
kernel_suffix =
|
|
2166
|
+
kernel_suffix = ("itp", *field_names, domain.name, type_str)
|
|
2162
2167
|
else:
|
|
2163
|
-
kernel_suffix =
|
|
2168
|
+
kernel_suffix = ("itp", *field_names, domain.name, quadrature.name, type_str)
|
|
2164
2169
|
|
|
2165
|
-
kernel = cache.get_integrand_kernel(
|
|
2170
|
+
kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
|
|
2166
2171
|
integrand=integrand,
|
|
2167
2172
|
suffix=kernel_suffix,
|
|
2168
2173
|
kernel_options=kernel_options,
|
|
2169
2174
|
)
|
|
2170
2175
|
if kernel is not None:
|
|
2171
|
-
return kernel,
|
|
2176
|
+
return kernel, field_arg_values, value_struct_values
|
|
2177
|
+
|
|
2178
|
+
# Generate field struct
|
|
2179
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
2180
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
2172
2181
|
|
|
2173
2182
|
# Not found in cache, transform integrand and generate kernel
|
|
2174
2183
|
_check_field_compat(integrand, arguments, domain)
|
|
@@ -2235,7 +2244,7 @@ def _generate_interpolate_kernel(
|
|
|
2235
2244
|
ValueStruct=ValueStruct,
|
|
2236
2245
|
)
|
|
2237
2246
|
|
|
2238
|
-
kernel = cache.get_integrand_kernel(
|
|
2247
|
+
kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
|
|
2239
2248
|
integrand=integrand,
|
|
2240
2249
|
kernel_fn=interpolate_kernel_fn,
|
|
2241
2250
|
suffix=kernel_suffix,
|
|
@@ -2245,16 +2254,18 @@ def _generate_interpolate_kernel(
|
|
|
2245
2254
|
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
2246
2255
|
)
|
|
2247
2256
|
],
|
|
2257
|
+
FieldStruct=FieldStruct,
|
|
2258
|
+
ValueStruct=ValueStruct,
|
|
2248
2259
|
)
|
|
2249
2260
|
|
|
2250
|
-
return kernel, FieldStruct, ValueStruct
|
|
2261
|
+
return kernel, FieldStruct(), ValueStruct()
|
|
2251
2262
|
|
|
2252
2263
|
|
|
2253
2264
|
def _launch_interpolate_kernel(
|
|
2254
2265
|
integrand: Integrand,
|
|
2255
2266
|
kernel: wp.kernel,
|
|
2256
|
-
|
|
2257
|
-
|
|
2267
|
+
field_arg_values: StructInstance,
|
|
2268
|
+
value_struct_values: StructInstance,
|
|
2258
2269
|
domain: GeometryDomain,
|
|
2259
2270
|
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
2260
2271
|
quadrature: Optional[Quadrature],
|
|
@@ -2270,12 +2281,10 @@ def _launch_interpolate_kernel(
|
|
|
2270
2281
|
elt_arg = domain.element_arg_value(device=device)
|
|
2271
2282
|
elt_index_arg = domain.element_index_arg_value(device=device)
|
|
2272
2283
|
|
|
2273
|
-
field_arg_values = FieldStruct()
|
|
2274
2284
|
for k, v in fields.items():
|
|
2275
2285
|
if not isinstance(v, GeometryDomain):
|
|
2276
2286
|
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
2277
|
-
|
|
2278
|
-
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
2287
|
+
cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
|
|
2279
2288
|
|
|
2280
2289
|
if isinstance(dest, FieldRestriction):
|
|
2281
2290
|
dest_node_arg = dest.space_restriction.node_arg_value(device=device)
|
|
@@ -2313,7 +2322,7 @@ def _launch_interpolate_kernel(
|
|
|
2313
2322
|
qp_index_count = quadrature.total_point_count()
|
|
2314
2323
|
|
|
2315
2324
|
if qp_eval_count != qp_index_count:
|
|
2316
|
-
|
|
2325
|
+
warn(
|
|
2317
2326
|
f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
|
|
2318
2327
|
category=UserWarning,
|
|
2319
2328
|
stacklevel=2,
|
|
@@ -2353,7 +2362,6 @@ def _launch_interpolate_kernel(
|
|
|
2353
2362
|
triplet_rows = triplet_rows_temp.array
|
|
2354
2363
|
triplet_values = triplet_values_temp.array
|
|
2355
2364
|
triplet_rows.fill_(-1)
|
|
2356
|
-
triplet_values.zero_()
|
|
2357
2365
|
|
|
2358
2366
|
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
2359
2367
|
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
@@ -2470,9 +2478,9 @@ def interpolate(
|
|
|
2470
2478
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
2471
2479
|
|
|
2472
2480
|
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
2473
|
-
|
|
2481
|
+
warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
2474
2482
|
|
|
2475
|
-
kernel,
|
|
2483
|
+
kernel, field_struct, value_struct = _generate_interpolate_kernel(
|
|
2476
2484
|
integrand=integrand,
|
|
2477
2485
|
domain=domain,
|
|
2478
2486
|
dest=dest,
|
|
@@ -2484,8 +2492,8 @@ def interpolate(
|
|
|
2484
2492
|
return _launch_interpolate_kernel(
|
|
2485
2493
|
integrand=integrand,
|
|
2486
2494
|
kernel=kernel,
|
|
2487
|
-
|
|
2488
|
-
|
|
2495
|
+
field_arg_values=field_struct,
|
|
2496
|
+
value_struct_values=value_struct,
|
|
2489
2497
|
domain=domain,
|
|
2490
2498
|
dest=dest,
|
|
2491
2499
|
quadrature=quadrature,
|