warp-lang 1.4.2__py3-none-macosx_10_13_universal2.whl → 1.5.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ from typing import Any
2
2
 
3
3
  import warp as wp
4
4
  from warp.fem import cache, domain
5
+ from warp.fem.geometry import Element
5
6
  from warp.fem.space import FunctionSpace
6
7
  from warp.fem.types import Coords, ElementIndex
7
8
 
@@ -77,6 +78,38 @@ class Quadrature:
77
78
  class RegularQuadrature(Quadrature):
78
79
  """Regular quadrature formula, using a constant set of quadrature points per element"""
79
80
 
81
+ @wp.struct
82
+ class Arg:
83
+ # Quadrature points and weights used to be passed as Warp constants,
84
+ # but this tended to incur register spilling for high point counts
85
+ points: wp.array(dtype=Coords)
86
+ weights: wp.array(dtype=float)
87
+
88
+ # Cache common formulas so we do dot have to do h2d transfer for each call
89
+ class CachedFormula:
90
+ _cache = {}
91
+
92
+ def __init__(self, element: Element, order: int, family: Polynomial):
93
+ self.points, self.weights = element.instantiate_quadrature(order, family)
94
+ self.count = wp.constant(len(self.points))
95
+
96
+ @cache.cached_arg_value
97
+ def arg_value(self, device):
98
+ arg = RegularQuadrature.Arg()
99
+ arg.points = wp.array(self.points, device=device, dtype=Coords)
100
+ arg.weights = wp.array(self.weights, device=device, dtype=float)
101
+ return arg
102
+
103
+ @staticmethod
104
+ def get(element: Element, order: int, family: Polynomial):
105
+ key = (element.__class__.__name__, order, family)
106
+ try:
107
+ return RegularQuadrature.CachedFormula._cache[key]
108
+ except KeyError:
109
+ quadrature = RegularQuadrature.CachedFormula(element, order, family)
110
+ RegularQuadrature.CachedFormula._cache[key] = quadrature
111
+ return quadrature
112
+
80
113
  def __init__(
81
114
  self,
82
115
  domain: domain.GeometryDomain,
@@ -88,15 +121,7 @@ class RegularQuadrature(Quadrature):
88
121
  self.family = family
89
122
  self.order = order
90
123
 
91
- self._element_quadrature = domain.reference_element().instantiate_quadrature(order, family)
92
-
93
- self._N = wp.constant(len(self.points))
94
-
95
- WeightVec = wp.vec(length=self._N, dtype=wp.float32)
96
- CoordMat = wp.mat(shape=(self._N, 3), dtype=wp.float32)
97
-
98
- self._POINTS = wp.constant(CoordMat(self.points))
99
- self._WEIGHTS = wp.constant(WeightVec(self.weights))
124
+ self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
100
125
 
101
126
  self.point_count = self._make_point_count()
102
127
  self.point_index = self._make_point_index()
@@ -108,21 +133,24 @@ class RegularQuadrature(Quadrature):
108
133
  return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
109
134
 
110
135
  def total_point_count(self):
111
- return len(self.points) * self.domain.geometry_element_count()
136
+ return self._formula.count * self.domain.element_count()
112
137
 
113
138
  def max_points_per_element(self):
114
- return self._N
139
+ return self._formula.count
115
140
 
116
141
  @property
117
142
  def points(self):
118
- return self._element_quadrature[0]
143
+ return self._formula.points
119
144
 
120
145
  @property
121
146
  def weights(self):
122
- return self._element_quadrature[1]
147
+ return self._formula.weights
148
+
149
+ def arg_value(self, device):
150
+ return self._formula.arg_value(device)
123
151
 
124
152
  def _make_point_count(self):
125
- N = self._N
153
+ N = self._formula.count
126
154
 
127
155
  @cache.dynamic_func(suffix=self.name)
128
156
  def point_count(
@@ -136,8 +164,6 @@ class RegularQuadrature(Quadrature):
136
164
  return point_count
137
165
 
138
166
  def _make_point_coords(self):
139
- POINTS = self._POINTS
140
-
141
167
  @cache.dynamic_func(suffix=self.name)
142
168
  def point_coords(
143
169
  elt_arg: self.domain.ElementArg,
@@ -146,13 +172,11 @@ class RegularQuadrature(Quadrature):
146
172
  element_index: ElementIndex,
147
173
  qp_index: int,
148
174
  ):
149
- return Coords(POINTS[qp_index, 0], POINTS[qp_index, 1], POINTS[qp_index, 2])
175
+ return qp_arg.points[qp_index]
150
176
 
151
177
  return point_coords
152
178
 
153
179
  def _make_point_weight(self):
154
- WEIGHTS = self._WEIGHTS
155
-
156
180
  @cache.dynamic_func(suffix=self.name)
157
181
  def point_weight(
158
182
  elt_arg: self.domain.ElementArg,
@@ -161,12 +185,12 @@ class RegularQuadrature(Quadrature):
161
185
  element_index: ElementIndex,
162
186
  qp_index: int,
163
187
  ):
164
- return WEIGHTS[qp_index]
188
+ return qp_arg.weights[qp_index]
165
189
 
166
190
  return point_weight
167
191
 
168
192
  def _make_point_index(self):
169
- N = self._N
193
+ N = self._formula.count
170
194
 
171
195
  @cache.dynamic_func(suffix=self.name)
172
196
  def point_index(
@@ -8,20 +8,20 @@ import warp.fem.geometry as _geometry
8
8
  import warp.fem.polynomial as _polynomial
9
9
 
10
10
  from .function_space import FunctionSpace
11
+ from .basis_function_space import CollocatedFunctionSpace, ContravariantFunctionSpace, CovariantFunctionSpace
11
12
  from .topology import SpaceTopology
12
13
  from .basis_space import BasisSpace, PointBasisSpace, ShapeBasisSpace, make_discontinuous_basis_space
13
- from .collocated_function_space import CollocatedFunctionSpace
14
- from .shape import ElementBasis, get_shape_function
14
+ from .shape import ElementBasis, get_shape_function, ShapeFunction
15
15
 
16
16
  from .grid_2d_function_space import make_grid_2d_space_topology
17
17
 
18
18
  from .grid_3d_function_space import make_grid_3d_space_topology
19
19
 
20
- from .trimesh_2d_function_space import make_trimesh_2d_space_topology
20
+ from .trimesh_function_space import make_trimesh_space_topology
21
21
 
22
22
  from .tetmesh_function_space import make_tetmesh_space_topology
23
23
 
24
- from .quadmesh_2d_function_space import make_quadmesh_2d_space_topology
24
+ from .quadmesh_function_space import make_quadmesh_space_topology
25
25
 
26
26
  from .hexmesh_function_space import make_hexmesh_space_topology
27
27
 
@@ -115,12 +115,12 @@ def make_polynomial_basis_space(
115
115
  topology = make_grid_2d_space_topology(geo, shape)
116
116
  elif isinstance(base_geo, _geometry.Grid3D):
117
117
  topology = make_grid_3d_space_topology(geo, shape)
118
- elif isinstance(base_geo, _geometry.Trimesh2D):
119
- topology = make_trimesh_2d_space_topology(geo, shape)
118
+ elif isinstance(base_geo, _geometry.Trimesh):
119
+ topology = make_trimesh_space_topology(geo, shape)
120
120
  elif isinstance(base_geo, _geometry.Tetmesh):
121
121
  topology = make_tetmesh_space_topology(geo, shape)
122
- elif isinstance(base_geo, _geometry.Quadmesh2D):
123
- topology = make_quadmesh_2d_space_topology(geo, shape)
122
+ elif isinstance(base_geo, _geometry.Quadmesh):
123
+ topology = make_quadmesh_space_topology(geo, shape)
124
124
  elif isinstance(base_geo, _geometry.Hexmesh):
125
125
  topology = make_hexmesh_space_topology(geo, shape)
126
126
  elif isinstance(base_geo, _geometry.Nanogrid) or isinstance(base_geo, _geometry.AdaptiveNanogrid):
@@ -136,7 +136,7 @@ def make_collocated_function_space(
136
136
  basis_space: BasisSpace, dtype: type = float, dof_mapper: Optional[DofMapper] = None
137
137
  ) -> CollocatedFunctionSpace:
138
138
  """
139
- Constructs a function space from a basis space and a value type, such that all degrees of freedom of the value type are stored at each of the basis nodes.
139
+ Constructs a function space from a scalar-valued basis space and a value type, such that all degrees of freedom of the value type are stored at each of the basis nodes.
140
140
 
141
141
  Args:
142
142
  geo: the Geometry on which to build the space
@@ -146,9 +146,37 @@ def make_collocated_function_space(
146
146
  Returns:
147
147
  the constructed function space
148
148
  """
149
+
150
+ if basis_space.value != ShapeFunction.Value.Scalar:
151
+ raise ValueError("Collocated function spaces may only be constructed from scalar-valued basis")
152
+
149
153
  return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)
150
154
 
151
155
 
156
+ def make_covariant_function_space(
157
+ basis_space: BasisSpace,
158
+ ) -> CovariantFunctionSpace:
159
+ """
160
+ Constructs a covariant function space from a vector-valued basis space
161
+ """
162
+
163
+ if basis_space.value != ShapeFunction.Value.CovariantVector:
164
+ raise ValueError("Covariant function spaces may only be constructed from covariant vector-valued basis")
165
+ return CovariantFunctionSpace(basis_space)
166
+
167
+
168
+ def make_contravariant_function_space(
169
+ basis_space: BasisSpace,
170
+ ) -> ContravariantFunctionSpace:
171
+ """
172
+ Constructs a contravariant function space from a vector-valued basis space
173
+ """
174
+
175
+ if basis_space.value != ShapeFunction.Value.ContravariantVector:
176
+ raise ValueError("Contravariant function spaces may only be constructed from contravariant vector-valued basis")
177
+ return ContravariantFunctionSpace(basis_space)
178
+
179
+
152
180
  def make_polynomial_space(
153
181
  geo: _geometry.Geometry,
154
182
  dtype: type = float,
@@ -160,7 +188,7 @@ def make_polynomial_space(
160
188
  ) -> CollocatedFunctionSpace:
161
189
  """
162
190
  Equips a geometry with a collocated, polynomial function space.
163
- Equivalent to successive calls to :func:`make_polynomial_basis_space` and `make_collocated_function_space`.
191
+ Equivalent to successive calls to :func:`make_polynomial_basis_space` then `make_collocated_function_space`, `make_covariant_function_space` or `make_contravariant_function_space`.
164
192
 
165
193
  Args:
166
194
  geo: the Geometry on which to build the space
@@ -176,4 +204,10 @@ def make_polynomial_space(
176
204
  """
177
205
 
178
206
  basis_space = make_polynomial_basis_space(geo, degree, element_basis, discontinuous, family)
179
- return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)
207
+
208
+ if basis_space.value == ShapeFunction.Value.CovariantVector:
209
+ return make_covariant_function_space(basis_space)
210
+ if basis_space.value == ShapeFunction.Value.ContravariantVector:
211
+ return make_contravariant_function_space(basis_space)
212
+
213
+ return make_collocated_function_space(basis_space, dtype=dtype, dof_mapper=dof_mapper)
@@ -0,0 +1,451 @@
1
+ from typing import Any, Optional
2
+
3
+ import warp as wp
4
+ from warp.fem import cache
5
+ from warp.fem.geometry import Geometry
6
+ from warp.fem.linalg import basis_element, generalized_inner, generalized_outer
7
+ from warp.fem.types import Coords, ElementIndex, make_free_sample
8
+
9
+ from .basis_space import BasisSpace
10
+ from .dof_mapper import DofMapper, IdentityMapper
11
+ from .function_space import FunctionSpace
12
+ from .partition import SpacePartition, make_space_partition
13
+
14
+
15
+ class CollocatedFunctionSpace(FunctionSpace):
16
+ """Function space where values are collocated at nodes"""
17
+
18
+ @wp.struct
19
+ class LocalValueMap:
20
+ pass
21
+
22
+ def __init__(self, basis: BasisSpace, dtype: type = float, dof_mapper: DofMapper = None):
23
+ self.dof_mapper = IdentityMapper(dtype) if dof_mapper is None else dof_mapper
24
+ self._basis = basis
25
+
26
+ super().__init__(topology=basis.topology)
27
+
28
+ self.dtype = self.dof_mapper.value_dtype
29
+ self.dof_dtype = self.dof_mapper.dof_dtype
30
+ self.VALUE_DOF_COUNT = self.dof_mapper.DOF_SIZE
31
+ self.NODE_DOF_COUNT = self.dof_mapper.DOF_SIZE
32
+
33
+ self.SpaceArg = self._basis.BasisArg
34
+ self.space_arg_value = self._basis.basis_arg_value
35
+
36
+ self.ORDER = self._basis.ORDER
37
+
38
+ self.node_basis_element = self._make_node_basis_element()
39
+ self.value_basis_element = self._make_value_basis_element()
40
+
41
+ self.node_coords_in_element = self._basis.make_node_coords_in_element()
42
+ self.node_quadrature_weight = self._basis.make_node_quadrature_weight()
43
+ self.element_inner_weight = self._basis.make_element_inner_weight()
44
+ self.element_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
45
+ self.element_outer_weight = self._basis.make_element_outer_weight()
46
+ self.element_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
47
+
48
+ self.space_value = self._make_space_value()
49
+ self.space_gradient = self._make_space_gradient()
50
+ self.space_divergence = self._make_space_divergence()
51
+
52
+ self.node_dof_value = self._make_node_dof_value()
53
+
54
+ # For backward compatibility
55
+ if hasattr(basis, "node_grid"):
56
+ self.node_grid = basis.node_grid
57
+ if hasattr(basis, "node_triangulation"):
58
+ self.node_triangulation = basis.node_triangulation
59
+ if hasattr(basis, "node_tets"):
60
+ self.node_tets = basis.node_tets
61
+ if hasattr(basis, "node_hexes"):
62
+ self.node_hexes = basis.node_hexes
63
+ if hasattr(basis, "vtk_cells"):
64
+ self.vtk_cells = basis.vtk_cells
65
+
66
+ @property
67
+ def name(self):
68
+ return f"{self._basis.name}_{self.dof_mapper}".replace(".", "_")
69
+
70
+ def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
71
+ return self._basis.node_positions(out=out)
72
+
73
+ def make_field(
74
+ self,
75
+ space_partition: Optional[SpacePartition] = None,
76
+ ) -> "wp.fem.field.NodalField":
77
+ from warp.fem.field import NodalField
78
+
79
+ if space_partition is None:
80
+ space_partition = make_space_partition(space_topology=self.topology)
81
+
82
+ return NodalField(space=self, space_partition=space_partition)
83
+
84
+ def trace(self) -> "CollocatedFunctionSpace":
85
+ return CollocatedFunctionSpaceTrace(self)
86
+
87
+ def _make_node_basis_element(self):
88
+ @cache.dynamic_func(suffix=self.name)
89
+ def node_basis_element(dof_coord: int):
90
+ return basis_element(self.dof_dtype(0.0), dof_coord)
91
+
92
+ return node_basis_element
93
+
94
+ def _make_value_basis_element(self):
95
+ @cache.dynamic_func(suffix=self.name)
96
+ def value_basis_element(dof_coord: int, value_map: CollocatedFunctionSpace.LocalValueMap):
97
+ return self.dof_mapper.dof_to_value(self.node_basis_element(dof_coord))
98
+
99
+ return value_basis_element
100
+
101
+ @wp.func
102
+ def local_value_map_inner(
103
+ elt_arg: Any,
104
+ element_index: ElementIndex,
105
+ element_coords: Coords,
106
+ ):
107
+ return CollocatedFunctionSpace.LocalValueMap()
108
+
109
+ @wp.func
110
+ def local_value_map_outer(
111
+ elt_arg: Any,
112
+ element_index: ElementIndex,
113
+ element_coords: Coords,
114
+ ):
115
+ return CollocatedFunctionSpace.LocalValueMap()
116
+
117
+ def _make_space_value(self):
118
+ @cache.dynamic_func(suffix=self.name)
119
+ def value_func(
120
+ dof_value: self.dof_dtype,
121
+ node_weight: self._basis.weight_type,
122
+ local_value_map: self.LocalValueMap,
123
+ ):
124
+ return node_weight * self.dof_mapper.dof_to_value(dof_value)
125
+
126
+ return value_func
127
+
128
+ def _make_space_gradient(self):
129
+ @cache.dynamic_func(suffix=self.name)
130
+ def gradient_func(
131
+ dof_value: self.dof_dtype,
132
+ node_weight_gradient: self._basis.weight_gradient_type,
133
+ local_value_map: self.LocalValueMap,
134
+ grad_transform: Any,
135
+ ):
136
+ return generalized_outer(self.dof_mapper.dof_to_value(dof_value), node_weight_gradient * grad_transform)
137
+
138
+ return gradient_func
139
+
140
+ def _make_space_divergence(self):
141
+ @cache.dynamic_func(suffix=self.name)
142
+ def divergence_func(
143
+ dof_value: self.dof_dtype,
144
+ node_weight_gradient: self._basis.weight_gradient_type,
145
+ local_value_map: self.LocalValueMap,
146
+ grad_transform: Any,
147
+ ):
148
+ return generalized_inner(self.dof_mapper.dof_to_value(dof_value), node_weight_gradient * grad_transform)
149
+
150
+ return divergence_func
151
+
152
+ def _make_node_dof_value(self):
153
+ @cache.dynamic_func(suffix=self.name)
154
+ def node_dof_value(
155
+ elt_arg: self.ElementArg,
156
+ space_arg: self.SpaceArg,
157
+ element_index: ElementIndex,
158
+ node_index_in_elt: int,
159
+ space_value: self.dtype,
160
+ ):
161
+ return self.dof_mapper.value_to_dof(space_value)
162
+
163
+ return node_dof_value
164
+
165
+
166
+ class CollocatedFunctionSpaceTrace(CollocatedFunctionSpace):
167
+ """Trace of a :class:`CollocatedFunctionSpace`"""
168
+
169
+ def __init__(self, space: CollocatedFunctionSpace):
170
+ self._space = space
171
+ super().__init__(space._basis.trace(), space.dtype, space.dof_mapper)
172
+
173
+ @property
174
+ def name(self):
175
+ return f"{self._space.name}_Trace"
176
+
177
+ def __eq__(self, other: "CollocatedFunctionSpaceTrace") -> bool:
178
+ return self._space == other._space
179
+
180
+
181
+ class VectorValuedFunctionSpace(FunctionSpace):
182
+ """Function space whose values are vectors"""
183
+
184
+ def __init__(self, basis: BasisSpace):
185
+ self._basis = basis
186
+
187
+ super().__init__(topology=basis.topology)
188
+
189
+ self.dtype = cache.cached_vec_type(self.geometry.dimension, dtype=float)
190
+ self.dof_dtype = float
191
+
192
+ self.VALUE_DOF_COUNT = self.geometry.dimension
193
+ self.NODE_DOF_COUNT = 1
194
+
195
+ self.SpaceArg = self._basis.BasisArg
196
+ self.space_arg_value = self._basis.basis_arg_value
197
+
198
+ self.ORDER = self._basis.ORDER
199
+
200
+ self.LocalValueMap = cache.cached_mat_type(
201
+ shape=(self.geometry.dimension, self.geometry.cell_dimension), dtype=float
202
+ )
203
+
204
+ self.value_basis_element = self._make_value_basis_element()
205
+
206
+ self.node_coords_in_element = self._basis.make_node_coords_in_element()
207
+ self.node_quadrature_weight = self._basis.make_node_quadrature_weight()
208
+ self.element_inner_weight = self._basis.make_element_inner_weight()
209
+ self.element_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
210
+ self.element_outer_weight = self._basis.make_element_outer_weight()
211
+ self.element_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
212
+
213
+ self.space_value = self._make_space_value()
214
+ self.space_gradient = self._make_space_gradient()
215
+ self.space_divergence = self._make_space_divergence()
216
+
217
+ self.node_dof_value = self._make_node_dof_value()
218
+
219
+ @property
220
+ def name(self):
221
+ return self._basis.name
222
+
223
+ def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
224
+ return self._basis.node_positions(out=out)
225
+
226
+ def make_field(
227
+ self,
228
+ space_partition: Optional[SpacePartition] = None,
229
+ ) -> "wp.fem.field.NodalField":
230
+ from warp.fem.field import NodalField
231
+
232
+ if space_partition is None:
233
+ space_partition = make_space_partition(space_topology=self.topology)
234
+
235
+ return NodalField(space=self, space_partition=space_partition)
236
+
237
+ @wp.func
238
+ def node_basis_element(dof_coord: int):
239
+ return 1.0
240
+
241
+ def _make_value_basis_element(self):
242
+ @cache.dynamic_func(suffix=self.name)
243
+ def value_basis_element(dof_coord: int, value_map: Any):
244
+ return value_map * basis_element(self.dtype(0.0), dof_coord)
245
+
246
+ return value_basis_element
247
+
248
+ def _make_space_value(self):
249
+ @cache.dynamic_func(suffix=self.name)
250
+ def value_func(
251
+ dof_value: self.dof_dtype,
252
+ node_weight: self._basis.weight_type,
253
+ local_value_map: self.LocalValueMap,
254
+ ):
255
+ return local_value_map * (node_weight * dof_value)
256
+
257
+ return value_func
258
+
259
+ def _make_space_gradient(self):
260
+ @cache.dynamic_func(suffix=self.name)
261
+ def gradient_func(
262
+ dof_value: self.dof_dtype,
263
+ node_weight_gradient: self._basis.weight_gradient_type,
264
+ local_value_map: self.LocalValueMap,
265
+ grad_transform: Any,
266
+ ):
267
+ return dof_value * local_value_map * node_weight_gradient * grad_transform
268
+
269
+ return gradient_func
270
+
271
+ def _make_space_divergence(self):
272
+ @cache.dynamic_func(suffix=self.name)
273
+ def divergence_func(
274
+ dof_value: self.dof_dtype,
275
+ node_weight_gradient: self._basis.weight_gradient_type,
276
+ local_value_map: self.LocalValueMap,
277
+ grad_transform: Any,
278
+ ):
279
+ return dof_value * wp.trace(local_value_map * node_weight_gradient * grad_transform)
280
+
281
+ return divergence_func
282
+
283
+ def _make_node_dof_value(self):
284
+ @cache.dynamic_func(suffix=self.name)
285
+ def node_dof_value(
286
+ elt_arg: self.ElementArg,
287
+ space_arg: self.SpaceArg,
288
+ element_index: ElementIndex,
289
+ node_index_in_elt: int,
290
+ space_value: self.dtype,
291
+ ):
292
+ coords = self.node_coords_in_element(elt_arg, space_arg, element_index, node_index_in_elt)
293
+ weight = self.element_inner_weight(elt_arg, space_arg, element_index, coords, node_index_in_elt)
294
+ local_value_map = self.local_value_map_inner(elt_arg, element_index, coords)
295
+
296
+ unit_value = local_value_map * weight
297
+ return wp.dot(space_value, unit_value) / wp.length_sq(unit_value)
298
+
299
+ return node_dof_value
300
+
301
+
302
+ class CovariantFunctionSpace(VectorValuedFunctionSpace):
303
+ """Function space whose values are covariant vectors"""
304
+
305
+ def __init__(self, basis: BasisSpace):
306
+ super().__init__(basis)
307
+
308
+ self.local_value_map_inner = self._make_local_value_map()
309
+ self.local_value_map_outer = self.local_value_map_inner
310
+
311
+ def trace(self) -> "CovariantFunctionSpaceTrace":
312
+ return CovariantFunctionSpaceTrace(self)
313
+
314
+ def _make_local_value_map(self):
315
+ @cache.dynamic_func(suffix=self.name)
316
+ def local_value_map(
317
+ elt_arg: self.ElementArg,
318
+ element_index: ElementIndex,
319
+ element_coords: Coords,
320
+ ):
321
+ J = wp.transpose(
322
+ self.geometry.cell_inverse_deformation_gradient(
323
+ elt_arg, make_free_sample(element_index, element_coords)
324
+ )
325
+ )
326
+ return J
327
+
328
+ return local_value_map
329
+
330
+
331
+ class CovariantFunctionSpaceTrace(VectorValuedFunctionSpace):
332
+ """Trace of a :class:`CovariantFunctionSpace`"""
333
+
334
+ def __init__(self, space: VectorValuedFunctionSpace):
335
+ self._space = space
336
+ super().__init__(space._basis.trace())
337
+
338
+ self.local_value_map_inner = self._make_local_value_map_inner()
339
+ self.local_value_map_outer = self._make_local_value_map_outer()
340
+
341
+ @property
342
+ def name(self):
343
+ return f"{self._space.name}_Trace"
344
+
345
+ def __eq__(self, other: "CovariantFunctionSpaceTrace") -> bool:
346
+ return self._space == other._space
347
+
348
+ def _make_local_value_map_inner(self):
349
+ @cache.dynamic_func(suffix=self.name)
350
+ def local_value_map_inner(
351
+ elt_arg: self.ElementArg,
352
+ element_index: ElementIndex,
353
+ element_coords: Coords,
354
+ ):
355
+ return wp.transpose(
356
+ self.geometry.side_inner_inverse_deformation_gradient(
357
+ elt_arg, make_free_sample(element_index, element_coords)
358
+ )
359
+ )
360
+
361
+ return local_value_map_inner
362
+
363
+ def _make_local_value_map_outer(self):
364
+ @cache.dynamic_func(suffix=self.name)
365
+ def local_value_map_outer(
366
+ elt_arg: self.ElementArg,
367
+ element_index: ElementIndex,
368
+ element_coords: Coords,
369
+ ):
370
+ return wp.transpose(
371
+ self.geometry.side_outer_inverse_deformation_gradient(
372
+ elt_arg, make_free_sample(element_index, element_coords)
373
+ )
374
+ )
375
+
376
+ return local_value_map_outer
377
+
378
+
379
+ class ContravariantFunctionSpace(VectorValuedFunctionSpace):
380
+ """Function space whose values are contravariant vectors"""
381
+
382
+ def __init__(self, basis: BasisSpace):
383
+ super().__init__(basis)
384
+
385
+ self.local_value_map_inner = self._make_local_value_map()
386
+ self.local_value_map_outer = self.local_value_map_inner
387
+
388
+ def trace(self) -> "ContravariantFunctionSpaceTrace":
389
+ return ContravariantFunctionSpaceTrace(self)
390
+
391
+ def _make_local_value_map(self):
392
+ @cache.dynamic_func(suffix=self.name)
393
+ def local_value_map(
394
+ elt_arg: self.ElementArg,
395
+ element_index: ElementIndex,
396
+ element_coords: Coords,
397
+ ):
398
+ F = self.geometry.cell_deformation_gradient(elt_arg, make_free_sample(element_index, element_coords))
399
+ return F / Geometry._element_measure(F)
400
+
401
+ return local_value_map
402
+
403
+
404
+ class ContravariantFunctionSpaceTrace(VectorValuedFunctionSpace):
405
+ """Trace of a :class:`ContravariantFunctionSpace`"""
406
+
407
+ def __init__(self, space: ContravariantFunctionSpace):
408
+ self._space = space
409
+ super().__init__(space._basis.trace())
410
+
411
+ self.local_value_map_inner = self._make_local_value_map_inner()
412
+ self.local_value_map_outer = self._make_local_value_map_outer()
413
+
414
+ @property
415
+ def name(self):
416
+ return f"{self._space.name}_Trace"
417
+
418
+ def __eq__(self, other: "ContravariantFunctionSpaceTrace") -> bool:
419
+ return self._space == other._space
420
+
421
+ def _make_local_value_map_inner(self):
422
+ @cache.dynamic_func(suffix=self.name)
423
+ def local_value_map_inner(
424
+ elt_arg: self.ElementArg,
425
+ element_index: ElementIndex,
426
+ element_coords: Coords,
427
+ ):
428
+ cell_index = self.geometry.side_inner_cell_index(elt_arg, element_index)
429
+ cell_coords = self.geometry.side_inner_cell_coords(elt_arg, element_index, element_coords)
430
+ cell_arg = self.geometry.side_to_cell_arg(elt_arg)
431
+
432
+ F = self.geometry.cell_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
433
+ return F / Geometry._element_measure(F)
434
+
435
+ return local_value_map_inner
436
+
437
+ def _make_local_value_map_outer(self):
438
+ @cache.dynamic_func(suffix=self.name)
439
+ def local_value_map_outer(
440
+ elt_arg: self.ElementArg,
441
+ element_index: ElementIndex,
442
+ element_coords: Coords,
443
+ ):
444
+ cell_index = self.geometry.side_outer_cell_index(elt_arg, element_index)
445
+ cell_coords = self.geometry.side_outer_cell_coords(elt_arg, element_index, element_coords)
446
+ cell_arg = self.geometry.side_to_cell_arg(elt_arg)
447
+
448
+ F = self.geometry.cell_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
449
+ return F / Geometry._element_measure(F)
450
+
451
+ return local_value_map_outer