warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -19,6 +19,7 @@ import textwrap
|
|
|
19
19
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
|
20
20
|
|
|
21
21
|
import warp as wp
|
|
22
|
+
import warp.fem.operator as operator
|
|
22
23
|
from warp.codegen import get_annotations
|
|
23
24
|
from warp.fem import cache
|
|
24
25
|
from warp.fem.domain import GeometryDomain
|
|
@@ -35,7 +36,11 @@ from warp.fem.field import (
|
|
|
35
36
|
)
|
|
36
37
|
from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
|
|
37
38
|
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
38
|
-
from warp.fem.operator import
|
|
39
|
+
from warp.fem.operator import (
|
|
40
|
+
Integrand,
|
|
41
|
+
Operator,
|
|
42
|
+
integrand,
|
|
43
|
+
)
|
|
39
44
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
40
45
|
from warp.fem.types import (
|
|
41
46
|
NULL_DOF_INDEX,
|
|
@@ -49,8 +54,9 @@ from warp.fem.types import (
|
|
|
49
54
|
Sample,
|
|
50
55
|
make_free_sample,
|
|
51
56
|
)
|
|
57
|
+
from warp.fem.utils import type_zero_element
|
|
52
58
|
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
53
|
-
from warp.types import
|
|
59
|
+
from warp.types import is_array, type_size
|
|
54
60
|
from warp.utils import array_cast
|
|
55
61
|
|
|
56
62
|
|
|
@@ -111,6 +117,8 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
111
117
|
def get_concrete_type(field: Union[FieldLike, Domain]):
|
|
112
118
|
if isinstance(field, FieldLike):
|
|
113
119
|
return field.ElementEvalArg
|
|
120
|
+
elif isinstance(field, GeometryDomain):
|
|
121
|
+
return field.DomainArg
|
|
114
122
|
return field.ElementArg
|
|
115
123
|
|
|
116
124
|
return {
|
|
@@ -232,7 +240,7 @@ class IntegrandOperatorParser(IntegrandVisitor):
|
|
|
232
240
|
|
|
233
241
|
@staticmethod
|
|
234
242
|
def apply(
|
|
235
|
-
integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
|
|
243
|
+
integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Optional[Callable] = None
|
|
236
244
|
) -> wp.Function:
|
|
237
245
|
field_info = IntegrandVisitor._build_field_info(integrand, field_args)
|
|
238
246
|
IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
|
|
@@ -267,7 +275,11 @@ class IntegrandTransformer(IntegrandVisitor):
|
|
|
267
275
|
setattr(field_info.concrete_type, pointer.key, pointer)
|
|
268
276
|
|
|
269
277
|
# also insert callee as first argument
|
|
270
|
-
call.args = [ast.Name(id=callee, ctx=ast.Load())
|
|
278
|
+
call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
|
|
279
|
+
|
|
280
|
+
# replace first argument with selected attribute
|
|
281
|
+
if operator.attr:
|
|
282
|
+
call.args[0] = ast.Attribute(value=call.args[0], attr=operator.attr)
|
|
271
283
|
|
|
272
284
|
def _process_integrand_call(
|
|
273
285
|
self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
|
|
@@ -456,6 +468,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
456
468
|
fields_var_name: str = "fields",
|
|
457
469
|
values_var_name: str = "values",
|
|
458
470
|
domain_var_name: str = "domain_arg",
|
|
471
|
+
domain_index_var_name: str = "domain_index_arg",
|
|
459
472
|
sample_var_name: str = "sample",
|
|
460
473
|
field_wrappers_attr: str = "_field_wrappers",
|
|
461
474
|
):
|
|
@@ -470,6 +483,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
470
483
|
self._fields_var_name = fields_var_name
|
|
471
484
|
self._values_var_name = values_var_name
|
|
472
485
|
self._domain_var_name = domain_var_name
|
|
486
|
+
self._domain_index_var_name = domain_index_var_name
|
|
473
487
|
self._sample_var_name = sample_var_name
|
|
474
488
|
|
|
475
489
|
self._field_wrappers_attr = field_wrappers_attr
|
|
@@ -485,8 +499,28 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
485
499
|
for name, field in fields.items():
|
|
486
500
|
if isinstance(field, FieldLike):
|
|
487
501
|
setattr(field_wrappers, name, field.ElementEvalArg)
|
|
502
|
+
elif isinstance(field, GeometryDomain):
|
|
503
|
+
setattr(field_wrappers, name, field.DomainArg)
|
|
488
504
|
setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
|
|
489
505
|
|
|
506
|
+
def _emit_field_wrapper_call(self, field_name, *data_arguments):
|
|
507
|
+
return ast.Call(
|
|
508
|
+
func=ast.Attribute(
|
|
509
|
+
value=ast.Attribute(
|
|
510
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
|
511
|
+
attr=self._field_wrappers_attr,
|
|
512
|
+
ctx=ast.Load(),
|
|
513
|
+
),
|
|
514
|
+
attr=field_name,
|
|
515
|
+
ctx=ast.Load(),
|
|
516
|
+
),
|
|
517
|
+
args=[
|
|
518
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
519
|
+
*data_arguments,
|
|
520
|
+
],
|
|
521
|
+
keywords=[],
|
|
522
|
+
)
|
|
523
|
+
|
|
490
524
|
def visit_Call(self, call: ast.Call):
|
|
491
525
|
call = self.generic_visit(call)
|
|
492
526
|
|
|
@@ -498,33 +532,25 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
498
532
|
for arg in self._arg_names:
|
|
499
533
|
if arg == self._domain_name:
|
|
500
534
|
call.args.append(
|
|
501
|
-
|
|
535
|
+
self._emit_field_wrapper_call(
|
|
536
|
+
arg,
|
|
537
|
+
ast.Name(id=self._domain_index_var_name, ctx=ast.Load()),
|
|
538
|
+
)
|
|
502
539
|
)
|
|
540
|
+
|
|
503
541
|
elif arg == self._sample_name:
|
|
504
542
|
call.args.append(
|
|
505
543
|
ast.Name(id=self._sample_var_name, ctx=ast.Load()),
|
|
506
544
|
)
|
|
507
545
|
elif arg in self._field_args:
|
|
508
546
|
call.args.append(
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
attr=self._field_wrappers_attr,
|
|
514
|
-
ctx=ast.Load(),
|
|
515
|
-
),
|
|
547
|
+
self._emit_field_wrapper_call(
|
|
548
|
+
arg,
|
|
549
|
+
ast.Attribute(
|
|
550
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
516
551
|
attr=arg,
|
|
517
552
|
ctx=ast.Load(),
|
|
518
553
|
),
|
|
519
|
-
args=[
|
|
520
|
-
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
521
|
-
ast.Attribute(
|
|
522
|
-
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
523
|
-
attr=arg,
|
|
524
|
-
ctx=ast.Load(),
|
|
525
|
-
),
|
|
526
|
-
],
|
|
527
|
-
keywords=[],
|
|
528
554
|
)
|
|
529
555
|
)
|
|
530
556
|
elif arg in self._value_args:
|
|
@@ -704,7 +730,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
704
730
|
|
|
705
731
|
coords = test.space.node_coords_in_element(
|
|
706
732
|
domain_arg,
|
|
707
|
-
_get_test_arg(),
|
|
733
|
+
_get_test_arg().space_arg,
|
|
708
734
|
element_index,
|
|
709
735
|
node_element_index.node_index_in_element,
|
|
710
736
|
)
|
|
@@ -712,7 +738,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
712
738
|
if coords[0] != OUTSIDE:
|
|
713
739
|
node_weight = test.space.node_quadrature_weight(
|
|
714
740
|
domain_arg,
|
|
715
|
-
_get_test_arg(),
|
|
741
|
+
_get_test_arg().space_arg,
|
|
716
742
|
element_index,
|
|
717
743
|
node_element_index.node_index_in_element,
|
|
718
744
|
)
|
|
@@ -913,7 +939,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
913
939
|
|
|
914
940
|
coords = test.space.node_coords_in_element(
|
|
915
941
|
domain_arg,
|
|
916
|
-
_get_test_arg(),
|
|
942
|
+
_get_test_arg().space_arg,
|
|
917
943
|
element_index,
|
|
918
944
|
node_element_index.node_index_in_element,
|
|
919
945
|
)
|
|
@@ -921,7 +947,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
921
947
|
if coords[0] != OUTSIDE:
|
|
922
948
|
node_weight = test.space.node_quadrature_weight(
|
|
923
949
|
domain_arg,
|
|
924
|
-
_get_test_arg(),
|
|
950
|
+
_get_test_arg().space_arg,
|
|
925
951
|
element_index,
|
|
926
952
|
node_element_index.node_index_in_element,
|
|
927
953
|
)
|
|
@@ -1153,7 +1179,7 @@ def _launch_integrate_kernel(
|
|
|
1153
1179
|
field_arg_values = FieldStruct()
|
|
1154
1180
|
for k, v in fields.items():
|
|
1155
1181
|
if not isinstance(v, GeometryDomain):
|
|
1156
|
-
|
|
1182
|
+
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
1157
1183
|
|
|
1158
1184
|
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
1159
1185
|
|
|
@@ -1203,7 +1229,7 @@ def _launch_integrate_kernel(
|
|
|
1203
1229
|
array_cast(in_array=accumulate_array, out_array=output)
|
|
1204
1230
|
return output
|
|
1205
1231
|
|
|
1206
|
-
test_arg = test.space_restriction.
|
|
1232
|
+
test_arg = test.space_restriction.node_arg_value(device=device)
|
|
1207
1233
|
nodal = quadrature is None
|
|
1208
1234
|
|
|
1209
1235
|
# Linear form
|
|
@@ -1211,9 +1237,9 @@ def _launch_integrate_kernel(
|
|
|
1211
1237
|
# If an output array is provided with the correct type, accumulate directly into it
|
|
1212
1238
|
# Otherwise, grab a temporary array
|
|
1213
1239
|
if output is None:
|
|
1214
|
-
if
|
|
1240
|
+
if type_size(output_dtype) == test.node_dof_count:
|
|
1215
1241
|
output_shape = (test.space_partition.node_count(),)
|
|
1216
|
-
elif
|
|
1242
|
+
elif type_size(output_dtype) == 1:
|
|
1217
1243
|
output_shape = (test.space_partition.node_count(), test.node_dof_count)
|
|
1218
1244
|
else:
|
|
1219
1245
|
raise RuntimeError(
|
|
@@ -1236,8 +1262,8 @@ def _launch_integrate_kernel(
|
|
|
1236
1262
|
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
1237
1263
|
|
|
1238
1264
|
output_dtype = output.dtype
|
|
1239
|
-
if
|
|
1240
|
-
if
|
|
1265
|
+
if type_size(output_dtype) != test.node_dof_count:
|
|
1266
|
+
if type_size(output_dtype) != 1:
|
|
1241
1267
|
raise RuntimeError(
|
|
1242
1268
|
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1243
1269
|
)
|
|
@@ -1302,21 +1328,28 @@ def _launch_integrate_kernel(
|
|
|
1302
1328
|
device=device,
|
|
1303
1329
|
)
|
|
1304
1330
|
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1331
|
+
if test.TAYLOR_DOF_COUNT == 0:
|
|
1332
|
+
wp.utils.warn(
|
|
1333
|
+
f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1334
|
+
category=UserWarning,
|
|
1335
|
+
stacklevel=2,
|
|
1336
|
+
)
|
|
1337
|
+
else:
|
|
1338
|
+
dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
|
|
1339
|
+
wp.launch(
|
|
1340
|
+
kernel=dispatch_kernel,
|
|
1341
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1342
|
+
inputs=[
|
|
1343
|
+
qp_arg,
|
|
1344
|
+
domain_elt_arg,
|
|
1345
|
+
domain_elt_index_arg,
|
|
1346
|
+
test_arg,
|
|
1347
|
+
test.space.space_arg_value(device),
|
|
1348
|
+
local_result.array,
|
|
1349
|
+
output_view,
|
|
1350
|
+
],
|
|
1351
|
+
device=device,
|
|
1352
|
+
)
|
|
1320
1353
|
|
|
1321
1354
|
local_result.release()
|
|
1322
1355
|
|
|
@@ -1433,34 +1466,42 @@ def _launch_integrate_kernel(
|
|
|
1433
1466
|
dtype=vec_array_dtype,
|
|
1434
1467
|
)
|
|
1435
1468
|
|
|
1436
|
-
|
|
1469
|
+
if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
|
|
1470
|
+
wp.utils.warn(
|
|
1471
|
+
f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1472
|
+
category=UserWarning,
|
|
1473
|
+
stacklevel=2,
|
|
1474
|
+
)
|
|
1475
|
+
triplet_rows.fill_(-1)
|
|
1476
|
+
else:
|
|
1477
|
+
dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
|
|
1437
1478
|
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1479
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1480
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1481
|
+
wp.launch(
|
|
1482
|
+
kernel=dispatch_kernel,
|
|
1483
|
+
dim=(
|
|
1484
|
+
test.space_restriction.node_count(),
|
|
1485
|
+
test.node_dof_count,
|
|
1486
|
+
trial.node_dof_count,
|
|
1487
|
+
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1488
|
+
),
|
|
1489
|
+
inputs=[
|
|
1490
|
+
qp_arg,
|
|
1491
|
+
domain_elt_arg,
|
|
1492
|
+
domain_elt_index_arg,
|
|
1493
|
+
test_arg,
|
|
1494
|
+
test.space.space_arg_value(device),
|
|
1495
|
+
trial_partition_arg,
|
|
1496
|
+
trial_topology_arg,
|
|
1497
|
+
trial.space.space_arg_value(device),
|
|
1498
|
+
local_result_as_vec,
|
|
1499
|
+
triplet_rows,
|
|
1500
|
+
triplet_cols,
|
|
1501
|
+
triplet_values,
|
|
1502
|
+
],
|
|
1503
|
+
device=device,
|
|
1504
|
+
)
|
|
1464
1505
|
|
|
1465
1506
|
local_result.release()
|
|
1466
1507
|
|
|
@@ -1529,21 +1570,30 @@ def _pick_assembly_strategy(
|
|
|
1529
1570
|
if assembly not in ("generic", "nodal", "dispatch"):
|
|
1530
1571
|
raise ValueError(f"Invalid assembly strategy'{assembly}'")
|
|
1531
1572
|
return assembly
|
|
1532
|
-
elif nodal:
|
|
1533
|
-
|
|
1573
|
+
elif nodal is not None:
|
|
1574
|
+
wp.utils.warn(
|
|
1575
|
+
"'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
|
|
1576
|
+
category=DeprecationWarning,
|
|
1577
|
+
stacklevel=2,
|
|
1578
|
+
)
|
|
1579
|
+
if nodal:
|
|
1580
|
+
return "nodal"
|
|
1534
1581
|
|
|
1535
|
-
test_operators = operators.get(arguments.test_name,
|
|
1536
|
-
trial_operators = operators.get(arguments.trial_name,
|
|
1537
|
-
uses_at_node = at_node in test_operators or at_node in trial_operators
|
|
1582
|
+
test_operators = operators.get(arguments.test_name, set())
|
|
1583
|
+
trial_operators = operators.get(arguments.trial_name, set())
|
|
1538
1584
|
|
|
1539
|
-
|
|
1585
|
+
uses_virtual_node_operator = {operator.at_node, operator.node_count, operator.node_index} & (
|
|
1586
|
+
test_operators | trial_operators
|
|
1587
|
+
)
|
|
1588
|
+
|
|
1589
|
+
return "generic" if uses_virtual_node_operator else "dispatch"
|
|
1540
1590
|
|
|
1541
1591
|
|
|
1542
1592
|
def integrate(
|
|
1543
1593
|
integrand: Integrand,
|
|
1544
1594
|
domain: Optional[GeometryDomain] = None,
|
|
1545
1595
|
quadrature: Optional[Quadrature] = None,
|
|
1546
|
-
nodal: bool =
|
|
1596
|
+
nodal: Optional[bool] = None,
|
|
1547
1597
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
1548
1598
|
values: Optional[Dict[str, Any]] = None,
|
|
1549
1599
|
accumulate_dtype: type = wp.float64,
|
|
@@ -1575,7 +1625,7 @@ def integrate(
|
|
|
1575
1625
|
assembly: Specifies the strategy for assembling the integrated vector or matrix:
|
|
1576
1626
|
- "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1577
1627
|
- "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
|
|
1578
|
-
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node`
|
|
1628
|
+
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` and `node_index` operators on test or trial functions.
|
|
1579
1629
|
- `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
|
|
1580
1630
|
add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
|
|
1581
1631
|
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
@@ -1622,6 +1672,9 @@ def integrate(
|
|
|
1622
1672
|
|
|
1623
1673
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
1624
1674
|
|
|
1675
|
+
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
1676
|
+
wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
1677
|
+
|
|
1625
1678
|
assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
|
|
1626
1679
|
# print("assembly for ", integrand.name, ":", strategy)
|
|
1627
1680
|
|
|
@@ -1703,7 +1756,7 @@ def get_interpolate_to_field_function(
|
|
|
1703
1756
|
ValueStruct: wp.codegen.Struct,
|
|
1704
1757
|
dest: FieldRestriction,
|
|
1705
1758
|
):
|
|
1706
|
-
|
|
1759
|
+
zero_value = type_zero_element(dest.space.dtype)
|
|
1707
1760
|
|
|
1708
1761
|
def interpolate_to_field_fn(
|
|
1709
1762
|
local_node_index: int,
|
|
@@ -1724,7 +1777,7 @@ def get_interpolate_to_field_function(
|
|
|
1724
1777
|
# Volume-weighted average across elements
|
|
1725
1778
|
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
|
|
1726
1779
|
|
|
1727
|
-
val_sum =
|
|
1780
|
+
val_sum = zero_value()
|
|
1728
1781
|
vol_sum = float(0.0)
|
|
1729
1782
|
|
|
1730
1783
|
for n in range(element_beg, element_end):
|
|
@@ -1969,6 +2022,7 @@ def get_interpolate_free_kernel(
|
|
|
1969
2022
|
def interpolate_free_nonvalued_kernel_fn(
|
|
1970
2023
|
dim: int,
|
|
1971
2024
|
domain_arg: domain.ElementArg,
|
|
2025
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1972
2026
|
fields: FieldStruct,
|
|
1973
2027
|
values: ValueStruct,
|
|
1974
2028
|
result: wp.array(dtype=float),
|
|
@@ -1987,6 +2041,7 @@ def get_interpolate_free_kernel(
|
|
|
1987
2041
|
def interpolate_free_kernel_fn(
|
|
1988
2042
|
dim: int,
|
|
1989
2043
|
domain_arg: domain.ElementArg,
|
|
2044
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1990
2045
|
fields: FieldStruct,
|
|
1991
2046
|
values: ValueStruct,
|
|
1992
2047
|
result: wp.array(dtype=value_type),
|
|
@@ -2143,12 +2198,12 @@ def _launch_interpolate_kernel(
|
|
|
2143
2198
|
field_arg_values = FieldStruct()
|
|
2144
2199
|
for k, v in fields.items():
|
|
2145
2200
|
if not isinstance(v, GeometryDomain):
|
|
2146
|
-
|
|
2201
|
+
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
2147
2202
|
|
|
2148
2203
|
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
2149
2204
|
|
|
2150
2205
|
if isinstance(dest, FieldRestriction):
|
|
2151
|
-
dest_node_arg = dest.space_restriction.
|
|
2206
|
+
dest_node_arg = dest.space_restriction.node_arg_value(device=device)
|
|
2152
2207
|
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
2153
2208
|
|
|
2154
2209
|
wp.launch(
|
|
@@ -2167,33 +2222,49 @@ def _launch_interpolate_kernel(
|
|
|
2167
2222
|
return
|
|
2168
2223
|
|
|
2169
2224
|
if quadrature is None:
|
|
2225
|
+
if dest is not None and (not is_array(dest) or dest.shape[0] != dim):
|
|
2226
|
+
raise ValueError(f"dest must be a warp array with {dim} rows")
|
|
2227
|
+
|
|
2170
2228
|
wp.launch(
|
|
2171
2229
|
kernel=kernel,
|
|
2172
2230
|
dim=dim,
|
|
2173
|
-
inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
|
|
2231
|
+
inputs=[dim, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2174
2232
|
device=device,
|
|
2175
2233
|
)
|
|
2176
2234
|
return
|
|
2177
2235
|
|
|
2178
2236
|
qp_arg = quadrature.arg_value(device)
|
|
2237
|
+
qp_eval_count = quadrature.evaluation_point_count()
|
|
2238
|
+
qp_index_count = quadrature.total_point_count()
|
|
2239
|
+
|
|
2240
|
+
if qp_eval_count != qp_index_count:
|
|
2241
|
+
wp.utils.warn(
|
|
2242
|
+
f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
|
|
2243
|
+
category=UserWarning,
|
|
2244
|
+
stacklevel=2,
|
|
2245
|
+
)
|
|
2246
|
+
|
|
2179
2247
|
qp_element_index_arg = quadrature.element_index_arg_value(device)
|
|
2180
2248
|
if trial is None:
|
|
2249
|
+
if dest is not None and (not is_array(dest) or dest.shape[0] != qp_index_count):
|
|
2250
|
+
raise ValueError(f"dest must be a warp array with {qp_index_count} rows")
|
|
2251
|
+
|
|
2181
2252
|
wp.launch(
|
|
2182
2253
|
kernel=kernel,
|
|
2183
|
-
dim=
|
|
2254
|
+
dim=qp_eval_count,
|
|
2184
2255
|
inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2185
2256
|
device=device,
|
|
2186
2257
|
)
|
|
2187
2258
|
return
|
|
2188
2259
|
|
|
2189
|
-
nnz =
|
|
2260
|
+
nnz = qp_eval_count * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
2190
2261
|
|
|
2191
|
-
if dest.nrow !=
|
|
2262
|
+
if dest.nrow != qp_index_count or dest.ncol != trial.space_partition.node_count():
|
|
2192
2263
|
raise RuntimeError(
|
|
2193
|
-
f"'dest' matrix must have {
|
|
2264
|
+
f"'dest' matrix must have {qp_index_count} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
2194
2265
|
)
|
|
2195
2266
|
if dest.block_shape[1] != trial.node_dof_count:
|
|
2196
|
-
raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
|
|
2267
|
+
raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
|
|
2197
2268
|
|
|
2198
2269
|
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2199
2270
|
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
@@ -2243,7 +2314,7 @@ def interpolate(
|
|
|
2243
2314
|
integrand: Union[Integrand, FieldLike],
|
|
2244
2315
|
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
2245
2316
|
quadrature: Optional[Quadrature] = None,
|
|
2246
|
-
dim: int =
|
|
2317
|
+
dim: Optional[int] = None,
|
|
2247
2318
|
domain: Optional[Domain] = None,
|
|
2248
2319
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
2249
2320
|
values: Optional[Dict[str, Any]] = None,
|
|
@@ -2290,11 +2361,13 @@ def interpolate(
|
|
|
2290
2361
|
arguments = _parse_integrand_arguments(integrand, fields)
|
|
2291
2362
|
if arguments.test_name:
|
|
2292
2363
|
raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
|
|
2293
|
-
if arguments.trial_name and
|
|
2364
|
+
if arguments.trial_name and not isinstance(dest, BsrMatrix):
|
|
2294
2365
|
raise ValueError(
|
|
2295
|
-
f"Interpolation using trial field '{arguments.trial_name}' requires '
|
|
2366
|
+
f"Interpolation using trial field '{arguments.trial_name}' requires 'dest' to be a `warp.sparse.BsrMatrix`"
|
|
2296
2367
|
)
|
|
2297
2368
|
|
|
2369
|
+
trial = arguments.field_args.get(arguments.trial_name, None)
|
|
2370
|
+
|
|
2298
2371
|
if isinstance(dest, DiscreteField):
|
|
2299
2372
|
dest = make_restriction(dest, domain=domain)
|
|
2300
2373
|
|
|
@@ -2302,12 +2375,25 @@ def interpolate(
|
|
|
2302
2375
|
domain = dest.domain
|
|
2303
2376
|
elif quadrature is not None:
|
|
2304
2377
|
domain = quadrature.domain
|
|
2378
|
+
elif dim is None:
|
|
2379
|
+
if trial is not None:
|
|
2380
|
+
domain = trial.domain
|
|
2381
|
+
elif domain is None:
|
|
2382
|
+
raise ValueError(
|
|
2383
|
+
"Unable to determine interpolation domain, provide an explicit field restriction or quadrature"
|
|
2384
|
+
)
|
|
2385
|
+
|
|
2386
|
+
# Default to one sample per domain element
|
|
2387
|
+
quadrature = RegularQuadrature(domain, order=0)
|
|
2305
2388
|
|
|
2306
2389
|
if arguments.domain_name:
|
|
2307
2390
|
arguments.field_args[arguments.domain_name] = domain
|
|
2308
2391
|
|
|
2309
2392
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
2310
2393
|
|
|
2394
|
+
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
2395
|
+
wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
2396
|
+
|
|
2311
2397
|
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
2312
2398
|
integrand=integrand,
|
|
2313
2399
|
domain=domain,
|
|
@@ -2326,7 +2412,7 @@ def interpolate(
|
|
|
2326
2412
|
dest=dest,
|
|
2327
2413
|
quadrature=quadrature,
|
|
2328
2414
|
dim=dim,
|
|
2329
|
-
trial=
|
|
2415
|
+
trial=trial,
|
|
2330
2416
|
fields=arguments.field_args,
|
|
2331
2417
|
values=values,
|
|
2332
2418
|
temporary_store=temporary_store,
|