warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import ClassVar
17
+
16
18
  import warp as wp
17
19
  from warp.fem import cache
18
- from warp.fem.types import Coords, ElementIndex, Sample
20
+ from warp.fem.polynomial import Polynomial
21
+ from warp.fem.types import Coords, ElementIndex, Sample, make_free_sample
19
22
 
20
23
  from .geometry import Geometry
21
24
 
@@ -23,7 +26,31 @@ _mat32 = wp.mat(shape=(3, 2), dtype=float)
23
26
 
24
27
 
25
28
  class DeformedGeometry(Geometry):
26
- def __init__(self, field: "wp.fem.field.GeometryField", relative: bool = True):
29
+ _dynamic_attribute_constructors_phase_1: ClassVar = {
30
+ "CellArg": lambda obj: obj._make_cell_arg(),
31
+ "SideArg": lambda obj: obj._make_side_arg(),
32
+ "cell_position": lambda obj: obj._make_cell_position(),
33
+ "cell_deformation_gradient": lambda obj: obj._make_cell_deformation_gradient(),
34
+ "side_to_cell_arg": lambda obj: obj._make_side_to_cell_arg(),
35
+ "side_position": lambda obj: obj._make_side_position(),
36
+ "side_deformation_gradient": lambda obj: obj._make_side_deformation_gradient(),
37
+ "side_inner_cell_index": lambda obj: obj._make_side_inner_cell_index(),
38
+ "side_outer_cell_index": lambda obj: obj._make_side_outer_cell_index(),
39
+ "side_inner_cell_coords": lambda obj: obj._make_side_inner_cell_coords(),
40
+ "side_outer_cell_coords": lambda obj: obj._make_side_outer_cell_coords(),
41
+ "side_from_cell_coords": lambda obj: obj._make_side_from_cell_coords(),
42
+ "cell_bvh_id": lambda obj: obj._make_cell_bvh_id(),
43
+ "cell_bounds": lambda obj: obj._make_cell_bounds(),
44
+ }
45
+
46
+ _dynamic_attribute_constructors_phase_2: ClassVar = {
47
+ "cell_closest_point": lambda obj: obj._make_cell_closest_point(),
48
+ "side_closest_point": lambda obj: obj._make_side_closest_point(),
49
+ "cell_coordinates": lambda obj: obj._make_cell_coordinates(),
50
+ "side_coordinates": lambda obj: obj._make_side_coordinates(),
51
+ }
52
+
53
+ def __init__(self, field: "wp.fem.field.GeometryField", relative: bool = True, build_bvh: bool = False):
27
54
  """Constructs a Deformed Geometry from a displacement or absolute position field defined over a base geometry.
28
55
  The deformation field does not need to be isoparameteric.
29
56
 
@@ -33,10 +60,7 @@ class DeformedGeometry(Geometry):
33
60
  from warp.fem.field import DiscreteField, GeometryField
34
61
 
35
62
  if isinstance(field, DiscreteField):
36
- if (
37
- not wp.types.type_is_vector(field.dtype)
38
- or wp.types.type_length(field.dtype) != field.geometry.dimension
39
- ):
63
+ if not wp.types.type_is_vector(field.dtype) or wp.types.type_size(field.dtype) != field.geometry.dimension:
40
64
  raise ValueError(
41
65
  "Invalid value type for position field, must be vector-valued with same dimension as underlying geometry"
42
66
  )
@@ -46,14 +70,10 @@ class DeformedGeometry(Geometry):
46
70
  self._relative = relative
47
71
 
48
72
  self.field: GeometryField = field
73
+ self.field_trace = field.trace()
49
74
  self.dimension = self.base.dimension
50
75
 
51
- self.CellArg = self.field.ElementEvalArg
52
-
53
- self.field_trace = field.trace()
54
- self.SideArg = self._make_side_arg()
55
76
  self.SideIndexArg = self.base.SideIndexArg
56
-
57
77
  self.cell_count = self.base.cell_count
58
78
  self.vertex_count = self.base.vertex_count
59
79
  self.side_count = self.base.side_count
@@ -62,23 +82,18 @@ class DeformedGeometry(Geometry):
62
82
  self.reference_side = self.base.reference_side
63
83
 
64
84
  self.side_index_arg_value = self.base.side_index_arg_value
65
-
66
- self.cell_position = self._make_cell_position()
67
- self.cell_deformation_gradient = self._make_cell_deformation_gradient()
68
-
85
+ self.fill_side_index_arg = self.base.fill_side_index_arg
69
86
  self.boundary_side_index = self.base.boundary_side_index
70
87
 
71
- self.side_to_cell_arg = self._make_side_to_cell_arg()
72
- self.side_position = self._make_side_position()
73
- self.side_deformation_gradient = self._make_side_deformation_gradient()
74
- self.side_inner_cell_index = self._make_side_inner_cell_index()
75
- self.side_outer_cell_index = self._make_side_outer_cell_index()
76
- self.side_inner_cell_coords = self._make_side_inner_cell_coords()
77
- self.side_outer_cell_coords = self._make_side_outer_cell_coords()
78
- self.side_from_cell_coords = self._make_side_from_cell_coords()
88
+ cache.setup_dynamic_attributes(self, constructors=self._dynamic_attribute_constructors_phase_1)
79
89
 
80
90
  self._make_default_dependent_implementations()
81
91
 
92
+ cache.setup_dynamic_attributes(self, constructors=self._dynamic_attribute_constructors_phase_2)
93
+
94
+ if build_bvh:
95
+ self.build_bvh(self.field.dof_values.device)
96
+
82
97
  @property
83
98
  def name(self) -> str:
84
99
  return f"DefGeo_{self.field.name}_{'rel' if self._relative else 'abs'}"
@@ -89,35 +104,49 @@ class DeformedGeometry(Geometry):
89
104
 
90
105
  # Geometry device interface
91
106
 
92
- @cache.cached_arg_value
93
- def cell_arg_value(self, device) -> "DeformedGeometry.CellArg":
94
- args = self.CellArg()
107
+ def _make_cell_arg(self):
108
+ @cache.dynamic_struct(suffix=self.name)
109
+ class CellArg:
110
+ base_arg: self.base.CellArg
111
+ field_arg: self.field.EvalArg
112
+ cell_bvh: wp.uint64
95
113
 
96
- args.elt_arg = self.base.cell_arg_value(device)
97
- args.eval_arg = self.field.eval_arg_value(device)
114
+ return CellArg
98
115
 
116
+ def cell_arg_value(self, device) -> "DeformedGeometry.CellArg":
117
+ args = self.CellArg()
118
+ self.fill_cell_arg(args, device)
99
119
  return args
100
120
 
121
+ def fill_cell_arg(self, args: "DeformedGeometry.CellArg", device):
122
+ self.base.fill_cell_arg(args.base_arg, device)
123
+ self.field.fill_eval_arg(args.field_arg, device)
124
+ args.cell_bvh = self.bvh_id(device)
125
+
101
126
  def _make_cell_position(self):
102
127
  @cache.dynamic_func(suffix=self.name)
103
128
  def cell_position_absolute(cell_arg: self.CellArg, s: Sample):
104
- return self.field.eval_inner(cell_arg, s)
129
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
130
+ return self.field.eval_inner(field_arg, s)
105
131
 
106
132
  @cache.dynamic_func(suffix=self.name)
107
133
  def cell_position(cell_arg: self.CellArg, s: Sample):
108
- return self.field.eval_inner(cell_arg, s) + self.base.cell_position(cell_arg.elt_arg, s)
134
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
135
+ return self.field.eval_inner(field_arg, s) + self.base.cell_position(cell_arg.base_arg, s)
109
136
 
110
137
  return cell_position if self._relative else cell_position_absolute
111
138
 
112
139
  def _make_cell_deformation_gradient(self):
113
140
  @cache.dynamic_func(suffix=self.name)
114
141
  def cell_deformation_gradient_absolute(cell_arg: self.CellArg, s: Sample):
115
- return self.field.eval_reference_grad_inner(cell_arg, s)
142
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
143
+ return self.field.eval_reference_grad_inner(field_arg, s)
116
144
 
117
145
  @cache.dynamic_func(suffix=self.name)
118
146
  def cell_deformation_gradient(cell_arg: self.CellArg, s: Sample):
119
- return self.field.eval_reference_grad_inner(cell_arg, s) + self.base.cell_deformation_gradient(
120
- cell_arg.elt_arg, s
147
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
148
+ return self.field.eval_reference_grad_inner(field_arg, s) + self.base.cell_deformation_gradient(
149
+ cell_arg.base_arg, s
121
150
  )
122
151
 
123
152
  return cell_deformation_gradient if self._relative else cell_deformation_gradient_absolute
@@ -128,19 +157,21 @@ class DeformedGeometry(Geometry):
128
157
  base_arg: self.base.SideArg
129
158
  trace_arg: self.field_trace.EvalArg
130
159
  field_arg: self.field.EvalArg
160
+ cell_bvh: wp.uint64
131
161
 
132
162
  return SideArg
133
163
 
134
- @cache.cached_arg_value
135
164
  def side_arg_value(self, device) -> "DeformedGeometry.SideArg":
136
165
  args = self.SideArg()
137
-
138
- args.base_arg = self.base.side_arg_value(device)
139
- args.field_arg = self.field.eval_arg_value(device)
140
- args.trace_arg = self.field_trace.eval_arg_value(device)
141
-
166
+ self.fill_side_arg(args, device)
142
167
  return args
143
168
 
169
+ def fill_side_arg(self, args: "DeformedGeometry.SideArg", device):
170
+ self.base.fill_side_arg(args.base_arg, device)
171
+ self.field.fill_eval_arg(args.field_arg, device)
172
+ self.field_trace.fill_eval_arg(args.trace_arg, device)
173
+ args.cell_bvh = self.bvh_id(device)
174
+
144
175
  def _make_side_position(self):
145
176
  @cache.dynamic_func(suffix=self.name)
146
177
  def side_position_absolute(args: self.SideArg, s: Sample):
@@ -216,6 +247,37 @@ class DeformedGeometry(Geometry):
216
247
  def _make_side_to_cell_arg(self):
217
248
  @cache.dynamic_func(suffix=self.name)
218
249
  def side_to_cell_arg(side_arg: self.SideArg):
219
- return self.CellArg(self.base.side_to_cell_arg(side_arg.base_arg), side_arg.field_arg)
250
+ return self.CellArg(self.base.side_to_cell_arg(side_arg.base_arg), side_arg.field_arg, side_arg.cell_bvh)
220
251
 
221
252
  return side_to_cell_arg
253
+
254
+ def _make_cell_bvh_id(self):
255
+ @cache.dynamic_func(suffix=self.name)
256
+ def cell_bvh_id(cell_arg: self.CellArg):
257
+ return cell_arg.cell_bvh
258
+
259
+ return cell_bvh_id
260
+
261
+ def _make_cell_bounds(self):
262
+ points, _weights = self.reference_cell().instantiate_quadrature(
263
+ order=self.field.degree, family=Polynomial.LOBATTO_GAUSS_LEGENDRE
264
+ )
265
+
266
+ points = cache.cached_mat_type((len(points), 3), dtype=float)(points)
267
+ point_count = len(points)
268
+
269
+ @cache.dynamic_func(suffix=self.name)
270
+ def cell_bounds(cell_arg: self.CellArg, cell_index: ElementIndex):
271
+ lower = wp.vec3(1.0e8)
272
+ upper = wp.vec3(-1.0e8)
273
+ for k in range(point_count):
274
+ pos = self.cell_position(cell_arg, make_free_sample(cell_index, points[k]))
275
+ lower = wp.min(lower, pos)
276
+ upper = wp.max(upper, pos)
277
+
278
+ # pad the BBox to account for potential overflows
279
+ pad = 0.25 * (upper - lower)
280
+
281
+ return lower - pad, upper + pad
282
+
283
+ return cell_bounds
@@ -15,14 +15,18 @@
15
15
 
16
16
  from typing import List, Tuple
17
17
 
18
+ import warp as wp
18
19
  from warp.fem.polynomial import Polynomial, quadrature_1d
19
20
  from warp.fem.types import Coords
20
21
 
22
+ _vec1 = wp.types.vector(length=1, dtype=float)
23
+
21
24
 
22
25
  class Element:
23
26
  dimension = 0
24
27
  """Intrinsic dimension of the element"""
25
28
 
29
+ @staticmethod
26
30
  def measure() -> float:
27
31
  """Measure (area, volume, ...) of the reference element"""
28
32
  raise NotImplementedError
@@ -32,10 +36,32 @@ class Element:
32
36
  """Returns a quadrature of a given order for a prototypical element"""
33
37
  raise NotImplementedError
34
38
 
35
- def center(self) -> Tuple[float]:
36
- coords, _ = self.instantiate_quadrature(order=0, family=None)
39
+ @classmethod
40
+ def center(cls) -> Coords:
41
+ """Returns the coordinates for the center of the element"""
42
+ coords, _ = cls.instantiate_quadrature(order=0, family=None)
37
43
  return coords[0]
38
44
 
45
+ @wp.func
46
+ def project(v: Coords):
47
+ """project coordinates so that they belong to the element"""
48
+ return wp.min(wp.max(v, Coords(0.0)), Coords(1.0))
49
+
50
+ @wp.func
51
+ def coord_delta(ref_delta: wp.vec3):
52
+ """Transform a delta in reference space to element coords"""
53
+ return ref_delta
54
+
55
+ @wp.func
56
+ def coord_delta(ref_delta: wp.vec2):
57
+ """Transform a delta in reference space to element coords"""
58
+ return Coords(ref_delta[0], ref_delta[1], 0.0)
59
+
60
+ @wp.func
61
+ def coord_delta(ref_delta: _vec1):
62
+ """Transform a delta in reference space to element coords"""
63
+ return Coords(ref_delta[0], 0.0, 0.0)
64
+
39
65
 
40
66
  def _point_count_from_order(order: int, family: Polynomial):
41
67
  if family == Polynomial.GAUSS_LEGENDRE:
@@ -454,6 +480,17 @@ class Triangle(Element):
454
480
 
455
481
  return coords, weights
456
482
 
483
+ @wp.func
484
+ def project(v: Coords):
485
+ n = wp.max(0.0, (v[0] + v[1] - 1.0) * 0.5)
486
+ a = wp.clamp(v[0] - n, 0.0, 1.0)
487
+ b = wp.clamp(v[1] - n, 0.0, 1.0)
488
+ return Coords(a, b, 1.0 - (a + b))
489
+
490
+ @wp.func
491
+ def coord_delta(ref_delta: wp.vec2):
492
+ return Coords(-ref_delta[0] - ref_delta[1], ref_delta[0], ref_delta[1])
493
+
457
494
 
458
495
  class Tetrahedron(Element):
459
496
  dimension = 3
@@ -774,3 +811,20 @@ class Tetrahedron(Element):
774
811
  raise NotImplementedError
775
812
 
776
813
  return coords, weights
814
+
815
+ @wp.func
816
+ def project(v: Coords):
817
+ # project on 1-2-3 half-space
818
+ n = wp.max(0.0, (v[0] + v[1] + v[2] - 1.0) / 3.0)
819
+ c = v - Coords(n)
820
+
821
+ # project on 1-2, 2-3, 3-1 half-spaces
822
+ n = wp.max(0.0, (c[0] + c[1] - 1.0) * 0.5)
823
+ c = c - Coords(n, n, 0.0)
824
+ n = wp.max(0.0, (c[1] + c[2] - 1.0) * 0.5)
825
+ c = c - Coords(0.0, n, n)
826
+ n = wp.max(0.0, (c[2] + c[0] - 1.0) * 0.5)
827
+ c = c - Coords(n, 0.0, n)
828
+
829
+ # project on cube
830
+ return Element.project(c)