warp-lang 1.4.1__py3-none-macosx_10_13_universal2.whl → 1.5.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,794 @@
1
+ from typing import Any, Set
2
+
3
+ import warp as wp
4
+ import warp.fem.operator as operator
5
+ from warp.fem import cache
6
+ from warp.fem.domain import GeometryDomain
7
+ from warp.fem.linalg import basis_coefficient, generalized_inner, generalized_outer
8
+ from warp.fem.quadrature import Quadrature
9
+ from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction
10
+ from warp.fem.types import NULL_NODE_INDEX, DofIndex, Sample, get_node_coord, get_node_index_in_element
11
+
12
+ from .field import SpaceField
13
+
14
+
15
+ class AdjointField(SpaceField):
16
+ """Adjoint of a discrete field with respect to its degrees of freedom"""
17
+
18
+ def __init__(self, space: FunctionSpace, space_partition: SpaceRestriction):
19
+ super().__init__(space, space_partition=space_partition)
20
+
21
+ self.node_dof_count = self.space.NODE_DOF_COUNT
22
+ self.value_dof_count = self.space.VALUE_DOF_COUNT
23
+
24
+ self.EvalArg = self.space.SpaceArg
25
+ self.ElementEvalArg = self._make_element_eval_arg()
26
+
27
+ self.eval_arg_value = self.space.space_arg_value
28
+
29
+ self.eval_degree = self._make_eval_degree()
30
+ self.eval_inner = self._make_eval_inner()
31
+ self.eval_grad_inner = self._make_eval_grad_inner()
32
+ self.eval_div_inner = self._make_eval_div_inner()
33
+ self.eval_outer = self._make_eval_outer()
34
+ self.eval_grad_outer = self._make_eval_grad_outer()
35
+ self.eval_div_outer = self._make_eval_div_outer()
36
+ self.at_node = self._make_at_node()
37
+
38
+ @property
39
+ def name(self) -> str:
40
+ return f"{self.__class__.__name__}{self.space.name}{self._space_partition.name}"
41
+
42
+ def _make_element_eval_arg(self):
43
+ from warp.fem import cache
44
+
45
+ @cache.dynamic_struct(suffix=self.name)
46
+ class ElementEvalArg:
47
+ elt_arg: self.space.topology.ElementArg
48
+ eval_arg: self.EvalArg
49
+
50
+ return ElementEvalArg
51
+
52
+ def _make_eval_inner(self):
53
+ @cache.dynamic_func(suffix=self.name)
54
+ def eval_test_inner(args: self.ElementEvalArg, s: Sample):
55
+ dof = self._get_dof(s)
56
+ node_weight = self.space.element_inner_weight(
57
+ args.elt_arg, args.eval_arg, s.element_index, s.element_coords, get_node_index_in_element(dof)
58
+ )
59
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
60
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
61
+ return self.space.space_value(dof_value, node_weight, local_value_map)
62
+
63
+ return eval_test_inner
64
+
65
+ def _make_eval_grad_inner(self):
66
+ if not self.space.gradient_valid():
67
+ return None
68
+
69
+ @cache.dynamic_func(suffix=self.name)
70
+ def eval_grad_inner(args: self.ElementEvalArg, s: Sample):
71
+ dof = self._get_dof(s)
72
+ nabla_weight = self.space.element_inner_weight_gradient(
73
+ args.elt_arg,
74
+ args.eval_arg,
75
+ s.element_index,
76
+ s.element_coords,
77
+ get_node_index_in_element(dof),
78
+ )
79
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
80
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
81
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
82
+ return self.space.space_gradient(dof_value, nabla_weight, local_value_map, grad_transform)
83
+
84
+ return eval_grad_inner
85
+
86
+ def _make_eval_div_inner(self):
87
+ if not self.space.divergence_valid():
88
+ return None
89
+
90
+ @cache.dynamic_func(suffix=self.name)
91
+ def eval_div_inner(args: self.ElementEvalArg, s: Sample):
92
+ dof = self._get_dof(s)
93
+ nabla_weight = self.space.element_inner_weight_gradient(
94
+ args.elt_arg,
95
+ args.eval_arg,
96
+ s.element_index,
97
+ s.element_coords,
98
+ get_node_index_in_element(dof),
99
+ )
100
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
101
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
102
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
103
+ return self.space.space_divergence(dof_value, nabla_weight, local_value_map, grad_transform)
104
+
105
+ return eval_div_inner
106
+
107
+ def _make_eval_outer(self):
108
+ @cache.dynamic_func(suffix=self.name)
109
+ def eval_test_outer(args: self.ElementEvalArg, s: Sample):
110
+ dof = self._get_dof(s)
111
+ node_weight = self.space.element_outer_weight(
112
+ args.elt_arg, args.eval_arg, s.element_index, s.element_coords, get_node_index_in_element(dof)
113
+ )
114
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
115
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
116
+ return self.space.space_value(dof_value, node_weight, local_value_map)
117
+
118
+ return eval_test_outer
119
+
120
+ def _make_eval_grad_outer(self):
121
+ if not self.space.gradient_valid():
122
+ return None
123
+
124
+ @cache.dynamic_func(suffix=self.name)
125
+ def eval_grad_outer(args: self.ElementEvalArg, s: Sample):
126
+ dof = self._get_dof(s)
127
+ nabla_weight = self.space.element_outer_weight_gradient(
128
+ args.elt_arg,
129
+ args.eval_arg,
130
+ s.element_index,
131
+ s.element_coords,
132
+ get_node_index_in_element(dof),
133
+ )
134
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
135
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
136
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
137
+ return self.space.space_gradient(dof_value, nabla_weight, local_value_map, grad_transform)
138
+
139
+ return eval_grad_outer
140
+
141
+ def _make_eval_div_outer(self):
142
+ if not self.space.divergence_valid():
143
+ return None
144
+
145
+ @cache.dynamic_func(suffix=self.name)
146
+ def eval_div_outer(args: self.ElementEvalArg, s: Sample):
147
+ dof = self._get_dof(s)
148
+ nabla_weight = self.space.element_outer_weight_gradient(
149
+ args.elt_arg,
150
+ args.eval_arg,
151
+ s.element_index,
152
+ s.element_coords,
153
+ get_node_index_in_element(dof),
154
+ )
155
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
156
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
157
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
158
+ return self.space.space_divergence(dof_value, nabla_weight, local_value_map, grad_transform)
159
+
160
+ return eval_div_outer
161
+
162
+ def _make_at_node(self):
163
+ @cache.dynamic_func(suffix=self.name)
164
+ def at_node(args: self.ElementEvalArg, s: Sample):
165
+ dof = self._get_dof(s)
166
+ node_coords = self.space.node_coords_in_element(
167
+ args.elt_arg, args.eval_arg, s.element_index, get_node_index_in_element(dof)
168
+ )
169
+ return Sample(s.element_index, node_coords, s.qp_index, s.qp_weight, s.test_dof, s.trial_dof)
170
+
171
+ return at_node
172
+
173
+
174
+ class TestField(AdjointField):
175
+ """Field defined over a space restriction that can be used as a test function.
176
+
177
+ In order to reuse computations, it is possible to define the test field using a SpaceRestriction
178
+ defined for a different value type than the test function value type, as long as the node topology is similar.
179
+ """
180
+
181
+ def __init__(self, space_restriction: SpaceRestriction, space: FunctionSpace):
182
+ if space_restriction.domain.dimension == space.dimension - 1:
183
+ space = space.trace()
184
+
185
+ if space_restriction.domain.dimension != space.dimension:
186
+ raise ValueError("Incompatible space and domain dimensions")
187
+
188
+ if space.topology != space_restriction.space_topology:
189
+ raise ValueError("Incompatible space and space partition topologies")
190
+
191
+ super().__init__(space, space_restriction.space_partition)
192
+
193
+ self.space_restriction = space_restriction
194
+ self.domain = space_restriction.domain
195
+
196
+ @wp.func
197
+ def _get_dof(s: Sample):
198
+ return s.test_dof
199
+
200
+
201
+ class TrialField(AdjointField):
202
+ """Field defined over a domain that can be used as a trial function"""
203
+
204
+ def __init__(
205
+ self,
206
+ space: FunctionSpace,
207
+ space_partition: SpacePartition,
208
+ domain: GeometryDomain,
209
+ ):
210
+ if domain.dimension == space.dimension - 1:
211
+ space = space.trace()
212
+
213
+ if domain.dimension != space.dimension:
214
+ raise ValueError("Incompatible space and domain dimensions")
215
+
216
+ if not space.topology.is_derived_from(space_partition.space_topology):
217
+ raise ValueError("Incompatible space and space partition topologies")
218
+
219
+ super().__init__(space, space_partition)
220
+ self.domain = domain
221
+
222
+ def partition_node_count(self) -> int:
223
+ """Returns the number of nodes in the associated space topology partition"""
224
+ return self.space_partition.node_count()
225
+
226
+ @wp.func
227
+ def _get_dof(s: Sample):
228
+ return s.trial_dof
229
+
230
+
231
+ class LocalAdjointField(SpaceField):
232
+ """
233
+ A custom field specially for dispatched assembly.
234
+ Stores adjoint and gradient adjoint at quadrature point locations.
235
+ """
236
+
237
+ INNER_DOF = wp.constant(0)
238
+ OUTER_DOF = wp.constant(1)
239
+ INNER_GRAD_DOF = wp.constant(2)
240
+ OUTER_GRAD_DOF = wp.constant(3)
241
+ DOF_TYPE_COUNT = wp.constant(4)
242
+
243
+ _OP_DOF_MAP_CONTINUOUS = {
244
+ operator.inner: INNER_DOF,
245
+ operator.outer: INNER_DOF,
246
+ operator.grad: INNER_GRAD_DOF,
247
+ operator.grad_outer: INNER_GRAD_DOF,
248
+ operator.div: INNER_GRAD_DOF,
249
+ operator.div_outer: INNER_GRAD_DOF,
250
+ }
251
+
252
+ _OP_DOF_MAP_DISCONTINUOUS = {
253
+ operator.inner: INNER_DOF,
254
+ operator.outer: OUTER_DOF,
255
+ operator.grad: INNER_GRAD_DOF,
256
+ operator.grad_outer: OUTER_GRAD_DOF,
257
+ operator.div: INNER_GRAD_DOF,
258
+ operator.div_outer: OUTER_GRAD_DOF,
259
+ }
260
+
261
+ DofOffsets = wp.vec(length=DOF_TYPE_COUNT, dtype=int)
262
+
263
+ @wp.struct
264
+ class EvalArg:
265
+ pass
266
+
267
+ def __init__(self, field: AdjointField):
268
+ # if not isinstance(field.space, CollocatedFunctionSpace):
269
+ # raise NotImplementedError("Local assembly only implemented for collocated function spaces")
270
+
271
+ super().__init__(field.space, space_partition=field.space_partition)
272
+ self.global_field = field
273
+
274
+ self.domain = self.global_field.domain
275
+ self.node_dof_count = self.space.NODE_DOF_COUNT
276
+ self.value_dof_count = self.space.VALUE_DOF_COUNT
277
+
278
+ self._dof_suffix = ""
279
+
280
+ self.ElementEvalArg = self._make_element_eval_arg()
281
+ self.eval_degree = self._make_eval_degree()
282
+ self.at_node = None
283
+
284
+ self._is_discontinuous = (self.space.element_inner_weight != self.space.element_outer_weight) or (
285
+ self.space.element_inner_weight_gradient != self.space.element_outer_weight_gradient
286
+ )
287
+
288
+ self._TAYLOR_DOF_OFFSETS = LocalAdjointField.DofOffsets(0)
289
+ self._TAYLOR_DOF_COUNTS = LocalAdjointField.DofOffsets(0)
290
+ self.TAYLOR_DOF_COUNT = 0
291
+
292
+ def notify_operator_usage(self, ops: Set[operator.Operator]):
293
+ # Rebuild degrees-of-freedom offsets based on used operators
294
+
295
+ operators_dof_map = (
296
+ LocalAdjointField._OP_DOF_MAP_DISCONTINUOUS
297
+ if self._is_discontinuous
298
+ else LocalAdjointField._OP_DOF_MAP_CONTINUOUS
299
+ )
300
+
301
+ dof_counts = LocalAdjointField.DofOffsets(0)
302
+ for op in ops:
303
+ if op in operators_dof_map:
304
+ dof_counts[operators_dof_map[op]] = 1
305
+
306
+ grad_dim = self.geometry.cell_dimension
307
+ dof_counts[LocalAdjointField.INNER_GRAD_DOF] *= grad_dim
308
+ dof_counts[LocalAdjointField.OUTER_GRAD_DOF] *= grad_dim
309
+
310
+ dof_offsets = LocalAdjointField.DofOffsets(0)
311
+ for k in range(1, LocalAdjointField.DOF_TYPE_COUNT):
312
+ dof_offsets[k] = dof_offsets[k - 1] + dof_counts[k - 1]
313
+
314
+ self.TAYLOR_DOF_COUNT = wp.constant(dof_offsets[k] + dof_counts[k])
315
+
316
+ self._TAYLOR_DOF_OFFSETS = dof_offsets
317
+ self._TAYLOR_DOF_COUNTS = dof_counts
318
+
319
+ self._dof_suffix = "".join(str(c) for c in dof_counts)
320
+
321
+ self._split_dof = self._make_split_dof()
322
+
323
+ self.eval_inner = self._make_eval_inner()
324
+ self.eval_grad_inner = self._make_eval_grad_inner()
325
+ self.eval_div_inner = self._make_eval_div_inner()
326
+
327
+ if self._is_discontinuous:
328
+ self.eval_outer = self._make_eval_outer()
329
+ self.eval_grad_outer = self._make_eval_grad_outer()
330
+ self.eval_div_outer = self._make_eval_div_outer()
331
+ else:
332
+ self.eval_outer = self.eval_inner
333
+ self.eval_grad_outer = self.eval_grad_inner
334
+ self.eval_div_outer = self.eval_div_inner
335
+
336
+ @property
337
+ def name(self) -> str:
338
+ return f"{self.global_field.name}_Taylor{self._dof_suffix}"
339
+
340
+ def eval_arg_value(self, device):
341
+ return LocalAdjointField.EvalArg()
342
+
343
+ def _make_element_eval_arg(self):
344
+ from warp.fem import cache
345
+
346
+ @cache.dynamic_struct(suffix=self.name)
347
+ class ElementEvalArg:
348
+ elt_arg: self.space.topology.ElementArg
349
+ eval_arg: self.EvalArg
350
+
351
+ return ElementEvalArg
352
+
353
+ def _make_split_dof(self):
354
+ TAYLOR_DOF_COUNT = self.TAYLOR_DOF_COUNT
355
+
356
+ @cache.dynamic_func(suffix=str(TAYLOR_DOF_COUNT))
357
+ def split_dof(dof_index: DofIndex, dof_begin: int):
358
+ dof = get_node_coord(dof_index)
359
+ value_dof = dof // TAYLOR_DOF_COUNT
360
+ taylor_dof = dof - value_dof * TAYLOR_DOF_COUNT - dof_begin
361
+ return value_dof, taylor_dof
362
+
363
+ return split_dof
364
+
365
+ def _make_eval_inner(self):
366
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF])
367
+
368
+ @cache.dynamic_func(suffix=self.name)
369
+ def eval_test_inner(args: self.ElementEvalArg, s: Sample):
370
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
371
+
372
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
373
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
374
+ return wp.select(taylor_dof == 0, self.dtype(0.0), dof_value)
375
+
376
+ return eval_test_inner
377
+
378
+ def _make_eval_grad_inner(self):
379
+ if not self.gradient_valid():
380
+ return None
381
+
382
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
383
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
384
+
385
+ @cache.dynamic_func(suffix=self.name)
386
+ def eval_nabla_test_inner(args: self.ElementEvalArg, s: Sample):
387
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
388
+
389
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
390
+ return self.gradient_dtype(0.0)
391
+
392
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
393
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
394
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
395
+ return generalized_outer(dof_value, grad_transform[taylor_dof])
396
+
397
+ return eval_nabla_test_inner
398
+
399
+ def _make_eval_div_inner(self):
400
+ if not self.divergence_valid():
401
+ return None
402
+
403
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
404
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
405
+
406
+ @cache.dynamic_func(suffix=self.name)
407
+ def eval_div_test_inner(args: self.ElementEvalArg, s: Sample):
408
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
409
+
410
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
411
+ return self.divergence_dtype(0.0)
412
+
413
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
414
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
415
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
416
+ return generalized_inner(dof_value, grad_transform[taylor_dof])
417
+
418
+ return eval_div_test_inner
419
+
420
+ def _make_eval_outer(self):
421
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF])
422
+
423
+ @cache.dynamic_func(suffix=self.name)
424
+ def eval_test_outer(args: self.ElementEvalArg, s: Sample):
425
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
426
+
427
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
428
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
429
+ return wp.select(taylor_dof == 0, self.dtype(0.0), dof_value)
430
+
431
+ return eval_test_outer
432
+
433
+ def _make_eval_grad_outer(self):
434
+ if not self.gradient_valid():
435
+ return None
436
+
437
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
438
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
439
+
440
+ @cache.dynamic_func(suffix=self.name)
441
+ def eval_nabla_test_outer(args: self.ElementEvalArg, s: Sample):
442
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
443
+
444
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
445
+ return self.gradient_dtype(0.0)
446
+
447
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
448
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
449
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
450
+ return generalized_outer(dof_value, grad_transform[taylor_dof])
451
+
452
+ return eval_nabla_test_outer
453
+
454
+ def _make_eval_div_outer(self):
455
+ if not self.divergence_valid():
456
+ return None
457
+
458
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
459
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
460
+
461
+ @cache.dynamic_func(suffix=self.name)
462
+ def eval_div_test_outer(args: self.ElementEvalArg, s: Sample):
463
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
464
+
465
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
466
+ return self.divergence_dtype(0.0)
467
+
468
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
469
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
470
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
471
+ return generalized_inner(dof_value, grad_transform[taylor_dof])
472
+
473
+ return eval_div_test_outer
474
+
475
+
476
+ class LocalTestField(LocalAdjointField):
477
+ def __init__(self, test_field: TestField):
478
+ super().__init__(test_field)
479
+ self.space_restriction = test_field.space_restriction
480
+
481
+ @wp.func
482
+ def _get_dof(s: Sample):
483
+ return s.test_dof
484
+
485
+
486
+ class LocalTrialField(LocalAdjointField):
487
+ def __init__(self, trial_field: TrialField):
488
+ super().__init__(trial_field)
489
+
490
+ @wp.func
491
+ def _get_dof(s: Sample):
492
+ return s.trial_dof
493
+
494
+
495
+ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, accumulate_dtype: type):
496
+ global_test: TestField = test.global_field
497
+ space_restriction = global_test.space_restriction
498
+ domain = global_test.domain
499
+
500
+ TEST_INNER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
501
+ TEST_OUTER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
502
+ TEST_INNER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
503
+ TEST_OUTER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
504
+
505
+ TEST_INNER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
506
+ TEST_OUTER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
507
+ TEST_INNER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
508
+ TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
509
+
510
+ TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
511
+
512
+ @cache.dynamic_kernel(f"{test.name}_{quadrature.name}_{wp.types.get_type_code(accumulate_dtype)}")
513
+ def dispatch_linear_kernel_fn(
514
+ qp_arg: quadrature.Arg,
515
+ domain_arg: domain.ElementArg,
516
+ domain_index_arg: domain.ElementIndexArg,
517
+ test_arg: space_restriction.NodeArg,
518
+ test_space_arg: test.space.SpaceArg,
519
+ local_result: wp.array3d(dtype=Any),
520
+ result: wp.array2d(dtype=Any),
521
+ ):
522
+ local_node_index, test_node_dof = wp.tid()
523
+ node_index = space_restriction.node_partition_index(test_arg, local_node_index)
524
+ element_beg, element_end = space_restriction.node_element_range(test_arg, node_index)
525
+
526
+ val_sum = accumulate_dtype(0.0)
527
+
528
+ for n in range(element_beg, element_end):
529
+ test_element_index = space_restriction.node_element_index(test_arg, n)
530
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
531
+
532
+ qp_point_count = quadrature.point_count(
533
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
534
+ )
535
+ for k in range(qp_point_count):
536
+ qp_index = quadrature.point_index(
537
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
538
+ )
539
+ coords = quadrature.point_coords(
540
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
541
+ )
542
+
543
+ qp_result = local_result[qp_index]
544
+
545
+ qp_sum = float(0.0)
546
+
547
+ if wp.static(0 != TEST_INNER_COUNT):
548
+ w = test.space.element_inner_weight(
549
+ domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
550
+ )
551
+ for val_dof in range(TEST_NODE_DOF_DIM):
552
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
553
+ qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
554
+
555
+ if wp.static(0 != TEST_OUTER_COUNT):
556
+ w = test.space.element_outer_weight(
557
+ domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
558
+ )
559
+ for val_dof in range(TEST_NODE_DOF_DIM):
560
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
561
+ qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
562
+
563
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
564
+ w_grad = test.space.element_inner_weight_gradient(
565
+ domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
566
+ )
567
+ for val_dof in range(TEST_NODE_DOF_DIM):
568
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
569
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
570
+ qp_sum += (
571
+ basis_coefficient(w_grad, val_dof, grad_dof)
572
+ * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
573
+ )
574
+
575
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
576
+ w_grad = test.space.element_outer_weight_gradient(
577
+ domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
578
+ )
579
+ for val_dof in range(TEST_NODE_DOF_DIM):
580
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
581
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
582
+ qp_sum += (
583
+ basis_coefficient(w_grad, val_dof, grad_dof)
584
+ * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
585
+ )
586
+
587
+ val_sum += accumulate_dtype(qp_sum)
588
+
589
+ result[node_index, test_node_dof] += result.dtype(val_sum)
590
+
591
+ return dispatch_linear_kernel_fn
592
+
593
+
594
+ def make_bilinear_dispatch_kernel(
595
+ test: LocalTestField, trial: LocalTrialField, quadrature: Quadrature, accumulate_dtype: type
596
+ ):
597
+ global_test: TestField = test.global_field
598
+ space_restriction = global_test.space_restriction
599
+ domain = global_test.domain
600
+
601
+ TEST_INNER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
602
+ TEST_OUTER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
603
+ TEST_INNER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
604
+ TEST_OUTER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
605
+
606
+ TEST_INNER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
607
+ TEST_OUTER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
608
+ TEST_INNER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
609
+ TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
610
+
611
+ TRIAL_INNER_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
612
+ TRIAL_OUTER_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
613
+ TRIAL_INNER_GRAD_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
614
+ TRIAL_OUTER_GRAD_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
615
+
616
+ TRIAL_INNER_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
617
+ TRIAL_OUTER_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
618
+ TRIAL_INNER_GRAD_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
619
+ TRIAL_OUTER_GRAD_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
620
+
621
+ TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
622
+ TRIAL_NODE_DOF_DIM = trial.value_dof_count // trial.node_dof_count
623
+
624
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
625
+
626
+ trial_dof_vec = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
627
+
628
+ @cache.dynamic_kernel(f"{trial.name}_{test.name}_{quadrature.name}{wp.types.get_type_code(accumulate_dtype)}")
629
+ def dispatch_bilinear_kernel_fn(
630
+ qp_arg: quadrature.Arg,
631
+ domain_arg: domain.ElementArg,
632
+ domain_index_arg: domain.ElementIndexArg,
633
+ test_arg: test.space_restriction.NodeArg,
634
+ test_space_arg: test.space.SpaceArg,
635
+ trial_partition_arg: trial.space_partition.PartitionArg,
636
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
637
+ trial_space_arg: trial.space.SpaceArg,
638
+ local_result: wp.array4d(dtype=trial_dof_vec),
639
+ triplet_rows: wp.array(dtype=int),
640
+ triplet_cols: wp.array(dtype=int),
641
+ triplet_values: wp.array3d(dtype=Any),
642
+ ):
643
+ test_local_node_index, test_node_dof, trial_node_dof, trial_node = wp.tid()
644
+
645
+ test_node_index = space_restriction.node_partition_index(test_arg, test_local_node_index)
646
+ element_beg, element_end = space_restriction.node_element_range(test_arg, test_node_index)
647
+
648
+ for element in range(element_beg, element_end):
649
+ test_element_index = space_restriction.node_element_index(test_arg, element)
650
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
651
+ test_node = test_element_index.node_index_in_element
652
+
653
+ element_trial_node_count = trial.space.topology.element_node_count(
654
+ domain_arg, trial_topology_arg, element_index
655
+ )
656
+
657
+ qp_point_count = wp.select(
658
+ trial_node < element_trial_node_count,
659
+ 0,
660
+ quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
661
+ )
662
+
663
+ val_sum = accumulate_dtype(0.0)
664
+
665
+ for k in range(qp_point_count):
666
+ qp_index = quadrature.point_index(
667
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
668
+ )
669
+ coords = quadrature.point_coords(
670
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
671
+ )
672
+
673
+ qp_result = local_result[qp_index]
674
+ trial_result = float(0.0)
675
+
676
+ if wp.static(0 != TEST_INNER_COUNT):
677
+ w_test_inner = test.space.element_inner_weight(
678
+ domain_arg, test_space_arg, element_index, coords, test_node
679
+ )
680
+
681
+ if wp.static(0 != TEST_OUTER_COUNT):
682
+ w_test_outer = test.space.element_outer_weight(
683
+ domain_arg, test_space_arg, element_index, coords, test_node
684
+ )
685
+
686
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
687
+ w_test_grad_inner = test.space.element_inner_weight_gradient(
688
+ domain_arg, test_space_arg, element_index, coords, test_node
689
+ )
690
+
691
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
692
+ w_test_grad_outer = test.space.element_outer_weight_gradient(
693
+ domain_arg, test_space_arg, element_index, coords, test_node
694
+ )
695
+
696
+ if wp.static(0 != TRIAL_INNER_COUNT):
697
+ w_trial_inner = trial.space.element_inner_weight(
698
+ domain_arg, trial_space_arg, element_index, coords, trial_node
699
+ )
700
+
701
+ if wp.static(0 != TRIAL_OUTER_COUNT):
702
+ w_trial_outer = trial.space.element_outer_weight(
703
+ domain_arg, trial_space_arg, element_index, coords, trial_node
704
+ )
705
+
706
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
707
+ w_trial_grad_inner = trial.space.element_inner_weight_gradient(
708
+ domain_arg, trial_space_arg, element_index, coords, trial_node
709
+ )
710
+
711
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
712
+ w_trial_grad_outer = trial.space.element_outer_weight_gradient(
713
+ domain_arg, trial_space_arg, element_index, coords, trial_node
714
+ )
715
+
716
+ for trial_val_dof in range(TRIAL_NODE_DOF_DIM):
717
+ trial_dof = trial_node_dof * TRIAL_NODE_DOF_DIM + trial_val_dof
718
+ test_result = trial_dof_vec(0.0)
719
+
720
+ if wp.static(0 != TEST_INNER_COUNT):
721
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
722
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
723
+ test_result += (
724
+ basis_coefficient(w_test_inner, test_val_dof)
725
+ * qp_result[test_dof, trial_dof, TEST_INNER_BEGIN]
726
+ )
727
+
728
+ if wp.static(0 != TEST_OUTER_COUNT):
729
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
730
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
731
+ test_result += (
732
+ basis_coefficient(w_test_outer, test_val_dof)
733
+ * qp_result[test_dof, trial_dof, TEST_OUTER_BEGIN]
734
+ )
735
+
736
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
737
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
738
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
739
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
740
+ test_result += (
741
+ basis_coefficient(w_test_grad_inner, test_val_dof, grad_dof)
742
+ * qp_result[test_dof, trial_dof, grad_dof + TEST_INNER_GRAD_BEGIN]
743
+ )
744
+
745
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
746
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
747
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
748
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
749
+ test_result += (
750
+ basis_coefficient(w_test_grad_outer, test_val_dof, grad_dof)
751
+ * qp_result[test_dof, trial_dof, grad_dof + TEST_OUTER_GRAD_BEGIN]
752
+ )
753
+
754
+ if wp.static(0 != TRIAL_INNER_COUNT):
755
+ trial_result += basis_coefficient(w_trial_inner, trial_val_dof) * test_result[TRIAL_INNER_BEGIN]
756
+
757
+ if wp.static(0 != TRIAL_OUTER_COUNT):
758
+ trial_result += basis_coefficient(w_trial_outer, trial_val_dof) * test_result[TRIAL_OUTER_BEGIN]
759
+
760
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
761
+ for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
762
+ trial_result += (
763
+ basis_coefficient(w_trial_grad_inner, trial_val_dof, grad_dof)
764
+ * test_result[grad_dof + TRIAL_INNER_GRAD_BEGIN]
765
+ )
766
+
767
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
768
+ for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
769
+ trial_result += (
770
+ basis_coefficient(w_trial_grad_outer, trial_val_dof, grad_dof)
771
+ * test_result[grad_dof + TRIAL_OUTER_GRAD_BEGIN]
772
+ )
773
+
774
+ val_sum += accumulate_dtype(trial_result)
775
+
776
+ block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
777
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(val_sum)
778
+
779
+ # Set row and column indices
780
+ if test_node_dof == 0 and trial_node_dof == 0:
781
+ if trial_node < element_trial_node_count:
782
+ trial_node_index = trial.space_partition.partition_node_index(
783
+ trial_partition_arg,
784
+ trial.space.topology.element_node_index(
785
+ domain_arg, trial_topology_arg, element_index, trial_node
786
+ ),
787
+ )
788
+ else:
789
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
790
+
791
+ triplet_rows[block_offset] = test_node_index
792
+ triplet_cols[block_offset] = trial_node_index
793
+
794
+ return dispatch_bilinear_kernel_fn