warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import ast
|
|
2
|
-
|
|
2
|
+
import inspect
|
|
3
|
+
import textwrap
|
|
4
|
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
|
3
5
|
|
|
4
6
|
import warp as wp
|
|
5
7
|
from warp.codegen import get_annotations
|
|
@@ -10,11 +12,15 @@ from warp.fem.field import (
|
|
|
10
12
|
FieldLike,
|
|
11
13
|
FieldRestriction,
|
|
12
14
|
GeometryField,
|
|
15
|
+
LocalTestField,
|
|
16
|
+
LocalTrialField,
|
|
13
17
|
TestField,
|
|
14
18
|
TrialField,
|
|
15
19
|
make_restriction,
|
|
16
20
|
)
|
|
17
|
-
from warp.fem.
|
|
21
|
+
from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
|
|
22
|
+
from warp.fem.linalg import array_axpy
|
|
23
|
+
from warp.fem.operator import Integrand, Operator, at_node, integrand
|
|
18
24
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
19
25
|
from warp.fem.types import (
|
|
20
26
|
NULL_DOF_INDEX,
|
|
@@ -69,28 +75,58 @@ def _resolve_path(func, node):
|
|
|
69
75
|
return None, path
|
|
70
76
|
|
|
71
77
|
|
|
72
|
-
class
|
|
73
|
-
|
|
78
|
+
class IntegrandVisitor(ast.NodeTransformer):
|
|
79
|
+
class FieldInfo(NamedTuple):
|
|
80
|
+
field: FieldLike
|
|
81
|
+
abstract_type: type
|
|
82
|
+
concrete_type: type
|
|
83
|
+
root_arg_name: type
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
integrand: Integrand,
|
|
88
|
+
field_info: Dict[str, FieldInfo],
|
|
89
|
+
):
|
|
74
90
|
self._integrand = integrand
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
91
|
+
self._field_symbols = field_info.copy()
|
|
92
|
+
self._field_nodes = {}
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
96
|
+
def get_concrete_type(field: Union[FieldLike, Domain]):
|
|
97
|
+
if isinstance(field, FieldLike):
|
|
98
|
+
return field.ElementEvalArg
|
|
99
|
+
return field.ElementArg
|
|
100
|
+
|
|
101
|
+
return {
|
|
102
|
+
name: IntegrandVisitor.FieldInfo(
|
|
103
|
+
field=field,
|
|
104
|
+
abstract_type=integrand.argspec.annotations[name],
|
|
105
|
+
concrete_type=get_concrete_type(field),
|
|
106
|
+
root_arg_name=name,
|
|
107
|
+
)
|
|
108
|
+
for name, field in field_args.items()
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
def _get_field_info(self, node: ast.expr):
|
|
112
|
+
field_info = self._field_nodes.get(node)
|
|
113
|
+
if field_info is None and isinstance(node, ast.Name):
|
|
114
|
+
field_info = self._field_symbols.get(node.id)
|
|
115
|
+
|
|
116
|
+
return field_info
|
|
77
117
|
|
|
78
118
|
def visit_Call(self, call: ast.Call):
|
|
79
119
|
call = self.generic_visit(call)
|
|
80
120
|
|
|
81
121
|
callee = getattr(call.func, "id", None)
|
|
82
|
-
if callee in self.
|
|
122
|
+
if callee in self._field_symbols:
|
|
83
123
|
# Shortcut for evaluating fields as f(x...)
|
|
84
|
-
|
|
124
|
+
field_info = self._field_symbols[callee]
|
|
85
125
|
|
|
86
126
|
# Replace with default call operator
|
|
87
|
-
|
|
88
|
-
default_operator = abstract_arg_type.call_operator
|
|
89
|
-
concrete_arg_type = self._annotations[callee]
|
|
90
|
-
self._replace_call_func(call, concrete_arg_type, default_operator, field)
|
|
127
|
+
default_operator = field_info.abstract_type.call_operator
|
|
91
128
|
|
|
92
|
-
|
|
93
|
-
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
|
|
129
|
+
self._process_operator_call(call, callee, default_operator, field_info)
|
|
94
130
|
|
|
95
131
|
return call
|
|
96
132
|
|
|
@@ -98,44 +134,47 @@ class IntegrandTransformer(ast.NodeTransformer):
|
|
|
98
134
|
|
|
99
135
|
if isinstance(func, Operator) and len(call.args) > 0:
|
|
100
136
|
# Evaluating operators as op(field, x, ...)
|
|
101
|
-
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
137
|
+
field_info = self._get_field_info(call.args[0])
|
|
138
|
+
if field_info is not None:
|
|
139
|
+
self._process_operator_call(call, func, func, field_info)
|
|
140
|
+
|
|
141
|
+
if func.field_result:
|
|
142
|
+
res = func.field_result(field_info.field)
|
|
143
|
+
self._field_nodes[call] = IntegrandVisitor.FieldInfo(
|
|
144
|
+
field=res[0],
|
|
145
|
+
abstract_type=res[1],
|
|
146
|
+
concrete_type=res[2],
|
|
147
|
+
root_arg_name=f"{field_info.root_arg_name}.{func.name}",
|
|
148
|
+
)
|
|
105
149
|
|
|
106
150
|
if isinstance(func, Integrand):
|
|
107
|
-
|
|
108
|
-
call
|
|
109
|
-
value=call.func,
|
|
110
|
-
attr=key,
|
|
111
|
-
ctx=ast.Load(),
|
|
112
|
-
)
|
|
151
|
+
callee_field_args = self._get_callee_field_args(func, call.args)
|
|
152
|
+
self._process_integrand_call(call, func, callee_field_args)
|
|
113
153
|
|
|
114
154
|
# print(ast.dump(call, indent=4))
|
|
115
155
|
|
|
116
156
|
return call
|
|
117
157
|
|
|
118
|
-
def
|
|
119
|
-
|
|
120
|
-
# Retrieve the function pointer corresponding to the operator implementation for the field type
|
|
121
|
-
pointer = operator.resolver(field)
|
|
122
|
-
if pointer is None:
|
|
123
|
-
raise NotImplementedError(operator.resolver.__name__)
|
|
158
|
+
def visit_Assign(self, node: ast.Assign):
|
|
159
|
+
node = self.generic_visit(node)
|
|
124
160
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
161
|
+
# Check if we're assigning a field
|
|
162
|
+
src_field_info = self._get_field_info(node.value)
|
|
163
|
+
if src_field_info is not None:
|
|
164
|
+
if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
|
|
165
|
+
raise NotImplementedError("warp.fem Fields and Domains may only be assigned to simple variables")
|
|
166
|
+
|
|
167
|
+
self._field_symbols[node.targets[0].id] = src_field_info
|
|
168
|
+
|
|
169
|
+
return node
|
|
131
170
|
|
|
132
|
-
def
|
|
171
|
+
def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
|
|
133
172
|
# Get field types for call site arguments
|
|
134
|
-
call_site_field_args = []
|
|
173
|
+
call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
|
|
135
174
|
for arg in args:
|
|
136
|
-
|
|
137
|
-
if
|
|
138
|
-
call_site_field_args.append(
|
|
175
|
+
field_info = self._get_field_info(arg)
|
|
176
|
+
if field_info is not None:
|
|
177
|
+
call_site_field_args.append(field_info)
|
|
139
178
|
|
|
140
179
|
call_site_field_args.reverse()
|
|
141
180
|
|
|
@@ -144,46 +183,129 @@ class IntegrandTransformer(ast.NodeTransformer):
|
|
|
144
183
|
for arg in callee.argspec.args:
|
|
145
184
|
arg_type = callee.argspec.annotations[arg]
|
|
146
185
|
if arg_type in (Field, Domain):
|
|
147
|
-
|
|
186
|
+
passed_field_info = call_site_field_args.pop()
|
|
187
|
+
if passed_field_info.abstract_type != arg_type:
|
|
188
|
+
raise TypeError(
|
|
189
|
+
f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
|
|
190
|
+
)
|
|
191
|
+
callee_field_args[arg] = passed_field_info
|
|
148
192
|
|
|
149
|
-
return
|
|
193
|
+
return callee_field_args
|
|
150
194
|
|
|
151
195
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
196
|
+
class IntegrandOperatorParser(IntegrandVisitor):
|
|
197
|
+
def __init__(self, integrand: Integrand, field_info: Dict[str, IntegrandVisitor.FieldInfo], callback: Callable):
|
|
198
|
+
super().__init__(integrand, field_info)
|
|
199
|
+
self._operator_callback = callback
|
|
200
|
+
|
|
201
|
+
def _process_operator_call(
|
|
202
|
+
self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
|
|
203
|
+
):
|
|
204
|
+
self._operator_callback(field_info, operator)
|
|
205
|
+
|
|
206
|
+
def _process_integrand_call(
|
|
207
|
+
self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
|
|
208
|
+
):
|
|
209
|
+
callee_field_args = self._get_callee_field_args(callee, call.args)
|
|
210
|
+
callee_parser = IntegrandOperatorParser(callee, callee_field_args, callback=self._operator_callback)
|
|
211
|
+
callee_parser._apply()
|
|
212
|
+
|
|
213
|
+
def _apply(self):
|
|
214
|
+
source = textwrap.dedent(inspect.getsource(self._integrand.func))
|
|
215
|
+
tree = ast.parse(source)
|
|
216
|
+
self.visit(tree)
|
|
217
|
+
|
|
218
|
+
@staticmethod
|
|
219
|
+
def apply(
|
|
220
|
+
integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
|
|
221
|
+
) -> wp.Function:
|
|
222
|
+
field_info = IntegrandVisitor._build_field_info(integrand, field_args)
|
|
223
|
+
IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class IntegrandTransformer(IntegrandVisitor):
|
|
227
|
+
def _process_operator_call(
|
|
228
|
+
self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
|
|
229
|
+
):
|
|
230
|
+
field = field_info.field
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
# Retrieve the function pointer corresponding to the operator implementation for the field type
|
|
234
|
+
pointer = operator.resolver(field)
|
|
235
|
+
if not isinstance(pointer, wp.context.Function):
|
|
236
|
+
raise NotImplementedError(operator.resolver.__name__)
|
|
237
|
+
|
|
238
|
+
except (AttributeError, NotImplementedError) as e:
|
|
239
|
+
raise TypeError(
|
|
240
|
+
f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
|
|
241
|
+
) from e
|
|
242
|
+
|
|
243
|
+
# Update the ast Call node to use the new function pointer
|
|
244
|
+
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
|
|
245
|
+
|
|
246
|
+
# Save the pointer as an attribute than can be accessed from the calling scope
|
|
247
|
+
# For usual operator call syntax, we can use the operator itself, but for the
|
|
248
|
+
# shortcut default operator syntax, we store it on the callee's concrete type
|
|
249
|
+
if isinstance(callee, Operator):
|
|
250
|
+
setattr(callee, pointer.key, pointer)
|
|
162
251
|
else:
|
|
163
|
-
|
|
252
|
+
setattr(field_info.concrete_type, pointer.key, pointer)
|
|
253
|
+
|
|
254
|
+
# also insert callee as first argument
|
|
255
|
+
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
|
|
256
|
+
|
|
257
|
+
def _process_integrand_call(
|
|
258
|
+
self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
|
|
259
|
+
):
|
|
260
|
+
callee_field_args = self._get_callee_field_args(callee, call.args)
|
|
261
|
+
transformer = IntegrandTransformer(callee, callee_field_args)
|
|
262
|
+
key = transformer._apply().key
|
|
263
|
+
call.func = ast.Attribute(
|
|
264
|
+
value=call.func,
|
|
265
|
+
attr=key,
|
|
266
|
+
ctx=ast.Load(),
|
|
267
|
+
)
|
|
164
268
|
|
|
165
|
-
|
|
166
|
-
|
|
269
|
+
def _apply(self) -> wp.Function:
|
|
270
|
+
# Transform field evaluation calls
|
|
271
|
+
field_info = self._field_symbols
|
|
272
|
+
|
|
273
|
+
# Specialize field argument types
|
|
274
|
+
argspec = self._integrand.argspec
|
|
275
|
+
annotations = argspec.annotations.copy()
|
|
276
|
+
annotations.update({name: f.concrete_type for name, f in field_info.items()})
|
|
277
|
+
|
|
278
|
+
suffix = "_".join([f.field.name for f in field_info.values()])
|
|
279
|
+
func = cache.get_integrand_function(
|
|
280
|
+
integrand=self._integrand,
|
|
281
|
+
suffix=suffix,
|
|
282
|
+
annotations=annotations,
|
|
283
|
+
code_transformers=[self],
|
|
284
|
+
)
|
|
167
285
|
|
|
168
|
-
|
|
286
|
+
# func = self._integrand.module.functions[func.key] #no longer needed?
|
|
287
|
+
setattr(self._integrand, func.key, func)
|
|
169
288
|
|
|
170
|
-
|
|
171
|
-
integrand=integrand,
|
|
172
|
-
suffix=suffix,
|
|
173
|
-
annotations=annotations,
|
|
174
|
-
code_transformers=[transformer],
|
|
175
|
-
)
|
|
289
|
+
return func
|
|
176
290
|
|
|
177
|
-
|
|
178
|
-
|
|
291
|
+
@staticmethod
|
|
292
|
+
def apply(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
|
|
293
|
+
field_info = IntegrandVisitor._build_field_info(integrand, field_args)
|
|
294
|
+
return IntegrandTransformer(integrand, field_info)._apply()
|
|
179
295
|
|
|
180
|
-
return getattr(integrand, key)
|
|
181
296
|
|
|
297
|
+
class IntegrandArguments(NamedTuple):
|
|
298
|
+
field_args: Dict[str, Union[FieldLike, GeometryDomain]]
|
|
299
|
+
value_args: Dict[str, Any]
|
|
300
|
+
domain_name: str
|
|
301
|
+
sample_name: str
|
|
302
|
+
test_name: str
|
|
303
|
+
trial_name: str
|
|
182
304
|
|
|
183
|
-
|
|
305
|
+
|
|
306
|
+
def _parse_integrand_arguments(
|
|
184
307
|
integrand: Integrand,
|
|
185
308
|
fields: Dict[str, FieldLike],
|
|
186
|
-
domain: GeometryDomain = None,
|
|
187
309
|
):
|
|
188
310
|
# parse argument types
|
|
189
311
|
field_args = {}
|
|
@@ -191,38 +313,57 @@ def _get_integrand_field_arguments(
|
|
|
191
313
|
|
|
192
314
|
domain_name = None
|
|
193
315
|
sample_name = None
|
|
316
|
+
test_name = None
|
|
317
|
+
trial_name = None
|
|
194
318
|
|
|
195
319
|
argspec = integrand.argspec
|
|
196
320
|
for arg in argspec.args:
|
|
197
321
|
arg_type = argspec.annotations[arg]
|
|
198
322
|
if arg_type == Field:
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
323
|
+
try:
|
|
324
|
+
field = fields[arg]
|
|
325
|
+
except KeyError as err:
|
|
326
|
+
raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
|
|
327
|
+
if not isinstance(field, FieldLike):
|
|
328
|
+
raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
|
|
329
|
+
if isinstance(field, TestField):
|
|
330
|
+
if test_name is not None:
|
|
331
|
+
raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
|
|
332
|
+
test_name = arg
|
|
333
|
+
elif isinstance(field, TrialField):
|
|
334
|
+
if trial_name is not None:
|
|
335
|
+
raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
|
|
336
|
+
trial_name = arg
|
|
337
|
+
field_args[arg] = field
|
|
202
338
|
elif arg_type == Domain:
|
|
339
|
+
if domain_name is not None:
|
|
340
|
+
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
|
|
341
|
+
if arg in fields:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
344
|
+
)
|
|
203
345
|
domain_name = arg
|
|
204
|
-
field_args[arg] = domain
|
|
205
346
|
elif arg_type == Sample:
|
|
347
|
+
if sample_name is not None:
|
|
348
|
+
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
|
|
349
|
+
if arg in fields:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
352
|
+
)
|
|
206
353
|
sample_name = arg
|
|
207
354
|
else:
|
|
355
|
+
if arg in fields:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
|
|
358
|
+
)
|
|
208
359
|
value_args[arg] = arg_type
|
|
209
360
|
|
|
210
|
-
return field_args, value_args, domain_name, sample_name
|
|
361
|
+
return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
|
|
211
362
|
|
|
212
363
|
|
|
213
|
-
def _check_field_compat(
|
|
214
|
-
integrand: Integrand,
|
|
215
|
-
fields: Dict[str, FieldLike],
|
|
216
|
-
field_args: Dict[str, FieldLike],
|
|
217
|
-
domain: GeometryDomain = None,
|
|
218
|
-
):
|
|
364
|
+
def _check_field_compat(integrand: Integrand, arguments: IntegrandArguments, domain: GeometryDomain):
|
|
219
365
|
# Check field compatibility
|
|
220
|
-
for name, field in
|
|
221
|
-
if name not in field_args:
|
|
222
|
-
raise ValueError(
|
|
223
|
-
f"Passed field argument '{name}' does not match any parameter of integrand '{integrand.name}'"
|
|
224
|
-
)
|
|
225
|
-
|
|
366
|
+
for name, field in arguments.field_args.items():
|
|
226
367
|
if isinstance(field, GeometryField) and domain is not None:
|
|
227
368
|
if field.geometry != domain.geometry:
|
|
228
369
|
raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
|
|
@@ -232,37 +373,32 @@ def _check_field_compat(
|
|
|
232
373
|
)
|
|
233
374
|
|
|
234
375
|
|
|
235
|
-
def
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
trial = None
|
|
240
|
-
test_name = None
|
|
241
|
-
trial_name = None
|
|
376
|
+
def _find_integrand_operators(integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
377
|
+
if integrand.operators is None:
|
|
378
|
+
# Integrands operator dictionary does not depend on concrete field type,
|
|
379
|
+
# so only needs to be built once per integrand
|
|
242
380
|
|
|
243
|
-
|
|
244
|
-
if not isinstance(field, FieldLike):
|
|
245
|
-
raise ValueError(f"Passed field argument '{name}' is not a proper Field")
|
|
246
|
-
|
|
247
|
-
if isinstance(field, TestField):
|
|
248
|
-
if test is not None:
|
|
249
|
-
raise ValueError(f"More than one test field argument: '{test_name}' and '{name}'")
|
|
250
|
-
test = field
|
|
251
|
-
test_name = name
|
|
252
|
-
elif isinstance(field, TrialField):
|
|
253
|
-
if trial is not None:
|
|
254
|
-
raise ValueError(f"More than one trial field argument: '{trial_name}' and '{name}'")
|
|
255
|
-
trial = field
|
|
256
|
-
trial_name = name
|
|
257
|
-
|
|
258
|
-
if trial is not None:
|
|
259
|
-
if test is None:
|
|
260
|
-
raise ValueError("A trial field cannot be provided without a test field")
|
|
381
|
+
operators = {}
|
|
261
382
|
|
|
262
|
-
|
|
263
|
-
|
|
383
|
+
def operator_callback(field: IntegrandVisitor.FieldInfo, op: Operator):
|
|
384
|
+
if field.root_arg_name in operators:
|
|
385
|
+
operators[field.root_arg_name].add(op)
|
|
386
|
+
else:
|
|
387
|
+
operators[field.root_arg_name] = {op}
|
|
388
|
+
|
|
389
|
+
IntegrandOperatorParser.apply(integrand, field_args, operator_callback=operator_callback)
|
|
390
|
+
|
|
391
|
+
integrand.operators = operators
|
|
264
392
|
|
|
265
|
-
|
|
393
|
+
|
|
394
|
+
def _notify_operator_usage(
|
|
395
|
+
integrand: Integrand,
|
|
396
|
+
field_args: Dict[str, FieldLike],
|
|
397
|
+
):
|
|
398
|
+
for arg, field_ops in integrand.operators.items():
|
|
399
|
+
if arg in field_args:
|
|
400
|
+
# print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
|
|
401
|
+
field_args[arg].notify_operator_usage(field_ops)
|
|
266
402
|
|
|
267
403
|
|
|
268
404
|
def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
@@ -295,26 +431,12 @@ def _get_test_arg():
|
|
|
295
431
|
pass
|
|
296
432
|
|
|
297
433
|
|
|
298
|
-
class _FieldWrappers:
|
|
299
|
-
pass
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
|
|
303
|
-
integrand_func._field_wrappers = _FieldWrappers()
|
|
304
|
-
for name, field in fields.items():
|
|
305
|
-
setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
|
|
306
|
-
|
|
307
|
-
|
|
308
434
|
class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
309
435
|
def __init__(
|
|
310
436
|
self,
|
|
311
437
|
arg_names: List[str],
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
sample_name: str,
|
|
315
|
-
domain_name: str,
|
|
316
|
-
test_name: str = None,
|
|
317
|
-
trial_name: str = None,
|
|
438
|
+
parsed_args: IntegrandArguments,
|
|
439
|
+
integrand_func: wp.Function,
|
|
318
440
|
func_name: str = "integrand_func",
|
|
319
441
|
fields_var_name: str = "fields",
|
|
320
442
|
values_var_name: str = "values",
|
|
@@ -323,18 +445,32 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
323
445
|
field_wrappers_attr: str = "_field_wrappers",
|
|
324
446
|
):
|
|
325
447
|
self._arg_names = arg_names
|
|
326
|
-
self._field_args = field_args
|
|
327
|
-
self._value_args = value_args
|
|
328
|
-
self._domain_name = domain_name
|
|
329
|
-
self._sample_name = sample_name
|
|
448
|
+
self._field_args = parsed_args.field_args
|
|
449
|
+
self._value_args = parsed_args.value_args
|
|
450
|
+
self._domain_name = parsed_args.domain_name
|
|
451
|
+
self._sample_name = parsed_args.sample_name
|
|
452
|
+
self._test_name = parsed_args.test_name
|
|
453
|
+
self._trial_name = parsed_args.trial_name
|
|
330
454
|
self._func_name = func_name
|
|
331
|
-
self._test_name = test_name
|
|
332
|
-
self._trial_name = trial_name
|
|
333
455
|
self._fields_var_name = fields_var_name
|
|
334
456
|
self._values_var_name = values_var_name
|
|
335
457
|
self._domain_var_name = domain_var_name
|
|
336
458
|
self._sample_var_name = sample_var_name
|
|
459
|
+
|
|
337
460
|
self._field_wrappers_attr = field_wrappers_attr
|
|
461
|
+
self._register_integrand_field_wrappers(integrand_func, parsed_args.field_args)
|
|
462
|
+
|
|
463
|
+
class _FieldWrappers:
|
|
464
|
+
pass
|
|
465
|
+
|
|
466
|
+
def _register_integrand_field_wrappers(self, integrand_func: wp.Function, fields: Dict[str, FieldLike]):
|
|
467
|
+
# Mechanism to pass the geometry argument only once to the root kernel
|
|
468
|
+
# Field wrappers are used to forward it to all fields in nested integrand calls
|
|
469
|
+
field_wrappers = PassFieldArgsToIntegrand._FieldWrappers()
|
|
470
|
+
for name, field in fields.items():
|
|
471
|
+
if isinstance(field, FieldLike):
|
|
472
|
+
setattr(field_wrappers, name, field.ElementEvalArg)
|
|
473
|
+
setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
|
|
338
474
|
|
|
339
475
|
def visit_Call(self, call: ast.Call):
|
|
340
476
|
call = self.generic_visit(call)
|
|
@@ -405,6 +541,16 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
405
541
|
return call
|
|
406
542
|
|
|
407
543
|
|
|
544
|
+
def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_site_options: Optional[Dict[str, Any]]):
|
|
545
|
+
if integrand_options is None:
|
|
546
|
+
return {} if call_site_options is None else call_site_options
|
|
547
|
+
|
|
548
|
+
options = integrand_options.copy()
|
|
549
|
+
if call_site_options is not None:
|
|
550
|
+
options.update(call_site_options)
|
|
551
|
+
return options
|
|
552
|
+
|
|
553
|
+
|
|
408
554
|
def get_integrate_constant_kernel(
|
|
409
555
|
integrand_func: wp.Function,
|
|
410
556
|
domain: GeometryDomain,
|
|
@@ -500,7 +646,7 @@ def get_integrate_linear_kernel(
|
|
|
500
646
|
|
|
501
647
|
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
502
648
|
|
|
503
|
-
result[node_index, test_dof]
|
|
649
|
+
result[node_index, test_dof] += output_dtype(val_sum)
|
|
504
650
|
|
|
505
651
|
return integrate_kernel_fn
|
|
506
652
|
|
|
@@ -571,7 +717,48 @@ def get_integrate_linear_nodal_kernel(
|
|
|
571
717
|
|
|
572
718
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
573
719
|
|
|
574
|
-
result[partition_node_index, dof]
|
|
720
|
+
result[partition_node_index, dof] += output_dtype(val_sum)
|
|
721
|
+
|
|
722
|
+
return integrate_kernel_fn
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
def get_integrate_linear_local_kernel(
|
|
726
|
+
integrand_func: wp.Function,
|
|
727
|
+
domain: GeometryDomain,
|
|
728
|
+
quadrature: Quadrature,
|
|
729
|
+
FieldStruct: wp.codegen.Struct,
|
|
730
|
+
ValueStruct: wp.codegen.Struct,
|
|
731
|
+
test: LocalTestField,
|
|
732
|
+
):
|
|
733
|
+
TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
|
|
734
|
+
|
|
735
|
+
def integrate_kernel_fn(
|
|
736
|
+
qp_arg: quadrature.Arg,
|
|
737
|
+
domain_arg: domain.ElementArg,
|
|
738
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
739
|
+
fields: FieldStruct,
|
|
740
|
+
values: ValueStruct,
|
|
741
|
+
result: wp.array3d(dtype=float),
|
|
742
|
+
):
|
|
743
|
+
domain_element_index, taylor_dof, test_dof = wp.tid()
|
|
744
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
745
|
+
|
|
746
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
747
|
+
test_dof_offset = test_dof * TAYLOR_DOF_COUNT
|
|
748
|
+
|
|
749
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
750
|
+
for qp in range(qp_point_count):
|
|
751
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
752
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
753
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
754
|
+
|
|
755
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
756
|
+
|
|
757
|
+
test_dof_index = DofIndex(qp_index, test_dof_offset + taylor_dof)
|
|
758
|
+
|
|
759
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
760
|
+
val = integrand_func(sample, fields, values)
|
|
761
|
+
result[qp_index, taylor_dof, test_dof] = qp_weight * vol * val
|
|
575
762
|
|
|
576
763
|
return integrate_kernel_fn
|
|
577
764
|
|
|
@@ -747,62 +934,89 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
747
934
|
return integrate_kernel_fn
|
|
748
935
|
|
|
749
936
|
|
|
937
|
+
def get_integrate_bilinear_local_kernel(
|
|
938
|
+
integrand_func: wp.Function,
|
|
939
|
+
domain: GeometryDomain,
|
|
940
|
+
quadrature: Quadrature,
|
|
941
|
+
FieldStruct: wp.codegen.Struct,
|
|
942
|
+
ValueStruct: wp.codegen.Struct,
|
|
943
|
+
test: LocalTestField,
|
|
944
|
+
trial: LocalTrialField,
|
|
945
|
+
):
|
|
946
|
+
TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
|
|
947
|
+
TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
|
|
948
|
+
|
|
949
|
+
def integrate_kernel_fn(
|
|
950
|
+
qp_arg: quadrature.Arg,
|
|
951
|
+
domain_arg: domain.ElementArg,
|
|
952
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
953
|
+
fields: FieldStruct,
|
|
954
|
+
values: ValueStruct,
|
|
955
|
+
result: wp.array4d(dtype=float),
|
|
956
|
+
):
|
|
957
|
+
domain_element_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
|
|
958
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
959
|
+
|
|
960
|
+
test_dof_offset = TEST_TAYLOR_DOF_COUNT * test_dof
|
|
961
|
+
trial_dof_offset = TRIAL_TAYLOR_DOF_COUNT * trial_dof
|
|
962
|
+
|
|
963
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
964
|
+
for k in range(qp_point_count):
|
|
965
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
966
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
967
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
968
|
+
|
|
969
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
970
|
+
qp_vol = vol * qp_weight
|
|
971
|
+
|
|
972
|
+
for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
|
|
973
|
+
taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
|
|
974
|
+
|
|
975
|
+
test_dof_index = DofIndex(qp_index, test_dof_offset + test_taylor_dof)
|
|
976
|
+
trial_dof_index = DofIndex(qp_index, trial_dof_offset + trial_taylor_dof)
|
|
977
|
+
|
|
978
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
979
|
+
val = integrand_func(sample, fields, values)
|
|
980
|
+
result[qp_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
|
|
981
|
+
|
|
982
|
+
return integrate_kernel_fn
|
|
983
|
+
|
|
984
|
+
|
|
750
985
|
def _generate_integrate_kernel(
|
|
751
986
|
integrand: Integrand,
|
|
752
987
|
domain: GeometryDomain,
|
|
753
|
-
nodal: bool,
|
|
754
988
|
quadrature: Quadrature,
|
|
989
|
+
arguments: IntegrandArguments,
|
|
755
990
|
test: Optional[TestField],
|
|
756
|
-
test_name: str,
|
|
757
991
|
trial: Optional[TrialField],
|
|
758
|
-
trial_name: str,
|
|
759
|
-
fields: Dict[str, FieldLike],
|
|
760
992
|
output_dtype: type,
|
|
761
993
|
accumulate_dtype: type,
|
|
762
994
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
763
995
|
) -> wp.Kernel:
|
|
764
|
-
if kernel_options is None:
|
|
765
|
-
kernel_options = {}
|
|
766
|
-
|
|
767
996
|
output_dtype = wp.types.type_scalar_type(output_dtype)
|
|
768
997
|
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
integrand, fields=fields, domain=domain
|
|
772
|
-
)
|
|
998
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
999
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
773
1000
|
|
|
774
|
-
|
|
775
|
-
ValueStruct = cache.get_argument_struct(value_args)
|
|
1001
|
+
_notify_operator_usage(integrand, arguments.field_args)
|
|
776
1002
|
|
|
777
1003
|
# Check if kernel exist in cache
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
kernel_suffix += "_nodal"
|
|
781
|
-
else:
|
|
782
|
-
kernel_suffix += quadrature.name
|
|
1004
|
+
field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
|
|
1005
|
+
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{field_names}"
|
|
783
1006
|
|
|
784
|
-
if
|
|
785
|
-
kernel_suffix +=
|
|
786
|
-
if trial:
|
|
787
|
-
kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
|
|
1007
|
+
if quadrature is not None:
|
|
1008
|
+
kernel_suffix += quadrature.name
|
|
788
1009
|
|
|
789
|
-
kernel = cache.get_integrand_kernel(
|
|
790
|
-
integrand=integrand,
|
|
791
|
-
suffix=kernel_suffix,
|
|
792
|
-
)
|
|
1010
|
+
kernel = cache.get_integrand_kernel(integrand=integrand, suffix=kernel_suffix, kernel_options=kernel_options)
|
|
793
1011
|
if kernel is not None:
|
|
794
1012
|
return kernel, FieldStruct, ValueStruct
|
|
795
1013
|
|
|
796
|
-
# Not found in cache, transform integrand and generate
|
|
1014
|
+
# Not found in cache, transform integrand and generate kernel
|
|
1015
|
+
_check_field_compat(integrand, arguments, domain)
|
|
797
1016
|
|
|
798
|
-
|
|
1017
|
+
integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
|
|
799
1018
|
|
|
800
|
-
|
|
801
|
-
integrand,
|
|
802
|
-
field_args,
|
|
803
|
-
)
|
|
804
|
-
|
|
805
|
-
_register_integrand_field_wrappers(integrand_func, fields)
|
|
1019
|
+
nodal = quadrature is None
|
|
806
1020
|
|
|
807
1021
|
if test is None and trial is None:
|
|
808
1022
|
integrate_kernel_fn = get_integrate_constant_kernel(
|
|
@@ -824,6 +1038,15 @@ def _generate_integrate_kernel(
|
|
|
824
1038
|
output_dtype=output_dtype,
|
|
825
1039
|
accumulate_dtype=accumulate_dtype,
|
|
826
1040
|
)
|
|
1041
|
+
elif isinstance(test, LocalTestField):
|
|
1042
|
+
integrate_kernel_fn = get_integrate_linear_local_kernel(
|
|
1043
|
+
integrand_func,
|
|
1044
|
+
domain,
|
|
1045
|
+
quadrature,
|
|
1046
|
+
FieldStruct,
|
|
1047
|
+
ValueStruct,
|
|
1048
|
+
test=test,
|
|
1049
|
+
)
|
|
827
1050
|
else:
|
|
828
1051
|
integrate_kernel_fn = get_integrate_linear_kernel(
|
|
829
1052
|
integrand_func,
|
|
@@ -846,6 +1069,16 @@ def _generate_integrate_kernel(
|
|
|
846
1069
|
output_dtype=output_dtype,
|
|
847
1070
|
accumulate_dtype=accumulate_dtype,
|
|
848
1071
|
)
|
|
1072
|
+
elif isinstance(test, LocalTestField):
|
|
1073
|
+
integrate_kernel_fn = get_integrate_bilinear_local_kernel(
|
|
1074
|
+
integrand_func,
|
|
1075
|
+
domain,
|
|
1076
|
+
quadrature,
|
|
1077
|
+
FieldStruct,
|
|
1078
|
+
ValueStruct,
|
|
1079
|
+
test=test,
|
|
1080
|
+
trial=trial,
|
|
1081
|
+
)
|
|
849
1082
|
else:
|
|
850
1083
|
integrate_kernel_fn = get_integrate_bilinear_kernel(
|
|
851
1084
|
integrand_func,
|
|
@@ -866,13 +1099,7 @@ def _generate_integrate_kernel(
|
|
|
866
1099
|
kernel_options=kernel_options,
|
|
867
1100
|
code_transformers=[
|
|
868
1101
|
PassFieldArgsToIntegrand(
|
|
869
|
-
arg_names=integrand.argspec.args,
|
|
870
|
-
field_args=field_args.keys(),
|
|
871
|
-
value_args=value_args.keys(),
|
|
872
|
-
sample_name=sample_name,
|
|
873
|
-
domain_name=domain_name,
|
|
874
|
-
test_name=test_name,
|
|
875
|
-
trial_name=trial_name,
|
|
1102
|
+
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
876
1103
|
)
|
|
877
1104
|
],
|
|
878
1105
|
)
|
|
@@ -886,7 +1113,6 @@ def _launch_integrate_kernel(
|
|
|
886
1113
|
FieldStruct: wp.codegen.Struct,
|
|
887
1114
|
ValueStruct: wp.codegen.Struct,
|
|
888
1115
|
domain: GeometryDomain,
|
|
889
|
-
nodal: bool,
|
|
890
1116
|
quadrature: Quadrature,
|
|
891
1117
|
test: Optional[TestField],
|
|
892
1118
|
trial: Optional[TrialField],
|
|
@@ -896,6 +1122,7 @@ def _launch_integrate_kernel(
|
|
|
896
1122
|
temporary_store: Optional[cache.TemporaryStore],
|
|
897
1123
|
output_dtype: type,
|
|
898
1124
|
output: Optional[Union[wp.array, BsrMatrix]],
|
|
1125
|
+
add_to_output: bool,
|
|
899
1126
|
device,
|
|
900
1127
|
):
|
|
901
1128
|
# Set-up launch arguments
|
|
@@ -907,7 +1134,8 @@ def _launch_integrate_kernel(
|
|
|
907
1134
|
|
|
908
1135
|
field_arg_values = FieldStruct()
|
|
909
1136
|
for k, v in fields.items():
|
|
910
|
-
|
|
1137
|
+
if not isinstance(v, GeometryDomain):
|
|
1138
|
+
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
911
1139
|
|
|
912
1140
|
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
913
1141
|
|
|
@@ -927,7 +1155,9 @@ def _launch_integrate_kernel(
|
|
|
927
1155
|
)
|
|
928
1156
|
accumulate_array = accumulate_temporary.array
|
|
929
1157
|
|
|
930
|
-
accumulate_array
|
|
1158
|
+
if output != accumulate_array or not add_to_output:
|
|
1159
|
+
accumulate_array.zero_()
|
|
1160
|
+
|
|
931
1161
|
wp.launch(
|
|
932
1162
|
kernel=kernel,
|
|
933
1163
|
dim=domain.element_count(),
|
|
@@ -944,26 +1174,31 @@ def _launch_integrate_kernel(
|
|
|
944
1174
|
|
|
945
1175
|
if output == accumulate_array:
|
|
946
1176
|
return output
|
|
947
|
-
|
|
1177
|
+
if output is None:
|
|
948
1178
|
return accumulate_array.numpy()[0]
|
|
1179
|
+
|
|
1180
|
+
if add_to_output:
|
|
1181
|
+
# accumulate dtype is distinct from output dtype
|
|
1182
|
+
array_axpy(x=accumulate_array, y=output)
|
|
949
1183
|
else:
|
|
950
1184
|
array_cast(in_array=accumulate_array, out_array=output)
|
|
951
|
-
|
|
1185
|
+
return output
|
|
952
1186
|
|
|
953
1187
|
test_arg = test.space_restriction.node_arg(device=device)
|
|
1188
|
+
nodal = quadrature is None
|
|
954
1189
|
|
|
955
1190
|
# Linear form
|
|
956
1191
|
if trial is None:
|
|
957
1192
|
# If an output array is provided with the correct type, accumulate directly into it
|
|
958
1193
|
# Otherwise, grab a temporary array
|
|
959
1194
|
if output is None:
|
|
960
|
-
if type_length(output_dtype) == test.
|
|
1195
|
+
if type_length(output_dtype) == test.node_dof_count:
|
|
961
1196
|
output_shape = (test.space_partition.node_count(),)
|
|
962
1197
|
elif type_length(output_dtype) == 1:
|
|
963
|
-
output_shape = (test.space_partition.node_count(), test.
|
|
1198
|
+
output_shape = (test.space_partition.node_count(), test.node_dof_count)
|
|
964
1199
|
else:
|
|
965
1200
|
raise RuntimeError(
|
|
966
|
-
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.
|
|
1201
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
967
1202
|
)
|
|
968
1203
|
|
|
969
1204
|
output_temporary = cache.borrow_temporary(
|
|
@@ -982,18 +1217,19 @@ def _launch_integrate_kernel(
|
|
|
982
1217
|
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
983
1218
|
|
|
984
1219
|
output_dtype = output.dtype
|
|
985
|
-
if type_length(output_dtype) != test.
|
|
1220
|
+
if type_length(output_dtype) != test.node_dof_count:
|
|
986
1221
|
if type_length(output_dtype) != 1:
|
|
987
1222
|
raise RuntimeError(
|
|
988
|
-
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.
|
|
1223
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
989
1224
|
)
|
|
990
|
-
if output.ndim != 2 and output.shape[1] != test.
|
|
1225
|
+
if output.ndim != 2 and output.shape[1] != test.node_dof_count:
|
|
991
1226
|
raise RuntimeError(
|
|
992
|
-
f"Incompatible output array shape, last dimension must be of size {test.
|
|
1227
|
+
f"Incompatible output array shape, last dimension must be of size {test.node_dof_count}"
|
|
993
1228
|
)
|
|
994
1229
|
|
|
995
1230
|
# Launch the integration on the kernel on a 2d scalar view of the actual array
|
|
996
|
-
|
|
1231
|
+
if not add_to_output:
|
|
1232
|
+
output.zero_()
|
|
997
1233
|
|
|
998
1234
|
def as_2d_array(array):
|
|
999
1235
|
return wp.array(
|
|
@@ -1001,7 +1237,7 @@ def _launch_integrate_kernel(
|
|
|
1001
1237
|
ptr=array.ptr,
|
|
1002
1238
|
capacity=array.capacity,
|
|
1003
1239
|
device=array.device,
|
|
1004
|
-
shape=(test.space_partition.node_count(), test.
|
|
1240
|
+
shape=(test.space_partition.node_count(), test.node_dof_count),
|
|
1005
1241
|
dtype=wp.types.type_scalar_type(output_dtype),
|
|
1006
1242
|
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
1007
1243
|
)
|
|
@@ -1011,7 +1247,7 @@ def _launch_integrate_kernel(
|
|
|
1011
1247
|
if nodal:
|
|
1012
1248
|
wp.launch(
|
|
1013
1249
|
kernel=kernel,
|
|
1014
|
-
dim=(test.space_restriction.node_count(), test.
|
|
1250
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1015
1251
|
inputs=[
|
|
1016
1252
|
domain_elt_arg,
|
|
1017
1253
|
domain_elt_index_arg,
|
|
@@ -1023,10 +1259,51 @@ def _launch_integrate_kernel(
|
|
|
1023
1259
|
],
|
|
1024
1260
|
device=device,
|
|
1025
1261
|
)
|
|
1262
|
+
elif isinstance(test, LocalTestField):
|
|
1263
|
+
local_result = cache.borrow_temporary(
|
|
1264
|
+
temporary_store=temporary_store,
|
|
1265
|
+
device=device,
|
|
1266
|
+
requires_grad=output.requires_grad,
|
|
1267
|
+
shape=(quadrature.total_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
|
|
1268
|
+
dtype=float,
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
wp.launch(
|
|
1272
|
+
kernel=kernel,
|
|
1273
|
+
dim=(domain.element_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
|
|
1274
|
+
inputs=[
|
|
1275
|
+
qp_arg,
|
|
1276
|
+
domain_elt_arg,
|
|
1277
|
+
domain_elt_index_arg,
|
|
1278
|
+
field_arg_values,
|
|
1279
|
+
value_struct_values,
|
|
1280
|
+
local_result.array,
|
|
1281
|
+
],
|
|
1282
|
+
device=device,
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
|
|
1286
|
+
wp.launch(
|
|
1287
|
+
kernel=dispatch_kernel,
|
|
1288
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1289
|
+
inputs=[
|
|
1290
|
+
qp_arg,
|
|
1291
|
+
domain_elt_arg,
|
|
1292
|
+
domain_elt_index_arg,
|
|
1293
|
+
test_arg,
|
|
1294
|
+
test.global_field.eval_arg_value(device),
|
|
1295
|
+
local_result.array,
|
|
1296
|
+
output_view,
|
|
1297
|
+
],
|
|
1298
|
+
device=device,
|
|
1299
|
+
)
|
|
1300
|
+
|
|
1301
|
+
local_result.release()
|
|
1302
|
+
|
|
1026
1303
|
else:
|
|
1027
1304
|
wp.launch(
|
|
1028
1305
|
kernel=kernel,
|
|
1029
|
-
dim=(test.space_restriction.node_count(), test.
|
|
1306
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1030
1307
|
inputs=[
|
|
1031
1308
|
qp_arg,
|
|
1032
1309
|
domain_elt_arg,
|
|
@@ -1046,12 +1323,10 @@ def _launch_integrate_kernel(
|
|
|
1046
1323
|
|
|
1047
1324
|
# Bilinear form
|
|
1048
1325
|
|
|
1049
|
-
if test.
|
|
1326
|
+
if test.node_dof_count == 1 and trial.node_dof_count == 1:
|
|
1050
1327
|
block_type = output_dtype
|
|
1051
1328
|
else:
|
|
1052
|
-
block_type = cache.cached_mat_type(
|
|
1053
|
-
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
|
|
1054
|
-
)
|
|
1329
|
+
block_type = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=output_dtype)
|
|
1055
1330
|
|
|
1056
1331
|
if nodal:
|
|
1057
1332
|
nnz = test.space_restriction.node_count()
|
|
@@ -1064,8 +1339,8 @@ def _launch_integrate_kernel(
|
|
|
1064
1339
|
temporary_store,
|
|
1065
1340
|
shape=(
|
|
1066
1341
|
nnz,
|
|
1067
|
-
test.
|
|
1068
|
-
trial.
|
|
1342
|
+
test.node_dof_count,
|
|
1343
|
+
trial.node_dof_count,
|
|
1069
1344
|
),
|
|
1070
1345
|
dtype=output_dtype,
|
|
1071
1346
|
device=device,
|
|
@@ -1093,6 +1368,75 @@ def _launch_integrate_kernel(
|
|
|
1093
1368
|
],
|
|
1094
1369
|
device=device,
|
|
1095
1370
|
)
|
|
1371
|
+
elif isinstance(test, LocalTestField):
|
|
1372
|
+
local_result = cache.borrow_temporary(
|
|
1373
|
+
temporary_store=temporary_store,
|
|
1374
|
+
device=device,
|
|
1375
|
+
requires_grad=False,
|
|
1376
|
+
shape=(
|
|
1377
|
+
quadrature.total_point_count(),
|
|
1378
|
+
test.value_dof_count,
|
|
1379
|
+
trial.value_dof_count,
|
|
1380
|
+
test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
|
|
1381
|
+
),
|
|
1382
|
+
dtype=float,
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
wp.launch(
|
|
1386
|
+
kernel=kernel,
|
|
1387
|
+
dim=(domain.element_count(), test.value_dof_count, trial.value_dof_count, trial.TAYLOR_DOF_COUNT),
|
|
1388
|
+
inputs=[
|
|
1389
|
+
qp_arg,
|
|
1390
|
+
domain_elt_arg,
|
|
1391
|
+
domain_elt_index_arg,
|
|
1392
|
+
field_arg_values,
|
|
1393
|
+
value_struct_values,
|
|
1394
|
+
local_result.array,
|
|
1395
|
+
],
|
|
1396
|
+
device=device,
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
|
|
1400
|
+
vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
|
|
1401
|
+
local_result_as_vec = wp.array(
|
|
1402
|
+
data=None,
|
|
1403
|
+
ptr=local_result.array.ptr,
|
|
1404
|
+
capacity=local_result.array.capacity,
|
|
1405
|
+
device=local_result.array.device,
|
|
1406
|
+
shape=vec_array_shape,
|
|
1407
|
+
dtype=vec_array_dtype,
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
|
|
1411
|
+
|
|
1412
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1413
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1414
|
+
wp.launch(
|
|
1415
|
+
kernel=dispatch_kernel,
|
|
1416
|
+
dim=(
|
|
1417
|
+
test.space_restriction.node_count(),
|
|
1418
|
+
test.node_dof_count,
|
|
1419
|
+
trial.node_dof_count,
|
|
1420
|
+
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1421
|
+
),
|
|
1422
|
+
inputs=[
|
|
1423
|
+
qp_arg,
|
|
1424
|
+
domain_elt_arg,
|
|
1425
|
+
domain_elt_index_arg,
|
|
1426
|
+
test_arg,
|
|
1427
|
+
test.global_field.eval_arg_value(device),
|
|
1428
|
+
trial_partition_arg,
|
|
1429
|
+
trial_topology_arg,
|
|
1430
|
+
trial.global_field.eval_arg_value(device),
|
|
1431
|
+
local_result_as_vec,
|
|
1432
|
+
triplet_rows,
|
|
1433
|
+
triplet_cols,
|
|
1434
|
+
triplet_values,
|
|
1435
|
+
],
|
|
1436
|
+
device=device,
|
|
1437
|
+
)
|
|
1438
|
+
|
|
1439
|
+
local_result.release()
|
|
1096
1440
|
|
|
1097
1441
|
else:
|
|
1098
1442
|
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
@@ -1102,8 +1446,8 @@ def _launch_integrate_kernel(
|
|
|
1102
1446
|
dim=(
|
|
1103
1447
|
test.space_restriction.node_count(),
|
|
1104
1448
|
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1105
|
-
test.
|
|
1106
|
-
trial.
|
|
1449
|
+
test.node_dof_count,
|
|
1450
|
+
trial.node_dof_count,
|
|
1107
1451
|
),
|
|
1108
1452
|
inputs=[
|
|
1109
1453
|
qp_arg,
|
|
@@ -1127,24 +1471,48 @@ def _launch_integrate_kernel(
|
|
|
1127
1471
|
f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
1128
1472
|
)
|
|
1129
1473
|
|
|
1130
|
-
|
|
1131
|
-
|
|
1474
|
+
if output is None or add_to_output:
|
|
1475
|
+
bsr_result = bsr_zeros(
|
|
1132
1476
|
rows_of_blocks=test.space_partition.node_count(),
|
|
1133
1477
|
cols_of_blocks=trial.space_partition.node_count(),
|
|
1134
1478
|
block_type=block_type,
|
|
1135
1479
|
device=device,
|
|
1136
1480
|
)
|
|
1481
|
+
else:
|
|
1482
|
+
bsr_result = output
|
|
1137
1483
|
|
|
1138
|
-
bsr_set_from_triplets(
|
|
1484
|
+
bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values)
|
|
1139
1485
|
|
|
1140
1486
|
# Do not wait for garbage collection
|
|
1141
1487
|
triplet_values_temp.release()
|
|
1142
1488
|
triplet_rows_temp.release()
|
|
1143
1489
|
triplet_cols_temp.release()
|
|
1144
1490
|
|
|
1491
|
+
if add_to_output:
|
|
1492
|
+
output += bsr_result
|
|
1493
|
+
else:
|
|
1494
|
+
output = bsr_result
|
|
1495
|
+
|
|
1145
1496
|
return output
|
|
1146
1497
|
|
|
1147
1498
|
|
|
1499
|
+
def _pick_assembly_strategy(
|
|
1500
|
+
assembly: Optional[str], nodal: bool, operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
|
|
1501
|
+
):
|
|
1502
|
+
if assembly is not None:
|
|
1503
|
+
if assembly not in ("generic", "nodal", "dispatch"):
|
|
1504
|
+
raise ValueError(f"Invalid assembly strategy'{assembly}'")
|
|
1505
|
+
return assembly
|
|
1506
|
+
elif nodal:
|
|
1507
|
+
return "nodal"
|
|
1508
|
+
|
|
1509
|
+
test_operators = operators.get(arguments.test_name, {})
|
|
1510
|
+
trial_operators = operators.get(arguments.trial_name, {})
|
|
1511
|
+
uses_at_node = at_node in test_operators or at_node in trial_operators
|
|
1512
|
+
|
|
1513
|
+
return "generic" if uses_at_node else "dispatch"
|
|
1514
|
+
|
|
1515
|
+
|
|
1148
1516
|
def integrate(
|
|
1149
1517
|
integrand: Integrand,
|
|
1150
1518
|
domain: Optional[GeometryDomain] = None,
|
|
@@ -1158,6 +1526,8 @@ def integrate(
|
|
|
1158
1526
|
device=None,
|
|
1159
1527
|
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
1160
1528
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1529
|
+
assembly: str = None,
|
|
1530
|
+
add: bool = False,
|
|
1161
1531
|
):
|
|
1162
1532
|
"""
|
|
1163
1533
|
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
|
|
@@ -1166,7 +1536,7 @@ def integrate(
|
|
|
1166
1536
|
integrand: Form to be integrated, must have :func:`integrand` decorator
|
|
1167
1537
|
domain: Integration domain. If None, deduced from fields
|
|
1168
1538
|
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1169
|
-
nodal:
|
|
1539
|
+
nodal: Deprecated. Use the equivalent assembly="nodal" instead.
|
|
1170
1540
|
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1171
1541
|
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1172
1542
|
temporary_store: shared pool from which to allocate temporary arrays
|
|
@@ -1175,6 +1545,12 @@ def integrate(
|
|
|
1175
1545
|
output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
|
|
1176
1546
|
device: Device on which to perform the integration
|
|
1177
1547
|
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1548
|
+
assembly: Specifies the strategy for assembling the integrated vector or matrix:
|
|
1549
|
+
- "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1550
|
+
- "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
|
|
1551
|
+
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
|
|
1552
|
+
- `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
|
|
1553
|
+
add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
|
|
1178
1554
|
"""
|
|
1179
1555
|
if fields is None:
|
|
1180
1556
|
fields = {}
|
|
@@ -1182,13 +1558,22 @@ def integrate(
|
|
|
1182
1558
|
if values is None:
|
|
1183
1559
|
values = {}
|
|
1184
1560
|
|
|
1185
|
-
if kernel_options is None:
|
|
1186
|
-
kernel_options = {}
|
|
1187
|
-
|
|
1188
1561
|
if not isinstance(integrand, Integrand):
|
|
1189
1562
|
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
|
|
1190
1563
|
|
|
1191
|
-
test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
|
|
1564
|
+
# test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
|
|
1565
|
+
arguments = _parse_integrand_arguments(integrand, fields)
|
|
1566
|
+
|
|
1567
|
+
test = None
|
|
1568
|
+
if arguments.test_name:
|
|
1569
|
+
test = arguments.field_args[arguments.test_name]
|
|
1570
|
+
trial = None
|
|
1571
|
+
if arguments.trial_name:
|
|
1572
|
+
if test is None:
|
|
1573
|
+
raise ValueError("A trial field cannot be provided without a test field")
|
|
1574
|
+
trial = arguments.field_args[arguments.trial_name]
|
|
1575
|
+
if test.domain != trial.domain:
|
|
1576
|
+
raise ValueError("Incompatible test and trial domains")
|
|
1192
1577
|
|
|
1193
1578
|
if domain is None:
|
|
1194
1579
|
if quadrature is not None:
|
|
@@ -1201,7 +1586,26 @@ def integrate(
|
|
|
1201
1586
|
if test is not None and domain != test.domain:
|
|
1202
1587
|
raise NotImplementedError("Mixing integration and test domain is not supported yet")
|
|
1203
1588
|
|
|
1204
|
-
if
|
|
1589
|
+
if add and output is None:
|
|
1590
|
+
raise ValueError("An 'output' array or matrix needs to be provided for add=True")
|
|
1591
|
+
|
|
1592
|
+
if arguments.domain_name is not None:
|
|
1593
|
+
arguments.field_args[arguments.domain_name] = domain
|
|
1594
|
+
|
|
1595
|
+
_find_integrand_operators(integrand, arguments.field_args)
|
|
1596
|
+
|
|
1597
|
+
assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
|
|
1598
|
+
# print("assembly for ", integrand.name, ":", strategy)
|
|
1599
|
+
|
|
1600
|
+
if assembly == "dispatch":
|
|
1601
|
+
if test is not None:
|
|
1602
|
+
test = LocalTestField(test)
|
|
1603
|
+
arguments.field_args[arguments.test_name] = test
|
|
1604
|
+
if trial is not None:
|
|
1605
|
+
trial = LocalTrialField(trial)
|
|
1606
|
+
arguments.field_args[arguments.trial_name] = trial
|
|
1607
|
+
|
|
1608
|
+
if assembly == "nodal":
|
|
1205
1609
|
if quadrature is not None:
|
|
1206
1610
|
raise ValueError("Cannot specify quadrature for nodal integration")
|
|
1207
1611
|
|
|
@@ -1234,13 +1638,10 @@ def integrate(
|
|
|
1234
1638
|
kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
|
|
1235
1639
|
integrand=integrand,
|
|
1236
1640
|
domain=domain,
|
|
1237
|
-
nodal=nodal,
|
|
1238
1641
|
quadrature=quadrature,
|
|
1642
|
+
arguments=arguments,
|
|
1239
1643
|
test=test,
|
|
1240
|
-
test_name=test_name,
|
|
1241
1644
|
trial=trial,
|
|
1242
|
-
trial_name=trial_name,
|
|
1243
|
-
fields=fields,
|
|
1244
1645
|
accumulate_dtype=accumulate_dtype,
|
|
1245
1646
|
output_dtype=output_dtype,
|
|
1246
1647
|
kernel_options=kernel_options,
|
|
@@ -1252,16 +1653,16 @@ def integrate(
|
|
|
1252
1653
|
FieldStruct=FieldStruct,
|
|
1253
1654
|
ValueStruct=ValueStruct,
|
|
1254
1655
|
domain=domain,
|
|
1255
|
-
nodal=nodal,
|
|
1256
1656
|
quadrature=quadrature,
|
|
1257
1657
|
test=test,
|
|
1258
1658
|
trial=trial,
|
|
1259
|
-
fields=
|
|
1659
|
+
fields=arguments.field_args,
|
|
1260
1660
|
values=values,
|
|
1261
1661
|
accumulate_dtype=accumulate_dtype,
|
|
1262
1662
|
temporary_store=temporary_store,
|
|
1263
1663
|
output_dtype=output_dtype,
|
|
1264
1664
|
output=output,
|
|
1665
|
+
add_to_output=add,
|
|
1265
1666
|
device=device,
|
|
1266
1667
|
)
|
|
1267
1668
|
|
|
@@ -1340,6 +1741,30 @@ def get_interpolate_to_field_kernel(
|
|
|
1340
1741
|
ValueStruct: wp.codegen.Struct,
|
|
1341
1742
|
dest: FieldRestriction,
|
|
1342
1743
|
):
|
|
1744
|
+
@wp.func
|
|
1745
|
+
def _find_node_in_element(
|
|
1746
|
+
domain_arg: domain.ElementArg,
|
|
1747
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1748
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1749
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1750
|
+
partition_node_index: int,
|
|
1751
|
+
):
|
|
1752
|
+
element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
|
|
1753
|
+
|
|
1754
|
+
for n in range(element_beg, element_end):
|
|
1755
|
+
node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
|
|
1756
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
1757
|
+
coords = dest.space.node_coords_in_element(
|
|
1758
|
+
domain_arg,
|
|
1759
|
+
dest_eval_arg.space_arg,
|
|
1760
|
+
element_index,
|
|
1761
|
+
node_element_index.node_index_in_element,
|
|
1762
|
+
)
|
|
1763
|
+
if coords[0] != OUTSIDE:
|
|
1764
|
+
return element_index, node_element_index.node_index_in_element
|
|
1765
|
+
|
|
1766
|
+
return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
|
|
1767
|
+
|
|
1343
1768
|
def interpolate_to_field_kernel_fn(
|
|
1344
1769
|
domain_arg: domain.ElementArg,
|
|
1345
1770
|
domain_index_arg: domain.ElementIndexArg,
|
|
@@ -1355,8 +1780,20 @@ def get_interpolate_to_field_kernel(
|
|
|
1355
1780
|
)
|
|
1356
1781
|
|
|
1357
1782
|
if vol_sum > 0.0:
|
|
1358
|
-
|
|
1359
|
-
|
|
1783
|
+
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1784
|
+
|
|
1785
|
+
# Grab first element containing node; there must be at least one since vol_sum != 0
|
|
1786
|
+
element_index, node_index_in_element = _find_node_in_element(
|
|
1787
|
+
domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
|
|
1788
|
+
)
|
|
1789
|
+
dest.field.set_node_value(
|
|
1790
|
+
domain_arg,
|
|
1791
|
+
dest_eval_arg,
|
|
1792
|
+
element_index,
|
|
1793
|
+
node_index_in_element,
|
|
1794
|
+
partition_node_index,
|
|
1795
|
+
val_sum / vol_sum,
|
|
1796
|
+
)
|
|
1360
1797
|
|
|
1361
1798
|
return interpolate_to_field_kernel_fn
|
|
1362
1799
|
|
|
@@ -1470,53 +1907,43 @@ def _generate_interpolate_kernel(
|
|
|
1470
1907
|
domain: GeometryDomain,
|
|
1471
1908
|
dest: Optional[Union[FieldLike, wp.array]],
|
|
1472
1909
|
quadrature: Optional[Quadrature],
|
|
1473
|
-
|
|
1910
|
+
arguments: IntegrandArguments,
|
|
1474
1911
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1475
1912
|
) -> wp.Kernel:
|
|
1476
|
-
if kernel_options is None:
|
|
1477
|
-
kernel_options = {}
|
|
1478
|
-
|
|
1479
|
-
# Extract field arguments from integrand
|
|
1480
|
-
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
1481
|
-
integrand, fields=fields, domain=domain
|
|
1482
|
-
)
|
|
1483
|
-
|
|
1484
1913
|
# Generate field struct
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
field_args,
|
|
1488
|
-
)
|
|
1489
|
-
|
|
1490
|
-
_register_integrand_field_wrappers(integrand_func, fields)
|
|
1914
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
1915
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
1491
1916
|
|
|
1492
|
-
|
|
1493
|
-
ValueStruct = cache.get_argument_struct(value_args)
|
|
1917
|
+
_notify_operator_usage(integrand, arguments.field_args)
|
|
1494
1918
|
|
|
1495
1919
|
# Check if kernel exist in cache
|
|
1920
|
+
field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
|
|
1496
1921
|
if isinstance(dest, FieldRestriction):
|
|
1497
|
-
kernel_suffix =
|
|
1498
|
-
f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
|
|
1499
|
-
)
|
|
1922
|
+
kernel_suffix = f"_itp_{field_names}_{dest.domain.name}_{dest.space_restriction.space_partition.name}"
|
|
1500
1923
|
else:
|
|
1501
1924
|
dest_dtype = dest.dtype if dest else None
|
|
1502
1925
|
type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
|
|
1503
1926
|
if quadrature is None:
|
|
1504
|
-
kernel_suffix = f"_itp_{
|
|
1927
|
+
kernel_suffix = f"_itp_{field_names}_{type_str}"
|
|
1505
1928
|
else:
|
|
1506
|
-
kernel_suffix = f"_itp_{
|
|
1929
|
+
kernel_suffix = f"_itp_{field_names}_{quadrature.name}_{type_str}"
|
|
1507
1930
|
|
|
1508
1931
|
kernel = cache.get_integrand_kernel(
|
|
1509
1932
|
integrand=integrand,
|
|
1510
1933
|
suffix=kernel_suffix,
|
|
1934
|
+
kernel_options=kernel_options,
|
|
1511
1935
|
)
|
|
1512
1936
|
if kernel is not None:
|
|
1513
1937
|
return kernel, FieldStruct, ValueStruct
|
|
1514
1938
|
|
|
1515
|
-
|
|
1939
|
+
# Not found in cache, transform integrand and generate kernel
|
|
1940
|
+
_check_field_compat(integrand, arguments, domain)
|
|
1941
|
+
|
|
1942
|
+
integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
|
|
1516
1943
|
|
|
1517
1944
|
# Generate interpolation kernel
|
|
1518
1945
|
if isinstance(dest, FieldRestriction):
|
|
1519
|
-
# need to split into kernel + function for
|
|
1946
|
+
# need to split into kernel + function for differentiability
|
|
1520
1947
|
interpolate_fn = get_interpolate_to_field_function(
|
|
1521
1948
|
integrand_func,
|
|
1522
1949
|
domain,
|
|
@@ -1531,11 +1958,7 @@ def _generate_interpolate_kernel(
|
|
|
1531
1958
|
suffix=kernel_suffix,
|
|
1532
1959
|
code_transformers=[
|
|
1533
1960
|
PassFieldArgsToIntegrand(
|
|
1534
|
-
arg_names=integrand.argspec.args,
|
|
1535
|
-
field_args=field_args.keys(),
|
|
1536
|
-
value_args=value_args.keys(),
|
|
1537
|
-
sample_name=sample_name,
|
|
1538
|
-
domain_name=domain_name,
|
|
1961
|
+
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
1539
1962
|
)
|
|
1540
1963
|
],
|
|
1541
1964
|
)
|
|
@@ -1572,11 +1995,7 @@ def _generate_interpolate_kernel(
|
|
|
1572
1995
|
kernel_options=kernel_options,
|
|
1573
1996
|
code_transformers=[
|
|
1574
1997
|
PassFieldArgsToIntegrand(
|
|
1575
|
-
arg_names=integrand.argspec.args,
|
|
1576
|
-
field_args=field_args.keys(),
|
|
1577
|
-
value_args=value_args.keys(),
|
|
1578
|
-
sample_name=sample_name,
|
|
1579
|
-
domain_name=domain_name,
|
|
1998
|
+
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
1580
1999
|
)
|
|
1581
2000
|
],
|
|
1582
2001
|
)
|
|
@@ -1603,7 +2022,8 @@ def _launch_interpolate_kernel(
|
|
|
1603
2022
|
|
|
1604
2023
|
field_arg_values = FieldStruct()
|
|
1605
2024
|
for k, v in fields.items():
|
|
1606
|
-
|
|
2025
|
+
if not isinstance(v, GeometryDomain):
|
|
2026
|
+
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
1607
2027
|
|
|
1608
2028
|
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
1609
2029
|
|
|
@@ -1687,14 +2107,11 @@ def interpolate(
|
|
|
1687
2107
|
if values is None:
|
|
1688
2108
|
values = {}
|
|
1689
2109
|
|
|
1690
|
-
if kernel_options is None:
|
|
1691
|
-
kernel_options = {}
|
|
1692
|
-
|
|
1693
2110
|
if not isinstance(integrand, Integrand):
|
|
1694
2111
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
1695
2112
|
|
|
1696
|
-
|
|
1697
|
-
if
|
|
2113
|
+
arguments = _parse_integrand_arguments(integrand, fields)
|
|
2114
|
+
if arguments.test_name or arguments.trial_name:
|
|
1698
2115
|
raise ValueError("Test or Trial fields should not be used for interpolation")
|
|
1699
2116
|
|
|
1700
2117
|
if isinstance(dest, DiscreteField):
|
|
@@ -1705,12 +2122,17 @@ def interpolate(
|
|
|
1705
2122
|
elif quadrature is not None:
|
|
1706
2123
|
domain = quadrature.domain
|
|
1707
2124
|
|
|
2125
|
+
if arguments.domain_name:
|
|
2126
|
+
arguments.field_args[arguments.domain_name] = domain
|
|
2127
|
+
|
|
2128
|
+
_find_integrand_operators(integrand, arguments.field_args)
|
|
2129
|
+
|
|
1708
2130
|
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
1709
2131
|
integrand=integrand,
|
|
1710
2132
|
domain=domain,
|
|
1711
2133
|
dest=dest,
|
|
1712
2134
|
quadrature=quadrature,
|
|
1713
|
-
|
|
2135
|
+
arguments=arguments,
|
|
1714
2136
|
kernel_options=kernel_options,
|
|
1715
2137
|
)
|
|
1716
2138
|
|
|
@@ -1723,7 +2145,7 @@ def interpolate(
|
|
|
1723
2145
|
dest=dest,
|
|
1724
2146
|
quadrature=quadrature,
|
|
1725
2147
|
dim=dim,
|
|
1726
|
-
fields=
|
|
2148
|
+
fields=arguments.field_args,
|
|
1727
2149
|
values=values,
|
|
1728
2150
|
device=device,
|
|
1729
2151
|
)
|