warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +48 -63
- warp/builtins.py +955 -137
- warp/codegen.py +327 -209
- warp/config.py +1 -1
- warp/context.py +1363 -800
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +266 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +200 -91
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +203 -54
- warp/marching_cubes.py +708 -0
- warp/native/array.h +103 -8
- warp/native/builtin.h +90 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +13 -3
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +42 -11
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +4 -4
- warp/native/mat.h +1913 -119
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +5 -3
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +337 -16
- warp/native/rand.h +7 -7
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +14 -14
- warp/native/spatial.h +366 -17
- warp/native/svd.h +23 -8
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +303 -70
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +385 -18
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +337 -193
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +137 -57
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/graph_coloring.py +2 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +559 -176
- warp/tape.py +2 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +82 -7
- warp/tests/test_array.py +56 -5
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1540 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +162 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +103 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tape.py +38 -0
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +216 -441
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +206 -152
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +16 -16
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +16 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import ast
|
|
17
17
|
import inspect
|
|
18
18
|
import textwrap
|
|
19
|
-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
|
19
|
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import warp as wp
|
|
22
22
|
import warp.fem.operator as operator
|
|
@@ -34,7 +34,10 @@ from warp.fem.field import (
|
|
|
34
34
|
TrialField,
|
|
35
35
|
make_restriction,
|
|
36
36
|
)
|
|
37
|
-
from warp.fem.field.virtual import
|
|
37
|
+
from warp.fem.field.virtual import (
|
|
38
|
+
make_bilinear_dispatch_kernel,
|
|
39
|
+
make_linear_dispatch_kernel,
|
|
40
|
+
)
|
|
38
41
|
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
39
42
|
from warp.fem.operator import (
|
|
40
43
|
Integrand,
|
|
@@ -56,7 +59,7 @@ from warp.fem.types import (
|
|
|
56
59
|
)
|
|
57
60
|
from warp.fem.utils import type_zero_element
|
|
58
61
|
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
59
|
-
from warp.types import type_size
|
|
62
|
+
from warp.types import is_array, type_size
|
|
60
63
|
from warp.utils import array_cast
|
|
61
64
|
|
|
62
65
|
|
|
@@ -101,7 +104,8 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
101
104
|
field: FieldLike
|
|
102
105
|
abstract_type: type
|
|
103
106
|
concrete_type: type
|
|
104
|
-
root_arg_name:
|
|
107
|
+
root_arg_name: str
|
|
108
|
+
local_arg_name: str
|
|
105
109
|
|
|
106
110
|
def __init__(
|
|
107
111
|
self,
|
|
@@ -111,6 +115,7 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
111
115
|
self._integrand = integrand
|
|
112
116
|
self._field_symbols = field_info.copy()
|
|
113
117
|
self._field_nodes = {}
|
|
118
|
+
self._field_arg_annotation_nodes = {}
|
|
114
119
|
|
|
115
120
|
@staticmethod
|
|
116
121
|
def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
@@ -127,6 +132,7 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
127
132
|
abstract_type=integrand.argspec.annotations[name],
|
|
128
133
|
concrete_type=get_concrete_type(field),
|
|
129
134
|
root_arg_name=name,
|
|
135
|
+
local_arg_name=name,
|
|
130
136
|
)
|
|
131
137
|
for name, field in field_args.items()
|
|
132
138
|
}
|
|
@@ -167,6 +173,7 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
167
173
|
field=res[0],
|
|
168
174
|
abstract_type=res[1],
|
|
169
175
|
concrete_type=res[2],
|
|
176
|
+
local_arg_name=field_info.local_arg_name,
|
|
170
177
|
root_arg_name=f"{field_info.root_arg_name}.{func.name}",
|
|
171
178
|
)
|
|
172
179
|
|
|
@@ -191,6 +198,13 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
191
198
|
|
|
192
199
|
return node
|
|
193
200
|
|
|
201
|
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
202
|
+
# record field arg annotation nodes
|
|
203
|
+
for arg in node.args.args:
|
|
204
|
+
self._field_arg_annotation_nodes[arg.arg] = arg.annotation
|
|
205
|
+
|
|
206
|
+
return self.generic_visit(node)
|
|
207
|
+
|
|
194
208
|
def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
|
|
195
209
|
# Get field types for call site arguments
|
|
196
210
|
call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
|
|
@@ -211,7 +225,13 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
211
225
|
raise TypeError(
|
|
212
226
|
f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
|
|
213
227
|
)
|
|
214
|
-
callee_field_args[arg] =
|
|
228
|
+
callee_field_args[arg] = IntegrandVisitor.FieldInfo(
|
|
229
|
+
field=passed_field_info.field,
|
|
230
|
+
abstract_type=passed_field_info.abstract_type,
|
|
231
|
+
concrete_type=passed_field_info.concrete_type,
|
|
232
|
+
local_arg_name=arg,
|
|
233
|
+
root_arg_name=passed_field_info.root_arg_name,
|
|
234
|
+
)
|
|
215
235
|
|
|
216
236
|
return callee_field_args
|
|
217
237
|
|
|
@@ -263,18 +283,14 @@ class IntegrandTransformer(IntegrandVisitor):
|
|
|
263
283
|
f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
|
|
264
284
|
) from e
|
|
265
285
|
|
|
266
|
-
# Update the ast Call node to use the new function pointer
|
|
267
|
-
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
|
|
268
|
-
|
|
269
286
|
# Save the pointer as an attribute than can be accessed from the calling scope
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
else:
|
|
275
|
-
setattr(field_info.concrete_type, pointer.key, pointer)
|
|
287
|
+
# (use the annotation node of the argument this field is constructed from)
|
|
288
|
+
callee_node = self._field_arg_annotation_nodes[field_info.local_arg_name]
|
|
289
|
+
setattr(self._field_symbols[field_info.local_arg_name].abstract_type, pointer.key, pointer)
|
|
290
|
+
call.func = ast.Attribute(value=callee_node, attr=pointer.key, ctx=ast.Load())
|
|
276
291
|
|
|
277
|
-
|
|
292
|
+
# For shortcut default operator syntax, insert callee as first argument
|
|
293
|
+
if not isinstance(callee, Operator):
|
|
278
294
|
call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
|
|
279
295
|
|
|
280
296
|
# replace first argument with selected attribute
|
|
@@ -592,6 +608,9 @@ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_s
|
|
|
592
608
|
return options
|
|
593
609
|
|
|
594
610
|
|
|
611
|
+
_INTEGRATE_CONSTANT_TILE_SIZE = 256
|
|
612
|
+
|
|
613
|
+
|
|
595
614
|
def get_integrate_constant_kernel(
|
|
596
615
|
integrand_func: wp.Function,
|
|
597
616
|
domain: GeometryDomain,
|
|
@@ -599,8 +618,12 @@ def get_integrate_constant_kernel(
|
|
|
599
618
|
FieldStruct: wp.codegen.Struct,
|
|
600
619
|
ValueStruct: wp.codegen.Struct,
|
|
601
620
|
accumulate_dtype,
|
|
621
|
+
tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
|
|
602
622
|
):
|
|
623
|
+
zero_element = type_zero_element(accumulate_dtype)
|
|
624
|
+
|
|
603
625
|
def integrate_kernel_fn(
|
|
626
|
+
qp_count: int,
|
|
604
627
|
qp_arg: quadrature.Arg,
|
|
605
628
|
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
606
629
|
domain_arg: domain.ElementArg,
|
|
@@ -609,26 +632,33 @@ def get_integrate_constant_kernel(
|
|
|
609
632
|
values: ValueStruct,
|
|
610
633
|
result: wp.array(dtype=accumulate_dtype),
|
|
611
634
|
):
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
615
|
-
return
|
|
635
|
+
block_index, lane = wp.tid()
|
|
636
|
+
qp_eval_index = block_index * tile_size + lane
|
|
616
637
|
|
|
617
|
-
|
|
638
|
+
if qp_eval_index >= qp_count:
|
|
639
|
+
domain_element_index, qp = NULL_ELEMENT_INDEX, 0
|
|
640
|
+
else:
|
|
641
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
618
642
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
643
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
644
|
+
val = zero_element()
|
|
645
|
+
else:
|
|
646
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
622
647
|
|
|
623
|
-
|
|
624
|
-
|
|
648
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
649
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
650
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
625
651
|
|
|
626
|
-
|
|
627
|
-
|
|
652
|
+
test_dof_index = NULL_DOF_INDEX
|
|
653
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
628
654
|
|
|
629
|
-
|
|
655
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
656
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
657
|
+
|
|
658
|
+
val = accumulate_dtype(qp_weight * vol * integrand_func(sample, fields, values))
|
|
630
659
|
|
|
631
|
-
wp.
|
|
660
|
+
tile_integral = wp.tile_sum(wp.tile(val))
|
|
661
|
+
wp.tile_atomic_add(result, tile_integral, offset=0)
|
|
632
662
|
|
|
633
663
|
return integrate_kernel_fn
|
|
634
664
|
|
|
@@ -1020,7 +1050,7 @@ def get_integrate_bilinear_local_kernel(
|
|
|
1020
1050
|
|
|
1021
1051
|
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1022
1052
|
val = integrand_func(sample, fields, values)
|
|
1023
|
-
result[
|
|
1053
|
+
result[test_dof, trial_dof, qp_eval_index, taylor_dof] = qp_vol * val
|
|
1024
1054
|
|
|
1025
1055
|
return integrate_kernel_fn
|
|
1026
1056
|
|
|
@@ -1150,9 +1180,46 @@ def _generate_integrate_kernel(
|
|
|
1150
1180
|
return kernel, FieldStruct, ValueStruct
|
|
1151
1181
|
|
|
1152
1182
|
|
|
1183
|
+
def _generate_auxiliary_kernels(
|
|
1184
|
+
quadrature: Quadrature,
|
|
1185
|
+
test: Optional[TestField],
|
|
1186
|
+
trial: Optional[TrialField],
|
|
1187
|
+
accumulate_dtype: type,
|
|
1188
|
+
device,
|
|
1189
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1190
|
+
) -> List[Tuple[wp.Kernel, int]]:
|
|
1191
|
+
if test is None or not isinstance(test, LocalTestField):
|
|
1192
|
+
return ()
|
|
1193
|
+
|
|
1194
|
+
# For dispatched assembly, generate additional kernels
|
|
1195
|
+
# heuristic to use tiles for "long" quadratures
|
|
1196
|
+
dispatch_tile_size = 32
|
|
1197
|
+
qp_eval_count = quadrature.evaluation_point_count()
|
|
1198
|
+
|
|
1199
|
+
if trial is None:
|
|
1200
|
+
if (
|
|
1201
|
+
not device.is_cuda
|
|
1202
|
+
or qp_eval_count * test.space_restriction.total_node_element_count()
|
|
1203
|
+
< 3 * dispatch_tile_size * test.space_restriction.node_count() * test.domain.element_count()
|
|
1204
|
+
):
|
|
1205
|
+
dispatch_tile_size = 1
|
|
1206
|
+
dispatch_kernel = make_linear_dispatch_kernel(
|
|
1207
|
+
test, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
|
|
1208
|
+
)
|
|
1209
|
+
else:
|
|
1210
|
+
if not device.is_cuda or qp_eval_count < 3 * dispatch_tile_size * test.domain.element_count():
|
|
1211
|
+
dispatch_tile_size = 1
|
|
1212
|
+
dispatch_kernel = make_bilinear_dispatch_kernel(
|
|
1213
|
+
test, trial, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
return ((dispatch_kernel, dispatch_tile_size),)
|
|
1217
|
+
|
|
1218
|
+
|
|
1153
1219
|
def _launch_integrate_kernel(
|
|
1154
1220
|
integrand: Integrand,
|
|
1155
1221
|
kernel: wp.Kernel,
|
|
1222
|
+
auxiliary_kernels: List[Tuple[wp.Kernel, int]],
|
|
1156
1223
|
FieldStruct: wp.codegen.Struct,
|
|
1157
1224
|
ValueStruct: wp.codegen.Struct,
|
|
1158
1225
|
domain: GeometryDomain,
|
|
@@ -1202,10 +1269,15 @@ def _launch_integrate_kernel(
|
|
|
1202
1269
|
if output != accumulate_array or not add_to_output:
|
|
1203
1270
|
accumulate_array.zero_()
|
|
1204
1271
|
|
|
1272
|
+
qp_count = quadrature.evaluation_point_count()
|
|
1273
|
+
tile_size = _INTEGRATE_CONSTANT_TILE_SIZE
|
|
1274
|
+
block_count = (qp_count + tile_size - 1) // tile_size
|
|
1205
1275
|
wp.launch(
|
|
1206
1276
|
kernel=kernel,
|
|
1207
|
-
dim=
|
|
1277
|
+
dim=(block_count, tile_size),
|
|
1278
|
+
block_dim=tile_size,
|
|
1208
1279
|
inputs=[
|
|
1280
|
+
qp_count,
|
|
1209
1281
|
qp_arg,
|
|
1210
1282
|
quadrature.element_index_arg_value(device),
|
|
1211
1283
|
domain_elt_arg,
|
|
@@ -1328,21 +1400,29 @@ def _launch_integrate_kernel(
|
|
|
1328
1400
|
device=device,
|
|
1329
1401
|
)
|
|
1330
1402
|
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1403
|
+
if test.TAYLOR_DOF_COUNT == 0:
|
|
1404
|
+
wp.utils.warn(
|
|
1405
|
+
f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1406
|
+
category=UserWarning,
|
|
1407
|
+
stacklevel=2,
|
|
1408
|
+
)
|
|
1409
|
+
else:
|
|
1410
|
+
dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
|
|
1411
|
+
wp.launch(
|
|
1412
|
+
kernel=dispatch_kernel,
|
|
1413
|
+
dim=(test.space_restriction.node_count(), dispatch_tile_size),
|
|
1414
|
+
block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
|
|
1415
|
+
inputs=[
|
|
1416
|
+
qp_arg,
|
|
1417
|
+
domain_elt_arg,
|
|
1418
|
+
domain_elt_index_arg,
|
|
1419
|
+
test_arg,
|
|
1420
|
+
test.space.space_arg_value(device),
|
|
1421
|
+
local_result.array,
|
|
1422
|
+
output_view,
|
|
1423
|
+
],
|
|
1424
|
+
device=device,
|
|
1425
|
+
)
|
|
1346
1426
|
|
|
1347
1427
|
local_result.release()
|
|
1348
1428
|
|
|
@@ -1415,14 +1495,15 @@ def _launch_integrate_kernel(
|
|
|
1415
1495
|
device=device,
|
|
1416
1496
|
)
|
|
1417
1497
|
elif isinstance(test, LocalTestField):
|
|
1498
|
+
qp_eval_count = quadrature.evaluation_point_count()
|
|
1418
1499
|
local_result = cache.borrow_temporary(
|
|
1419
1500
|
temporary_store=temporary_store,
|
|
1420
1501
|
device=device,
|
|
1421
1502
|
requires_grad=False,
|
|
1422
1503
|
shape=(
|
|
1423
|
-
quadrature.evaluation_point_count(),
|
|
1424
1504
|
test.value_dof_count,
|
|
1425
1505
|
trial.value_dof_count,
|
|
1506
|
+
qp_eval_count,
|
|
1426
1507
|
test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
|
|
1427
1508
|
),
|
|
1428
1509
|
dtype=float,
|
|
@@ -1431,7 +1512,7 @@ def _launch_integrate_kernel(
|
|
|
1431
1512
|
wp.launch(
|
|
1432
1513
|
kernel=kernel,
|
|
1433
1514
|
dim=(
|
|
1434
|
-
|
|
1515
|
+
qp_eval_count,
|
|
1435
1516
|
test.value_dof_count,
|
|
1436
1517
|
trial.value_dof_count,
|
|
1437
1518
|
trial.TAYLOR_DOF_COUNT,
|
|
@@ -1448,45 +1529,41 @@ def _launch_integrate_kernel(
|
|
|
1448
1529
|
device=device,
|
|
1449
1530
|
)
|
|
1450
1531
|
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
triplet_values,
|
|
1487
|
-
],
|
|
1488
|
-
device=device,
|
|
1489
|
-
)
|
|
1532
|
+
if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
|
|
1533
|
+
wp.utils.warn(
|
|
1534
|
+
f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
|
|
1535
|
+
category=UserWarning,
|
|
1536
|
+
stacklevel=2,
|
|
1537
|
+
)
|
|
1538
|
+
triplet_rows.fill_(-1)
|
|
1539
|
+
else:
|
|
1540
|
+
dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
|
|
1541
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1542
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1543
|
+
wp.launch(
|
|
1544
|
+
kernel=dispatch_kernel,
|
|
1545
|
+
dim=(
|
|
1546
|
+
test.space_restriction.total_node_element_count(),
|
|
1547
|
+
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1548
|
+
dispatch_tile_size,
|
|
1549
|
+
),
|
|
1550
|
+
block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
|
|
1551
|
+
inputs=[
|
|
1552
|
+
qp_arg,
|
|
1553
|
+
domain_elt_arg,
|
|
1554
|
+
domain_elt_index_arg,
|
|
1555
|
+
test_arg,
|
|
1556
|
+
test.space.space_arg_value(device),
|
|
1557
|
+
trial_partition_arg,
|
|
1558
|
+
trial_topology_arg,
|
|
1559
|
+
trial.space.space_arg_value(device),
|
|
1560
|
+
local_result.array,
|
|
1561
|
+
triplet_rows,
|
|
1562
|
+
triplet_cols,
|
|
1563
|
+
triplet_values,
|
|
1564
|
+
],
|
|
1565
|
+
device=device,
|
|
1566
|
+
)
|
|
1490
1567
|
|
|
1491
1568
|
local_result.release()
|
|
1492
1569
|
|
|
@@ -1621,6 +1698,9 @@ def integrate(
|
|
|
1621
1698
|
if values is None:
|
|
1622
1699
|
values = {}
|
|
1623
1700
|
|
|
1701
|
+
if device is None:
|
|
1702
|
+
device = wp.get_device()
|
|
1703
|
+
|
|
1624
1704
|
if not isinstance(integrand, Integrand):
|
|
1625
1705
|
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
|
|
1626
1706
|
|
|
@@ -1713,9 +1793,19 @@ def integrate(
|
|
|
1713
1793
|
kernel_options=kernel_options,
|
|
1714
1794
|
)
|
|
1715
1795
|
|
|
1796
|
+
auxiliary_kernels = _generate_auxiliary_kernels(
|
|
1797
|
+
quadrature=quadrature,
|
|
1798
|
+
test=test,
|
|
1799
|
+
trial=trial,
|
|
1800
|
+
accumulate_dtype=accumulate_dtype,
|
|
1801
|
+
device=device,
|
|
1802
|
+
kernel_options=kernel_options,
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1716
1805
|
return _launch_integrate_kernel(
|
|
1717
1806
|
integrand=integrand,
|
|
1718
1807
|
kernel=kernel,
|
|
1808
|
+
auxiliary_kernels=auxiliary_kernels,
|
|
1719
1809
|
FieldStruct=FieldStruct,
|
|
1720
1810
|
ValueStruct=ValueStruct,
|
|
1721
1811
|
domain=domain,
|
|
@@ -2207,6 +2297,9 @@ def _launch_interpolate_kernel(
|
|
|
2207
2297
|
return
|
|
2208
2298
|
|
|
2209
2299
|
if quadrature is None:
|
|
2300
|
+
if dest is not None and (not is_array(dest) or dest.shape[0] != dim):
|
|
2301
|
+
raise ValueError(f"dest must be a warp array with {dim} rows")
|
|
2302
|
+
|
|
2210
2303
|
wp.launch(
|
|
2211
2304
|
kernel=kernel,
|
|
2212
2305
|
dim=dim,
|
|
@@ -2216,21 +2309,34 @@ def _launch_interpolate_kernel(
|
|
|
2216
2309
|
return
|
|
2217
2310
|
|
|
2218
2311
|
qp_arg = quadrature.arg_value(device)
|
|
2312
|
+
qp_eval_count = quadrature.evaluation_point_count()
|
|
2313
|
+
qp_index_count = quadrature.total_point_count()
|
|
2314
|
+
|
|
2315
|
+
if qp_eval_count != qp_index_count:
|
|
2316
|
+
wp.utils.warn(
|
|
2317
|
+
f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
|
|
2318
|
+
category=UserWarning,
|
|
2319
|
+
stacklevel=2,
|
|
2320
|
+
)
|
|
2321
|
+
|
|
2219
2322
|
qp_element_index_arg = quadrature.element_index_arg_value(device)
|
|
2220
2323
|
if trial is None:
|
|
2324
|
+
if dest is not None and (not is_array(dest) or dest.shape[0] != qp_index_count):
|
|
2325
|
+
raise ValueError(f"dest must be a warp array with {qp_index_count} rows")
|
|
2326
|
+
|
|
2221
2327
|
wp.launch(
|
|
2222
2328
|
kernel=kernel,
|
|
2223
|
-
dim=
|
|
2329
|
+
dim=qp_eval_count,
|
|
2224
2330
|
inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2225
2331
|
device=device,
|
|
2226
2332
|
)
|
|
2227
2333
|
return
|
|
2228
2334
|
|
|
2229
|
-
nnz =
|
|
2335
|
+
nnz = qp_eval_count * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
2230
2336
|
|
|
2231
|
-
if dest.nrow !=
|
|
2337
|
+
if dest.nrow != qp_index_count or dest.ncol != trial.space_partition.node_count():
|
|
2232
2338
|
raise RuntimeError(
|
|
2233
|
-
f"'dest' matrix must have {
|
|
2339
|
+
f"'dest' matrix must have {qp_index_count} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
2234
2340
|
)
|
|
2235
2341
|
if dest.block_shape[1] != trial.node_dof_count:
|
|
2236
2342
|
raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
|
|
@@ -2324,6 +2430,9 @@ def interpolate(
|
|
|
2324
2430
|
if values is None:
|
|
2325
2431
|
values = {}
|
|
2326
2432
|
|
|
2433
|
+
if device is None:
|
|
2434
|
+
device = wp.get_device()
|
|
2435
|
+
|
|
2327
2436
|
if not isinstance(integrand, Integrand):
|
|
2328
2437
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
2329
2438
|
|
warp/fem/space/restriction.py
CHANGED
|
@@ -159,6 +159,10 @@ class SpaceRestriction:
|
|
|
159
159
|
def node_partition_index(args: NodeArg, restriction_node_index: int):
|
|
160
160
|
return args.dof_partition_indices[restriction_node_index]
|
|
161
161
|
|
|
162
|
+
@wp.func
|
|
163
|
+
def node_partition_index_from_element_offset(args: NodeArg, element_offset: int):
|
|
164
|
+
return wp.lower_bound(args.dof_element_offsets, element_offset + 1) - 1
|
|
165
|
+
|
|
162
166
|
@wp.func
|
|
163
167
|
def node_element_range(args: NodeArg, partition_node_index: int):
|
|
164
168
|
return args.dof_element_offsets[partition_node_index], args.dof_element_offsets[partition_node_index + 1]
|
|
@@ -168,19 +168,12 @@ class TetrahedronPolynomialShapeFunctions(TetrahedronShapeFunction):
|
|
|
168
168
|
|
|
169
169
|
self.VERTEX_NODE_COUNT = wp.constant(1)
|
|
170
170
|
self.EDGE_NODE_COUNT = wp.constant(degree - 1)
|
|
171
|
+
self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
|
|
172
|
+
self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
|
|
173
|
+
|
|
171
174
|
self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) * (degree + 3) // 6)
|
|
172
175
|
self.NODES_PER_SIDE = wp.constant((degree + 1) * (degree + 2) // 2)
|
|
173
176
|
|
|
174
|
-
self.SIDE_NODE_COUNT = wp.constant(self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT))
|
|
175
|
-
self.INTERIOR_NODE_COUNT = wp.constant(
|
|
176
|
-
self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT)
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
self.VERTEX_NODE_COUNT = wp.constant(1)
|
|
180
|
-
self.EDGE_NODE_COUNT = wp.constant(degree - 1)
|
|
181
|
-
self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
|
|
182
|
-
self.INERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
|
|
183
|
-
|
|
184
177
|
tet_coords = np.empty((self.NODES_PER_ELEMENT, 3), dtype=int)
|
|
185
178
|
|
|
186
179
|
for tx in range(degree + 1):
|
|
@@ -107,7 +107,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
107
107
|
assert hooks.forward, "Failed to find kernel entry point"
|
|
108
108
|
|
|
109
109
|
# Launch the kernel.
|
|
110
|
-
wp.context.runtime.core.
|
|
110
|
+
wp.context.runtime.core.wp_cuda_launch_kernel(
|
|
111
111
|
device.context, hooks.forward, bounds.size, 0, 256, hooks.forward_smem_bytes, kernel_params, stream
|
|
112
112
|
)
|
|
113
113
|
|