warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -378
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/space/partition.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
2
|
|
|
3
3
|
import warp as wp
|
|
4
|
-
|
|
4
|
+
from warp.fem.cache import (
|
|
5
|
+
TemporaryStore,
|
|
6
|
+
borrow_temporary,
|
|
7
|
+
borrow_temporary_like,
|
|
8
|
+
cached_arg_value,
|
|
9
|
+
)
|
|
5
10
|
from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
|
|
6
|
-
from warp.fem.utils import compress_node_indices, _iota_kernel
|
|
7
11
|
from warp.fem.types import NULL_NODE_INDEX
|
|
12
|
+
from warp.fem.utils import _iota_kernel, compress_node_indices
|
|
8
13
|
|
|
9
14
|
from .function_space import FunctionSpace
|
|
10
|
-
|
|
15
|
+
from .topology import SpaceTopology
|
|
11
16
|
|
|
12
17
|
wp.set_module_options({"enable_backward": False})
|
|
13
18
|
|
|
@@ -16,8 +21,8 @@ class SpacePartition:
|
|
|
16
21
|
class PartitionArg:
|
|
17
22
|
pass
|
|
18
23
|
|
|
19
|
-
def __init__(self,
|
|
20
|
-
self.
|
|
24
|
+
def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
|
|
25
|
+
self.space_topology = space_topology
|
|
21
26
|
self.geo_partition = geo_partition
|
|
22
27
|
|
|
23
28
|
def node_count(self):
|
|
@@ -35,7 +40,8 @@ class SpacePartition:
|
|
|
35
40
|
def partition_arg_value(self, device):
|
|
36
41
|
pass
|
|
37
42
|
|
|
38
|
-
|
|
43
|
+
@staticmethod
|
|
44
|
+
def partition_node_index(args: "PartitionArg", space_node_index: int):
|
|
39
45
|
"""Returns the index in the partition of a function space node, or -1 if it does not exist"""
|
|
40
46
|
|
|
41
47
|
def __str__(self) -> str:
|
|
@@ -51,28 +57,28 @@ class WholeSpacePartition(SpacePartition):
|
|
|
51
57
|
class PartitionArg:
|
|
52
58
|
pass
|
|
53
59
|
|
|
54
|
-
def __init__(self,
|
|
55
|
-
super().__init__(
|
|
60
|
+
def __init__(self, space_topology: SpaceTopology):
|
|
61
|
+
super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
|
|
56
62
|
self._node_indices = None
|
|
57
63
|
|
|
58
64
|
def node_count(self):
|
|
59
65
|
"""Returns number of nodes in this partition"""
|
|
60
|
-
return self.
|
|
66
|
+
return self.space_topology.node_count()
|
|
61
67
|
|
|
62
68
|
def owned_node_count(self) -> int:
|
|
63
69
|
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
64
|
-
return self.
|
|
70
|
+
return self.space_topology.node_count()
|
|
65
71
|
|
|
66
72
|
def interior_node_count(self) -> int:
|
|
67
73
|
"""Returns number of interior nodes in this partition"""
|
|
68
|
-
return self.
|
|
74
|
+
return self.space_topology.node_count()
|
|
69
75
|
|
|
70
76
|
def space_node_indices(self):
|
|
71
77
|
"""Return the global function space indices for nodes in this partition"""
|
|
72
78
|
if self._node_indices is None:
|
|
73
|
-
self._node_indices =
|
|
74
|
-
wp.launch(kernel=_iota_kernel, dim=self.
|
|
75
|
-
return self._node_indices
|
|
79
|
+
self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
|
|
80
|
+
wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
|
|
81
|
+
return self._node_indices.array
|
|
76
82
|
|
|
77
83
|
def partition_arg_value(self, device):
|
|
78
84
|
return WholeSpacePartition.PartitionArg()
|
|
@@ -82,7 +88,11 @@ class WholeSpacePartition(SpacePartition):
|
|
|
82
88
|
return space_node_index
|
|
83
89
|
|
|
84
90
|
def __eq__(self, other: SpacePartition) -> bool:
|
|
85
|
-
return isinstance(other, SpacePartition) and self.
|
|
91
|
+
return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def name(self) -> str:
|
|
95
|
+
return "Whole"
|
|
86
96
|
|
|
87
97
|
|
|
88
98
|
class NodeCategory:
|
|
@@ -105,46 +115,56 @@ class NodePartition(SpacePartition):
|
|
|
105
115
|
class PartitionArg:
|
|
106
116
|
space_to_partition: wp.array(dtype=int)
|
|
107
117
|
|
|
108
|
-
def __init__(
|
|
109
|
-
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
space_topology: SpaceTopology,
|
|
121
|
+
geo_partition: GeometryPartition,
|
|
122
|
+
with_halo: bool = True,
|
|
123
|
+
device=None,
|
|
124
|
+
temporary_store: TemporaryStore = None,
|
|
125
|
+
):
|
|
126
|
+
super().__init__(space_topology=space_topology, geo_partition=geo_partition)
|
|
110
127
|
|
|
111
|
-
self._compute_node_indices_from_sides(device, with_halo)
|
|
128
|
+
self._compute_node_indices_from_sides(device, with_halo, temporary_store)
|
|
112
129
|
|
|
113
130
|
def node_count(self) -> int:
|
|
114
131
|
"""Returns number of nodes referenced by this partition, including exterior halo"""
|
|
115
|
-
return int(self._category_offsets[NodeCategory.HALO_OTHER_SIDE + 1])
|
|
132
|
+
return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
|
|
116
133
|
|
|
117
134
|
def owned_node_count(self) -> int:
|
|
118
135
|
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
119
|
-
return int(self._category_offsets[NodeCategory.OWNED_FRONTIER + 1])
|
|
136
|
+
return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
|
|
120
137
|
|
|
121
138
|
def interior_node_count(self) -> int:
|
|
122
139
|
"""Returns number of interior nodes in this partition"""
|
|
123
|
-
return int(self._category_offsets[NodeCategory.OWNED_INTERIOR + 1])
|
|
140
|
+
return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
|
|
124
141
|
|
|
125
142
|
def space_node_indices(self):
|
|
126
143
|
"""Return the global function space indices for nodes in this partition"""
|
|
127
|
-
return self._node_indices
|
|
144
|
+
return self._node_indices.array
|
|
128
145
|
|
|
146
|
+
@cached_arg_value
|
|
129
147
|
def partition_arg_value(self, device):
|
|
130
148
|
arg = NodePartition.PartitionArg()
|
|
131
|
-
arg.space_to_partition = self._space_to_partition.to(device)
|
|
149
|
+
arg.space_to_partition = self._space_to_partition.array.to(device)
|
|
132
150
|
return arg
|
|
133
151
|
|
|
134
152
|
@wp.func
|
|
135
153
|
def partition_node_index(args: PartitionArg, space_node_index: int):
|
|
136
154
|
return args.space_to_partition[space_node_index]
|
|
137
155
|
|
|
138
|
-
def _compute_node_indices_from_sides(self, device, with_halo: bool):
|
|
156
|
+
def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
|
|
139
157
|
from warp.fem import cache
|
|
140
158
|
|
|
141
|
-
|
|
142
|
-
NODES_PER_CELL = self.
|
|
143
|
-
NODES_PER_SIDE =
|
|
159
|
+
trace_topology = self.space_topology.trace()
|
|
160
|
+
NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
|
|
161
|
+
NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
|
|
144
162
|
|
|
145
|
-
|
|
163
|
+
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
164
|
+
def node_category_from_cells_kernel(
|
|
165
|
+
geo_arg: self.geo_partition.geometry.CellArg,
|
|
146
166
|
geo_partition_arg: self.geo_partition.CellArg,
|
|
147
|
-
space_arg: self.
|
|
167
|
+
space_arg: self.space_topology.TopologyArg,
|
|
148
168
|
node_mask: wp.array(dtype=int),
|
|
149
169
|
):
|
|
150
170
|
partition_cell_index = wp.tid()
|
|
@@ -152,12 +172,14 @@ class NodePartition(SpacePartition):
|
|
|
152
172
|
cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
|
|
153
173
|
|
|
154
174
|
for n in range(NODES_PER_CELL):
|
|
155
|
-
space_nidx = self.
|
|
175
|
+
space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
|
|
156
176
|
node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
|
|
157
177
|
|
|
158
|
-
|
|
178
|
+
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
179
|
+
def node_category_from_owned_sides_kernel(
|
|
180
|
+
geo_arg: self.geo_partition.geometry.SideArg,
|
|
159
181
|
geo_partition_arg: self.geo_partition.SideArg,
|
|
160
|
-
space_arg:
|
|
182
|
+
space_arg: trace_topology.TopologyArg,
|
|
161
183
|
node_mask: wp.array(dtype=int),
|
|
162
184
|
):
|
|
163
185
|
partition_side_index = wp.tid()
|
|
@@ -165,13 +187,16 @@ class NodePartition(SpacePartition):
|
|
|
165
187
|
side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
|
|
166
188
|
|
|
167
189
|
for n in range(NODES_PER_SIDE):
|
|
168
|
-
space_nidx =
|
|
190
|
+
space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
|
|
191
|
+
|
|
169
192
|
if node_mask[space_nidx] == NodeCategory.EXTERIOR:
|
|
170
193
|
node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
|
|
171
194
|
|
|
172
|
-
|
|
195
|
+
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
196
|
+
def node_category_from_frontier_sides_kernel(
|
|
197
|
+
geo_arg: self.geo_partition.geometry.SideArg,
|
|
173
198
|
geo_partition_arg: self.geo_partition.SideArg,
|
|
174
|
-
space_arg:
|
|
199
|
+
space_arg: trace_topology.TopologyArg,
|
|
175
200
|
node_mask: wp.array(dtype=int),
|
|
176
201
|
):
|
|
177
202
|
frontier_side_index = wp.tid()
|
|
@@ -179,39 +204,28 @@ class NodePartition(SpacePartition):
|
|
|
179
204
|
side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
|
|
180
205
|
|
|
181
206
|
for n in range(NODES_PER_SIDE):
|
|
182
|
-
space_nidx =
|
|
207
|
+
space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
|
|
183
208
|
if node_mask[space_nidx] == NodeCategory.EXTERIOR:
|
|
184
209
|
node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
|
|
185
210
|
elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
|
|
186
211
|
node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
|
|
187
212
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
)
|
|
192
|
-
node_category_from_owned_sides_kernel = cache.get_kernel(
|
|
193
|
-
node_category_from_owned_sides_fn,
|
|
194
|
-
suffix=f"{self.geo_partition.name}_{self.space.name}",
|
|
195
|
-
)
|
|
196
|
-
node_category_from_frontier_sides_kernel = cache.get_kernel(
|
|
197
|
-
node_category_from_frontier_sides_fn,
|
|
198
|
-
suffix=f"{self.geo_partition.name}_{self.space.name}",
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
node_category = wp.empty(
|
|
202
|
-
shape=(self.space.node_count(),),
|
|
213
|
+
node_category = borrow_temporary(
|
|
214
|
+
temporary_store,
|
|
215
|
+
shape=(self.space_topology.node_count(),),
|
|
203
216
|
dtype=int,
|
|
204
217
|
device=device,
|
|
205
218
|
)
|
|
206
|
-
node_category.fill_(value=NodeCategory.EXTERIOR)
|
|
219
|
+
node_category.array.fill_(value=NodeCategory.EXTERIOR)
|
|
207
220
|
|
|
208
221
|
wp.launch(
|
|
209
222
|
dim=self.geo_partition.cell_count(),
|
|
210
223
|
kernel=node_category_from_cells_kernel,
|
|
211
224
|
inputs=[
|
|
225
|
+
self.geo_partition.geometry.cell_arg_value(device),
|
|
212
226
|
self.geo_partition.cell_arg_value(device),
|
|
213
|
-
self.
|
|
214
|
-
node_category,
|
|
227
|
+
self.space_topology.topo_arg_value(device),
|
|
228
|
+
node_category.array,
|
|
215
229
|
],
|
|
216
230
|
device=device,
|
|
217
231
|
)
|
|
@@ -221,9 +235,10 @@ class NodePartition(SpacePartition):
|
|
|
221
235
|
dim=self.geo_partition.side_count(),
|
|
222
236
|
kernel=node_category_from_owned_sides_kernel,
|
|
223
237
|
inputs=[
|
|
238
|
+
self.geo_partition.geometry.side_arg_value(device),
|
|
224
239
|
self.geo_partition.side_arg_value(device),
|
|
225
|
-
self.
|
|
226
|
-
node_category,
|
|
240
|
+
self.space_topology.topo_arg_value(device),
|
|
241
|
+
node_category.array,
|
|
227
242
|
],
|
|
228
243
|
device=device,
|
|
229
244
|
)
|
|
@@ -232,31 +247,52 @@ class NodePartition(SpacePartition):
|
|
|
232
247
|
dim=self.geo_partition.frontier_side_count(),
|
|
233
248
|
kernel=node_category_from_frontier_sides_kernel,
|
|
234
249
|
inputs=[
|
|
250
|
+
self.geo_partition.geometry.side_arg_value(device),
|
|
235
251
|
self.geo_partition.side_arg_value(device),
|
|
236
|
-
self.
|
|
237
|
-
node_category,
|
|
252
|
+
self.space_topology.topo_arg_value(device),
|
|
253
|
+
node_category.array,
|
|
238
254
|
],
|
|
239
255
|
device=device,
|
|
240
256
|
)
|
|
241
257
|
|
|
242
|
-
self._finalize_node_indices(node_category)
|
|
258
|
+
self._finalize_node_indices(node_category.array, temporary_store)
|
|
243
259
|
|
|
244
|
-
|
|
260
|
+
node_category.release()
|
|
261
|
+
|
|
262
|
+
def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
|
|
245
263
|
category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
|
|
246
|
-
self._category_offsets = category_offsets.numpy()
|
|
247
264
|
|
|
248
|
-
#
|
|
249
|
-
|
|
265
|
+
# Copy offsets to cpu
|
|
266
|
+
device = node_category.device
|
|
267
|
+
self._category_offsets = borrow_temporary(
|
|
268
|
+
temporary_store,
|
|
269
|
+
shape=category_offsets.array.shape,
|
|
270
|
+
dtype=category_offsets.array.dtype,
|
|
271
|
+
pinned=device.is_cuda,
|
|
272
|
+
device="cpu",
|
|
273
|
+
)
|
|
274
|
+
wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
|
|
275
|
+
|
|
276
|
+
if device.is_cuda:
|
|
277
|
+
# TODO switch to synchronize_event once available
|
|
278
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
279
|
+
|
|
280
|
+
category_offsets.release()
|
|
281
|
+
|
|
282
|
+
# Compute global to local indices
|
|
283
|
+
self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
|
|
250
284
|
wp.launch(
|
|
251
285
|
kernel=NodePartition._scatter_partition_indices,
|
|
252
|
-
dim=self.
|
|
253
|
-
device=
|
|
254
|
-
inputs=[self.node_count(), node_indices, self._space_to_partition],
|
|
286
|
+
dim=self.space_topology.node_count(),
|
|
287
|
+
device=device,
|
|
288
|
+
inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
|
|
255
289
|
)
|
|
256
290
|
|
|
257
|
-
# Copy to shrinked-to-fit array
|
|
258
|
-
self._node_indices =
|
|
259
|
-
wp.copy(dest=self._node_indices, src=node_indices, count=self.node_count())
|
|
291
|
+
# Copy to shrinked-to-fit array
|
|
292
|
+
self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
|
|
293
|
+
wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
|
|
294
|
+
|
|
295
|
+
node_indices.release()
|
|
260
296
|
|
|
261
297
|
@wp.kernel
|
|
262
298
|
def _scatter_partition_indices(
|
|
@@ -274,16 +310,21 @@ class NodePartition(SpacePartition):
|
|
|
274
310
|
|
|
275
311
|
|
|
276
312
|
def make_space_partition(
|
|
277
|
-
space: FunctionSpace,
|
|
313
|
+
space: Optional[FunctionSpace] = None,
|
|
278
314
|
geometry_partition: Optional[GeometryPartition] = None,
|
|
315
|
+
space_topology: Optional[SpaceTopology] = None,
|
|
279
316
|
with_halo: bool = True,
|
|
280
317
|
device=None,
|
|
318
|
+
temporary_store: TemporaryStore = None,
|
|
281
319
|
) -> SpacePartition:
|
|
282
|
-
"""Computes the
|
|
320
|
+
"""Computes the subset of nodes from a function space topology that touch a geometry partition
|
|
321
|
+
|
|
322
|
+
Either `space_topology` or `space` must be provided (and will be considered in that order).
|
|
283
323
|
|
|
284
324
|
Args:
|
|
285
|
-
space: the function space
|
|
325
|
+
space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
|
|
286
326
|
geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
|
|
327
|
+
space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
|
|
287
328
|
with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
|
|
288
329
|
device: Warp device on which to perform and store computations
|
|
289
330
|
|
|
@@ -291,7 +332,19 @@ def make_space_partition(
|
|
|
291
332
|
the resulting space partition
|
|
292
333
|
"""
|
|
293
334
|
|
|
294
|
-
if
|
|
295
|
-
|
|
335
|
+
if space_topology is None:
|
|
336
|
+
space_topology = space.topology
|
|
337
|
+
|
|
338
|
+
space_topology = space_topology.full_space_topology()
|
|
339
|
+
|
|
340
|
+
if geometry_partition is not None:
|
|
341
|
+
if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
|
|
342
|
+
return NodePartition(
|
|
343
|
+
space_topology=space_topology,
|
|
344
|
+
geo_partition=geometry_partition,
|
|
345
|
+
with_halo=with_halo,
|
|
346
|
+
device=device,
|
|
347
|
+
temporary_store=temporary_store,
|
|
348
|
+
)
|
|
296
349
|
|
|
297
|
-
return WholeSpacePartition(
|
|
350
|
+
return WholeSpacePartition(space_topology)
|