warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/fem/field/field.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, Dict, Optional, Set
16
+ from typing import Any, ClassVar, Dict, Optional, Set
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
@@ -22,6 +22,7 @@ from warp.fem.geometry import DeformedGeometry, Geometry
22
22
  from warp.fem.operator import Operator, integrand
23
23
  from warp.fem.space import FunctionSpace, SpacePartition
24
24
  from warp.fem.types import NULL_ELEMENT_INDEX, ElementKind, Sample
25
+ from warp.fem.utils import type_zero_element
25
26
 
26
27
 
27
28
  class FieldLike:
@@ -37,6 +38,10 @@ class FieldLike:
37
38
  """Value of the field-level arguments to be passed to device functions"""
38
39
  raise NotImplementedError
39
40
 
41
+ def fill_eval_arg(self, arg: "FieldLike.EvalArg", device):
42
+ """Fill the field-level arguments to be passed to device functions"""
43
+ raise NotImplementedError
44
+
40
45
  @property
41
46
  def degree(self) -> int:
42
47
  """Polynomial degree of the field, used to estimate necessary quadrature order"""
@@ -50,10 +55,6 @@ class FieldLike:
50
55
  def __str__(self) -> str:
51
56
  return self.name
52
57
 
53
- def eval_arg_value(self, device):
54
- """Value of arguments to be passed to device functions"""
55
- raise NotImplementedError
56
-
57
58
  def gradient_valid(self) -> bool:
58
59
  """Whether the gradient operator is implemented for this field."""
59
60
  return False
@@ -139,6 +140,51 @@ class GeometryField(FieldLike):
139
140
  """
140
141
  return DeformedGeometry(self, relative=relative)
141
142
 
143
+ @property
144
+ def gradient_dtype(self):
145
+ """Return type of the (world space) gradient operator. Assumes self.gradient_valid()"""
146
+ if wp.types.type_is_matrix(self.dtype):
147
+ return None
148
+
149
+ if wp.types.type_is_vector(self.dtype):
150
+ return cache.cached_mat_type(
151
+ shape=(wp.types.type_size(self.dtype), self.geometry.dimension),
152
+ dtype=wp.types.type_scalar_type(self.dtype),
153
+ )
154
+ if wp.types.type_is_quaternion(self.dtype):
155
+ return cache.cached_mat_type(
156
+ shape=(4, self.geometry.dimension),
157
+ dtype=wp.types.type_scalar_type(self.dtype),
158
+ )
159
+ return cache.cached_vec_type(length=self.geometry.dimension, dtype=wp.types.type_scalar_type(self.dtype))
160
+
161
+ @property
162
+ def reference_gradient_dtype(self):
163
+ """Return type of the reference space gradient operator. Assumes self.gradient_valid()"""
164
+ if wp.types.type_is_matrix(self.dtype):
165
+ return None
166
+
167
+ if wp.types.type_is_vector(self.dtype):
168
+ return cache.cached_mat_type(
169
+ shape=(wp.types.type_size(self.dtype), self.geometry.cell_dimension),
170
+ dtype=wp.types.type_scalar_type(self.dtype),
171
+ )
172
+ if wp.types.type_is_quaternion(self.dtype):
173
+ return cache.cached_mat_type(
174
+ shape=(4, self.geometry.cell_dimension),
175
+ dtype=wp.types.type_scalar_type(self.dtype),
176
+ )
177
+ return cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=wp.types.type_scalar_type(self.dtype))
178
+
179
+ @property
180
+ def divergence_dtype(self):
181
+ """Return type of the divergence operator. Assumes self.divergence_valid()"""
182
+ if wp.types.type_is_vector(self.dtype):
183
+ return wp.types.type_scalar_type(self.dtype)
184
+ if wp.types.type_is_matrix(self.dtype):
185
+ return cache.cached_vec_type(length=self.dtype._shape_[1], dtype=wp.types.type_scalar_type(self.dtype))
186
+ return None
187
+
142
188
 
143
189
  class SpaceField(GeometryField):
144
190
  """Base class for fields defined over a function space"""
@@ -178,33 +224,6 @@ class SpaceField(GeometryField):
178
224
  def dof_dtype(self) -> type:
179
225
  return self.space.dof_dtype
180
226
 
181
- @property
182
- def gradient_dtype(self):
183
- """Return type of the (world space) gradient operator. Assumes self.gradient_valid()"""
184
- if wp.types.type_is_vector(self.dtype):
185
- return cache.cached_mat_type(
186
- shape=(wp.types.type_length(self.dtype), self.geometry.dimension),
187
- dtype=wp.types.type_scalar_type(self.dtype),
188
- )
189
- return cache.cached_vec_type(length=self.geometry.dimension, dtype=wp.types.type_scalar_type(self.dtype))
190
-
191
- @property
192
- def reference_gradient_dtype(self):
193
- """Return type of the reference space gradient operator. Assumes self.gradient_valid()"""
194
- if wp.types.type_is_vector(self.dtype):
195
- return cache.cached_mat_type(
196
- shape=(wp.types.type_length(self.dtype), self.geometry.cell_dimension),
197
- dtype=wp.types.type_scalar_type(self.dtype),
198
- )
199
- return cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=wp.types.type_scalar_type(self.dtype))
200
-
201
- @property
202
- def divergence_dtype(self):
203
- """Return type of the divergence operator. Assumes self.gradient_valid()"""
204
- if wp.types.type_is_vector(self.dtype):
205
- return wp.types.type_scalar_type(self.dtype)
206
- return cache.cached_vec_type(length=self.dtype._shape_[1], dtype=wp.types.type_scalar_type(self.dtype))
207
-
208
227
  def _make_eval_degree(self):
209
228
  ORDER = self.space.ORDER
210
229
 
@@ -251,6 +270,19 @@ class ImplicitField(GeometryField):
251
270
  degree: Optional hint for automatic determination of quadrature orders when integrating this field
252
271
  """
253
272
 
273
+ _dynamic_attribute_constructors: ClassVar = {
274
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
275
+ "eval_degree": lambda obj: obj._make_eval_degree(),
276
+ "eval_inner": lambda obj: obj._make_eval_func(obj._func),
277
+ "eval_grad_inner": lambda obj: obj._make_eval_func(obj._grad_func),
278
+ "eval_div_inner": lambda obj: obj._make_eval_func(obj._div_func),
279
+ "eval_reference_grad_inner": lambda obj: obj._make_eval_reference_grad(),
280
+ "eval_outer": lambda obj: obj.eval_inner,
281
+ "eval_grad_outer": lambda obj: obj.eval_grad_inner,
282
+ "eval_div_outer": lambda obj: obj.eval_div_inner,
283
+ "eval_reference_grad_outer": lambda obj: obj.eval_reference_grad_inner,
284
+ }
285
+
254
286
  def __init__(
255
287
  self,
256
288
  domain: GeometryDomain,
@@ -284,18 +316,7 @@ class ImplicitField(GeometryField):
284
316
  self.EvalArg = cache.get_argument_struct(arg_types)
285
317
  self.values = values
286
318
 
287
- self.ElementEvalArg = self._make_element_eval_arg()
288
- self.eval_degree = self._make_eval_degree()
289
-
290
- self.eval_inner = self._make_eval_func(func)
291
- self.eval_grad_inner = self._make_eval_func(grad_func)
292
- self.eval_div_inner = self._make_eval_func(div_func)
293
- self.eval_reference_grad_inner = self._make_eval_reference_grad()
294
-
295
- self.eval_outer = self.eval_inner
296
- self.eval_grad_outer = self.eval_grad_inner
297
- self.eval_div_outer = self.eval_div_inner
298
- self.eval_reference_grad_outer = self.eval_reference_grad_inner
319
+ cache.setup_dynamic_attributes(self)
299
320
 
300
321
  @property
301
322
  def values(self):
@@ -303,6 +324,7 @@ class ImplicitField(GeometryField):
303
324
 
304
325
  @values.setter
305
326
  def values(self, v):
327
+ self._values = v
306
328
  self._func_arg = cache.populate_argument_struct(self.EvalArg, v, self._func.func.__name__)
307
329
 
308
330
  @property
@@ -316,6 +338,9 @@ class ImplicitField(GeometryField):
316
338
  def eval_arg_value(self, device):
317
339
  return self._func_arg
318
340
 
341
+ def fill_eval_arg(self, arg, device):
342
+ cache.populate_argument_struct(self.EvalArg, self._values, self._func.func.__name__, arg)
343
+
319
344
  @property
320
345
  def degree(self) -> int:
321
346
  return self._degree
@@ -324,6 +349,12 @@ class ImplicitField(GeometryField):
324
349
  def name(self) -> str:
325
350
  return f"Implicit_{self.domain.name}_{self.degree}_{self.EvalArg.key}"
326
351
 
352
+ def gradient_valid(self) -> bool:
353
+ return self._grad_func is not None
354
+
355
+ def divergence_valid(self) -> bool:
356
+ return self._div_func is not None
357
+
327
358
  def _make_eval_func(self, func):
328
359
  if func is None:
329
360
  return None
@@ -387,6 +418,20 @@ class UniformField(GeometryField):
387
418
  value: Uniform value over the domain
388
419
  """
389
420
 
421
+ _dynamic_attribute_constructors: ClassVar = {
422
+ "EvalArg": lambda obj: obj._make_eval_arg(),
423
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
424
+ "eval_degree": lambda obj: obj._make_eval_degree(),
425
+ "eval_inner": lambda obj: obj._make_eval_inner(),
426
+ "eval_grad_inner": lambda obj: obj._make_eval_zero(obj.gradient_dtype),
427
+ "eval_div_inner": lambda obj: obj._make_eval_zero(obj.divergence_dtype),
428
+ "eval_reference_grad_inner": lambda obj: obj._make_eval_zero(obj.reference_gradient_dtype),
429
+ "eval_outer": lambda obj: obj.eval_inner,
430
+ "eval_grad_outer": lambda obj: obj.eval_grad_inner,
431
+ "eval_div_outer": lambda obj: obj.eval_div_inner,
432
+ "eval_reference_grad_outer": lambda obj: obj.eval_reference_grad_inner,
433
+ }
434
+
390
435
  def __init__(self, domain: GeometryDomain, value: Any):
391
436
  self.domain = domain
392
437
 
@@ -396,30 +441,7 @@ class UniformField(GeometryField):
396
441
  self.dtype = wp.types.type_to_warp(type(value))
397
442
  self._value = self.dtype(value)
398
443
 
399
- scalar_type = wp.types.type_scalar_type(self.dtype)
400
- if wp.types.type_is_vector(self.dtype):
401
- grad_type = wp.mat(shape=(wp.types.type_length(self.dtype), self.geometry.dimension), dtype=scalar_type)
402
- div_type = scalar_type
403
- elif wp.types.type_is_matrix(self.dtype):
404
- grad_type = None
405
- div_type = wp.vec(length=(wp.types.type_length(self.dtype) // self.geometry.dimension), dtype=scalar_type)
406
- else:
407
- div_type = None
408
- grad_type = wp.vec(length=self.geometry.dimension, dtype=scalar_type)
409
-
410
- self.EvalArg = self._make_eval_arg()
411
- self.ElementEvalArg = self._make_element_eval_arg()
412
- self.eval_degree = self._make_eval_degree()
413
-
414
- self.eval_inner = self._make_eval_inner()
415
- self.eval_grad_inner = self._make_eval_zero(grad_type)
416
- self.eval_div_inner = self._make_eval_zero(div_type)
417
- self.eval_reference_grad_inner = self.eval_grad_inner
418
-
419
- self.eval_outer = self.eval_inner
420
- self.eval_grad_outer = self.eval_grad_inner
421
- self.eval_div_outer = self.eval_div_inner
422
- self.eval_reference_grad_outer = self.eval_reference_grad_inner
444
+ cache.setup_dynamic_attributes(self)
423
445
 
424
446
  @property
425
447
  def value(self):
@@ -444,10 +466,19 @@ class UniformField(GeometryField):
444
466
  arg.value = self.value
445
467
  return arg
446
468
 
469
+ def fill_eval_arg(self, arg, device):
470
+ arg.value = self.value
471
+
447
472
  @property
448
473
  def degree(self) -> int:
449
474
  return 0
450
475
 
476
+ def gradient_valid(self) -> bool:
477
+ return self.gradient_dtype is not None
478
+
479
+ def divergence_valid(self) -> bool:
480
+ return self.divergence_dtype is not None
481
+
451
482
  @property
452
483
  def name(self) -> str:
453
484
  return f"Uniform{self.domain.name}_{wp.types.get_type_code(self.dtype)}"
@@ -463,11 +494,11 @@ class UniformField(GeometryField):
463
494
  if dtype is None:
464
495
  return None
465
496
 
466
- scalar_type = wp.types.type_scalar_type(dtype)
497
+ zero_element = type_zero_element(dtype)
467
498
 
468
499
  @cache.dynamic_func(suffix=f"{self.name}_{wp.types.get_type_code(dtype)}")
469
500
  def eval_zero(args: self.ElementEvalArg, s: Sample):
470
- return dtype(scalar_type(0.0))
501
+ return zero_element()
471
502
 
472
503
  return eval_zero
473
504
 
@@ -511,6 +542,20 @@ class NonconformingField(GeometryField):
511
542
 
512
543
  _LOOKUP_EPS = wp.constant(1.0e-6)
513
544
 
545
+ _dynamic_attribute_constructors: ClassVar = {
546
+ "EvalArg": lambda obj: obj._make_eval_arg(),
547
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
548
+ "eval_degree": lambda obj: obj._make_eval_degree(),
549
+ "eval_inner": lambda obj: obj._make_nonconforming_eval("eval_inner"),
550
+ "eval_grad_inner": lambda obj: obj._make_nonconforming_eval("eval_grad_inner"),
551
+ "eval_div_inner": lambda obj: obj._make_nonconforming_eval("eval_div_inner"),
552
+ "eval_reference_grad_inner": lambda obj: obj._make_eval_reference_grad(),
553
+ "eval_outer": lambda obj: obj.eval_inner,
554
+ "eval_grad_outer": lambda obj: obj.eval_grad_inner,
555
+ "eval_div_outer": lambda obj: obj.eval_div_inner,
556
+ "eval_reference_grad_outer": lambda obj: obj.eval_reference_grad_inner,
557
+ }
558
+
514
559
  def __init__(self, domain: GeometryDomain, field: DiscreteField, background: Any = 0.0):
515
560
  self.domain = domain
516
561
 
@@ -523,20 +568,7 @@ class NonconformingField(GeometryField):
523
568
  raise ValueError("Background field must be conforming to the domain")
524
569
  self.background = background
525
570
 
526
- self.EvalArg = self._make_eval_arg()
527
- self.ElementEvalArg = self._make_element_eval_arg()
528
- self.eval_degree = self._make_eval_degree()
529
-
530
- self.eval_inner = self._make_nonconforming_eval("eval_inner")
531
- self.eval_grad_inner = self._make_nonconforming_eval("eval_grad_inner")
532
- self.eval_div_inner = self._make_nonconforming_eval("eval_div_inner")
533
- self.eval_reference_grad_inner = self._make_eval_reference_grad()
534
-
535
- # Nonconforming evaluation is position based, does not handle discontinuous fields
536
- self.eval_outer = self.eval_inner
537
- self.eval_grad_outer = self.eval_grad_inner
538
- self.eval_div_outer = self.eval_div_inner
539
- self.eval_reference_grad_outer = self.eval_reference_grad_inner
571
+ cache.setup_dynamic_attributes(self)
540
572
 
541
573
  @property
542
574
  def geometry(self) -> Geometry:
@@ -546,19 +578,26 @@ class NonconformingField(GeometryField):
546
578
  def element_kind(self) -> ElementKind:
547
579
  return self.domain.element_kind
548
580
 
549
- @cache.cached_arg_value
550
581
  def eval_arg_value(self, device):
551
582
  arg = self.EvalArg()
552
- arg.field_cell_eval_arg = self.field.ElementEvalArg()
553
- arg.field_cell_eval_arg.elt_arg = self.field.geometry.cell_arg_value(device)
554
- arg.field_cell_eval_arg.eval_arg = self.field.eval_arg_value(device)
555
- arg.background_arg = self.background.eval_arg_value(device)
583
+ self.fill_eval_arg(arg, device)
556
584
  return arg
557
585
 
586
+ def fill_eval_arg(self, arg, device):
587
+ self.field.fill_eval_arg(arg.field_cell_eval_arg.eval_arg, device)
588
+ self.field.geometry.fill_cell_arg(arg.field_cell_eval_arg.elt_arg, device)
589
+ self.background.fill_eval_arg(arg.background_arg, device)
590
+
558
591
  @property
559
592
  def degree(self) -> int:
560
593
  return self.field.degree
561
594
 
595
+ def gradient_valid(self) -> bool:
596
+ return self.field.gradient_valid() and self.background.gradient_valid()
597
+
598
+ def divergence_valid(self) -> bool:
599
+ return self.field.divergence_valid() and self.background.divergence_valid()
600
+
562
601
  @property
563
602
  def name(self) -> str:
564
603
  return f"{self.domain.name}_{self.field.name}_{self.background.name}"
@@ -570,20 +609,23 @@ class NonconformingField(GeometryField):
570
609
  if field_eval is None or bg_eval is None:
571
610
  return None
572
611
 
612
+ cell_lookup = self.field.geometry.cell_lookup
613
+
573
614
  @cache.dynamic_func(suffix=f"{eval_func_name}_{self.name}")
574
615
  def eval_nc(args: self.ElementEvalArg, s: Sample):
575
616
  pos = self.domain.element_position(args.elt_arg, s)
576
617
  cell_arg = args.eval_arg.field_cell_eval_arg.elt_arg
577
- nonconforming_s = self.field.geometry.cell_lookup(cell_arg, pos)
578
- if (
579
- nonconforming_s.element_index == NULL_ELEMENT_INDEX
580
- or wp.length_sq(pos - self.field.geometry.cell_position(cell_arg, nonconforming_s))
581
- > NonconformingField._LOOKUP_EPS
582
- ):
583
- return bg_eval(self.background.ElementEvalArg(args.elt_arg, args.eval_arg.background_arg), s)
584
- return field_eval(
585
- self.field.ElementEvalArg(cell_arg, args.eval_arg.field_cell_eval_arg.eval_arg), nonconforming_s
586
- )
618
+ nonconforming_s = cell_lookup(cell_arg, pos, NonconformingField._LOOKUP_EPS)
619
+ if nonconforming_s.element_index != NULL_ELEMENT_INDEX:
620
+ if (
621
+ wp.length_sq(pos - self.field.geometry.cell_position(cell_arg, nonconforming_s))
622
+ <= NonconformingField._LOOKUP_EPS
623
+ ):
624
+ return field_eval(
625
+ self.field.ElementEvalArg(cell_arg, args.eval_arg.field_cell_eval_arg.eval_arg), nonconforming_s
626
+ )
627
+
628
+ return bg_eval(self.background.ElementEvalArg(args.elt_arg, args.eval_arg.background_arg), s)
587
629
 
588
630
  return eval_nc
589
631
 
@@ -13,12 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any
16
+ from typing import Any, ClassVar
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
20
20
  from warp.fem.space import CollocatedFunctionSpace, SpacePartition
21
21
  from warp.fem.types import NULL_NODE_INDEX, ElementIndex, Sample
22
+ from warp.fem.utils import type_zero_element
22
23
 
23
24
  from .field import DiscreteField
24
25
 
@@ -26,26 +27,29 @@ from .field import DiscreteField
26
27
  class NodalFieldBase(DiscreteField):
27
28
  """Base class for nodal field and nodal field traces. Does not hold values"""
28
29
 
30
+ _dynamic_attribute_constructors: ClassVar = {
31
+ "EvalArg": lambda obj: obj._make_eval_arg(),
32
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
33
+ "eval_degree": lambda obj: DiscreteField._make_eval_degree(obj),
34
+ "_read_node_value": lambda obj: obj._make_read_node_value(),
35
+ "eval_inner": lambda obj: obj._make_eval_inner(),
36
+ "eval_outer": lambda obj: obj._make_eval_outer(),
37
+ "eval_grad_inner": lambda obj: obj._make_eval_grad_inner(world_space=True),
38
+ "eval_grad_outer": lambda obj: obj._make_eval_grad_outer(world_space=True),
39
+ "eval_reference_grad_inner": lambda obj: obj._make_eval_grad_inner(world_space=False),
40
+ "eval_reference_grad_outer": lambda obj: obj._make_eval_grad_outer(world_space=False),
41
+ "eval_div_inner": lambda obj: obj._make_eval_div_inner(),
42
+ "eval_div_outer": lambda obj: obj._make_eval_div_outer(),
43
+ "set_node_value": lambda obj: obj._make_set_node_value(),
44
+ "node_partition_index": lambda obj: obj._make_node_partition_index(),
45
+ "node_count": lambda obj: obj._make_node_count(),
46
+ "node_index": lambda obj: obj._make_node_index(),
47
+ "at_node": lambda obj: obj._make_at_node(),
48
+ }
49
+
29
50
  def __init__(self, space: CollocatedFunctionSpace, space_partition: SpacePartition):
30
51
  super().__init__(space, space_partition)
31
-
32
- self.EvalArg = self._make_eval_arg()
33
- self.ElementEvalArg = self._make_element_eval_arg()
34
- self.eval_degree = DiscreteField._make_eval_degree(self)
35
-
36
- self._read_node_value = self._make_read_node_value()
37
-
38
- self.eval_inner = self._make_eval_inner()
39
- self.eval_outer = self._make_eval_outer()
40
- self.eval_grad_inner = self._make_eval_grad_inner(world_space=True)
41
- self.eval_grad_outer = self._make_eval_grad_outer(world_space=True)
42
- self.eval_reference_grad_inner = self._make_eval_grad_inner(world_space=False)
43
- self.eval_reference_grad_outer = self._make_eval_grad_outer(world_space=False)
44
- self.eval_div_inner = self._make_eval_div_inner()
45
- self.eval_div_outer = self._make_eval_div_outer()
46
-
47
- self.set_node_value = self._make_set_node_value()
48
- self.node_partition_index = self._make_node_partition_index()
52
+ cache.setup_dynamic_attributes(self)
49
53
 
50
54
  def _make_eval_arg(self):
51
55
  @cache.dynamic_struct(suffix=self.name)
@@ -66,6 +70,8 @@ class NodalFieldBase(DiscreteField):
66
70
  return ElementEvalArg
67
71
 
68
72
  def _make_read_node_value(self):
73
+ zero_element = type_zero_element(self.dof_dtype)
74
+
69
75
  @cache.dynamic_func(suffix=self.name)
70
76
  def read_node_value(args: self.ElementEvalArg, geo_element_index: ElementIndex, node_index_in_elt: int):
71
77
  nidx = self.space.topology.element_node_index(
@@ -73,26 +79,29 @@ class NodalFieldBase(DiscreteField):
73
79
  )
74
80
  pidx = self.space_partition.partition_node_index(args.eval_arg.partition_arg, nidx)
75
81
  if pidx == NULL_NODE_INDEX:
76
- return self.space.dof_dtype(0.0)
82
+ return zero_element()
77
83
 
78
84
  return args.eval_arg.dof_values[pidx]
79
85
 
80
86
  return read_node_value
81
87
 
82
88
  def _make_eval_inner(self):
89
+ zero_element = type_zero_element(self.dtype)
90
+
83
91
  @cache.dynamic_func(suffix=self.name)
84
92
  def eval_inner(args: self.ElementEvalArg, s: Sample):
85
93
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
86
94
  node_count = self.space.topology.element_node_count(
87
95
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
88
96
  )
89
- res = self.space.dtype(0.0)
97
+ res = zero_element()
90
98
  for k in range(node_count):
99
+ w = self.space.element_inner_weight(
100
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
101
+ )
91
102
  res += self.space.space_value(
92
103
  self._read_node_value(args, s.element_index, k),
93
- self.space.element_inner_weight(
94
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
95
- ),
104
+ w,
96
105
  local_value_map,
97
106
  )
98
107
  return res
@@ -104,6 +113,7 @@ class NodalFieldBase(DiscreteField):
104
113
  return None
105
114
 
106
115
  gradient_dtype = self.gradient_dtype if world_space else self.reference_gradient_dtype
116
+ zero_element = type_zero_element(gradient_dtype)
107
117
 
108
118
  @cache.dynamic_func(suffix=f"{self.name}{world_space}")
109
119
  def eval_grad_inner(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
@@ -112,7 +122,7 @@ class NodalFieldBase(DiscreteField):
112
122
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
113
123
  )
114
124
 
115
- res = gradient_dtype(0.0)
125
+ res = zero_element()
116
126
  for k in range(node_count):
117
127
  res += self.space.space_gradient(
118
128
  self._read_node_value(args, s.element_index, k),
@@ -144,6 +154,7 @@ class NodalFieldBase(DiscreteField):
144
154
  def _make_eval_div_inner(self):
145
155
  if not self.divergence_valid():
146
156
  return None
157
+ zero_element = type_zero_element(self.divergence_dtype)
147
158
 
148
159
  @cache.dynamic_func(suffix=self.name)
149
160
  def eval_div_inner(args: self.ElementEvalArg, s: Sample):
@@ -153,7 +164,7 @@ class NodalFieldBase(DiscreteField):
153
164
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
154
165
  )
155
166
 
156
- res = self.divergence_dtype(0.0)
167
+ res = zero_element()
157
168
  for k in range(node_count):
158
169
  res += self.space.space_divergence(
159
170
  self._read_node_value(args, s.element_index, k),
@@ -168,6 +179,8 @@ class NodalFieldBase(DiscreteField):
168
179
  return eval_div_inner
169
180
 
170
181
  def _make_eval_outer(self):
182
+ zero_element = type_zero_element(self.dtype)
183
+
171
184
  @cache.dynamic_func(suffix=self.name)
172
185
  def eval_outer(args: self.ElementEvalArg, s: Sample):
173
186
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -175,7 +188,7 @@ class NodalFieldBase(DiscreteField):
175
188
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
176
189
  )
177
190
 
178
- res = self.dtype(0.0)
191
+ res = zero_element()
179
192
  for k in range(node_count):
180
193
  res += self.space.space_value(
181
194
  self._read_node_value(args, s.element_index, k),
@@ -193,6 +206,7 @@ class NodalFieldBase(DiscreteField):
193
206
  return None
194
207
 
195
208
  gradient_dtype = self.gradient_dtype if world_space else self.reference_gradient_dtype
209
+ zero_element = type_zero_element(gradient_dtype)
196
210
 
197
211
  @cache.dynamic_func(suffix=f"{self.name}{world_space}")
198
212
  def eval_grad_outer(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
@@ -201,7 +215,7 @@ class NodalFieldBase(DiscreteField):
201
215
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
202
216
  )
203
217
 
204
- res = gradient_dtype(0.0)
218
+ res = zero_element()
205
219
  for k in range(node_count):
206
220
  res += self.space.space_gradient(
207
221
  self._read_node_value(args, s.element_index, k),
@@ -234,6 +248,8 @@ class NodalFieldBase(DiscreteField):
234
248
  if not self.divergence_valid():
235
249
  return None
236
250
 
251
+ zero_element = type_zero_element(self.divergence_dtype)
252
+
237
253
  @cache.dynamic_func(suffix=self.name)
238
254
  def eval_div_outer(args: self.ElementEvalArg, s: Sample):
239
255
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
@@ -242,7 +258,7 @@ class NodalFieldBase(DiscreteField):
242
258
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
243
259
  )
244
260
 
245
- res = self.divergence_dtype(0.0)
261
+ res = zero_element()
246
262
  for k in range(node_count):
247
263
  res += self.space.space_divergence(
248
264
  self._read_node_value(args, s.element_index, k),
@@ -279,6 +295,32 @@ class NodalFieldBase(DiscreteField):
279
295
 
280
296
  return node_partition_index
281
297
 
298
+ def _make_node_count(self):
299
+ @cache.dynamic_func(suffix=self.name)
300
+ def node_count(args: self.ElementEvalArg, s: Sample):
301
+ return self.space.topology.element_node_count(args.elt_arg, args.eval_arg.topology_arg, s.element_index)
302
+
303
+ return node_count
304
+
305
+ def _make_at_node(self):
306
+ @cache.dynamic_func(suffix=self.name)
307
+ def at_node(args: self.ElementEvalArg, s: Sample, node_index_in_elt: int):
308
+ node_coords = self.space.node_coords_in_element(
309
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, node_index_in_elt
310
+ )
311
+ return Sample(s.element_index, node_coords, s.qp_index, s.qp_weight, s.test_dof, s.trial_dof)
312
+
313
+ return at_node
314
+
315
+ def _make_node_index(self):
316
+ @cache.dynamic_func(suffix=self.name)
317
+ def node_index(args: self.ElementEvalArg, s: Sample, node_index_in_elt: int):
318
+ return self.space.topology.element_node_index(
319
+ args.elt_arg, args.eval_arg.topology_arg, s.element_index, node_index_in_elt
320
+ )
321
+
322
+ return node_index
323
+
282
324
 
283
325
  class NodalField(NodalFieldBase):
284
326
  """A field holding values for all degrees of freedom at each node of the underlying function space partition
@@ -296,13 +338,15 @@ class NodalField(NodalFieldBase):
296
338
 
297
339
  def eval_arg_value(self, device):
298
340
  arg = self.EvalArg()
299
- arg.dof_values = self._dof_values.to(device)
300
- arg.space_arg = self.space.space_arg_value(device)
301
- arg.partition_arg = self.space_partition.partition_arg_value(device)
302
- arg.topology_arg = self.space.topology.topo_arg_value(device)
303
-
341
+ self.fill_eval_arg(arg, device)
304
342
  return arg
305
343
 
344
+ def fill_eval_arg(self, arg, device):
345
+ arg.dof_values = self._dof_values.to(device)
346
+ self.space.fill_space_arg(arg.space_arg, device)
347
+ self.space_partition.fill_partition_arg(arg.partition_arg, device)
348
+ self.space.topology.fill_topo_arg(arg.topology_arg, device)
349
+
306
350
  @property
307
351
  def dof_values(self) -> wp.array:
308
352
  """Returns a warp array containing the values at all degrees of freedom of the underlying space partition"""
@@ -328,13 +372,15 @@ class NodalField(NodalFieldBase):
328
372
 
329
373
  def eval_arg_value(self, device):
330
374
  arg = self.EvalArg()
331
- arg.dof_values = self._field.dof_values.to(device)
332
- arg.space_arg = self.space.space_arg_value(device)
333
- arg.partition_arg = self.space_partition.partition_arg_value(device)
334
- arg.topology_arg = self.space.topology.topo_arg_value(device)
335
-
375
+ self.fill_eval_arg(arg, device)
336
376
  return arg
337
377
 
378
+ def fill_eval_arg(self, arg, device):
379
+ arg.dof_values = self._field.dof_values.to(device)
380
+ self.space.fill_space_arg(arg.space_arg, device)
381
+ self.space_partition.fill_partition_arg(arg.partition_arg, device)
382
+ self.space.topology.fill_topo_arg(arg.topology_arg, device)
383
+
338
384
  def trace(self) -> Trace:
339
385
  trace_field = NodalField.Trace(self)
340
386
  return trace_field