warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,18 @@
1
- from typing import Any, Optional
1
+ from typing import Any, Optional, Union
2
2
 
3
3
  import warp as wp
4
-
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
5
10
  from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
6
- from warp.fem.utils import compress_node_indices, _iota_kernel
7
11
  from warp.fem.types import NULL_NODE_INDEX
12
+ from warp.fem.utils import _iota_kernel, compress_node_indices
8
13
 
9
14
  from .function_space import FunctionSpace
10
-
15
+ from .topology import SpaceTopology
11
16
 
12
17
  wp.set_module_options({"enable_backward": False})
13
18
 
@@ -16,8 +21,8 @@ class SpacePartition:
16
21
  class PartitionArg:
17
22
  pass
18
23
 
19
- def __init__(self, space: FunctionSpace, geo_partition: GeometryPartition):
20
- self.space = space
24
+ def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
25
+ self.space_topology = space_topology
21
26
  self.geo_partition = geo_partition
22
27
 
23
28
  def node_count(self):
@@ -35,7 +40,8 @@ class SpacePartition:
35
40
  def partition_arg_value(self, device):
36
41
  pass
37
42
 
38
- def partition_node_index(args: Any, space_node_index: int):
43
+ @staticmethod
44
+ def partition_node_index(args: "PartitionArg", space_node_index: int):
39
45
  """Returns the index in the partition of a function space node, or -1 if it does not exist"""
40
46
 
41
47
  def __str__(self) -> str:
@@ -51,28 +57,28 @@ class WholeSpacePartition(SpacePartition):
51
57
  class PartitionArg:
52
58
  pass
53
59
 
54
- def __init__(self, space: FunctionSpace):
55
- super().__init__(space, WholeGeometryPartition(space.geometry))
60
+ def __init__(self, space_topology: SpaceTopology):
61
+ super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
56
62
  self._node_indices = None
57
63
 
58
64
  def node_count(self):
59
65
  """Returns number of nodes in this partition"""
60
- return self.space.node_count()
66
+ return self.space_topology.node_count()
61
67
 
62
68
  def owned_node_count(self) -> int:
63
69
  """Returns number of nodes in this partition, excluding exterior halo"""
64
- return self.space.node_count()
70
+ return self.space_topology.node_count()
65
71
 
66
72
  def interior_node_count(self) -> int:
67
73
  """Returns number of interior nodes in this partition"""
68
- return self.space.node_count()
74
+ return self.space_topology.node_count()
69
75
 
70
76
  def space_node_indices(self):
71
77
  """Return the global function space indices for nodes in this partition"""
72
78
  if self._node_indices is None:
73
- self._node_indices = wp.empty(shape=(self.node_count(),), dtype=int)
74
- wp.launch(kernel=_iota_kernel, dim=self._node_indices.shape, inputs=[self._node_indices, 1])
75
- return self._node_indices
79
+ self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
80
+ wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
81
+ return self._node_indices.array
76
82
 
77
83
  def partition_arg_value(self, device):
78
84
  return WholeSpacePartition.PartitionArg()
@@ -82,7 +88,11 @@ class WholeSpacePartition(SpacePartition):
82
88
  return space_node_index
83
89
 
84
90
  def __eq__(self, other: SpacePartition) -> bool:
85
- return isinstance(other, SpacePartition) and self.space == other.space
91
+ return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
92
+
93
+ @property
94
+ def name(self) -> str:
95
+ return "Whole"
86
96
 
87
97
 
88
98
  class NodeCategory:
@@ -105,46 +115,56 @@ class NodePartition(SpacePartition):
105
115
  class PartitionArg:
106
116
  space_to_partition: wp.array(dtype=int)
107
117
 
108
- def __init__(self, space: FunctionSpace, geo_partition: GeometryPartition, with_halo: bool = True, device=None):
109
- super().__init__(space, geo_partition=geo_partition)
118
+ def __init__(
119
+ self,
120
+ space_topology: SpaceTopology,
121
+ geo_partition: GeometryPartition,
122
+ with_halo: bool = True,
123
+ device=None,
124
+ temporary_store: TemporaryStore = None,
125
+ ):
126
+ super().__init__(space_topology=space_topology, geo_partition=geo_partition)
110
127
 
111
- self._compute_node_indices_from_sides(device, with_halo)
128
+ self._compute_node_indices_from_sides(device, with_halo, temporary_store)
112
129
 
113
130
  def node_count(self) -> int:
114
131
  """Returns number of nodes referenced by this partition, including exterior halo"""
115
- return int(self._category_offsets[NodeCategory.HALO_OTHER_SIDE + 1])
132
+ return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
116
133
 
117
134
  def owned_node_count(self) -> int:
118
135
  """Returns number of nodes in this partition, excluding exterior halo"""
119
- return int(self._category_offsets[NodeCategory.OWNED_FRONTIER + 1])
136
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
120
137
 
121
138
  def interior_node_count(self) -> int:
122
139
  """Returns number of interior nodes in this partition"""
123
- return int(self._category_offsets[NodeCategory.OWNED_INTERIOR + 1])
140
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
124
141
 
125
142
  def space_node_indices(self):
126
143
  """Return the global function space indices for nodes in this partition"""
127
- return self._node_indices
144
+ return self._node_indices.array
128
145
 
146
+ @cached_arg_value
129
147
  def partition_arg_value(self, device):
130
148
  arg = NodePartition.PartitionArg()
131
- arg.space_to_partition = self._space_to_partition.to(device)
149
+ arg.space_to_partition = self._space_to_partition.array.to(device)
132
150
  return arg
133
151
 
134
152
  @wp.func
135
153
  def partition_node_index(args: PartitionArg, space_node_index: int):
136
154
  return args.space_to_partition[space_node_index]
137
155
 
138
- def _compute_node_indices_from_sides(self, device, with_halo: bool):
156
+ def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
139
157
  from warp.fem import cache
140
158
 
141
- trace_space = self.space.trace()
142
- NODES_PER_CELL = self.space.NODES_PER_ELEMENT
143
- NODES_PER_SIDE = trace_space.NODES_PER_ELEMENT
159
+ trace_topology = self.space_topology.trace()
160
+ NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
161
+ NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
144
162
 
145
- def node_category_from_cells_fn(
163
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
164
+ def node_category_from_cells_kernel(
165
+ geo_arg: self.geo_partition.geometry.CellArg,
146
166
  geo_partition_arg: self.geo_partition.CellArg,
147
- space_arg: self.space.SpaceArg,
167
+ space_arg: self.space_topology.TopologyArg,
148
168
  node_mask: wp.array(dtype=int),
149
169
  ):
150
170
  partition_cell_index = wp.tid()
@@ -152,12 +172,14 @@ class NodePartition(SpacePartition):
152
172
  cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
153
173
 
154
174
  for n in range(NODES_PER_CELL):
155
- space_nidx = self.space.element_node_index(space_arg, cell_index, n)
175
+ space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
156
176
  node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
157
177
 
158
- def node_category_from_owned_sides_fn(
178
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
179
+ def node_category_from_owned_sides_kernel(
180
+ geo_arg: self.geo_partition.geometry.SideArg,
159
181
  geo_partition_arg: self.geo_partition.SideArg,
160
- space_arg: trace_space.SpaceArg,
182
+ space_arg: trace_topology.TopologyArg,
161
183
  node_mask: wp.array(dtype=int),
162
184
  ):
163
185
  partition_side_index = wp.tid()
@@ -165,13 +187,16 @@ class NodePartition(SpacePartition):
165
187
  side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
166
188
 
167
189
  for n in range(NODES_PER_SIDE):
168
- space_nidx = trace_space.element_node_index(space_arg, side_index, n)
190
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
191
+
169
192
  if node_mask[space_nidx] == NodeCategory.EXTERIOR:
170
193
  node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
171
194
 
172
- def node_category_from_frontier_sides_fn(
195
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
196
+ def node_category_from_frontier_sides_kernel(
197
+ geo_arg: self.geo_partition.geometry.SideArg,
173
198
  geo_partition_arg: self.geo_partition.SideArg,
174
- space_arg: trace_space.SpaceArg,
199
+ space_arg: trace_topology.TopologyArg,
175
200
  node_mask: wp.array(dtype=int),
176
201
  ):
177
202
  frontier_side_index = wp.tid()
@@ -179,39 +204,28 @@ class NodePartition(SpacePartition):
179
204
  side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
180
205
 
181
206
  for n in range(NODES_PER_SIDE):
182
- space_nidx = trace_space.element_node_index(space_arg, side_index, n)
207
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
183
208
  if node_mask[space_nidx] == NodeCategory.EXTERIOR:
184
209
  node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
185
210
  elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
186
211
  node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
187
212
 
188
- node_category_from_cells_kernel = cache.get_kernel(
189
- node_category_from_cells_fn,
190
- suffix=f"{self.geo_partition.name}_{self.space.name}",
191
- )
192
- node_category_from_owned_sides_kernel = cache.get_kernel(
193
- node_category_from_owned_sides_fn,
194
- suffix=f"{self.geo_partition.name}_{self.space.name}",
195
- )
196
- node_category_from_frontier_sides_kernel = cache.get_kernel(
197
- node_category_from_frontier_sides_fn,
198
- suffix=f"{self.geo_partition.name}_{self.space.name}",
199
- )
200
-
201
- node_category = wp.empty(
202
- shape=(self.space.node_count(),),
213
+ node_category = borrow_temporary(
214
+ temporary_store,
215
+ shape=(self.space_topology.node_count(),),
203
216
  dtype=int,
204
217
  device=device,
205
218
  )
206
- node_category.fill_(value=NodeCategory.EXTERIOR)
219
+ node_category.array.fill_(value=NodeCategory.EXTERIOR)
207
220
 
208
221
  wp.launch(
209
222
  dim=self.geo_partition.cell_count(),
210
223
  kernel=node_category_from_cells_kernel,
211
224
  inputs=[
225
+ self.geo_partition.geometry.cell_arg_value(device),
212
226
  self.geo_partition.cell_arg_value(device),
213
- self.space.space_arg_value(device),
214
- node_category,
227
+ self.space_topology.topo_arg_value(device),
228
+ node_category.array,
215
229
  ],
216
230
  device=device,
217
231
  )
@@ -221,9 +235,10 @@ class NodePartition(SpacePartition):
221
235
  dim=self.geo_partition.side_count(),
222
236
  kernel=node_category_from_owned_sides_kernel,
223
237
  inputs=[
238
+ self.geo_partition.geometry.side_arg_value(device),
224
239
  self.geo_partition.side_arg_value(device),
225
- self.space.space_arg_value(device),
226
- node_category,
240
+ self.space_topology.topo_arg_value(device),
241
+ node_category.array,
227
242
  ],
228
243
  device=device,
229
244
  )
@@ -232,31 +247,52 @@ class NodePartition(SpacePartition):
232
247
  dim=self.geo_partition.frontier_side_count(),
233
248
  kernel=node_category_from_frontier_sides_kernel,
234
249
  inputs=[
250
+ self.geo_partition.geometry.side_arg_value(device),
235
251
  self.geo_partition.side_arg_value(device),
236
- self.space.space_arg_value(device),
237
- node_category,
252
+ self.space_topology.topo_arg_value(device),
253
+ node_category.array,
238
254
  ],
239
255
  device=device,
240
256
  )
241
257
 
242
- self._finalize_node_indices(node_category)
258
+ self._finalize_node_indices(node_category.array, temporary_store)
243
259
 
244
- def _finalize_node_indices(self, node_category: wp.array(dtype=int)):
260
+ node_category.release()
261
+
262
+ def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
245
263
  category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
246
- self._category_offsets = category_offsets.numpy()
247
264
 
248
- # Compute globla to local indices
249
- self._space_to_partition = node_category # Reuse array storage
265
+ # Copy offsets to cpu
266
+ device = node_category.device
267
+ self._category_offsets = borrow_temporary(
268
+ temporary_store,
269
+ shape=category_offsets.array.shape,
270
+ dtype=category_offsets.array.dtype,
271
+ pinned=device.is_cuda,
272
+ device="cpu",
273
+ )
274
+ wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
275
+
276
+ if device.is_cuda:
277
+ # TODO switch to synchronize_event once available
278
+ wp.synchronize_stream(wp.get_stream(device))
279
+
280
+ category_offsets.release()
281
+
282
+ # Compute global to local indices
283
+ self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
250
284
  wp.launch(
251
285
  kernel=NodePartition._scatter_partition_indices,
252
- dim=self.space.node_count(),
253
- device=self._space_to_partition.device,
254
- inputs=[self.node_count(), node_indices, self._space_to_partition],
286
+ dim=self.space_topology.node_count(),
287
+ device=device,
288
+ inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
255
289
  )
256
290
 
257
- # Copy to shrinked-to-fit array, save on memory
258
- self._node_indices = wp.empty(shape=(self.node_count()), dtype=int, device=node_indices.device)
259
- wp.copy(dest=self._node_indices, src=node_indices, count=self.node_count())
291
+ # Copy to shrinked-to-fit array
292
+ self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
293
+ wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
294
+
295
+ node_indices.release()
260
296
 
261
297
  @wp.kernel
262
298
  def _scatter_partition_indices(
@@ -274,16 +310,21 @@ class NodePartition(SpacePartition):
274
310
 
275
311
 
276
312
  def make_space_partition(
277
- space: FunctionSpace,
313
+ space: Optional[FunctionSpace] = None,
278
314
  geometry_partition: Optional[GeometryPartition] = None,
315
+ space_topology: Optional[SpaceTopology] = None,
279
316
  with_halo: bool = True,
280
317
  device=None,
318
+ temporary_store: TemporaryStore = None,
281
319
  ) -> SpacePartition:
282
- """Computes the substep of nodes from a function space that touch a geometry partition
320
+ """Computes the subset of nodes from a function space topology that touch a geometry partition
321
+
322
+ Either `space_topology` or `space` must be provided (and will be considered in that order).
283
323
 
284
324
  Args:
285
- space: the function space to consider
325
+ space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
286
326
  geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
327
+ space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
287
328
  with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
288
329
  device: Warp device on which to perform and store computations
289
330
 
@@ -291,7 +332,19 @@ def make_space_partition(
291
332
  the resulting space partition
292
333
  """
293
334
 
294
- if geometry_partition is not None and geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
295
- return NodePartition(space, geometry_partition, with_halo=with_halo, device=device)
335
+ if space_topology is None:
336
+ space_topology = space.topology
337
+
338
+ space_topology = space_topology.full_space_topology()
339
+
340
+ if geometry_partition is not None:
341
+ if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
342
+ return NodePartition(
343
+ space_topology=space_topology,
344
+ geo_partition=geometry_partition,
345
+ with_halo=with_halo,
346
+ device=device,
347
+ temporary_store=temporary_store,
348
+ )
296
349
 
297
- return WholeSpacePartition(space)
350
+ return WholeSpacePartition(space_topology)