warp-lang 1.2.2__py3-none-macosx_10_13_universal2.whl → 1.3.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1410 -886
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +66 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
- warp_lang-1.3.0.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -9,14 +9,25 @@ from warp.fem.field import (
|
|
|
9
9
|
DiscreteField,
|
|
10
10
|
FieldLike,
|
|
11
11
|
FieldRestriction,
|
|
12
|
-
|
|
12
|
+
GeometryField,
|
|
13
13
|
TestField,
|
|
14
14
|
TrialField,
|
|
15
15
|
make_restriction,
|
|
16
16
|
)
|
|
17
|
-
from warp.fem.operator import Integrand, Operator
|
|
17
|
+
from warp.fem.operator import Integrand, Operator, integrand
|
|
18
18
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
19
|
-
from warp.fem.types import
|
|
19
|
+
from warp.fem.types import (
|
|
20
|
+
NULL_DOF_INDEX,
|
|
21
|
+
NULL_ELEMENT_INDEX,
|
|
22
|
+
NULL_NODE_INDEX,
|
|
23
|
+
OUTSIDE,
|
|
24
|
+
Coords,
|
|
25
|
+
DofIndex,
|
|
26
|
+
Domain,
|
|
27
|
+
Field,
|
|
28
|
+
Sample,
|
|
29
|
+
make_free_sample,
|
|
30
|
+
)
|
|
20
31
|
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
21
32
|
from warp.types import type_length
|
|
22
33
|
from warp.utils import array_cast
|
|
@@ -58,24 +69,11 @@ def _resolve_path(func, node):
|
|
|
58
69
|
return None, path
|
|
59
70
|
|
|
60
71
|
|
|
61
|
-
def _path_to_ast_attribute(name: str) -> ast.Attribute:
|
|
62
|
-
path = name.split(".")
|
|
63
|
-
path.reverse()
|
|
64
|
-
|
|
65
|
-
node = ast.Name(id=path.pop(), ctx=ast.Load())
|
|
66
|
-
while len(path):
|
|
67
|
-
node = ast.Attribute(
|
|
68
|
-
value=node,
|
|
69
|
-
attr=path.pop(),
|
|
70
|
-
ctx=ast.Load(),
|
|
71
|
-
)
|
|
72
|
-
return node
|
|
73
|
-
|
|
74
|
-
|
|
75
72
|
class IntegrandTransformer(ast.NodeTransformer):
|
|
76
|
-
def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
73
|
+
def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike], annotations: Dict[str, Any]):
|
|
77
74
|
self._integrand = integrand
|
|
78
75
|
self._field_args = field_args
|
|
76
|
+
self._annotations = annotations
|
|
79
77
|
|
|
80
78
|
def visit_Call(self, call: ast.Call):
|
|
81
79
|
call = self.generic_visit(call)
|
|
@@ -85,18 +83,15 @@ class IntegrandTransformer(ast.NodeTransformer):
|
|
|
85
83
|
# Shortcut for evaluating fields as f(x...)
|
|
86
84
|
field = self._field_args[callee]
|
|
87
85
|
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
# Replace with default call operator
|
|
87
|
+
abstract_arg_type = self._integrand.argspec.annotations[callee]
|
|
88
|
+
default_operator = abstract_arg_type.call_operator
|
|
89
|
+
concrete_arg_type = self._annotations[callee]
|
|
90
|
+
self._replace_call_func(call, concrete_arg_type, default_operator, field)
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
|
|
93
|
-
attr="call_operator",
|
|
94
|
-
ctx=ast.Load(),
|
|
95
|
-
)
|
|
92
|
+
# insert callee as first argument
|
|
96
93
|
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
|
|
97
94
|
|
|
98
|
-
self._replace_call_func(call, operator, field)
|
|
99
|
-
|
|
100
95
|
return call
|
|
101
96
|
|
|
102
97
|
func, _ = _resolve_path(self._integrand.func, call.func)
|
|
@@ -106,7 +101,7 @@ class IntegrandTransformer(ast.NodeTransformer):
|
|
|
106
101
|
callee = getattr(call.args[0], "id", None)
|
|
107
102
|
if callee in self._field_args:
|
|
108
103
|
field = self._field_args[callee]
|
|
109
|
-
self._replace_call_func(call, func, field)
|
|
104
|
+
self._replace_call_func(call, func, func, field)
|
|
110
105
|
|
|
111
106
|
if isinstance(func, Integrand):
|
|
112
107
|
key = self._translate_callee(func, call.args)
|
|
@@ -120,12 +115,18 @@ class IntegrandTransformer(ast.NodeTransformer):
|
|
|
120
115
|
|
|
121
116
|
return call
|
|
122
117
|
|
|
123
|
-
def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
|
|
118
|
+
def _replace_call_func(self, call: ast.Call, callee: Union[type, Operator], operator: Operator, field: FieldLike):
|
|
124
119
|
try:
|
|
120
|
+
# Retrieve the function pointer corresponding to the operator implementation for the field type
|
|
125
121
|
pointer = operator.resolver(field)
|
|
126
|
-
|
|
127
|
-
|
|
122
|
+
if pointer is None:
|
|
123
|
+
raise NotImplementedError(operator.resolver.__name__)
|
|
124
|
+
|
|
125
|
+
except (AttributeError, NotImplementedError) as e:
|
|
128
126
|
raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}") from e
|
|
127
|
+
# Save the pointer as an attribute than can be accessed from the callee scope
|
|
128
|
+
setattr(callee, pointer.key, pointer)
|
|
129
|
+
# Update the ast Call node to use the new function pointer
|
|
129
130
|
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
|
|
130
131
|
|
|
131
132
|
def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
|
|
@@ -162,7 +163,7 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
|
|
|
162
163
|
annotations[arg] = arg_type
|
|
163
164
|
|
|
164
165
|
# Transform field evaluation calls
|
|
165
|
-
transformer = IntegrandTransformer(integrand, field_args)
|
|
166
|
+
transformer = IntegrandTransformer(integrand, field_args, annotations)
|
|
166
167
|
|
|
167
168
|
suffix = "_".join([f.name for f in field_args.values()])
|
|
168
169
|
|
|
@@ -215,46 +216,22 @@ def _check_field_compat(
|
|
|
215
216
|
field_args: Dict[str, FieldLike],
|
|
216
217
|
domain: GeometryDomain = None,
|
|
217
218
|
):
|
|
218
|
-
# Check field
|
|
219
|
+
# Check field compatibility
|
|
219
220
|
for name, field in fields.items():
|
|
220
221
|
if name not in field_args:
|
|
221
222
|
raise ValueError(
|
|
222
223
|
f"Passed field argument '{name}' does not match any parameter of integrand '{integrand.name}'"
|
|
223
224
|
)
|
|
224
225
|
|
|
225
|
-
if isinstance(field,
|
|
226
|
-
|
|
227
|
-
if space.geometry != domain.geometry:
|
|
226
|
+
if isinstance(field, GeometryField) and domain is not None:
|
|
227
|
+
if field.geometry != domain.geometry:
|
|
228
228
|
raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
|
|
229
|
-
if
|
|
229
|
+
if field.element_kind != domain.element_kind:
|
|
230
230
|
raise ValueError(
|
|
231
|
-
f"Field '{name}'
|
|
231
|
+
f"Field '{name}' is not defined on the same kind of elements (cells or sides) as the integration domain. Maybe a forgotten `.trace()`?"
|
|
232
232
|
)
|
|
233
233
|
|
|
234
234
|
|
|
235
|
-
def _populate_value_struct(ValueStruct: wp.codegen.Struct, values: Dict[str, Any], integrand_name: str):
|
|
236
|
-
value_struct_values = ValueStruct()
|
|
237
|
-
for k, v in values.items():
|
|
238
|
-
try:
|
|
239
|
-
setattr(value_struct_values, k, v)
|
|
240
|
-
except Exception as err:
|
|
241
|
-
if k not in ValueStruct.vars:
|
|
242
|
-
raise ValueError(
|
|
243
|
-
f"Passed value argument '{k}' does not match any of the integrand '{integrand_name}' parameters"
|
|
244
|
-
) from err
|
|
245
|
-
raise ValueError(
|
|
246
|
-
f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the integrand '{integrand_name}' parameter of type '{wp.types.type_repr(ValueStruct.vars[k].type)}'"
|
|
247
|
-
) from err
|
|
248
|
-
|
|
249
|
-
missing_values = ValueStruct.vars.keys() - values.keys()
|
|
250
|
-
if missing_values:
|
|
251
|
-
wp.utils.warn(
|
|
252
|
-
f"Missing values for parameter(s) '{', '.join(missing_values)}' of the integrand '{integrand_name}', will be zero-initialized"
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
return value_struct_values
|
|
256
|
-
|
|
257
|
-
|
|
258
235
|
def _get_test_and_trial_fields(
|
|
259
236
|
fields: Dict[str, FieldLike],
|
|
260
237
|
):
|
|
@@ -310,36 +287,6 @@ def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
|
310
287
|
return cache.get_struct(Fields, suffix=suffix)
|
|
311
288
|
|
|
312
289
|
|
|
313
|
-
def _gen_value_struct(value_args: Dict[str, type]):
|
|
314
|
-
class Values:
|
|
315
|
-
pass
|
|
316
|
-
|
|
317
|
-
annotations = get_annotations(Values)
|
|
318
|
-
|
|
319
|
-
for name, arg_type in value_args.items():
|
|
320
|
-
setattr(Values, name, None)
|
|
321
|
-
annotations[name] = arg_type
|
|
322
|
-
|
|
323
|
-
def arg_type_name(arg_type):
|
|
324
|
-
if isinstance(arg_type, wp.codegen.Struct):
|
|
325
|
-
return arg_type_name(arg_type.cls)
|
|
326
|
-
return getattr(arg_type, "__name__", str(arg_type))
|
|
327
|
-
|
|
328
|
-
def arg_type_name(arg_type):
|
|
329
|
-
if isinstance(arg_type, wp.codegen.Struct):
|
|
330
|
-
return arg_type_name(arg_type.cls)
|
|
331
|
-
return getattr(arg_type, "__name__", str(arg_type))
|
|
332
|
-
|
|
333
|
-
try:
|
|
334
|
-
Values.__annotations__ = annotations
|
|
335
|
-
except AttributeError:
|
|
336
|
-
Values.__dict__.__annotations__ = annotations
|
|
337
|
-
|
|
338
|
-
suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
|
|
339
|
-
|
|
340
|
-
return cache.get_struct(Values, suffix=suffix)
|
|
341
|
-
|
|
342
|
-
|
|
343
290
|
def _get_trial_arg():
|
|
344
291
|
pass
|
|
345
292
|
|
|
@@ -474,17 +421,18 @@ def get_integrate_constant_kernel(
|
|
|
474
421
|
values: ValueStruct,
|
|
475
422
|
result: wp.array(dtype=accumulate_dtype),
|
|
476
423
|
):
|
|
477
|
-
|
|
424
|
+
domain_element_index = wp.tid()
|
|
425
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
478
426
|
elem_sum = accumulate_dtype(0.0)
|
|
479
427
|
|
|
480
428
|
test_dof_index = NULL_DOF_INDEX
|
|
481
429
|
trial_dof_index = NULL_DOF_INDEX
|
|
482
430
|
|
|
483
|
-
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
431
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
484
432
|
for k in range(qp_point_count):
|
|
485
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
486
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
487
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
433
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
434
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
435
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
488
436
|
|
|
489
437
|
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
490
438
|
vol = domain.element_measure(domain_arg, sample)
|
|
@@ -519,23 +467,31 @@ def get_integrate_linear_kernel(
|
|
|
519
467
|
):
|
|
520
468
|
local_node_index, test_dof = wp.tid()
|
|
521
469
|
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
522
|
-
|
|
470
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
|
|
523
471
|
|
|
524
472
|
trial_dof_index = NULL_DOF_INDEX
|
|
525
473
|
|
|
526
474
|
val_sum = accumulate_dtype(0.0)
|
|
527
475
|
|
|
528
|
-
for n in range(
|
|
529
|
-
node_element_index = test.space_restriction.node_element_index(test_arg,
|
|
476
|
+
for n in range(element_beg, element_end):
|
|
477
|
+
node_element_index = test.space_restriction.node_element_index(test_arg, n)
|
|
530
478
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
531
479
|
|
|
532
480
|
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
533
481
|
|
|
534
|
-
qp_point_count = quadrature.point_count(
|
|
482
|
+
qp_point_count = quadrature.point_count(
|
|
483
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index
|
|
484
|
+
)
|
|
535
485
|
for k in range(qp_point_count):
|
|
536
|
-
qp_index = quadrature.point_index(
|
|
537
|
-
|
|
538
|
-
|
|
486
|
+
qp_index = quadrature.point_index(
|
|
487
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
|
|
488
|
+
)
|
|
489
|
+
qp_coords = quadrature.point_coords(
|
|
490
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
|
|
491
|
+
)
|
|
492
|
+
qp_weight = quadrature.point_weight(
|
|
493
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
|
|
494
|
+
)
|
|
539
495
|
|
|
540
496
|
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
541
497
|
|
|
@@ -562,23 +518,29 @@ def get_integrate_linear_nodal_kernel(
|
|
|
562
518
|
domain_arg: domain.ElementArg,
|
|
563
519
|
domain_index_arg: domain.ElementIndexArg,
|
|
564
520
|
test_restriction_arg: test.space_restriction.NodeArg,
|
|
521
|
+
test_topo_arg: test.space.topology.TopologyArg,
|
|
565
522
|
fields: FieldStruct,
|
|
566
523
|
values: ValueStruct,
|
|
567
524
|
result: wp.array2d(dtype=output_dtype),
|
|
568
525
|
):
|
|
569
526
|
local_node_index, dof = wp.tid()
|
|
570
527
|
|
|
571
|
-
|
|
572
|
-
|
|
528
|
+
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
529
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
573
530
|
|
|
574
531
|
trial_dof_index = NULL_DOF_INDEX
|
|
575
532
|
|
|
576
533
|
val_sum = accumulate_dtype(0.0)
|
|
577
534
|
|
|
578
|
-
for n in range(
|
|
579
|
-
node_element_index = test.space_restriction.node_element_index(test_restriction_arg,
|
|
535
|
+
for n in range(element_beg, element_end):
|
|
536
|
+
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
|
|
580
537
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
581
538
|
|
|
539
|
+
if n == element_beg:
|
|
540
|
+
node_index = test.space.topology.element_node_index(
|
|
541
|
+
domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
|
|
542
|
+
)
|
|
543
|
+
|
|
582
544
|
coords = test.space.node_coords_in_element(
|
|
583
545
|
domain_arg,
|
|
584
546
|
_get_test_arg(),
|
|
@@ -609,7 +571,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
609
571
|
|
|
610
572
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
611
573
|
|
|
612
|
-
result[
|
|
574
|
+
result[partition_node_index, dof] = output_dtype(val_sum)
|
|
613
575
|
|
|
614
576
|
return integrate_kernel_fn
|
|
615
577
|
|
|
@@ -625,7 +587,7 @@ def get_integrate_bilinear_kernel(
|
|
|
625
587
|
output_dtype,
|
|
626
588
|
accumulate_dtype,
|
|
627
589
|
):
|
|
628
|
-
|
|
590
|
+
MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
629
591
|
|
|
630
592
|
def integrate_kernel_fn(
|
|
631
593
|
qp_arg: quadrature.Arg,
|
|
@@ -636,22 +598,29 @@ def get_integrate_bilinear_kernel(
|
|
|
636
598
|
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
637
599
|
fields: FieldStruct,
|
|
638
600
|
values: ValueStruct,
|
|
639
|
-
row_offsets: wp.array(dtype=int),
|
|
640
601
|
triplet_rows: wp.array(dtype=int),
|
|
641
602
|
triplet_cols: wp.array(dtype=int),
|
|
642
603
|
triplet_values: wp.array3d(dtype=output_dtype),
|
|
643
604
|
):
|
|
644
605
|
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
645
606
|
|
|
646
|
-
element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
|
|
647
607
|
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
608
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
|
|
648
609
|
|
|
649
610
|
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
650
611
|
|
|
651
|
-
for element in range(
|
|
652
|
-
test_element_index = test.space_restriction.node_element_index(test_arg,
|
|
612
|
+
for element in range(element_beg, element_end):
|
|
613
|
+
test_element_index = test.space_restriction.node_element_index(test_arg, element)
|
|
653
614
|
element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
|
|
654
|
-
|
|
615
|
+
|
|
616
|
+
element_trial_node_count = trial.space.topology.element_node_count(
|
|
617
|
+
domain_arg, trial_topology_arg, element_index
|
|
618
|
+
)
|
|
619
|
+
qp_point_count = wp.select(
|
|
620
|
+
trial_node < element_trial_node_count,
|
|
621
|
+
0,
|
|
622
|
+
quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
|
|
623
|
+
)
|
|
655
624
|
|
|
656
625
|
test_dof_index = DofIndex(
|
|
657
626
|
test_element_index.node_index_in_element,
|
|
@@ -661,10 +630,16 @@ def get_integrate_bilinear_kernel(
|
|
|
661
630
|
val_sum = accumulate_dtype(0.0)
|
|
662
631
|
|
|
663
632
|
for k in range(qp_point_count):
|
|
664
|
-
qp_index = quadrature.point_index(
|
|
665
|
-
|
|
633
|
+
qp_index = quadrature.point_index(
|
|
634
|
+
domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
|
|
635
|
+
)
|
|
636
|
+
coords = quadrature.point_coords(
|
|
637
|
+
domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
|
|
638
|
+
)
|
|
666
639
|
|
|
667
|
-
qp_weight = quadrature.point_weight(
|
|
640
|
+
qp_weight = quadrature.point_weight(
|
|
641
|
+
domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
|
|
642
|
+
)
|
|
668
643
|
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
|
|
669
644
|
|
|
670
645
|
sample = Sample(
|
|
@@ -678,15 +653,20 @@ def get_integrate_bilinear_kernel(
|
|
|
678
653
|
val = integrand_func(sample, fields, values)
|
|
679
654
|
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
680
655
|
|
|
681
|
-
block_offset =
|
|
656
|
+
block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
|
|
682
657
|
triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
|
|
683
658
|
|
|
684
659
|
# Set row and column indices
|
|
685
660
|
if test_dof == 0 and trial_dof == 0:
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
661
|
+
if trial_node < element_trial_node_count:
|
|
662
|
+
trial_node_index = trial.space_partition.partition_node_index(
|
|
663
|
+
trial_partition_arg,
|
|
664
|
+
trial.space.topology.element_node_index(
|
|
665
|
+
domain_arg, trial_topology_arg, element_index, trial_node
|
|
666
|
+
),
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
|
|
690
670
|
triplet_rows[block_offset] = test_node_index
|
|
691
671
|
triplet_cols[block_offset] = trial_node_index
|
|
692
672
|
|
|
@@ -706,6 +686,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
706
686
|
domain_arg: domain.ElementArg,
|
|
707
687
|
domain_index_arg: domain.ElementIndexArg,
|
|
708
688
|
test_restriction_arg: test.space_restriction.NodeArg,
|
|
689
|
+
test_topo_arg: test.space.topology.TopologyArg,
|
|
709
690
|
fields: FieldStruct,
|
|
710
691
|
values: ValueStruct,
|
|
711
692
|
triplet_rows: wp.array(dtype=int),
|
|
@@ -714,15 +695,20 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
714
695
|
):
|
|
715
696
|
local_node_index, test_dof, trial_dof = wp.tid()
|
|
716
697
|
|
|
717
|
-
|
|
718
|
-
|
|
698
|
+
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
699
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
719
700
|
|
|
720
701
|
val_sum = accumulate_dtype(0.0)
|
|
721
702
|
|
|
722
|
-
for n in range(
|
|
723
|
-
node_element_index = test.space_restriction.node_element_index(test_restriction_arg,
|
|
703
|
+
for n in range(element_beg, element_end):
|
|
704
|
+
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
|
|
724
705
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
725
706
|
|
|
707
|
+
if n == element_beg:
|
|
708
|
+
node_index = test.space.topology.element_node_index(
|
|
709
|
+
domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
|
|
710
|
+
)
|
|
711
|
+
|
|
726
712
|
coords = test.space.node_coords_in_element(
|
|
727
713
|
domain_arg,
|
|
728
714
|
_get_test_arg(),
|
|
@@ -755,8 +741,8 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
755
741
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
756
742
|
|
|
757
743
|
triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
|
|
758
|
-
triplet_rows[local_node_index] =
|
|
759
|
-
triplet_cols[local_node_index] =
|
|
744
|
+
triplet_rows[local_node_index] = partition_node_index
|
|
745
|
+
triplet_cols[local_node_index] = partition_node_index
|
|
760
746
|
|
|
761
747
|
return integrate_kernel_fn
|
|
762
748
|
|
|
@@ -786,7 +772,7 @@ def _generate_integrate_kernel(
|
|
|
786
772
|
)
|
|
787
773
|
|
|
788
774
|
FieldStruct = _gen_field_struct(field_args)
|
|
789
|
-
ValueStruct =
|
|
775
|
+
ValueStruct = cache.get_argument_struct(value_args)
|
|
790
776
|
|
|
791
777
|
# Check if kernel exist in cache
|
|
792
778
|
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
|
|
@@ -923,7 +909,7 @@ def _launch_integrate_kernel(
|
|
|
923
909
|
for k, v in fields.items():
|
|
924
910
|
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
925
911
|
|
|
926
|
-
value_struct_values =
|
|
912
|
+
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
927
913
|
|
|
928
914
|
# Constant form
|
|
929
915
|
if test is None and trial is None:
|
|
@@ -1030,6 +1016,7 @@ def _launch_integrate_kernel(
|
|
|
1030
1016
|
domain_elt_arg,
|
|
1031
1017
|
domain_elt_index_arg,
|
|
1032
1018
|
test_arg,
|
|
1019
|
+
test.space.topology.topo_arg_value(device),
|
|
1033
1020
|
field_arg_values,
|
|
1034
1021
|
value_struct_values,
|
|
1035
1022
|
output_view,
|
|
@@ -1069,7 +1056,7 @@ def _launch_integrate_kernel(
|
|
|
1069
1056
|
if nodal:
|
|
1070
1057
|
nnz = test.space_restriction.node_count()
|
|
1071
1058
|
else:
|
|
1072
|
-
nnz = test.space_restriction.total_node_element_count() * trial.space.topology.
|
|
1059
|
+
nnz = test.space_restriction.total_node_element_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
1073
1060
|
|
|
1074
1061
|
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1075
1062
|
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
@@ -1097,6 +1084,7 @@ def _launch_integrate_kernel(
|
|
|
1097
1084
|
domain_elt_arg,
|
|
1098
1085
|
domain_elt_index_arg,
|
|
1099
1086
|
test_arg,
|
|
1087
|
+
test.space.topology.topo_arg_value(device),
|
|
1100
1088
|
field_arg_values,
|
|
1101
1089
|
value_struct_values,
|
|
1102
1090
|
triplet_rows,
|
|
@@ -1107,15 +1095,13 @@ def _launch_integrate_kernel(
|
|
|
1107
1095
|
)
|
|
1108
1096
|
|
|
1109
1097
|
else:
|
|
1110
|
-
offsets = test.space_restriction.partition_element_offsets()
|
|
1111
|
-
|
|
1112
1098
|
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1113
1099
|
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1114
1100
|
wp.launch(
|
|
1115
1101
|
kernel=kernel,
|
|
1116
1102
|
dim=(
|
|
1117
1103
|
test.space_restriction.node_count(),
|
|
1118
|
-
trial.space.topology.
|
|
1104
|
+
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1119
1105
|
test.space.VALUE_DOF_COUNT,
|
|
1120
1106
|
trial.space.VALUE_DOF_COUNT,
|
|
1121
1107
|
),
|
|
@@ -1128,7 +1114,6 @@ def _launch_integrate_kernel(
|
|
|
1128
1114
|
trial_topology_arg,
|
|
1129
1115
|
field_arg_values,
|
|
1130
1116
|
value_struct_values,
|
|
1131
|
-
offsets,
|
|
1132
1117
|
triplet_rows,
|
|
1133
1118
|
triplet_cols,
|
|
1134
1119
|
triplet_values,
|
|
@@ -1299,8 +1284,8 @@ def get_interpolate_to_field_function(
|
|
|
1299
1284
|
fields: FieldStruct,
|
|
1300
1285
|
values: ValueStruct,
|
|
1301
1286
|
):
|
|
1302
|
-
|
|
1303
|
-
|
|
1287
|
+
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1288
|
+
element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
|
|
1304
1289
|
|
|
1305
1290
|
test_dof_index = NULL_DOF_INDEX
|
|
1306
1291
|
trial_dof_index = NULL_DOF_INDEX
|
|
@@ -1312,10 +1297,15 @@ def get_interpolate_to_field_function(
|
|
|
1312
1297
|
val_sum = value_type(0.0)
|
|
1313
1298
|
vol_sum = float(0.0)
|
|
1314
1299
|
|
|
1315
|
-
for n in range(
|
|
1316
|
-
node_element_index = dest.space_restriction.node_element_index(dest_node_arg,
|
|
1300
|
+
for n in range(element_beg, element_end):
|
|
1301
|
+
node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
|
|
1317
1302
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
1318
1303
|
|
|
1304
|
+
if n == element_beg:
|
|
1305
|
+
node_index = dest.space.topology.element_node_index(
|
|
1306
|
+
domain_arg, dest_eval_arg.topology_arg, element_index, node_element_index.node_index_in_element
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1319
1309
|
coords = dest.space.node_coords_in_element(
|
|
1320
1310
|
domain_arg,
|
|
1321
1311
|
dest_eval_arg.space_arg,
|
|
@@ -1371,7 +1361,7 @@ def get_interpolate_to_field_kernel(
|
|
|
1371
1361
|
return interpolate_to_field_kernel_fn
|
|
1372
1362
|
|
|
1373
1363
|
|
|
1374
|
-
def
|
|
1364
|
+
def get_interpolate_at_quadrature_kernel(
|
|
1375
1365
|
integrand_func: wp.Function,
|
|
1376
1366
|
domain: GeometryDomain,
|
|
1377
1367
|
quadrature: Quadrature,
|
|
@@ -1379,61 +1369,100 @@ def get_interpolate_to_array_kernel(
|
|
|
1379
1369
|
ValueStruct: wp.codegen.Struct,
|
|
1380
1370
|
value_type: type,
|
|
1381
1371
|
):
|
|
1382
|
-
def
|
|
1372
|
+
def interpolate_at_quadrature_nonvalued_kernel_fn(
|
|
1383
1373
|
qp_arg: quadrature.Arg,
|
|
1384
1374
|
domain_arg: quadrature.domain.ElementArg,
|
|
1385
1375
|
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1386
1376
|
fields: FieldStruct,
|
|
1387
1377
|
values: ValueStruct,
|
|
1388
|
-
result: wp.array(dtype=
|
|
1378
|
+
result: wp.array(dtype=float),
|
|
1389
1379
|
):
|
|
1390
|
-
|
|
1380
|
+
domain_element_index = wp.tid()
|
|
1381
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1391
1382
|
|
|
1392
1383
|
test_dof_index = NULL_DOF_INDEX
|
|
1393
1384
|
trial_dof_index = NULL_DOF_INDEX
|
|
1394
1385
|
|
|
1395
|
-
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1386
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
1396
1387
|
for k in range(qp_point_count):
|
|
1397
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1398
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1399
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1388
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1389
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1390
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1400
1391
|
|
|
1401
1392
|
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1393
|
+
integrand_func(sample, fields, values)
|
|
1394
|
+
|
|
1395
|
+
def interpolate_at_quadrature_kernel_fn(
|
|
1396
|
+
qp_arg: quadrature.Arg,
|
|
1397
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1398
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1399
|
+
fields: FieldStruct,
|
|
1400
|
+
values: ValueStruct,
|
|
1401
|
+
result: wp.array(dtype=value_type),
|
|
1402
|
+
):
|
|
1403
|
+
domain_element_index = wp.tid()
|
|
1404
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1405
|
+
|
|
1406
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1407
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1402
1408
|
|
|
1409
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
1410
|
+
for k in range(qp_point_count):
|
|
1411
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1412
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1413
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1414
|
+
|
|
1415
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1403
1416
|
result[qp_index] = integrand_func(sample, fields, values)
|
|
1404
1417
|
|
|
1405
|
-
return
|
|
1418
|
+
return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
|
|
1406
1419
|
|
|
1407
1420
|
|
|
1408
|
-
def
|
|
1421
|
+
def get_interpolate_free_kernel(
|
|
1409
1422
|
integrand_func: wp.Function,
|
|
1410
1423
|
domain: GeometryDomain,
|
|
1411
|
-
quadrature: Quadrature,
|
|
1412
1424
|
FieldStruct: wp.codegen.Struct,
|
|
1413
1425
|
ValueStruct: wp.codegen.Struct,
|
|
1426
|
+
value_type: type,
|
|
1414
1427
|
):
|
|
1415
|
-
def
|
|
1416
|
-
|
|
1417
|
-
domain_arg:
|
|
1418
|
-
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1428
|
+
def interpolate_free_nonvalued_kernel_fn(
|
|
1429
|
+
dim: int,
|
|
1430
|
+
domain_arg: domain.ElementArg,
|
|
1419
1431
|
fields: FieldStruct,
|
|
1420
1432
|
values: ValueStruct,
|
|
1433
|
+
result: wp.array(dtype=float),
|
|
1421
1434
|
):
|
|
1422
|
-
|
|
1435
|
+
qp_index = wp.tid()
|
|
1436
|
+
qp_weight = 1.0 / float(dim)
|
|
1437
|
+
element_index = NULL_ELEMENT_INDEX
|
|
1438
|
+
coords = Coords(OUTSIDE)
|
|
1423
1439
|
|
|
1424
1440
|
test_dof_index = NULL_DOF_INDEX
|
|
1425
1441
|
trial_dof_index = NULL_DOF_INDEX
|
|
1426
1442
|
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1430
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1431
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1443
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1444
|
+
integrand_func(sample, fields, values)
|
|
1432
1445
|
|
|
1433
|
-
|
|
1434
|
-
|
|
1446
|
+
def interpolate_free_kernel_fn(
|
|
1447
|
+
dim: int,
|
|
1448
|
+
domain_arg: domain.ElementArg,
|
|
1449
|
+
fields: FieldStruct,
|
|
1450
|
+
values: ValueStruct,
|
|
1451
|
+
result: wp.array(dtype=value_type),
|
|
1452
|
+
):
|
|
1453
|
+
qp_index = wp.tid()
|
|
1454
|
+
qp_weight = 1.0 / float(dim)
|
|
1455
|
+
element_index = NULL_ELEMENT_INDEX
|
|
1456
|
+
coords = Coords(OUTSIDE)
|
|
1457
|
+
|
|
1458
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1459
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1460
|
+
|
|
1461
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1435
1462
|
|
|
1436
|
-
|
|
1463
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1464
|
+
|
|
1465
|
+
return interpolate_free_nonvalued_kernel_fn if value_type is None else interpolate_free_kernel_fn
|
|
1437
1466
|
|
|
1438
1467
|
|
|
1439
1468
|
def _generate_interpolate_kernel(
|
|
@@ -1461,17 +1490,20 @@ def _generate_interpolate_kernel(
|
|
|
1461
1490
|
_register_integrand_field_wrappers(integrand_func, fields)
|
|
1462
1491
|
|
|
1463
1492
|
FieldStruct = _gen_field_struct(field_args)
|
|
1464
|
-
ValueStruct =
|
|
1493
|
+
ValueStruct = cache.get_argument_struct(value_args)
|
|
1465
1494
|
|
|
1466
1495
|
# Check if kernel exist in cache
|
|
1467
1496
|
if isinstance(dest, FieldRestriction):
|
|
1468
1497
|
kernel_suffix = (
|
|
1469
1498
|
f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
|
|
1470
1499
|
)
|
|
1471
|
-
elif wp.types.is_array(dest):
|
|
1472
|
-
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
|
|
1473
1500
|
else:
|
|
1474
|
-
|
|
1501
|
+
dest_dtype = dest.dtype if dest else None
|
|
1502
|
+
type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
|
|
1503
|
+
if quadrature is None:
|
|
1504
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{type_str}"
|
|
1505
|
+
else:
|
|
1506
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{type_str}"
|
|
1475
1507
|
|
|
1476
1508
|
kernel = cache.get_integrand_kernel(
|
|
1477
1509
|
integrand=integrand,
|
|
@@ -1515,20 +1547,20 @@ def _generate_interpolate_kernel(
|
|
|
1515
1547
|
FieldStruct=FieldStruct,
|
|
1516
1548
|
ValueStruct=ValueStruct,
|
|
1517
1549
|
)
|
|
1518
|
-
elif
|
|
1519
|
-
interpolate_kernel_fn =
|
|
1550
|
+
elif quadrature is not None:
|
|
1551
|
+
interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
|
|
1520
1552
|
integrand_func,
|
|
1521
1553
|
domain=domain,
|
|
1522
1554
|
quadrature=quadrature,
|
|
1523
|
-
value_type=
|
|
1555
|
+
value_type=dest_dtype,
|
|
1524
1556
|
FieldStruct=FieldStruct,
|
|
1525
1557
|
ValueStruct=ValueStruct,
|
|
1526
1558
|
)
|
|
1527
1559
|
else:
|
|
1528
|
-
interpolate_kernel_fn =
|
|
1560
|
+
interpolate_kernel_fn = get_interpolate_free_kernel(
|
|
1529
1561
|
integrand_func,
|
|
1530
1562
|
domain=domain,
|
|
1531
|
-
|
|
1563
|
+
value_type=dest_dtype,
|
|
1532
1564
|
FieldStruct=FieldStruct,
|
|
1533
1565
|
ValueStruct=ValueStruct,
|
|
1534
1566
|
)
|
|
@@ -1560,6 +1592,7 @@ def _launch_interpolate_kernel(
|
|
|
1560
1592
|
domain: GeometryDomain,
|
|
1561
1593
|
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
1562
1594
|
quadrature: Optional[Quadrature],
|
|
1595
|
+
dim: int,
|
|
1563
1596
|
fields: Dict[str, FieldLike],
|
|
1564
1597
|
values: Dict[str, Any],
|
|
1565
1598
|
device,
|
|
@@ -1572,7 +1605,7 @@ def _launch_interpolate_kernel(
|
|
|
1572
1605
|
for k, v in fields.items():
|
|
1573
1606
|
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
1574
1607
|
|
|
1575
|
-
value_struct_values =
|
|
1608
|
+
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
1576
1609
|
|
|
1577
1610
|
if isinstance(dest, FieldRestriction):
|
|
1578
1611
|
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
@@ -1591,7 +1624,7 @@ def _launch_interpolate_kernel(
|
|
|
1591
1624
|
],
|
|
1592
1625
|
device=device,
|
|
1593
1626
|
)
|
|
1594
|
-
elif
|
|
1627
|
+
elif quadrature is not None:
|
|
1595
1628
|
qp_arg = quadrature.arg_value(device)
|
|
1596
1629
|
wp.launch(
|
|
1597
1630
|
kernel=kernel,
|
|
@@ -1600,19 +1633,25 @@ def _launch_interpolate_kernel(
|
|
|
1600
1633
|
device=device,
|
|
1601
1634
|
)
|
|
1602
1635
|
else:
|
|
1603
|
-
qp_arg = quadrature.arg_value(device)
|
|
1604
1636
|
wp.launch(
|
|
1605
1637
|
kernel=kernel,
|
|
1606
|
-
dim=
|
|
1607
|
-
inputs=[
|
|
1638
|
+
dim=dim,
|
|
1639
|
+
inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
|
|
1608
1640
|
device=device,
|
|
1609
1641
|
)
|
|
1610
1642
|
|
|
1611
1643
|
|
|
1644
|
+
@integrand
|
|
1645
|
+
def _identity_field(field: Field, s: Sample):
|
|
1646
|
+
return field(s)
|
|
1647
|
+
|
|
1648
|
+
|
|
1612
1649
|
def interpolate(
|
|
1613
|
-
integrand: Integrand,
|
|
1650
|
+
integrand: Union[Integrand, FieldLike],
|
|
1614
1651
|
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
1615
1652
|
quadrature: Optional[Quadrature] = None,
|
|
1653
|
+
dim: int = 0,
|
|
1654
|
+
domain: Optional[Domain] = None,
|
|
1616
1655
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
1617
1656
|
values: Optional[Dict[str, Any]] = None,
|
|
1618
1657
|
device=None,
|
|
@@ -1622,18 +1661,26 @@ def interpolate(
|
|
|
1622
1661
|
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
1623
1662
|
|
|
1624
1663
|
Args:
|
|
1625
|
-
integrand: Function to be interpolated
|
|
1664
|
+
integrand: Function to be interpolated: either a function with :func:`warp.fem.integrand` decorator or a field
|
|
1626
1665
|
dest: Where to store the interpolation result. Can be either
|
|
1627
1666
|
|
|
1628
1667
|
- a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
|
|
1629
|
-
- a normal warp array
|
|
1630
|
-
- ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is responsible for dealing with the interpolation result.
|
|
1668
|
+
- a normal warp ``array``, or ``None``. In this case, the interpolation samples will determined by the `quadrature` or `dim` arguments, in that order.
|
|
1631
1669
|
quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
|
|
1670
|
+
dim: Number of interpolation samples if `dest` is not a discrete field or restriction and `quadrature` is ``None``.
|
|
1671
|
+
In this case, the ``Sample`` passed to the `integrand` will be invalid, but the sample point index ``s.qp_index`` can be used to define custom interpolation logic.
|
|
1672
|
+
domain: Interpolation domain, only used if `dest` is not a field restriction and `quadrature` is ``None``
|
|
1632
1673
|
fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
|
|
1633
1674
|
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1634
1675
|
device: Device on which to perform the interpolation
|
|
1635
1676
|
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1636
1677
|
"""
|
|
1678
|
+
|
|
1679
|
+
if isinstance(integrand, FieldLike):
|
|
1680
|
+
fields = {"field": integrand}
|
|
1681
|
+
values = {}
|
|
1682
|
+
integrand = _identity_field
|
|
1683
|
+
|
|
1637
1684
|
if fields is None:
|
|
1638
1685
|
fields = {}
|
|
1639
1686
|
|
|
@@ -1651,14 +1698,11 @@ def interpolate(
|
|
|
1651
1698
|
raise ValueError("Test or Trial fields should not be used for interpolation")
|
|
1652
1699
|
|
|
1653
1700
|
if isinstance(dest, DiscreteField):
|
|
1654
|
-
dest = make_restriction(dest)
|
|
1701
|
+
dest = make_restriction(dest, domain=domain)
|
|
1655
1702
|
|
|
1656
1703
|
if isinstance(dest, FieldRestriction):
|
|
1657
1704
|
domain = dest.domain
|
|
1658
|
-
|
|
1659
|
-
if quadrature is None:
|
|
1660
|
-
raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
|
|
1661
|
-
|
|
1705
|
+
elif quadrature is not None:
|
|
1662
1706
|
domain = quadrature.domain
|
|
1663
1707
|
|
|
1664
1708
|
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
@@ -1678,6 +1722,7 @@ def interpolate(
|
|
|
1678
1722
|
domain=domain,
|
|
1679
1723
|
dest=dest,
|
|
1680
1724
|
quadrature=quadrature,
|
|
1725
|
+
dim=dim,
|
|
1681
1726
|
fields=fields,
|
|
1682
1727
|
values=values,
|
|
1683
1728
|
device=device,
|