warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -19,6 +19,7 @@ import textwrap
|
|
|
19
19
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
|
20
20
|
|
|
21
21
|
import warp as wp
|
|
22
|
+
import warp.fem.operator as operator
|
|
22
23
|
from warp.codegen import get_annotations
|
|
23
24
|
from warp.fem import cache
|
|
24
25
|
from warp.fem.domain import GeometryDomain
|
|
@@ -35,7 +36,11 @@ from warp.fem.field import (
|
|
|
35
36
|
)
|
|
36
37
|
from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
|
|
37
38
|
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
38
|
-
from warp.fem.operator import
|
|
39
|
+
from warp.fem.operator import (
|
|
40
|
+
Integrand,
|
|
41
|
+
Operator,
|
|
42
|
+
integrand,
|
|
43
|
+
)
|
|
39
44
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
40
45
|
from warp.fem.types import (
|
|
41
46
|
NULL_DOF_INDEX,
|
|
@@ -49,8 +54,9 @@ from warp.fem.types import (
|
|
|
49
54
|
Sample,
|
|
50
55
|
make_free_sample,
|
|
51
56
|
)
|
|
57
|
+
from warp.fem.utils import type_zero_element
|
|
52
58
|
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
53
|
-
from warp.types import
|
|
59
|
+
from warp.types import type_size
|
|
54
60
|
from warp.utils import array_cast
|
|
55
61
|
|
|
56
62
|
|
|
@@ -111,6 +117,8 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
111
117
|
def get_concrete_type(field: Union[FieldLike, Domain]):
|
|
112
118
|
if isinstance(field, FieldLike):
|
|
113
119
|
return field.ElementEvalArg
|
|
120
|
+
elif isinstance(field, GeometryDomain):
|
|
121
|
+
return field.DomainArg
|
|
114
122
|
return field.ElementArg
|
|
115
123
|
|
|
116
124
|
return {
|
|
@@ -232,7 +240,7 @@ class IntegrandOperatorParser(IntegrandVisitor):
|
|
|
232
240
|
|
|
233
241
|
@staticmethod
|
|
234
242
|
def apply(
|
|
235
|
-
integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
|
|
243
|
+
integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Optional[Callable] = None
|
|
236
244
|
) -> wp.Function:
|
|
237
245
|
field_info = IntegrandVisitor._build_field_info(integrand, field_args)
|
|
238
246
|
IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
|
|
@@ -267,7 +275,11 @@ class IntegrandTransformer(IntegrandVisitor):
|
|
|
267
275
|
setattr(field_info.concrete_type, pointer.key, pointer)
|
|
268
276
|
|
|
269
277
|
# also insert callee as first argument
|
|
270
|
-
call.args = [ast.Name(id=callee, ctx=ast.Load())
|
|
278
|
+
call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
|
|
279
|
+
|
|
280
|
+
# replace first argument with selected attribute
|
|
281
|
+
if operator.attr:
|
|
282
|
+
call.args[0] = ast.Attribute(value=call.args[0], attr=operator.attr)
|
|
271
283
|
|
|
272
284
|
def _process_integrand_call(
|
|
273
285
|
self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
|
|
@@ -456,6 +468,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
456
468
|
fields_var_name: str = "fields",
|
|
457
469
|
values_var_name: str = "values",
|
|
458
470
|
domain_var_name: str = "domain_arg",
|
|
471
|
+
domain_index_var_name: str = "domain_index_arg",
|
|
459
472
|
sample_var_name: str = "sample",
|
|
460
473
|
field_wrappers_attr: str = "_field_wrappers",
|
|
461
474
|
):
|
|
@@ -470,6 +483,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
470
483
|
self._fields_var_name = fields_var_name
|
|
471
484
|
self._values_var_name = values_var_name
|
|
472
485
|
self._domain_var_name = domain_var_name
|
|
486
|
+
self._domain_index_var_name = domain_index_var_name
|
|
473
487
|
self._sample_var_name = sample_var_name
|
|
474
488
|
|
|
475
489
|
self._field_wrappers_attr = field_wrappers_attr
|
|
@@ -485,8 +499,28 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
485
499
|
for name, field in fields.items():
|
|
486
500
|
if isinstance(field, FieldLike):
|
|
487
501
|
setattr(field_wrappers, name, field.ElementEvalArg)
|
|
502
|
+
elif isinstance(field, GeometryDomain):
|
|
503
|
+
setattr(field_wrappers, name, field.DomainArg)
|
|
488
504
|
setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
|
|
489
505
|
|
|
506
|
+
def _emit_field_wrapper_call(self, field_name, *data_arguments):
|
|
507
|
+
return ast.Call(
|
|
508
|
+
func=ast.Attribute(
|
|
509
|
+
value=ast.Attribute(
|
|
510
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
|
511
|
+
attr=self._field_wrappers_attr,
|
|
512
|
+
ctx=ast.Load(),
|
|
513
|
+
),
|
|
514
|
+
attr=field_name,
|
|
515
|
+
ctx=ast.Load(),
|
|
516
|
+
),
|
|
517
|
+
args=[
|
|
518
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
519
|
+
*data_arguments,
|
|
520
|
+
],
|
|
521
|
+
keywords=[],
|
|
522
|
+
)
|
|
523
|
+
|
|
490
524
|
def visit_Call(self, call: ast.Call):
|
|
491
525
|
call = self.generic_visit(call)
|
|
492
526
|
|
|
@@ -498,33 +532,25 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
498
532
|
for arg in self._arg_names:
|
|
499
533
|
if arg == self._domain_name:
|
|
500
534
|
call.args.append(
|
|
501
|
-
|
|
535
|
+
self._emit_field_wrapper_call(
|
|
536
|
+
arg,
|
|
537
|
+
ast.Name(id=self._domain_index_var_name, ctx=ast.Load()),
|
|
538
|
+
)
|
|
502
539
|
)
|
|
540
|
+
|
|
503
541
|
elif arg == self._sample_name:
|
|
504
542
|
call.args.append(
|
|
505
543
|
ast.Name(id=self._sample_var_name, ctx=ast.Load()),
|
|
506
544
|
)
|
|
507
545
|
elif arg in self._field_args:
|
|
508
546
|
call.args.append(
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
attr=self._field_wrappers_attr,
|
|
514
|
-
ctx=ast.Load(),
|
|
515
|
-
),
|
|
547
|
+
self._emit_field_wrapper_call(
|
|
548
|
+
arg,
|
|
549
|
+
ast.Attribute(
|
|
550
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
516
551
|
attr=arg,
|
|
517
552
|
ctx=ast.Load(),
|
|
518
553
|
),
|
|
519
|
-
args=[
|
|
520
|
-
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
521
|
-
ast.Attribute(
|
|
522
|
-
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
523
|
-
attr=arg,
|
|
524
|
-
ctx=ast.Load(),
|
|
525
|
-
),
|
|
526
|
-
],
|
|
527
|
-
keywords=[],
|
|
528
554
|
)
|
|
529
555
|
)
|
|
530
556
|
elif arg in self._value_args:
|
|
@@ -704,7 +730,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
704
730
|
|
|
705
731
|
coords = test.space.node_coords_in_element(
|
|
706
732
|
domain_arg,
|
|
707
|
-
_get_test_arg(),
|
|
733
|
+
_get_test_arg().space_arg,
|
|
708
734
|
element_index,
|
|
709
735
|
node_element_index.node_index_in_element,
|
|
710
736
|
)
|
|
@@ -712,7 +738,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
712
738
|
if coords[0] != OUTSIDE:
|
|
713
739
|
node_weight = test.space.node_quadrature_weight(
|
|
714
740
|
domain_arg,
|
|
715
|
-
_get_test_arg(),
|
|
741
|
+
_get_test_arg().space_arg,
|
|
716
742
|
element_index,
|
|
717
743
|
node_element_index.node_index_in_element,
|
|
718
744
|
)
|
|
@@ -913,7 +939,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
913
939
|
|
|
914
940
|
coords = test.space.node_coords_in_element(
|
|
915
941
|
domain_arg,
|
|
916
|
-
_get_test_arg(),
|
|
942
|
+
_get_test_arg().space_arg,
|
|
917
943
|
element_index,
|
|
918
944
|
node_element_index.node_index_in_element,
|
|
919
945
|
)
|
|
@@ -921,7 +947,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
921
947
|
if coords[0] != OUTSIDE:
|
|
922
948
|
node_weight = test.space.node_quadrature_weight(
|
|
923
949
|
domain_arg,
|
|
924
|
-
_get_test_arg(),
|
|
950
|
+
_get_test_arg().space_arg,
|
|
925
951
|
element_index,
|
|
926
952
|
node_element_index.node_index_in_element,
|
|
927
953
|
)
|
|
@@ -1153,7 +1179,7 @@ def _launch_integrate_kernel(
|
|
|
1153
1179
|
field_arg_values = FieldStruct()
|
|
1154
1180
|
for k, v in fields.items():
|
|
1155
1181
|
if not isinstance(v, GeometryDomain):
|
|
1156
|
-
|
|
1182
|
+
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
1157
1183
|
|
|
1158
1184
|
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
1159
1185
|
|
|
@@ -1203,7 +1229,7 @@ def _launch_integrate_kernel(
|
|
|
1203
1229
|
array_cast(in_array=accumulate_array, out_array=output)
|
|
1204
1230
|
return output
|
|
1205
1231
|
|
|
1206
|
-
test_arg = test.space_restriction.
|
|
1232
|
+
test_arg = test.space_restriction.node_arg_value(device=device)
|
|
1207
1233
|
nodal = quadrature is None
|
|
1208
1234
|
|
|
1209
1235
|
# Linear form
|
|
@@ -1211,9 +1237,9 @@ def _launch_integrate_kernel(
|
|
|
1211
1237
|
# If an output array is provided with the correct type, accumulate directly into it
|
|
1212
1238
|
# Otherwise, grab a temporary array
|
|
1213
1239
|
if output is None:
|
|
1214
|
-
if
|
|
1240
|
+
if type_size(output_dtype) == test.node_dof_count:
|
|
1215
1241
|
output_shape = (test.space_partition.node_count(),)
|
|
1216
|
-
elif
|
|
1242
|
+
elif type_size(output_dtype) == 1:
|
|
1217
1243
|
output_shape = (test.space_partition.node_count(), test.node_dof_count)
|
|
1218
1244
|
else:
|
|
1219
1245
|
raise RuntimeError(
|
|
@@ -1236,8 +1262,8 @@ def _launch_integrate_kernel(
|
|
|
1236
1262
|
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
1237
1263
|
|
|
1238
1264
|
output_dtype = output.dtype
|
|
1239
|
-
if
|
|
1240
|
-
if
|
|
1265
|
+
if type_size(output_dtype) != test.node_dof_count:
|
|
1266
|
+
if type_size(output_dtype) != 1:
|
|
1241
1267
|
raise RuntimeError(
|
|
1242
1268
|
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1243
1269
|
)
|
|
@@ -1311,7 +1337,7 @@ def _launch_integrate_kernel(
|
|
|
1311
1337
|
domain_elt_arg,
|
|
1312
1338
|
domain_elt_index_arg,
|
|
1313
1339
|
test_arg,
|
|
1314
|
-
test.
|
|
1340
|
+
test.space.space_arg_value(device),
|
|
1315
1341
|
local_result.array,
|
|
1316
1342
|
output_view,
|
|
1317
1343
|
],
|
|
@@ -1450,10 +1476,10 @@ def _launch_integrate_kernel(
|
|
|
1450
1476
|
domain_elt_arg,
|
|
1451
1477
|
domain_elt_index_arg,
|
|
1452
1478
|
test_arg,
|
|
1453
|
-
test.
|
|
1479
|
+
test.space.space_arg_value(device),
|
|
1454
1480
|
trial_partition_arg,
|
|
1455
1481
|
trial_topology_arg,
|
|
1456
|
-
trial.
|
|
1482
|
+
trial.space.space_arg_value(device),
|
|
1457
1483
|
local_result_as_vec,
|
|
1458
1484
|
triplet_rows,
|
|
1459
1485
|
triplet_cols,
|
|
@@ -1529,21 +1555,30 @@ def _pick_assembly_strategy(
|
|
|
1529
1555
|
if assembly not in ("generic", "nodal", "dispatch"):
|
|
1530
1556
|
raise ValueError(f"Invalid assembly strategy'{assembly}'")
|
|
1531
1557
|
return assembly
|
|
1532
|
-
elif nodal:
|
|
1533
|
-
|
|
1558
|
+
elif nodal is not None:
|
|
1559
|
+
wp.utils.warn(
|
|
1560
|
+
"'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
|
|
1561
|
+
category=DeprecationWarning,
|
|
1562
|
+
stacklevel=2,
|
|
1563
|
+
)
|
|
1564
|
+
if nodal:
|
|
1565
|
+
return "nodal"
|
|
1534
1566
|
|
|
1535
|
-
test_operators = operators.get(arguments.test_name,
|
|
1536
|
-
trial_operators = operators.get(arguments.trial_name,
|
|
1537
|
-
uses_at_node = at_node in test_operators or at_node in trial_operators
|
|
1567
|
+
test_operators = operators.get(arguments.test_name, set())
|
|
1568
|
+
trial_operators = operators.get(arguments.trial_name, set())
|
|
1538
1569
|
|
|
1539
|
-
|
|
1570
|
+
uses_virtual_node_operator = {operator.at_node, operator.node_count, operator.node_index} & (
|
|
1571
|
+
test_operators | trial_operators
|
|
1572
|
+
)
|
|
1573
|
+
|
|
1574
|
+
return "generic" if uses_virtual_node_operator else "dispatch"
|
|
1540
1575
|
|
|
1541
1576
|
|
|
1542
1577
|
def integrate(
|
|
1543
1578
|
integrand: Integrand,
|
|
1544
1579
|
domain: Optional[GeometryDomain] = None,
|
|
1545
1580
|
quadrature: Optional[Quadrature] = None,
|
|
1546
|
-
nodal: bool =
|
|
1581
|
+
nodal: Optional[bool] = None,
|
|
1547
1582
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
1548
1583
|
values: Optional[Dict[str, Any]] = None,
|
|
1549
1584
|
accumulate_dtype: type = wp.float64,
|
|
@@ -1575,7 +1610,7 @@ def integrate(
|
|
|
1575
1610
|
assembly: Specifies the strategy for assembling the integrated vector or matrix:
|
|
1576
1611
|
- "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1577
1612
|
- "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
|
|
1578
|
-
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node`
|
|
1613
|
+
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` and `node_index` operators on test or trial functions.
|
|
1579
1614
|
- `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
|
|
1580
1615
|
add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
|
|
1581
1616
|
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
@@ -1622,6 +1657,9 @@ def integrate(
|
|
|
1622
1657
|
|
|
1623
1658
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
1624
1659
|
|
|
1660
|
+
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
1661
|
+
wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
1662
|
+
|
|
1625
1663
|
assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
|
|
1626
1664
|
# print("assembly for ", integrand.name, ":", strategy)
|
|
1627
1665
|
|
|
@@ -1703,7 +1741,7 @@ def get_interpolate_to_field_function(
|
|
|
1703
1741
|
ValueStruct: wp.codegen.Struct,
|
|
1704
1742
|
dest: FieldRestriction,
|
|
1705
1743
|
):
|
|
1706
|
-
|
|
1744
|
+
zero_value = type_zero_element(dest.space.dtype)
|
|
1707
1745
|
|
|
1708
1746
|
def interpolate_to_field_fn(
|
|
1709
1747
|
local_node_index: int,
|
|
@@ -1724,7 +1762,7 @@ def get_interpolate_to_field_function(
|
|
|
1724
1762
|
# Volume-weighted average across elements
|
|
1725
1763
|
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
|
|
1726
1764
|
|
|
1727
|
-
val_sum =
|
|
1765
|
+
val_sum = zero_value()
|
|
1728
1766
|
vol_sum = float(0.0)
|
|
1729
1767
|
|
|
1730
1768
|
for n in range(element_beg, element_end):
|
|
@@ -1969,6 +2007,7 @@ def get_interpolate_free_kernel(
|
|
|
1969
2007
|
def interpolate_free_nonvalued_kernel_fn(
|
|
1970
2008
|
dim: int,
|
|
1971
2009
|
domain_arg: domain.ElementArg,
|
|
2010
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1972
2011
|
fields: FieldStruct,
|
|
1973
2012
|
values: ValueStruct,
|
|
1974
2013
|
result: wp.array(dtype=float),
|
|
@@ -1987,6 +2026,7 @@ def get_interpolate_free_kernel(
|
|
|
1987
2026
|
def interpolate_free_kernel_fn(
|
|
1988
2027
|
dim: int,
|
|
1989
2028
|
domain_arg: domain.ElementArg,
|
|
2029
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1990
2030
|
fields: FieldStruct,
|
|
1991
2031
|
values: ValueStruct,
|
|
1992
2032
|
result: wp.array(dtype=value_type),
|
|
@@ -2143,12 +2183,12 @@ def _launch_interpolate_kernel(
|
|
|
2143
2183
|
field_arg_values = FieldStruct()
|
|
2144
2184
|
for k, v in fields.items():
|
|
2145
2185
|
if not isinstance(v, GeometryDomain):
|
|
2146
|
-
|
|
2186
|
+
v.fill_eval_arg(getattr(field_arg_values, k), device=device)
|
|
2147
2187
|
|
|
2148
2188
|
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
2149
2189
|
|
|
2150
2190
|
if isinstance(dest, FieldRestriction):
|
|
2151
|
-
dest_node_arg = dest.space_restriction.
|
|
2191
|
+
dest_node_arg = dest.space_restriction.node_arg_value(device=device)
|
|
2152
2192
|
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
2153
2193
|
|
|
2154
2194
|
wp.launch(
|
|
@@ -2170,7 +2210,7 @@ def _launch_interpolate_kernel(
|
|
|
2170
2210
|
wp.launch(
|
|
2171
2211
|
kernel=kernel,
|
|
2172
2212
|
dim=dim,
|
|
2173
|
-
inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
|
|
2213
|
+
inputs=[dim, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2174
2214
|
device=device,
|
|
2175
2215
|
)
|
|
2176
2216
|
return
|
|
@@ -2193,7 +2233,7 @@ def _launch_interpolate_kernel(
|
|
|
2193
2233
|
f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
2194
2234
|
)
|
|
2195
2235
|
if dest.block_shape[1] != trial.node_dof_count:
|
|
2196
|
-
raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
|
|
2236
|
+
raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
|
|
2197
2237
|
|
|
2198
2238
|
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2199
2239
|
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
@@ -2243,7 +2283,7 @@ def interpolate(
|
|
|
2243
2283
|
integrand: Union[Integrand, FieldLike],
|
|
2244
2284
|
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
2245
2285
|
quadrature: Optional[Quadrature] = None,
|
|
2246
|
-
dim: int =
|
|
2286
|
+
dim: Optional[int] = None,
|
|
2247
2287
|
domain: Optional[Domain] = None,
|
|
2248
2288
|
fields: Optional[Dict[str, FieldLike]] = None,
|
|
2249
2289
|
values: Optional[Dict[str, Any]] = None,
|
|
@@ -2290,11 +2330,13 @@ def interpolate(
|
|
|
2290
2330
|
arguments = _parse_integrand_arguments(integrand, fields)
|
|
2291
2331
|
if arguments.test_name:
|
|
2292
2332
|
raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
|
|
2293
|
-
if arguments.trial_name and
|
|
2333
|
+
if arguments.trial_name and not isinstance(dest, BsrMatrix):
|
|
2294
2334
|
raise ValueError(
|
|
2295
|
-
f"Interpolation using trial field '{arguments.trial_name}' requires '
|
|
2335
|
+
f"Interpolation using trial field '{arguments.trial_name}' requires 'dest' to be a `warp.sparse.BsrMatrix`"
|
|
2296
2336
|
)
|
|
2297
2337
|
|
|
2338
|
+
trial = arguments.field_args.get(arguments.trial_name, None)
|
|
2339
|
+
|
|
2298
2340
|
if isinstance(dest, DiscreteField):
|
|
2299
2341
|
dest = make_restriction(dest, domain=domain)
|
|
2300
2342
|
|
|
@@ -2302,12 +2344,25 @@ def interpolate(
|
|
|
2302
2344
|
domain = dest.domain
|
|
2303
2345
|
elif quadrature is not None:
|
|
2304
2346
|
domain = quadrature.domain
|
|
2347
|
+
elif dim is None:
|
|
2348
|
+
if trial is not None:
|
|
2349
|
+
domain = trial.domain
|
|
2350
|
+
elif domain is None:
|
|
2351
|
+
raise ValueError(
|
|
2352
|
+
"Unable to determine interpolation domain, provide an explicit field restriction or quadrature"
|
|
2353
|
+
)
|
|
2354
|
+
|
|
2355
|
+
# Default to one sample per domain element
|
|
2356
|
+
quadrature = RegularQuadrature(domain, order=0)
|
|
2305
2357
|
|
|
2306
2358
|
if arguments.domain_name:
|
|
2307
2359
|
arguments.field_args[arguments.domain_name] = domain
|
|
2308
2360
|
|
|
2309
2361
|
_find_integrand_operators(integrand, arguments.field_args)
|
|
2310
2362
|
|
|
2363
|
+
if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
|
|
2364
|
+
wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
|
|
2365
|
+
|
|
2311
2366
|
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
2312
2367
|
integrand=integrand,
|
|
2313
2368
|
domain=domain,
|
|
@@ -2326,7 +2381,7 @@ def interpolate(
|
|
|
2326
2381
|
dest=dest,
|
|
2327
2382
|
quadrature=quadrature,
|
|
2328
2383
|
dim=dim,
|
|
2329
|
-
trial=
|
|
2384
|
+
trial=trial,
|
|
2330
2385
|
fields=arguments.field_args,
|
|
2331
2386
|
values=values,
|
|
2332
2387
|
temporary_store=temporary_store,
|
warp/fem/linalg.py
CHANGED
|
@@ -16,80 +16,62 @@
|
|
|
16
16
|
from typing import Any
|
|
17
17
|
|
|
18
18
|
import warp as wp
|
|
19
|
+
import warp.types
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
@wp.func
|
|
22
|
-
def generalized_outer(x: Any, y: Any):
|
|
23
|
-
"""Generalized outer product allowing for
|
|
23
|
+
def generalized_outer(x: wp.vec(Any, wp.Scalar), y: wp.vec(Any, wp.Scalar)):
|
|
24
|
+
"""Generalized outer product allowing for vector or scalar arguments"""
|
|
24
25
|
return wp.outer(x, y)
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
@wp.func
|
|
28
|
-
def generalized_outer(x: wp.
|
|
29
|
+
def generalized_outer(x: wp.Scalar, y: wp.vec(Any, wp.Scalar)):
|
|
29
30
|
return x * y
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@wp.func
|
|
33
|
-
def generalized_outer(x: wp.
|
|
34
|
+
def generalized_outer(x: wp.vec(Any, wp.Scalar), y: wp.Scalar):
|
|
34
35
|
return x * y
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
@wp.func
|
|
38
|
-
def
|
|
39
|
-
|
|
40
|
-
return wp.dot(x, y)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@wp.func
|
|
44
|
-
def generalized_inner(x: float, y: float):
|
|
45
|
-
return x * y
|
|
39
|
+
def generalized_outer(x: wp.quatf, y: wp.vec(Any, wp.Scalar)):
|
|
40
|
+
return generalized_outer(wp.vec4(x[0], x[1], x[2], x[3]), y)
|
|
46
41
|
|
|
47
42
|
|
|
48
43
|
@wp.func
|
|
49
|
-
def generalized_inner(x: wp.
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@wp.func
|
|
54
|
-
def generalized_inner(x: wp.mat33, y: wp.vec3):
|
|
55
|
-
return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]
|
|
44
|
+
def generalized_inner(x: wp.vec(Any, wp.Scalar), y: wp.vec(Any, wp.Scalar)):
|
|
45
|
+
"""Generalized inner product allowing for vector, tensor and scalar arguments"""
|
|
46
|
+
return wp.dot(x, y)
|
|
56
47
|
|
|
57
48
|
|
|
58
49
|
@wp.func
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
t = type(template_type)(0.0)
|
|
63
|
-
t[coord] = 1.0
|
|
64
|
-
return t
|
|
50
|
+
def generalized_inner(x: wp.Scalar, y: wp.Scalar):
|
|
51
|
+
return x * y
|
|
65
52
|
|
|
66
53
|
|
|
67
54
|
@wp.func
|
|
68
|
-
def
|
|
69
|
-
return
|
|
55
|
+
def generalized_inner(x: wp.mat((Any, Any), wp.Scalar), y: wp.vec(Any, wp.Scalar)):
|
|
56
|
+
return y @ x
|
|
70
57
|
|
|
71
58
|
|
|
72
59
|
@wp.func
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
row = coord // 2
|
|
76
|
-
col = coord - 2 * row
|
|
77
|
-
t[row, col] = 1.0
|
|
78
|
-
return t
|
|
60
|
+
def generalized_inner(x: wp.vec(Any, wp.Scalar), y: wp.mat((Any, Any), wp.Scalar)):
|
|
61
|
+
return y @ x
|
|
79
62
|
|
|
80
63
|
|
|
81
64
|
@wp.func
|
|
82
|
-
def
|
|
83
|
-
|
|
84
|
-
row = coord // 3
|
|
85
|
-
col = coord - 3 * row
|
|
86
|
-
t[row, col] = 1.0
|
|
87
|
-
return t
|
|
65
|
+
def basis_coefficient(val: wp.Scalar, i: int):
|
|
66
|
+
return val
|
|
88
67
|
|
|
89
68
|
|
|
90
69
|
@wp.func
|
|
91
|
-
def basis_coefficient(val: wp.
|
|
92
|
-
|
|
70
|
+
def basis_coefficient(val: wp.mat((Any, Any), wp.Scalar), i: int):
|
|
71
|
+
cols = int(type(val[0]).length)
|
|
72
|
+
row = i // cols
|
|
73
|
+
col = i - row * cols
|
|
74
|
+
return val[row, col]
|
|
93
75
|
|
|
94
76
|
|
|
95
77
|
@wp.func
|
|
@@ -98,31 +80,16 @@ def basis_coefficient(val: Any, i: int):
|
|
|
98
80
|
|
|
99
81
|
|
|
100
82
|
@wp.func
|
|
101
|
-
def basis_coefficient(val: wp.
|
|
102
|
-
# treat as row vector
|
|
103
|
-
return val[j]
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
@wp.func
|
|
107
|
-
def basis_coefficient(val: wp.vec3, i: int, j: int):
|
|
83
|
+
def basis_coefficient(val: wp.vec(Any, wp.Scalar), i: int, j: int):
|
|
108
84
|
# treat as row vector
|
|
109
85
|
return val[j]
|
|
110
86
|
|
|
111
87
|
|
|
112
88
|
@wp.func
|
|
113
|
-
def basis_coefficient(val: Any, i: int, j: int):
|
|
89
|
+
def basis_coefficient(val: wp.mat((Any, Any), wp.Scalar), i: int, j: int):
|
|
114
90
|
return val[i, j]
|
|
115
91
|
|
|
116
92
|
|
|
117
|
-
@wp.func
|
|
118
|
-
def basis_coefficient(template_type: wp.mat33, coord: int):
|
|
119
|
-
t = wp.mat33(0.0)
|
|
120
|
-
row = coord // 3
|
|
121
|
-
col = coord - 3 * row
|
|
122
|
-
t[row, col] = 1.0
|
|
123
|
-
return t
|
|
124
|
-
|
|
125
|
-
|
|
126
93
|
@wp.func
|
|
127
94
|
def symmetric_part(x: Any):
|
|
128
95
|
"""Symmetric part of a square tensor"""
|