warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import ClassVar
17
+
16
18
  import warp as wp
17
19
  from warp.fem import cache
18
- from warp.fem.types import Coords, ElementIndex, Sample
20
+ from warp.fem.polynomial import Polynomial
21
+ from warp.fem.types import Coords, ElementIndex, Sample, make_free_sample
19
22
 
20
23
  from .geometry import Geometry
21
24
 
@@ -23,7 +26,31 @@ _mat32 = wp.mat(shape=(3, 2), dtype=float)
23
26
 
24
27
 
25
28
  class DeformedGeometry(Geometry):
26
- def __init__(self, field: "wp.fem.field.GeometryField", relative: bool = True):
29
+ _dynamic_attribute_constructors_phase_1: ClassVar = {
30
+ "CellArg": lambda obj: obj._make_cell_arg(),
31
+ "SideArg": lambda obj: obj._make_side_arg(),
32
+ "cell_position": lambda obj: obj._make_cell_position(),
33
+ "cell_deformation_gradient": lambda obj: obj._make_cell_deformation_gradient(),
34
+ "side_to_cell_arg": lambda obj: obj._make_side_to_cell_arg(),
35
+ "side_position": lambda obj: obj._make_side_position(),
36
+ "side_deformation_gradient": lambda obj: obj._make_side_deformation_gradient(),
37
+ "side_inner_cell_index": lambda obj: obj._make_side_inner_cell_index(),
38
+ "side_outer_cell_index": lambda obj: obj._make_side_outer_cell_index(),
39
+ "side_inner_cell_coords": lambda obj: obj._make_side_inner_cell_coords(),
40
+ "side_outer_cell_coords": lambda obj: obj._make_side_outer_cell_coords(),
41
+ "side_from_cell_coords": lambda obj: obj._make_side_from_cell_coords(),
42
+ "cell_bvh_id": lambda obj: obj._make_cell_bvh_id(),
43
+ "cell_bounds": lambda obj: obj._make_cell_bounds(),
44
+ }
45
+
46
+ _dynamic_attribute_constructors_phase_2: ClassVar = {
47
+ "cell_closest_point": lambda obj: obj._make_cell_closest_point(),
48
+ "side_closest_point": lambda obj: obj._make_side_closest_point(),
49
+ "cell_coordinates": lambda obj: obj._make_cell_coordinates(),
50
+ "side_coordinates": lambda obj: obj._make_side_coordinates(),
51
+ }
52
+
53
+ def __init__(self, field: "wp.fem.field.GeometryField", relative: bool = True, build_bvh: bool = False):
27
54
  """Constructs a Deformed Geometry from a displacement or absolute position field defined over a base geometry.
28
55
  The deformation field does not need to be isoparameteric.
29
56
 
@@ -33,10 +60,7 @@ class DeformedGeometry(Geometry):
33
60
  from warp.fem.field import DiscreteField, GeometryField
34
61
 
35
62
  if isinstance(field, DiscreteField):
36
- if (
37
- not wp.types.type_is_vector(field.dtype)
38
- or wp.types.type_length(field.dtype) != field.geometry.dimension
39
- ):
63
+ if not wp.types.type_is_vector(field.dtype) or wp.types.type_size(field.dtype) != field.geometry.dimension:
40
64
  raise ValueError(
41
65
  "Invalid value type for position field, must be vector-valued with same dimension as underlying geometry"
42
66
  )
@@ -46,14 +70,10 @@ class DeformedGeometry(Geometry):
46
70
  self._relative = relative
47
71
 
48
72
  self.field: GeometryField = field
73
+ self.field_trace = field.trace()
49
74
  self.dimension = self.base.dimension
50
75
 
51
- self.CellArg = self.field.ElementEvalArg
52
-
53
- self.field_trace = field.trace()
54
- self.SideArg = self._make_side_arg()
55
76
  self.SideIndexArg = self.base.SideIndexArg
56
-
57
77
  self.cell_count = self.base.cell_count
58
78
  self.vertex_count = self.base.vertex_count
59
79
  self.side_count = self.base.side_count
@@ -62,23 +82,18 @@ class DeformedGeometry(Geometry):
62
82
  self.reference_side = self.base.reference_side
63
83
 
64
84
  self.side_index_arg_value = self.base.side_index_arg_value
65
-
66
- self.cell_position = self._make_cell_position()
67
- self.cell_deformation_gradient = self._make_cell_deformation_gradient()
68
-
85
+ self.fill_side_index_arg = self.base.fill_side_index_arg
69
86
  self.boundary_side_index = self.base.boundary_side_index
70
87
 
71
- self.side_to_cell_arg = self._make_side_to_cell_arg()
72
- self.side_position = self._make_side_position()
73
- self.side_deformation_gradient = self._make_side_deformation_gradient()
74
- self.side_inner_cell_index = self._make_side_inner_cell_index()
75
- self.side_outer_cell_index = self._make_side_outer_cell_index()
76
- self.side_inner_cell_coords = self._make_side_inner_cell_coords()
77
- self.side_outer_cell_coords = self._make_side_outer_cell_coords()
78
- self.side_from_cell_coords = self._make_side_from_cell_coords()
88
+ cache.setup_dynamic_attributes(self, constructors=self._dynamic_attribute_constructors_phase_1)
79
89
 
80
90
  self._make_default_dependent_implementations()
81
91
 
92
+ cache.setup_dynamic_attributes(self, constructors=self._dynamic_attribute_constructors_phase_2)
93
+
94
+ if build_bvh:
95
+ self.build_bvh(self.field.dof_values.device)
96
+
82
97
  @property
83
98
  def name(self) -> str:
84
99
  return f"DefGeo_{self.field.name}_{'rel' if self._relative else 'abs'}"
@@ -89,35 +104,49 @@ class DeformedGeometry(Geometry):
89
104
 
90
105
  # Geometry device interface
91
106
 
92
- @cache.cached_arg_value
93
- def cell_arg_value(self, device) -> "DeformedGeometry.CellArg":
94
- args = self.CellArg()
107
+ def _make_cell_arg(self):
108
+ @cache.dynamic_struct(suffix=self.name)
109
+ class CellArg:
110
+ base_arg: self.base.CellArg
111
+ field_arg: self.field.EvalArg
112
+ cell_bvh: wp.uint64
95
113
 
96
- args.elt_arg = self.base.cell_arg_value(device)
97
- args.eval_arg = self.field.eval_arg_value(device)
114
+ return CellArg
98
115
 
116
+ def cell_arg_value(self, device) -> "DeformedGeometry.CellArg":
117
+ args = self.CellArg()
118
+ self.fill_cell_arg(args, device)
99
119
  return args
100
120
 
121
+ def fill_cell_arg(self, args: "DeformedGeometry.CellArg", device):
122
+ self.base.fill_cell_arg(args.base_arg, device)
123
+ self.field.fill_eval_arg(args.field_arg, device)
124
+ args.cell_bvh = self.bvh_id(device)
125
+
101
126
  def _make_cell_position(self):
102
127
  @cache.dynamic_func(suffix=self.name)
103
128
  def cell_position_absolute(cell_arg: self.CellArg, s: Sample):
104
- return self.field.eval_inner(cell_arg, s)
129
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
130
+ return self.field.eval_inner(field_arg, s)
105
131
 
106
132
  @cache.dynamic_func(suffix=self.name)
107
133
  def cell_position(cell_arg: self.CellArg, s: Sample):
108
- return self.field.eval_inner(cell_arg, s) + self.base.cell_position(cell_arg.elt_arg, s)
134
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
135
+ return self.field.eval_inner(field_arg, s) + self.base.cell_position(cell_arg.base_arg, s)
109
136
 
110
137
  return cell_position if self._relative else cell_position_absolute
111
138
 
112
139
  def _make_cell_deformation_gradient(self):
113
140
  @cache.dynamic_func(suffix=self.name)
114
141
  def cell_deformation_gradient_absolute(cell_arg: self.CellArg, s: Sample):
115
- return self.field.eval_reference_grad_inner(cell_arg, s)
142
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
143
+ return self.field.eval_reference_grad_inner(field_arg, s)
116
144
 
117
145
  @cache.dynamic_func(suffix=self.name)
118
146
  def cell_deformation_gradient(cell_arg: self.CellArg, s: Sample):
119
- return self.field.eval_reference_grad_inner(cell_arg, s) + self.base.cell_deformation_gradient(
120
- cell_arg.elt_arg, s
147
+ field_arg = self.field.ElementEvalArg(cell_arg.base_arg, cell_arg.field_arg)
148
+ return self.field.eval_reference_grad_inner(field_arg, s) + self.base.cell_deformation_gradient(
149
+ cell_arg.base_arg, s
121
150
  )
122
151
 
123
152
  return cell_deformation_gradient if self._relative else cell_deformation_gradient_absolute
@@ -128,19 +157,21 @@ class DeformedGeometry(Geometry):
128
157
  base_arg: self.base.SideArg
129
158
  trace_arg: self.field_trace.EvalArg
130
159
  field_arg: self.field.EvalArg
160
+ cell_bvh: wp.uint64
131
161
 
132
162
  return SideArg
133
163
 
134
- @cache.cached_arg_value
135
164
  def side_arg_value(self, device) -> "DeformedGeometry.SideArg":
136
165
  args = self.SideArg()
137
-
138
- args.base_arg = self.base.side_arg_value(device)
139
- args.field_arg = self.field.eval_arg_value(device)
140
- args.trace_arg = self.field_trace.eval_arg_value(device)
141
-
166
+ self.fill_side_arg(args, device)
142
167
  return args
143
168
 
169
+ def fill_side_arg(self, args: "DeformedGeometry.SideArg", device):
170
+ self.base.fill_side_arg(args.base_arg, device)
171
+ self.field.fill_eval_arg(args.field_arg, device)
172
+ self.field_trace.fill_eval_arg(args.trace_arg, device)
173
+ args.cell_bvh = self.bvh_id(device)
174
+
144
175
  def _make_side_position(self):
145
176
  @cache.dynamic_func(suffix=self.name)
146
177
  def side_position_absolute(args: self.SideArg, s: Sample):
@@ -216,6 +247,37 @@ class DeformedGeometry(Geometry):
216
247
  def _make_side_to_cell_arg(self):
217
248
  @cache.dynamic_func(suffix=self.name)
218
249
  def side_to_cell_arg(side_arg: self.SideArg):
219
- return self.CellArg(self.base.side_to_cell_arg(side_arg.base_arg), side_arg.field_arg)
250
+ return self.CellArg(self.base.side_to_cell_arg(side_arg.base_arg), side_arg.field_arg, side_arg.cell_bvh)
220
251
 
221
252
  return side_to_cell_arg
253
+
254
+ def _make_cell_bvh_id(self):
255
+ @cache.dynamic_func(suffix=self.name)
256
+ def cell_bvh_id(cell_arg: self.CellArg):
257
+ return cell_arg.cell_bvh
258
+
259
+ return cell_bvh_id
260
+
261
+ def _make_cell_bounds(self):
262
+ points, _weights = self.reference_cell().instantiate_quadrature(
263
+ order=self.field.degree, family=Polynomial.LOBATTO_GAUSS_LEGENDRE
264
+ )
265
+
266
+ points = cache.cached_mat_type((len(points), 3), dtype=float)(points)
267
+ point_count = len(points)
268
+
269
+ @cache.dynamic_func(suffix=self.name)
270
+ def cell_bounds(cell_arg: self.CellArg, cell_index: ElementIndex):
271
+ lower = wp.vec3(1.0e8)
272
+ upper = wp.vec3(-1.0e8)
273
+ for k in range(point_count):
274
+ pos = self.cell_position(cell_arg, make_free_sample(cell_index, points[k]))
275
+ lower = wp.min(lower, pos)
276
+ upper = wp.max(upper, pos)
277
+
278
+ # pad the BBox to account for potential overflows
279
+ pad = 0.25 * (upper - lower)
280
+
281
+ return lower - pad, upper + pad
282
+
283
+ return cell_bounds
@@ -15,14 +15,18 @@
15
15
 
16
16
  from typing import List, Tuple
17
17
 
18
+ import warp as wp
18
19
  from warp.fem.polynomial import Polynomial, quadrature_1d
19
20
  from warp.fem.types import Coords
20
21
 
22
+ _vec1 = wp.types.vector(length=1, dtype=float)
23
+
21
24
 
22
25
  class Element:
23
26
  dimension = 0
24
27
  """Intrinsic dimension of the element"""
25
28
 
29
+ @staticmethod
26
30
  def measure() -> float:
27
31
  """Measure (area, volume, ...) of the reference element"""
28
32
  raise NotImplementedError
@@ -32,10 +36,32 @@ class Element:
32
36
  """Returns a quadrature of a given order for a prototypical element"""
33
37
  raise NotImplementedError
34
38
 
35
- def center(self) -> Tuple[float]:
36
- coords, _ = self.instantiate_quadrature(order=0, family=None)
39
+ @classmethod
40
+ def center(cls) -> Coords:
41
+ """Returns the coordinates for the center of the element"""
42
+ coords, _ = cls.instantiate_quadrature(order=0, family=None)
37
43
  return coords[0]
38
44
 
45
+ @wp.func
46
+ def project(v: Coords):
47
+ """project coordinates so that they belong to the element"""
48
+ return wp.min(wp.max(v, Coords(0.0)), Coords(1.0))
49
+
50
+ @wp.func
51
+ def coord_delta(ref_delta: wp.vec3):
52
+ """Transform a delta in reference space to element coords"""
53
+ return ref_delta
54
+
55
+ @wp.func
56
+ def coord_delta(ref_delta: wp.vec2):
57
+ """Transform a delta in reference space to element coords"""
58
+ return Coords(ref_delta[0], ref_delta[1], 0.0)
59
+
60
+ @wp.func
61
+ def coord_delta(ref_delta: _vec1):
62
+ """Transform a delta in reference space to element coords"""
63
+ return Coords(ref_delta[0], 0.0, 0.0)
64
+
39
65
 
40
66
  def _point_count_from_order(order: int, family: Polynomial):
41
67
  if family == Polynomial.GAUSS_LEGENDRE:
@@ -454,6 +480,17 @@ class Triangle(Element):
454
480
 
455
481
  return coords, weights
456
482
 
483
+ @wp.func
484
+ def project(v: Coords):
485
+ n = wp.max(0.0, (v[0] + v[1] - 1.0) * 0.5)
486
+ a = wp.clamp(v[0] - n, 0.0, 1.0)
487
+ b = wp.clamp(v[1] - n, 0.0, 1.0)
488
+ return Coords(a, b, 1.0 - (a + b))
489
+
490
+ @wp.func
491
+ def coord_delta(ref_delta: wp.vec2):
492
+ return Coords(-ref_delta[0] - ref_delta[1], ref_delta[0], ref_delta[1])
493
+
457
494
 
458
495
  class Tetrahedron(Element):
459
496
  dimension = 3
@@ -774,3 +811,20 @@ class Tetrahedron(Element):
774
811
  raise NotImplementedError
775
812
 
776
813
  return coords, weights
814
+
815
+ @wp.func
816
+ def project(v: Coords):
817
+ # project on 1-2-3 half-space
818
+ n = wp.max(0.0, (v[0] + v[1] + v[2] - 1.0) / 3.0)
819
+ c = v - Coords(n)
820
+
821
+ # project on 1-2, 2-3, 3-1 half-spaces
822
+ n = wp.max(0.0, (c[0] + c[1] - 1.0) * 0.5)
823
+ c = c - Coords(n, n, 0.0)
824
+ n = wp.max(0.0, (c[1] + c[2] - 1.0) * 0.5)
825
+ c = c - Coords(0.0, n, n)
826
+ n = wp.max(0.0, (c[2] + c[0] - 1.0) * 0.5)
827
+ c = c - Coords(n, 0.0, n)
828
+
829
+ # project on cube
830
+ return Element.project(c)