warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -380
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -5,13 +5,12 @@ import warp as wp
|
|
|
5
5
|
import re
|
|
6
6
|
import ast
|
|
7
7
|
|
|
8
|
-
from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy,
|
|
8
|
+
from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_assign
|
|
9
9
|
from warp.types import type_length
|
|
10
10
|
from warp.utils import array_cast
|
|
11
11
|
from warp.codegen import get_annotations
|
|
12
12
|
|
|
13
13
|
from warp.fem.domain import GeometryDomain
|
|
14
|
-
from warp.fem.space import SpaceRestriction
|
|
15
14
|
from warp.fem.field import (
|
|
16
15
|
TestField,
|
|
17
16
|
TrialField,
|
|
@@ -23,7 +22,7 @@ from warp.fem.field import (
|
|
|
23
22
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
24
23
|
from warp.fem.operator import Operator, Integrand
|
|
25
24
|
from warp.fem import cache
|
|
26
|
-
from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE
|
|
25
|
+
from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE, make_free_sample
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
def _resolve_path(func, node):
|
|
@@ -98,7 +97,7 @@ class IntegrandTransformer(ast.NodeTransformer):
|
|
|
98
97
|
operator = arg_type.call_operator
|
|
99
98
|
|
|
100
99
|
call.func = ast.Attribute(
|
|
101
|
-
value=_path_to_ast_attribute(arg_type.__qualname__),
|
|
100
|
+
value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
|
|
102
101
|
attr="call_operator",
|
|
103
102
|
ctx=ast.Load(),
|
|
104
103
|
)
|
|
@@ -164,7 +163,7 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
|
|
|
164
163
|
for arg in argspec.args:
|
|
165
164
|
arg_type = argspec.annotations[arg]
|
|
166
165
|
if arg_type == Field:
|
|
167
|
-
annotations[arg] = field_args[arg].
|
|
166
|
+
annotations[arg] = field_args[arg].ElementEvalArg
|
|
168
167
|
elif arg_type == Domain:
|
|
169
168
|
annotations[arg] = field_args[arg].ElementArg
|
|
170
169
|
else:
|
|
@@ -174,11 +173,9 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
|
|
|
174
173
|
transformer = IntegrandTransformer(integrand, field_args)
|
|
175
174
|
|
|
176
175
|
def is_field_like(f):
|
|
177
|
-
|
|
178
|
-
return any(isinstance(f, field_class) for field_class in FieldLike.__args__)
|
|
176
|
+
return isinstance(f, FieldLike)
|
|
179
177
|
|
|
180
178
|
suffix = "_".join([f.name for f in field_args.values() if is_field_like(f)])
|
|
181
|
-
key = integrand.name + suffix
|
|
182
179
|
|
|
183
180
|
func = cache.get_integrand_function(
|
|
184
181
|
integrand=integrand,
|
|
@@ -265,18 +262,14 @@ def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
|
265
262
|
setattr(Fields, name, arg.EvalArg())
|
|
266
263
|
annotations[name] = arg.EvalArg
|
|
267
264
|
|
|
268
|
-
Fields.__qualname__ = (
|
|
269
|
-
Fields.__name__
|
|
270
|
-
+ "_"
|
|
271
|
-
+ "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
|
|
272
|
-
)
|
|
273
|
-
|
|
274
265
|
try:
|
|
275
266
|
Fields.__annotations__ = annotations
|
|
276
267
|
except AttributeError:
|
|
277
268
|
setattr(Fields.__dict__, "__annotations__", annotations)
|
|
278
269
|
|
|
279
|
-
|
|
270
|
+
suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
|
|
271
|
+
|
|
272
|
+
return cache.get_struct(Fields, suffix=suffix)
|
|
280
273
|
|
|
281
274
|
|
|
282
275
|
def _gen_value_struct(value_args: Dict[str, type]):
|
|
@@ -299,25 +292,34 @@ def _gen_value_struct(value_args: Dict[str, type]):
|
|
|
299
292
|
return arg_type_name(arg_type.cls)
|
|
300
293
|
return getattr(arg_type, "__name__", str(arg_type))
|
|
301
294
|
|
|
302
|
-
Values.__qualname__ = (
|
|
303
|
-
Values.__name__
|
|
304
|
-
+ "_"
|
|
305
|
-
+ "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
|
|
306
|
-
)
|
|
307
|
-
|
|
308
295
|
try:
|
|
309
296
|
Values.__annotations__ = annotations
|
|
310
297
|
except AttributeError:
|
|
311
298
|
setattr(Values.__dict__, "__annotations__", annotations)
|
|
312
299
|
|
|
313
|
-
|
|
300
|
+
suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
|
|
301
|
+
|
|
302
|
+
return cache.get_struct(Values, suffix=suffix)
|
|
314
303
|
|
|
315
304
|
|
|
316
305
|
def _get_trial_arg():
|
|
317
306
|
pass
|
|
318
307
|
|
|
308
|
+
|
|
319
309
|
def _get_test_arg():
|
|
320
310
|
pass
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class _FieldWrappers:
|
|
314
|
+
pass
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
|
|
318
|
+
integrand_func._field_wrappers = _FieldWrappers()
|
|
319
|
+
for name, field in fields.items():
|
|
320
|
+
setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
|
|
321
|
+
|
|
322
|
+
|
|
321
323
|
class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
322
324
|
def __init__(
|
|
323
325
|
self,
|
|
@@ -333,6 +335,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
333
335
|
values_var_name: str = "values",
|
|
334
336
|
domain_var_name: str = "domain_arg",
|
|
335
337
|
sample_var_name: str = "sample",
|
|
338
|
+
field_wrappers_attr: str = "_field_wrappers",
|
|
336
339
|
):
|
|
337
340
|
self._arg_names = arg_names
|
|
338
341
|
self._field_args = field_args
|
|
@@ -346,6 +349,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
346
349
|
self._values_var_name = values_var_name
|
|
347
350
|
self._domain_var_name = domain_var_name
|
|
348
351
|
self._sample_var_name = sample_var_name
|
|
352
|
+
self._field_wrappers_attr = field_wrappers_attr
|
|
349
353
|
|
|
350
354
|
def visit_Call(self, call: ast.Call):
|
|
351
355
|
call = self.generic_visit(call)
|
|
@@ -366,10 +370,25 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
366
370
|
)
|
|
367
371
|
elif arg in self._field_args:
|
|
368
372
|
call.args.append(
|
|
369
|
-
ast.
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
+
ast.Call(
|
|
374
|
+
func=ast.Attribute(
|
|
375
|
+
value=ast.Attribute(
|
|
376
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
|
377
|
+
attr=self._field_wrappers_attr,
|
|
378
|
+
ctx=ast.Load(),
|
|
379
|
+
),
|
|
380
|
+
attr=arg,
|
|
381
|
+
ctx=ast.Load(),
|
|
382
|
+
),
|
|
383
|
+
args=[
|
|
384
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
385
|
+
ast.Attribute(
|
|
386
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
387
|
+
attr=arg,
|
|
388
|
+
ctx=ast.Load(),
|
|
389
|
+
),
|
|
390
|
+
],
|
|
391
|
+
keywords=[],
|
|
373
392
|
)
|
|
374
393
|
)
|
|
375
394
|
elif arg in self._value_args:
|
|
@@ -401,36 +420,6 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
401
420
|
return call
|
|
402
421
|
|
|
403
422
|
|
|
404
|
-
def get_integrate_null_kernel(
|
|
405
|
-
integrand_func: wp.Function,
|
|
406
|
-
domain: GeometryDomain,
|
|
407
|
-
quadrature: Quadrature,
|
|
408
|
-
FieldStruct: wp.codegen.Struct,
|
|
409
|
-
ValueStruct: wp.codegen.Struct,
|
|
410
|
-
):
|
|
411
|
-
def integrate_kernel_fn(
|
|
412
|
-
qp_arg: quadrature.Arg,
|
|
413
|
-
domain_arg: domain.ElementArg,
|
|
414
|
-
domain_index_arg: domain.ElementIndexArg,
|
|
415
|
-
fields: FieldStruct,
|
|
416
|
-
values: ValueStruct,
|
|
417
|
-
):
|
|
418
|
-
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
419
|
-
|
|
420
|
-
test_dof_index = NULL_DOF_INDEX
|
|
421
|
-
trial_dof_index = NULL_DOF_INDEX
|
|
422
|
-
|
|
423
|
-
qp_point_count = quadrature.point_count(qp_arg, element_index)
|
|
424
|
-
for k in range(qp_point_count):
|
|
425
|
-
qp_index = quadrature.point_index(qp_arg, element_index, k)
|
|
426
|
-
qp_coords = quadrature.point_coords(qp_arg, element_index, k)
|
|
427
|
-
qp_weight = quadrature.point_weight(qp_arg, element_index, k)
|
|
428
|
-
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
429
|
-
integrand_func(sample, fields, values)
|
|
430
|
-
|
|
431
|
-
return integrate_kernel_fn
|
|
432
|
-
|
|
433
|
-
|
|
434
423
|
def get_integrate_constant_kernel(
|
|
435
424
|
integrand_func: wp.Function,
|
|
436
425
|
domain: GeometryDomain,
|
|
@@ -453,14 +442,15 @@ def get_integrate_constant_kernel(
|
|
|
453
442
|
test_dof_index = NULL_DOF_INDEX
|
|
454
443
|
trial_dof_index = NULL_DOF_INDEX
|
|
455
444
|
|
|
456
|
-
qp_point_count = quadrature.point_count(qp_arg, element_index)
|
|
445
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
457
446
|
for k in range(qp_point_count):
|
|
458
|
-
qp_index = quadrature.point_index(qp_arg, element_index, k)
|
|
459
|
-
coords = quadrature.point_coords(qp_arg, element_index, k)
|
|
460
|
-
qp_weight = quadrature.point_weight(qp_arg, element_index, k)
|
|
461
|
-
vol = domain.element_measure(domain_arg, element_index, coords)
|
|
447
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
448
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
449
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
462
450
|
|
|
463
451
|
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
452
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
453
|
+
|
|
464
454
|
val = integrand_func(sample, fields, values)
|
|
465
455
|
|
|
466
456
|
elem_sum += accumulate_dtype(qp_weight * vol * val)
|
|
@@ -476,42 +466,47 @@ def get_integrate_linear_kernel(
|
|
|
476
466
|
quadrature: Quadrature,
|
|
477
467
|
FieldStruct: wp.codegen.Struct,
|
|
478
468
|
ValueStruct: wp.codegen.Struct,
|
|
479
|
-
|
|
469
|
+
test: TestField,
|
|
470
|
+
output_dtype,
|
|
480
471
|
accumulate_dtype,
|
|
481
472
|
):
|
|
482
473
|
def integrate_kernel_fn(
|
|
483
474
|
qp_arg: quadrature.Arg,
|
|
484
475
|
domain_arg: domain.ElementArg,
|
|
485
476
|
domain_index_arg: domain.ElementIndexArg,
|
|
486
|
-
test_arg:
|
|
477
|
+
test_arg: test.space_restriction.NodeArg,
|
|
487
478
|
fields: FieldStruct,
|
|
488
479
|
values: ValueStruct,
|
|
489
|
-
result: wp.array2d(dtype=
|
|
480
|
+
result: wp.array2d(dtype=output_dtype),
|
|
490
481
|
):
|
|
491
|
-
local_node_index = wp.tid()
|
|
492
|
-
node_index =
|
|
493
|
-
element_count =
|
|
482
|
+
local_node_index, test_dof = wp.tid()
|
|
483
|
+
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
484
|
+
element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
|
|
494
485
|
|
|
495
486
|
trial_dof_index = NULL_DOF_INDEX
|
|
496
487
|
|
|
488
|
+
val_sum = accumulate_dtype(0.0)
|
|
489
|
+
|
|
497
490
|
for n in range(element_count):
|
|
498
|
-
node_element_index =
|
|
491
|
+
node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
|
|
499
492
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
500
493
|
|
|
501
|
-
|
|
494
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
495
|
+
|
|
496
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
502
497
|
for k in range(qp_point_count):
|
|
503
|
-
qp_index = quadrature.point_index(qp_arg, element_index, k)
|
|
504
|
-
|
|
498
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
499
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
500
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
505
501
|
|
|
506
|
-
|
|
507
|
-
vol = domain.element_measure(domain_arg, element_index, coords)
|
|
502
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
508
503
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
504
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
505
|
+
val = integrand_func(sample, fields, values)
|
|
506
|
+
|
|
507
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
513
508
|
|
|
514
|
-
|
|
509
|
+
result[node_index, test_dof] = output_dtype(val_sum)
|
|
515
510
|
|
|
516
511
|
return integrate_kernel_fn
|
|
517
512
|
|
|
@@ -522,6 +517,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
522
517
|
FieldStruct: wp.codegen.Struct,
|
|
523
518
|
ValueStruct: wp.codegen.Struct,
|
|
524
519
|
test: TestField,
|
|
520
|
+
output_dtype,
|
|
525
521
|
accumulate_dtype,
|
|
526
522
|
):
|
|
527
523
|
def integrate_kernel_fn(
|
|
@@ -530,7 +526,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
530
526
|
test_restriction_arg: test.space_restriction.NodeArg,
|
|
531
527
|
fields: FieldStruct,
|
|
532
528
|
values: ValueStruct,
|
|
533
|
-
result: wp.array2d(dtype=
|
|
529
|
+
result: wp.array2d(dtype=output_dtype),
|
|
534
530
|
):
|
|
535
531
|
local_node_index, dof = wp.tid()
|
|
536
532
|
|
|
@@ -546,6 +542,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
546
542
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
547
543
|
|
|
548
544
|
coords = test.space.node_coords_in_element(
|
|
545
|
+
domain_arg,
|
|
549
546
|
_get_test_arg(),
|
|
550
547
|
element_index,
|
|
551
548
|
node_element_index.node_index_in_element,
|
|
@@ -553,12 +550,12 @@ def get_integrate_linear_nodal_kernel(
|
|
|
553
550
|
|
|
554
551
|
if coords[0] != OUTSIDE:
|
|
555
552
|
node_weight = test.space.node_quadrature_weight(
|
|
553
|
+
domain_arg,
|
|
556
554
|
_get_test_arg(),
|
|
557
555
|
element_index,
|
|
558
556
|
node_element_index.node_index_in_element,
|
|
559
557
|
)
|
|
560
558
|
|
|
561
|
-
vol = domain.element_measure(domain_arg, element_index, coords)
|
|
562
559
|
test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
|
|
563
560
|
|
|
564
561
|
sample = Sample(
|
|
@@ -569,11 +566,12 @@ def get_integrate_linear_nodal_kernel(
|
|
|
569
566
|
test_dof_index,
|
|
570
567
|
trial_dof_index,
|
|
571
568
|
)
|
|
569
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
572
570
|
val = integrand_func(sample, fields, values)
|
|
573
571
|
|
|
574
572
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
575
573
|
|
|
576
|
-
result[node_index, dof] = val_sum
|
|
574
|
+
result[node_index, dof] = output_dtype(val_sum)
|
|
577
575
|
|
|
578
576
|
return integrate_kernel_fn
|
|
579
577
|
|
|
@@ -584,80 +582,75 @@ def get_integrate_bilinear_kernel(
|
|
|
584
582
|
quadrature: Quadrature,
|
|
585
583
|
FieldStruct: wp.codegen.Struct,
|
|
586
584
|
ValueStruct: wp.codegen.Struct,
|
|
587
|
-
|
|
585
|
+
test: TestField,
|
|
588
586
|
trial: TrialField,
|
|
587
|
+
output_dtype,
|
|
589
588
|
accumulate_dtype,
|
|
590
589
|
):
|
|
591
|
-
NODES_PER_ELEMENT = trial.space.NODES_PER_ELEMENT
|
|
590
|
+
NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
|
|
592
591
|
|
|
593
592
|
def integrate_kernel_fn(
|
|
594
593
|
qp_arg: quadrature.Arg,
|
|
595
594
|
domain_arg: domain.ElementArg,
|
|
596
595
|
domain_index_arg: domain.ElementIndexArg,
|
|
597
|
-
test_arg:
|
|
596
|
+
test_arg: test.space_restriction.NodeArg,
|
|
598
597
|
trial_partition_arg: trial.space_partition.PartitionArg,
|
|
598
|
+
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
599
599
|
fields: FieldStruct,
|
|
600
600
|
values: ValueStruct,
|
|
601
601
|
row_offsets: wp.array(dtype=int),
|
|
602
602
|
triplet_rows: wp.array(dtype=int),
|
|
603
603
|
triplet_cols: wp.array(dtype=int),
|
|
604
|
-
triplet_values: wp.array3d(dtype=
|
|
604
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
605
605
|
):
|
|
606
|
-
test_local_node_index = wp.tid()
|
|
606
|
+
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
607
|
+
|
|
608
|
+
element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
|
|
609
|
+
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
607
610
|
|
|
608
|
-
|
|
609
|
-
test_node_index = test_space.node_partition_index(test_arg, test_local_node_index)
|
|
611
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
610
612
|
|
|
611
613
|
for element in range(element_count):
|
|
612
|
-
test_element_index =
|
|
614
|
+
test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
|
|
613
615
|
element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
|
|
614
|
-
qp_point_count = quadrature.point_count(qp_arg, element_index)
|
|
616
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
615
617
|
|
|
616
|
-
|
|
618
|
+
test_dof_index = DofIndex(
|
|
619
|
+
test_element_index.node_index_in_element,
|
|
620
|
+
test_dof,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
val_sum = accumulate_dtype(0.0)
|
|
617
624
|
|
|
618
625
|
for k in range(qp_point_count):
|
|
619
|
-
qp_index = quadrature.point_index(qp_arg, element_index, k)
|
|
620
|
-
coords = quadrature.point_coords(qp_arg, element_index, k)
|
|
621
|
-
|
|
622
|
-
qp_weight = quadrature.point_weight(qp_arg, element_index, k)
|
|
623
|
-
vol = domain.element_measure(domain_arg, element_index, coords)
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
trial_dof_index,
|
|
642
|
-
)
|
|
643
|
-
val = integrand_func(sample, fields, values)
|
|
644
|
-
triplet_values[offset_cur, i, j] = triplet_values[offset_cur, i, j] + accumulate_dtype(
|
|
645
|
-
qp_weight * vol * val
|
|
646
|
-
)
|
|
647
|
-
|
|
648
|
-
offset_cur += 1
|
|
649
|
-
|
|
650
|
-
# Set column indices
|
|
651
|
-
offset_cur = start_offset
|
|
652
|
-
for trial_n in range(NODES_PER_ELEMENT):
|
|
626
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
627
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
628
|
+
|
|
629
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
630
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
|
|
631
|
+
|
|
632
|
+
sample = Sample(
|
|
633
|
+
element_index,
|
|
634
|
+
coords,
|
|
635
|
+
qp_index,
|
|
636
|
+
qp_weight,
|
|
637
|
+
test_dof_index,
|
|
638
|
+
trial_dof_index,
|
|
639
|
+
)
|
|
640
|
+
val = integrand_func(sample, fields, values)
|
|
641
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
642
|
+
|
|
643
|
+
block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
|
|
644
|
+
triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
|
|
645
|
+
|
|
646
|
+
# Set row and column indices
|
|
647
|
+
if test_dof == 0 and trial_dof == 0:
|
|
653
648
|
trial_node_index = trial.space_partition.partition_node_index(
|
|
654
649
|
trial_partition_arg,
|
|
655
|
-
trial.space.element_node_index(
|
|
650
|
+
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
|
|
656
651
|
)
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
triplet_cols[offset_cur] = trial_node_index
|
|
660
|
-
offset_cur += 1
|
|
652
|
+
triplet_rows[block_offset] = test_node_index
|
|
653
|
+
triplet_cols[block_offset] = trial_node_index
|
|
661
654
|
|
|
662
655
|
return integrate_kernel_fn
|
|
663
656
|
|
|
@@ -668,6 +661,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
668
661
|
FieldStruct: wp.codegen.Struct,
|
|
669
662
|
ValueStruct: wp.codegen.Struct,
|
|
670
663
|
test: TestField,
|
|
664
|
+
output_dtype,
|
|
671
665
|
accumulate_dtype,
|
|
672
666
|
):
|
|
673
667
|
def integrate_kernel_fn(
|
|
@@ -678,7 +672,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
678
672
|
values: ValueStruct,
|
|
679
673
|
triplet_rows: wp.array(dtype=int),
|
|
680
674
|
triplet_cols: wp.array(dtype=int),
|
|
681
|
-
triplet_values: wp.array3d(dtype=
|
|
675
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
682
676
|
):
|
|
683
677
|
local_node_index, test_dof, trial_dof = wp.tid()
|
|
684
678
|
|
|
@@ -692,6 +686,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
692
686
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
693
687
|
|
|
694
688
|
coords = test.space.node_coords_in_element(
|
|
689
|
+
domain_arg,
|
|
695
690
|
_get_test_arg(),
|
|
696
691
|
element_index,
|
|
697
692
|
node_element_index.node_index_in_element,
|
|
@@ -699,13 +694,12 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
699
694
|
|
|
700
695
|
if coords[0] != OUTSIDE:
|
|
701
696
|
node_weight = test.space.node_quadrature_weight(
|
|
697
|
+
domain_arg,
|
|
702
698
|
_get_test_arg(),
|
|
703
699
|
element_index,
|
|
704
700
|
node_element_index.node_index_in_element,
|
|
705
701
|
)
|
|
706
702
|
|
|
707
|
-
vol = domain.element_measure(domain_arg, element_index, coords)
|
|
708
|
-
|
|
709
703
|
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
710
704
|
trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
|
|
711
705
|
|
|
@@ -717,11 +711,12 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
717
711
|
test_dof_index,
|
|
718
712
|
trial_dof_index,
|
|
719
713
|
)
|
|
714
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
720
715
|
val = integrand_func(sample, fields, values)
|
|
721
716
|
|
|
722
717
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
723
718
|
|
|
724
|
-
triplet_values[local_node_index, test_dof, trial_dof] = val_sum
|
|
719
|
+
triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
|
|
725
720
|
triplet_rows[local_node_index] = node_index
|
|
726
721
|
triplet_cols[local_node_index] = node_index
|
|
727
722
|
|
|
@@ -738,8 +733,12 @@ def _generate_integrate_kernel(
|
|
|
738
733
|
trial: Optional[TrialField],
|
|
739
734
|
trial_name: str,
|
|
740
735
|
fields: Dict[str, FieldLike],
|
|
736
|
+
output_dtype: type,
|
|
741
737
|
accumulate_dtype: type,
|
|
738
|
+
kernel_options: Dict[str, Any] = {},
|
|
742
739
|
) -> wp.Kernel:
|
|
740
|
+
output_dtype = wp.types.type_scalar_type(output_dtype)
|
|
741
|
+
|
|
743
742
|
# Extract field arguments from integrand
|
|
744
743
|
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
745
744
|
integrand, fields=fields, domain=domain
|
|
@@ -749,7 +748,7 @@ def _generate_integrate_kernel(
|
|
|
749
748
|
ValueStruct = _gen_value_struct(value_args)
|
|
750
749
|
|
|
751
750
|
# Check if kernel exist in cache
|
|
752
|
-
kernel_suffix = f"_itg_{domain.name}_{FieldStruct.key}"
|
|
751
|
+
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
|
|
753
752
|
if nodal:
|
|
754
753
|
kernel_suffix += "_nodal"
|
|
755
754
|
else:
|
|
@@ -774,6 +773,8 @@ def _generate_integrate_kernel(
|
|
|
774
773
|
field_args,
|
|
775
774
|
)
|
|
776
775
|
|
|
776
|
+
_register_integrand_field_wrappers(integrand_func, fields)
|
|
777
|
+
|
|
777
778
|
if test is None and trial is None:
|
|
778
779
|
integrate_kernel_fn = get_integrate_constant_kernel(
|
|
779
780
|
integrand_func,
|
|
@@ -791,6 +792,7 @@ def _generate_integrate_kernel(
|
|
|
791
792
|
FieldStruct,
|
|
792
793
|
ValueStruct,
|
|
793
794
|
test=test,
|
|
795
|
+
output_dtype=output_dtype,
|
|
794
796
|
accumulate_dtype=accumulate_dtype,
|
|
795
797
|
)
|
|
796
798
|
else:
|
|
@@ -800,7 +802,8 @@ def _generate_integrate_kernel(
|
|
|
800
802
|
quadrature,
|
|
801
803
|
FieldStruct,
|
|
802
804
|
ValueStruct,
|
|
803
|
-
|
|
805
|
+
test=test,
|
|
806
|
+
output_dtype=output_dtype,
|
|
804
807
|
accumulate_dtype=accumulate_dtype,
|
|
805
808
|
)
|
|
806
809
|
else:
|
|
@@ -811,6 +814,7 @@ def _generate_integrate_kernel(
|
|
|
811
814
|
FieldStruct,
|
|
812
815
|
ValueStruct,
|
|
813
816
|
test=test,
|
|
817
|
+
output_dtype=output_dtype,
|
|
814
818
|
accumulate_dtype=accumulate_dtype,
|
|
815
819
|
)
|
|
816
820
|
else:
|
|
@@ -820,8 +824,9 @@ def _generate_integrate_kernel(
|
|
|
820
824
|
quadrature,
|
|
821
825
|
FieldStruct,
|
|
822
826
|
ValueStruct,
|
|
823
|
-
|
|
827
|
+
test=test,
|
|
824
828
|
trial=trial,
|
|
829
|
+
output_dtype=output_dtype,
|
|
825
830
|
accumulate_dtype=accumulate_dtype,
|
|
826
831
|
)
|
|
827
832
|
|
|
@@ -829,6 +834,7 @@ def _generate_integrate_kernel(
|
|
|
829
834
|
integrand=integrand,
|
|
830
835
|
kernel_fn=integrate_kernel_fn,
|
|
831
836
|
suffix=kernel_suffix,
|
|
837
|
+
kernel_options=kernel_options,
|
|
832
838
|
code_transformers=[
|
|
833
839
|
PassFieldArgsToIntegrand(
|
|
834
840
|
arg_names=integrand.argspec.args,
|
|
@@ -837,7 +843,7 @@ def _generate_integrate_kernel(
|
|
|
837
843
|
sample_name=sample_name,
|
|
838
844
|
domain_name=domain_name,
|
|
839
845
|
test_name=test_name,
|
|
840
|
-
trial_name=trial_name
|
|
846
|
+
trial_name=trial_name,
|
|
841
847
|
)
|
|
842
848
|
],
|
|
843
849
|
)
|
|
@@ -846,7 +852,7 @@ def _generate_integrate_kernel(
|
|
|
846
852
|
|
|
847
853
|
|
|
848
854
|
def _launch_integrate_kernel(
|
|
849
|
-
kernel: wp.
|
|
855
|
+
kernel: wp.Kernel,
|
|
850
856
|
FieldStruct: wp.codegen.Struct,
|
|
851
857
|
ValueStruct: wp.codegen.Struct,
|
|
852
858
|
domain: GeometryDomain,
|
|
@@ -857,16 +863,11 @@ def _launch_integrate_kernel(
|
|
|
857
863
|
fields: Dict[str, FieldLike],
|
|
858
864
|
values: Dict[str, Any],
|
|
859
865
|
accumulate_dtype: type,
|
|
866
|
+
temporary_store: Optional[cache.TemporaryStore],
|
|
860
867
|
output_dtype: type,
|
|
861
868
|
output: Optional[Union[wp.array, BsrMatrix]],
|
|
862
869
|
device,
|
|
863
|
-
)
|
|
864
|
-
if output_dtype is None:
|
|
865
|
-
if output is not None:
|
|
866
|
-
output_dtype = output.dtype
|
|
867
|
-
else:
|
|
868
|
-
output_dtype = accumulate_dtype
|
|
869
|
-
|
|
870
|
+
):
|
|
870
871
|
# Set-up launch arguments
|
|
871
872
|
domain_elt_arg = domain.element_arg_value(device=device)
|
|
872
873
|
domain_elt_index_arg = domain.element_index_arg_value(device=device)
|
|
@@ -882,14 +883,23 @@ def _launch_integrate_kernel(
|
|
|
882
883
|
for k, v in values.items():
|
|
883
884
|
setattr(value_struct_values, k, v)
|
|
884
885
|
|
|
885
|
-
# Constant
|
|
886
|
+
# Constant form
|
|
886
887
|
if test is None and trial is None:
|
|
887
|
-
if output is None
|
|
888
|
-
|
|
888
|
+
if output is not None and output.dtype == accumulate_dtype:
|
|
889
|
+
if output.size < 1:
|
|
890
|
+
raise RuntimeError("Output array must be of size at least 1")
|
|
891
|
+
accumulate_array = output
|
|
889
892
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
893
|
+
accumulate_temporary = cache.borrow_temporary(
|
|
894
|
+
shape=(1),
|
|
895
|
+
device=device,
|
|
896
|
+
dtype=accumulate_dtype,
|
|
897
|
+
temporary_store=temporary_store,
|
|
898
|
+
requires_grad=output is not None and output.requires_grad,
|
|
899
|
+
)
|
|
900
|
+
accumulate_array = accumulate_temporary.array
|
|
892
901
|
|
|
902
|
+
accumulate_array.zero_()
|
|
893
903
|
wp.launch(
|
|
894
904
|
kernel=kernel,
|
|
895
905
|
dim=domain.element_count(),
|
|
@@ -899,43 +909,77 @@ def _launch_integrate_kernel(
|
|
|
899
909
|
domain_elt_index_arg,
|
|
900
910
|
field_arg_values,
|
|
901
911
|
value_struct_values,
|
|
902
|
-
|
|
912
|
+
accumulate_array,
|
|
903
913
|
],
|
|
904
914
|
device=device,
|
|
905
915
|
)
|
|
906
916
|
|
|
907
|
-
if output
|
|
908
|
-
return
|
|
917
|
+
if output == accumulate_array:
|
|
918
|
+
return output
|
|
919
|
+
elif output is None:
|
|
920
|
+
return accumulate_array.numpy()[0]
|
|
909
921
|
else:
|
|
910
|
-
|
|
911
|
-
array_cast(in_array=result, out_array=output)
|
|
922
|
+
array_cast(in_array=accumulate_array, out_array=output)
|
|
912
923
|
return output
|
|
913
924
|
|
|
914
925
|
test_arg = test.space_restriction.node_arg(device=device)
|
|
915
926
|
|
|
916
927
|
# Linear form
|
|
917
928
|
if trial is None:
|
|
918
|
-
|
|
919
|
-
|
|
929
|
+
# If an output array is provided with the correct type, accumulate directly into it
|
|
930
|
+
# Otherwise, grab a temporary array
|
|
931
|
+
if output is None:
|
|
932
|
+
if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
|
|
933
|
+
output_shape = (test.space_partition.node_count(),)
|
|
934
|
+
elif type_length(output_dtype) == 1:
|
|
935
|
+
output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
|
|
936
|
+
else:
|
|
937
|
+
raise RuntimeError(
|
|
938
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
output_temporary = cache.borrow_temporary(
|
|
942
|
+
temporary_store=temporary_store,
|
|
943
|
+
shape=output_shape,
|
|
944
|
+
dtype=output_dtype,
|
|
945
|
+
device=device,
|
|
946
|
+
)
|
|
947
|
+
|
|
948
|
+
output = output_temporary.array
|
|
949
|
+
|
|
920
950
|
else:
|
|
921
|
-
|
|
951
|
+
output_temporary = None
|
|
922
952
|
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
953
|
+
if output.shape[0] < test.space_partition.node_count():
|
|
954
|
+
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
955
|
+
|
|
956
|
+
output_dtype = output.dtype
|
|
957
|
+
if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
|
|
958
|
+
if type_length(output_dtype) != 1:
|
|
959
|
+
raise RuntimeError(
|
|
960
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
|
|
961
|
+
)
|
|
962
|
+
if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
|
|
963
|
+
raise RuntimeError(
|
|
964
|
+
f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
|
|
965
|
+
)
|
|
928
966
|
|
|
929
967
|
# Launch the integration on the kernel on a 2d scalar view of the actual array
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
968
|
+
output.zero_()
|
|
969
|
+
|
|
970
|
+
def as_2d_array(array):
|
|
971
|
+
return wp.array(
|
|
972
|
+
data=None,
|
|
973
|
+
ptr=array.ptr,
|
|
974
|
+
capacity=array.capacity,
|
|
975
|
+
owner=False,
|
|
976
|
+
device=array.device,
|
|
977
|
+
shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
|
|
978
|
+
dtype=wp.types.type_scalar_type(output_dtype),
|
|
979
|
+
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
output_view = output if output.ndim == 2 else as_2d_array(output)
|
|
939
983
|
|
|
940
984
|
if nodal:
|
|
941
985
|
wp.launch(
|
|
@@ -947,14 +991,14 @@ def _launch_integrate_kernel(
|
|
|
947
991
|
test_arg,
|
|
948
992
|
field_arg_values,
|
|
949
993
|
value_struct_values,
|
|
950
|
-
|
|
994
|
+
output_view,
|
|
951
995
|
],
|
|
952
996
|
device=device,
|
|
953
997
|
)
|
|
954
998
|
else:
|
|
955
999
|
wp.launch(
|
|
956
1000
|
kernel=kernel,
|
|
957
|
-
dim=test.space_restriction.node_count(),
|
|
1001
|
+
dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
|
|
958
1002
|
inputs=[
|
|
959
1003
|
qp_arg,
|
|
960
1004
|
domain_elt_arg,
|
|
@@ -962,55 +1006,47 @@ def _launch_integrate_kernel(
|
|
|
962
1006
|
test_arg,
|
|
963
1007
|
field_arg_values,
|
|
964
1008
|
value_struct_values,
|
|
965
|
-
|
|
1009
|
+
output_view,
|
|
966
1010
|
],
|
|
967
1011
|
device=device,
|
|
968
1012
|
)
|
|
969
1013
|
|
|
970
|
-
if
|
|
971
|
-
return
|
|
972
|
-
|
|
973
|
-
output_type_length = type_length(output_dtype)
|
|
974
|
-
if output_type_length == test.space.VALUE_DOF_COUNT:
|
|
975
|
-
cast_result = wp.empty(dtype=output_dtype, shape=result_array.shape)
|
|
976
|
-
else:
|
|
977
|
-
cast_result = wp.empty(dtype=output_dtype, shape=result_2d_view.shape)
|
|
1014
|
+
if output_temporary is not None:
|
|
1015
|
+
return output_temporary.detach()
|
|
978
1016
|
|
|
979
|
-
|
|
980
|
-
return cast_result
|
|
1017
|
+
return output
|
|
981
1018
|
|
|
982
1019
|
# Bilinear form
|
|
983
1020
|
|
|
984
1021
|
if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
|
|
985
|
-
block_type =
|
|
1022
|
+
block_type = output_dtype
|
|
986
1023
|
else:
|
|
987
|
-
block_type =
|
|
988
|
-
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=
|
|
1024
|
+
block_type = cache.cached_mat_type(
|
|
1025
|
+
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
|
|
989
1026
|
)
|
|
990
1027
|
|
|
991
|
-
bsr_matrix = bsr_zeros(
|
|
992
|
-
rows_of_blocks=test.space_partition.node_count(),
|
|
993
|
-
cols_of_blocks=trial.space_partition.node_count(),
|
|
994
|
-
block_type=block_type,
|
|
995
|
-
device=device,
|
|
996
|
-
)
|
|
997
|
-
|
|
998
1028
|
if nodal:
|
|
999
1029
|
nnz = test.space_restriction.node_count()
|
|
1000
1030
|
else:
|
|
1001
|
-
nnz = test.space_restriction.total_node_element_count() * trial.space.NODES_PER_ELEMENT
|
|
1031
|
+
nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
|
|
1002
1032
|
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1033
|
+
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1034
|
+
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1035
|
+
triplet_values_temp = cache.borrow_temporary(
|
|
1036
|
+
temporary_store,
|
|
1006
1037
|
shape=(
|
|
1007
1038
|
nnz,
|
|
1008
1039
|
test.space.VALUE_DOF_COUNT,
|
|
1009
1040
|
trial.space.VALUE_DOF_COUNT,
|
|
1010
1041
|
),
|
|
1011
|
-
dtype=
|
|
1042
|
+
dtype=output_dtype,
|
|
1012
1043
|
device=device,
|
|
1013
1044
|
)
|
|
1045
|
+
triplet_cols = triplet_cols_temp.array
|
|
1046
|
+
triplet_rows = triplet_rows_temp.array
|
|
1047
|
+
triplet_values = triplet_values_temp.array
|
|
1048
|
+
|
|
1049
|
+
triplet_values.zero_()
|
|
1014
1050
|
|
|
1015
1051
|
if nodal:
|
|
1016
1052
|
wp.launch(
|
|
@@ -1033,15 +1069,22 @@ def _launch_integrate_kernel(
|
|
|
1033
1069
|
offsets = test.space_restriction.partition_element_offsets()
|
|
1034
1070
|
|
|
1035
1071
|
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1072
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1036
1073
|
wp.launch(
|
|
1037
1074
|
kernel=kernel,
|
|
1038
|
-
dim=
|
|
1075
|
+
dim=(
|
|
1076
|
+
test.space_restriction.node_count(),
|
|
1077
|
+
trial.space.topology.NODES_PER_ELEMENT,
|
|
1078
|
+
test.space.VALUE_DOF_COUNT,
|
|
1079
|
+
trial.space.VALUE_DOF_COUNT,
|
|
1080
|
+
),
|
|
1039
1081
|
inputs=[
|
|
1040
1082
|
qp_arg,
|
|
1041
1083
|
domain_elt_arg,
|
|
1042
1084
|
domain_elt_index_arg,
|
|
1043
1085
|
test_arg,
|
|
1044
1086
|
trial_partition_arg,
|
|
1087
|
+
trial_topology_arg,
|
|
1045
1088
|
field_arg_values,
|
|
1046
1089
|
value_struct_values,
|
|
1047
1090
|
offsets,
|
|
@@ -1052,38 +1095,63 @@ def _launch_integrate_kernel(
|
|
|
1052
1095
|
device=device,
|
|
1053
1096
|
)
|
|
1054
1097
|
|
|
1055
|
-
|
|
1056
|
-
|
|
1098
|
+
if output is not None:
|
|
1099
|
+
if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
|
|
1100
|
+
raise RuntimeError(
|
|
1101
|
+
f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
else:
|
|
1105
|
+
output = bsr_zeros(
|
|
1106
|
+
rows_of_blocks=test.space_partition.node_count(),
|
|
1107
|
+
cols_of_blocks=trial.space_partition.node_count(),
|
|
1108
|
+
block_type=block_type,
|
|
1109
|
+
device=device,
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
|
|
1113
|
+
|
|
1114
|
+
# Do not wait for garbage collection
|
|
1115
|
+
triplet_values_temp.release()
|
|
1116
|
+
triplet_rows_temp.release()
|
|
1117
|
+
triplet_cols_temp.release()
|
|
1118
|
+
|
|
1119
|
+
return output
|
|
1057
1120
|
|
|
1058
1121
|
|
|
1059
1122
|
def integrate(
|
|
1060
1123
|
integrand: Integrand,
|
|
1061
|
-
domain: GeometryDomain = None,
|
|
1062
|
-
quadrature: Quadrature = None,
|
|
1124
|
+
domain: Optional[GeometryDomain] = None,
|
|
1125
|
+
quadrature: Optional[Quadrature] = None,
|
|
1063
1126
|
nodal: bool = False,
|
|
1064
|
-
fields={},
|
|
1065
|
-
values={},
|
|
1127
|
+
fields: Dict[str, FieldLike] = {},
|
|
1128
|
+
values: Dict[str, Any] = {},
|
|
1129
|
+
accumulate_dtype: type = wp.float64,
|
|
1130
|
+
output_dtype: Optional[type] = None,
|
|
1131
|
+
output: Optional[Union[BsrMatrix, wp.array]] = None,
|
|
1066
1132
|
device=None,
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
output=None,
|
|
1133
|
+
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
1134
|
+
kernel_options: Dict[str, Any] = {},
|
|
1070
1135
|
):
|
|
1071
1136
|
"""
|
|
1072
1137
|
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
|
|
1073
1138
|
|
|
1074
1139
|
Args:
|
|
1075
|
-
integrand: Form to be integrated, must have
|
|
1140
|
+
integrand: Form to be integrated, must have :func:`integrand` decorator
|
|
1076
1141
|
domain: Integration domain. If None, deduced from fields
|
|
1077
1142
|
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1078
1143
|
nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1079
1144
|
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1080
|
-
values: Additional variable values to be passed to the integrand, can
|
|
1081
|
-
|
|
1145
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
|
|
1146
|
+
temporary_store: shared pool from which to allocate temporary arrays
|
|
1082
1147
|
accumulate_dtype: Scalar type to be used for accumulating integration samples
|
|
1083
|
-
|
|
1148
|
+
output: Sparse matrix or warp array into which to store the result of the integration
|
|
1149
|
+
output_dtype: Scalar type for returned results in `output` is notr provided. If None, defaults to `accumulate_dtype`
|
|
1150
|
+
device: Device on which to perform the integration
|
|
1151
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1084
1152
|
"""
|
|
1085
1153
|
if not isinstance(integrand, Integrand):
|
|
1086
|
-
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
1154
|
+
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
|
|
1087
1155
|
|
|
1088
1156
|
test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
|
|
1089
1157
|
|
|
@@ -1111,15 +1179,23 @@ def integrate(
|
|
|
1111
1179
|
)
|
|
1112
1180
|
else:
|
|
1113
1181
|
if quadrature is None:
|
|
1114
|
-
order =
|
|
1115
|
-
if test is not None:
|
|
1116
|
-
order += test.space.degree
|
|
1117
|
-
if trial is not None:
|
|
1118
|
-
order += trial.space.degree
|
|
1182
|
+
order = sum(field.degree for field in fields.values())
|
|
1119
1183
|
quadrature = RegularQuadrature(domain=domain, order=order)
|
|
1120
1184
|
elif domain != quadrature.domain:
|
|
1121
1185
|
raise ValueError("Incompatible integration and quadrature domain")
|
|
1122
1186
|
|
|
1187
|
+
# Canonicalize types
|
|
1188
|
+
accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
|
|
1189
|
+
if output is not None:
|
|
1190
|
+
if isinstance(output, BsrMatrix):
|
|
1191
|
+
output_dtype = output.scalar_type
|
|
1192
|
+
else:
|
|
1193
|
+
output_dtype = output.dtype
|
|
1194
|
+
elif output_dtype is None:
|
|
1195
|
+
output_dtype = accumulate_dtype
|
|
1196
|
+
else:
|
|
1197
|
+
output_dtype = wp.types.type_to_warp(output_dtype)
|
|
1198
|
+
|
|
1123
1199
|
kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
|
|
1124
1200
|
integrand=integrand,
|
|
1125
1201
|
domain=domain,
|
|
@@ -1131,6 +1207,8 @@ def integrate(
|
|
|
1131
1207
|
trial_name=trial_name,
|
|
1132
1208
|
fields=fields,
|
|
1133
1209
|
accumulate_dtype=accumulate_dtype,
|
|
1210
|
+
output_dtype=output_dtype,
|
|
1211
|
+
kernel_options=kernel_options,
|
|
1134
1212
|
)
|
|
1135
1213
|
|
|
1136
1214
|
return _launch_integrate_kernel(
|
|
@@ -1145,13 +1223,14 @@ def integrate(
|
|
|
1145
1223
|
fields=fields,
|
|
1146
1224
|
values=values,
|
|
1147
1225
|
accumulate_dtype=accumulate_dtype,
|
|
1226
|
+
temporary_store=temporary_store,
|
|
1148
1227
|
output_dtype=output_dtype,
|
|
1149
1228
|
output=output,
|
|
1150
1229
|
device=device,
|
|
1151
1230
|
)
|
|
1152
1231
|
|
|
1153
1232
|
|
|
1154
|
-
def
|
|
1233
|
+
def get_interpolate_to_field_function(
|
|
1155
1234
|
integrand_func: wp.Function,
|
|
1156
1235
|
domain: GeometryDomain,
|
|
1157
1236
|
FieldStruct: wp.codegen.Struct,
|
|
@@ -1160,7 +1239,8 @@ def get_interpolate_kernel(
|
|
|
1160
1239
|
):
|
|
1161
1240
|
value_type = dest.space.dtype
|
|
1162
1241
|
|
|
1163
|
-
def
|
|
1242
|
+
def interpolate_to_field_fn(
|
|
1243
|
+
local_node_index: int,
|
|
1164
1244
|
domain_arg: domain.ElementArg,
|
|
1165
1245
|
domain_index_arg: domain.ElementIndexArg,
|
|
1166
1246
|
dest_node_arg: dest.space_restriction.NodeArg,
|
|
@@ -1168,19 +1248,15 @@ def get_interpolate_kernel(
|
|
|
1168
1248
|
fields: FieldStruct,
|
|
1169
1249
|
values: ValueStruct,
|
|
1170
1250
|
):
|
|
1171
|
-
local_node_index = wp.tid()
|
|
1172
1251
|
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1173
|
-
|
|
1174
1252
|
element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
|
|
1175
|
-
if element_count == 0:
|
|
1176
|
-
return
|
|
1177
1253
|
|
|
1178
1254
|
test_dof_index = NULL_DOF_INDEX
|
|
1179
1255
|
trial_dof_index = NULL_DOF_INDEX
|
|
1180
1256
|
node_weight = 1.0
|
|
1181
1257
|
|
|
1182
|
-
# Volume-weighted average
|
|
1183
|
-
# Superfluous if the function is continuous, but
|
|
1258
|
+
# Volume-weighted average across elements
|
|
1259
|
+
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
|
|
1184
1260
|
|
|
1185
1261
|
val_sum = value_type(0.0)
|
|
1186
1262
|
vol_sum = float(0.0)
|
|
@@ -1190,14 +1266,13 @@ def get_interpolate_kernel(
|
|
|
1190
1266
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
1191
1267
|
|
|
1192
1268
|
coords = dest.space.node_coords_in_element(
|
|
1269
|
+
domain_arg,
|
|
1193
1270
|
dest_eval_arg.space_arg,
|
|
1194
1271
|
element_index,
|
|
1195
1272
|
node_element_index.node_index_in_element,
|
|
1196
1273
|
)
|
|
1197
1274
|
|
|
1198
1275
|
if coords[0] != OUTSIDE:
|
|
1199
|
-
vol = domain.element_measure(domain_arg, element_index, coords)
|
|
1200
|
-
|
|
1201
1276
|
sample = Sample(
|
|
1202
1277
|
element_index,
|
|
1203
1278
|
coords,
|
|
@@ -1206,20 +1281,118 @@ def get_interpolate_kernel(
|
|
|
1206
1281
|
test_dof_index,
|
|
1207
1282
|
trial_dof_index,
|
|
1208
1283
|
)
|
|
1284
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
1209
1285
|
val = integrand_func(sample, fields, values)
|
|
1210
1286
|
|
|
1211
1287
|
vol_sum += vol
|
|
1212
1288
|
val_sum += vol * val
|
|
1213
1289
|
|
|
1290
|
+
return val_sum, vol_sum
|
|
1291
|
+
|
|
1292
|
+
return interpolate_to_field_fn
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
def get_interpolate_to_field_kernel(
|
|
1296
|
+
interpolate_to_field_fn: wp.Function,
|
|
1297
|
+
domain: GeometryDomain,
|
|
1298
|
+
FieldStruct: wp.codegen.Struct,
|
|
1299
|
+
ValueStruct: wp.codegen.Struct,
|
|
1300
|
+
dest: FieldRestriction,
|
|
1301
|
+
):
|
|
1302
|
+
def interpolate_to_field_kernel_fn(
|
|
1303
|
+
domain_arg: domain.ElementArg,
|
|
1304
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1305
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1306
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1307
|
+
fields: FieldStruct,
|
|
1308
|
+
values: ValueStruct,
|
|
1309
|
+
):
|
|
1310
|
+
local_node_index = wp.tid()
|
|
1311
|
+
|
|
1312
|
+
val_sum, vol_sum = interpolate_to_field_fn(
|
|
1313
|
+
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1214
1316
|
if vol_sum > 0.0:
|
|
1317
|
+
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1215
1318
|
dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
|
|
1216
1319
|
|
|
1217
|
-
return
|
|
1320
|
+
return interpolate_to_field_kernel_fn
|
|
1321
|
+
|
|
1322
|
+
|
|
1323
|
+
def get_interpolate_to_array_kernel(
|
|
1324
|
+
integrand_func: wp.Function,
|
|
1325
|
+
domain: GeometryDomain,
|
|
1326
|
+
quadrature: Quadrature,
|
|
1327
|
+
FieldStruct: wp.codegen.Struct,
|
|
1328
|
+
ValueStruct: wp.codegen.Struct,
|
|
1329
|
+
value_type: type,
|
|
1330
|
+
):
|
|
1331
|
+
def interpolate_to_array_kernel_fn(
|
|
1332
|
+
qp_arg: quadrature.Arg,
|
|
1333
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1334
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1335
|
+
fields: FieldStruct,
|
|
1336
|
+
values: ValueStruct,
|
|
1337
|
+
result: wp.array(dtype=value_type),
|
|
1338
|
+
):
|
|
1339
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
1340
|
+
|
|
1341
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1342
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1343
|
+
|
|
1344
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1345
|
+
for k in range(qp_point_count):
|
|
1346
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1347
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1348
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1349
|
+
|
|
1350
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1351
|
+
|
|
1352
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1353
|
+
|
|
1354
|
+
return interpolate_to_array_kernel_fn
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
def get_interpolate_nonvalued_kernel(
|
|
1358
|
+
integrand_func: wp.Function,
|
|
1359
|
+
domain: GeometryDomain,
|
|
1360
|
+
quadrature: Quadrature,
|
|
1361
|
+
FieldStruct: wp.codegen.Struct,
|
|
1362
|
+
ValueStruct: wp.codegen.Struct,
|
|
1363
|
+
):
|
|
1364
|
+
def interpolate_nonvalued_kernel_fn(
|
|
1365
|
+
qp_arg: quadrature.Arg,
|
|
1366
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1367
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1368
|
+
fields: FieldStruct,
|
|
1369
|
+
values: ValueStruct,
|
|
1370
|
+
):
|
|
1371
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
1372
|
+
|
|
1373
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1374
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1375
|
+
|
|
1376
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1377
|
+
for k in range(qp_point_count):
|
|
1378
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1379
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1380
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1381
|
+
|
|
1382
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1383
|
+
integrand_func(sample, fields, values)
|
|
1218
1384
|
|
|
1385
|
+
return interpolate_nonvalued_kernel_fn
|
|
1219
1386
|
|
|
1220
|
-
def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields: Dict[str, FieldLike]) -> wp.Kernel:
|
|
1221
|
-
domain = dest.domain
|
|
1222
1387
|
|
|
1388
|
+
def _generate_interpolate_kernel(
|
|
1389
|
+
integrand: Integrand,
|
|
1390
|
+
domain: GeometryDomain,
|
|
1391
|
+
dest: Optional[Union[FieldLike, wp.array]],
|
|
1392
|
+
quadrature: Optional[Quadrature],
|
|
1393
|
+
fields: Dict[str, FieldLike],
|
|
1394
|
+
kernel_options: Dict[str, Any] = {},
|
|
1395
|
+
) -> wp.Kernel:
|
|
1223
1396
|
# Extract field arguments from integrand
|
|
1224
1397
|
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
1225
1398
|
integrand, fields=fields, domain=domain
|
|
@@ -1231,11 +1404,20 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
|
|
|
1231
1404
|
field_args,
|
|
1232
1405
|
)
|
|
1233
1406
|
|
|
1407
|
+
_register_integrand_field_wrappers(integrand_func, fields)
|
|
1408
|
+
|
|
1234
1409
|
FieldStruct = _gen_field_struct(field_args)
|
|
1235
1410
|
ValueStruct = _gen_value_struct(value_args)
|
|
1236
1411
|
|
|
1237
1412
|
# Check if kernel exist in cache
|
|
1238
|
-
|
|
1413
|
+
if isinstance(dest, FieldRestriction):
|
|
1414
|
+
kernel_suffix = (
|
|
1415
|
+
f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
|
|
1416
|
+
)
|
|
1417
|
+
elif wp.types.is_array(dest):
|
|
1418
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
|
|
1419
|
+
else:
|
|
1420
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
|
|
1239
1421
|
|
|
1240
1422
|
kernel = cache.get_integrand_kernel(
|
|
1241
1423
|
integrand=integrand,
|
|
@@ -1245,18 +1427,61 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
|
|
|
1245
1427
|
return kernel, FieldStruct, ValueStruct
|
|
1246
1428
|
|
|
1247
1429
|
# Generate interpolation kernel
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1430
|
+
if isinstance(dest, FieldRestriction):
|
|
1431
|
+
# need to split into kernel + function for diffferentiability
|
|
1432
|
+
interpolate_fn = get_interpolate_to_field_function(
|
|
1433
|
+
integrand_func,
|
|
1434
|
+
domain,
|
|
1435
|
+
dest=dest,
|
|
1436
|
+
FieldStruct=FieldStruct,
|
|
1437
|
+
ValueStruct=ValueStruct,
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
interpolate_fn = cache.get_integrand_function(
|
|
1441
|
+
integrand=integrand,
|
|
1442
|
+
func=interpolate_fn,
|
|
1443
|
+
suffix=kernel_suffix,
|
|
1444
|
+
code_transformers=[
|
|
1445
|
+
PassFieldArgsToIntegrand(
|
|
1446
|
+
arg_names=integrand.argspec.args,
|
|
1447
|
+
field_args=field_args.keys(),
|
|
1448
|
+
value_args=value_args.keys(),
|
|
1449
|
+
sample_name=sample_name,
|
|
1450
|
+
domain_name=domain_name,
|
|
1451
|
+
)
|
|
1452
|
+
],
|
|
1453
|
+
)
|
|
1454
|
+
|
|
1455
|
+
interpolate_kernel_fn = get_interpolate_to_field_kernel(
|
|
1456
|
+
interpolate_fn,
|
|
1457
|
+
domain,
|
|
1458
|
+
dest=dest,
|
|
1459
|
+
FieldStruct=FieldStruct,
|
|
1460
|
+
ValueStruct=ValueStruct,
|
|
1461
|
+
)
|
|
1462
|
+
elif wp.types.is_array(dest):
|
|
1463
|
+
interpolate_kernel_fn = get_interpolate_to_array_kernel(
|
|
1464
|
+
integrand_func,
|
|
1465
|
+
domain=domain,
|
|
1466
|
+
quadrature=quadrature,
|
|
1467
|
+
value_type=dest.dtype,
|
|
1468
|
+
FieldStruct=FieldStruct,
|
|
1469
|
+
ValueStruct=ValueStruct,
|
|
1470
|
+
)
|
|
1471
|
+
else:
|
|
1472
|
+
interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
|
|
1473
|
+
integrand_func,
|
|
1474
|
+
domain=domain,
|
|
1475
|
+
quadrature=quadrature,
|
|
1476
|
+
FieldStruct=FieldStruct,
|
|
1477
|
+
ValueStruct=ValueStruct,
|
|
1478
|
+
)
|
|
1255
1479
|
|
|
1256
1480
|
kernel = cache.get_integrand_kernel(
|
|
1257
1481
|
integrand=integrand,
|
|
1258
1482
|
kernel_fn=interpolate_kernel_fn,
|
|
1259
1483
|
suffix=kernel_suffix,
|
|
1484
|
+
kernel_options=kernel_options,
|
|
1260
1485
|
code_transformers=[
|
|
1261
1486
|
PassFieldArgsToIntegrand(
|
|
1262
1487
|
arg_names=integrand.argspec.args,
|
|
@@ -1275,16 +1500,16 @@ def _launch_interpolate_kernel(
|
|
|
1275
1500
|
kernel: wp.kernel,
|
|
1276
1501
|
FieldStruct: wp.codegen.Struct,
|
|
1277
1502
|
ValueStruct: wp.codegen.Struct,
|
|
1278
|
-
|
|
1503
|
+
domain: GeometryDomain,
|
|
1504
|
+
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
1505
|
+
quadrature: Optional[Quadrature],
|
|
1279
1506
|
fields: Dict[str, FieldLike],
|
|
1280
1507
|
values: Dict[str, Any],
|
|
1281
1508
|
device,
|
|
1282
1509
|
) -> wp.Kernel:
|
|
1283
1510
|
# Set-up launch arguments
|
|
1284
|
-
elt_arg =
|
|
1285
|
-
elt_index_arg =
|
|
1286
|
-
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
1287
|
-
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
1511
|
+
elt_arg = domain.element_arg_value(device=device)
|
|
1512
|
+
elt_index_arg = domain.element_index_arg_value(device=device)
|
|
1288
1513
|
|
|
1289
1514
|
field_arg_values = FieldStruct()
|
|
1290
1515
|
for k, v in fields.items():
|
|
@@ -1294,37 +1519,65 @@ def _launch_interpolate_kernel(
|
|
|
1294
1519
|
for k, v in values.items():
|
|
1295
1520
|
setattr(value_struct_values, k, v)
|
|
1296
1521
|
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1522
|
+
if isinstance(dest, FieldRestriction):
|
|
1523
|
+
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
1524
|
+
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
1525
|
+
|
|
1526
|
+
wp.launch(
|
|
1527
|
+
kernel=kernel,
|
|
1528
|
+
dim=dest.space_restriction.node_count(),
|
|
1529
|
+
inputs=[
|
|
1530
|
+
elt_arg,
|
|
1531
|
+
elt_index_arg,
|
|
1532
|
+
dest_node_arg,
|
|
1533
|
+
dest_eval_arg,
|
|
1534
|
+
field_arg_values,
|
|
1535
|
+
value_struct_values,
|
|
1536
|
+
],
|
|
1537
|
+
device=device,
|
|
1538
|
+
)
|
|
1539
|
+
elif wp.types.is_array(dest):
|
|
1540
|
+
qp_arg = quadrature.arg_value(device)
|
|
1541
|
+
wp.launch(
|
|
1542
|
+
kernel=kernel,
|
|
1543
|
+
dim=domain.element_count(),
|
|
1544
|
+
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
1545
|
+
device=device,
|
|
1546
|
+
)
|
|
1547
|
+
else:
|
|
1548
|
+
qp_arg = quadrature.arg_value(device)
|
|
1549
|
+
wp.launch(
|
|
1550
|
+
kernel=kernel,
|
|
1551
|
+
dim=domain.element_count(),
|
|
1552
|
+
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
|
|
1553
|
+
device=device,
|
|
1554
|
+
)
|
|
1310
1555
|
|
|
1311
1556
|
|
|
1312
1557
|
def interpolate(
|
|
1313
1558
|
integrand: Integrand,
|
|
1314
|
-
dest: Union[DiscreteField, FieldRestriction],
|
|
1315
|
-
|
|
1316
|
-
|
|
1559
|
+
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
1560
|
+
quadrature: Optional[Quadrature] = None,
|
|
1561
|
+
fields: Dict[str, FieldLike] = {},
|
|
1562
|
+
values: Dict[str, Any] = {},
|
|
1317
1563
|
device=None,
|
|
1564
|
+
kernel_options: Dict[str, Any] = {},
|
|
1318
1565
|
):
|
|
1319
1566
|
"""
|
|
1320
|
-
Interpolates a function and assigns the result to a discrete field.
|
|
1567
|
+
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
1321
1568
|
|
|
1322
1569
|
Args:
|
|
1323
|
-
integrand: Function to be interpolated, must have
|
|
1324
|
-
dest:
|
|
1570
|
+
integrand: Function to be interpolated, must have :func:`integrand` decorator
|
|
1571
|
+
dest: Where to store the interpolation result. Can be either
|
|
1572
|
+
|
|
1573
|
+
- a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
|
|
1574
|
+
- a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
|
|
1575
|
+
- ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is reponsible for dealing with the interpolation result.
|
|
1576
|
+
quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
|
|
1325
1577
|
fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
|
|
1326
|
-
values: Additional variable values to be passed to the integrand, can
|
|
1578
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
|
|
1327
1579
|
device: Device on which to perform the interpolation
|
|
1580
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1328
1581
|
"""
|
|
1329
1582
|
if not isinstance(integrand, Integrand):
|
|
1330
1583
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
@@ -1333,20 +1586,33 @@ def interpolate(
|
|
|
1333
1586
|
if test is not None or trial is not None:
|
|
1334
1587
|
raise ValueError("Test or Trial fields should not be used for interpolation")
|
|
1335
1588
|
|
|
1336
|
-
if
|
|
1589
|
+
if isinstance(dest, DiscreteField):
|
|
1337
1590
|
dest = make_restriction(dest)
|
|
1338
1591
|
|
|
1592
|
+
if isinstance(dest, FieldRestriction):
|
|
1593
|
+
domain = dest.domain
|
|
1594
|
+
else:
|
|
1595
|
+
if quadrature is None:
|
|
1596
|
+
raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
|
|
1597
|
+
|
|
1598
|
+
domain = quadrature.domain
|
|
1599
|
+
|
|
1339
1600
|
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
1340
1601
|
integrand=integrand,
|
|
1602
|
+
domain=domain,
|
|
1341
1603
|
dest=dest,
|
|
1604
|
+
quadrature=quadrature,
|
|
1342
1605
|
fields=fields,
|
|
1606
|
+
kernel_options=kernel_options,
|
|
1343
1607
|
)
|
|
1344
1608
|
|
|
1345
1609
|
return _launch_interpolate_kernel(
|
|
1346
1610
|
kernel=kernel,
|
|
1347
1611
|
FieldStruct=FieldStruct,
|
|
1348
1612
|
ValueStruct=ValueStruct,
|
|
1613
|
+
domain=domain,
|
|
1349
1614
|
dest=dest,
|
|
1615
|
+
quadrature=quadrature,
|
|
1350
1616
|
fields=fields,
|
|
1351
1617
|
values=values,
|
|
1352
1618
|
device=device,
|