warp-lang 1.4.1__py3-none-manylinux2014_x86_64.whl → 1.5.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/fem/cache.py CHANGED
@@ -6,6 +6,7 @@ from copy import copy
6
6
  from typing import Any, Callable, Dict, Optional, Tuple, Union
7
7
 
8
8
  import warp as wp
9
+ from warp.fem.operator import Integrand
9
10
 
10
11
  _kernel_cache = {}
11
12
  _struct_cache = {}
@@ -186,7 +187,7 @@ class ExpandStarredArgumentStruct(ast.NodeTransformer):
186
187
 
187
188
 
188
189
  def get_integrand_function(
189
- integrand: "warp.fem.operator.Integrand", # noqa: F821
190
+ integrand: Integrand,
190
191
  suffix: str,
191
192
  func=None,
192
193
  annotations=None,
@@ -208,27 +209,30 @@ def get_integrand_function(
208
209
 
209
210
 
210
211
  def get_integrand_kernel(
211
- integrand: "warp.fem.operator.Integrand", # noqa: F821
212
+ integrand: Integrand,
212
213
  suffix: str,
213
214
  kernel_fn: Optional[Callable] = None,
214
215
  kernel_options: Dict[str, Any] = None,
215
216
  code_transformers=None,
216
217
  ):
217
- if kernel_options is None:
218
- kernel_options = {}
218
+ options = integrand.module.options.copy()
219
+ options.update(integrand.kernel_options)
220
+ if kernel_options is not None:
221
+ options.update(kernel_options)
219
222
 
220
- key = _make_key(integrand.func, suffix, use_qualified_name=True)
223
+ kernel_key = _make_key(integrand.func, suffix, use_qualified_name=True)
224
+ opts_key = "".join([f"{k}:{v}" for k, v in sorted(options.items())])
225
+ cache_key = kernel_key + opts_key
221
226
 
222
- if key not in _kernel_cache:
227
+ if cache_key not in _kernel_cache:
223
228
  if kernel_fn is None:
224
229
  return None
225
230
 
226
231
  module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
227
- module.options = copy(integrand.module.options)
228
- module.options.update(kernel_options)
229
-
230
- _kernel_cache[key] = wp.Kernel(func=kernel_fn, key=key, module=module, code_transformers=code_transformers)
231
- return _kernel_cache[key]
232
+ _kernel_cache[cache_key] = wp.Kernel(
233
+ func=kernel_fn, key=kernel_key, module=module, code_transformers=code_transformers, options=options
234
+ )
235
+ return _kernel_cache[cache_key]
232
236
 
233
237
 
234
238
  def cached_arg_value(func: Callable):
@@ -478,7 +482,7 @@ def borrow_temporary(
478
482
  if temporary_store is None:
479
483
  temporary_store = TemporaryStore._default_store
480
484
 
481
- if temporary_store is None:
485
+ if temporary_store is None or (requires_grad and wp.context.runtime.tape is not None):
482
486
  return Temporary(
483
487
  array=wp.empty(shape=shape, dtype=dtype, pinned=pinned, device=device, requires_grad=requires_grad)
484
488
  )
warp/fem/dirichlet.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Optional
2
2
 
3
3
  import warp as wp
4
- from warp.fem.utils import array_axpy, symmetric_eigenvalues_qr
4
+ from warp.fem.linalg import array_axpy, symmetric_eigenvalues_qr
5
5
  from warp.sparse import BsrMatrix, bsr_assign, bsr_axpy, bsr_copy, bsr_mm, bsr_mv
6
6
  from warp.types import type_is_matrix, type_length
7
7
 
warp/fem/domain.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, Union
1
+ from typing import Any, Optional, Set, Union
2
2
 
3
3
  import warp as wp
4
4
  import warp.codegen
@@ -11,6 +11,7 @@ from warp.fem.geometry import (
11
11
  GeometryPartition,
12
12
  WholeGeometryPartition,
13
13
  )
14
+ from warp.fem.operator import Operator
14
15
  from warp.fem.types import ElementKind
15
16
 
16
17
  GeometryOrPartition = Union[Geometry, GeometryPartition]
@@ -94,6 +95,10 @@ class GeometryDomain:
94
95
  element_lookup: wp.Function
95
96
  """Device function returning the sample point corresponding to a world position"""
96
97
 
98
+ def notify_operator_usage(self, ops: Set[Operator]):
99
+ """Makes the Domain aware that the operators `ops` will be applied"""
100
+ pass
101
+
97
102
 
98
103
  class Cells(GeometryDomain):
99
104
  """A Domain containing all cells of the geometry or geometry partition"""
@@ -160,6 +165,17 @@ class Cells(GeometryDomain):
160
165
  def element_lookup(self) -> wp.Function:
161
166
  return self.geometry.cell_lookup
162
167
 
168
+ @property
169
+ def domain_cell_arg(self) -> wp.Function:
170
+ return Cells._identity_fn
171
+
172
+ def cell_domain(self):
173
+ return self
174
+
175
+ @wp.func
176
+ def _identity_fn(x: Any):
177
+ return x
178
+
163
179
 
164
180
  class Sides(GeometryDomain):
165
181
  """A Domain containing all (interior and boundary) sides of the geometry or geometry partition"""
@@ -225,6 +241,33 @@ class Sides(GeometryDomain):
225
241
  def element_normal(self) -> wp.Function:
226
242
  return self.geometry.side_normal
227
243
 
244
+ @property
245
+ def element_inner_cell_index(self) -> wp.Function:
246
+ return self.geometry.side_inner_cell_index
247
+
248
+ @property
249
+ def element_outer_cell_index(self) -> wp.Function:
250
+ return self.geometry.side_outer_cell_index
251
+
252
+ @property
253
+ def element_inner_cell_coords(self) -> wp.Function:
254
+ return self.geometry.side_inner_cell_coords
255
+
256
+ @property
257
+ def element_outer_cell_coords(self) -> wp.Function:
258
+ return self.geometry.side_outer_cell_coords
259
+
260
+ @property
261
+ def cell_to_element_coords(self) -> wp.Function:
262
+ return self.geometry.side_from_cell_coords
263
+
264
+ @property
265
+ def domain_cell_arg(self) -> wp.Function:
266
+ return self.geometry.side_to_cell_arg
267
+
268
+ def cell_domain(self):
269
+ return Cells(self.geometry_partition)
270
+
228
271
 
229
272
  class BoundarySides(Sides):
230
273
  """A Domain containing boundary sides of the geometry or geometry partition"""
@@ -6,8 +6,7 @@ from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction, make
6
6
  from .field import DiscreteField, FieldLike, GeometryField, ImplicitField, NonconformingField, SpaceField, UniformField
7
7
  from .nodal_field import NodalField
8
8
  from .restriction import FieldRestriction
9
- from .test import TestField
10
- from .trial import TrialField
9
+ from .virtual import LocalTestField, LocalTrialField, TestField, TrialField
11
10
 
12
11
 
13
12
  def make_restriction(
warp/fem/field/field.py CHANGED
@@ -1,10 +1,10 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Dict, Optional, Set
2
2
 
3
3
  import warp as wp
4
4
  from warp.fem import cache
5
5
  from warp.fem.domain import GeometryDomain, Sides
6
6
  from warp.fem.geometry import DeformedGeometry, Geometry
7
- from warp.fem.operator import integrand
7
+ from warp.fem.operator import Operator, integrand
8
8
  from warp.fem.space import FunctionSpace, SpacePartition
9
9
  from warp.fem.types import NULL_ELEMENT_INDEX, ElementKind, Sample
10
10
 
@@ -48,32 +48,32 @@ class FieldLike:
48
48
  return False
49
49
 
50
50
  @staticmethod
51
- def eval_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
51
+ def eval_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
52
52
  """Device function evaluating the inner field value at a sample point"""
53
53
  raise NotImplementedError
54
54
 
55
55
  @staticmethod
56
- def eval_grad_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
56
+ def eval_grad_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
57
57
  """Device function evaluating the inner field gradient at a sample point"""
58
58
  raise NotImplementedError
59
59
 
60
60
  @staticmethod
61
- def eval_div_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
61
+ def eval_div_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
62
62
  """Device function evaluating the inner field divergence at a sample point"""
63
63
  raise NotImplementedError
64
64
 
65
65
  @staticmethod
66
- def eval_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
66
+ def eval_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
67
67
  """Device function evaluating the outer field value at a sample point"""
68
68
  raise NotImplementedError
69
69
 
70
70
  @staticmethod
71
- def eval_grad_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
71
+ def eval_grad_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
72
72
  """Device function evaluating the outer field gradient at a sample point"""
73
73
  raise NotImplementedError
74
74
 
75
75
  @staticmethod
76
- def eval_div_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
76
+ def eval_div_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
77
77
  """Device function evaluating the outer field divergence at a sample point"""
78
78
  raise NotImplementedError
79
79
 
@@ -82,6 +82,10 @@ class FieldLike:
82
82
  """Polynomial degree of the field is applicable, or hint for determination of interpolation order"""
83
83
  raise NotImplementedError
84
84
 
85
+ def notify_operator_usage(self, ops: Set[Operator]):
86
+ """Makes the Domain aware that the operators `ops` will be applied"""
87
+ pass
88
+
85
89
 
86
90
  class GeometryField(FieldLike):
87
91
  """Base class for fields defined over a geometry"""
@@ -97,12 +101,12 @@ class GeometryField(FieldLike):
97
101
  raise NotImplementedError
98
102
 
99
103
  @staticmethod
100
- def eval_reference_grad_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
104
+ def eval_reference_grad_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
101
105
  """Device function evaluating the inner field gradient with respect to reference element coordinates at a sample point"""
102
106
  raise NotImplementedError
103
107
 
104
108
  @staticmethod
105
- def eval_reference_grad_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
109
+ def eval_reference_grad_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
106
110
  """Device function evaluating the outer field gradient with respect to reference element coordinates at a sample point"""
107
111
  raise NotImplementedError
108
112
 
@@ -128,6 +132,9 @@ class SpaceField(GeometryField):
128
132
  self._space = space
129
133
  self._space_partition = space_partition
130
134
 
135
+ self.gradient_valid = self.space.gradient_valid
136
+ self.divergence_valid = self.space.divergence_valid
137
+
131
138
  @property
132
139
  def geometry(self) -> Geometry:
133
140
  return self._space.geometry
@@ -156,17 +163,22 @@ class SpaceField(GeometryField):
156
163
  def dof_dtype(self) -> type:
157
164
  return self.space.dof_dtype
158
165
 
159
- def gradient_valid(self) -> bool:
160
- """Whether gradient operator can be computed. Only for scalar and vector fields as higher-order tensors are not support yet"""
161
- return not wp.types.type_is_matrix(self.dtype)
166
+ @property
167
+ def gradient_dtype(self):
168
+ """Return type of the gradient operator. Assumes self.gradient_valid()"""
169
+ if wp.types.type_is_vector(self.dtype):
170
+ return cache.cached_mat_type(
171
+ shape=(wp.types.type_length(self.dtype), self.geometry.dimension),
172
+ dtype=wp.types.type_scalar_type(self.dtype),
173
+ )
174
+ return cache.cached_vec_type(length=self.geometry.dimension, dtype=wp.types.type_scalar_type(self.dtype))
162
175
 
163
- def divergence_valid(self) -> bool:
164
- """Whether divergence of this field can be computed. Only for vector and tensor fields with same dimension as embedding geometry"""
176
+ @property
177
+ def divergence_dtype(self):
178
+ """Return type of the divergence operator. Assumes self.gradient_valid()"""
165
179
  if wp.types.type_is_vector(self.dtype):
166
- return wp.types.type_length(self.dtype) == self.space.geometry.dimension
167
- if wp.types.type_is_matrix(self.dtype):
168
- return self.dtype._shape_[0] == self.space.geometry.dimension
169
- return False
180
+ return wp.types.type_scalar_type(self.dtype)
181
+ return cache.cached_vec_type(length=self.dtype._shape_[1], dtype=wp.types.type_scalar_type(self.dtype))
170
182
 
171
183
  def _make_eval_degree(self):
172
184
  ORDER = self.space.ORDER
@@ -1,5 +1,7 @@
1
+ from typing import Any
2
+
1
3
  import warp as wp
2
- from warp.fem import cache, utils
4
+ from warp.fem import cache
3
5
  from warp.fem.space import CollocatedFunctionSpace, SpacePartition
4
6
  from warp.fem.types import NULL_NODE_INDEX, ElementIndex, Sample
5
7
 
@@ -56,58 +58,79 @@ class NodalFieldBase(DiscreteField):
56
58
  )
57
59
  pidx = self.space_partition.partition_node_index(args.eval_arg.partition_arg, nidx)
58
60
  if pidx == NULL_NODE_INDEX:
59
- return self.space.dtype(0.0)
61
+ return self.space.dof_dtype(0.0)
60
62
 
61
- return self.space.dof_mapper.dof_to_value(args.eval_arg.dof_values[pidx])
63
+ return args.eval_arg.dof_values[pidx]
62
64
 
63
65
  return read_node_value
64
66
 
65
67
  def _make_eval_inner(self):
66
68
  @cache.dynamic_func(suffix=self.name)
67
69
  def eval_inner(args: self.ElementEvalArg, s: Sample):
68
- res = self.space.element_inner_weight(
69
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
70
- ) * self._read_node_value(args, s.element_index, 0)
70
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
71
+ res = self.space.space_value(
72
+ self._read_node_value(args, s.element_index, 0),
73
+ self.space.element_inner_weight(
74
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
75
+ ),
76
+ local_value_map,
77
+ )
78
+
71
79
  node_count = self.space.topology.element_node_count(
72
80
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
73
81
  )
74
82
  for k in range(1, node_count):
75
- res += self.space.element_inner_weight(
76
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
77
- ) * self._read_node_value(args, s.element_index, k)
83
+ res += self.space.space_value(
84
+ self._read_node_value(args, s.element_index, k),
85
+ self.space.element_inner_weight(
86
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
87
+ ),
88
+ local_value_map,
89
+ )
78
90
  return res
79
91
 
80
92
  return eval_inner
81
93
 
82
94
  def _make_eval_grad_inner(self, world_space: bool):
83
- if not self.gradient_valid():
95
+ if not self.space.gradient_valid():
84
96
  return None
85
97
 
86
98
  @cache.dynamic_func(suffix=self.name)
87
- def eval_grad_inner_ref_space(args: self.ElementEvalArg, s: Sample):
88
- res = utils.generalized_outer(
99
+ def eval_grad_inner(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
100
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
101
+
102
+ res = self.space.space_gradient(
89
103
  self._read_node_value(args, s.element_index, 0),
90
104
  self.space.element_inner_weight_gradient(
91
105
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
92
106
  ),
107
+ local_value_map,
108
+ grad_transform,
93
109
  )
110
+
94
111
  node_count = self.space.topology.element_node_count(
95
112
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
96
113
  )
97
114
  for k in range(1, node_count):
98
- res += utils.generalized_outer(
115
+ res += self.space.space_gradient(
99
116
  self._read_node_value(args, s.element_index, k),
100
117
  self.space.element_inner_weight_gradient(
101
118
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
102
119
  ),
120
+ local_value_map,
121
+ grad_transform,
103
122
  )
104
123
  return res
105
124
 
125
+ @cache.dynamic_func(suffix=self.name)
126
+ def eval_grad_inner_ref_space(args: self.ElementEvalArg, s: Sample):
127
+ grad_transform = 1.0
128
+ return eval_grad_inner(args, s, grad_transform)
129
+
106
130
  @cache.dynamic_func(suffix=self.name)
107
131
  def eval_grad_inner_world_space(args: self.ElementEvalArg, s: Sample):
108
132
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
109
- res = eval_grad_inner_ref_space(args, s)
110
- return res * grad_transform
133
+ return eval_grad_inner(args, s, grad_transform)
111
134
 
112
135
  return eval_grad_inner_world_space if world_space else eval_grad_inner_ref_space
113
136
 
@@ -118,25 +141,28 @@ class NodalFieldBase(DiscreteField):
118
141
  @cache.dynamic_func(suffix=self.name)
119
142
  def eval_div_inner(args: self.ElementEvalArg, s: Sample):
120
143
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
144
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
121
145
 
122
- res = utils.generalized_inner(
146
+ res = self.space.space_divergence(
123
147
  self._read_node_value(args, s.element_index, 0),
124
148
  self.space.element_inner_weight_gradient(
125
149
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
126
- )
127
- * grad_transform,
150
+ ),
151
+ local_value_map,
152
+ grad_transform,
128
153
  )
129
154
 
130
155
  node_count = self.space.topology.element_node_count(
131
156
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
132
157
  )
133
158
  for k in range(1, node_count):
134
- res += utils.generalized_inner(
159
+ res += self.space.space_divergence(
135
160
  self._read_node_value(args, s.element_index, k),
136
161
  self.space.element_inner_weight_gradient(
137
162
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
138
- )
139
- * grad_transform,
163
+ ),
164
+ local_value_map,
165
+ grad_transform,
140
166
  )
141
167
  return res
142
168
 
@@ -145,57 +171,71 @@ class NodalFieldBase(DiscreteField):
145
171
  def _make_eval_outer(self):
146
172
  @cache.dynamic_func(suffix=self.name)
147
173
  def eval_outer(args: self.ElementEvalArg, s: Sample):
148
- res = self.space.element_outer_weight(
149
- args.elt_arg,
150
- args.eval_arg.space_arg,
151
- s.element_index,
152
- s.element_coords,
153
- 0,
154
- ) * self._read_node_value(args, s.element_index, 0)
174
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
175
+ res = self.space.space_value(
176
+ self._read_node_value(args, s.element_index, 0),
177
+ self.space.element_outer_weight(
178
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
179
+ ),
180
+ local_value_map,
181
+ )
182
+
155
183
  node_count = self.space.topology.element_node_count(
156
184
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
157
185
  )
186
+
158
187
  for k in range(1, node_count):
159
- res += self.space.element_outer_weight(
160
- args.elt_arg,
161
- args.eval_arg.space_arg,
162
- s.element_index,
163
- s.element_coords,
164
- k,
165
- ) * self._read_node_value(args, s.element_index, k)
188
+ res += self.space.space_value(
189
+ self._read_node_value(args, s.element_index, k),
190
+ self.space.element_outer_weight(
191
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
192
+ ),
193
+ local_value_map,
194
+ )
166
195
  return res
167
196
 
168
197
  return eval_outer
169
198
 
170
199
  def _make_eval_grad_outer(self, world_space: bool):
171
- if not self.gradient_valid():
200
+ if not self.space.gradient_valid():
172
201
  return None
173
202
 
174
203
  @cache.dynamic_func(suffix=self.name)
175
- def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
176
- res = utils.generalized_outer(
204
+ def eval_grad_outer(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
205
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
206
+
207
+ res = self.space.space_gradient(
177
208
  self._read_node_value(args, s.element_index, 0),
178
209
  self.space.element_outer_weight_gradient(
179
210
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
180
211
  ),
212
+ local_value_map,
213
+ grad_transform,
181
214
  )
215
+
182
216
  node_count = self.space.topology.element_node_count(
183
217
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
184
218
  )
185
219
  for k in range(1, node_count):
186
- res += utils.generalized_outer(
220
+ res += self.space.space_gradient(
187
221
  self._read_node_value(args, s.element_index, k),
188
222
  self.space.element_outer_weight_gradient(
189
223
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
190
224
  ),
225
+ local_value_map,
226
+ grad_transform,
191
227
  )
192
228
  return res
193
229
 
230
+ @cache.dynamic_func(suffix=self.name)
231
+ def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
232
+ grad_transform = 1.0
233
+ return eval_grad_outer_ref_space(args, s, grad_transform)
234
+
194
235
  @cache.dynamic_func(suffix=self.name)
195
236
  def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
196
237
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
197
- res = eval_grad_outer_ref_space(args, s)
198
- return res * grad_transform
238
+ return eval_grad_outer_ref_space(args, s, grad_transform)
199
239
 
200
240
  return eval_grad_outer_world_space if world_space else eval_grad_outer_ref_space
201
241
 
@@ -206,25 +246,28 @@ class NodalFieldBase(DiscreteField):
206
246
  @cache.dynamic_func(suffix=self.name)
207
247
  def eval_div_outer(args: self.ElementEvalArg, s: Sample):
208
248
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
249
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
209
250
 
210
- res = utils.generalized_inner(
251
+ res = self.space.space_divergence(
211
252
  self._read_node_value(args, s.element_index, 0),
212
253
  self.space.element_outer_weight_gradient(
213
254
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
214
- )
215
- * grad_transform,
255
+ ),
256
+ local_value_map,
257
+ grad_transform,
216
258
  )
217
259
 
218
260
  node_count = self.space.topology.element_node_count(
219
261
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
220
262
  )
221
263
  for k in range(1, node_count):
222
- res += utils.generalized_inner(
264
+ res += self.space.space_divergence(
223
265
  self._read_node_value(args, s.element_index, k),
224
266
  self.space.element_outer_weight_gradient(
225
267
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
226
- )
227
- * grad_transform,
268
+ ),
269
+ local_value_map,
270
+ grad_transform,
228
271
  )
229
272
  return res
230
273
 
@@ -232,8 +275,17 @@ class NodalFieldBase(DiscreteField):
232
275
 
233
276
  def _make_set_node_value(self):
234
277
  @cache.dynamic_func(suffix=self.name)
235
- def set_node_value(args: self.EvalArg, partition_node_index: int, value: self.space.dtype):
236
- args.dof_values[partition_node_index] = self.space.dof_mapper.value_to_dof(value)
278
+ def set_node_value(
279
+ elt_arg: self.space.ElementArg,
280
+ eval_arg: self.EvalArg,
281
+ element_index: ElementIndex,
282
+ node_index_in_element: int,
283
+ partition_node_index: int,
284
+ value: self.space.dtype,
285
+ ):
286
+ eval_arg.dof_values[partition_node_index] = self.space.node_dof_value(
287
+ elt_arg, eval_arg.space_arg, element_index, node_index_in_element, value
288
+ )
237
289
 
238
290
  return set_node_value
239
291