warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -19,11 +19,11 @@ import textwrap
|
|
|
19
19
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import warp as wp
|
|
22
|
-
import warp.fem.operator as operator
|
|
23
|
-
from warp.codegen import get_annotations
|
|
24
|
-
from warp.fem import cache
|
|
25
|
-
from warp.fem.domain import GeometryDomain
|
|
26
|
-
from warp.fem.field import (
|
|
22
|
+
import warp._src.fem.operator as operator
|
|
23
|
+
from warp._src.codegen import Struct, StructInstance, get_annotations
|
|
24
|
+
from warp._src.fem import cache
|
|
25
|
+
from warp._src.fem.domain import GeometryDomain
|
|
26
|
+
from warp._src.fem.field import (
|
|
27
27
|
DiscreteField,
|
|
28
28
|
FieldLike,
|
|
29
29
|
FieldRestriction,
|
|
@@ -34,18 +34,18 @@ from warp.fem.field import (
|
|
|
34
34
|
TrialField,
|
|
35
35
|
make_restriction,
|
|
36
36
|
)
|
|
37
|
-
from warp.fem.field.virtual import (
|
|
37
|
+
from warp._src.fem.field.virtual import (
|
|
38
38
|
make_bilinear_dispatch_kernel,
|
|
39
39
|
make_linear_dispatch_kernel,
|
|
40
40
|
)
|
|
41
|
-
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
42
|
-
from warp.fem.operator import (
|
|
41
|
+
from warp._src.fem.linalg import array_axpy, basis_coefficient
|
|
42
|
+
from warp._src.fem.operator import (
|
|
43
43
|
Integrand,
|
|
44
44
|
Operator,
|
|
45
45
|
integrand,
|
|
46
46
|
)
|
|
47
|
-
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
48
|
-
from warp.fem.types import (
|
|
47
|
+
from warp._src.fem.quadrature import Quadrature, RegularQuadrature
|
|
48
|
+
from warp._src.fem.types import (
|
|
49
49
|
NULL_DOF_INDEX,
|
|
50
50
|
NULL_ELEMENT_INDEX,
|
|
51
51
|
NULL_NODE_INDEX,
|
|
@@ -57,15 +57,15 @@ from warp.fem.types import (
|
|
|
57
57
|
Sample,
|
|
58
58
|
make_free_sample,
|
|
59
59
|
)
|
|
60
|
-
from warp.fem.utils import type_zero_element
|
|
61
|
-
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
62
|
-
from warp.types import is_array, type_size
|
|
63
|
-
from warp.utils import array_cast
|
|
60
|
+
from warp._src.fem.utils import type_zero_element
|
|
61
|
+
from warp._src.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
62
|
+
from warp._src.types import is_array, type_repr, type_scalar_type, type_size, type_to_warp
|
|
63
|
+
from warp._src.utils import array_cast, warn
|
|
64
64
|
|
|
65
65
|
|
|
66
66
|
def _resolve_path(func, node):
|
|
67
67
|
"""
|
|
68
|
-
Resolves variable and path from ast node/attribute (adapted from warp.codegen)
|
|
68
|
+
Resolves variable and path from ast node/attribute (adapted from warp._src.codegen)
|
|
69
69
|
"""
|
|
70
70
|
|
|
71
71
|
modules = []
|
|
@@ -83,20 +83,20 @@ def _resolve_path(func, node):
|
|
|
83
83
|
if len(path) == 0:
|
|
84
84
|
return None, path
|
|
85
85
|
|
|
86
|
-
|
|
86
|
+
name = path[0]
|
|
87
87
|
try:
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
88
|
+
# look up in closure variables
|
|
89
|
+
idx = func.__code__.co_freevars.index(name)
|
|
90
|
+
expr = func.__closure__[idx].cell_contents
|
|
91
|
+
except ValueError:
|
|
92
|
+
# look up in global variables
|
|
93
|
+
expr = func.__globals__.get(name)
|
|
94
|
+
|
|
95
|
+
for name in path[1:]:
|
|
96
|
+
if expr is not None:
|
|
97
|
+
expr = getattr(expr, name, None)
|
|
98
98
|
|
|
99
|
-
return
|
|
99
|
+
return expr, path
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
class IntegrandVisitor(ast.NodeTransformer):
|
|
@@ -275,7 +275,7 @@ class IntegrandTransformer(IntegrandVisitor):
|
|
|
275
275
|
try:
|
|
276
276
|
# Retrieve the function pointer corresponding to the operator implementation for the field type
|
|
277
277
|
pointer = operator.resolver(field)
|
|
278
|
-
if not isinstance(pointer, wp.
|
|
278
|
+
if not isinstance(pointer, wp.Function):
|
|
279
279
|
raise NotImplementedError(operator.resolver.__name__)
|
|
280
280
|
|
|
281
281
|
except (AttributeError, NotImplementedError) as e:
|
|
@@ -360,15 +360,13 @@ def _parse_integrand_arguments(
|
|
|
360
360
|
trial_name = None
|
|
361
361
|
|
|
362
362
|
argspec = integrand.argspec
|
|
363
|
-
for arg in argspec.
|
|
364
|
-
arg_type = argspec.annotations[arg]
|
|
363
|
+
for arg, arg_type in argspec.annotations.items():
|
|
365
364
|
if arg_type == Field:
|
|
366
365
|
try:
|
|
367
366
|
field = fields[arg]
|
|
368
367
|
except KeyError as err:
|
|
369
368
|
raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
|
|
370
|
-
|
|
371
|
-
raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
|
|
369
|
+
|
|
372
370
|
if isinstance(field, TestField):
|
|
373
371
|
if test_name is not None:
|
|
374
372
|
raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
|
|
@@ -377,28 +375,26 @@ def _parse_integrand_arguments(
|
|
|
377
375
|
if trial_name is not None:
|
|
378
376
|
raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
|
|
379
377
|
trial_name = arg
|
|
378
|
+
elif not isinstance(field, FieldLike):
|
|
379
|
+
raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
|
|
380
|
+
|
|
380
381
|
field_args[arg] = field
|
|
381
|
-
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
if arg in fields:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
f"Cannot pass a field argument to '{arg}' of '{integrand.name}' which is not of type 'Field'"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if arg_type == Domain:
|
|
382
390
|
if domain_name is not None:
|
|
383
391
|
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
|
|
384
|
-
if arg in fields:
|
|
385
|
-
raise ValueError(
|
|
386
|
-
f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
387
|
-
)
|
|
388
392
|
domain_name = arg
|
|
389
393
|
elif arg_type == Sample:
|
|
390
394
|
if sample_name is not None:
|
|
391
395
|
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
|
|
392
|
-
if arg in fields:
|
|
393
|
-
raise ValueError(
|
|
394
|
-
f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
395
|
-
)
|
|
396
396
|
sample_name = arg
|
|
397
397
|
else:
|
|
398
|
-
if arg in fields:
|
|
399
|
-
raise ValueError(
|
|
400
|
-
f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
|
|
401
|
-
)
|
|
402
398
|
value_args[arg] = arg_type
|
|
403
399
|
|
|
404
400
|
return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
|
|
@@ -438,10 +434,8 @@ def _notify_operator_usage(
|
|
|
438
434
|
integrand: Integrand,
|
|
439
435
|
field_args: Dict[str, FieldLike],
|
|
440
436
|
):
|
|
441
|
-
for arg,
|
|
442
|
-
|
|
443
|
-
# print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
|
|
444
|
-
field_args[arg].notify_operator_usage(field_ops)
|
|
437
|
+
for arg, field in field_args.items():
|
|
438
|
+
field.notify_operator_usage(integrand.operators.get(arg, set()))
|
|
445
439
|
|
|
446
440
|
|
|
447
441
|
def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
@@ -615,8 +609,8 @@ def get_integrate_constant_kernel(
|
|
|
615
609
|
integrand_func: wp.Function,
|
|
616
610
|
domain: GeometryDomain,
|
|
617
611
|
quadrature: Quadrature,
|
|
618
|
-
FieldStruct:
|
|
619
|
-
ValueStruct:
|
|
612
|
+
FieldStruct: Struct,
|
|
613
|
+
ValueStruct: Struct,
|
|
620
614
|
accumulate_dtype,
|
|
621
615
|
tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
|
|
622
616
|
):
|
|
@@ -641,10 +635,13 @@ def get_integrate_constant_kernel(
|
|
|
641
635
|
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
642
636
|
|
|
643
637
|
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
644
|
-
|
|
638
|
+
element_index = NULL_ELEMENT_INDEX
|
|
645
639
|
else:
|
|
646
640
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
647
641
|
|
|
642
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
643
|
+
val = zero_element()
|
|
644
|
+
else:
|
|
648
645
|
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
649
646
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
650
647
|
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
@@ -667,8 +664,8 @@ def get_integrate_linear_kernel(
|
|
|
667
664
|
integrand_func: wp.Function,
|
|
668
665
|
domain: GeometryDomain,
|
|
669
666
|
quadrature: Quadrature,
|
|
670
|
-
FieldStruct:
|
|
671
|
-
ValueStruct:
|
|
667
|
+
FieldStruct: Struct,
|
|
668
|
+
ValueStruct: Struct,
|
|
672
669
|
test: TestField,
|
|
673
670
|
output_dtype,
|
|
674
671
|
accumulate_dtype,
|
|
@@ -684,6 +681,9 @@ def get_integrate_linear_kernel(
|
|
|
684
681
|
):
|
|
685
682
|
local_node_index, test_dof = wp.tid()
|
|
686
683
|
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
684
|
+
if node_index == NULL_NODE_INDEX:
|
|
685
|
+
return
|
|
686
|
+
|
|
687
687
|
element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
|
|
688
688
|
|
|
689
689
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -725,8 +725,8 @@ def get_integrate_linear_kernel(
|
|
|
725
725
|
def get_integrate_linear_nodal_kernel(
|
|
726
726
|
integrand_func: wp.Function,
|
|
727
727
|
domain: GeometryDomain,
|
|
728
|
-
FieldStruct:
|
|
729
|
-
ValueStruct:
|
|
728
|
+
FieldStruct: Struct,
|
|
729
|
+
ValueStruct: Struct,
|
|
730
730
|
test: TestField,
|
|
731
731
|
output_dtype,
|
|
732
732
|
accumulate_dtype,
|
|
@@ -743,6 +743,9 @@ def get_integrate_linear_nodal_kernel(
|
|
|
743
743
|
local_node_index, dof = wp.tid()
|
|
744
744
|
|
|
745
745
|
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
746
|
+
if partition_node_index == NULL_NODE_INDEX:
|
|
747
|
+
return
|
|
748
|
+
|
|
746
749
|
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
747
750
|
|
|
748
751
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -797,8 +800,8 @@ def get_integrate_linear_local_kernel(
|
|
|
797
800
|
integrand_func: wp.Function,
|
|
798
801
|
domain: GeometryDomain,
|
|
799
802
|
quadrature: Quadrature,
|
|
800
|
-
FieldStruct:
|
|
801
|
-
ValueStruct:
|
|
803
|
+
FieldStruct: Struct,
|
|
804
|
+
ValueStruct: Struct,
|
|
802
805
|
test: LocalTestField,
|
|
803
806
|
):
|
|
804
807
|
def integrate_kernel_fn(
|
|
@@ -817,6 +820,8 @@ def get_integrate_linear_local_kernel(
|
|
|
817
820
|
return
|
|
818
821
|
|
|
819
822
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
823
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
824
|
+
return
|
|
820
825
|
|
|
821
826
|
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
822
827
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
@@ -838,8 +843,8 @@ def get_integrate_bilinear_kernel(
|
|
|
838
843
|
integrand_func: wp.Function,
|
|
839
844
|
domain: GeometryDomain,
|
|
840
845
|
quadrature: Quadrature,
|
|
841
|
-
FieldStruct:
|
|
842
|
-
ValueStruct:
|
|
846
|
+
FieldStruct: Struct,
|
|
847
|
+
ValueStruct: Struct,
|
|
843
848
|
test: TestField,
|
|
844
849
|
trial: TrialField,
|
|
845
850
|
output_dtype,
|
|
@@ -863,6 +868,9 @@ def get_integrate_bilinear_kernel(
|
|
|
863
868
|
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
864
869
|
|
|
865
870
|
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
871
|
+
if test_node_index == NULL_NODE_INDEX:
|
|
872
|
+
return
|
|
873
|
+
|
|
866
874
|
element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
|
|
867
875
|
|
|
868
876
|
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
@@ -934,8 +942,8 @@ def get_integrate_bilinear_kernel(
|
|
|
934
942
|
def get_integrate_bilinear_nodal_kernel(
|
|
935
943
|
integrand_func: wp.Function,
|
|
936
944
|
domain: GeometryDomain,
|
|
937
|
-
FieldStruct:
|
|
938
|
-
ValueStruct:
|
|
945
|
+
FieldStruct: Struct,
|
|
946
|
+
ValueStruct: Struct,
|
|
939
947
|
test: TestField,
|
|
940
948
|
output_dtype,
|
|
941
949
|
accumulate_dtype,
|
|
@@ -954,6 +962,11 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
954
962
|
local_node_index, test_dof, trial_dof = wp.tid()
|
|
955
963
|
|
|
956
964
|
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
965
|
+
if partition_node_index == NULL_NODE_INDEX:
|
|
966
|
+
triplet_rows[local_node_index] = -1
|
|
967
|
+
triplet_cols[local_node_index] = -1
|
|
968
|
+
return
|
|
969
|
+
|
|
957
970
|
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
958
971
|
|
|
959
972
|
val_sum = accumulate_dtype(0.0)
|
|
@@ -1009,8 +1022,8 @@ def get_integrate_bilinear_local_kernel(
|
|
|
1009
1022
|
integrand_func: wp.Function,
|
|
1010
1023
|
domain: GeometryDomain,
|
|
1011
1024
|
quadrature: Quadrature,
|
|
1012
|
-
FieldStruct:
|
|
1013
|
-
ValueStruct:
|
|
1025
|
+
FieldStruct: Struct,
|
|
1026
|
+
ValueStruct: Struct,
|
|
1014
1027
|
test: LocalTestField,
|
|
1015
1028
|
trial: LocalTrialField,
|
|
1016
1029
|
):
|
|
@@ -1033,6 +1046,8 @@ def get_integrate_bilinear_local_kernel(
|
|
|
1033
1046
|
return
|
|
1034
1047
|
|
|
1035
1048
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1049
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
1050
|
+
return
|
|
1036
1051
|
|
|
1037
1052
|
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1038
1053
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
@@ -1066,23 +1081,27 @@ def _generate_integrate_kernel(
|
|
|
1066
1081
|
accumulate_dtype: type,
|
|
1067
1082
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1068
1083
|
) -> wp.Kernel:
|
|
1069
|
-
output_dtype =
|
|
1070
|
-
|
|
1071
|
-
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
1072
|
-
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
1084
|
+
output_dtype = type_scalar_type(output_dtype)
|
|
1073
1085
|
|
|
1074
1086
|
_notify_operator_usage(integrand, arguments.field_args)
|
|
1075
1087
|
|
|
1076
1088
|
# Check if kernel exist in cache
|
|
1077
|
-
field_names =
|
|
1078
|
-
kernel_suffix =
|
|
1089
|
+
field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
|
|
1090
|
+
kernel_suffix = ("itg", field_names, cache.pod_type_key(output_dtype), cache.pod_type_key(accumulate_dtype))
|
|
1079
1091
|
|
|
1080
1092
|
if quadrature is not None:
|
|
1081
|
-
kernel_suffix
|
|
1093
|
+
kernel_suffix = (quadrature.name, *kernel_suffix)
|
|
1082
1094
|
|
|
1083
|
-
kernel = cache.get_integrand_kernel(
|
|
1095
|
+
kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
|
|
1096
|
+
integrand=integrand,
|
|
1097
|
+
suffix=kernel_suffix,
|
|
1098
|
+
kernel_options=kernel_options,
|
|
1099
|
+
)
|
|
1084
1100
|
if kernel is not None:
|
|
1085
|
-
return kernel,
|
|
1101
|
+
return kernel, field_arg_values, value_struct_values
|
|
1102
|
+
|
|
1103
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
1104
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
1086
1105
|
|
|
1087
1106
|
# Not found in cache, transform integrand and generate kernel
|
|
1088
1107
|
_check_field_compat(integrand, arguments, domain)
|
|
@@ -1165,7 +1184,7 @@ def _generate_integrate_kernel(
|
|
|
1165
1184
|
accumulate_dtype=accumulate_dtype,
|
|
1166
1185
|
)
|
|
1167
1186
|
|
|
1168
|
-
kernel = cache.get_integrand_kernel(
|
|
1187
|
+
kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
|
|
1169
1188
|
integrand=integrand,
|
|
1170
1189
|
kernel_fn=integrate_kernel_fn,
|
|
1171
1190
|
suffix=kernel_suffix,
|
|
@@ -1175,9 +1194,11 @@ def _generate_integrate_kernel(
|
|
|
1175
1194
|
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
1176
1195
|
)
|
|
1177
1196
|
],
|
|
1197
|
+
FieldStruct=FieldStruct,
|
|
1198
|
+
ValueStruct=ValueStruct,
|
|
1178
1199
|
)
|
|
1179
1200
|
|
|
1180
|
-
return kernel, FieldStruct, ValueStruct
|
|
1201
|
+
return kernel, FieldStruct(), ValueStruct()
|
|
1181
1202
|
|
|
1182
1203
|
|
|
1183
1204
|
def _generate_auxiliary_kernels(
|
|
@@ -1220,8 +1241,8 @@ def _launch_integrate_kernel(
|
|
|
1220
1241
|
integrand: Integrand,
|
|
1221
1242
|
kernel: wp.Kernel,
|
|
1222
1243
|
auxiliary_kernels: List[Tuple[wp.Kernel, int]],
|
|
1223
|
-
|
|
1224
|
-
|
|
1244
|
+
field_arg_values: StructInstance,
|
|
1245
|
+
value_struct_values: StructInstance,
|
|
1225
1246
|
domain: GeometryDomain,
|
|
1226
1247
|
quadrature: Quadrature,
|
|
1227
1248
|
test: Optional[TestField],
|
|
@@ -1243,12 +1264,11 @@ def _launch_integrate_kernel(
|
|
|
1243
1264
|
if quadrature is not None:
|
|
1244
1265
|
qp_arg = quadrature.arg_value(device=device)
|
|
1245
1266
|
|
|
1246
|
-
field_arg_values = FieldStruct()
|
|
1247
1267
|
for k, v in fields.items():
|
|
1248
1268
|
if not isinstance(v, GeometryDomain):
|
|
1249
1269
|
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
1250
1270
|
|
|
1251
|
-
|
|
1271
|
+
cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
|
|
1252
1272
|
|
|
1253
1273
|
# Constant form
|
|
1254
1274
|
if test is None and trial is None:
|
|
@@ -1257,14 +1277,13 @@ def _launch_integrate_kernel(
|
|
|
1257
1277
|
raise RuntimeError("Output array must be of size at least 1")
|
|
1258
1278
|
accumulate_array = output
|
|
1259
1279
|
else:
|
|
1260
|
-
|
|
1280
|
+
accumulate_array = cache.borrow_temporary(
|
|
1261
1281
|
shape=(1),
|
|
1262
1282
|
device=device,
|
|
1263
1283
|
dtype=accumulate_dtype,
|
|
1264
1284
|
temporary_store=temporary_store,
|
|
1265
1285
|
requires_grad=output is not None and output.requires_grad,
|
|
1266
1286
|
)
|
|
1267
|
-
accumulate_array = accumulate_temporary.array
|
|
1268
1287
|
|
|
1269
1288
|
if output != accumulate_array or not add_to_output:
|
|
1270
1289
|
accumulate_array.zero_()
|
|
@@ -1315,21 +1334,17 @@ def _launch_integrate_kernel(
|
|
|
1315
1334
|
output_shape = (test.space_partition.node_count(), test.node_dof_count)
|
|
1316
1335
|
else:
|
|
1317
1336
|
raise RuntimeError(
|
|
1318
|
-
f"Incompatible output type {
|
|
1337
|
+
f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1319
1338
|
)
|
|
1320
1339
|
|
|
1321
|
-
|
|
1340
|
+
output = cache.borrow_temporary(
|
|
1322
1341
|
temporary_store=temporary_store,
|
|
1323
1342
|
shape=output_shape,
|
|
1324
1343
|
dtype=output_dtype,
|
|
1325
1344
|
device=device,
|
|
1326
1345
|
)
|
|
1327
1346
|
|
|
1328
|
-
output = output_temporary.array
|
|
1329
|
-
|
|
1330
1347
|
else:
|
|
1331
|
-
output_temporary = None
|
|
1332
|
-
|
|
1333
1348
|
if output.shape[0] < test.space_partition.node_count():
|
|
1334
1349
|
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
1335
1350
|
|
|
@@ -1337,7 +1352,7 @@ def _launch_integrate_kernel(
|
|
|
1337
1352
|
if type_size(output_dtype) != test.node_dof_count:
|
|
1338
1353
|
if type_size(output_dtype) != 1:
|
|
1339
1354
|
raise RuntimeError(
|
|
1340
|
-
f"Incompatible output type {
|
|
1355
|
+
f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1341
1356
|
)
|
|
1342
1357
|
if output.ndim != 2 and output.shape[1] != test.node_dof_count:
|
|
1343
1358
|
raise RuntimeError(
|
|
@@ -1355,7 +1370,7 @@ def _launch_integrate_kernel(
|
|
|
1355
1370
|
capacity=array.capacity,
|
|
1356
1371
|
device=array.device,
|
|
1357
1372
|
shape=(test.space_partition.node_count(), test.node_dof_count),
|
|
1358
|
-
dtype=
|
|
1373
|
+
dtype=type_scalar_type(output_dtype),
|
|
1359
1374
|
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
1360
1375
|
)
|
|
1361
1376
|
|
|
@@ -1387,7 +1402,7 @@ def _launch_integrate_kernel(
|
|
|
1387
1402
|
|
|
1388
1403
|
wp.launch(
|
|
1389
1404
|
kernel=kernel,
|
|
1390
|
-
dim=local_result.
|
|
1405
|
+
dim=local_result.shape,
|
|
1391
1406
|
inputs=[
|
|
1392
1407
|
qp_arg,
|
|
1393
1408
|
quadrature.element_index_arg_value(device),
|
|
@@ -1395,13 +1410,13 @@ def _launch_integrate_kernel(
|
|
|
1395
1410
|
domain_elt_index_arg,
|
|
1396
1411
|
field_arg_values,
|
|
1397
1412
|
value_struct_values,
|
|
1398
|
-
local_result
|
|
1413
|
+
local_result,
|
|
1399
1414
|
],
|
|
1400
1415
|
device=device,
|
|
1401
1416
|
)
|
|
1402
1417
|
|
|
1403
1418
|
if test.TAYLOR_DOF_COUNT == 0:
|
|
1404
|
-
|
|
1419
|
+
warn(
|
|
1405
1420
|
f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1406
1421
|
category=UserWarning,
|
|
1407
1422
|
stacklevel=2,
|
|
@@ -1418,7 +1433,7 @@ def _launch_integrate_kernel(
|
|
|
1418
1433
|
domain_elt_index_arg,
|
|
1419
1434
|
test_arg,
|
|
1420
1435
|
test.space.space_arg_value(device),
|
|
1421
|
-
local_result
|
|
1436
|
+
local_result,
|
|
1422
1437
|
output_view,
|
|
1423
1438
|
],
|
|
1424
1439
|
device=device,
|
|
@@ -1442,9 +1457,6 @@ def _launch_integrate_kernel(
|
|
|
1442
1457
|
device=device,
|
|
1443
1458
|
)
|
|
1444
1459
|
|
|
1445
|
-
if output_temporary is not None:
|
|
1446
|
-
return output_temporary.detach()
|
|
1447
|
-
|
|
1448
1460
|
return output
|
|
1449
1461
|
|
|
1450
1462
|
# Bilinear form
|
|
@@ -1475,8 +1487,6 @@ def _launch_integrate_kernel(
|
|
|
1475
1487
|
triplet_rows = triplet_rows_temp.array
|
|
1476
1488
|
triplet_values = triplet_values_temp.array
|
|
1477
1489
|
|
|
1478
|
-
triplet_values.zero_()
|
|
1479
|
-
|
|
1480
1490
|
if nodal:
|
|
1481
1491
|
wp.launch(
|
|
1482
1492
|
kernel=kernel,
|
|
@@ -1524,13 +1534,13 @@ def _launch_integrate_kernel(
|
|
|
1524
1534
|
domain_elt_index_arg,
|
|
1525
1535
|
field_arg_values,
|
|
1526
1536
|
value_struct_values,
|
|
1527
|
-
local_result
|
|
1537
|
+
local_result,
|
|
1528
1538
|
],
|
|
1529
1539
|
device=device,
|
|
1530
1540
|
)
|
|
1531
1541
|
|
|
1532
1542
|
if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
|
|
1533
|
-
|
|
1543
|
+
warn(
|
|
1534
1544
|
f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1535
1545
|
category=UserWarning,
|
|
1536
1546
|
stacklevel=2,
|
|
@@ -1557,7 +1567,7 @@ def _launch_integrate_kernel(
|
|
|
1557
1567
|
trial_partition_arg,
|
|
1558
1568
|
trial_topology_arg,
|
|
1559
1569
|
trial.space.space_arg_value(device),
|
|
1560
|
-
local_result
|
|
1570
|
+
local_result,
|
|
1561
1571
|
triplet_rows,
|
|
1562
1572
|
triplet_cols,
|
|
1563
1573
|
triplet_values,
|
|
@@ -1626,20 +1636,12 @@ def _launch_integrate_kernel(
|
|
|
1626
1636
|
|
|
1627
1637
|
|
|
1628
1638
|
def _pick_assembly_strategy(
|
|
1629
|
-
assembly: Optional[str],
|
|
1639
|
+
assembly: Optional[str], operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
|
|
1630
1640
|
):
|
|
1631
1641
|
if assembly is not None:
|
|
1632
1642
|
if assembly not in ("generic", "nodal", "dispatch"):
|
|
1633
1643
|
raise ValueError(f"Invalid assembly strategy'{assembly}'")
|
|
1634
1644
|
return assembly
|
|
1635
|
-
elif nodal is not None:
|
|
1636
|
-
wp.utils.warn(
|
|
1637
|
-
"'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
|
|
1638
|
-
category=DeprecationWarning,
|
|
1639
|
-
stacklevel=2,
|
|
1640
|
-
)
|
|
1641
|
-
if nodal:
|
|
1642
|
-
return "nodal"
|
|
1643
1645
|
|
|
1644
1646
|
test_operators = operators.get(arguments.test_name, set())
|
|
1645
1647
|
trial_operators = operators.get(arguments.trial_name, set())
|
|
@@ -1655,7 +1657,6 @@ def integrate(
|
|
|
1655
1657
|
integrand: Integrand,
|
|
1656
1658
|
domain: Optional[GeometryDomain] = None,
|
|
1657
1659
|
quadrature: Optional[Quadrature] = None,
|
|
1658
|
-
nodal: Optional[bool] = None,
|
|
1659
1660
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
1660
1661
|
values: Optional[Dict[str, Any]] = None,
|
|
1661
1662
|
accumulate_dtype: type = wp.float64,
|
|
@@ -1675,7 +1676,6 @@ def integrate(
|
|
|
1675
1676
|
integrand: Form to be integrated, must have :func:`integrand` decorator
|
|
1676
1677
|
domain: Integration domain. If None, deduced from fields
|
|
1677
1678
|
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1678
|
-
nodal: Deprecated. Use the equivalent assembly="nodal" instead.
|
|
1679
1679
|
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1680
1680
|
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1681
1681
|
temporary_store: shared pool from which to allocate temporary arrays
|
|
@@ -1738,9 +1738,9 @@ def integrate(
|
|
|
1738
1738
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
1739
1739
|
|
|
1740
1740
|
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
1741
|
-
|
|
1741
|
+
warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
1742
1742
|
|
|
1743
|
-
assembly = _pick_assembly_strategy(assembly,
|
|
1743
|
+
assembly = _pick_assembly_strategy(assembly, arguments=arguments, operators=integrand.operators)
|
|
1744
1744
|
# print("assembly for ", integrand.name, ":", strategy)
|
|
1745
1745
|
|
|
1746
1746
|
if assembly == "dispatch":
|
|
@@ -1770,7 +1770,7 @@ def integrate(
|
|
|
1770
1770
|
raise ValueError("Incompatible integration and quadrature domain")
|
|
1771
1771
|
|
|
1772
1772
|
# Canonicalize types
|
|
1773
|
-
accumulate_dtype =
|
|
1773
|
+
accumulate_dtype = type_to_warp(accumulate_dtype)
|
|
1774
1774
|
if output is not None:
|
|
1775
1775
|
if isinstance(output, BsrMatrix):
|
|
1776
1776
|
output_dtype = output.scalar_type
|
|
@@ -1779,9 +1779,9 @@ def integrate(
|
|
|
1779
1779
|
elif output_dtype is None:
|
|
1780
1780
|
output_dtype = accumulate_dtype
|
|
1781
1781
|
else:
|
|
1782
|
-
output_dtype =
|
|
1782
|
+
output_dtype = type_to_warp(output_dtype)
|
|
1783
1783
|
|
|
1784
|
-
kernel,
|
|
1784
|
+
kernel, field_arg_values, value_struct_values = _generate_integrate_kernel(
|
|
1785
1785
|
integrand=integrand,
|
|
1786
1786
|
domain=domain,
|
|
1787
1787
|
quadrature=quadrature,
|
|
@@ -1806,8 +1806,8 @@ def integrate(
|
|
|
1806
1806
|
integrand=integrand,
|
|
1807
1807
|
kernel=kernel,
|
|
1808
1808
|
auxiliary_kernels=auxiliary_kernels,
|
|
1809
|
-
|
|
1810
|
-
|
|
1809
|
+
field_arg_values=field_arg_values,
|
|
1810
|
+
value_struct_values=value_struct_values,
|
|
1811
1811
|
domain=domain,
|
|
1812
1812
|
quadrature=quadrature,
|
|
1813
1813
|
test=test,
|
|
@@ -1827,14 +1827,14 @@ def integrate(
|
|
|
1827
1827
|
def get_interpolate_to_field_function(
|
|
1828
1828
|
integrand_func: wp.Function,
|
|
1829
1829
|
domain: GeometryDomain,
|
|
1830
|
-
FieldStruct:
|
|
1831
|
-
ValueStruct:
|
|
1830
|
+
FieldStruct: Struct,
|
|
1831
|
+
ValueStruct: Struct,
|
|
1832
1832
|
dest: FieldRestriction,
|
|
1833
1833
|
):
|
|
1834
1834
|
zero_value = type_zero_element(dest.space.dtype)
|
|
1835
1835
|
|
|
1836
1836
|
def interpolate_to_field_fn(
|
|
1837
|
-
|
|
1837
|
+
partition_node_index: int,
|
|
1838
1838
|
domain_arg: domain.ElementArg,
|
|
1839
1839
|
domain_index_arg: domain.ElementIndexArg,
|
|
1840
1840
|
dest_node_arg: dest.space_restriction.NodeArg,
|
|
@@ -1842,7 +1842,6 @@ def get_interpolate_to_field_function(
|
|
|
1842
1842
|
fields: FieldStruct,
|
|
1843
1843
|
values: ValueStruct,
|
|
1844
1844
|
):
|
|
1845
|
-
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1846
1845
|
element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
|
|
1847
1846
|
|
|
1848
1847
|
test_dof_index = NULL_DOF_INDEX
|
|
@@ -1894,8 +1893,8 @@ def get_interpolate_to_field_function(
|
|
|
1894
1893
|
def get_interpolate_to_field_kernel(
|
|
1895
1894
|
interpolate_to_field_fn: wp.Function,
|
|
1896
1895
|
domain: GeometryDomain,
|
|
1897
|
-
FieldStruct:
|
|
1898
|
-
ValueStruct:
|
|
1896
|
+
FieldStruct: Struct,
|
|
1897
|
+
ValueStruct: Struct,
|
|
1899
1898
|
dest: FieldRestriction,
|
|
1900
1899
|
):
|
|
1901
1900
|
@wp.func
|
|
@@ -1932,13 +1931,15 @@ def get_interpolate_to_field_kernel(
|
|
|
1932
1931
|
):
|
|
1933
1932
|
local_node_index = wp.tid()
|
|
1934
1933
|
|
|
1934
|
+
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1935
|
+
if partition_node_index == NULL_NODE_INDEX:
|
|
1936
|
+
return
|
|
1937
|
+
|
|
1935
1938
|
val_sum, vol_sum = interpolate_to_field_fn(
|
|
1936
1939
|
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
|
|
1937
1940
|
)
|
|
1938
1941
|
|
|
1939
1942
|
if vol_sum > 0.0:
|
|
1940
|
-
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1941
|
-
|
|
1942
1943
|
# Grab first element containing node; there must be at least one since vol_sum != 0
|
|
1943
1944
|
element_index, node_index_in_element = _find_node_in_element(
|
|
1944
1945
|
domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
|
|
@@ -1959,8 +1960,8 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
1959
1960
|
integrand_func: wp.Function,
|
|
1960
1961
|
domain: GeometryDomain,
|
|
1961
1962
|
quadrature: Quadrature,
|
|
1962
|
-
FieldStruct:
|
|
1963
|
-
ValueStruct:
|
|
1963
|
+
FieldStruct: Struct,
|
|
1964
|
+
ValueStruct: Struct,
|
|
1964
1965
|
value_type: type,
|
|
1965
1966
|
):
|
|
1966
1967
|
def interpolate_at_quadrature_nonvalued_kernel_fn(
|
|
@@ -1978,6 +1979,8 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
1978
1979
|
return
|
|
1979
1980
|
|
|
1980
1981
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1982
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
1983
|
+
return
|
|
1981
1984
|
|
|
1982
1985
|
test_dof_index = NULL_DOF_INDEX
|
|
1983
1986
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -2004,6 +2007,8 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
2004
2007
|
return
|
|
2005
2008
|
|
|
2006
2009
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
2010
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
2011
|
+
return
|
|
2007
2012
|
|
|
2008
2013
|
test_dof_index = NULL_DOF_INDEX
|
|
2009
2014
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -2022,8 +2027,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
|
|
|
2022
2027
|
integrand_func: wp.Function,
|
|
2023
2028
|
domain: GeometryDomain,
|
|
2024
2029
|
quadrature: Quadrature,
|
|
2025
|
-
FieldStruct:
|
|
2026
|
-
ValueStruct:
|
|
2030
|
+
FieldStruct: Struct,
|
|
2031
|
+
ValueStruct: Struct,
|
|
2027
2032
|
trial: TrialField,
|
|
2028
2033
|
value_size: int,
|
|
2029
2034
|
value_type: type,
|
|
@@ -2046,11 +2051,13 @@ def get_interpolate_jacobian_at_quadrature_kernel(
|
|
|
2046
2051
|
):
|
|
2047
2052
|
qp_eval_index, trial_node, trial_dof = wp.tid()
|
|
2048
2053
|
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
2049
|
-
|
|
2050
2054
|
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
2051
2055
|
return
|
|
2052
2056
|
|
|
2053
2057
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
2058
|
+
if element_index == NULL_ELEMENT_INDEX:
|
|
2059
|
+
return
|
|
2060
|
+
|
|
2054
2061
|
if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
|
|
2055
2062
|
return
|
|
2056
2063
|
|
|
@@ -2090,8 +2097,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
|
|
|
2090
2097
|
def get_interpolate_free_kernel(
|
|
2091
2098
|
integrand_func: wp.Function,
|
|
2092
2099
|
domain: GeometryDomain,
|
|
2093
|
-
FieldStruct:
|
|
2094
|
-
ValueStruct:
|
|
2100
|
+
FieldStruct: Struct,
|
|
2101
|
+
ValueStruct: Struct,
|
|
2095
2102
|
value_type: type,
|
|
2096
2103
|
):
|
|
2097
2104
|
def interpolate_free_nonvalued_kernel_fn(
|
|
@@ -2144,31 +2151,31 @@ def _generate_interpolate_kernel(
|
|
|
2144
2151
|
arguments: IntegrandArguments,
|
|
2145
2152
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
2146
2153
|
) -> wp.Kernel:
|
|
2147
|
-
# Generate field struct
|
|
2148
|
-
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
2149
|
-
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
2150
|
-
|
|
2151
2154
|
_notify_operator_usage(integrand, arguments.field_args)
|
|
2152
2155
|
|
|
2153
2156
|
# Check if kernel exist in cache
|
|
2154
|
-
field_names =
|
|
2157
|
+
field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
|
|
2155
2158
|
if isinstance(dest, FieldRestriction):
|
|
2156
|
-
kernel_suffix =
|
|
2159
|
+
kernel_suffix = ("itp", *field_names, dest.domain.name, dest.space_restriction.space_partition.name)
|
|
2157
2160
|
else:
|
|
2158
2161
|
dest_dtype = dest.dtype if dest else None
|
|
2159
|
-
type_str =
|
|
2162
|
+
type_str = cache.pod_type_key(dest_dtype) if dest_dtype else ""
|
|
2160
2163
|
if quadrature is None:
|
|
2161
|
-
kernel_suffix =
|
|
2164
|
+
kernel_suffix = ("itp", *field_names, domain.name, type_str)
|
|
2162
2165
|
else:
|
|
2163
|
-
kernel_suffix =
|
|
2166
|
+
kernel_suffix = ("itp", *field_names, domain.name, quadrature.name, type_str)
|
|
2164
2167
|
|
|
2165
|
-
kernel = cache.get_integrand_kernel(
|
|
2168
|
+
kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
|
|
2166
2169
|
integrand=integrand,
|
|
2167
2170
|
suffix=kernel_suffix,
|
|
2168
2171
|
kernel_options=kernel_options,
|
|
2169
2172
|
)
|
|
2170
2173
|
if kernel is not None:
|
|
2171
|
-
return kernel,
|
|
2174
|
+
return kernel, field_arg_values, value_struct_values
|
|
2175
|
+
|
|
2176
|
+
# Generate field struct
|
|
2177
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
2178
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
2172
2179
|
|
|
2173
2180
|
# Not found in cache, transform integrand and generate kernel
|
|
2174
2181
|
_check_field_compat(integrand, arguments, domain)
|
|
@@ -2235,7 +2242,7 @@ def _generate_interpolate_kernel(
|
|
|
2235
2242
|
ValueStruct=ValueStruct,
|
|
2236
2243
|
)
|
|
2237
2244
|
|
|
2238
|
-
kernel = cache.get_integrand_kernel(
|
|
2245
|
+
kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
|
|
2239
2246
|
integrand=integrand,
|
|
2240
2247
|
kernel_fn=interpolate_kernel_fn,
|
|
2241
2248
|
suffix=kernel_suffix,
|
|
@@ -2245,16 +2252,18 @@ def _generate_interpolate_kernel(
|
|
|
2245
2252
|
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
2246
2253
|
)
|
|
2247
2254
|
],
|
|
2255
|
+
FieldStruct=FieldStruct,
|
|
2256
|
+
ValueStruct=ValueStruct,
|
|
2248
2257
|
)
|
|
2249
2258
|
|
|
2250
|
-
return kernel, FieldStruct, ValueStruct
|
|
2259
|
+
return kernel, FieldStruct(), ValueStruct()
|
|
2251
2260
|
|
|
2252
2261
|
|
|
2253
2262
|
def _launch_interpolate_kernel(
|
|
2254
2263
|
integrand: Integrand,
|
|
2255
2264
|
kernel: wp.kernel,
|
|
2256
|
-
|
|
2257
|
-
|
|
2265
|
+
field_arg_values: StructInstance,
|
|
2266
|
+
value_struct_values: StructInstance,
|
|
2258
2267
|
domain: GeometryDomain,
|
|
2259
2268
|
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
2260
2269
|
quadrature: Optional[Quadrature],
|
|
@@ -2270,12 +2279,10 @@ def _launch_interpolate_kernel(
|
|
|
2270
2279
|
elt_arg = domain.element_arg_value(device=device)
|
|
2271
2280
|
elt_index_arg = domain.element_index_arg_value(device=device)
|
|
2272
2281
|
|
|
2273
|
-
field_arg_values = FieldStruct()
|
|
2274
2282
|
for k, v in fields.items():
|
|
2275
2283
|
if not isinstance(v, GeometryDomain):
|
|
2276
2284
|
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
2277
|
-
|
|
2278
|
-
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
2285
|
+
cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
|
|
2279
2286
|
|
|
2280
2287
|
if isinstance(dest, FieldRestriction):
|
|
2281
2288
|
dest_node_arg = dest.space_restriction.node_arg_value(device=device)
|
|
@@ -2313,7 +2320,7 @@ def _launch_interpolate_kernel(
|
|
|
2313
2320
|
qp_index_count = quadrature.total_point_count()
|
|
2314
2321
|
|
|
2315
2322
|
if qp_eval_count != qp_index_count:
|
|
2316
|
-
|
|
2323
|
+
warn(
|
|
2317
2324
|
f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
|
|
2318
2325
|
category=UserWarning,
|
|
2319
2326
|
stacklevel=2,
|
|
@@ -2353,7 +2360,6 @@ def _launch_interpolate_kernel(
|
|
|
2353
2360
|
triplet_rows = triplet_rows_temp.array
|
|
2354
2361
|
triplet_values = triplet_values_temp.array
|
|
2355
2362
|
triplet_rows.fill_(-1)
|
|
2356
|
-
triplet_values.zero_()
|
|
2357
2363
|
|
|
2358
2364
|
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
2359
2365
|
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
@@ -2470,9 +2476,9 @@ def interpolate(
|
|
|
2470
2476
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
2471
2477
|
|
|
2472
2478
|
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
2473
|
-
|
|
2479
|
+
warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
2474
2480
|
|
|
2475
|
-
kernel,
|
|
2481
|
+
kernel, field_struct, value_struct = _generate_interpolate_kernel(
|
|
2476
2482
|
integrand=integrand,
|
|
2477
2483
|
domain=domain,
|
|
2478
2484
|
dest=dest,
|
|
@@ -2484,8 +2490,8 @@ def interpolate(
|
|
|
2484
2490
|
return _launch_interpolate_kernel(
|
|
2485
2491
|
integrand=integrand,
|
|
2486
2492
|
kernel=kernel,
|
|
2487
|
-
|
|
2488
|
-
|
|
2493
|
+
field_arg_values=field_struct,
|
|
2494
|
+
value_struct_values=value_struct,
|
|
2489
2495
|
domain=domain,
|
|
2490
2496
|
dest=dest,
|
|
2491
2497
|
quadrature=quadrature,
|