warp-lang 1.2.2__py3-none-win_amd64.whl → 1.3.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1412 -888
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +400 -198
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +91 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +65 -1
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +341 -224
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.1.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.2.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
- from typing import Optional, Type
1
+ from typing import Optional, Tuple, Type
2
2
 
3
3
  import warp as wp
4
4
  from warp.fem import cache
5
5
  from warp.fem.geometry import DeformedGeometry, Geometry
6
- from warp.fem.types import ElementIndex
6
+ from warp.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX, ElementIndex
7
7
 
8
8
 
9
9
  class SpaceTopology:
@@ -18,8 +18,8 @@ class SpaceTopology:
18
18
  dimension: int
19
19
  """Embedding dimension of the function space"""
20
20
 
21
- NODES_PER_ELEMENT: int
22
- """Number of interpolation nodes per element of the geometry.
21
+ MAX_NODES_PER_ELEMENT: int
22
+ """maximum number of interpolation nodes per element of the geometry.
23
23
 
24
24
  .. note:: This will change to be defined per-element in future versions
25
25
  """
@@ -30,12 +30,14 @@ class SpaceTopology:
30
30
 
31
31
  pass
32
32
 
33
- def __init__(self, geometry: Geometry, nodes_per_element: int):
33
+ def __init__(self, geometry: Geometry, max_nodes_per_element: int):
34
34
  self._geometry = geometry
35
35
  self.dimension = geometry.dimension
36
- self.NODES_PER_ELEMENT = wp.constant(nodes_per_element)
36
+ self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
37
37
  self.ElementArg = geometry.CellArg
38
38
 
39
+ self._make_constant_element_node_count()
40
+
39
41
  @property
40
42
  def geometry(self) -> Geometry:
41
43
  """Underlying geometry"""
@@ -51,25 +53,42 @@ class SpaceTopology:
51
53
 
52
54
  @property
53
55
  def name(self):
54
- return f"{self.__class__.__name__}_{self.NODES_PER_ELEMENT}"
56
+ return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
55
57
 
56
58
  def __str__(self):
57
59
  return self.name
58
60
 
61
+ @staticmethod
62
+ def element_node_count(
63
+ geo_arg: "ElementArg", # noqa: F821
64
+ topo_arg: "TopologyArg",
65
+ element_index: ElementIndex,
66
+ ) -> int:
67
+ """Returns the actual number of nodes in a given element"""
68
+ raise NotImplementedError
69
+
59
70
  @staticmethod
60
71
  def element_node_index(
61
72
  geo_arg: "ElementArg", # noqa: F821
62
73
  topo_arg: "TopologyArg",
63
74
  element_index: ElementIndex,
64
75
  node_index_in_elt: int,
65
- ):
76
+ ) -> int:
66
77
  """Global node index for a given node in a given element"""
67
78
  raise NotImplementedError
68
79
 
80
+ @staticmethod
81
+ def side_neighbor_node_counts(
82
+ side_arg: "ElementArg", # noqa: F821
83
+ side_index: ElementIndex,
84
+ ) -> Tuple[int, int]:
85
+ """Returns the number of nodes for both the inner and outer cells of a given sides"""
86
+ raise NotImplementedError
87
+
69
88
  def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array:
70
89
  """Returns a temporary array containing the global index for each node of each element"""
71
90
 
72
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
91
+ MAX_NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
73
92
 
74
93
  @cache.dynamic_kernel(suffix=self.name)
75
94
  def fill_element_node_indices(
@@ -78,12 +97,13 @@ class SpaceTopology:
78
97
  element_node_indices: wp.array2d(dtype=int),
79
98
  ):
80
99
  element_index = wp.tid()
81
- for n in range(NODES_PER_ELEMENT):
100
+ element_node_count = self.element_node_count(geo_cell_arg, topo_arg, element_index)
101
+ for n in range(element_node_count):
82
102
  element_node_indices[element_index, n] = self.element_node_index(
83
103
  geo_cell_arg, topo_arg, element_index, n
84
104
  )
85
105
 
86
- shape = (self.geometry.cell_count(), NODES_PER_ELEMENT)
106
+ shape = (self.geometry.cell_count(), MAX_NODES_PER_ELEMENT)
87
107
  if out is None:
88
108
  element_node_indices = wp.empty(
89
109
  shape=shape,
@@ -135,14 +155,36 @@ class SpaceTopology:
135
155
  return self.full_space_topology() == other
136
156
  return False
137
157
 
158
+ def _make_constant_element_node_count(self):
159
+ NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
160
+
161
+ @cache.dynamic_func(suffix=self.name)
162
+ def constant_element_node_count(
163
+ geo_arg: self.geometry.CellArg,
164
+ topo_arg: self.TopologyArg,
165
+ element_index: ElementIndex,
166
+ ):
167
+ return NODES_PER_ELEMENT
168
+
169
+ @cache.dynamic_func(suffix=self.name)
170
+ def constant_side_neighbor_node_counts(
171
+ side_arg: self.geometry.SideArg,
172
+ element_index: ElementIndex,
173
+ ):
174
+ return NODES_PER_ELEMENT, NODES_PER_ELEMENT
175
+
176
+ self.element_node_count = constant_element_node_count
177
+ self.side_neighbor_node_counts = constant_side_neighbor_node_counts
178
+
138
179
 
139
180
  class TraceSpaceTopology(SpaceTopology):
140
181
  """Auto-generated trace topology defining the node indices associated to the geometry sides"""
141
182
 
142
183
  def __init__(self, topo: SpaceTopology):
143
- super().__init__(topo.geometry, 2 * topo.NODES_PER_ELEMENT)
144
-
145
184
  self._topo = topo
185
+
186
+ super().__init__(topo.geometry, 2 * topo.MAX_NODES_PER_ELEMENT)
187
+
146
188
  self.dimension = topo.dimension - 1
147
189
  self.ElementArg = topo.geometry.SideArg
148
190
 
@@ -154,6 +196,8 @@ class TraceSpaceTopology(SpaceTopology):
154
196
  self.neighbor_cell_index = self._make_neighbor_cell_index()
155
197
 
156
198
  self.element_node_index = self._make_element_node_index()
199
+ self.element_node_count = self._make_element_node_count()
200
+ self.side_neighbor_node_counts = None
157
201
 
158
202
  def node_count(self) -> int:
159
203
  return self._topo.node_count()
@@ -163,39 +207,51 @@ class TraceSpaceTopology(SpaceTopology):
163
207
  return f"{self._topo.name}_Trace"
164
208
 
165
209
  def _make_inner_cell_index(self):
166
- NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
167
-
168
210
  @cache.dynamic_func(suffix=self.name)
169
- def inner_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
170
- index_in_inner_cell = wp.select(node_index_in_elt < NODES_PER_ELEMENT, -1, node_index_in_elt)
171
- return self.geometry.side_inner_cell_index(args, element_index), index_in_inner_cell
211
+ def inner_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
212
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
213
+ if node_index_in_elt >= inner_count:
214
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
215
+ return self.geometry.side_inner_cell_index(side_arg, element_index), node_index_in_elt
172
216
 
173
217
  return inner_cell_index
174
218
 
175
219
  def _make_outer_cell_index(self):
176
- NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
177
-
178
220
  @cache.dynamic_func(suffix=self.name)
179
- def outer_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
180
- return self.geometry.side_outer_cell_index(args, element_index), node_index_in_elt - NODES_PER_ELEMENT
221
+ def outer_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
222
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
223
+ if node_index_in_elt < inner_count:
224
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
225
+ return self.geometry.side_outer_cell_index(side_arg, element_index), node_index_in_elt - inner_count
181
226
 
182
227
  return outer_cell_index
183
228
 
184
229
  def _make_neighbor_cell_index(self):
185
- NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
186
-
187
230
  @cache.dynamic_func(suffix=self.name)
188
- def neighbor_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
189
- if node_index_in_elt < NODES_PER_ELEMENT:
190
- return self.geometry.side_inner_cell_index(args, element_index), node_index_in_elt
191
- else:
192
- return (
193
- self.geometry.side_outer_cell_index(args, element_index),
194
- node_index_in_elt - NODES_PER_ELEMENT,
195
- )
231
+ def neighbor_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
232
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
233
+ if node_index_in_elt < inner_count:
234
+ return self.geometry.side_inner_cell_index(side_arg, element_index), node_index_in_elt
235
+
236
+ return (
237
+ self.geometry.side_outer_cell_index(side_arg, element_index),
238
+ node_index_in_elt - inner_count,
239
+ )
196
240
 
197
241
  return neighbor_cell_index
198
242
 
243
+ def _make_element_node_count(self):
244
+ @cache.dynamic_func(suffix=self.name)
245
+ def trace_element_node_count(
246
+ geo_side_arg: self.geometry.SideArg,
247
+ topo_arg: self._topo.TopologyArg,
248
+ element_index: ElementIndex,
249
+ ):
250
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(geo_side_arg, element_index)
251
+ return inner_count + outer_count
252
+
253
+ return trace_element_node_count
254
+
199
255
  def _make_element_node_index(self):
200
256
  @cache.dynamic_func(suffix=self.name)
201
257
  def trace_element_node_index(
@@ -219,7 +275,7 @@ class TraceSpaceTopology(SpaceTopology):
219
275
  return self._topo == other._topo
220
276
 
221
277
 
222
- class DiscontinuousSpaceTopologyMixin:
278
+ class RegularDiscontinuousSpaceTopologyMixin:
223
279
  """Helper for defining discontinuous topologies (per-element nodes)"""
224
280
 
225
281
  def __init__(self, *args, **kwargs):
@@ -227,14 +283,14 @@ class DiscontinuousSpaceTopologyMixin:
227
283
  self.element_node_index = self._make_element_node_index()
228
284
 
229
285
  def node_count(self):
230
- return self.geometry.cell_count() * self.NODES_PER_ELEMENT
286
+ return self.geometry.cell_count() * self.MAX_NODES_PER_ELEMENT
231
287
 
232
288
  @property
233
289
  def name(self):
234
- return f"{self.geometry.name}_D{self.NODES_PER_ELEMENT}"
290
+ return f"{self.geometry.name}_D{self.MAX_NODES_PER_ELEMENT}"
235
291
 
236
292
  def _make_element_node_index(self):
237
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
293
+ NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
238
294
 
239
295
  @cache.dynamic_func(suffix=self.name)
240
296
  def element_node_index(
@@ -248,7 +304,7 @@ class DiscontinuousSpaceTopologyMixin:
248
304
  return element_node_index
249
305
 
250
306
 
251
- class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology):
307
+ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin, SpaceTopology):
252
308
  """Topology for generic discontinuous spaces"""
253
309
 
254
310
  pass
@@ -256,20 +312,20 @@ class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology)
256
312
 
257
313
  class DeformedGeometrySpaceTopology(SpaceTopology):
258
314
  def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
259
- super().__init__(geometry, base_topology.NODES_PER_ELEMENT)
260
-
261
315
  self.base = base_topology
316
+ super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
317
+
262
318
  self.node_count = self.base.node_count
263
319
  self.topo_arg_value = self.base.topo_arg_value
264
320
  self.TopologyArg = self.base.TopologyArg
265
321
 
266
- self.element_node_index = self._make_element_node_index()
322
+ self._make_passthrough_functions()
267
323
 
268
324
  @property
269
325
  def name(self):
270
326
  return f"{self.base.name}_{self.geometry.field.name}"
271
327
 
272
- def _make_element_node_index(self):
328
+ def _make_passthrough_functions(self):
273
329
  @cache.dynamic_func(suffix=self.name)
274
330
  def element_node_index(
275
331
  elt_arg: self.geometry.CellArg,
@@ -279,7 +335,25 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
279
335
  ):
280
336
  return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
281
337
 
282
- return element_node_index
338
+ @cache.dynamic_func(suffix=self.name)
339
+ def element_node_count(
340
+ elt_arg: self.geometry.CellArg,
341
+ topo_arg: self.TopologyArg,
342
+ element_count: ElementIndex,
343
+ ):
344
+ return self.base.element_node_count(elt_arg.elt_arg, topo_arg, element_count)
345
+
346
+ @cache.dynamic_func(suffix=self.name)
347
+ def side_neighbor_node_counts(
348
+ side_arg: self.geometry.SideArg,
349
+ element_index: ElementIndex,
350
+ ):
351
+ inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
352
+ return inner_count, outer_count
353
+
354
+ self.element_node_index = element_node_index
355
+ self.element_node_count = element_node_count
356
+ self.side_neighbor_node_counts = side_neighbor_node_counts
283
357
 
284
358
 
285
359
  def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
warp/fem/types.py CHANGED
@@ -1,3 +1,5 @@
1
+ from enum import Enum
2
+
1
3
  import warp as wp
2
4
 
3
5
  # kept to avoid breaking existing example code, no longer used internally
@@ -14,7 +16,7 @@ NodeIndex = int
14
16
 
15
17
  NULL_ELEMENT_INDEX = wp.constant(-1)
16
18
  NULL_QP_INDEX = wp.constant(-1)
17
- NULL_NODE_INDEX = wp.constant(-1)
19
+ NULL_NODE_INDEX = wp.constant((1 << 31) - 1) # this should be larger than normal nodes when sorting
18
20
 
19
21
  DofIndex = wp.vec2i
20
22
  """Opaque descriptor for indexing degrees of freedom within elements"""
@@ -31,6 +33,11 @@ def get_node_coord(dof_idx: DofIndex):
31
33
  return dof_idx[1]
32
34
 
33
35
 
36
+ class ElementKind(Enum):
37
+ CELL = 0
38
+ SIDE = 1
39
+
40
+
34
41
  @wp.struct
35
42
  class NodeElementIndex:
36
43
  domain_element_index: ElementIndex