warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import warp as wp
4
4
 
5
5
  from warp.fem.types import ElementIndex, NULL_ELEMENT_INDEX
6
6
  from warp.fem.utils import masked_indices
7
+ from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary
7
8
 
8
9
  from .geometry import Geometry
9
10
 
@@ -12,9 +13,14 @@ wp.set_module_options({"enable_backward": False})
12
13
 
13
14
 
14
15
  class GeometryPartition:
15
-
16
16
  """Base class for geometry partitions, i.e. subset of cells and sides"""
17
17
 
18
+ class CellArg:
19
+ pass
20
+
21
+ class SideArg:
22
+ pass
23
+
18
24
  def __init__(self, geometry: Geometry):
19
25
  self.geometry = geometry
20
26
 
@@ -41,6 +47,37 @@ class GeometryPartition:
41
47
  def __str__(self) -> str:
42
48
  return self.name
43
49
 
50
+ def cell_arg_value(self, device):
51
+ raise NotImplementedError()
52
+
53
+ def side_arg_value(self, device):
54
+ raise NotImplementedError()
55
+
56
+ @staticmethod
57
+ def cell_index(args: CellArg, partition_cell_index: int):
58
+ """Index in the geometry of a partition cell"""
59
+ raise NotImplementedError()
60
+
61
+ @staticmethod
62
+ def partition_cell_index(args: CellArg, cell_index: int):
63
+ """Index of a geometry cell in the partition (or ``NULL_ELEMENT_INDEX``)"""
64
+ raise NotImplementedError()
65
+
66
+ @staticmethod
67
+ def side_index(args: SideArg, partition_side_index: int):
68
+ """Partition side to side index"""
69
+ raise NotImplementedError()
70
+
71
+ @staticmethod
72
+ def boundary_side_index(args: SideArg, boundary_side_index: int):
73
+ """Boundary side to side index"""
74
+ raise NotImplementedError()
75
+
76
+ @staticmethod
77
+ def frontier_side_index(args: SideArg, frontier_side_index: int):
78
+ """Frontier side to side index"""
79
+ raise NotImplementedError()
80
+
44
81
 
45
82
  class WholeGeometryPartition(GeometryPartition):
46
83
  """Trivial (NOP) partition"""
@@ -89,6 +126,10 @@ class WholeGeometryPartition(GeometryPartition):
89
126
  def _identity_element_index(args: Any, idx: ElementIndex):
90
127
  return idx
91
128
 
129
+ @property
130
+ def name(self) -> str:
131
+ return self.geometry.name
132
+
92
133
 
93
134
  class CellBasedGeometryPartition(GeometryPartition):
94
135
  """Geometry partition based on a subset of cells. Interior, boundary and frontier sides are automatically categorized."""
@@ -107,19 +148,20 @@ class CellBasedGeometryPartition(GeometryPartition):
107
148
  frontier_side_indices: wp.array(dtype=int)
108
149
 
109
150
  def side_count(self) -> int:
110
- return self._partition_side_indices.shape[0]
151
+ return self._partition_side_indices.array.shape[0]
111
152
 
112
153
  def boundary_side_count(self) -> int:
113
- return self._boundary_side_indices.shape[0]
154
+ return self._boundary_side_indices.array.shape[0]
114
155
 
115
156
  def frontier_side_count(self) -> int:
116
- return self._frontier_side_indices.shape[0]
157
+ return self._frontier_side_indices.array.shape[0]
117
158
 
159
+ @cached_arg_value
118
160
  def side_arg_value(self, device):
119
161
  arg = LinearGeometryPartition.SideArg()
120
- arg.partition_side_indices = self._partition_side_indices.to(device)
121
- arg.boundary_side_indices = self._boundary_side_indices.to(device)
122
- arg.frontier_side_indices = self._frontier_side_indices.to(device)
162
+ arg.partition_side_indices = self._partition_side_indices.array.to(device)
163
+ arg.boundary_side_indices = self._boundary_side_indices.array.to(device)
164
+ arg.frontier_side_indices = self._frontier_side_indices.array.to(device)
123
165
  return arg
124
166
 
125
167
  @wp.func
@@ -138,16 +180,16 @@ class CellBasedGeometryPartition(GeometryPartition):
138
180
  return args.frontier_side_indices[frontier_side_index]
139
181
 
140
182
  def compute_side_indices_from_cells(
141
- self,
142
- cell_arg_value: Any,
143
- cell_inclusion_test_func: wp.Function,
144
- device,
183
+ self, cell_arg_value: Any, cell_inclusion_test_func: wp.Function, device, temporary_store: TemporaryStore = None
145
184
  ):
146
185
  from warp.fem import cache
147
186
 
148
- def count_side_fn(
187
+ cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values()))
188
+
189
+ @cache.dynamic_kernel(suffix=f"{self.geometry.name}_{cell_inclusion_test_func.key}")
190
+ def count_sides(
149
191
  geo_arg: self.geometry.SideArg,
150
- cell_arg_value: Any,
192
+ cell_arg_value: cell_arg_type,
151
193
  partition_side_mask: wp.array(dtype=int),
152
194
  boundary_side_mask: wp.array(dtype=int),
153
195
  frontier_side_mask: wp.array(dtype=int),
@@ -171,44 +213,50 @@ class CellBasedGeometryPartition(GeometryPartition):
171
213
  # Exactly one neighbor in partition; count as frontier side
172
214
  frontier_side_mask[side_index] = 1
173
215
 
174
- count_sides = cache.get_kernel(
175
- count_side_fn,
176
- suffix=f"{self.geometry.name}_{cell_inclusion_test_func.key}",
177
- )
178
-
179
- partition_side_mask = wp.zeros(
216
+ partition_side_mask = borrow_temporary(
217
+ temporary_store,
180
218
  shape=(self.geometry.side_count(),),
181
219
  dtype=int,
182
220
  device=device,
183
221
  )
184
- boundary_side_mask = wp.zeros(
222
+ boundary_side_mask = borrow_temporary(
223
+ temporary_store,
185
224
  shape=(self.geometry.side_count(),),
186
225
  dtype=int,
187
226
  device=device,
188
227
  )
189
- frontier_side_mask = wp.zeros(
228
+ frontier_side_mask = borrow_temporary(
229
+ temporary_store,
190
230
  shape=(self.geometry.side_count(),),
191
231
  dtype=int,
192
232
  device=device,
193
233
  )
194
234
 
235
+ partition_side_mask.array.zero_()
236
+ boundary_side_mask.array.zero_()
237
+ frontier_side_mask.array.zero_()
238
+
195
239
  wp.launch(
196
- dim=partition_side_mask.shape[0],
240
+ dim=partition_side_mask.array.shape[0],
197
241
  kernel=count_sides,
198
242
  inputs=[
199
243
  self.geometry.side_arg_value(device),
200
244
  cell_arg_value,
201
- partition_side_mask,
202
- boundary_side_mask,
203
- frontier_side_mask,
245
+ partition_side_mask.array,
246
+ boundary_side_mask.array,
247
+ frontier_side_mask.array,
204
248
  ],
205
249
  device=device,
206
250
  )
207
251
 
208
252
  # Convert counts to indices
209
- self._partition_side_indices, _ = masked_indices(partition_side_mask)
210
- self._boundary_side_indices, _ = masked_indices(boundary_side_mask)
211
- self._frontier_side_indices, _ = masked_indices(frontier_side_mask)
253
+ self._partition_side_indices, _ = masked_indices(partition_side_mask.array, temporary_store=temporary_store)
254
+ self._boundary_side_indices, _ = masked_indices(boundary_side_mask.array, temporary_store=temporary_store)
255
+ self._frontier_side_indices, _ = masked_indices(frontier_side_mask.array, temporary_store=temporary_store)
256
+
257
+ partition_side_mask.release()
258
+ boundary_side_mask.release()
259
+ frontier_side_mask.release()
212
260
 
213
261
 
214
262
  class LinearGeometryPartition(CellBasedGeometryPartition):
@@ -218,6 +266,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
218
266
  partition_rank: int,
219
267
  partition_count: int,
220
268
  device=None,
269
+ temporary_store: TemporaryStore = None,
221
270
  ):
222
271
  """Creates a geometry partition by uniformly partionning cell indices
223
272
 
@@ -239,6 +288,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
239
288
  self.cell_arg_value(device),
240
289
  LinearGeometryPartition._cell_inclusion_test,
241
290
  device,
291
+ temporary_store=temporary_store,
242
292
  )
243
293
 
244
294
  def cell_count(self) -> int:
@@ -278,7 +328,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
278
328
 
279
329
 
280
330
  class ExplicitGeometryPartition(CellBasedGeometryPartition):
281
- def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)"):
331
+ def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)", temporary_store: TemporaryStore = None):
282
332
  """Creates a geometry partition by uniformly partionning cell indices
283
333
 
284
334
  Args:
@@ -289,26 +339,28 @@ class ExplicitGeometryPartition(CellBasedGeometryPartition):
289
339
  super().__init__(geometry)
290
340
 
291
341
  self._cell_mask = cell_mask
292
- self._cells, self._partition_cells = masked_indices(self._cell_mask)
342
+ self._cells, self._partition_cells = masked_indices(self._cell_mask, temporary_store=temporary_store)
293
343
 
294
344
  super().compute_side_indices_from_cells(
295
345
  self._cell_mask,
296
346
  ExplicitGeometryPartition._cell_inclusion_test,
297
347
  self._cell_mask.device,
348
+ temporary_store=temporary_store,
298
349
  )
299
350
 
300
351
  def cell_count(self) -> int:
301
- return self._cells.shape[0]
352
+ return self._cells.array.shape[0]
302
353
 
303
354
  @wp.struct
304
355
  class CellArg:
305
356
  cell_index: wp.array(dtype=int)
306
357
  partition_cell_index: wp.array(dtype=int)
307
358
 
359
+ @cached_arg_value
308
360
  def cell_arg_value(self, device):
309
361
  arg = ExplicitGeometryPartition.CellArg()
310
- arg.cell_index = self._cells.to(device)
311
- arg.partition_cell_index = self._partition_cells.to(device)
362
+ arg.cell_index = self._cells.array.to(device)
363
+ arg.partition_cell_index = self._partition_cells.array.to(device)
312
364
  return arg
313
365
 
314
366
  @wp.func