warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp-clang.dylib +0 -0
  5. warp/bin/libwarp.dylib +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any
16
+ from typing import Any, ClassVar
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
20
20
  from warp.fem.space import CollocatedFunctionSpace, SpacePartition
21
21
  from warp.fem.types import NULL_NODE_INDEX, ElementIndex, Sample
22
+ from warp.fem.utils import type_zero_element
22
23
 
23
24
  from .field import DiscreteField
24
25
 
@@ -26,26 +27,29 @@ from .field import DiscreteField
26
27
  class NodalFieldBase(DiscreteField):
27
28
  """Base class for nodal field and nodal field traces. Does not hold values"""
28
29
 
30
+ _dynamic_attribute_constructors: ClassVar = {
31
+ "EvalArg": lambda obj: obj._make_eval_arg(),
32
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
33
+ "eval_degree": lambda obj: DiscreteField._make_eval_degree(obj),
34
+ "_read_node_value": lambda obj: obj._make_read_node_value(),
35
+ "eval_inner": lambda obj: obj._make_eval_inner(),
36
+ "eval_outer": lambda obj: obj._make_eval_outer(),
37
+ "eval_grad_inner": lambda obj: obj._make_eval_grad_inner(world_space=True),
38
+ "eval_grad_outer": lambda obj: obj._make_eval_grad_outer(world_space=True),
39
+ "eval_reference_grad_inner": lambda obj: obj._make_eval_grad_inner(world_space=False),
40
+ "eval_reference_grad_outer": lambda obj: obj._make_eval_grad_outer(world_space=False),
41
+ "eval_div_inner": lambda obj: obj._make_eval_div_inner(),
42
+ "eval_div_outer": lambda obj: obj._make_eval_div_outer(),
43
+ "set_node_value": lambda obj: obj._make_set_node_value(),
44
+ "node_partition_index": lambda obj: obj._make_node_partition_index(),
45
+ "node_count": lambda obj: obj._make_node_count(),
46
+ "node_index": lambda obj: obj._make_node_index(),
47
+ "at_node": lambda obj: obj._make_at_node(),
48
+ }
49
+
29
50
  def __init__(self, space: CollocatedFunctionSpace, space_partition: SpacePartition):
30
51
  super().__init__(space, space_partition)
31
-
32
- self.EvalArg = self._make_eval_arg()
33
- self.ElementEvalArg = self._make_element_eval_arg()
34
- self.eval_degree = DiscreteField._make_eval_degree(self)
35
-
36
- self._read_node_value = self._make_read_node_value()
37
-
38
- self.eval_inner = self._make_eval_inner()
39
- self.eval_outer = self._make_eval_outer()
40
- self.eval_grad_inner = self._make_eval_grad_inner(world_space=True)
41
- self.eval_grad_outer = self._make_eval_grad_outer(world_space=True)
42
- self.eval_reference_grad_inner = self._make_eval_grad_inner(world_space=False)
43
- self.eval_reference_grad_outer = self._make_eval_grad_outer(world_space=False)
44
- self.eval_div_inner = self._make_eval_div_inner()
45
- self.eval_div_outer = self._make_eval_div_outer()
46
-
47
- self.set_node_value = self._make_set_node_value()
48
- self.node_partition_index = self._make_node_partition_index()
52
+ cache.setup_dynamic_attributes(self)
49
53
 
50
54
  def _make_eval_arg(self):
51
55
  @cache.dynamic_struct(suffix=self.name)
@@ -66,6 +70,8 @@ class NodalFieldBase(DiscreteField):
66
70
  return ElementEvalArg
67
71
 
68
72
  def _make_read_node_value(self):
73
+ zero_element = type_zero_element(self.dof_dtype)
74
+
69
75
  @cache.dynamic_func(suffix=self.name)
70
76
  def read_node_value(args: self.ElementEvalArg, geo_element_index: ElementIndex, node_index_in_elt: int):
71
77
  nidx = self.space.topology.element_node_index(
@@ -73,26 +79,29 @@ class NodalFieldBase(DiscreteField):
73
79
  )
74
80
  pidx = self.space_partition.partition_node_index(args.eval_arg.partition_arg, nidx)
75
81
  if pidx == NULL_NODE_INDEX:
76
- return self.space.dof_dtype(0.0)
82
+ return zero_element()
77
83
 
78
84
  return args.eval_arg.dof_values[pidx]
79
85
 
80
86
  return read_node_value
81
87
 
82
88
  def _make_eval_inner(self):
89
+ zero_element = type_zero_element(self.dtype)
90
+
83
91
  @cache.dynamic_func(suffix=self.name)
84
92
  def eval_inner(args: self.ElementEvalArg, s: Sample):
85
93
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
86
94
  node_count = self.space.topology.element_node_count(
87
95
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
88
96
  )
89
- res = self.space.dtype(0.0)
97
+ res = zero_element()
90
98
  for k in range(node_count):
99
+ w = self.space.element_inner_weight(
100
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
101
+ )
91
102
  res += self.space.space_value(
92
103
  self._read_node_value(args, s.element_index, k),
93
- self.space.element_inner_weight(
94
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
95
- ),
104
+ w,
96
105
  local_value_map,
97
106
  )
98
107
  return res
@@ -104,6 +113,7 @@ class NodalFieldBase(DiscreteField):
104
113
  return None
105
114
 
106
115
  gradient_dtype = self.gradient_dtype if world_space else self.reference_gradient_dtype
116
+ zero_element = type_zero_element(gradient_dtype)
107
117
 
108
118
  @cache.dynamic_func(suffix=f"{self.name}{world_space}")
109
119
  def eval_grad_inner(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
@@ -112,7 +122,7 @@ class NodalFieldBase(DiscreteField):
112
122
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
113
123
  )
114
124
 
115
- res = gradient_dtype(0.0)
125
+ res = zero_element()
116
126
  for k in range(node_count):
117
127
  res += self.space.space_gradient(
118
128
  self._read_node_value(args, s.element_index, k),
@@ -144,6 +154,7 @@ class NodalFieldBase(DiscreteField):
144
154
  def _make_eval_div_inner(self):
145
155
  if not self.divergence_valid():
146
156
  return None
157
+ zero_element = type_zero_element(self.divergence_dtype)
147
158
 
148
159
  @cache.dynamic_func(suffix=self.name)
149
160
  def eval_div_inner(args: self.ElementEvalArg, s: Sample):
@@ -153,7 +164,7 @@ class NodalFieldBase(DiscreteField):
153
164
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
154
165
  )
155
166
 
156
- res = self.divergence_dtype(0.0)
167
+ res = zero_element()
157
168
  for k in range(node_count):
158
169
  res += self.space.space_divergence(
159
170
  self._read_node_value(args, s.element_index, k),
@@ -168,6 +179,8 @@ class NodalFieldBase(DiscreteField):
168
179
  return eval_div_inner
169
180
 
170
181
  def _make_eval_outer(self):
182
+ zero_element = type_zero_element(self.dtype)
183
+
171
184
  @cache.dynamic_func(suffix=self.name)
172
185
  def eval_outer(args: self.ElementEvalArg, s: Sample):
173
186
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -175,7 +188,7 @@ class NodalFieldBase(DiscreteField):
175
188
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
176
189
  )
177
190
 
178
- res = self.dtype(0.0)
191
+ res = zero_element()
179
192
  for k in range(node_count):
180
193
  res += self.space.space_value(
181
194
  self._read_node_value(args, s.element_index, k),
@@ -193,6 +206,7 @@ class NodalFieldBase(DiscreteField):
193
206
  return None
194
207
 
195
208
  gradient_dtype = self.gradient_dtype if world_space else self.reference_gradient_dtype
209
+ zero_element = type_zero_element(gradient_dtype)
196
210
 
197
211
  @cache.dynamic_func(suffix=f"{self.name}{world_space}")
198
212
  def eval_grad_outer(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
@@ -201,7 +215,7 @@ class NodalFieldBase(DiscreteField):
201
215
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
202
216
  )
203
217
 
204
- res = gradient_dtype(0.0)
218
+ res = zero_element()
205
219
  for k in range(node_count):
206
220
  res += self.space.space_gradient(
207
221
  self._read_node_value(args, s.element_index, k),
@@ -234,6 +248,8 @@ class NodalFieldBase(DiscreteField):
234
248
  if not self.divergence_valid():
235
249
  return None
236
250
 
251
+ zero_element = type_zero_element(self.divergence_dtype)
252
+
237
253
  @cache.dynamic_func(suffix=self.name)
238
254
  def eval_div_outer(args: self.ElementEvalArg, s: Sample):
239
255
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
@@ -242,7 +258,7 @@ class NodalFieldBase(DiscreteField):
242
258
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
243
259
  )
244
260
 
245
- res = self.divergence_dtype(0.0)
261
+ res = zero_element()
246
262
  for k in range(node_count):
247
263
  res += self.space.space_divergence(
248
264
  self._read_node_value(args, s.element_index, k),
@@ -279,6 +295,32 @@ class NodalFieldBase(DiscreteField):
279
295
 
280
296
  return node_partition_index
281
297
 
298
+ def _make_node_count(self):
299
+ @cache.dynamic_func(suffix=self.name)
300
+ def node_count(args: self.ElementEvalArg, s: Sample):
301
+ return self.space.topology.element_node_count(args.elt_arg, args.eval_arg.topology_arg, s.element_index)
302
+
303
+ return node_count
304
+
305
+ def _make_at_node(self):
306
+ @cache.dynamic_func(suffix=self.name)
307
+ def at_node(args: self.ElementEvalArg, s: Sample, node_index_in_elt: int):
308
+ node_coords = self.space.node_coords_in_element(
309
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, node_index_in_elt
310
+ )
311
+ return Sample(s.element_index, node_coords, s.qp_index, s.qp_weight, s.test_dof, s.trial_dof)
312
+
313
+ return at_node
314
+
315
+ def _make_node_index(self):
316
+ @cache.dynamic_func(suffix=self.name)
317
+ def node_index(args: self.ElementEvalArg, s: Sample, node_index_in_elt: int):
318
+ return self.space.topology.element_node_index(
319
+ args.elt_arg, args.eval_arg.topology_arg, s.element_index, node_index_in_elt
320
+ )
321
+
322
+ return node_index
323
+
282
324
 
283
325
  class NodalField(NodalFieldBase):
284
326
  """A field holding values for all degrees of freedom at each node of the underlying function space partition
@@ -296,13 +338,15 @@ class NodalField(NodalFieldBase):
296
338
 
297
339
  def eval_arg_value(self, device):
298
340
  arg = self.EvalArg()
299
- arg.dof_values = self._dof_values.to(device)
300
- arg.space_arg = self.space.space_arg_value(device)
301
- arg.partition_arg = self.space_partition.partition_arg_value(device)
302
- arg.topology_arg = self.space.topology.topo_arg_value(device)
303
-
341
+ self.fill_eval_arg(arg, device)
304
342
  return arg
305
343
 
344
+ def fill_eval_arg(self, arg, device):
345
+ arg.dof_values = self._dof_values.to(device)
346
+ self.space.fill_space_arg(arg.space_arg, device)
347
+ self.space_partition.fill_partition_arg(arg.partition_arg, device)
348
+ self.space.topology.fill_topo_arg(arg.topology_arg, device)
349
+
306
350
  @property
307
351
  def dof_values(self) -> wp.array:
308
352
  """Returns a warp array containing the values at all degrees of freedom of the underlying space partition"""
@@ -328,13 +372,15 @@ class NodalField(NodalFieldBase):
328
372
 
329
373
  def eval_arg_value(self, device):
330
374
  arg = self.EvalArg()
331
- arg.dof_values = self._field.dof_values.to(device)
332
- arg.space_arg = self.space.space_arg_value(device)
333
- arg.partition_arg = self.space_partition.partition_arg_value(device)
334
- arg.topology_arg = self.space.topology.topo_arg_value(device)
335
-
375
+ self.fill_eval_arg(arg, device)
336
376
  return arg
337
377
 
378
+ def fill_eval_arg(self, arg, device):
379
+ arg.dof_values = self._field.dof_values.to(device)
380
+ self.space.fill_space_arg(arg.space_arg, device)
381
+ self.space_partition.fill_partition_arg(arg.partition_arg, device)
382
+ self.space.topology.fill_topo_arg(arg.topology_arg, device)
383
+
338
384
  def trace(self) -> Trace:
339
385
  trace_field = NodalField.Trace(self)
340
386
  return trace_field
warp/fem/field/virtual.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, Set
16
+ from typing import Any, ClassVar, Dict, Set
17
17
 
18
18
  import warp as wp
19
19
  import warp.fem.operator as operator
@@ -23,6 +23,7 @@ from warp.fem.linalg import basis_coefficient, generalized_inner, generalized_ou
23
23
  from warp.fem.quadrature import Quadrature
24
24
  from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction
25
25
  from warp.fem.types import NULL_NODE_INDEX, DofIndex, Sample, get_node_coord, get_node_index_in_element
26
+ from warp.fem.utils import type_zero_element
26
27
 
27
28
  from .field import SpaceField
28
29
 
@@ -30,33 +31,52 @@ from .field import SpaceField
30
31
  class AdjointField(SpaceField):
31
32
  """Adjoint of a discrete field with respect to its degrees of freedom"""
32
33
 
33
- def __init__(self, space: FunctionSpace, space_partition: SpaceRestriction):
34
+ _dynamic_attribute_constructors: ClassVar = {
35
+ "EvalArg": lambda obj: obj._make_eval_arg(),
36
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
37
+ "eval_degree": lambda obj: obj._make_eval_degree(),
38
+ "eval_inner": lambda obj: obj._make_eval_inner(),
39
+ "eval_grad_inner": lambda obj: obj._make_eval_grad_inner(),
40
+ "eval_div_inner": lambda obj: obj._make_eval_div_inner(),
41
+ "eval_outer": lambda obj: obj._make_eval_outer(),
42
+ "eval_grad_outer": lambda obj: obj._make_eval_grad_outer(),
43
+ "eval_div_outer": lambda obj: obj._make_eval_div_outer(),
44
+ "node_count": lambda obj: obj._make_node_count(),
45
+ "at_node": lambda obj: obj._make_at_node(),
46
+ "node_index": lambda obj: obj._make_node_index(),
47
+ }
48
+
49
+ def __init__(self, space: FunctionSpace, space_partition: SpacePartition):
34
50
  super().__init__(space, space_partition=space_partition)
35
51
 
36
52
  self.node_dof_count = self.space.NODE_DOF_COUNT
37
53
  self.value_dof_count = self.space.VALUE_DOF_COUNT
38
54
 
39
- self.EvalArg = self.space.SpaceArg
40
- self.ElementEvalArg = self._make_element_eval_arg()
41
-
42
- self.eval_arg_value = self.space.space_arg_value
43
-
44
- self.eval_degree = self._make_eval_degree()
45
- self.eval_inner = self._make_eval_inner()
46
- self.eval_grad_inner = self._make_eval_grad_inner()
47
- self.eval_div_inner = self._make_eval_div_inner()
48
- self.eval_outer = self._make_eval_outer()
49
- self.eval_grad_outer = self._make_eval_grad_outer()
50
- self.eval_div_outer = self._make_eval_div_outer()
51
- self.at_node = self._make_at_node()
55
+ cache.setup_dynamic_attributes(self)
52
56
 
53
57
  @property
54
58
  def name(self) -> str:
55
59
  return f"{self.__class__.__name__}{self.space.name}{self._space_partition.name}"
56
60
 
57
- def _make_element_eval_arg(self):
58
- from warp.fem import cache
61
+ @cache.cached_arg_value
62
+ def eval_arg_value(self, device):
63
+ arg = self.EvalArg()
64
+ self.fill_eval_arg(arg, device)
65
+ return arg
66
+
67
+ def fill_eval_arg(self, arg, device):
68
+ self.space.fill_space_arg(arg.space_arg, device)
69
+ self.space.topology.fill_topo_arg(arg.topo_arg, device)
59
70
 
71
+ def _make_eval_arg(self):
72
+ @cache.dynamic_struct(suffix=self.name)
73
+ class EvalArg:
74
+ space_arg: self.space.SpaceArg
75
+ topo_arg: self.space.topology.TopologyArg
76
+
77
+ return EvalArg
78
+
79
+ def _make_element_eval_arg(self):
60
80
  @cache.dynamic_struct(suffix=self.name)
61
81
  class ElementEvalArg:
62
82
  elt_arg: self.space.topology.ElementArg
@@ -70,7 +90,7 @@ class AdjointField(SpaceField):
70
90
  dof = self._get_dof(s)
71
91
  node_weight = self.space.element_inner_weight(
72
92
  args.elt_arg,
73
- args.eval_arg,
93
+ args.eval_arg.space_arg,
74
94
  s.element_index,
75
95
  s.element_coords,
76
96
  get_node_index_in_element(dof),
@@ -91,7 +111,7 @@ class AdjointField(SpaceField):
91
111
  dof = self._get_dof(s)
92
112
  nabla_weight = self.space.element_inner_weight_gradient(
93
113
  args.elt_arg,
94
- args.eval_arg,
114
+ args.eval_arg.space_arg,
95
115
  s.element_index,
96
116
  s.element_coords,
97
117
  get_node_index_in_element(dof),
@@ -113,7 +133,7 @@ class AdjointField(SpaceField):
113
133
  dof = self._get_dof(s)
114
134
  nabla_weight = self.space.element_inner_weight_gradient(
115
135
  args.elt_arg,
116
- args.eval_arg,
136
+ args.eval_arg.space_arg,
117
137
  s.element_index,
118
138
  s.element_coords,
119
139
  get_node_index_in_element(dof),
@@ -132,7 +152,7 @@ class AdjointField(SpaceField):
132
152
  dof = self._get_dof(s)
133
153
  node_weight = self.space.element_outer_weight(
134
154
  args.elt_arg,
135
- args.eval_arg,
155
+ args.eval_arg.space_arg,
136
156
  s.element_index,
137
157
  s.element_coords,
138
158
  get_node_index_in_element(dof),
@@ -153,7 +173,7 @@ class AdjointField(SpaceField):
153
173
  dof = self._get_dof(s)
154
174
  nabla_weight = self.space.element_outer_weight_gradient(
155
175
  args.elt_arg,
156
- args.eval_arg,
176
+ args.eval_arg.space_arg,
157
177
  s.element_index,
158
178
  s.element_coords,
159
179
  get_node_index_in_element(dof),
@@ -175,7 +195,7 @@ class AdjointField(SpaceField):
175
195
  dof = self._get_dof(s)
176
196
  nabla_weight = self.space.element_outer_weight_gradient(
177
197
  args.elt_arg,
178
- args.eval_arg,
198
+ args.eval_arg.space_arg,
179
199
  s.element_index,
180
200
  s.element_coords,
181
201
  get_node_index_in_element(dof),
@@ -193,12 +213,30 @@ class AdjointField(SpaceField):
193
213
  def at_node(args: self.ElementEvalArg, s: Sample):
194
214
  dof = self._get_dof(s)
195
215
  node_coords = self.space.node_coords_in_element(
196
- args.elt_arg, args.eval_arg, s.element_index, get_node_index_in_element(dof)
216
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, get_node_index_in_element(dof)
197
217
  )
198
218
  return Sample(s.element_index, node_coords, s.qp_index, s.qp_weight, s.test_dof, s.trial_dof)
199
219
 
200
220
  return at_node
201
221
 
222
+ def _make_node_index(self):
223
+ @cache.dynamic_func(suffix=self.name)
224
+ def node_index(args: self.ElementEvalArg, s: Sample):
225
+ dof = self._get_dof(s)
226
+ node_idx = self.space.topology.element_node_index(
227
+ args.elt_arg, args.eval_arg.topo_arg, s.element_index, get_node_index_in_element(dof)
228
+ )
229
+ return node_idx
230
+
231
+ return node_index
232
+
233
+ def _make_node_count(self):
234
+ @cache.dynamic_func(suffix=self.name)
235
+ def node_count(args: self.ElementEvalArg, s: Sample):
236
+ return self.space.topology.element_node_count(args.elt_arg, args.eval_arg.topo_arg, s.element_index)
237
+
238
+ return node_count
239
+
202
240
 
203
241
  class TestField(AdjointField):
204
242
  """Field defined over a space restriction that can be used as a test function.
@@ -269,16 +307,16 @@ class LocalAdjointField(SpaceField):
269
307
  OUTER_GRAD_DOF = wp.constant(3)
270
308
  DOF_TYPE_COUNT = wp.constant(4)
271
309
 
272
- _OP_DOF_MAP_CONTINUOUS = {
310
+ _OP_DOF_MAP_CONTINUOUS: ClassVar[Dict[operator.Operator, int]] = {
273
311
  operator.inner: INNER_DOF,
274
312
  operator.outer: INNER_DOF,
275
313
  operator.grad: INNER_GRAD_DOF,
276
- operator.grad_outer: INNER_GRAD_DOF,
314
+ operator.grad_outer: OUTER_GRAD_DOF,
277
315
  operator.div: INNER_GRAD_DOF,
278
- operator.div_outer: INNER_GRAD_DOF,
316
+ operator.div_outer: OUTER_GRAD_DOF,
279
317
  }
280
318
 
281
- _OP_DOF_MAP_DISCONTINUOUS = {
319
+ _OP_DOF_MAP_DISCONTINUOUS: ClassVar[Dict[operator.Operator, int]] = {
282
320
  operator.inner: INNER_DOF,
283
321
  operator.outer: OUTER_DOF,
284
322
  operator.grad: INNER_GRAD_DOF,
@@ -293,6 +331,18 @@ class LocalAdjointField(SpaceField):
293
331
  class EvalArg:
294
332
  pass
295
333
 
334
+ _dynamic_attribute_constructors: ClassVar = {
335
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
336
+ "eval_degree": lambda obj: obj._make_eval_degree(),
337
+ "_split_dof": lambda obj: obj._make_split_dof(),
338
+ "eval_inner": lambda obj: obj._make_eval_inner(),
339
+ "eval_grad_inner": lambda obj: obj._make_eval_grad_inner(),
340
+ "eval_div_inner": lambda obj: obj._make_eval_div_inner(),
341
+ "eval_outer": lambda obj: obj._make_eval_outer(),
342
+ "eval_grad_outer": lambda obj: obj._make_eval_grad_outer(),
343
+ "eval_div_outer": lambda obj: obj._make_eval_div_outer(),
344
+ }
345
+
296
346
  def __init__(self, field: AdjointField):
297
347
  # if not isinstance(field.space, CollocatedFunctionSpace):
298
348
  # raise NotImplementedError("Local assembly only implemented for collocated function spaces")
@@ -305,9 +355,6 @@ class LocalAdjointField(SpaceField):
305
355
  self.value_dof_count = self.space.VALUE_DOF_COUNT
306
356
 
307
357
  self._dof_suffix = ""
308
-
309
- self.ElementEvalArg = self._make_element_eval_arg()
310
- self.eval_degree = self._make_eval_degree()
311
358
  self.at_node = None
312
359
 
313
360
  self._is_discontinuous = (self.space.element_inner_weight != self.space.element_outer_weight) or (
@@ -346,21 +393,7 @@ class LocalAdjointField(SpaceField):
346
393
  self._TAYLOR_DOF_COUNTS = dof_counts
347
394
 
348
395
  self._dof_suffix = "".join(str(c) for c in dof_counts)
349
-
350
- self._split_dof = self._make_split_dof()
351
-
352
- self.eval_inner = self._make_eval_inner()
353
- self.eval_grad_inner = self._make_eval_grad_inner()
354
- self.eval_div_inner = self._make_eval_div_inner()
355
-
356
- if self._is_discontinuous:
357
- self.eval_outer = self._make_eval_outer()
358
- self.eval_grad_outer = self._make_eval_grad_outer()
359
- self.eval_div_outer = self._make_eval_div_outer()
360
- else:
361
- self.eval_outer = self.eval_inner
362
- self.eval_grad_outer = self.eval_grad_inner
363
- self.eval_div_outer = self.eval_div_inner
396
+ cache.setup_dynamic_attributes(self)
364
397
 
365
398
  @property
366
399
  def name(self) -> str:
@@ -369,6 +402,9 @@ class LocalAdjointField(SpaceField):
369
402
  def eval_arg_value(self, device):
370
403
  return LocalAdjointField.EvalArg()
371
404
 
405
+ def fill_eval_arg(self, arg, device):
406
+ pass
407
+
372
408
  def _make_element_eval_arg(self):
373
409
  from warp.fem import cache
374
410
 
@@ -392,6 +428,7 @@ class LocalAdjointField(SpaceField):
392
428
 
393
429
  def _make_eval_inner(self):
394
430
  DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF])
431
+ zero_element = type_zero_element(self.dtype)
395
432
 
396
433
  @cache.dynamic_func(suffix=self.name)
397
434
  def eval_test_inner(args: self.ElementEvalArg, s: Sample):
@@ -399,7 +436,7 @@ class LocalAdjointField(SpaceField):
399
436
 
400
437
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
401
438
  dof_value = self.space.value_basis_element(value_dof, local_value_map)
402
- return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
439
+ return wp.where(taylor_dof == 0, dof_value, zero_element())
403
440
 
404
441
  return eval_test_inner
405
442
 
@@ -409,13 +446,14 @@ class LocalAdjointField(SpaceField):
409
446
 
410
447
  DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
411
448
  DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
449
+ zero_element = type_zero_element(self.gradient_dtype)
412
450
 
413
451
  @cache.dynamic_func(suffix=self.name)
414
452
  def eval_nabla_test_inner(args: self.ElementEvalArg, s: Sample):
415
453
  value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
416
454
 
417
455
  if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
418
- return self.gradient_dtype(0.0)
456
+ return zero_element()
419
457
 
420
458
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
421
459
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
@@ -430,13 +468,14 @@ class LocalAdjointField(SpaceField):
430
468
 
431
469
  DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
432
470
  DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
471
+ zero_element = type_zero_element(self.divergence_dtype)
433
472
 
434
473
  @cache.dynamic_func(suffix=self.name)
435
474
  def eval_div_test_inner(args: self.ElementEvalArg, s: Sample):
436
475
  value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
437
476
 
438
477
  if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
439
- return self.divergence_dtype(0.0)
478
+ return zero_element()
440
479
 
441
480
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
442
481
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
@@ -446,7 +485,11 @@ class LocalAdjointField(SpaceField):
446
485
  return eval_div_test_inner
447
486
 
448
487
  def _make_eval_outer(self):
488
+ if not self._is_discontinuous:
489
+ return self.eval_inner
490
+
449
491
  DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF])
492
+ zero_element = type_zero_element(self.dtype)
450
493
 
451
494
  @cache.dynamic_func(suffix=self.name)
452
495
  def eval_test_outer(args: self.ElementEvalArg, s: Sample):
@@ -454,7 +497,7 @@ class LocalAdjointField(SpaceField):
454
497
 
455
498
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
456
499
  dof_value = self.space.value_basis_element(value_dof, local_value_map)
457
- return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
500
+ return wp.where(taylor_dof == 0, dof_value, zero_element())
458
501
 
459
502
  return eval_test_outer
460
503
 
@@ -464,13 +507,14 @@ class LocalAdjointField(SpaceField):
464
507
 
465
508
  DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
466
509
  DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
510
+ zero_element = type_zero_element(self.gradient_dtype)
467
511
 
468
512
  @cache.dynamic_func(suffix=self.name)
469
513
  def eval_nabla_test_outer(args: self.ElementEvalArg, s: Sample):
470
514
  value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
471
515
 
472
516
  if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
473
- return self.gradient_dtype(0.0)
517
+ return zero_element()
474
518
 
475
519
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
476
520
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -485,13 +529,14 @@ class LocalAdjointField(SpaceField):
485
529
 
486
530
  DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
487
531
  DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
532
+ zero_element = type_zero_element(self.divergence_dtype)
488
533
 
489
534
  @cache.dynamic_func(suffix=self.name)
490
535
  def eval_div_test_outer(args: self.ElementEvalArg, s: Sample):
491
536
  value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
492
537
 
493
538
  if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
494
- return self.divergence_dtype(0.0)
539
+ return zero_element()
495
540
 
496
541
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
497
542
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)