warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -380
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/geometry/partition.py
CHANGED
|
@@ -4,6 +4,7 @@ import warp as wp
|
|
|
4
4
|
|
|
5
5
|
from warp.fem.types import ElementIndex, NULL_ELEMENT_INDEX
|
|
6
6
|
from warp.fem.utils import masked_indices
|
|
7
|
+
from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary
|
|
7
8
|
|
|
8
9
|
from .geometry import Geometry
|
|
9
10
|
|
|
@@ -12,9 +13,14 @@ wp.set_module_options({"enable_backward": False})
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class GeometryPartition:
|
|
15
|
-
|
|
16
16
|
"""Base class for geometry partitions, i.e. subset of cells and sides"""
|
|
17
17
|
|
|
18
|
+
class CellArg:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
class SideArg:
|
|
22
|
+
pass
|
|
23
|
+
|
|
18
24
|
def __init__(self, geometry: Geometry):
|
|
19
25
|
self.geometry = geometry
|
|
20
26
|
|
|
@@ -41,6 +47,37 @@ class GeometryPartition:
|
|
|
41
47
|
def __str__(self) -> str:
|
|
42
48
|
return self.name
|
|
43
49
|
|
|
50
|
+
def cell_arg_value(self, device):
|
|
51
|
+
raise NotImplementedError()
|
|
52
|
+
|
|
53
|
+
def side_arg_value(self, device):
|
|
54
|
+
raise NotImplementedError()
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def cell_index(args: CellArg, partition_cell_index: int):
|
|
58
|
+
"""Index in the geometry of a partition cell"""
|
|
59
|
+
raise NotImplementedError()
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def partition_cell_index(args: CellArg, cell_index: int):
|
|
63
|
+
"""Index of a geometry cell in the partition (or ``NULL_ELEMENT_INDEX``)"""
|
|
64
|
+
raise NotImplementedError()
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def side_index(args: SideArg, partition_side_index: int):
|
|
68
|
+
"""Partition side to side index"""
|
|
69
|
+
raise NotImplementedError()
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def boundary_side_index(args: SideArg, boundary_side_index: int):
|
|
73
|
+
"""Boundary side to side index"""
|
|
74
|
+
raise NotImplementedError()
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def frontier_side_index(args: SideArg, frontier_side_index: int):
|
|
78
|
+
"""Frontier side to side index"""
|
|
79
|
+
raise NotImplementedError()
|
|
80
|
+
|
|
44
81
|
|
|
45
82
|
class WholeGeometryPartition(GeometryPartition):
|
|
46
83
|
"""Trivial (NOP) partition"""
|
|
@@ -89,6 +126,10 @@ class WholeGeometryPartition(GeometryPartition):
|
|
|
89
126
|
def _identity_element_index(args: Any, idx: ElementIndex):
|
|
90
127
|
return idx
|
|
91
128
|
|
|
129
|
+
@property
|
|
130
|
+
def name(self) -> str:
|
|
131
|
+
return self.geometry.name
|
|
132
|
+
|
|
92
133
|
|
|
93
134
|
class CellBasedGeometryPartition(GeometryPartition):
|
|
94
135
|
"""Geometry partition based on a subset of cells. Interior, boundary and frontier sides are automatically categorized."""
|
|
@@ -107,19 +148,20 @@ class CellBasedGeometryPartition(GeometryPartition):
|
|
|
107
148
|
frontier_side_indices: wp.array(dtype=int)
|
|
108
149
|
|
|
109
150
|
def side_count(self) -> int:
|
|
110
|
-
return self._partition_side_indices.shape[0]
|
|
151
|
+
return self._partition_side_indices.array.shape[0]
|
|
111
152
|
|
|
112
153
|
def boundary_side_count(self) -> int:
|
|
113
|
-
return self._boundary_side_indices.shape[0]
|
|
154
|
+
return self._boundary_side_indices.array.shape[0]
|
|
114
155
|
|
|
115
156
|
def frontier_side_count(self) -> int:
|
|
116
|
-
return self._frontier_side_indices.shape[0]
|
|
157
|
+
return self._frontier_side_indices.array.shape[0]
|
|
117
158
|
|
|
159
|
+
@cached_arg_value
|
|
118
160
|
def side_arg_value(self, device):
|
|
119
161
|
arg = LinearGeometryPartition.SideArg()
|
|
120
|
-
arg.partition_side_indices = self._partition_side_indices.to(device)
|
|
121
|
-
arg.boundary_side_indices = self._boundary_side_indices.to(device)
|
|
122
|
-
arg.frontier_side_indices = self._frontier_side_indices.to(device)
|
|
162
|
+
arg.partition_side_indices = self._partition_side_indices.array.to(device)
|
|
163
|
+
arg.boundary_side_indices = self._boundary_side_indices.array.to(device)
|
|
164
|
+
arg.frontier_side_indices = self._frontier_side_indices.array.to(device)
|
|
123
165
|
return arg
|
|
124
166
|
|
|
125
167
|
@wp.func
|
|
@@ -138,16 +180,16 @@ class CellBasedGeometryPartition(GeometryPartition):
|
|
|
138
180
|
return args.frontier_side_indices[frontier_side_index]
|
|
139
181
|
|
|
140
182
|
def compute_side_indices_from_cells(
|
|
141
|
-
self,
|
|
142
|
-
cell_arg_value: Any,
|
|
143
|
-
cell_inclusion_test_func: wp.Function,
|
|
144
|
-
device,
|
|
183
|
+
self, cell_arg_value: Any, cell_inclusion_test_func: wp.Function, device, temporary_store: TemporaryStore = None
|
|
145
184
|
):
|
|
146
185
|
from warp.fem import cache
|
|
147
186
|
|
|
148
|
-
|
|
187
|
+
cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values()))
|
|
188
|
+
|
|
189
|
+
@cache.dynamic_kernel(suffix=f"{self.geometry.name}_{cell_inclusion_test_func.key}")
|
|
190
|
+
def count_sides(
|
|
149
191
|
geo_arg: self.geometry.SideArg,
|
|
150
|
-
cell_arg_value:
|
|
192
|
+
cell_arg_value: cell_arg_type,
|
|
151
193
|
partition_side_mask: wp.array(dtype=int),
|
|
152
194
|
boundary_side_mask: wp.array(dtype=int),
|
|
153
195
|
frontier_side_mask: wp.array(dtype=int),
|
|
@@ -171,44 +213,50 @@ class CellBasedGeometryPartition(GeometryPartition):
|
|
|
171
213
|
# Exactly one neighbor in partition; count as frontier side
|
|
172
214
|
frontier_side_mask[side_index] = 1
|
|
173
215
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
suffix=f"{self.geometry.name}_{cell_inclusion_test_func.key}",
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
partition_side_mask = wp.zeros(
|
|
216
|
+
partition_side_mask = borrow_temporary(
|
|
217
|
+
temporary_store,
|
|
180
218
|
shape=(self.geometry.side_count(),),
|
|
181
219
|
dtype=int,
|
|
182
220
|
device=device,
|
|
183
221
|
)
|
|
184
|
-
boundary_side_mask =
|
|
222
|
+
boundary_side_mask = borrow_temporary(
|
|
223
|
+
temporary_store,
|
|
185
224
|
shape=(self.geometry.side_count(),),
|
|
186
225
|
dtype=int,
|
|
187
226
|
device=device,
|
|
188
227
|
)
|
|
189
|
-
frontier_side_mask =
|
|
228
|
+
frontier_side_mask = borrow_temporary(
|
|
229
|
+
temporary_store,
|
|
190
230
|
shape=(self.geometry.side_count(),),
|
|
191
231
|
dtype=int,
|
|
192
232
|
device=device,
|
|
193
233
|
)
|
|
194
234
|
|
|
235
|
+
partition_side_mask.array.zero_()
|
|
236
|
+
boundary_side_mask.array.zero_()
|
|
237
|
+
frontier_side_mask.array.zero_()
|
|
238
|
+
|
|
195
239
|
wp.launch(
|
|
196
|
-
dim=partition_side_mask.shape[0],
|
|
240
|
+
dim=partition_side_mask.array.shape[0],
|
|
197
241
|
kernel=count_sides,
|
|
198
242
|
inputs=[
|
|
199
243
|
self.geometry.side_arg_value(device),
|
|
200
244
|
cell_arg_value,
|
|
201
|
-
partition_side_mask,
|
|
202
|
-
boundary_side_mask,
|
|
203
|
-
frontier_side_mask,
|
|
245
|
+
partition_side_mask.array,
|
|
246
|
+
boundary_side_mask.array,
|
|
247
|
+
frontier_side_mask.array,
|
|
204
248
|
],
|
|
205
249
|
device=device,
|
|
206
250
|
)
|
|
207
251
|
|
|
208
252
|
# Convert counts to indices
|
|
209
|
-
self._partition_side_indices, _ = masked_indices(partition_side_mask)
|
|
210
|
-
self._boundary_side_indices, _ = masked_indices(boundary_side_mask)
|
|
211
|
-
self._frontier_side_indices, _ = masked_indices(frontier_side_mask)
|
|
253
|
+
self._partition_side_indices, _ = masked_indices(partition_side_mask.array, temporary_store=temporary_store)
|
|
254
|
+
self._boundary_side_indices, _ = masked_indices(boundary_side_mask.array, temporary_store=temporary_store)
|
|
255
|
+
self._frontier_side_indices, _ = masked_indices(frontier_side_mask.array, temporary_store=temporary_store)
|
|
256
|
+
|
|
257
|
+
partition_side_mask.release()
|
|
258
|
+
boundary_side_mask.release()
|
|
259
|
+
frontier_side_mask.release()
|
|
212
260
|
|
|
213
261
|
|
|
214
262
|
class LinearGeometryPartition(CellBasedGeometryPartition):
|
|
@@ -218,6 +266,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
|
|
|
218
266
|
partition_rank: int,
|
|
219
267
|
partition_count: int,
|
|
220
268
|
device=None,
|
|
269
|
+
temporary_store: TemporaryStore = None,
|
|
221
270
|
):
|
|
222
271
|
"""Creates a geometry partition by uniformly partionning cell indices
|
|
223
272
|
|
|
@@ -239,6 +288,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
|
|
|
239
288
|
self.cell_arg_value(device),
|
|
240
289
|
LinearGeometryPartition._cell_inclusion_test,
|
|
241
290
|
device,
|
|
291
|
+
temporary_store=temporary_store,
|
|
242
292
|
)
|
|
243
293
|
|
|
244
294
|
def cell_count(self) -> int:
|
|
@@ -278,7 +328,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
|
|
|
278
328
|
|
|
279
329
|
|
|
280
330
|
class ExplicitGeometryPartition(CellBasedGeometryPartition):
|
|
281
|
-
def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)"):
|
|
331
|
+
def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)", temporary_store: TemporaryStore = None):
|
|
282
332
|
"""Creates a geometry partition by uniformly partionning cell indices
|
|
283
333
|
|
|
284
334
|
Args:
|
|
@@ -289,26 +339,28 @@ class ExplicitGeometryPartition(CellBasedGeometryPartition):
|
|
|
289
339
|
super().__init__(geometry)
|
|
290
340
|
|
|
291
341
|
self._cell_mask = cell_mask
|
|
292
|
-
self._cells, self._partition_cells = masked_indices(self._cell_mask)
|
|
342
|
+
self._cells, self._partition_cells = masked_indices(self._cell_mask, temporary_store=temporary_store)
|
|
293
343
|
|
|
294
344
|
super().compute_side_indices_from_cells(
|
|
295
345
|
self._cell_mask,
|
|
296
346
|
ExplicitGeometryPartition._cell_inclusion_test,
|
|
297
347
|
self._cell_mask.device,
|
|
348
|
+
temporary_store=temporary_store,
|
|
298
349
|
)
|
|
299
350
|
|
|
300
351
|
def cell_count(self) -> int:
|
|
301
|
-
return self._cells.shape[0]
|
|
352
|
+
return self._cells.array.shape[0]
|
|
302
353
|
|
|
303
354
|
@wp.struct
|
|
304
355
|
class CellArg:
|
|
305
356
|
cell_index: wp.array(dtype=int)
|
|
306
357
|
partition_cell_index: wp.array(dtype=int)
|
|
307
358
|
|
|
359
|
+
@cached_arg_value
|
|
308
360
|
def cell_arg_value(self, device):
|
|
309
361
|
arg = ExplicitGeometryPartition.CellArg()
|
|
310
|
-
arg.cell_index = self._cells.to(device)
|
|
311
|
-
arg.partition_cell_index = self._partition_cells.to(device)
|
|
362
|
+
arg.cell_index = self._cells.array.to(device)
|
|
363
|
+
arg.partition_cell_index = self._partition_cells.array.to(device)
|
|
312
364
|
return arg
|
|
313
365
|
|
|
314
366
|
@wp.func
|