warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,14 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, Optional
16
+ from typing import Any, ClassVar, Optional
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
20
20
  from warp.fem.geometry import Geometry
21
- from warp.fem.linalg import basis_element, generalized_inner, generalized_outer
21
+ from warp.fem.linalg import generalized_inner, generalized_outer
22
22
  from warp.fem.types import NULL_QP_INDEX, Coords, ElementIndex, make_free_sample
23
+ from warp.fem.utils import type_basis_element
23
24
 
24
25
  from .basis_space import BasisSpace
25
26
  from .dof_mapper import DofMapper, IdentityMapper
@@ -30,6 +31,21 @@ from .partition import SpacePartition, make_space_partition
30
31
  class CollocatedFunctionSpace(FunctionSpace):
31
32
  """Function space where values are collocated at nodes"""
32
33
 
34
+ _dynamic_attribute_constructors: ClassVar = {
35
+ "node_basis_element": lambda obj: obj._make_node_basis_element(),
36
+ "value_basis_element": lambda obj: obj._make_value_basis_element(),
37
+ "node_coords_in_element": lambda obj: obj._basis.make_node_coords_in_element(),
38
+ "node_quadrature_weight": lambda obj: obj._basis.make_node_quadrature_weight(),
39
+ "element_inner_weight": lambda obj: obj._basis.make_element_inner_weight(),
40
+ "element_inner_weight_gradient": lambda obj: obj._basis.make_element_inner_weight_gradient(),
41
+ "element_outer_weight": lambda obj: obj._basis.make_element_outer_weight(),
42
+ "element_outer_weight_gradient": lambda obj: obj._basis.make_element_outer_weight_gradient(),
43
+ "space_value": lambda obj: obj._make_space_value(),
44
+ "space_gradient": lambda obj: obj._make_space_gradient(),
45
+ "space_divergence": lambda obj: obj._make_space_divergence(),
46
+ "node_dof_value": lambda obj: obj._make_node_dof_value(),
47
+ }
48
+
33
49
  @wp.struct
34
50
  class LocalValueMap:
35
51
  pass
@@ -47,24 +63,11 @@ class CollocatedFunctionSpace(FunctionSpace):
47
63
 
48
64
  self.SpaceArg = self._basis.BasisArg
49
65
  self.space_arg_value = self._basis.basis_arg_value
66
+ self.fill_space_arg = self._basis.fill_basis_arg
50
67
 
51
68
  self.ORDER = self._basis.ORDER
52
69
 
53
- self.node_basis_element = self._make_node_basis_element()
54
- self.value_basis_element = self._make_value_basis_element()
55
-
56
- self.node_coords_in_element = self._basis.make_node_coords_in_element()
57
- self.node_quadrature_weight = self._basis.make_node_quadrature_weight()
58
- self.element_inner_weight = self._basis.make_element_inner_weight()
59
- self.element_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
60
- self.element_outer_weight = self._basis.make_element_outer_weight()
61
- self.element_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
62
-
63
- self.space_value = self._make_space_value()
64
- self.space_gradient = self._make_space_gradient()
65
- self.space_divergence = self._make_space_divergence()
66
-
67
- self.node_dof_value = self._make_node_dof_value()
70
+ cache.setup_dynamic_attributes(self)
68
71
 
69
72
  # For backward compatibility
70
73
  if hasattr(basis, "node_grid"):
@@ -100,11 +103,8 @@ class CollocatedFunctionSpace(FunctionSpace):
100
103
  return CollocatedFunctionSpaceTrace(self)
101
104
 
102
105
  def _make_node_basis_element(self):
103
- @cache.dynamic_func(suffix=self.name)
104
- def node_basis_element(dof_coord: int):
105
- return basis_element(self.dof_dtype(0.0), dof_coord)
106
-
107
- return node_basis_element
106
+ basis_element = type_basis_element(self.dof_dtype)
107
+ return basis_element
108
108
 
109
109
  def _make_value_basis_element(self):
110
110
  @cache.dynamic_func(suffix=self.name)
@@ -196,19 +196,34 @@ class CollocatedFunctionSpaceTrace(CollocatedFunctionSpace):
196
196
  class VectorValuedFunctionSpace(FunctionSpace):
197
197
  """Function space whose values are vectors"""
198
198
 
199
+ _dynamic_attribute_constructors: ClassVar = {
200
+ "value_basis_element": lambda obj: obj._make_value_basis_element(),
201
+ "node_coords_in_element": lambda obj: obj._basis.make_node_coords_in_element(),
202
+ "node_quadrature_weight": lambda obj: obj._basis.make_node_quadrature_weight(),
203
+ "element_inner_weight": lambda obj: obj._basis.make_element_inner_weight(),
204
+ "element_inner_weight_gradient": lambda obj: obj._basis.make_element_inner_weight_gradient(),
205
+ "element_outer_weight": lambda obj: obj._basis.make_element_outer_weight(),
206
+ "element_outer_weight_gradient": lambda obj: obj._basis.make_element_outer_weight_gradient(),
207
+ "space_value": lambda obj: obj._make_space_value(),
208
+ "space_gradient": lambda obj: obj._make_space_gradient(),
209
+ "space_divergence": lambda obj: obj._make_space_divergence(),
210
+ "node_dof_value": lambda obj: obj._make_node_dof_value(),
211
+ }
212
+
199
213
  def __init__(self, basis: BasisSpace):
200
214
  self._basis = basis
201
215
 
202
216
  super().__init__(topology=basis.topology)
203
217
 
204
218
  self.dtype = cache.cached_vec_type(self.geometry.dimension, dtype=float)
205
- self.dof_dtype = float
219
+ self.dof_dtype = wp.float32
206
220
 
207
221
  self.VALUE_DOF_COUNT = self.geometry.dimension
208
222
  self.NODE_DOF_COUNT = 1
209
223
 
210
224
  self.SpaceArg = self._basis.BasisArg
211
225
  self.space_arg_value = self._basis.basis_arg_value
226
+ self.fill_space_arg = self._basis.fill_basis_arg
212
227
 
213
228
  self.ORDER = self._basis.ORDER
214
229
 
@@ -216,20 +231,7 @@ class VectorValuedFunctionSpace(FunctionSpace):
216
231
  shape=(self.geometry.dimension, self.geometry.cell_dimension), dtype=float
217
232
  )
218
233
 
219
- self.value_basis_element = self._make_value_basis_element()
220
-
221
- self.node_coords_in_element = self._basis.make_node_coords_in_element()
222
- self.node_quadrature_weight = self._basis.make_node_quadrature_weight()
223
- self.element_inner_weight = self._basis.make_element_inner_weight()
224
- self.element_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
225
- self.element_outer_weight = self._basis.make_element_outer_weight()
226
- self.element_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
227
-
228
- self.space_value = self._make_space_value()
229
- self.space_gradient = self._make_space_gradient()
230
- self.space_divergence = self._make_space_divergence()
231
-
232
- self.node_dof_value = self._make_node_dof_value()
234
+ cache.setup_dynamic_attributes(self, cls=__class__)
233
235
 
234
236
  @property
235
237
  def name(self):
@@ -254,9 +256,11 @@ class VectorValuedFunctionSpace(FunctionSpace):
254
256
  return 1.0
255
257
 
256
258
  def _make_value_basis_element(self):
259
+ basis_element = type_basis_element(self.dtype)
260
+
257
261
  @cache.dynamic_func(suffix=self.name)
258
262
  def value_basis_element(dof_coord: int, value_map: Any):
259
- return value_map * basis_element(self.dtype(0.0), dof_coord)
263
+ return value_map * basis_element(dof_coord)
260
264
 
261
265
  return value_basis_element
262
266
 
@@ -319,11 +323,15 @@ class VectorValuedFunctionSpace(FunctionSpace):
319
323
  class CovariantFunctionSpace(VectorValuedFunctionSpace):
320
324
  """Function space whose values are covariant vectors"""
321
325
 
326
+ _dynamic_attribute_constructors: ClassVar = {
327
+ "local_value_map_inner": lambda obj: obj._make_local_value_map(),
328
+ "local_value_map_outer": lambda obj: obj.local_value_map_inner,
329
+ }
330
+
322
331
  def __init__(self, basis: BasisSpace):
323
332
  super().__init__(basis)
324
333
 
325
- self.local_value_map_inner = self._make_local_value_map()
326
- self.local_value_map_outer = self.local_value_map_inner
334
+ cache.setup_dynamic_attributes(self, cls=__class__)
327
335
 
328
336
  def trace(self) -> "CovariantFunctionSpaceTrace":
329
337
  return CovariantFunctionSpaceTrace(self)
@@ -348,12 +356,16 @@ class CovariantFunctionSpace(VectorValuedFunctionSpace):
348
356
  class CovariantFunctionSpaceTrace(VectorValuedFunctionSpace):
349
357
  """Trace of a :class:`CovariantFunctionSpace`"""
350
358
 
359
+ _dynamic_attribute_constructors: ClassVar = {
360
+ "local_value_map_inner": lambda obj: obj._make_local_value_map_inner(),
361
+ "local_value_map_outer": lambda obj: obj._make_local_value_map_outer(),
362
+ }
363
+
351
364
  def __init__(self, space: VectorValuedFunctionSpace):
352
365
  self._space = space
353
366
  super().__init__(space._basis.trace())
354
367
 
355
- self.local_value_map_inner = self._make_local_value_map_inner()
356
- self.local_value_map_outer = self._make_local_value_map_outer()
368
+ cache.setup_dynamic_attributes(self, cls=__class__)
357
369
 
358
370
  @property
359
371
  def name(self):
@@ -396,11 +408,15 @@ class CovariantFunctionSpaceTrace(VectorValuedFunctionSpace):
396
408
  class ContravariantFunctionSpace(VectorValuedFunctionSpace):
397
409
  """Function space whose values are contravariant vectors"""
398
410
 
411
+ _dynamic_attribute_constructors: ClassVar = {
412
+ "local_value_map_inner": lambda obj: obj._make_local_value_map(),
413
+ "local_value_map_outer": lambda obj: obj.local_value_map_inner,
414
+ }
415
+
399
416
  def __init__(self, basis: BasisSpace):
400
417
  super().__init__(basis)
401
418
 
402
- self.local_value_map_inner = self._make_local_value_map()
403
- self.local_value_map_outer = self.local_value_map_inner
419
+ cache.setup_dynamic_attributes(self, cls=__class__)
404
420
 
405
421
  def trace(self) -> "ContravariantFunctionSpaceTrace":
406
422
  return ContravariantFunctionSpaceTrace(self)
@@ -421,12 +437,16 @@ class ContravariantFunctionSpace(VectorValuedFunctionSpace):
421
437
  class ContravariantFunctionSpaceTrace(VectorValuedFunctionSpace):
422
438
  """Trace of a :class:`ContravariantFunctionSpace`"""
423
439
 
440
+ _dynamic_attribute_constructors: ClassVar = {
441
+ "local_value_map_inner": lambda obj: obj._make_local_value_map_inner(),
442
+ "local_value_map_outer": lambda obj: obj._make_local_value_map_outer(),
443
+ }
444
+
424
445
  def __init__(self, space: ContravariantFunctionSpace):
425
446
  self._space = space
426
447
  super().__init__(space._basis.trace())
427
448
 
428
- self.local_value_map_inner = self._make_local_value_map_inner()
429
- self.local_value_map_outer = self._make_local_value_map_outer()
449
+ cache.setup_dynamic_attributes(self, cls=__class__)
430
450
 
431
451
  @property
432
452
  def name(self):
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import ClassVar, Optional
17
17
 
18
18
  import numpy as np
19
19
 
@@ -71,6 +71,9 @@ class BasisSpace:
71
71
  """Value for the argument structure to be passed to device functions"""
72
72
  return BasisSpace.BasisArg()
73
73
 
74
+ def fill_basis_arg(self, arg, device):
75
+ pass
76
+
74
77
  # Helpers for generating node positions
75
78
 
76
79
  def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
@@ -174,6 +177,7 @@ class ShapeBasisSpace(BasisSpace):
174
177
  if self.value is not ShapeFunction.Value.Scalar:
175
178
  self.BasisArg = self.topology.TopologyArg
176
179
  self.basis_arg_value = self.topology.topo_arg_value
180
+ self.fill_basis_arg = self.topology.fill_topo_arg
177
181
 
178
182
  self.ORDER = self._shape.ORDER
179
183
 
@@ -335,6 +339,7 @@ class TraceBasisSpace(BasisSpace):
335
339
  self._basis = basis
336
340
  self.BasisArg = self._basis.BasisArg
337
341
  self.basis_arg_value = self._basis.basis_arg_value
342
+ self.fill_basis_arg = self._basis.fill_basis_arg
338
343
 
339
344
  @property
340
345
  def name(self):
@@ -507,6 +512,12 @@ def make_discontinuous_basis_space(geometry: Geometry, shape: ShapeFunction):
507
512
  class UnstructuredPointTopology(SpaceTopology):
508
513
  """Topology for unstructured points defined from quadrature formula. See :class:`PointBasisSpace`"""
509
514
 
515
+ _dynamic_attribute_constructors: ClassVar = {
516
+ "element_node_index": lambda obj: obj._make_element_node_index(),
517
+ "element_node_count": lambda obj: obj._make_element_node_count(),
518
+ "side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
519
+ }
520
+
510
521
  def __init__(self, quadrature: Quadrature):
511
522
  if quadrature.max_points_per_element() is None:
512
523
  raise ValueError("Quadrature must define a maximum number of points per element")
@@ -516,12 +527,12 @@ class UnstructuredPointTopology(SpaceTopology):
516
527
 
517
528
  self._quadrature = quadrature
518
529
  self.TopologyArg = quadrature.Arg
530
+ self.topo_arg_value = quadrature.arg_value
531
+ self.fill_topo_arg = quadrature.fill_arg
519
532
 
520
533
  super().__init__(quadrature.domain.geometry, max_nodes_per_element=quadrature.max_points_per_element())
521
534
 
522
- self.element_node_index = self._make_element_node_index()
523
- self.element_node_count = self._make_element_node_count()
524
- self.side_neighbor_node_counts = self._make_side_neighbor_node_counts()
535
+ cache.setup_dynamic_attributes(self, cls=__class__)
525
536
 
526
537
  def node_count(self):
527
538
  return self._quadrature.total_point_count()
@@ -582,6 +593,8 @@ class PointBasisSpace(BasisSpace):
582
593
 
583
594
  self.BasisArg = quadrature.Arg
584
595
  self.basis_arg_value = quadrature.arg_value
596
+ self.fill_basis_arg = quadrature.fill_arg
597
+
585
598
  self.ORDER = 0
586
599
 
587
600
  self.make_element_outer_weight = self.make_element_inner_weight
@@ -57,7 +57,7 @@ class IdentityMapper(DofMapper):
57
57
  self.value_dtype = dtype
58
58
  self.dof_dtype = dtype
59
59
 
60
- size = warp.types.type_length(dtype)
60
+ size = warp.types.type_size(dtype)
61
61
  self.DOF_SIZE = wp.constant(size)
62
62
 
63
63
  @wp.func
@@ -122,13 +122,13 @@ class FunctionSpace:
122
122
  raise NotImplementedError
123
123
 
124
124
  def gradient_valid(self) -> bool:
125
- """Whether gradient operator can be computed. Only for scalar and vector fields as higher-order tensors are not support yet"""
125
+ """Whether gradient operator can be computed. Only for scalar and vector fields as higher-order tensors are not supported yet"""
126
126
  return not wp.types.type_is_matrix(self.dtype)
127
127
 
128
128
  def divergence_valid(self) -> bool:
129
129
  """Whether divergence of this field can be computed. Only for vector and tensor fields with same dimension as embedding geometry"""
130
130
  if wp.types.type_is_vector(self.dtype):
131
- return wp.types.type_length(self.dtype) == self.geometry.dimension
131
+ return wp.types.type_size(self.dtype) == self.geometry.dimension
132
132
  if wp.types.type_is_matrix(self.dtype):
133
133
  return self.dtype._shape_[0] == self.geometry.dimension
134
134
  return False
@@ -41,6 +41,9 @@ class Grid2DSpaceTopology(SpaceTopology):
41
41
  def topo_arg_value(self, device):
42
42
  return self.geometry.side_arg_value(device)
43
43
 
44
+ def fill_topo_arg(self, arg: Grid2D.SideArg, device):
45
+ self.geometry.fill_side_arg(arg, device)
46
+
44
47
  def node_count(self) -> int:
45
48
  return (
46
49
  self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
@@ -84,7 +87,7 @@ class Grid2DSpaceTopology(SpaceTopology):
84
87
  axis = 1 - (node_type - SquareShapeFunction.EDGE_X)
85
88
 
86
89
  cell = Grid2D.get_cell(cell_arg.res, element_index)
87
- origin = wp.vec2i(cell[Grid2D.ROTATION[axis, 0]] + type_instance, cell[Grid2D.ROTATION[axis, 1]])
90
+ origin = Grid2D.orient(axis, cell) + wp.vec2i(type_instance, 0)
88
91
 
89
92
  side = Grid2D.Side(axis, origin)
90
93
  side_index = Grid2D.side_index(topo_arg, side)
@@ -85,13 +85,15 @@ class HexmeshSpaceTopology(SpaceTopology):
85
85
  @cache.cached_arg_value
86
86
  def topo_arg_value(self, device):
87
87
  arg = HexmeshTopologyArg()
88
+ self.fill_topo_arg(arg, device)
89
+ return arg
90
+
91
+ def fill_topo_arg(self, arg: HexmeshTopologyArg, device):
88
92
  arg.hex_edge_indices = self._hex_edge_indices.to(device)
89
93
  arg.hex_face_indices = self._hex_face_indices.to(device)
90
-
91
94
  arg.vertex_count = self._mesh.vertex_count()
92
95
  arg.face_count = self._mesh.side_count()
93
96
  arg.edge_count = self._edge_count
94
- return arg
95
97
 
96
98
  def _compute_hex_face_indices(self):
97
99
  self._hex_face_indices = wp.empty(
@@ -72,7 +72,10 @@ class NanogridSpaceTopology(SpaceTopology):
72
72
  @cache.cached_arg_value
73
73
  def topo_arg_value(self, device):
74
74
  arg = NanogridTopologyArg()
75
+ self.fill_topo_arg(arg, device)
76
+ return arg
75
77
 
78
+ def fill_topo_arg(self, arg, device):
76
79
  arg.vertex_grid = self._vertex_grid
77
80
  arg.face_grid = self._face_grid
78
81
  arg.edge_grid = self._edge_grid
@@ -80,7 +83,6 @@ class NanogridSpaceTopology(SpaceTopology):
80
83
  arg.vertex_count = self._grid.vertex_count()
81
84
  arg.face_count = self._face_count
82
85
  arg.edge_count = self._edge_count
83
- return arg
84
86
 
85
87
  def _make_element_node_index(self):
86
88
  element_node_index_generic = self._make_element_node_index_generic()
@@ -50,6 +50,9 @@ class SpacePartition:
50
50
  def partition_arg_value(self, device):
51
51
  pass
52
52
 
53
+ def fill_partition_arg(self, arg, device):
54
+ pass
55
+
53
56
  @staticmethod
54
57
  def partition_node_index(args: "PartitionArg", space_node_index: int):
55
58
  """Returns the index in the partition of a function space node, or ``NULL_NODE_INDEX`` if it does not exist"""
@@ -93,12 +96,15 @@ class WholeSpacePartition(SpacePartition):
93
96
  def partition_arg_value(self, device):
94
97
  return WholeSpacePartition.PartitionArg()
95
98
 
99
+ def fill_partition_arg(self, arg, device):
100
+ pass
101
+
96
102
  @wp.func
97
103
  def partition_node_index(args: Any, space_node_index: int):
98
104
  return space_node_index
99
105
 
100
106
  def __eq__(self, other: SpacePartition) -> bool:
101
- return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
107
+ return isinstance(other, WholeSpacePartition) and self.space_topology == other.space_topology
102
108
 
103
109
  @property
104
110
  def name(self) -> str:
@@ -160,9 +166,12 @@ class NodePartition(SpacePartition):
160
166
  @cache.cached_arg_value
161
167
  def partition_arg_value(self, device):
162
168
  arg = NodePartition.PartitionArg()
163
- arg.space_to_partition = self._space_to_partition.array.to(device)
169
+ self.fill_partition_arg(arg, device)
164
170
  return arg
165
171
 
172
+ def fill_partition_arg(self, arg, device):
173
+ arg.space_to_partition = self._space_to_partition.array.to(device)
174
+
166
175
  @wp.func
167
176
  def partition_node_index(args: PartitionArg, space_node_index: int):
168
177
  return args.space_to_partition[space_node_index]
@@ -55,13 +55,16 @@ class QuadmeshSpaceTopology(SpaceTopology):
55
55
  @cache.cached_arg_value
56
56
  def topo_arg_value(self, device):
57
57
  arg = Quadmesh2DTopologyArg()
58
+ self.fill_topo_arg(arg, device)
59
+ return arg
60
+
61
+ def fill_topo_arg(self, arg: Quadmesh2DTopologyArg, device):
58
62
  arg.quad_edge_indices = self._quad_edge_indices.to(device)
59
63
  arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
60
64
 
61
65
  arg.vertex_count = self._mesh.vertex_count()
62
66
  arg.edge_count = self._mesh.side_count()
63
67
  arg.cell_count = self._mesh.cell_count()
64
- return arg
65
68
 
66
69
  def _compute_quad_edge_indices(self):
67
70
  self._quad_edge_indices = wp.empty(
@@ -144,13 +144,16 @@ class SpaceRestriction:
144
144
  dof_indices_in_element: wp.array(dtype=int)
145
145
 
146
146
  @cached_arg_value
147
- def node_arg(self, device):
147
+ def node_arg_value(self, device):
148
148
  arg = SpaceRestriction.NodeArg()
149
+ self.fill_node_arg(arg, device)
150
+ return arg
151
+
152
+ def fill_node_arg(self, arg: NodeArg, device):
149
153
  arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
150
154
  arg.dof_element_indices = self._dof_element_indices.array.to(device)
151
155
  arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
152
156
  arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
153
- return arg
154
157
 
155
158
  @wp.func
156
159
  def node_partition_index(args: NodeArg, restriction_node_index: int):
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import functools
16
17
  from enum import Enum
17
18
  from typing import Optional
18
19
 
@@ -67,8 +68,9 @@ class ElementBasis(Enum):
67
68
  """Raviart-Thomas H(div) shape functions. Should be used with contravariant function space."""
68
69
 
69
70
 
71
+ @functools.lru_cache(maxsize=None)
70
72
  def get_shape_function(
71
- element: _element.Element,
73
+ element_class: type,
72
74
  space_dimension: int,
73
75
  degree: int,
74
76
  element_basis: ElementBasis,
@@ -78,7 +80,7 @@ def get_shape_function(
78
80
  Equips a reference element with a shape function basis.
79
81
 
80
82
  Args:
81
- element: the reference element on which to build the shape function
83
+ element_class: the type of reference element on which to build the shape function
82
84
  space_dimension: the dimension of the embedding space
83
85
  degree: polynomial degree of the per-element shape functions
84
86
  element_basis: type of basis function for the individual elements
@@ -89,12 +91,12 @@ def get_shape_function(
89
91
  """
90
92
 
91
93
  if degree == 0:
92
- return ConstantShapeFunction(element, space_dimension)
94
+ return ConstantShapeFunction(element_class(), space_dimension)
93
95
 
94
96
  if family is None:
95
97
  family = Polynomial.LOBATTO_GAUSS_LEGENDRE
96
98
 
97
- if isinstance(element, _element.Square):
99
+ if issubclass(element_class, _element.Square):
98
100
  if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
99
101
  return SquareNedelecFirstKindShapeFunctions(degree=degree)
100
102
  if element_basis == ElementBasis.RAVIART_THOMAS:
@@ -105,7 +107,7 @@ def get_shape_function(
105
107
  return SquareSerendipityShapeFunctions(degree=degree, family=family)
106
108
 
107
109
  return SquareBipolynomialShapeFunctions(degree=degree, family=family)
108
- if isinstance(element, _element.Triangle):
110
+ if issubclass(element_class, _element.Triangle):
109
111
  if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
110
112
  return TriangleNedelecFirstKindShapeFunctions(degree=degree)
111
113
  if element_basis == ElementBasis.RAVIART_THOMAS:
@@ -117,7 +119,7 @@ def get_shape_function(
117
119
 
118
120
  return TrianglePolynomialShapeFunctions(degree=degree)
119
121
 
120
- if isinstance(element, _element.Cube):
122
+ if issubclass(element_class, _element.Cube):
121
123
  if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
122
124
  return CubeNedelecFirstKindShapeFunctions(degree=degree)
123
125
  if element_basis == ElementBasis.RAVIART_THOMAS:
@@ -128,7 +130,7 @@ def get_shape_function(
128
130
  return CubeSerendipityShapeFunctions(degree=degree, family=family)
129
131
 
130
132
  return CubeTripolynomialShapeFunctions(degree=degree, family=family)
131
- if isinstance(element, _element.Tetrahedron):
133
+ if issubclass(element_class, _element.Tetrahedron):
132
134
  if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
133
135
  return TetrahedronNedelecFirstKindShapeFunctions(degree=degree)
134
136
  if element_basis == ElementBasis.RAVIART_THOMAS:
@@ -140,4 +142,4 @@ def get_shape_function(
140
142
 
141
143
  return TetrahedronPolynomialShapeFunctions(degree=degree)
142
144
 
143
- return NotImplementedError("Unrecognized element type")
145
+ raise NotImplementedError(f"Unrecognized element type {element_class.__name__}")
@@ -75,6 +75,10 @@ class TetmeshSpaceTopology(SpaceTopology):
75
75
  @cache.cached_arg_value
76
76
  def topo_arg_value(self, device):
77
77
  arg = TetmeshTopologyArg()
78
+ self.fill_topo_arg(arg, device)
79
+ return arg
80
+
81
+ def fill_topo_arg(self, arg: TetmeshTopologyArg, device):
78
82
  arg.tet_face_indices = self._tet_face_indices.to(device)
79
83
  arg.tet_edge_indices = self._tet_edge_indices.to(device)
80
84
  arg.face_vertex_indices = self._mesh.face_vertex_indices.to(device)
@@ -83,7 +87,6 @@ class TetmeshSpaceTopology(SpaceTopology):
83
87
  arg.vertex_count = self._mesh.vertex_count()
84
88
  arg.face_count = self._mesh.side_count()
85
89
  arg.edge_count = self._edge_count
86
- return arg
87
90
 
88
91
  def _compute_tet_face_indices(self):
89
92
  self._tet_face_indices = wp.empty(