warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
1
1
  from typing import Optional
2
2
 
3
+ import numpy as np
4
+
3
5
  import warp as wp
4
6
  from warp.fem import cache
5
7
  from warp.fem.geometry import Geometry
6
8
  from warp.fem.quadrature import Quadrature
7
- from warp.fem.types import Coords, ElementIndex, make_free_sample
9
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
8
10
 
9
11
  from .shape import ShapeFunction
10
- from .topology import DiscontinuousSpaceTopology, SpaceTopology
12
+ from .topology import RegularDiscontinuousSpaceTopology, SpaceTopology
11
13
 
12
14
 
13
15
  class BasisSpace:
@@ -28,8 +30,6 @@ class BasisSpace:
28
30
  def __init__(self, topology: SpaceTopology):
29
31
  self._topology = topology
30
32
 
31
- self.NODES_PER_ELEMENT = self._topology.NODES_PER_ELEMENT
32
-
33
33
  @property
34
34
  def topology(self) -> SpaceTopology:
35
35
  """Underlying topology of the basis space"""
@@ -49,8 +49,6 @@ class BasisSpace:
49
49
  def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
50
50
  """Returns a temporary array containing the world position for each node"""
51
51
 
52
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
53
-
54
52
  pos_type = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
55
53
 
56
54
  node_coords_in_element = self.make_node_coords_in_element()
@@ -64,7 +62,8 @@ class BasisSpace:
64
62
  ):
65
63
  element_index = wp.tid()
66
64
 
67
- for n in range(NODES_PER_ELEMENT):
65
+ element_node_count = self.topology.element_node_count(geo_cell_arg, topo_arg, element_index)
66
+ for n in range(element_node_count):
68
67
  node_index = self.topology.element_node_index(geo_cell_arg, topo_arg, element_index, n)
69
68
  coords = node_coords_in_element(geo_cell_arg, basis_arg, element_index, n)
70
69
 
@@ -139,6 +138,10 @@ class ShapeBasisSpace(BasisSpace):
139
138
  self.node_tets = self._node_tets
140
139
  if hasattr(shape, "element_node_hexes"):
141
140
  self.node_hexes = self._node_hexes
141
+ if hasattr(shape, "element_vtk_cells"):
142
+ self.vtk_cells = self._vtk_cells
143
+ if hasattr(topology, "node_grid"):
144
+ self.node_grid = topology.node_grid
142
145
 
143
146
  @property
144
147
  def shape(self) -> ShapeFunction:
@@ -245,6 +248,16 @@ class ShapeBasisSpace(BasisSpace):
245
248
  hex_indices = element_node_indices[:, element_hexes].reshape(-1, 8)
246
249
  return hex_indices
247
250
 
251
+ def _vtk_cells(self):
252
+ element_node_indices = self._topology.element_node_indices().numpy()
253
+ element_vtk_cells, element_vtk_cell_types = self._shape.element_vtk_cells()
254
+
255
+ idx_per_cell = element_vtk_cells.shape[1]
256
+ cell_indices = element_node_indices[:, element_vtk_cells].reshape(-1, idx_per_cell)
257
+ cells = np.hstack((np.full((cell_indices.shape[0], 1), idx_per_cell), cell_indices))
258
+
259
+ return cells.flatten(), np.tile(element_vtk_cell_types, element_node_indices.shape[0])
260
+
248
261
 
249
262
  class TraceBasisSpace(BasisSpace):
250
263
  """Auto-generated trace space evaluating the cell-defined basis on the geometry sides"""
@@ -302,7 +315,7 @@ class TraceBasisSpace(BasisSpace):
302
315
  node_index_in_elt: int,
303
316
  ):
304
317
  cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
305
- if index_in_cell < 0:
318
+ if cell_index == NULL_ELEMENT_INDEX:
306
319
  return 0.0
307
320
 
308
321
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
@@ -330,7 +343,7 @@ class TraceBasisSpace(BasisSpace):
330
343
  node_index_in_elt: int,
331
344
  ):
332
345
  cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
333
- if index_in_cell < 0:
346
+ if cell_index == NULL_ELEMENT_INDEX:
334
347
  return 0.0
335
348
 
336
349
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
@@ -359,7 +372,7 @@ class TraceBasisSpace(BasisSpace):
359
372
  node_index_in_elt: int,
360
373
  ):
361
374
  cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
362
- if index_in_cell < 0:
375
+ if cell_index == NULL_ELEMENT_INDEX:
363
376
  return grad_vec_type(0.0)
364
377
 
365
378
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
@@ -381,7 +394,7 @@ class TraceBasisSpace(BasisSpace):
381
394
  node_index_in_elt: int,
382
395
  ):
383
396
  cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
384
- if index_in_cell < 0:
397
+ if cell_index == NULL_ELEMENT_INDEX:
385
398
  return grad_vec_type(0.0)
386
399
 
387
400
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
@@ -419,7 +432,7 @@ class PiecewiseConstantBasisSpace(ShapeBasisSpace):
419
432
 
420
433
 
421
434
  def make_discontinuous_basis_space(geometry: Geometry, shape: ShapeFunction):
422
- topology = DiscontinuousSpaceTopology(geometry, shape.NODES_PER_ELEMENT)
435
+ topology = RegularDiscontinuousSpaceTopology(geometry, shape.NODES_PER_ELEMENT)
423
436
 
424
437
  if shape.NODES_PER_ELEMENT == 1:
425
438
  # piecewise-constant space
@@ -428,6 +441,70 @@ def make_discontinuous_basis_space(geometry: Geometry, shape: ShapeFunction):
428
441
  return ShapeBasisSpace(topology=topology, shape=shape)
429
442
 
430
443
 
444
+ class UnstructuredPointTopology(SpaceTopology):
445
+ """Topology for unstructured points defined from quadrature formula. See :class:`PointBasisSpace`"""
446
+
447
+ def __init__(self, quadrature: Quadrature):
448
+ if quadrature.max_points_per_element() is None:
449
+ raise ValueError("Quadrature must define a maximum number of points per element")
450
+
451
+ if quadrature.domain.element_count() != quadrature.domain.geometry_element_count():
452
+ raise ValueError("Point topology may only be defined on quadrature domains than span the whole geometry")
453
+
454
+ self._quadrature = quadrature
455
+ self.TopologyArg = quadrature.Arg
456
+
457
+ super().__init__(quadrature.domain.geometry, max_nodes_per_element=quadrature.max_points_per_element())
458
+
459
+ self.element_node_index = self._make_element_node_index()
460
+ self.element_node_count = self._make_element_node_count()
461
+ self.side_neighbor_node_counts = self._make_side_neighbor_node_counts()
462
+
463
+ def node_count(self):
464
+ return self._quadrature.total_point_count()
465
+
466
+ @property
467
+ def name(self):
468
+ return f"PointTopology_{self._quadrature}"
469
+
470
+ def topo_arg_value(self, device) -> SpaceTopology.TopologyArg:
471
+ """Value of the topology argument structure to be passed to device functions"""
472
+ return self._quadrature.arg_value(device)
473
+
474
+ def _make_element_node_index(self):
475
+ @cache.dynamic_func(suffix=self.name)
476
+ def element_node_index(
477
+ elt_arg: self.geometry.CellArg,
478
+ topo_arg: self.TopologyArg,
479
+ element_index: ElementIndex,
480
+ node_index_in_elt: int,
481
+ ):
482
+ return self._quadrature.point_index(elt_arg, topo_arg, element_index, element_index, node_index_in_elt)
483
+
484
+ return element_node_index
485
+
486
+ def _make_element_node_count(self):
487
+ @cache.dynamic_func(suffix=self.name)
488
+ def element_node_count(
489
+ elt_arg: self.geometry.CellArg,
490
+ topo_arg: self.TopologyArg,
491
+ element_index: ElementIndex,
492
+ ):
493
+ return self._quadrature.point_count(elt_arg, topo_arg, element_index, element_index)
494
+
495
+ return element_node_count
496
+
497
+ def _make_side_neighbor_node_counts(self):
498
+ @cache.dynamic_func(suffix=self.name)
499
+ def side_neighbor_node_counts(
500
+ side_arg: self.geometry.SideArg,
501
+ element_index: ElementIndex,
502
+ ):
503
+ return 0, 0
504
+
505
+ return side_neighbor_node_counts
506
+
507
+
431
508
  class PointBasisSpace(BasisSpace):
432
509
  """An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only.
433
510
 
@@ -437,12 +514,7 @@ class PointBasisSpace(BasisSpace):
437
514
  def __init__(self, quadrature: Quadrature):
438
515
  self._quadrature = quadrature
439
516
 
440
- if quadrature.points_per_element() is None:
441
- raise NotImplementedError("Varying number of points per element is not supported yet")
442
-
443
- topology = DiscontinuousSpaceTopology(
444
- geometry=quadrature.domain.geometry, nodes_per_element=quadrature.points_per_element()
445
- )
517
+ topology = UnstructuredPointTopology(quadrature)
446
518
  super().__init__(topology)
447
519
 
448
520
  self.BasisArg = quadrature.Arg
@@ -464,7 +536,7 @@ class PointBasisSpace(BasisSpace):
464
536
  element_index: ElementIndex,
465
537
  node_index_in_elt: int,
466
538
  ):
467
- return self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt)
539
+ return self._quadrature.point_coords(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
468
540
 
469
541
  return node_coords_in_element
470
542
 
@@ -476,11 +548,13 @@ class PointBasisSpace(BasisSpace):
476
548
  element_index: ElementIndex,
477
549
  node_index_in_elt: int,
478
550
  ):
479
- return self._quadrature.point_weight(elt_arg, basis_arg, element_index, node_index_in_elt)
551
+ return self._quadrature.point_weight(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
480
552
 
481
553
  return node_quadrature_weight
482
554
 
483
555
  def make_element_inner_weight(self):
556
+ _DIRAC_INTEGRATION_RADIUS = wp.constant(1.0e-6)
557
+
484
558
  @cache.dynamic_func(suffix=self.name)
485
559
  def element_inner_weight(
486
560
  elt_arg: self._quadrature.domain.ElementArg,
@@ -489,8 +563,10 @@ class PointBasisSpace(BasisSpace):
489
563
  coords: Coords,
490
564
  node_index_in_elt: int,
491
565
  ):
492
- qp_coord = self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt)
493
- return wp.select(wp.length_sq(coords - qp_coord) < 0.001, 0.0, 1.0)
566
+ qp_coord = self._quadrature.point_coords(
567
+ elt_arg, basis_arg, element_index, element_index, node_index_in_elt
568
+ )
569
+ return wp.select(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 0.0, 1.0)
494
570
 
495
571
  return element_inner_weight
496
572
 
@@ -36,7 +36,7 @@ class CollocatedFunctionSpace(FunctionSpace):
36
36
  self.element_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
37
37
 
38
38
  # For backward compatibility
39
- if hasattr(basis.topology, "node_grid"):
39
+ if hasattr(basis, "node_grid"):
40
40
  self.node_grid = basis.node_grid
41
41
  if hasattr(basis, "node_triangulation"):
42
42
  self.node_triangulation = basis.node_triangulation
@@ -44,6 +44,8 @@ class CollocatedFunctionSpace(FunctionSpace):
44
44
  self.node_tets = basis.node_tets
45
45
  if hasattr(basis, "node_hexes"):
46
46
  self.node_hexes = basis.node_hexes
47
+ if hasattr(basis, "vtk_cells"):
48
+ self.vtk_cells = basis.vtk_cells
47
49
 
48
50
  def space_arg_value(self, device):
49
51
  return self._basis.basis_arg_value(device)
@@ -1,6 +1,6 @@
1
1
  import warp as wp
2
2
  from warp.fem.geometry import Geometry
3
- from warp.fem.types import Coords, DofIndex, ElementIndex
3
+ from warp.fem.types import Coords, DofIndex, ElementIndex, ElementKind
4
4
 
5
5
  from .topology import SpaceTopology
6
6
 
@@ -47,6 +47,11 @@ class FunctionSpace:
47
47
  """Underlying geometry"""
48
48
  return self.topology.geometry
49
49
 
50
+ @property
51
+ def element_kind(self) -> ElementKind:
52
+ """Kind of element the function space is expressed over"""
53
+ return ElementKind.CELL if self.dimension == self.geometry.dimension else ElementKind.SIDE
54
+
50
55
  @property
51
56
  def dimension(self) -> int:
52
57
  """Function space embedding dimension"""
@@ -71,7 +76,7 @@ class FunctionSpace:
71
76
  def make_field(self, space_partition=None):
72
77
  """Creates a zero-initialized discrete field over the function space holding values for all degrees of freedom of nodes in a space partition
73
78
 
74
- space_arg:
79
+ Args:
75
80
  space_partition: If provided, the subset of nodes to consider
76
81
 
77
82
  See also: :func:`make_space_partition`
@@ -72,7 +72,7 @@ class GridBipolynomialSpaceTopology(Grid2DSpaceTopology):
72
72
 
73
73
  return element_node_index
74
74
 
75
- def _node_grid(self):
75
+ def node_grid(self):
76
76
  res = self.geometry.res
77
77
 
78
78
  cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
@@ -81,13 +81,13 @@ class GridBipolynomialSpaceTopology(Grid2DSpaceTopology):
81
81
  cell_coords, reps=res[0]
82
82
  )
83
83
  grid_coords_x = np.append(grid_coords_x, res[0])
84
- X = grid_coords_x * self._grid.cell_size[0] + self._grid.origin[0]
84
+ X = grid_coords_x * self.geometry.cell_size[0] + self.geometry.origin[0]
85
85
 
86
86
  grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
87
87
  cell_coords, reps=res[1]
88
88
  )
89
89
  grid_coords_y = np.append(grid_coords_y, res[1])
90
- Y = grid_coords_y * self._grid.cell_size[1] + self._grid.origin[1]
90
+ Y = grid_coords_y * self.geometry.cell_size[1] + self.geometry.origin[1]
91
91
 
92
92
  return np.meshgrid(X, Y, indexing="ij")
93
93
 
@@ -79,7 +79,7 @@ class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
79
79
 
80
80
  return element_node_index
81
81
 
82
- def _node_grid(self):
82
+ def node_grid(self):
83
83
  res = self.geometry.res
84
84
 
85
85
  cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
@@ -88,19 +88,19 @@ class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
88
88
  cell_coords, reps=res[0]
89
89
  )
90
90
  grid_coords_x = np.append(grid_coords_x, res[0])
91
- X = grid_coords_x * self._grid.cell_size[0] + self._grid.origin[0]
91
+ X = grid_coords_x * self.geometry.cell_size[0] + self.geometry.origin[0]
92
92
 
93
93
  grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
94
94
  cell_coords, reps=res[1]
95
95
  )
96
96
  grid_coords_y = np.append(grid_coords_y, res[1])
97
- Y = grid_coords_y * self._grid.cell_size[1] + self._grid.origin[1]
97
+ Y = grid_coords_y * self.geometry.cell_size[1] + self.geometry.origin[1]
98
98
 
99
99
  grid_coords_z = np.repeat(np.arange(0, res[2], dtype=float), len(cell_coords)) + np.tile(
100
100
  cell_coords, reps=res[2]
101
101
  )
102
102
  grid_coords_z = np.append(grid_coords_z, res[2])
103
- Z = grid_coords_z * self._grid.cell_size[2] + self._grid.origin[2]
103
+ Z = grid_coords_z * self.geometry.cell_size[2] + self.geometry.origin[2]
104
104
 
105
105
  return np.meshgrid(X, Y, Z, indexing="ij")
106
106
 
@@ -142,8 +142,9 @@ class HexmeshTripolynomialSpaceTopology(HexmeshSpaceTopology):
142
142
 
143
143
  fv = ori // 2
144
144
 
145
- rot_i = wp.dot(_FACE_ORIENTATION_I[2 * ori], coords) + _FACE_TRANSLATION_I[fv, 0]
146
- rot_j = wp.dot(_FACE_ORIENTATION_I[2 * ori + 1], coords) + _FACE_TRANSLATION_I[fv, 1]
145
+ # face indices from shape function always have positive orientation, drop `ori % 2`
146
+ rot_i = wp.dot(_FACE_ORIENTATION_I[4 * fv], coords) + _FACE_TRANSLATION_I[fv, 0]
147
+ rot_j = wp.dot(_FACE_ORIENTATION_I[4 * fv + 1], coords) + _FACE_TRANSLATION_I[fv, 1]
147
148
 
148
149
  return rot_i * size + rot_j
149
150
 
@@ -41,25 +41,23 @@ class NanogridSpaceTopology(SpaceTopology):
41
41
  self._grid = grid
42
42
  self._shape = shape
43
43
 
44
- if need_edge_indices:
45
- self._edge_count = self._grid.edge_count()
46
- else:
47
- self._edge_count = 0
44
+ self._vertex_grid = grid.vertex_grid.id
48
45
 
49
- self._vertex_grid = grid._node_grid
50
- self._face_grid = grid._face_grid
51
- self._edge_grid = grid._edge_grid
46
+ self._edge_grid = grid.edge_grid.id if need_edge_indices else -1
47
+ self._face_grid = grid.face_grid.id if need_face_indices else -1
48
+ self._edge_count = grid.edge_count() if need_edge_indices else 0
49
+ self._face_count = grid.side_count() if need_face_indices else 0
52
50
 
53
51
  @cache.cached_arg_value
54
52
  def topo_arg_value(self, device):
55
53
  arg = NanogridTopologyArg()
56
54
 
57
- arg.vertex_grid = self._vertex_grid.id
58
- arg.face_grid = self._face_grid.id
59
- arg.edge_grid = -1 if self._edge_grid is None else self._edge_grid.id
55
+ arg.vertex_grid = self._vertex_grid
56
+ arg.face_grid = self._face_grid
57
+ arg.edge_grid = self._edge_grid
60
58
 
61
59
  arg.vertex_count = self._grid.vertex_count()
62
- arg.face_count = self._grid.side_count()
60
+ arg.face_count = self._face_count
63
61
  arg.edge_count = self._edge_count
64
62
  return arg
65
63
 
@@ -98,8 +96,8 @@ class NanogridTripolynomialSpaceTopology(NanogridSpaceTopology):
98
96
 
99
97
  return (
100
98
  self._grid.vertex_count()
101
- + self._grid.edge_count() * INTERIOR_NODES_PER_EDGE
102
- + self._grid.side_count() * INTERIOR_NODES_PER_FACE
99
+ + self._edge_count * INTERIOR_NODES_PER_EDGE
100
+ + self._face_count * INTERIOR_NODES_PER_FACE
103
101
  + self._grid.cell_count() * INTERIOR_NODES_PER_CELL
104
102
  )
105
103
 
@@ -160,7 +158,7 @@ class NanogridSerendipitySpaceTopology(NanogridSpaceTopology):
160
158
  self.element_node_index = self._make_element_node_index()
161
159
 
162
160
  def node_count(self) -> int:
163
- return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.edge_count()
161
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self._edge_count
164
162
 
165
163
  def _make_element_node_index(self):
166
164
  ORDER = self._shape.ORDER
@@ -1,12 +1,7 @@
1
1
  from typing import Any, Optional
2
2
 
3
3
  import warp as wp
4
- from warp.fem.cache import (
5
- TemporaryStore,
6
- borrow_temporary,
7
- borrow_temporary_like,
8
- cached_arg_value,
9
- )
4
+ import warp.fem.cache as cache
10
5
  from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
11
6
  from warp.fem.types import NULL_NODE_INDEX
12
7
  from warp.fem.utils import _iota_kernel, compress_node_indices
@@ -42,7 +37,7 @@ class SpacePartition:
42
37
 
43
38
  @staticmethod
44
39
  def partition_node_index(args: "PartitionArg", space_node_index: int):
45
- """Returns the index in the partition of a function space node, or -1 if it does not exist"""
40
+ """Returns the index in the partition of a function space node, or ``NULL_NODE_INDEX`` if it does not exist"""
46
41
 
47
42
  def __str__(self) -> str:
48
43
  return self.name
@@ -76,7 +71,7 @@ class WholeSpacePartition(SpacePartition):
76
71
  def space_node_indices(self):
77
72
  """Return the global function space indices for nodes in this partition"""
78
73
  if self._node_indices is None:
79
- self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
74
+ self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
80
75
  wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
81
76
  return self._node_indices.array
82
77
 
@@ -121,7 +116,7 @@ class NodePartition(SpacePartition):
121
116
  geo_partition: GeometryPartition,
122
117
  with_halo: bool = True,
123
118
  device=None,
124
- temporary_store: TemporaryStore = None,
119
+ temporary_store: cache.TemporaryStore = None,
125
120
  ):
126
121
  super().__init__(space_topology=space_topology, geo_partition=geo_partition)
127
122
 
@@ -143,7 +138,7 @@ class NodePartition(SpacePartition):
143
138
  """Return the global function space indices for nodes in this partition"""
144
139
  return self._node_indices.array
145
140
 
146
- @cached_arg_value
141
+ @cache.cached_arg_value
147
142
  def partition_arg_value(self, device):
148
143
  arg = NodePartition.PartitionArg()
149
144
  arg.space_to_partition = self._space_to_partition.array.to(device)
@@ -153,12 +148,10 @@ class NodePartition(SpacePartition):
153
148
  def partition_node_index(args: PartitionArg, space_node_index: int):
154
149
  return args.space_to_partition[space_node_index]
155
150
 
156
- def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
151
+ def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: cache.TemporaryStore):
157
152
  from warp.fem import cache
158
153
 
159
154
  trace_topology = self.space_topology.trace()
160
- NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
161
- NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
162
155
 
163
156
  @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
164
157
  def node_category_from_cells_kernel(
@@ -171,7 +164,8 @@ class NodePartition(SpacePartition):
171
164
 
172
165
  cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
173
166
 
174
- for n in range(NODES_PER_CELL):
167
+ cell_node_count = self.space_topology.element_node_count(geo_arg, space_arg, cell_index)
168
+ for n in range(cell_node_count):
175
169
  space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
176
170
  node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
177
171
 
@@ -186,7 +180,8 @@ class NodePartition(SpacePartition):
186
180
 
187
181
  side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
188
182
 
189
- for n in range(NODES_PER_SIDE):
183
+ side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
184
+ for n in range(side_node_count):
190
185
  space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
191
186
 
192
187
  if node_mask[space_nidx] == NodeCategory.EXTERIOR:
@@ -203,14 +198,15 @@ class NodePartition(SpacePartition):
203
198
 
204
199
  side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
205
200
 
206
- for n in range(NODES_PER_SIDE):
201
+ side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
202
+ for n in range(side_node_count):
207
203
  space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
208
204
  if node_mask[space_nidx] == NodeCategory.EXTERIOR:
209
205
  node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
210
206
  elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
211
207
  node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
212
208
 
213
- node_category = borrow_temporary(
209
+ node_category = cache.borrow_temporary(
214
210
  temporary_store,
215
211
  shape=(self.space_topology.node_count(),),
216
212
  dtype=int,
@@ -259,50 +255,52 @@ class NodePartition(SpacePartition):
259
255
 
260
256
  node_category.release()
261
257
 
262
- def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
263
- category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
258
+ def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: cache.TemporaryStore):
259
+ category_offsets, node_indices = compress_node_indices(
260
+ NodeCategory.COUNT, node_category, temporary_store=temporary_store
261
+ )
264
262
 
265
263
  # Copy offsets to cpu
266
264
  device = node_category.device
267
- self._category_offsets = borrow_temporary(
268
- temporary_store,
269
- shape=category_offsets.array.shape,
270
- dtype=category_offsets.array.dtype,
271
- pinned=device.is_cuda,
272
- device="cpu",
273
- )
274
- wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
275
-
276
- if device.is_cuda:
277
- # TODO switch to synchronize_event once available
278
- wp.synchronize_stream(wp.get_stream(device))
279
-
280
- category_offsets.release()
265
+ with wp.ScopedDevice(device):
266
+ self._category_offsets = cache.borrow_temporary(
267
+ temporary_store,
268
+ shape=category_offsets.array.shape,
269
+ dtype=category_offsets.array.dtype,
270
+ pinned=device.is_cuda,
271
+ device="cpu",
272
+ )
273
+ wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
274
+ copy_event = cache.capture_event()
281
275
 
282
- # Compute global to local indices
283
- self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
284
- wp.launch(
285
- kernel=NodePartition._scatter_partition_indices,
286
- dim=self.space_topology.node_count(),
287
- device=device,
288
- inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
289
- )
276
+ # Compute global to local indices
277
+ self._space_to_partition = cache.borrow_temporary_like(node_indices, temporary_store)
278
+ wp.launch(
279
+ kernel=NodePartition._scatter_partition_indices,
280
+ dim=self.space_topology.node_count(),
281
+ device=device,
282
+ inputs=[category_offsets.array, node_indices.array, self._space_to_partition.array],
283
+ )
290
284
 
291
- # Copy to shrinked-to-fit array
292
- self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
293
- wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
285
+ # Copy to shrinked-to-fit array
286
+ cache.synchronize_event(copy_event) # Transfer to host must be finished to access node_count()
287
+ self._node_indices = cache.borrow_temporary(
288
+ temporary_store, shape=(self.node_count()), dtype=int, device=device
289
+ )
290
+ wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
294
291
 
295
- node_indices.release()
292
+ node_indices.release()
296
293
 
297
294
  @wp.kernel
298
295
  def _scatter_partition_indices(
299
- local_node_count: int,
296
+ category_offsets: wp.array(dtype=int),
300
297
  node_indices: wp.array(dtype=int),
301
298
  space_to_partition_indices: wp.array(dtype=int),
302
299
  ):
303
300
  local_idx = wp.tid()
304
301
  space_idx = node_indices[local_idx]
305
302
 
303
+ local_node_count = category_offsets[NodeCategory.EXTERIOR] # all but exterior nodes
306
304
  if local_idx < local_node_count:
307
305
  space_to_partition_indices[space_idx] = local_idx
308
306
  else:
@@ -315,7 +313,7 @@ def make_space_partition(
315
313
  space_topology: Optional[SpaceTopology] = None,
316
314
  with_halo: bool = True,
317
315
  device=None,
318
- temporary_store: TemporaryStore = None,
316
+ temporary_store: cache.TemporaryStore = None,
319
317
  ) -> SpacePartition:
320
318
  """Computes the subset of nodes from a function space topology that touch a geometry partition
321
319