warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +1904 -114
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +331 -101
- warp/builtins.py +1244 -160
- warp/codegen.py +317 -206
- warp/config.py +1 -1
- warp/context.py +1465 -789
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_kernel.py +2 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +25 -2
- warp/jax_experimental/ffi.py +22 -1
- warp/jax_experimental/xla_ffi.py +16 -7
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +86 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +40 -31
- warp/native/sort.h +2 -0
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +471 -82
- warp/native/vec.h +328 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +377 -216
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +99 -18
- warp/render/render_usd.py +1 -0
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/interop/test_jax.py +608 -28
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +58 -5
- warp/tests/test_codegen.py +4 -3
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +49 -6
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +15 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +61 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +245 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +571 -267
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import ast
|
|
17
17
|
import inspect
|
|
18
18
|
import textwrap
|
|
19
|
-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
|
19
|
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import warp as wp
|
|
22
22
|
import warp.fem.operator as operator
|
|
@@ -34,7 +34,10 @@ from warp.fem.field import (
|
|
|
34
34
|
TrialField,
|
|
35
35
|
make_restriction,
|
|
36
36
|
)
|
|
37
|
-
from warp.fem.field.virtual import
|
|
37
|
+
from warp.fem.field.virtual import (
|
|
38
|
+
make_bilinear_dispatch_kernel,
|
|
39
|
+
make_linear_dispatch_kernel,
|
|
40
|
+
)
|
|
38
41
|
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
39
42
|
from warp.fem.operator import (
|
|
40
43
|
Integrand,
|
|
@@ -101,7 +104,8 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
101
104
|
field: FieldLike
|
|
102
105
|
abstract_type: type
|
|
103
106
|
concrete_type: type
|
|
104
|
-
root_arg_name:
|
|
107
|
+
root_arg_name: str
|
|
108
|
+
local_arg_name: str
|
|
105
109
|
|
|
106
110
|
def __init__(
|
|
107
111
|
self,
|
|
@@ -111,6 +115,7 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
111
115
|
self._integrand = integrand
|
|
112
116
|
self._field_symbols = field_info.copy()
|
|
113
117
|
self._field_nodes = {}
|
|
118
|
+
self._field_arg_annotation_nodes = {}
|
|
114
119
|
|
|
115
120
|
@staticmethod
|
|
116
121
|
def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
@@ -127,6 +132,7 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
127
132
|
abstract_type=integrand.argspec.annotations[name],
|
|
128
133
|
concrete_type=get_concrete_type(field),
|
|
129
134
|
root_arg_name=name,
|
|
135
|
+
local_arg_name=name,
|
|
130
136
|
)
|
|
131
137
|
for name, field in field_args.items()
|
|
132
138
|
}
|
|
@@ -167,6 +173,7 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
167
173
|
field=res[0],
|
|
168
174
|
abstract_type=res[1],
|
|
169
175
|
concrete_type=res[2],
|
|
176
|
+
local_arg_name=field_info.local_arg_name,
|
|
170
177
|
root_arg_name=f"{field_info.root_arg_name}.{func.name}",
|
|
171
178
|
)
|
|
172
179
|
|
|
@@ -191,6 +198,13 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
191
198
|
|
|
192
199
|
return node
|
|
193
200
|
|
|
201
|
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
202
|
+
# record field arg annotation nodes
|
|
203
|
+
for arg in node.args.args:
|
|
204
|
+
self._field_arg_annotation_nodes[arg.arg] = arg.annotation
|
|
205
|
+
|
|
206
|
+
return self.generic_visit(node)
|
|
207
|
+
|
|
194
208
|
def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
|
|
195
209
|
# Get field types for call site arguments
|
|
196
210
|
call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
|
|
@@ -211,7 +225,13 @@ class IntegrandVisitor(ast.NodeTransformer):
|
|
|
211
225
|
raise TypeError(
|
|
212
226
|
f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
|
|
213
227
|
)
|
|
214
|
-
callee_field_args[arg] =
|
|
228
|
+
callee_field_args[arg] = IntegrandVisitor.FieldInfo(
|
|
229
|
+
field=passed_field_info.field,
|
|
230
|
+
abstract_type=passed_field_info.abstract_type,
|
|
231
|
+
concrete_type=passed_field_info.concrete_type,
|
|
232
|
+
local_arg_name=arg,
|
|
233
|
+
root_arg_name=passed_field_info.root_arg_name,
|
|
234
|
+
)
|
|
215
235
|
|
|
216
236
|
return callee_field_args
|
|
217
237
|
|
|
@@ -263,18 +283,14 @@ class IntegrandTransformer(IntegrandVisitor):
|
|
|
263
283
|
f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
|
|
264
284
|
) from e
|
|
265
285
|
|
|
266
|
-
# Update the ast Call node to use the new function pointer
|
|
267
|
-
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
|
|
268
|
-
|
|
269
286
|
# Save the pointer as an attribute than can be accessed from the calling scope
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
else:
|
|
275
|
-
setattr(field_info.concrete_type, pointer.key, pointer)
|
|
287
|
+
# (use the annotation node of the argument this field is constructed from)
|
|
288
|
+
callee_node = self._field_arg_annotation_nodes[field_info.local_arg_name]
|
|
289
|
+
setattr(self._field_symbols[field_info.local_arg_name].abstract_type, pointer.key, pointer)
|
|
290
|
+
call.func = ast.Attribute(value=callee_node, attr=pointer.key, ctx=ast.Load())
|
|
276
291
|
|
|
277
|
-
|
|
292
|
+
# For shortcut default operator syntax, insert callee as first argument
|
|
293
|
+
if not isinstance(callee, Operator):
|
|
278
294
|
call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
|
|
279
295
|
|
|
280
296
|
# replace first argument with selected attribute
|
|
@@ -592,6 +608,9 @@ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_s
|
|
|
592
608
|
return options
|
|
593
609
|
|
|
594
610
|
|
|
611
|
+
_INTEGRATE_CONSTANT_TILE_SIZE = 256
|
|
612
|
+
|
|
613
|
+
|
|
595
614
|
def get_integrate_constant_kernel(
|
|
596
615
|
integrand_func: wp.Function,
|
|
597
616
|
domain: GeometryDomain,
|
|
@@ -599,8 +618,12 @@ def get_integrate_constant_kernel(
|
|
|
599
618
|
FieldStruct: wp.codegen.Struct,
|
|
600
619
|
ValueStruct: wp.codegen.Struct,
|
|
601
620
|
accumulate_dtype,
|
|
621
|
+
tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
|
|
602
622
|
):
|
|
623
|
+
zero_element = type_zero_element(accumulate_dtype)
|
|
624
|
+
|
|
603
625
|
def integrate_kernel_fn(
|
|
626
|
+
qp_count: int,
|
|
604
627
|
qp_arg: quadrature.Arg,
|
|
605
628
|
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
606
629
|
domain_arg: domain.ElementArg,
|
|
@@ -609,26 +632,33 @@ def get_integrate_constant_kernel(
|
|
|
609
632
|
values: ValueStruct,
|
|
610
633
|
result: wp.array(dtype=accumulate_dtype),
|
|
611
634
|
):
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
615
|
-
return
|
|
635
|
+
block_index, lane = wp.tid()
|
|
636
|
+
qp_eval_index = block_index * tile_size + lane
|
|
616
637
|
|
|
617
|
-
|
|
638
|
+
if qp_eval_index >= qp_count:
|
|
639
|
+
domain_element_index, qp = NULL_ELEMENT_INDEX, 0
|
|
640
|
+
else:
|
|
641
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
618
642
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
643
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
644
|
+
val = zero_element()
|
|
645
|
+
else:
|
|
646
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
622
647
|
|
|
623
|
-
|
|
624
|
-
|
|
648
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
649
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
650
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
625
651
|
|
|
626
|
-
|
|
627
|
-
|
|
652
|
+
test_dof_index = NULL_DOF_INDEX
|
|
653
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
628
654
|
|
|
629
|
-
|
|
655
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
656
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
630
657
|
|
|
631
|
-
|
|
658
|
+
val = accumulate_dtype(qp_weight * vol * integrand_func(sample, fields, values))
|
|
659
|
+
|
|
660
|
+
tile_integral = wp.tile_sum(wp.tile(val))
|
|
661
|
+
wp.tile_atomic_add(result, tile_integral, offset=0)
|
|
632
662
|
|
|
633
663
|
return integrate_kernel_fn
|
|
634
664
|
|
|
@@ -1020,7 +1050,7 @@ def get_integrate_bilinear_local_kernel(
|
|
|
1020
1050
|
|
|
1021
1051
|
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1022
1052
|
val = integrand_func(sample, fields, values)
|
|
1023
|
-
result[
|
|
1053
|
+
result[test_dof, trial_dof, qp_eval_index, taylor_dof] = qp_vol * val
|
|
1024
1054
|
|
|
1025
1055
|
return integrate_kernel_fn
|
|
1026
1056
|
|
|
@@ -1150,9 +1180,46 @@ def _generate_integrate_kernel(
|
|
|
1150
1180
|
return kernel, FieldStruct, ValueStruct
|
|
1151
1181
|
|
|
1152
1182
|
|
|
1183
|
+
def _generate_auxiliary_kernels(
|
|
1184
|
+
quadrature: Quadrature,
|
|
1185
|
+
test: Optional[TestField],
|
|
1186
|
+
trial: Optional[TrialField],
|
|
1187
|
+
accumulate_dtype: type,
|
|
1188
|
+
device,
|
|
1189
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1190
|
+
) -> List[Tuple[wp.Kernel, int]]:
|
|
1191
|
+
if test is None or not isinstance(test, LocalTestField):
|
|
1192
|
+
return ()
|
|
1193
|
+
|
|
1194
|
+
# For dispatched assembly, generate additional kernels
|
|
1195
|
+
# heuristic to use tiles for "long" quadratures
|
|
1196
|
+
dispatch_tile_size = 32
|
|
1197
|
+
qp_eval_count = quadrature.evaluation_point_count()
|
|
1198
|
+
|
|
1199
|
+
if trial is None:
|
|
1200
|
+
if (
|
|
1201
|
+
not device.is_cuda
|
|
1202
|
+
or qp_eval_count * test.space_restriction.total_node_element_count()
|
|
1203
|
+
< 3 * dispatch_tile_size * test.space_restriction.node_count() * test.domain.element_count()
|
|
1204
|
+
):
|
|
1205
|
+
dispatch_tile_size = 1
|
|
1206
|
+
dispatch_kernel = make_linear_dispatch_kernel(
|
|
1207
|
+
test, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
|
|
1208
|
+
)
|
|
1209
|
+
else:
|
|
1210
|
+
if not device.is_cuda or qp_eval_count < 3 * dispatch_tile_size * test.domain.element_count():
|
|
1211
|
+
dispatch_tile_size = 1
|
|
1212
|
+
dispatch_kernel = make_bilinear_dispatch_kernel(
|
|
1213
|
+
test, trial, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
return ((dispatch_kernel, dispatch_tile_size),)
|
|
1217
|
+
|
|
1218
|
+
|
|
1153
1219
|
def _launch_integrate_kernel(
|
|
1154
1220
|
integrand: Integrand,
|
|
1155
1221
|
kernel: wp.Kernel,
|
|
1222
|
+
auxiliary_kernels: List[Tuple[wp.Kernel, int]],
|
|
1156
1223
|
FieldStruct: wp.codegen.Struct,
|
|
1157
1224
|
ValueStruct: wp.codegen.Struct,
|
|
1158
1225
|
domain: GeometryDomain,
|
|
@@ -1202,10 +1269,15 @@ def _launch_integrate_kernel(
|
|
|
1202
1269
|
if output != accumulate_array or not add_to_output:
|
|
1203
1270
|
accumulate_array.zero_()
|
|
1204
1271
|
|
|
1272
|
+
qp_count = quadrature.evaluation_point_count()
|
|
1273
|
+
tile_size = _INTEGRATE_CONSTANT_TILE_SIZE
|
|
1274
|
+
block_count = (qp_count + tile_size - 1) // tile_size
|
|
1205
1275
|
wp.launch(
|
|
1206
1276
|
kernel=kernel,
|
|
1207
|
-
dim=
|
|
1277
|
+
dim=(block_count, tile_size),
|
|
1278
|
+
block_dim=tile_size,
|
|
1208
1279
|
inputs=[
|
|
1280
|
+
qp_count,
|
|
1209
1281
|
qp_arg,
|
|
1210
1282
|
quadrature.element_index_arg_value(device),
|
|
1211
1283
|
domain_elt_arg,
|
|
@@ -1335,10 +1407,11 @@ def _launch_integrate_kernel(
|
|
|
1335
1407
|
stacklevel=2,
|
|
1336
1408
|
)
|
|
1337
1409
|
else:
|
|
1338
|
-
dispatch_kernel =
|
|
1410
|
+
dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
|
|
1339
1411
|
wp.launch(
|
|
1340
1412
|
kernel=dispatch_kernel,
|
|
1341
|
-
dim=(test.space_restriction.node_count(),
|
|
1413
|
+
dim=(test.space_restriction.node_count(), dispatch_tile_size),
|
|
1414
|
+
block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
|
|
1342
1415
|
inputs=[
|
|
1343
1416
|
qp_arg,
|
|
1344
1417
|
domain_elt_arg,
|
|
@@ -1422,14 +1495,15 @@ def _launch_integrate_kernel(
|
|
|
1422
1495
|
device=device,
|
|
1423
1496
|
)
|
|
1424
1497
|
elif isinstance(test, LocalTestField):
|
|
1498
|
+
qp_eval_count = quadrature.evaluation_point_count()
|
|
1425
1499
|
local_result = cache.borrow_temporary(
|
|
1426
1500
|
temporary_store=temporary_store,
|
|
1427
1501
|
device=device,
|
|
1428
1502
|
requires_grad=False,
|
|
1429
1503
|
shape=(
|
|
1430
|
-
quadrature.evaluation_point_count(),
|
|
1431
1504
|
test.value_dof_count,
|
|
1432
1505
|
trial.value_dof_count,
|
|
1506
|
+
qp_eval_count,
|
|
1433
1507
|
test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
|
|
1434
1508
|
),
|
|
1435
1509
|
dtype=float,
|
|
@@ -1438,7 +1512,7 @@ def _launch_integrate_kernel(
|
|
|
1438
1512
|
wp.launch(
|
|
1439
1513
|
kernel=kernel,
|
|
1440
1514
|
dim=(
|
|
1441
|
-
|
|
1515
|
+
qp_eval_count,
|
|
1442
1516
|
test.value_dof_count,
|
|
1443
1517
|
trial.value_dof_count,
|
|
1444
1518
|
trial.TAYLOR_DOF_COUNT,
|
|
@@ -1455,17 +1529,6 @@ def _launch_integrate_kernel(
|
|
|
1455
1529
|
device=device,
|
|
1456
1530
|
)
|
|
1457
1531
|
|
|
1458
|
-
vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
|
|
1459
|
-
vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
|
|
1460
|
-
local_result_as_vec = wp.array(
|
|
1461
|
-
data=None,
|
|
1462
|
-
ptr=local_result.array.ptr,
|
|
1463
|
-
capacity=local_result.array.capacity,
|
|
1464
|
-
device=local_result.array.device,
|
|
1465
|
-
shape=vec_array_shape,
|
|
1466
|
-
dtype=vec_array_dtype,
|
|
1467
|
-
)
|
|
1468
|
-
|
|
1469
1532
|
if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
|
|
1470
1533
|
wp.utils.warn(
|
|
1471
1534
|
f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
|
|
@@ -1474,18 +1537,17 @@ def _launch_integrate_kernel(
|
|
|
1474
1537
|
)
|
|
1475
1538
|
triplet_rows.fill_(-1)
|
|
1476
1539
|
else:
|
|
1477
|
-
dispatch_kernel
|
|
1478
|
-
|
|
1540
|
+
dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
|
|
1479
1541
|
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1480
1542
|
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1481
1543
|
wp.launch(
|
|
1482
1544
|
kernel=dispatch_kernel,
|
|
1483
1545
|
dim=(
|
|
1484
|
-
test.space_restriction.
|
|
1485
|
-
test.node_dof_count,
|
|
1486
|
-
trial.node_dof_count,
|
|
1546
|
+
test.space_restriction.total_node_element_count(),
|
|
1487
1547
|
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1548
|
+
dispatch_tile_size,
|
|
1488
1549
|
),
|
|
1550
|
+
block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
|
|
1489
1551
|
inputs=[
|
|
1490
1552
|
qp_arg,
|
|
1491
1553
|
domain_elt_arg,
|
|
@@ -1495,7 +1557,7 @@ def _launch_integrate_kernel(
|
|
|
1495
1557
|
trial_partition_arg,
|
|
1496
1558
|
trial_topology_arg,
|
|
1497
1559
|
trial.space.space_arg_value(device),
|
|
1498
|
-
|
|
1560
|
+
local_result.array,
|
|
1499
1561
|
triplet_rows,
|
|
1500
1562
|
triplet_cols,
|
|
1501
1563
|
triplet_values,
|
|
@@ -1636,6 +1698,9 @@ def integrate(
|
|
|
1636
1698
|
if values is None:
|
|
1637
1699
|
values = {}
|
|
1638
1700
|
|
|
1701
|
+
if device is None:
|
|
1702
|
+
device = wp.get_device()
|
|
1703
|
+
|
|
1639
1704
|
if not isinstance(integrand, Integrand):
|
|
1640
1705
|
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
|
|
1641
1706
|
|
|
@@ -1728,9 +1793,19 @@ def integrate(
|
|
|
1728
1793
|
kernel_options=kernel_options,
|
|
1729
1794
|
)
|
|
1730
1795
|
|
|
1796
|
+
auxiliary_kernels = _generate_auxiliary_kernels(
|
|
1797
|
+
quadrature=quadrature,
|
|
1798
|
+
test=test,
|
|
1799
|
+
trial=trial,
|
|
1800
|
+
accumulate_dtype=accumulate_dtype,
|
|
1801
|
+
device=device,
|
|
1802
|
+
kernel_options=kernel_options,
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1731
1805
|
return _launch_integrate_kernel(
|
|
1732
1806
|
integrand=integrand,
|
|
1733
1807
|
kernel=kernel,
|
|
1808
|
+
auxiliary_kernels=auxiliary_kernels,
|
|
1734
1809
|
FieldStruct=FieldStruct,
|
|
1735
1810
|
ValueStruct=ValueStruct,
|
|
1736
1811
|
domain=domain,
|
|
@@ -2355,6 +2430,9 @@ def interpolate(
|
|
|
2355
2430
|
if values is None:
|
|
2356
2431
|
values = {}
|
|
2357
2432
|
|
|
2433
|
+
if device is None:
|
|
2434
|
+
device = wp.get_device()
|
|
2435
|
+
|
|
2358
2436
|
if not isinstance(integrand, Integrand):
|
|
2359
2437
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
2360
2438
|
|
warp/fem/space/restriction.py
CHANGED
|
@@ -159,6 +159,10 @@ class SpaceRestriction:
|
|
|
159
159
|
def node_partition_index(args: NodeArg, restriction_node_index: int):
|
|
160
160
|
return args.dof_partition_indices[restriction_node_index]
|
|
161
161
|
|
|
162
|
+
@wp.func
|
|
163
|
+
def node_partition_index_from_element_offset(args: NodeArg, element_offset: int):
|
|
164
|
+
return wp.lower_bound(args.dof_element_offsets, element_offset + 1) - 1
|
|
165
|
+
|
|
162
166
|
@wp.func
|
|
163
167
|
def node_element_range(args: NodeArg, partition_node_index: int):
|
|
164
168
|
return args.dof_element_offsets[partition_node_index], args.dof_element_offsets[partition_node_index + 1]
|
|
@@ -168,19 +168,12 @@ class TetrahedronPolynomialShapeFunctions(TetrahedronShapeFunction):
|
|
|
168
168
|
|
|
169
169
|
self.VERTEX_NODE_COUNT = wp.constant(1)
|
|
170
170
|
self.EDGE_NODE_COUNT = wp.constant(degree - 1)
|
|
171
|
+
self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
|
|
172
|
+
self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
|
|
173
|
+
|
|
171
174
|
self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) * (degree + 3) // 6)
|
|
172
175
|
self.NODES_PER_SIDE = wp.constant((degree + 1) * (degree + 2) // 2)
|
|
173
176
|
|
|
174
|
-
self.SIDE_NODE_COUNT = wp.constant(self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT))
|
|
175
|
-
self.INTERIOR_NODE_COUNT = wp.constant(
|
|
176
|
-
self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT)
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
self.VERTEX_NODE_COUNT = wp.constant(1)
|
|
180
|
-
self.EDGE_NODE_COUNT = wp.constant(degree - 1)
|
|
181
|
-
self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
|
|
182
|
-
self.INERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
|
|
183
|
-
|
|
184
177
|
tet_coords = np.empty((self.NODES_PER_ELEMENT, 3), dtype=int)
|
|
185
178
|
|
|
186
179
|
for tx in range(degree + 1):
|
|
@@ -19,6 +19,7 @@ import warp as wp
|
|
|
19
19
|
from warp.context import type_str
|
|
20
20
|
from warp.jax import get_jax_device
|
|
21
21
|
from warp.types import array_t, launch_bounds_t, strides_from_shape
|
|
22
|
+
from warp.utils import warn
|
|
22
23
|
|
|
23
24
|
_jax_warp_p = None
|
|
24
25
|
|
|
@@ -28,7 +29,7 @@ _registered_kernels = [None]
|
|
|
28
29
|
_registered_kernel_to_id = {}
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def jax_kernel(kernel, launch_dims=None):
|
|
32
|
+
def jax_kernel(kernel, launch_dims=None, quiet=False):
|
|
32
33
|
"""Create a Jax primitive from a Warp kernel.
|
|
33
34
|
|
|
34
35
|
NOTE: This is an experimental feature under development.
|
|
@@ -38,6 +39,7 @@ def jax_kernel(kernel, launch_dims=None):
|
|
|
38
39
|
launch_dims: Optional. Specify the kernel launch dimensions. If None,
|
|
39
40
|
dimensions are inferred from the shape of the first argument.
|
|
40
41
|
This option when set will specify the output dimensions.
|
|
42
|
+
quiet: Optional. If True, suppress deprecation warnings with newer JAX versions.
|
|
41
43
|
|
|
42
44
|
Limitations:
|
|
43
45
|
- All kernel arguments must be contiguous arrays.
|
|
@@ -46,6 +48,27 @@ def jax_kernel(kernel, launch_dims=None):
|
|
|
46
48
|
- Only the CUDA backend is supported.
|
|
47
49
|
"""
|
|
48
50
|
|
|
51
|
+
import jax
|
|
52
|
+
|
|
53
|
+
# check if JAX version supports this
|
|
54
|
+
if jax.__version_info__ < (0, 4, 25) or jax.__version_info__ >= (0, 8, 0):
|
|
55
|
+
msg = (
|
|
56
|
+
"This version of jax_kernel() requires JAX version 0.4.25 - 0.7.x, "
|
|
57
|
+
f"but installed JAX version is {jax.__version_info__}."
|
|
58
|
+
)
|
|
59
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
60
|
+
msg += " Please use warp.jax_experimental.ffi.jax_kernel instead."
|
|
61
|
+
raise RuntimeError(msg)
|
|
62
|
+
|
|
63
|
+
# deprecation warning
|
|
64
|
+
if jax.__version_info__ >= (0, 5, 0) and not quiet:
|
|
65
|
+
warn(
|
|
66
|
+
"This version of jax_kernel() is deprecated and will not be supported with newer JAX versions. "
|
|
67
|
+
"Please use the newer FFI version instead (warp.jax_experimental.ffi.jax_kernel). "
|
|
68
|
+
"In Warp release 1.10, the FFI version will become the default implementation of jax_kernel().",
|
|
69
|
+
DeprecationWarning,
|
|
70
|
+
)
|
|
71
|
+
|
|
49
72
|
if _jax_warp_p is None:
|
|
50
73
|
# Create and register the primitive
|
|
51
74
|
_create_jax_warp_primitive()
|
|
@@ -107,7 +130,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
107
130
|
assert hooks.forward, "Failed to find kernel entry point"
|
|
108
131
|
|
|
109
132
|
# Launch the kernel.
|
|
110
|
-
wp.context.runtime.core.
|
|
133
|
+
wp.context.runtime.core.wp_cuda_launch_kernel(
|
|
111
134
|
device.context, hooks.forward, bounds.size, 0, 256, hooks.forward_smem_bytes, kernel_params, stream
|
|
112
135
|
)
|
|
113
136
|
|
warp/jax_experimental/ffi.py
CHANGED
|
@@ -29,6 +29,18 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
|
|
|
29
29
|
from .xla_ffi import *
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
def check_jax_version():
|
|
33
|
+
# check if JAX version supports this
|
|
34
|
+
if jax.__version_info__ < (0, 5, 0):
|
|
35
|
+
msg = (
|
|
36
|
+
"This version of jax_kernel() requires JAX version 0.5.0 or higher, "
|
|
37
|
+
f"but installed JAX version is {jax.__version_info__}."
|
|
38
|
+
)
|
|
39
|
+
if jax.__version_info__ >= (0, 4, 25):
|
|
40
|
+
msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
|
|
41
|
+
raise RuntimeError(msg)
|
|
42
|
+
|
|
43
|
+
|
|
32
44
|
class GraphMode(IntEnum):
|
|
33
45
|
NONE = 0 # don't capture a graph
|
|
34
46
|
JAX = 1 # let JAX capture a graph
|
|
@@ -317,7 +329,7 @@ class FfiKernel:
|
|
|
317
329
|
assert hooks.forward, "Failed to find kernel entry point"
|
|
318
330
|
|
|
319
331
|
# launch the kernel
|
|
320
|
-
wp.context.runtime.core.
|
|
332
|
+
wp.context.runtime.core.wp_cuda_launch_kernel(
|
|
321
333
|
device.context,
|
|
322
334
|
hooks.forward,
|
|
323
335
|
launch_bounds.size,
|
|
@@ -381,6 +393,7 @@ class FfiCallable:
|
|
|
381
393
|
if arg_name == "return":
|
|
382
394
|
if arg_type is not None:
|
|
383
395
|
raise TypeError("Function must not return a value")
|
|
396
|
+
continue
|
|
384
397
|
else:
|
|
385
398
|
arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
|
|
386
399
|
if arg_name in in_out_argnames:
|
|
@@ -667,8 +680,12 @@ def jax_kernel(
|
|
|
667
680
|
- There must be at least one output or input-output argument.
|
|
668
681
|
- Only the CUDA backend is supported.
|
|
669
682
|
"""
|
|
683
|
+
|
|
684
|
+
check_jax_version()
|
|
685
|
+
|
|
670
686
|
key = (
|
|
671
687
|
kernel.func,
|
|
688
|
+
kernel.sig,
|
|
672
689
|
num_outputs,
|
|
673
690
|
vmap_method,
|
|
674
691
|
tuple(launch_dims) if launch_dims else launch_dims,
|
|
@@ -725,6 +742,8 @@ def jax_callable(
|
|
|
725
742
|
- Only the CUDA backend is supported.
|
|
726
743
|
"""
|
|
727
744
|
|
|
745
|
+
check_jax_version()
|
|
746
|
+
|
|
728
747
|
if graph_compatible is not None:
|
|
729
748
|
wp.utils.warn(
|
|
730
749
|
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
@@ -771,6 +790,8 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
|
|
|
771
790
|
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
772
791
|
"""
|
|
773
792
|
|
|
793
|
+
check_jax_version()
|
|
794
|
+
|
|
774
795
|
# TODO check that the name is not already registered
|
|
775
796
|
|
|
776
797
|
def ffi_callback(call_frame):
|
warp/jax_experimental/xla_ffi.py
CHANGED
|
@@ -475,17 +475,26 @@ _xla_data_type_to_constructor = {
|
|
|
475
475
|
XLA_FFI_DataType.C64: jnp.complex64,
|
|
476
476
|
XLA_FFI_DataType.C128: jnp.complex128,
|
|
477
477
|
# XLA_FFI_DataType.TOKEN
|
|
478
|
-
XLA_FFI_DataType.F8E5M2: jnp.float8_e5m2,
|
|
479
|
-
XLA_FFI_DataType.F8E3M4: jnp.float8_e3m4,
|
|
480
|
-
XLA_FFI_DataType.F8E4M3: jnp.float8_e4m3,
|
|
481
|
-
XLA_FFI_DataType.F8E4M3FN: jnp.float8_e4m3fn,
|
|
482
|
-
XLA_FFI_DataType.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz,
|
|
483
|
-
XLA_FFI_DataType.F8E5M2FNUZ: jnp.float8_e5m2fnuz,
|
|
484
|
-
XLA_FFI_DataType.F8E4M3FNUZ: jnp.float8_e4m3fnuz,
|
|
485
478
|
# XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
|
|
486
479
|
# XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
|
|
487
480
|
}
|
|
488
481
|
|
|
482
|
+
# newer types not supported by older versions
|
|
483
|
+
if hasattr(jnp, "float8_e5m2"):
|
|
484
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2] = jnp.float8_e5m2
|
|
485
|
+
if hasattr(jnp, "float8_e3m4"):
|
|
486
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E3M4] = jnp.float8_e3m4
|
|
487
|
+
if hasattr(jnp, "float8_e4m3"):
|
|
488
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3] = jnp.float8_e4m3
|
|
489
|
+
if hasattr(jnp, "float8_e4m3fn"):
|
|
490
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FN] = jnp.float8_e4m3fn
|
|
491
|
+
if hasattr(jnp, "float8_e4m3b11fnuz"):
|
|
492
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3B11FNUZ] = jnp.float8_e4m3b11fnuz
|
|
493
|
+
if hasattr(jnp, "float8_e5m2fnuz"):
|
|
494
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2FNUZ] = jnp.float8_e5m2fnuz
|
|
495
|
+
if hasattr(jnp, "float8_e4m3fnuz"):
|
|
496
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FNUZ] = jnp.float8_e4m3fnuz
|
|
497
|
+
|
|
489
498
|
|
|
490
499
|
########################################################################
|
|
491
500
|
# Helpers for translating between ctypes and python types
|