warp-lang 1.0.2__py3-none-win_amd64.whl → 1.2.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +88 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3693 -3354
- warp/codegen.py +2925 -2792
- warp/config.py +40 -36
- warp/constants.py +49 -45
- warp/context.py +5409 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +381 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
- warp/examples/benchmarks/benchmark_launches.py +293 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +232 -219
- warp/examples/core/example_fluid.py +291 -267
- warp/examples/core/example_graph_capture.py +142 -126
- warp/examples/core/example_marching_cubes.py +186 -174
- warp/examples/core/example_mesh.py +172 -155
- warp/examples/core/example_mesh_intersect.py +203 -193
- warp/examples/core/example_nvdb.py +174 -170
- warp/examples/core/example_raycast.py +103 -90
- warp/examples/core/example_raymarch.py +197 -178
- warp/examples/core/example_render_opengl.py +183 -141
- warp/examples/core/example_sph.py +403 -387
- warp/examples/core/example_torch.py +219 -181
- warp/examples/core/example_wave.py +261 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +432 -389
- warp/examples/fem/example_burgers.py +262 -0
- warp/examples/fem/example_convection_diffusion.py +180 -168
- warp/examples/fem/example_convection_diffusion_dg.py +217 -209
- warp/examples/fem/example_deformed_geometry.py +175 -159
- warp/examples/fem/example_diffusion.py +199 -173
- warp/examples/fem/example_diffusion_3d.py +178 -152
- warp/examples/fem/example_diffusion_mgpu.py +219 -214
- warp/examples/fem/example_mixed_elasticity.py +242 -222
- warp/examples/fem/example_navier_stokes.py +257 -243
- warp/examples/fem/example_stokes.py +218 -192
- warp/examples/fem/example_stokes_transfer.py +263 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +258 -246
- warp/examples/optim/example_cloth_throw.py +220 -209
- warp/examples/optim/example_diffray.py +564 -536
- warp/examples/optim/example_drone.py +862 -835
- warp/examples/optim/example_inverse_kinematics.py +174 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
- warp/examples/optim/example_spring_cage.py +237 -231
- warp/examples/optim/example_trajectory.py +221 -199
- warp/examples/optim/example_walker.py +304 -293
- warp/examples/sim/example_cartpole.py +137 -129
- warp/examples/sim/example_cloth.py +194 -186
- warp/examples/sim/example_granular.py +122 -111
- warp/examples/sim/example_granular_collision_sdf.py +195 -186
- warp/examples/sim/example_jacobian_ik.py +234 -214
- warp/examples/sim/example_particle_chain.py +116 -105
- warp/examples/sim/example_quadruped.py +191 -180
- warp/examples/sim/example_rigid_chain.py +195 -187
- warp/examples/sim/example_rigid_contact.py +187 -177
- warp/examples/sim/example_rigid_force.py +125 -125
- warp/examples/sim/example_rigid_gyroscopic.py +107 -95
- warp/examples/sim/example_rigid_soft_contact.py +132 -122
- warp/examples/sim/example_soft_body.py +188 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +61 -27
- warp/fem/cache.py +403 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +16 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +748 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +437 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/nanogrid.py +455 -0
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1684 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +179 -292
- warp/fem/space/basis_space.py +522 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +148 -267
- warp/fem/space/grid_3d_function_space.py +167 -306
- warp/fem/space/hexmesh_function_space.py +253 -352
- warp/fem/space/nanogrid_function_space.py +202 -0
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +261 -369
- warp/fem/space/restriction.py +161 -160
- warp/fem/space/shape/__init__.py +90 -15
- warp/fem/space/shape/cube_shape_function.py +728 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +224 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +153 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1081 -1025
- warp/native/builtin.h +1603 -1560
- warp/native/bvh.cpp +402 -398
- warp/native/bvh.cu +533 -525
- warp/native/bvh.h +430 -429
- warp/native/clang/clang.cpp +496 -464
- warp/native/crt.cpp +42 -32
- warp/native/crt.h +352 -335
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/exports.h +187 -0
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1545 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +292 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -4782
- warp/native/nanovdb/PNanoVDB.h +3390 -2553
- warp/native/noise.h +850 -850
- warp/native/quat.h +1112 -1085
- warp/native/rand.h +303 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1177 -1133
- warp/native/volume.cpp +529 -297
- warp/native/volume.cu +58 -32
- warp/native/volume.h +960 -538
- warp/native/volume_builder.cu +446 -425
- warp/native/volume_builder.h +34 -19
- warp/native/volume_impl.h +61 -0
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2949 -2828
- warp/native/warp.h +321 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3356 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1917 -1991
- warp/sim/integrator_xpbd.py +3288 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1289 -1227
- warp/stubs.py +2192 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +20 -22
- warp/tests/aux_test_grad_customs.py +21 -23
- warp/tests/aux_test_reference.py +9 -11
- warp/tests/aux_test_reference_reference.py +8 -10
- warp/tests/aux_test_square.py +15 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +237 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +155 -157
- warp/tests/test_arithmetic.py +1088 -1124
- warp/tests/test_array.py +2415 -2326
- warp/tests/test_array_reduce.py +148 -150
- warp/tests/test_async.py +666 -656
- warp/tests/test_atomic.py +139 -141
- warp/tests/test_bool.py +212 -149
- warp/tests/test_builtins_resolution.py +1290 -1292
- warp/tests/test_bvh.py +162 -171
- warp/tests/test_closest_point_edge_edge.py +227 -228
- warp/tests/test_codegen.py +562 -553
- warp/tests/test_compile_consts.py +217 -101
- warp/tests/test_conditional.py +244 -246
- warp/tests/test_copy.py +230 -215
- warp/tests/test_ctypes.py +630 -632
- warp/tests/test_dense.py +65 -67
- warp/tests/test_devices.py +89 -98
- warp/tests/test_dlpack.py +528 -529
- warp/tests/test_examples.py +403 -378
- warp/tests/test_fabricarray.py +952 -955
- warp/tests/test_fast_math.py +60 -54
- warp/tests/test_fem.py +1298 -1278
- warp/tests/test_fp16.py +128 -130
- warp/tests/test_func.py +336 -337
- warp/tests/test_generics.py +596 -571
- warp/tests/test_grad.py +885 -640
- warp/tests/test_grad_customs.py +331 -336
- warp/tests/test_hash_grid.py +208 -164
- warp/tests/test_import.py +37 -39
- warp/tests/test_indexedarray.py +1132 -1134
- warp/tests/test_intersect.py +65 -67
- warp/tests/test_jax.py +305 -307
- warp/tests/test_large.py +169 -164
- warp/tests/test_launch.py +352 -354
- warp/tests/test_lerp.py +217 -261
- warp/tests/test_linear_solvers.py +189 -171
- warp/tests/test_lvalue.py +419 -493
- warp/tests/test_marching_cubes.py +63 -65
- warp/tests/test_mat.py +1799 -1827
- warp/tests/test_mat_lite.py +113 -115
- warp/tests/test_mat_scalar_ops.py +2905 -2889
- warp/tests/test_math.py +124 -193
- warp/tests/test_matmul.py +498 -499
- warp/tests/test_matmul_lite.py +408 -410
- warp/tests/test_mempool.py +186 -190
- warp/tests/test_mesh.py +281 -324
- warp/tests/test_mesh_query_aabb.py +226 -241
- warp/tests/test_mesh_query_point.py +690 -702
- warp/tests/test_mesh_query_ray.py +290 -303
- warp/tests/test_mlp.py +274 -276
- warp/tests/test_model.py +108 -110
- warp/tests/test_module_hashing.py +111 -0
- warp/tests/test_modules_lite.py +36 -39
- warp/tests/test_multigpu.py +161 -163
- warp/tests/test_noise.py +244 -248
- warp/tests/test_operators.py +248 -250
- warp/tests/test_options.py +121 -125
- warp/tests/test_peer.py +131 -137
- warp/tests/test_pinned.py +76 -78
- warp/tests/test_print.py +52 -54
- warp/tests/test_quat.py +2084 -2086
- warp/tests/test_rand.py +324 -288
- warp/tests/test_reload.py +207 -217
- warp/tests/test_rounding.py +177 -179
- warp/tests/test_runlength_encode.py +188 -190
- warp/tests/test_sim_grad.py +241 -0
- warp/tests/test_sim_kinematics.py +89 -97
- warp/tests/test_smoothstep.py +166 -168
- warp/tests/test_snippet.py +303 -266
- warp/tests/test_sparse.py +466 -460
- warp/tests/test_spatial.py +2146 -2148
- warp/tests/test_special_values.py +362 -0
- warp/tests/test_streams.py +484 -473
- warp/tests/test_struct.py +708 -675
- warp/tests/test_tape.py +171 -148
- warp/tests/test_torch.py +741 -743
- warp/tests/test_transient_module.py +85 -87
- warp/tests/test_types.py +554 -659
- warp/tests/test_utils.py +488 -499
- warp/tests/test_vec.py +1262 -1268
- warp/tests/test_vec_lite.py +71 -73
- warp/tests/test_vec_scalar_ops.py +2097 -2099
- warp/tests/test_verify_fp.py +92 -94
- warp/tests/test_volume.py +961 -736
- warp/tests/test_volume_write.py +338 -265
- warp/tests/unittest_serial.py +38 -37
- warp/tests/unittest_suites.py +367 -359
- warp/tests/unittest_utils.py +434 -578
- warp/tests/unused_test_misc.py +69 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +563 -561
- warp/torch.py +321 -295
- warp/types.py +4941 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
- warp_lang-1.2.0.dist-info/RECORD +359 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
- warp/native/nanovdb/PNanoVDBWrite.h +0 -295
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/fem/space/partition.py
CHANGED
|
@@ -1,350 +1,350 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
2
|
-
|
|
3
|
-
import warp as wp
|
|
4
|
-
from warp.fem.cache import (
|
|
5
|
-
TemporaryStore,
|
|
6
|
-
borrow_temporary,
|
|
7
|
-
borrow_temporary_like,
|
|
8
|
-
cached_arg_value,
|
|
9
|
-
)
|
|
10
|
-
from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
|
|
11
|
-
from warp.fem.types import NULL_NODE_INDEX
|
|
12
|
-
from warp.fem.utils import _iota_kernel, compress_node_indices
|
|
13
|
-
|
|
14
|
-
from .function_space import FunctionSpace
|
|
15
|
-
from .topology import SpaceTopology
|
|
16
|
-
|
|
17
|
-
wp.set_module_options({"enable_backward": False})
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class SpacePartition:
|
|
21
|
-
class PartitionArg:
|
|
22
|
-
pass
|
|
23
|
-
|
|
24
|
-
def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
|
|
25
|
-
self.space_topology = space_topology
|
|
26
|
-
self.geo_partition = geo_partition
|
|
27
|
-
|
|
28
|
-
def node_count(self):
|
|
29
|
-
"""Returns number of nodes in this partition"""
|
|
30
|
-
|
|
31
|
-
def owned_node_count(self) -> int:
|
|
32
|
-
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
33
|
-
|
|
34
|
-
def interior_node_count(self) -> int:
|
|
35
|
-
"""Returns number of interior nodes in this partition"""
|
|
36
|
-
|
|
37
|
-
def space_node_indices(self) -> wp.array:
|
|
38
|
-
"""Return the global function space indices for nodes in this partition"""
|
|
39
|
-
|
|
40
|
-
def partition_arg_value(self, device):
|
|
41
|
-
pass
|
|
42
|
-
|
|
43
|
-
@staticmethod
|
|
44
|
-
def partition_node_index(args: "PartitionArg", space_node_index: int):
|
|
45
|
-
"""Returns the index in the partition of a function space node, or -1 if it does not exist"""
|
|
46
|
-
|
|
47
|
-
def __str__(self) -> str:
|
|
48
|
-
return self.name
|
|
49
|
-
|
|
50
|
-
@property
|
|
51
|
-
def name(self) -> str:
|
|
52
|
-
return f"{self.__class__.__name__}"
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class WholeSpacePartition(SpacePartition):
|
|
56
|
-
@wp.struct
|
|
57
|
-
class PartitionArg:
|
|
58
|
-
pass
|
|
59
|
-
|
|
60
|
-
def __init__(self, space_topology: SpaceTopology):
|
|
61
|
-
super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
|
|
62
|
-
self._node_indices = None
|
|
63
|
-
|
|
64
|
-
def node_count(self):
|
|
65
|
-
"""Returns number of nodes in this partition"""
|
|
66
|
-
return self.space_topology.node_count()
|
|
67
|
-
|
|
68
|
-
def owned_node_count(self) -> int:
|
|
69
|
-
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
70
|
-
return self.space_topology.node_count()
|
|
71
|
-
|
|
72
|
-
def interior_node_count(self) -> int:
|
|
73
|
-
"""Returns number of interior nodes in this partition"""
|
|
74
|
-
return self.space_topology.node_count()
|
|
75
|
-
|
|
76
|
-
def space_node_indices(self):
|
|
77
|
-
"""Return the global function space indices for nodes in this partition"""
|
|
78
|
-
if self._node_indices is None:
|
|
79
|
-
self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
|
|
80
|
-
wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
|
|
81
|
-
return self._node_indices.array
|
|
82
|
-
|
|
83
|
-
def partition_arg_value(self, device):
|
|
84
|
-
return WholeSpacePartition.PartitionArg()
|
|
85
|
-
|
|
86
|
-
@wp.func
|
|
87
|
-
def partition_node_index(args: Any, space_node_index: int):
|
|
88
|
-
return space_node_index
|
|
89
|
-
|
|
90
|
-
def __eq__(self, other: SpacePartition) -> bool:
|
|
91
|
-
return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
|
|
92
|
-
|
|
93
|
-
@property
|
|
94
|
-
def name(self) -> str:
|
|
95
|
-
return "Whole"
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class NodeCategory:
|
|
99
|
-
OWNED_INTERIOR = wp.constant(0)
|
|
100
|
-
"""Node is touched exclusively by this partition, not touched by frontier side"""
|
|
101
|
-
OWNED_FRONTIER = wp.constant(1)
|
|
102
|
-
"""Node is touched by a frontier side, but belongs to an element of this partition"""
|
|
103
|
-
HALO_LOCAL_SIDE = wp.constant(2)
|
|
104
|
-
"""Node belongs to an element of another partition, but is touched by one of our frontier side"""
|
|
105
|
-
HALO_OTHER_SIDE = wp.constant(3)
|
|
106
|
-
"""Node belongs to an element of another partition, and is not touched by one of our frontier side"""
|
|
107
|
-
EXTERIOR = wp.constant(4)
|
|
108
|
-
"""Node is never referenced by this partition"""
|
|
109
|
-
|
|
110
|
-
COUNT = 5
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
class NodePartition(SpacePartition):
|
|
114
|
-
@wp.struct
|
|
115
|
-
class PartitionArg:
|
|
116
|
-
space_to_partition: wp.array(dtype=int)
|
|
117
|
-
|
|
118
|
-
def __init__(
|
|
119
|
-
self,
|
|
120
|
-
space_topology: SpaceTopology,
|
|
121
|
-
geo_partition: GeometryPartition,
|
|
122
|
-
with_halo: bool = True,
|
|
123
|
-
device=None,
|
|
124
|
-
temporary_store: TemporaryStore = None,
|
|
125
|
-
):
|
|
126
|
-
super().__init__(space_topology=space_topology, geo_partition=geo_partition)
|
|
127
|
-
|
|
128
|
-
self._compute_node_indices_from_sides(device, with_halo, temporary_store)
|
|
129
|
-
|
|
130
|
-
def node_count(self) -> int:
|
|
131
|
-
"""Returns number of nodes referenced by this partition, including exterior halo"""
|
|
132
|
-
return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
|
|
133
|
-
|
|
134
|
-
def owned_node_count(self) -> int:
|
|
135
|
-
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
136
|
-
return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
|
|
137
|
-
|
|
138
|
-
def interior_node_count(self) -> int:
|
|
139
|
-
"""Returns number of interior nodes in this partition"""
|
|
140
|
-
return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
|
|
141
|
-
|
|
142
|
-
def space_node_indices(self):
|
|
143
|
-
"""Return the global function space indices for nodes in this partition"""
|
|
144
|
-
return self._node_indices.array
|
|
145
|
-
|
|
146
|
-
@cached_arg_value
|
|
147
|
-
def partition_arg_value(self, device):
|
|
148
|
-
arg = NodePartition.PartitionArg()
|
|
149
|
-
arg.space_to_partition = self._space_to_partition.array.to(device)
|
|
150
|
-
return arg
|
|
151
|
-
|
|
152
|
-
@wp.func
|
|
153
|
-
def partition_node_index(args: PartitionArg, space_node_index: int):
|
|
154
|
-
return args.space_to_partition[space_node_index]
|
|
155
|
-
|
|
156
|
-
def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
|
|
157
|
-
from warp.fem import cache
|
|
158
|
-
|
|
159
|
-
trace_topology = self.space_topology.trace()
|
|
160
|
-
NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
|
|
161
|
-
NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
|
|
162
|
-
|
|
163
|
-
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
164
|
-
def node_category_from_cells_kernel(
|
|
165
|
-
geo_arg: self.geo_partition.geometry.CellArg,
|
|
166
|
-
geo_partition_arg: self.geo_partition.CellArg,
|
|
167
|
-
space_arg: self.space_topology.TopologyArg,
|
|
168
|
-
node_mask: wp.array(dtype=int),
|
|
169
|
-
):
|
|
170
|
-
partition_cell_index = wp.tid()
|
|
171
|
-
|
|
172
|
-
cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
|
|
173
|
-
|
|
174
|
-
for n in range(NODES_PER_CELL):
|
|
175
|
-
space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
|
|
176
|
-
node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
|
|
177
|
-
|
|
178
|
-
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
179
|
-
def node_category_from_owned_sides_kernel(
|
|
180
|
-
geo_arg: self.geo_partition.geometry.SideArg,
|
|
181
|
-
geo_partition_arg: self.geo_partition.SideArg,
|
|
182
|
-
space_arg: trace_topology.TopologyArg,
|
|
183
|
-
node_mask: wp.array(dtype=int),
|
|
184
|
-
):
|
|
185
|
-
partition_side_index = wp.tid()
|
|
186
|
-
|
|
187
|
-
side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
|
|
188
|
-
|
|
189
|
-
for n in range(NODES_PER_SIDE):
|
|
190
|
-
space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
|
|
191
|
-
|
|
192
|
-
if node_mask[space_nidx] == NodeCategory.EXTERIOR:
|
|
193
|
-
node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
|
|
194
|
-
|
|
195
|
-
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
196
|
-
def node_category_from_frontier_sides_kernel(
|
|
197
|
-
geo_arg: self.geo_partition.geometry.SideArg,
|
|
198
|
-
geo_partition_arg: self.geo_partition.SideArg,
|
|
199
|
-
space_arg: trace_topology.TopologyArg,
|
|
200
|
-
node_mask: wp.array(dtype=int),
|
|
201
|
-
):
|
|
202
|
-
frontier_side_index = wp.tid()
|
|
203
|
-
|
|
204
|
-
side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
|
|
205
|
-
|
|
206
|
-
for n in range(NODES_PER_SIDE):
|
|
207
|
-
space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
|
|
208
|
-
if node_mask[space_nidx] == NodeCategory.EXTERIOR:
|
|
209
|
-
node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
|
|
210
|
-
elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
|
|
211
|
-
node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
|
|
212
|
-
|
|
213
|
-
node_category = borrow_temporary(
|
|
214
|
-
temporary_store,
|
|
215
|
-
shape=(self.space_topology.node_count(),),
|
|
216
|
-
dtype=int,
|
|
217
|
-
device=device,
|
|
218
|
-
)
|
|
219
|
-
node_category.array.fill_(value=NodeCategory.EXTERIOR)
|
|
220
|
-
|
|
221
|
-
wp.launch(
|
|
222
|
-
dim=self.geo_partition.cell_count(),
|
|
223
|
-
kernel=node_category_from_cells_kernel,
|
|
224
|
-
inputs=[
|
|
225
|
-
self.geo_partition.geometry.cell_arg_value(device),
|
|
226
|
-
self.geo_partition.cell_arg_value(device),
|
|
227
|
-
self.space_topology.topo_arg_value(device),
|
|
228
|
-
node_category.array,
|
|
229
|
-
],
|
|
230
|
-
device=device,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
if with_halo:
|
|
234
|
-
wp.launch(
|
|
235
|
-
dim=self.geo_partition.side_count(),
|
|
236
|
-
kernel=node_category_from_owned_sides_kernel,
|
|
237
|
-
inputs=[
|
|
238
|
-
self.geo_partition.geometry.side_arg_value(device),
|
|
239
|
-
self.geo_partition.side_arg_value(device),
|
|
240
|
-
self.space_topology.topo_arg_value(device),
|
|
241
|
-
node_category.array,
|
|
242
|
-
],
|
|
243
|
-
device=device,
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
wp.launch(
|
|
247
|
-
dim=self.geo_partition.frontier_side_count(),
|
|
248
|
-
kernel=node_category_from_frontier_sides_kernel,
|
|
249
|
-
inputs=[
|
|
250
|
-
self.geo_partition.geometry.side_arg_value(device),
|
|
251
|
-
self.geo_partition.side_arg_value(device),
|
|
252
|
-
self.space_topology.topo_arg_value(device),
|
|
253
|
-
node_category.array,
|
|
254
|
-
],
|
|
255
|
-
device=device,
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
self._finalize_node_indices(node_category.array, temporary_store)
|
|
259
|
-
|
|
260
|
-
node_category.release()
|
|
261
|
-
|
|
262
|
-
def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
|
|
263
|
-
category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
|
|
264
|
-
|
|
265
|
-
# Copy offsets to cpu
|
|
266
|
-
device = node_category.device
|
|
267
|
-
self._category_offsets = borrow_temporary(
|
|
268
|
-
temporary_store,
|
|
269
|
-
shape=category_offsets.array.shape,
|
|
270
|
-
dtype=category_offsets.array.dtype,
|
|
271
|
-
pinned=device.is_cuda,
|
|
272
|
-
device="cpu",
|
|
273
|
-
)
|
|
274
|
-
wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
|
|
275
|
-
|
|
276
|
-
if device.is_cuda:
|
|
277
|
-
# TODO switch to synchronize_event once available
|
|
278
|
-
wp.synchronize_stream(wp.get_stream(device))
|
|
279
|
-
|
|
280
|
-
category_offsets.release()
|
|
281
|
-
|
|
282
|
-
# Compute global to local indices
|
|
283
|
-
self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
|
|
284
|
-
wp.launch(
|
|
285
|
-
kernel=NodePartition._scatter_partition_indices,
|
|
286
|
-
dim=self.space_topology.node_count(),
|
|
287
|
-
device=device,
|
|
288
|
-
inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
# Copy to shrinked-to-fit array
|
|
292
|
-
self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
|
|
293
|
-
wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
|
|
294
|
-
|
|
295
|
-
node_indices.release()
|
|
296
|
-
|
|
297
|
-
@wp.kernel
|
|
298
|
-
def _scatter_partition_indices(
|
|
299
|
-
local_node_count: int,
|
|
300
|
-
node_indices: wp.array(dtype=int),
|
|
301
|
-
space_to_partition_indices: wp.array(dtype=int),
|
|
302
|
-
):
|
|
303
|
-
local_idx = wp.tid()
|
|
304
|
-
space_idx = node_indices[local_idx]
|
|
305
|
-
|
|
306
|
-
if local_idx < local_node_count:
|
|
307
|
-
space_to_partition_indices[space_idx] = local_idx
|
|
308
|
-
else:
|
|
309
|
-
space_to_partition_indices[space_idx] = NULL_NODE_INDEX
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def make_space_partition(
|
|
313
|
-
space: Optional[FunctionSpace] = None,
|
|
314
|
-
geometry_partition: Optional[GeometryPartition] = None,
|
|
315
|
-
space_topology: Optional[SpaceTopology] = None,
|
|
316
|
-
with_halo: bool = True,
|
|
317
|
-
device=None,
|
|
318
|
-
temporary_store: TemporaryStore = None,
|
|
319
|
-
) -> SpacePartition:
|
|
320
|
-
"""Computes the subset of nodes from a function space topology that touch a geometry partition
|
|
321
|
-
|
|
322
|
-
Either `space_topology` or `space` must be provided (and will be considered in that order).
|
|
323
|
-
|
|
324
|
-
Args:
|
|
325
|
-
space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
|
|
326
|
-
geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
|
|
327
|
-
space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
|
|
328
|
-
with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
|
|
329
|
-
device: Warp device on which to perform and store computations
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
the resulting space partition
|
|
333
|
-
"""
|
|
334
|
-
|
|
335
|
-
if space_topology is None:
|
|
336
|
-
space_topology = space.topology
|
|
337
|
-
|
|
338
|
-
space_topology = space_topology.full_space_topology()
|
|
339
|
-
|
|
340
|
-
if geometry_partition is not None:
|
|
341
|
-
if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
|
|
342
|
-
return NodePartition(
|
|
343
|
-
space_topology=space_topology,
|
|
344
|
-
geo_partition=geometry_partition,
|
|
345
|
-
with_halo=with_halo,
|
|
346
|
-
device=device,
|
|
347
|
-
temporary_store=temporary_store,
|
|
348
|
-
)
|
|
349
|
-
|
|
350
|
-
return WholeSpacePartition(space_topology)
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import warp as wp
|
|
4
|
+
from warp.fem.cache import (
|
|
5
|
+
TemporaryStore,
|
|
6
|
+
borrow_temporary,
|
|
7
|
+
borrow_temporary_like,
|
|
8
|
+
cached_arg_value,
|
|
9
|
+
)
|
|
10
|
+
from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
|
|
11
|
+
from warp.fem.types import NULL_NODE_INDEX
|
|
12
|
+
from warp.fem.utils import _iota_kernel, compress_node_indices
|
|
13
|
+
|
|
14
|
+
from .function_space import FunctionSpace
|
|
15
|
+
from .topology import SpaceTopology
|
|
16
|
+
|
|
17
|
+
wp.set_module_options({"enable_backward": False})
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpacePartition:
|
|
21
|
+
class PartitionArg:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
|
|
25
|
+
self.space_topology = space_topology
|
|
26
|
+
self.geo_partition = geo_partition
|
|
27
|
+
|
|
28
|
+
def node_count(self):
|
|
29
|
+
"""Returns number of nodes in this partition"""
|
|
30
|
+
|
|
31
|
+
def owned_node_count(self) -> int:
|
|
32
|
+
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
33
|
+
|
|
34
|
+
def interior_node_count(self) -> int:
|
|
35
|
+
"""Returns number of interior nodes in this partition"""
|
|
36
|
+
|
|
37
|
+
def space_node_indices(self) -> wp.array:
|
|
38
|
+
"""Return the global function space indices for nodes in this partition"""
|
|
39
|
+
|
|
40
|
+
def partition_arg_value(self, device):
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def partition_node_index(args: "PartitionArg", space_node_index: int):
|
|
45
|
+
"""Returns the index in the partition of a function space node, or -1 if it does not exist"""
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
return self.name
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def name(self) -> str:
|
|
52
|
+
return f"{self.__class__.__name__}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class WholeSpacePartition(SpacePartition):
|
|
56
|
+
@wp.struct
|
|
57
|
+
class PartitionArg:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def __init__(self, space_topology: SpaceTopology):
|
|
61
|
+
super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
|
|
62
|
+
self._node_indices = None
|
|
63
|
+
|
|
64
|
+
def node_count(self):
|
|
65
|
+
"""Returns number of nodes in this partition"""
|
|
66
|
+
return self.space_topology.node_count()
|
|
67
|
+
|
|
68
|
+
def owned_node_count(self) -> int:
|
|
69
|
+
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
70
|
+
return self.space_topology.node_count()
|
|
71
|
+
|
|
72
|
+
def interior_node_count(self) -> int:
|
|
73
|
+
"""Returns number of interior nodes in this partition"""
|
|
74
|
+
return self.space_topology.node_count()
|
|
75
|
+
|
|
76
|
+
def space_node_indices(self):
|
|
77
|
+
"""Return the global function space indices for nodes in this partition"""
|
|
78
|
+
if self._node_indices is None:
|
|
79
|
+
self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
|
|
80
|
+
wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
|
|
81
|
+
return self._node_indices.array
|
|
82
|
+
|
|
83
|
+
def partition_arg_value(self, device):
|
|
84
|
+
return WholeSpacePartition.PartitionArg()
|
|
85
|
+
|
|
86
|
+
@wp.func
|
|
87
|
+
def partition_node_index(args: Any, space_node_index: int):
|
|
88
|
+
return space_node_index
|
|
89
|
+
|
|
90
|
+
def __eq__(self, other: SpacePartition) -> bool:
|
|
91
|
+
return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def name(self) -> str:
|
|
95
|
+
return "Whole"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class NodeCategory:
|
|
99
|
+
OWNED_INTERIOR = wp.constant(0)
|
|
100
|
+
"""Node is touched exclusively by this partition, not touched by frontier side"""
|
|
101
|
+
OWNED_FRONTIER = wp.constant(1)
|
|
102
|
+
"""Node is touched by a frontier side, but belongs to an element of this partition"""
|
|
103
|
+
HALO_LOCAL_SIDE = wp.constant(2)
|
|
104
|
+
"""Node belongs to an element of another partition, but is touched by one of our frontier side"""
|
|
105
|
+
HALO_OTHER_SIDE = wp.constant(3)
|
|
106
|
+
"""Node belongs to an element of another partition, and is not touched by one of our frontier side"""
|
|
107
|
+
EXTERIOR = wp.constant(4)
|
|
108
|
+
"""Node is never referenced by this partition"""
|
|
109
|
+
|
|
110
|
+
COUNT = 5
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class NodePartition(SpacePartition):
|
|
114
|
+
@wp.struct
|
|
115
|
+
class PartitionArg:
|
|
116
|
+
space_to_partition: wp.array(dtype=int)
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
space_topology: SpaceTopology,
|
|
121
|
+
geo_partition: GeometryPartition,
|
|
122
|
+
with_halo: bool = True,
|
|
123
|
+
device=None,
|
|
124
|
+
temporary_store: TemporaryStore = None,
|
|
125
|
+
):
|
|
126
|
+
super().__init__(space_topology=space_topology, geo_partition=geo_partition)
|
|
127
|
+
|
|
128
|
+
self._compute_node_indices_from_sides(device, with_halo, temporary_store)
|
|
129
|
+
|
|
130
|
+
def node_count(self) -> int:
|
|
131
|
+
"""Returns number of nodes referenced by this partition, including exterior halo"""
|
|
132
|
+
return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
|
|
133
|
+
|
|
134
|
+
def owned_node_count(self) -> int:
|
|
135
|
+
"""Returns number of nodes in this partition, excluding exterior halo"""
|
|
136
|
+
return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
|
|
137
|
+
|
|
138
|
+
def interior_node_count(self) -> int:
|
|
139
|
+
"""Returns number of interior nodes in this partition"""
|
|
140
|
+
return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
|
|
141
|
+
|
|
142
|
+
def space_node_indices(self):
|
|
143
|
+
"""Return the global function space indices for nodes in this partition"""
|
|
144
|
+
return self._node_indices.array
|
|
145
|
+
|
|
146
|
+
@cached_arg_value
|
|
147
|
+
def partition_arg_value(self, device):
|
|
148
|
+
arg = NodePartition.PartitionArg()
|
|
149
|
+
arg.space_to_partition = self._space_to_partition.array.to(device)
|
|
150
|
+
return arg
|
|
151
|
+
|
|
152
|
+
@wp.func
|
|
153
|
+
def partition_node_index(args: PartitionArg, space_node_index: int):
|
|
154
|
+
return args.space_to_partition[space_node_index]
|
|
155
|
+
|
|
156
|
+
def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
|
|
157
|
+
from warp.fem import cache
|
|
158
|
+
|
|
159
|
+
trace_topology = self.space_topology.trace()
|
|
160
|
+
NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
|
|
161
|
+
NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
|
|
162
|
+
|
|
163
|
+
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
164
|
+
def node_category_from_cells_kernel(
|
|
165
|
+
geo_arg: self.geo_partition.geometry.CellArg,
|
|
166
|
+
geo_partition_arg: self.geo_partition.CellArg,
|
|
167
|
+
space_arg: self.space_topology.TopologyArg,
|
|
168
|
+
node_mask: wp.array(dtype=int),
|
|
169
|
+
):
|
|
170
|
+
partition_cell_index = wp.tid()
|
|
171
|
+
|
|
172
|
+
cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
|
|
173
|
+
|
|
174
|
+
for n in range(NODES_PER_CELL):
|
|
175
|
+
space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
|
|
176
|
+
node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
|
|
177
|
+
|
|
178
|
+
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
179
|
+
def node_category_from_owned_sides_kernel(
|
|
180
|
+
geo_arg: self.geo_partition.geometry.SideArg,
|
|
181
|
+
geo_partition_arg: self.geo_partition.SideArg,
|
|
182
|
+
space_arg: trace_topology.TopologyArg,
|
|
183
|
+
node_mask: wp.array(dtype=int),
|
|
184
|
+
):
|
|
185
|
+
partition_side_index = wp.tid()
|
|
186
|
+
|
|
187
|
+
side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
|
|
188
|
+
|
|
189
|
+
for n in range(NODES_PER_SIDE):
|
|
190
|
+
space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
|
|
191
|
+
|
|
192
|
+
if node_mask[space_nidx] == NodeCategory.EXTERIOR:
|
|
193
|
+
node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
|
|
194
|
+
|
|
195
|
+
@cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
|
|
196
|
+
def node_category_from_frontier_sides_kernel(
|
|
197
|
+
geo_arg: self.geo_partition.geometry.SideArg,
|
|
198
|
+
geo_partition_arg: self.geo_partition.SideArg,
|
|
199
|
+
space_arg: trace_topology.TopologyArg,
|
|
200
|
+
node_mask: wp.array(dtype=int),
|
|
201
|
+
):
|
|
202
|
+
frontier_side_index = wp.tid()
|
|
203
|
+
|
|
204
|
+
side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
|
|
205
|
+
|
|
206
|
+
for n in range(NODES_PER_SIDE):
|
|
207
|
+
space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
|
|
208
|
+
if node_mask[space_nidx] == NodeCategory.EXTERIOR:
|
|
209
|
+
node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
|
|
210
|
+
elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
|
|
211
|
+
node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
|
|
212
|
+
|
|
213
|
+
node_category = borrow_temporary(
|
|
214
|
+
temporary_store,
|
|
215
|
+
shape=(self.space_topology.node_count(),),
|
|
216
|
+
dtype=int,
|
|
217
|
+
device=device,
|
|
218
|
+
)
|
|
219
|
+
node_category.array.fill_(value=NodeCategory.EXTERIOR)
|
|
220
|
+
|
|
221
|
+
wp.launch(
|
|
222
|
+
dim=self.geo_partition.cell_count(),
|
|
223
|
+
kernel=node_category_from_cells_kernel,
|
|
224
|
+
inputs=[
|
|
225
|
+
self.geo_partition.geometry.cell_arg_value(device),
|
|
226
|
+
self.geo_partition.cell_arg_value(device),
|
|
227
|
+
self.space_topology.topo_arg_value(device),
|
|
228
|
+
node_category.array,
|
|
229
|
+
],
|
|
230
|
+
device=device,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if with_halo:
|
|
234
|
+
wp.launch(
|
|
235
|
+
dim=self.geo_partition.side_count(),
|
|
236
|
+
kernel=node_category_from_owned_sides_kernel,
|
|
237
|
+
inputs=[
|
|
238
|
+
self.geo_partition.geometry.side_arg_value(device),
|
|
239
|
+
self.geo_partition.side_arg_value(device),
|
|
240
|
+
self.space_topology.topo_arg_value(device),
|
|
241
|
+
node_category.array,
|
|
242
|
+
],
|
|
243
|
+
device=device,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
wp.launch(
|
|
247
|
+
dim=self.geo_partition.frontier_side_count(),
|
|
248
|
+
kernel=node_category_from_frontier_sides_kernel,
|
|
249
|
+
inputs=[
|
|
250
|
+
self.geo_partition.geometry.side_arg_value(device),
|
|
251
|
+
self.geo_partition.side_arg_value(device),
|
|
252
|
+
self.space_topology.topo_arg_value(device),
|
|
253
|
+
node_category.array,
|
|
254
|
+
],
|
|
255
|
+
device=device,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
self._finalize_node_indices(node_category.array, temporary_store)
|
|
259
|
+
|
|
260
|
+
node_category.release()
|
|
261
|
+
|
|
262
|
+
def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
|
|
263
|
+
category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
|
|
264
|
+
|
|
265
|
+
# Copy offsets to cpu
|
|
266
|
+
device = node_category.device
|
|
267
|
+
self._category_offsets = borrow_temporary(
|
|
268
|
+
temporary_store,
|
|
269
|
+
shape=category_offsets.array.shape,
|
|
270
|
+
dtype=category_offsets.array.dtype,
|
|
271
|
+
pinned=device.is_cuda,
|
|
272
|
+
device="cpu",
|
|
273
|
+
)
|
|
274
|
+
wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
|
|
275
|
+
|
|
276
|
+
if device.is_cuda:
|
|
277
|
+
# TODO switch to synchronize_event once available
|
|
278
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
279
|
+
|
|
280
|
+
category_offsets.release()
|
|
281
|
+
|
|
282
|
+
# Compute global to local indices
|
|
283
|
+
self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
|
|
284
|
+
wp.launch(
|
|
285
|
+
kernel=NodePartition._scatter_partition_indices,
|
|
286
|
+
dim=self.space_topology.node_count(),
|
|
287
|
+
device=device,
|
|
288
|
+
inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Copy to shrinked-to-fit array
|
|
292
|
+
self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
|
|
293
|
+
wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
|
|
294
|
+
|
|
295
|
+
node_indices.release()
|
|
296
|
+
|
|
297
|
+
@wp.kernel
|
|
298
|
+
def _scatter_partition_indices(
|
|
299
|
+
local_node_count: int,
|
|
300
|
+
node_indices: wp.array(dtype=int),
|
|
301
|
+
space_to_partition_indices: wp.array(dtype=int),
|
|
302
|
+
):
|
|
303
|
+
local_idx = wp.tid()
|
|
304
|
+
space_idx = node_indices[local_idx]
|
|
305
|
+
|
|
306
|
+
if local_idx < local_node_count:
|
|
307
|
+
space_to_partition_indices[space_idx] = local_idx
|
|
308
|
+
else:
|
|
309
|
+
space_to_partition_indices[space_idx] = NULL_NODE_INDEX
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def make_space_partition(
|
|
313
|
+
space: Optional[FunctionSpace] = None,
|
|
314
|
+
geometry_partition: Optional[GeometryPartition] = None,
|
|
315
|
+
space_topology: Optional[SpaceTopology] = None,
|
|
316
|
+
with_halo: bool = True,
|
|
317
|
+
device=None,
|
|
318
|
+
temporary_store: TemporaryStore = None,
|
|
319
|
+
) -> SpacePartition:
|
|
320
|
+
"""Computes the subset of nodes from a function space topology that touch a geometry partition
|
|
321
|
+
|
|
322
|
+
Either `space_topology` or `space` must be provided (and will be considered in that order).
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
|
|
326
|
+
geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
|
|
327
|
+
space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
|
|
328
|
+
with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
|
|
329
|
+
device: Warp device on which to perform and store computations
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
the resulting space partition
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
if space_topology is None:
|
|
336
|
+
space_topology = space.topology
|
|
337
|
+
|
|
338
|
+
space_topology = space_topology.full_space_topology()
|
|
339
|
+
|
|
340
|
+
if geometry_partition is not None:
|
|
341
|
+
if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
|
|
342
|
+
return NodePartition(
|
|
343
|
+
space_topology=space_topology,
|
|
344
|
+
geo_partition=geometry_partition,
|
|
345
|
+
with_halo=with_halo,
|
|
346
|
+
device=device,
|
|
347
|
+
temporary_store=temporary_store,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
return WholeSpacePartition(space_topology)
|