warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.1__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/warp.so +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1412 -888
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +91 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
- warp_lang-1.3.1.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/fem/operator.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from typing import Any, Callable
|
|
3
2
|
|
|
4
3
|
import warp as wp
|
|
5
4
|
from warp.fem import utils
|
|
6
|
-
from warp.fem.types import Domain, Field, Sample
|
|
5
|
+
from warp.fem.types import Domain, Field, NodeIndex, Sample
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class Integrand:
|
|
@@ -15,7 +14,7 @@ class Integrand:
|
|
|
15
14
|
self.func = func
|
|
16
15
|
self.name = wp.codegen.make_full_qualified_name(self.func)
|
|
17
16
|
self.module = wp.get_module(self.func.__module__)
|
|
18
|
-
self.argspec =
|
|
17
|
+
self.argspec = wp.codegen.get_full_arg_spec(self.func)
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
class Operator:
|
|
@@ -55,7 +54,7 @@ def position(domain: Domain, s: Sample):
|
|
|
55
54
|
pass
|
|
56
55
|
|
|
57
56
|
|
|
58
|
-
@operator(resolver=lambda dmn: dmn.
|
|
57
|
+
@operator(resolver=lambda dmn: dmn.element_normal)
|
|
59
58
|
def normal(domain: Domain, s: Sample):
|
|
60
59
|
"""Evaluates the element normal at the sample point `s`. Null for interior points."""
|
|
61
60
|
pass
|
|
@@ -71,13 +70,12 @@ def deformation_gradient(domain: Domain, s: Sample):
|
|
|
71
70
|
def lookup(domain: Domain, x: Any) -> Sample:
|
|
72
71
|
"""Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the domain.
|
|
73
72
|
|
|
74
|
-
|
|
73
|
+
Args:
|
|
75
74
|
x: world position of the point to look-up in the geometry
|
|
76
75
|
guess: (optional) :class:`Sample` initial guess, may help perform the query
|
|
77
76
|
|
|
78
|
-
|
|
79
|
-
Currently this operator is
|
|
80
|
-
For :class:`TriangleMesh2D` and :class:`Tetmesh` geometries, the operator requires providing `guess`.
|
|
77
|
+
Note:
|
|
78
|
+
Currently this operator is unsupported for :class:`Hexmesh`, :class:`Quadmesh2D` and deformed geometries.
|
|
81
79
|
"""
|
|
82
80
|
pass
|
|
83
81
|
|
|
@@ -142,7 +140,14 @@ def degree(f: Field):
|
|
|
142
140
|
|
|
143
141
|
@operator(resolver=lambda f: f.at_node)
|
|
144
142
|
def at_node(f: Field, s: Sample):
|
|
145
|
-
"""For a Test or Trial field
|
|
143
|
+
"""For a Test or Trial field `f`, returns a copy of the Sample `s` moved to the coordinates of the node being evaluated"""
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@operator(resolver=lambda f: f.node_partition_index)
|
|
148
|
+
def node_partition_index(f: Field, node_index: NodeIndex):
|
|
149
|
+
"""For a NodalField `f`, returns the index of a given node in the fields's space partition,
|
|
150
|
+
or ``NULL_NODE_INDEX`` if it does not exists"""
|
|
146
151
|
pass
|
|
147
152
|
|
|
148
153
|
|
|
@@ -8,8 +8,6 @@ from warp.fem.utils import compress_node_indices
|
|
|
8
8
|
|
|
9
9
|
from .quadrature import Quadrature
|
|
10
10
|
|
|
11
|
-
wp.set_module_options({"enable_backward": False})
|
|
12
|
-
|
|
13
11
|
|
|
14
12
|
class PicQuadrature(Quadrature):
|
|
15
13
|
"""Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
|
|
@@ -23,6 +21,7 @@ class PicQuadrature(Quadrature):
|
|
|
23
21
|
define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
|
|
24
22
|
measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
|
|
25
23
|
If ``None``, defaults to the cell measure divided by the number of particles in the cell.
|
|
24
|
+
requires_grad: Whether gradients should be allocated for the computed quantities
|
|
26
25
|
temporary_store: shared pool from which to allocate temporary arrays
|
|
27
26
|
"""
|
|
28
27
|
|
|
@@ -37,11 +36,14 @@ class PicQuadrature(Quadrature):
|
|
|
37
36
|
],
|
|
38
37
|
],
|
|
39
38
|
measures: Optional["wp.array(dtype=float)"] = None,
|
|
39
|
+
requires_grad: bool = False,
|
|
40
40
|
temporary_store: TemporaryStore = None,
|
|
41
41
|
):
|
|
42
42
|
super().__init__(domain)
|
|
43
43
|
|
|
44
|
+
self._requires_grad = requires_grad
|
|
44
45
|
self._bin_particles(positions, measures, temporary_store)
|
|
46
|
+
self._max_particles_per_cell: int = None
|
|
45
47
|
|
|
46
48
|
@property
|
|
47
49
|
def name(self):
|
|
@@ -82,22 +84,40 @@ class PicQuadrature(Quadrature):
|
|
|
82
84
|
"""Number of cells containing at least one particle"""
|
|
83
85
|
return self._cell_count
|
|
84
86
|
|
|
87
|
+
def max_points_per_element(self):
|
|
88
|
+
if self._max_particles_per_cell is None:
|
|
89
|
+
max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.array.device)
|
|
90
|
+
wp.launch(
|
|
91
|
+
PicQuadrature._max_particles_per_cell_kernel,
|
|
92
|
+
self._cell_particle_offsets.array.shape[0] - 1,
|
|
93
|
+
device=max_ppc.device,
|
|
94
|
+
inputs=[self._cell_particle_offsets.array, max_ppc],
|
|
95
|
+
)
|
|
96
|
+
self._max_particles_per_cell = int(max_ppc.numpy()[0])
|
|
97
|
+
return self._max_particles_per_cell
|
|
98
|
+
|
|
85
99
|
@wp.func
|
|
86
|
-
def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
|
|
100
|
+
def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
|
|
87
101
|
return qp_arg.cell_particle_offsets[element_index + 1] - qp_arg.cell_particle_offsets[element_index]
|
|
88
102
|
|
|
89
103
|
@wp.func
|
|
90
|
-
def point_coords(
|
|
104
|
+
def point_coords(
|
|
105
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
|
|
106
|
+
):
|
|
91
107
|
particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
|
|
92
108
|
return qp_arg.particle_coords[particle_index]
|
|
93
109
|
|
|
94
110
|
@wp.func
|
|
95
|
-
def point_weight(
|
|
111
|
+
def point_weight(
|
|
112
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
|
|
113
|
+
):
|
|
96
114
|
particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
|
|
97
115
|
return qp_arg.particle_fraction[particle_index]
|
|
98
116
|
|
|
99
117
|
@wp.func
|
|
100
|
-
def point_index(
|
|
118
|
+
def point_index(
|
|
119
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
|
|
120
|
+
):
|
|
101
121
|
particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
|
|
102
122
|
return particle_index
|
|
103
123
|
|
|
@@ -158,7 +178,7 @@ class PicQuadrature(Quadrature):
|
|
|
158
178
|
cell_index = cell_index_temp.array
|
|
159
179
|
|
|
160
180
|
self._particle_coords_temp = borrow_temporary(
|
|
161
|
-
temporary_store, shape=positions.shape, dtype=Coords, device=device
|
|
181
|
+
temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
|
|
162
182
|
)
|
|
163
183
|
self._particle_coords = self._particle_coords_temp.array
|
|
164
184
|
|
|
@@ -183,7 +203,7 @@ class PicQuadrature(Quadrature):
|
|
|
183
203
|
self._particle_coords_temp = None
|
|
184
204
|
|
|
185
205
|
self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
|
|
186
|
-
self.domain.geometry_element_count(), cell_index
|
|
206
|
+
self.domain.geometry_element_count(), cell_index, return_unique_nodes=True, temporary_store=temporary_store
|
|
187
207
|
)
|
|
188
208
|
|
|
189
209
|
self._compute_fraction(cell_index, measures, temporary_store)
|
|
@@ -192,7 +212,7 @@ class PicQuadrature(Quadrature):
|
|
|
192
212
|
device = cell_index.device
|
|
193
213
|
|
|
194
214
|
self._particle_fraction_temp = borrow_temporary(
|
|
195
|
-
temporary_store, shape=cell_index.shape, dtype=float, device=device
|
|
215
|
+
temporary_store, shape=cell_index.shape, dtype=float, device=device, requires_grad=self._requires_grad
|
|
196
216
|
)
|
|
197
217
|
self._particle_fraction = self._particle_fraction_temp.array
|
|
198
218
|
|
|
@@ -241,3 +261,9 @@ class PicQuadrature(Quadrature):
|
|
|
241
261
|
],
|
|
242
262
|
device=device,
|
|
243
263
|
)
|
|
264
|
+
|
|
265
|
+
@wp.kernel
|
|
266
|
+
def _max_particles_per_cell_kernel(offsets: wp.array(dtype=int), max_count: wp.array(dtype=int)):
|
|
267
|
+
cell = wp.tid()
|
|
268
|
+
particle_count = offsets[cell + 1] - offsets[cell]
|
|
269
|
+
wp.atomic_max(max_count, 0, particle_count)
|
|
@@ -36,8 +36,8 @@ class Quadrature:
|
|
|
36
36
|
"""Total number of quadrature points over the domain"""
|
|
37
37
|
raise NotImplementedError()
|
|
38
38
|
|
|
39
|
-
def
|
|
40
|
-
"""
|
|
39
|
+
def max_points_per_element(self):
|
|
40
|
+
"""Maximum number of points per element if known, or ``None`` otherwise"""
|
|
41
41
|
return None
|
|
42
42
|
|
|
43
43
|
@staticmethod
|
|
@@ -61,7 +61,11 @@ class Quadrature:
|
|
|
61
61
|
|
|
62
62
|
@staticmethod
|
|
63
63
|
def point_index(
|
|
64
|
-
elt_arg: "domain.GeometryDomain.ElementArg",
|
|
64
|
+
elt_arg: "domain.GeometryDomain.ElementArg",
|
|
65
|
+
qp_arg: Arg,
|
|
66
|
+
domain_element_index: ElementIndex,
|
|
67
|
+
geo_element_index: ElementIndex,
|
|
68
|
+
element_qp_index: int,
|
|
65
69
|
):
|
|
66
70
|
"""Global index of the element's qp_index'th quadrature point"""
|
|
67
71
|
raise NotImplementedError()
|
|
@@ -106,7 +110,7 @@ class RegularQuadrature(Quadrature):
|
|
|
106
110
|
def total_point_count(self):
|
|
107
111
|
return len(self.points) * self.domain.geometry_element_count()
|
|
108
112
|
|
|
109
|
-
def
|
|
113
|
+
def max_points_per_element(self):
|
|
110
114
|
return self._N
|
|
111
115
|
|
|
112
116
|
@property
|
|
@@ -121,7 +125,12 @@ class RegularQuadrature(Quadrature):
|
|
|
121
125
|
N = self._N
|
|
122
126
|
|
|
123
127
|
@cache.dynamic_func(suffix=self.name)
|
|
124
|
-
def point_count(
|
|
128
|
+
def point_count(
|
|
129
|
+
elt_arg: self.domain.ElementArg,
|
|
130
|
+
qp_arg: self.Arg,
|
|
131
|
+
domain_element_index: ElementIndex,
|
|
132
|
+
element_index: ElementIndex,
|
|
133
|
+
):
|
|
125
134
|
return N
|
|
126
135
|
|
|
127
136
|
return point_count
|
|
@@ -130,7 +139,13 @@ class RegularQuadrature(Quadrature):
|
|
|
130
139
|
POINTS = self._POINTS
|
|
131
140
|
|
|
132
141
|
@cache.dynamic_func(suffix=self.name)
|
|
133
|
-
def point_coords(
|
|
142
|
+
def point_coords(
|
|
143
|
+
elt_arg: self.domain.ElementArg,
|
|
144
|
+
qp_arg: self.Arg,
|
|
145
|
+
domain_element_index: ElementIndex,
|
|
146
|
+
element_index: ElementIndex,
|
|
147
|
+
qp_index: int,
|
|
148
|
+
):
|
|
134
149
|
return Coords(POINTS[qp_index, 0], POINTS[qp_index, 1], POINTS[qp_index, 2])
|
|
135
150
|
|
|
136
151
|
return point_coords
|
|
@@ -139,7 +154,13 @@ class RegularQuadrature(Quadrature):
|
|
|
139
154
|
WEIGHTS = self._WEIGHTS
|
|
140
155
|
|
|
141
156
|
@cache.dynamic_func(suffix=self.name)
|
|
142
|
-
def point_weight(
|
|
157
|
+
def point_weight(
|
|
158
|
+
elt_arg: self.domain.ElementArg,
|
|
159
|
+
qp_arg: self.Arg,
|
|
160
|
+
domain_element_index: ElementIndex,
|
|
161
|
+
element_index: ElementIndex,
|
|
162
|
+
qp_index: int,
|
|
163
|
+
):
|
|
143
164
|
return WEIGHTS[qp_index]
|
|
144
165
|
|
|
145
166
|
return point_weight
|
|
@@ -148,8 +169,14 @@ class RegularQuadrature(Quadrature):
|
|
|
148
169
|
N = self._N
|
|
149
170
|
|
|
150
171
|
@cache.dynamic_func(suffix=self.name)
|
|
151
|
-
def point_index(
|
|
152
|
-
|
|
172
|
+
def point_index(
|
|
173
|
+
elt_arg: self.domain.ElementArg,
|
|
174
|
+
qp_arg: self.Arg,
|
|
175
|
+
domain_element_index: ElementIndex,
|
|
176
|
+
element_index: ElementIndex,
|
|
177
|
+
qp_index: int,
|
|
178
|
+
):
|
|
179
|
+
return N * domain_element_index + qp_index
|
|
153
180
|
|
|
154
181
|
return point_index
|
|
155
182
|
|
|
@@ -157,8 +184,8 @@ class RegularQuadrature(Quadrature):
|
|
|
157
184
|
class NodalQuadrature(Quadrature):
|
|
158
185
|
"""Quadrature using space node points as quadrature points
|
|
159
186
|
|
|
160
|
-
Note that in contrast to the `nodal=True` flag for :func:`integrate`, this quadrature
|
|
161
|
-
about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
|
|
187
|
+
Note that in contrast to the `nodal=True` flag for :func:`integrate`, using this quadrature does not imply
|
|
188
|
+
any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
|
|
162
189
|
"""
|
|
163
190
|
|
|
164
191
|
def __init__(self, domain: domain.GeometryDomain, space: FunctionSpace):
|
|
@@ -180,8 +207,8 @@ class NodalQuadrature(Quadrature):
|
|
|
180
207
|
def total_point_count(self):
|
|
181
208
|
return self._space.node_count()
|
|
182
209
|
|
|
183
|
-
def
|
|
184
|
-
return self._space.topology.
|
|
210
|
+
def max_points_per_element(self):
|
|
211
|
+
return self._space.topology.MAX_NODES_PER_ELEMENT
|
|
185
212
|
|
|
186
213
|
def _make_arg(self):
|
|
187
214
|
@cache.dynamic_struct(suffix=self.name)
|
|
@@ -199,44 +226,67 @@ class NodalQuadrature(Quadrature):
|
|
|
199
226
|
return arg
|
|
200
227
|
|
|
201
228
|
def _make_point_count(self):
|
|
202
|
-
N = self._space.topology.NODES_PER_ELEMENT
|
|
203
|
-
|
|
204
229
|
@cache.dynamic_func(suffix=self.name)
|
|
205
|
-
def point_count(
|
|
206
|
-
|
|
230
|
+
def point_count(
|
|
231
|
+
elt_arg: self.domain.ElementArg,
|
|
232
|
+
qp_arg: self.Arg,
|
|
233
|
+
domain_element_index: ElementIndex,
|
|
234
|
+
element_index: ElementIndex,
|
|
235
|
+
):
|
|
236
|
+
return self._space.topology.element_node_count(elt_arg, qp_arg.topo_arg, element_index)
|
|
207
237
|
|
|
208
238
|
return point_count
|
|
209
239
|
|
|
210
240
|
def _make_point_coords(self):
|
|
211
241
|
@cache.dynamic_func(suffix=self.name)
|
|
212
|
-
def point_coords(
|
|
242
|
+
def point_coords(
|
|
243
|
+
elt_arg: self.domain.ElementArg,
|
|
244
|
+
qp_arg: self.Arg,
|
|
245
|
+
domain_element_index: ElementIndex,
|
|
246
|
+
element_index: ElementIndex,
|
|
247
|
+
qp_index: int,
|
|
248
|
+
):
|
|
213
249
|
return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
|
|
214
250
|
|
|
215
251
|
return point_coords
|
|
216
252
|
|
|
217
253
|
def _make_point_weight(self):
|
|
218
254
|
@cache.dynamic_func(suffix=self.name)
|
|
219
|
-
def point_weight(
|
|
255
|
+
def point_weight(
|
|
256
|
+
elt_arg: self.domain.ElementArg,
|
|
257
|
+
qp_arg: self.Arg,
|
|
258
|
+
domain_element_index: ElementIndex,
|
|
259
|
+
element_index: ElementIndex,
|
|
260
|
+
qp_index: int,
|
|
261
|
+
):
|
|
220
262
|
return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
|
|
221
263
|
|
|
222
264
|
return point_weight
|
|
223
265
|
|
|
224
266
|
def _make_point_index(self):
|
|
225
267
|
@cache.dynamic_func(suffix=self.name)
|
|
226
|
-
def point_index(
|
|
268
|
+
def point_index(
|
|
269
|
+
elt_arg: self.domain.ElementArg,
|
|
270
|
+
qp_arg: self.Arg,
|
|
271
|
+
domain_element_index: ElementIndex,
|
|
272
|
+
element_index: ElementIndex,
|
|
273
|
+
qp_index: int,
|
|
274
|
+
):
|
|
227
275
|
return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
|
|
228
276
|
|
|
229
277
|
return point_index
|
|
230
278
|
|
|
231
279
|
|
|
232
280
|
class ExplicitQuadrature(Quadrature):
|
|
233
|
-
"""Quadrature using explicit per-cell points and weights.
|
|
234
|
-
|
|
281
|
+
"""Quadrature using explicit per-cell points and weights.
|
|
282
|
+
|
|
283
|
+
The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
|
|
284
|
+
Quadrature points may be provided for either the whole geometry or just the domain's elements.
|
|
235
285
|
|
|
236
286
|
Args:
|
|
237
287
|
domain: Domain of definition of the quadrature formula
|
|
238
|
-
points: 2d array of shape ``(domain.
|
|
239
|
-
weights: 2d array of shape ``(domain.
|
|
288
|
+
points: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the coordinates of each quadrature point.
|
|
289
|
+
weights: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the weight for each quadrature point.
|
|
240
290
|
|
|
241
291
|
See also: :class:`PicQuadrature`
|
|
242
292
|
"""
|
|
@@ -255,41 +305,78 @@ class ExplicitQuadrature(Quadrature):
|
|
|
255
305
|
if points.shape != weights.shape:
|
|
256
306
|
raise ValueError("Points and weights arrays must have the same shape")
|
|
257
307
|
|
|
308
|
+
if points.shape[0] == domain.geometry_element_count():
|
|
309
|
+
self.point_index = ExplicitQuadrature._point_index_geo
|
|
310
|
+
self.point_coords = ExplicitQuadrature._point_coords_geo
|
|
311
|
+
self.point_weight = ExplicitQuadrature._point_weight_geo
|
|
312
|
+
elif points.shape[0] == domain.element_count():
|
|
313
|
+
self.point_index = ExplicitQuadrature._point_index_domain
|
|
314
|
+
self.point_coords = ExplicitQuadrature._point_coords_domain
|
|
315
|
+
self.point_weight = ExplicitQuadrature._point_weight_domain
|
|
316
|
+
else:
|
|
317
|
+
raise NotImplementedError(
|
|
318
|
+
"The number of rows of points and weights must match the element count of either the domain or the geometry"
|
|
319
|
+
)
|
|
320
|
+
|
|
258
321
|
self._points_per_cell = points.shape[1]
|
|
322
|
+
self._whole_geo = points.shape[0] == domain.geometry_element_count()
|
|
259
323
|
self._points = points
|
|
260
324
|
self._weights = weights
|
|
261
325
|
|
|
262
326
|
@property
|
|
263
327
|
def name(self):
|
|
264
|
-
return f"{self.__class__.__name__}"
|
|
328
|
+
return f"{self.__class__.__name__}_{self._whole_geo}"
|
|
265
329
|
|
|
266
330
|
def total_point_count(self):
|
|
267
331
|
return self._weights.size
|
|
268
332
|
|
|
269
|
-
def
|
|
333
|
+
def max_points_per_element(self):
|
|
270
334
|
return self._points_per_cell
|
|
271
335
|
|
|
272
336
|
@cache.cached_arg_value
|
|
273
337
|
def arg_value(self, device):
|
|
274
338
|
arg = self.Arg()
|
|
275
|
-
arg.points_per_cell = self._points_per_cell
|
|
276
339
|
arg.points = self._points.to(device)
|
|
277
340
|
arg.weights = self._weights.to(device)
|
|
278
341
|
|
|
279
342
|
return arg
|
|
280
343
|
|
|
281
344
|
@wp.func
|
|
282
|
-
def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
|
|
283
|
-
return qp_arg.
|
|
345
|
+
def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
|
|
346
|
+
return qp_arg.points.shape[1]
|
|
284
347
|
|
|
285
348
|
@wp.func
|
|
286
|
-
def
|
|
349
|
+
def _point_coords_domain(
|
|
350
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
351
|
+
):
|
|
352
|
+
return qp_arg.points[domain_element_index, qp_index]
|
|
353
|
+
|
|
354
|
+
@wp.func
|
|
355
|
+
def _point_weight_domain(
|
|
356
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
357
|
+
):
|
|
358
|
+
return qp_arg.weights[domain_element_index, qp_index]
|
|
359
|
+
|
|
360
|
+
@wp.func
|
|
361
|
+
def _point_index_domain(
|
|
362
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
363
|
+
):
|
|
364
|
+
return qp_arg.points_per_cell * domain_element_index + qp_index
|
|
365
|
+
|
|
366
|
+
@wp.func
|
|
367
|
+
def _point_coords_geo(
|
|
368
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
369
|
+
):
|
|
287
370
|
return qp_arg.points[element_index, qp_index]
|
|
288
371
|
|
|
289
372
|
@wp.func
|
|
290
|
-
def
|
|
373
|
+
def _point_weight_geo(
|
|
374
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
375
|
+
):
|
|
291
376
|
return qp_arg.weights[element_index, qp_index]
|
|
292
377
|
|
|
293
378
|
@wp.func
|
|
294
|
-
def
|
|
379
|
+
def _point_index_geo(
|
|
380
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
381
|
+
):
|
|
295
382
|
return qp_arg.points_per_cell * element_index + qp_index
|