warp-lang 1.2.1__py3-none-win_amd64.whl → 1.3.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +401 -199
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +122 -39
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +344 -227
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.1.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/fabric.py CHANGED
@@ -105,11 +105,14 @@ class fabricarray(noncontiguous_array_base[T]):
105
105
  # (initialized when needed)
106
106
  _vars = None
107
107
 
108
+ def __new__(cls, *args, **kwargs):
109
+ instance = super(fabricarray, cls).__new__(cls)
110
+ instance.deleter = None
111
+ return instance
112
+
108
113
  def __init__(self, data=None, attrib=None, dtype=Any, ndim=None):
109
114
  super().__init__(ARRAY_TYPE_FABRIC)
110
115
 
111
- self.deleter = None
112
-
113
116
  if data is not None:
114
117
  from .context import runtime
115
118
 
warp/fem/__init__.py CHANGED
@@ -1,7 +1,17 @@
1
1
  from .cache import TemporaryStore, borrow_temporary, borrow_temporary_like, set_default_temporary_store
2
2
  from .dirichlet import normalize_dirichlet_projector, project_linear_system
3
- from .domain import BoundarySides, Cells, FrontierSides, GeometryDomain, Sides
4
- from .field import DiscreteField, FieldLike, make_restriction, make_test, make_trial
3
+ from .domain import BoundarySides, Cells, FrontierSides, GeometryDomain, Sides, Subdomain
4
+ from .field import (
5
+ DiscreteField,
6
+ FieldLike,
7
+ ImplicitField,
8
+ NonconformingField,
9
+ UniformField,
10
+ make_discrete_field,
11
+ make_restriction,
12
+ make_test,
13
+ make_trial,
14
+ )
5
15
  from .geometry import (
6
16
  ExplicitGeometryPartition,
7
17
  Geometry,
@@ -58,4 +68,4 @@ from .space import (
58
68
  make_space_partition,
59
69
  make_space_restriction,
60
70
  )
61
- from .types import Coords, Domain, ElementIndex, Field, Sample
71
+ from .types import NULL_ELEMENT_INDEX, Coords, Domain, ElementIndex, Field, Sample, make_free_sample
warp/fem/cache.py CHANGED
@@ -1,5 +1,7 @@
1
+ import ast
1
2
  import bisect
2
3
  import re
4
+ import weakref
3
5
  from copy import copy
4
6
  from typing import Any, Callable, Dict, Optional, Tuple, Union
5
7
 
@@ -17,7 +19,7 @@ def _make_key(obj, suffix: str, use_qualified_name):
17
19
  return _key_re.sub("", f"{base_name}_{suffix}")
18
20
 
19
21
 
20
- def get_func(func, suffix: str, use_qualified_name: bool = False):
22
+ def get_func(func, suffix: str, use_qualified_name: bool = False, code_transformers=None):
21
23
  key = _make_key(func, suffix, use_qualified_name)
22
24
 
23
25
  if key not in _func_cache:
@@ -28,14 +30,15 @@ def get_func(func, suffix: str, use_qualified_name: bool = False):
28
30
  module=wp.get_module(
29
31
  func.__module__,
30
32
  ),
33
+ code_transformers=code_transformers,
31
34
  )
32
35
 
33
36
  return _func_cache[key]
34
37
 
35
38
 
36
- def dynamic_func(suffix: str, use_qualified_name=False):
39
+ def dynamic_func(suffix: str, use_qualified_name=False, code_transformers=None):
37
40
  def wrap_func(func: Callable):
38
- return get_func(func, suffix=suffix, use_qualified_name=use_qualified_name)
41
+ return get_func(func, suffix=suffix, use_qualified_name=use_qualified_name, code_transformers=code_transformers)
39
42
 
40
43
  return wrap_func
41
44
 
@@ -96,6 +99,92 @@ def dynamic_struct(suffix: str, use_qualified_name=False):
96
99
  return wrap_struct
97
100
 
98
101
 
102
+ def get_argument_struct(arg_types: Dict[str, type]):
103
+ class Args:
104
+ pass
105
+
106
+ annotations = wp.codegen.get_annotations(Args)
107
+
108
+ for name, arg_type in arg_types.items():
109
+ setattr(Args, name, None)
110
+ annotations[name] = arg_type
111
+
112
+ def arg_type_name(arg_type):
113
+ return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
114
+
115
+ try:
116
+ Args.__annotations__ = annotations
117
+ except AttributeError:
118
+ Args.__dict__.__annotations__ = annotations
119
+
120
+ suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
121
+
122
+ return get_struct(Args, suffix=suffix)
123
+
124
+
125
+ def populate_argument_struct(Args: wp.codegen.Struct, values: Dict[str, Any], func_name: str):
126
+ if values is None:
127
+ values = {}
128
+
129
+ value_struct_values = Args()
130
+ for k, v in values.items():
131
+ try:
132
+ setattr(value_struct_values, k, v)
133
+ except Exception as err:
134
+ if k not in Args.vars:
135
+ raise ValueError(
136
+ f"Passed value argument '{k}' does not match any of the function '{func_name}' parameters"
137
+ ) from err
138
+ raise ValueError(
139
+ f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the function '{func_name}' parameter of type '{wp.types.type_repr(Args.vars[k].type)}'"
140
+ ) from err
141
+
142
+ missing_values = Args.vars.keys() - values.keys()
143
+ if missing_values:
144
+ wp.utils.warn(
145
+ f"Missing values for parameter(s) '{', '.join(missing_values)}' of the function '{func_name}', will be zero-initialized"
146
+ )
147
+
148
+ return value_struct_values
149
+
150
+
151
+ class ExpandStarredArgumentStruct(ast.NodeTransformer):
152
+ def __init__(
153
+ self,
154
+ structs: Dict[str, wp.codegen.Struct],
155
+ ):
156
+ self._structs = structs
157
+
158
+ @staticmethod
159
+ def _build_path(path, node):
160
+ if isinstance(node, ast.Attribute):
161
+ ExpandStarredArgumentStruct._build_path(path, node.value)
162
+ path.append(node.attr)
163
+ if isinstance(node, ast.Name):
164
+ path.append(node.id)
165
+ return path
166
+
167
+ def _get_expanded_struct(self, arg_node):
168
+ if not isinstance(arg_node, ast.Starred):
169
+ return None
170
+ path = ".".join(ExpandStarredArgumentStruct._build_path([], arg_node.value))
171
+ return self._structs.get(path, None)
172
+
173
+ def visit_Call(self, call: ast.Call):
174
+ call = self.generic_visit(call)
175
+
176
+ expanded_args = []
177
+ for arg in call.args:
178
+ struct = self._get_expanded_struct(arg)
179
+ if struct is None:
180
+ expanded_args.append(arg)
181
+ else:
182
+ expanded_args += [ast.Attribute(value=arg.value, attr=field) for field in struct.vars.keys()]
183
+ call.args = expanded_args
184
+
185
+ return call
186
+
187
+
99
188
  def get_integrand_function(
100
189
  integrand: "warp.fem.operator.Integrand", # noqa: F821
101
190
  suffix: str,
@@ -103,9 +192,6 @@ def get_integrand_function(
103
192
  annotations=None,
104
193
  code_transformers=None,
105
194
  ):
106
- if code_transformers is None:
107
- code_transformers = []
108
-
109
195
  key = _make_key(integrand.func, suffix, use_qualified_name=True)
110
196
 
111
197
  if key not in _func_cache:
@@ -131,9 +217,6 @@ def get_integrand_kernel(
131
217
  if kernel_options is None:
132
218
  kernel_options = {}
133
219
 
134
- if code_transformers is None:
135
- code_transformers = []
136
-
137
220
  key = _make_key(integrand.func, suffix, use_qualified_name=True)
138
221
 
139
222
  if key not in _kernel_cache:
@@ -198,11 +281,22 @@ class Temporary:
198
281
  The temporary may also be explicitly returned to the pool before destruction using :meth:`release`.
199
282
  """
200
283
 
284
+ def __new__(cls, *args, **kwargs):
285
+ instance = super(Temporary, cls).__new__(cls)
286
+ instance._pool = None
287
+ return instance
288
+
201
289
  def __init__(self, array: wp.array, pool: Optional["TemporaryStore.Pool"] = None, shape=None, dtype=None):
202
290
  self._raw_array = array
203
291
  self._array_view = array
204
292
  self._pool = pool
205
293
 
294
+ if pool is not None and wp.context.runtime.tape is not None:
295
+ # Extend lifetime to that of Tape (or Pool if shorter)
296
+ # This is to prevent temporary arrays held in tape launch parameters to be redeemed
297
+ pool.hold(self)
298
+ weakref.finalize(wp.context.runtime.tape, TemporaryStore.Pool.stop_holding, pool, self)
299
+
206
300
  if shape is not None or dtype is not None:
207
301
  self._view_as(shape=shape, dtype=dtype)
208
302
 
@@ -270,6 +364,8 @@ class TemporaryStore:
270
364
  self._pool_sizes = [] # Sizes of available arrays for borrowing, ascending
271
365
  self._allocs = {} # All allocated arrays, including borrowed ones
272
366
 
367
+ self._held_temporaries = set() # Temporaries that are prevented from going out of scope
368
+
273
369
  def borrow(self, shape, dtype, requires_grad: bool):
274
370
  size = 1
275
371
  if isinstance(shape, int):
@@ -285,8 +381,12 @@ class TemporaryStore:
285
381
  # Big enough array found, remove from pool
286
382
  array = self._pool.pop(index)
287
383
  self._pool_sizes.pop(index)
288
- if requires_grad and array.grad is None:
289
- array.requires_grad = True
384
+ if requires_grad:
385
+ if array.grad is None:
386
+ array.requires_grad = True
387
+ else:
388
+ # Zero-out existing gradient to mimic semantics of wp.empty()
389
+ array.grad.zero_()
290
390
  return Temporary(pool=self, array=array, shape=shape, dtype=dtype)
291
391
 
292
392
  # No big enough array found, allocate new one
@@ -312,6 +412,12 @@ class TemporaryStore:
312
412
  def detach(self, array):
313
413
  del self._allocs[array.ptr]
314
414
 
415
+ def hold(self, temp: Temporary):
416
+ self._held_temporaries.add(temp)
417
+
418
+ def stop_holding(self, temp: Temporary):
419
+ self._held_temporaries.remove(temp)
420
+
315
421
  def __init__(self):
316
422
  self.clear()
317
423
 
@@ -401,3 +507,47 @@ def borrow_temporary_like(
401
507
  device=array.device,
402
508
  requires_grad=array.requires_grad,
403
509
  )
510
+
511
+
512
+ _device_events = {}
513
+
514
+
515
+ def capture_event(device=None):
516
+ """
517
+ Records a CUDA event on the current stream and returns it,
518
+ reusing previously created events if possible.
519
+
520
+ If the current device is not a CUDA device, returns ``None``.
521
+
522
+ The event can be returned to the shared per-device pool for future reuse by
523
+ calling :func:`synchronize_event`
524
+ """
525
+
526
+ device = wp.get_device(device)
527
+ if not device.is_cuda:
528
+ return None
529
+
530
+ try:
531
+ device_events = _device_events[device.ordinal]
532
+ except KeyError:
533
+ device_events = []
534
+ _device_events[device.ordinal] = device_events
535
+
536
+ with wp.ScopedDevice(device):
537
+ if not device_events:
538
+ return wp.record_event()
539
+
540
+ return wp.record_event(device_events.pop())
541
+
542
+
543
+ def synchronize_event(event: Union[wp.Event, None]):
544
+ """
545
+ Synchronize an event created with :func:`capture_event` and returns it to the
546
+ per-device event pool.
547
+
548
+ If `event` is ``None``, do nothing.
549
+ """
550
+
551
+ if event is not None:
552
+ wp.synchronize_event(event)
553
+ _device_events[event.device.ordinal].append(event)
warp/fem/dirichlet.py CHANGED
@@ -1,11 +1,10 @@
1
1
  from typing import Any, Optional
2
2
 
3
3
  import warp as wp
4
+ from warp.fem.utils import array_axpy, symmetric_eigenvalues_qr
4
5
  from warp.sparse import BsrMatrix, bsr_assign, bsr_axpy, bsr_copy, bsr_mm, bsr_mv
5
6
  from warp.types import type_is_matrix, type_length
6
7
 
7
- from .utils import array_axpy
8
-
9
8
 
10
9
  def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Optional[wp.array] = None):
11
10
  """
@@ -115,11 +114,34 @@ def project_linear_system(
115
114
  project_system_matrix(system_matrix, projector_matrix)
116
115
 
117
116
 
117
+ @wp.func
118
+ def _normalize_projector_and_value(A: Any, b: Any):
119
+ # Do a modal decomposition of the left and right hand side,
120
+ # Make lhs an orthogonal projection and apply corresponding scaling to righ-hand-side
121
+
122
+ eps = wp.trace(A) * A.dtype(1.0e-6)
123
+
124
+ diag, ev = symmetric_eigenvalues_qr(A, eps * eps)
125
+
126
+ zero = A.dtype(0)
127
+ A_unitary = type(A)(zero)
128
+ b_normalized = type(b)(zero)
129
+
130
+ for k in range(b.length):
131
+ if diag[k] > eps: # ignore small eigenmodes
132
+ v = ev[k]
133
+ A_unitary += wp.outer(v, v)
134
+ b_normalized += wp.dot(v, b) / diag[k] * v
135
+
136
+ return A_unitary, b_normalized
137
+
138
+
118
139
  @wp.kernel
119
- def _normalize_dirichlet_projector_kernel(
140
+ def _normalize_dirichlet_projector_and_values_kernel(
120
141
  offsets: wp.array(dtype=int),
121
142
  columns: wp.array(dtype=int),
122
143
  block_values: wp.array(dtype=Any),
144
+ fixed_values: wp.array(dtype=Any),
123
145
  ):
124
146
  row = wp.tid()
125
147
 
@@ -131,26 +153,21 @@ def _normalize_dirichlet_projector_kernel(
131
153
 
132
154
  diag = wp.lower_bound(columns, beg, end, row)
133
155
 
134
- if diag < end and columns[diag] == row:
135
- P = block_values[diag]
156
+ if diag < end:
157
+ if columns[diag] == row:
158
+ P = block_values[diag]
136
159
 
137
- P_sq = P * P
138
- trace_P = wp.trace(P)
139
- trace_P_sq = wp.trace(P_sq)
160
+ P_norm, v_norm = _normalize_projector_and_value(P, fixed_values[row])
140
161
 
141
- if wp.nonzero(trace_P_sq):
142
- scale = trace_P / trace_P_sq
143
- block_values[diag] = scale * P
144
- else:
145
- block_values[diag] = P - P
162
+ block_values[diag] = P_norm
163
+ fixed_values[row] = v_norm
146
164
 
147
165
 
148
166
  @wp.kernel
149
- def _normalize_dirichlet_projector_and_values_kernel(
167
+ def _normalize_dirichlet_projector_kernel(
150
168
  offsets: wp.array(dtype=int),
151
169
  columns: wp.array(dtype=int),
152
170
  block_values: wp.array(dtype=Any),
153
- fixed_values: wp.array(dtype=Any),
154
171
  ):
155
172
  row = wp.tid()
156
173
 
@@ -162,17 +179,9 @@ def _normalize_dirichlet_projector_and_values_kernel(
162
179
 
163
180
  diag = wp.lower_bound(columns, beg, end, row)
164
181
 
165
- if diag < end and columns[diag] == row:
166
- P = block_values[diag]
167
-
168
- P_sq = P * P
169
- trace_P = wp.trace(P)
170
- trace_P_sq = wp.trace(P_sq)
182
+ if diag < end:
183
+ if columns[diag] == row:
184
+ P = block_values[diag]
171
185
 
172
- if wp.nonzero(trace_P_sq):
173
- scale = trace_P / trace_P_sq
174
- block_values[diag] = scale * P
175
- fixed_values[row] = scale * fixed_values[row]
176
- else:
177
- block_values[diag] = P - P
178
- fixed_values[row] = fixed_values[row] - fixed_values[row]
186
+ P_norm, v_norm = _normalize_projector_and_value(P, type(P[0])())
187
+ block_values[diag] = P_norm
warp/fem/domain.py CHANGED
@@ -1,15 +1,17 @@
1
- from enum import Enum
2
- from typing import Union
1
+ from typing import Optional, Union
3
2
 
4
3
  import warp as wp
5
4
  import warp.codegen
6
5
  import warp.context
6
+ import warp.fem.cache as cache
7
+ import warp.fem.utils as utils
7
8
  from warp.fem.geometry import (
8
9
  Element,
9
10
  Geometry,
10
11
  GeometryPartition,
11
12
  WholeGeometryPartition,
12
13
  )
14
+ from warp.fem.types import ElementKind
13
15
 
14
16
  GeometryOrPartition = Union[Geometry, GeometryPartition]
15
17
 
@@ -17,12 +19,6 @@ GeometryOrPartition = Union[Geometry, GeometryPartition]
17
19
  class GeometryDomain:
18
20
  """Interface class for domains, i.e. (partial) views of elements in a Geometry"""
19
21
 
20
- class ElementKind(Enum):
21
- """Possible kinds of elements contained in a domain"""
22
-
23
- CELL = 0
24
- SIDE = 1
25
-
26
22
  def __init__(self, geometry: GeometryOrPartition):
27
23
  if isinstance(geometry, GeometryPartition):
28
24
  self.geometry_partition = geometry
@@ -106,8 +102,8 @@ class Cells(GeometryDomain):
106
102
  super().__init__(geometry)
107
103
 
108
104
  @property
109
- def element_kind(self) -> GeometryDomain.ElementKind:
110
- return GeometryDomain.ElementKind.CELL
105
+ def element_kind(self) -> ElementKind:
106
+ return ElementKind.CELL
111
107
 
112
108
  @property
113
109
  def dimension(self) -> int:
@@ -157,7 +153,7 @@ class Cells(GeometryDomain):
157
153
  return self.geometry.cell_measure_ratio
158
154
 
159
155
  @property
160
- def eval_normal(self) -> wp.Function:
156
+ def element_normal(self) -> wp.Function:
161
157
  return self.geometry.cell_normal
162
158
 
163
159
  @property
@@ -172,9 +168,11 @@ class Sides(GeometryDomain):
172
168
  self.geometry = geometry
173
169
  super().__init__(geometry)
174
170
 
171
+ self.element_lookup = None
172
+
175
173
  @property
176
- def element_kind(self) -> GeometryDomain.ElementKind:
177
- return GeometryDomain.ElementKind.SIDE
174
+ def element_kind(self) -> ElementKind:
175
+ return ElementKind.SIDE
178
176
 
179
177
  @property
180
178
  def dimension(self) -> int:
@@ -224,7 +222,7 @@ class Sides(GeometryDomain):
224
222
  return self.geometry.side_measure_ratio
225
223
 
226
224
  @property
227
- def eval_normal(self) -> wp.Function:
225
+ def element_normal(self) -> wp.Function:
228
226
  return self.geometry.side_normal
229
227
 
230
228
 
@@ -260,3 +258,96 @@ class FrontierSides(Sides):
260
258
  @property
261
259
  def element_index(self) -> wp.Function:
262
260
  return self.geometry_partition.frontier_side_index
261
+
262
+
263
+ class Subdomain(GeometryDomain):
264
+ """Subdomain -- restriction of domain to a subset of its elements"""
265
+
266
+ def __init__(
267
+ self,
268
+ domain: GeometryDomain,
269
+ element_mask: Optional[wp.array] = None,
270
+ element_indices: Optional[wp.array] = None,
271
+ temporary_store: Optional[cache.TemporaryStore] = None,
272
+ ):
273
+ """
274
+ Create a subdomain from a subset of elements.
275
+
276
+ Exactly one of `element_mask` and `element_indices` should be provided.
277
+
278
+ Args:
279
+ domain: the containing domain
280
+ element_mask: Array of length ``domain.element_count()`` indicating which elements should be included. Array values must be either ``1`` (selected) or ``0`` (not selected).
281
+ element_indices: Explicit array of element indices to include
282
+ """
283
+
284
+ super().__init__(domain.geometry_partition)
285
+
286
+ if element_indices is None:
287
+ if element_mask is None:
288
+ raise ValueError("Either 'element_mask' or 'element_indices' should be provided")
289
+ element_indices, _ = utils.masked_indices(mask=element_mask, temporary_store=temporary_store)
290
+ element_indices = element_indices.detach()
291
+ elif element_mask is not None:
292
+ raise ValueError("Only one of 'element_mask' and 'element_indices' should be provided")
293
+
294
+ self._domain = domain
295
+ self._element_indices = element_indices
296
+ self.ElementIndexArg = self._make_element_index_arg()
297
+ self.element_index = self._make_element_index()
298
+
299
+ # forward
300
+ self.ElementArg = self._domain.ElementArg
301
+ self.geometry_element_count = self._domain.geometry_element_count
302
+ self.reference_element = self._domain.reference_element
303
+ self.element_arg_value = self._domain.element_arg_value
304
+ self.element_measure = self._domain.element_measure
305
+ self.element_measure_ratio = self._domain.element_measure_ratio
306
+ self.element_position = self._domain.element_position
307
+ self.element_deformation_gradient = self._domain.element_deformation_gradient
308
+ self.element_lookup = self._domain.element_lookup
309
+ self.element_normal = self._domain.element_normal
310
+
311
+ @property
312
+ def name(self) -> str:
313
+ return f"{self._domain.name}_Subdomain"
314
+
315
+ def __eq__(self, other) -> bool:
316
+ return (
317
+ self.__class__ == other.__class__
318
+ and self.geometry_partition == other.geometry_partition
319
+ and self._element_indices == other._element_indices
320
+ )
321
+
322
+ @property
323
+ def element_kind(self) -> ElementKind:
324
+ return self._domain.element_kind
325
+
326
+ @property
327
+ def dimension(self) -> int:
328
+ return self._domain.dimension
329
+
330
+ def element_count(self) -> int:
331
+ return self._element_indices.shape[0]
332
+
333
+ def _make_element_index_arg(self):
334
+ @cache.dynamic_struct(suffix=self.name)
335
+ class ElementIndexArg:
336
+ domain_arg: self._domain.ElementIndexArg
337
+ element_indices: wp.array(dtype=int)
338
+
339
+ return ElementIndexArg
340
+
341
+ @cache.cached_arg_value
342
+ def element_index_arg_value(self, device: warp.context.Devicelike):
343
+ arg = self.ElementIndexArg()
344
+ arg.domain_arg = self._domain.element_index_arg_value(device)
345
+ arg.element_indices = self._element_indices.to(device)
346
+ return arg
347
+
348
+ def _make_element_index(self) -> wp.Function:
349
+ @cache.dynamic_func(suffix=self.name)
350
+ def element_index(arg: self.ElementIndexArg, index: int):
351
+ return self._domain.element_index(arg.domain_arg, arg.element_indices[index])
352
+
353
+ return element_index
@@ -3,7 +3,7 @@ from typing import Optional, Union
3
3
  from warp.fem.domain import Cells, GeometryDomain
4
4
  from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction, make_space_partition, make_space_restriction
5
5
 
6
- from .field import DiscreteField, FieldLike, SpaceField
6
+ from .field import DiscreteField, FieldLike, GeometryField, ImplicitField, NonconformingField, SpaceField, UniformField
7
7
  from .nodal_field import NodalField
8
8
  from .restriction import FieldRestriction
9
9
  from .test import TestField
@@ -85,8 +85,8 @@ def make_trial(
85
85
  """
86
86
 
87
87
  if space_restriction is not None:
88
- domain = space.domain
89
- space_partition = space.space_partition
88
+ domain = space_restriction.domain
89
+ space_partition = space_restriction.space_partition
90
90
 
91
91
  if space_partition is None:
92
92
  if domain is None:
@@ -98,3 +98,14 @@ def make_trial(
98
98
  domain = Cells(geometry=space_partition.geo_partition)
99
99
 
100
100
  return TrialField(space, space_partition, domain)
101
+
102
+
103
+ def make_discrete_field(
104
+ space: FunctionSpace,
105
+ space_partition: Optional[SpacePartition] = None,
106
+ ) -> DiscreteField:
107
+ """Constructs a zero-initialized discrete field over a function space or partition
108
+
109
+ See also: :meth:`warp.fem.FunctionSpace.make_field`
110
+ """
111
+ return space.make_field(space_partition=space_partition)