warp-lang 1.4.2__py3-none-macosx_10_13_universal2.whl → 1.5.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (165) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +21 -2
  5. warp/build_dll.py +23 -6
  6. warp/builtins.py +1819 -7
  7. warp/codegen.py +197 -61
  8. warp/config.py +2 -2
  9. warp/context.py +379 -107
  10. warp/examples/assets/pixel.jpg +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  14. warp/examples/benchmarks/benchmark_tile.py +179 -0
  15. warp/examples/fem/example_adaptive_grid.py +37 -10
  16. warp/examples/fem/example_apic_fluid.py +3 -2
  17. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  18. warp/examples/fem/example_deformed_geometry.py +1 -1
  19. warp/examples/fem/example_diffusion_3d.py +47 -4
  20. warp/examples/fem/example_distortion_energy.py +220 -0
  21. warp/examples/fem/example_magnetostatics.py +127 -85
  22. warp/examples/fem/example_nonconforming_contact.py +5 -5
  23. warp/examples/fem/example_stokes.py +3 -1
  24. warp/examples/fem/example_streamlines.py +12 -19
  25. warp/examples/fem/utils.py +38 -15
  26. warp/examples/sim/example_cloth.py +4 -25
  27. warp/examples/sim/example_quadruped.py +2 -1
  28. warp/examples/tile/example_tile_convolution.py +58 -0
  29. warp/examples/tile/example_tile_fft.py +47 -0
  30. warp/examples/tile/example_tile_filtering.py +105 -0
  31. warp/examples/tile/example_tile_matmul.py +79 -0
  32. warp/examples/tile/example_tile_mlp.py +375 -0
  33. warp/fem/__init__.py +8 -0
  34. warp/fem/cache.py +16 -12
  35. warp/fem/dirichlet.py +1 -1
  36. warp/fem/domain.py +44 -1
  37. warp/fem/field/__init__.py +1 -2
  38. warp/fem/field/field.py +31 -19
  39. warp/fem/field/nodal_field.py +101 -49
  40. warp/fem/field/virtual.py +794 -0
  41. warp/fem/geometry/__init__.py +2 -2
  42. warp/fem/geometry/deformed_geometry.py +3 -105
  43. warp/fem/geometry/element.py +13 -0
  44. warp/fem/geometry/geometry.py +165 -7
  45. warp/fem/geometry/grid_2d.py +3 -6
  46. warp/fem/geometry/grid_3d.py +31 -28
  47. warp/fem/geometry/hexmesh.py +3 -46
  48. warp/fem/geometry/nanogrid.py +3 -2
  49. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  50. warp/fem/geometry/tetmesh.py +2 -43
  51. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  52. warp/fem/integrate.py +683 -261
  53. warp/fem/linalg.py +404 -0
  54. warp/fem/operator.py +101 -18
  55. warp/fem/polynomial.py +5 -5
  56. warp/fem/quadrature/quadrature.py +45 -21
  57. warp/fem/space/__init__.py +45 -11
  58. warp/fem/space/basis_function_space.py +451 -0
  59. warp/fem/space/basis_space.py +58 -11
  60. warp/fem/space/function_space.py +146 -5
  61. warp/fem/space/grid_2d_function_space.py +80 -66
  62. warp/fem/space/grid_3d_function_space.py +113 -68
  63. warp/fem/space/hexmesh_function_space.py +96 -108
  64. warp/fem/space/nanogrid_function_space.py +62 -110
  65. warp/fem/space/quadmesh_function_space.py +208 -0
  66. warp/fem/space/shape/__init__.py +45 -7
  67. warp/fem/space/shape/cube_shape_function.py +328 -54
  68. warp/fem/space/shape/shape_function.py +10 -1
  69. warp/fem/space/shape/square_shape_function.py +328 -60
  70. warp/fem/space/shape/tet_shape_function.py +269 -19
  71. warp/fem/space/shape/triangle_shape_function.py +238 -19
  72. warp/fem/space/tetmesh_function_space.py +69 -37
  73. warp/fem/space/topology.py +38 -0
  74. warp/fem/space/trimesh_function_space.py +179 -0
  75. warp/fem/utils.py +6 -331
  76. warp/jax_experimental.py +3 -1
  77. warp/native/array.h +15 -0
  78. warp/native/builtin.h +66 -26
  79. warp/native/bvh.h +4 -0
  80. warp/native/coloring.cpp +604 -0
  81. warp/native/cuda_util.cpp +68 -51
  82. warp/native/cuda_util.h +2 -1
  83. warp/native/fabric.h +8 -0
  84. warp/native/hashgrid.h +4 -0
  85. warp/native/marching.cu +8 -0
  86. warp/native/mat.h +14 -3
  87. warp/native/mathdx.cpp +59 -0
  88. warp/native/mesh.h +4 -0
  89. warp/native/range.h +13 -1
  90. warp/native/reduce.cpp +9 -1
  91. warp/native/reduce.cu +7 -0
  92. warp/native/runlength_encode.cpp +9 -1
  93. warp/native/runlength_encode.cu +7 -1
  94. warp/native/scan.cpp +8 -0
  95. warp/native/scan.cu +8 -0
  96. warp/native/scan.h +8 -1
  97. warp/native/sparse.cpp +8 -0
  98. warp/native/sparse.cu +8 -0
  99. warp/native/temp_buffer.h +7 -0
  100. warp/native/tile.h +1854 -0
  101. warp/native/tile_gemm.h +341 -0
  102. warp/native/tile_reduce.h +210 -0
  103. warp/native/volume_builder.cu +8 -0
  104. warp/native/volume_builder.h +8 -0
  105. warp/native/warp.cpp +10 -2
  106. warp/native/warp.cu +369 -15
  107. warp/native/warp.h +12 -2
  108. warp/optim/adam.py +39 -4
  109. warp/paddle.py +29 -12
  110. warp/render/render_opengl.py +140 -67
  111. warp/sim/graph_coloring.py +292 -0
  112. warp/sim/import_urdf.py +8 -8
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +109 -32
  117. warp/sparse.py +1 -1
  118. warp/stubs.py +569 -4
  119. warp/tape.py +12 -7
  120. warp/tests/assets/pixel.npy +0 -0
  121. warp/tests/aux_test_instancing_gc.py +18 -0
  122. warp/tests/test_array.py +39 -0
  123. warp/tests/test_codegen.py +81 -1
  124. warp/tests/test_codegen_instancing.py +30 -0
  125. warp/tests/test_collision.py +110 -0
  126. warp/tests/test_coloring.py +251 -0
  127. warp/tests/test_context.py +34 -0
  128. warp/tests/test_examples.py +21 -5
  129. warp/tests/test_fem.py +453 -113
  130. warp/tests/test_func.py +34 -4
  131. warp/tests/test_generics.py +52 -0
  132. warp/tests/test_iter.py +68 -0
  133. warp/tests/test_lerp.py +13 -87
  134. warp/tests/test_mat_scalar_ops.py +1 -1
  135. warp/tests/test_matmul.py +6 -9
  136. warp/tests/test_matmul_lite.py +6 -11
  137. warp/tests/test_mesh_query_point.py +1 -1
  138. warp/tests/test_module_hashing.py +23 -0
  139. warp/tests/test_overwrite.py +45 -0
  140. warp/tests/test_paddle.py +27 -87
  141. warp/tests/test_print.py +56 -1
  142. warp/tests/test_smoothstep.py +17 -83
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_static.py +3 -3
  145. warp/tests/test_tile.py +744 -0
  146. warp/tests/test_tile_mathdx.py +144 -0
  147. warp/tests/test_tile_mlp.py +383 -0
  148. warp/tests/test_tile_reduce.py +374 -0
  149. warp/tests/test_tile_shared_memory.py +190 -0
  150. warp/tests/test_vbd.py +12 -20
  151. warp/tests/test_volume.py +43 -0
  152. warp/tests/unittest_suites.py +19 -2
  153. warp/tests/unittest_utils.py +4 -2
  154. warp/types.py +340 -74
  155. warp/utils.py +23 -3
  156. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +160 -133
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  159. warp/fem/field/test.py +0 -180
  160. warp/fem/field/trial.py +0 -183
  161. warp/fem/space/collocated_function_space.py +0 -102
  162. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  163. warp/fem/space/trimesh_2d_function_space.py +0 -153
  164. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,6 @@ from .partition import (
12
12
  LinearGeometryPartition,
13
13
  WholeGeometryPartition,
14
14
  )
15
- from .quadmesh_2d import Quadmesh2D
15
+ from .quadmesh import Quadmesh, Quadmesh2D, Quadmesh3D
16
16
  from .tetmesh import Tetmesh
17
- from .trimesh_2d import Trimesh2D
17
+ from .trimesh import Trimesh, Trimesh2D, Trimesh3D
@@ -1,8 +1,6 @@
1
- from typing import Any
2
-
3
1
  import warp as wp
4
2
  from warp.fem import cache
5
- from warp.fem.types import Coords, ElementIndex, Sample, make_free_sample
3
+ from warp.fem.types import Coords, ElementIndex, Sample
6
4
 
7
5
  from .geometry import Geometry
8
6
 
@@ -53,8 +51,6 @@ class DeformedGeometry(Geometry):
53
51
 
54
52
  self.cell_position = self._make_cell_position()
55
53
  self.cell_deformation_gradient = self._make_cell_deformation_gradient()
56
- self.cell_inverse_deformation_gradient = self._make_cell_inverse_deformation_gradient()
57
- self.cell_measure = self._make_cell_measure()
58
54
 
59
55
  self.boundary_side_index = self.base.boundary_side_index
60
56
 
@@ -66,11 +62,8 @@ class DeformedGeometry(Geometry):
66
62
  self.side_inner_cell_coords = self._make_side_inner_cell_coords()
67
63
  self.side_outer_cell_coords = self._make_side_outer_cell_coords()
68
64
  self.side_from_cell_coords = self._make_side_from_cell_coords()
69
- self.side_inner_inverse_deformation_gradient = self._make_side_inner_inverse_deformation_gradient()
70
- self.side_outer_inverse_deformation_gradient = self._make_side_outer_inverse_deformation_gradient()
71
- self.side_measure = self._make_side_measure()
72
- self.side_measure_ratio = self._make_side_measure_ratio()
73
- self.side_normal = self._make_side_normal()
65
+
66
+ self._make_default_dependent_implementations()
74
67
 
75
68
  @property
76
69
  def name(self):
@@ -111,26 +104,6 @@ class DeformedGeometry(Geometry):
111
104
 
112
105
  return cell_deformation_gradient if self._relative else cell_deformation_gradient_absolute
113
106
 
114
- def _make_cell_inverse_deformation_gradient(self):
115
- @cache.dynamic_func(suffix=self.name)
116
- def cell_inverse_deformation_gradient(cell_arg: self.CellArg, s: Sample):
117
- return wp.inverse(self.cell_deformation_gradient(cell_arg, s))
118
-
119
- return cell_inverse_deformation_gradient
120
-
121
- def _make_cell_measure(self):
122
- REF_MEASURE = wp.constant(self.reference_cell().measure())
123
-
124
- @cache.dynamic_func(suffix=self.name)
125
- def cell_measure(args: self.CellArg, s: Sample):
126
- return wp.abs(wp.determinant(self.cell_deformation_gradient(args, s))) * REF_MEASURE
127
-
128
- return cell_measure
129
-
130
- @wp.func
131
- def cell_normal(args: Any, s: Sample):
132
- return wp.vec2(0.0)
133
-
134
107
  def _make_side_arg(self):
135
108
  @cache.dynamic_struct(suffix=self.name)
136
109
  class SideArg:
@@ -182,81 +155,6 @@ class DeformedGeometry(Geometry):
182
155
 
183
156
  return side_deformation_gradient if self._relative else side_deformation_gradient_absolute
184
157
 
185
- def _make_side_inner_inverse_deformation_gradient(self):
186
- @cache.dynamic_func(suffix=self.name)
187
- def side_inner_inverse_deformation_gradient(args: self.SideArg, s: Sample):
188
- cell_index = self.side_inner_cell_index(args, s.element_index)
189
- cell_coords = self.side_inner_cell_coords(args, s.element_index, s.element_coords)
190
- cell_arg = self.side_to_cell_arg(args)
191
- return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
192
-
193
- def _make_side_outer_inverse_deformation_gradient(self):
194
- @cache.dynamic_func(suffix=self.name)
195
- def side_outer_inverse_deformation_gradient(args: self.SideArg, s: Sample):
196
- cell_index = self.side_outer_cell_index(args, s.element_index)
197
- cell_coords = self.side_outer_cell_coords(args, s.element_index, s.element_coords)
198
- cell_arg = self.side_to_cell_arg(args)
199
- return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
200
-
201
- @wp.func
202
- def _side_measure(F: wp.vec2):
203
- return wp.length(F)
204
-
205
- @wp.func
206
- def _side_measure(F: _mat32):
207
- Fcross = wp.vec3(
208
- F[1, 0] * F[2, 1] - F[2, 0] * F[1, 1],
209
- F[2, 0] * F[0, 1] - F[0, 0] * F[2, 1],
210
- F[0, 0] * F[1, 1] - F[1, 0] * F[0, 1],
211
- )
212
- return wp.length(Fcross)
213
-
214
- @wp.func
215
- def _side_normal(F: wp.vec2):
216
- return wp.normalize(wp.vec2(-F[1], F[0]))
217
-
218
- @wp.func
219
- def _side_normal(F: _mat32):
220
- Fcross = wp.vec3(
221
- F[1, 0] * F[2, 1] - F[2, 0] * F[1, 1],
222
- F[2, 0] * F[0, 1] - F[0, 0] * F[2, 1],
223
- F[0, 0] * F[1, 1] - F[1, 0] * F[0, 1],
224
- )
225
- return wp.normalize(Fcross)
226
-
227
- def _make_side_measure(self):
228
- REF_MEASURE = wp.constant(self.reference_side().measure())
229
-
230
- @cache.dynamic_func(suffix=self.name)
231
- def side_measure(args: self.SideArg, s: Sample):
232
- F = self.side_deformation_gradient(args, s)
233
- return DeformedGeometry._side_measure(F) * REF_MEASURE
234
-
235
- return side_measure
236
-
237
- def _make_side_measure_ratio(self):
238
- @cache.dynamic_func(suffix=self.name)
239
- def side_measure_ratio(args: self.SideArg, s: Sample):
240
- inner = self.side_inner_cell_index(args, s.element_index)
241
- outer = self.side_outer_cell_index(args, s.element_index)
242
- inner_coords = self.side_inner_cell_coords(args, s.element_index, s.element_coords)
243
- outer_coords = self.side_outer_cell_coords(args, s.element_index, s.element_coords)
244
- cell_arg = self.side_to_cell_arg(args)
245
- return self.side_measure(args, s) / wp.min(
246
- self.cell_measure(cell_arg, make_free_sample(inner, inner_coords)),
247
- self.cell_measure(cell_arg, make_free_sample(outer, outer_coords)),
248
- )
249
-
250
- return side_measure_ratio
251
-
252
- def _make_side_normal(self):
253
- @cache.dynamic_func(suffix=self.name)
254
- def side_normal(args: self.SideArg, s: Sample):
255
- F = self.side_deformation_gradient(args, s)
256
- return DeformedGeometry._side_normal(F)
257
-
258
- return side_normal
259
-
260
158
  def _make_side_inner_cell_index(self):
261
159
  @cache.dynamic_func(suffix=self.name)
262
160
  def side_inner_cell_index(args: self.SideArg, side_index: ElementIndex):
@@ -5,6 +5,9 @@ from warp.fem.types import Coords
5
5
 
6
6
 
7
7
  class Element:
8
+ dimension = 0
9
+ """Intrinsic dimension of the element"""
10
+
8
11
  def measure() -> float:
9
12
  """Measure (area, volume, ...) of the reference element"""
10
13
  raise NotImplementedError
@@ -33,6 +36,8 @@ def _point_count_from_order(order: int, family: Polynomial):
33
36
 
34
37
 
35
38
  class Cube(Element):
39
+ dimension = 3
40
+
36
41
  @staticmethod
37
42
  def measure() -> float:
38
43
  return 1.0
@@ -52,6 +57,8 @@ class Cube(Element):
52
57
 
53
58
 
54
59
  class Square(Element):
60
+ dimension = 2
61
+
55
62
  @staticmethod
56
63
  def measure() -> float:
57
64
  return 1.0
@@ -71,6 +78,8 @@ class Square(Element):
71
78
 
72
79
 
73
80
  class LinearEdge(Element):
81
+ dimension = 1
82
+
74
83
  @staticmethod
75
84
  def measure() -> float:
76
85
  return 1.0
@@ -88,6 +97,8 @@ class LinearEdge(Element):
88
97
 
89
98
 
90
99
  class Triangle(Element):
100
+ dimension = 2
101
+
91
102
  @staticmethod
92
103
  def measure() -> float:
93
104
  return 0.5
@@ -430,6 +441,8 @@ class Triangle(Element):
430
441
 
431
442
 
432
443
  class Tetrahedron(Element):
444
+ dimension = 3
445
+
433
446
  @staticmethod
434
447
  def measure() -> float:
435
448
  return 1.0 / 6.0
@@ -1,10 +1,13 @@
1
1
  from typing import Any
2
2
 
3
3
  import warp as wp
4
- from warp.fem.types import Coords, ElementIndex, Sample
4
+ from warp.fem import cache
5
+ from warp.fem.types import Coords, ElementIndex, Sample, make_free_sample
5
6
 
6
7
  from .element import Element
7
8
 
9
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
10
+
8
11
 
9
12
  class Geometry:
10
13
  """
@@ -35,6 +38,11 @@ class Geometry:
35
38
  """Prototypical element for a side"""
36
39
  raise NotImplementedError
37
40
 
41
+ @property
42
+ def cell_dimension(self) -> int:
43
+ """Manifold dimension of the geometry cells"""
44
+ return self.reference_cell().dimension
45
+
38
46
  @property
39
47
  def name(self) -> str:
40
48
  return self.__class__.__name__
@@ -51,7 +59,6 @@ class Geometry:
51
59
  SideIndexArg: wp.codegen.Struct
52
60
  """Structure containing arguments to be passed to device functions for indexing sides"""
53
61
 
54
- @staticmethod
55
62
  def cell_arg_value(self, device) -> "Geometry.CellArg":
56
63
  """Value of the arguments to be passed to cell-related device functions"""
57
64
  raise NotImplementedError
@@ -67,7 +74,7 @@ class Geometry:
67
74
  raise NotImplementedError
68
75
 
69
76
  @staticmethod
70
- def cell_inverse_deformation_gradient(args: "Geometry.CellArg", cell_index: ElementIndex, coords: Coords):
77
+ def cell_inverse_deformation_gradient(args: "Geometry.CellArg", s: "Sample"):
71
78
  """Device function returning the matrix right-transforming a gradient w.r.t. cell space to a gradient w.r.t. world space
72
79
  (i.e. the inverse deformation gradient)
73
80
  """
@@ -99,7 +106,6 @@ class Geometry:
99
106
  For elements with the same dimension as the embedding space, this will be zero."""
100
107
  raise NotImplementedError
101
108
 
102
- @staticmethod
103
109
  def side_arg_value(self, device) -> "Geometry.SideArg":
104
110
  """Value of the arguments to be passed to side-related device functions"""
105
111
  raise NotImplementedError
@@ -115,12 +121,12 @@ class Geometry:
115
121
  raise NotImplementedError
116
122
 
117
123
  @staticmethod
118
- def side_deformation_gradient(args: "Geometry.CellArg", s: "Sample"):
119
- """Device function returning the gradient of world position with respect to reference cell"""
124
+ def side_deformation_gradient(args: "Geometry.SideArg", s: "Sample"):
125
+ """Device function returning the gradient of world position with respect to reference side"""
120
126
  raise NotImplementedError
121
127
 
122
128
  @staticmethod
123
- def side_inner_inverse_deformation_gradient(args: "Geometry.CellArg", side_index: ElementIndex, coords: Coords):
129
+ def side_inner_inverse_deformation_gradient(args: "Geometry.Siderg", side_index: ElementIndex, coords: Coords):
124
130
  """Device function returning the matrix right-transforming a gradient w.r.t. inner cell space to a gradient w.r.t. world space
125
131
  (i.e. the inverse deformation gradient)
126
132
  """
@@ -182,3 +188,155 @@ class Geometry:
182
188
  def side_to_cell_arg(side_arg: "Geometry.SideArg"):
183
189
  """Device function converting a side-related argument value to a cell-related argument value, for promoting trace samples to the full space"""
184
190
  raise NotImplementedError
191
+
192
+ # Default implementations for dependent quantities
193
+ # Can be overridden in derived classes if more efficient implementations exist
194
+
195
+ def _make_default_dependent_implementations(self):
196
+ self.cell_inverse_deformation_gradient = self._make_cell_inverse_deformation_gradient()
197
+ self.cell_measure = self._make_cell_measure()
198
+ self.cell_normal = self._make_cell_normal()
199
+
200
+ self.side_inner_inverse_deformation_gradient = self._make_side_inner_inverse_deformation_gradient()
201
+ self.side_outer_inverse_deformation_gradient = self._make_side_outer_inverse_deformation_gradient()
202
+ self.side_measure = self._make_side_measure()
203
+ self.side_measure_ratio = self._make_side_measure_ratio()
204
+ self.side_normal = self._make_side_normal()
205
+
206
+ @wp.func
207
+ def _element_measure(F: wp.vec2):
208
+ return wp.length(F)
209
+
210
+ @wp.func
211
+ def _element_measure(F: wp.vec3):
212
+ return wp.length(F)
213
+
214
+ @wp.func
215
+ def _element_measure(F: _mat32):
216
+ Ft = wp.transpose(F)
217
+ Fcross = wp.cross(Ft[0], Ft[1])
218
+ return wp.length(Fcross)
219
+
220
+ @wp.func
221
+ def _element_measure(F: wp.mat33):
222
+ return wp.abs(wp.determinant(F))
223
+
224
+ @wp.func
225
+ def _element_measure(F: wp.mat22):
226
+ return wp.abs(wp.determinant(F))
227
+
228
+ @wp.func
229
+ def _element_normal(F: wp.vec2):
230
+ return wp.normalize(wp.vec2(-F[1], F[0]))
231
+
232
+ @wp.func
233
+ def _element_normal(F: _mat32):
234
+ Ft = wp.transpose(F)
235
+ Fcross = wp.cross(Ft[0], Ft[1])
236
+ return wp.normalize(Fcross)
237
+
238
+ def _make_cell_measure(self):
239
+ REF_MEASURE = wp.constant(self.reference_cell().measure())
240
+
241
+ @cache.dynamic_func(suffix=self.name)
242
+ def cell_measure(args: self.CellArg, s: Sample):
243
+ F = self.cell_deformation_gradient(args, s)
244
+ return Geometry._element_measure(F) * REF_MEASURE
245
+
246
+ return cell_measure
247
+
248
+ def _make_cell_normal(self):
249
+ cell_dim = self.reference_cell().dimension
250
+ geo_dim = self.dimension
251
+ normal_vec = wp.vec(length=geo_dim, dtype=float)
252
+
253
+ @cache.dynamic_func(suffix=self.name)
254
+ def zero_normal(args: self.CellArg, s: Sample):
255
+ return normal_vec(0.0)
256
+
257
+ @cache.dynamic_func(suffix=self.name)
258
+ def cell_hyperplane_normal(args: self.CellArg, s: Sample):
259
+ F = self.cell_deformation_gradient(args, s)
260
+ return Geometry._element_normal(F)
261
+
262
+ if cell_dim == geo_dim:
263
+ return zero_normal
264
+ if cell_dim == geo_dim - 1:
265
+ return cell_hyperplane_normal
266
+
267
+ return None
268
+
269
+ def _make_cell_inverse_deformation_gradient(self):
270
+ cell_dim = self.reference_cell().dimension
271
+ geo_dim = self.dimension
272
+
273
+ @cache.dynamic_func(suffix=self.name)
274
+ def cell_inverse_deformation_gradient(cell_arg: self.CellArg, s: Sample):
275
+ return wp.inverse(self.cell_deformation_gradient(cell_arg, s))
276
+
277
+ @cache.dynamic_func(suffix=self.name)
278
+ def cell_pseudoinverse_deformation_gradient(cell_arg: self.CellArg, s: Sample):
279
+ F = self.cell_deformation_gradient(cell_arg, s)
280
+ Ft = wp.transpose(F)
281
+ return wp.inverse(Ft * F) * Ft
282
+
283
+ return cell_inverse_deformation_gradient if cell_dim == geo_dim else cell_pseudoinverse_deformation_gradient
284
+
285
+ def _make_side_measure(self):
286
+ REF_MEASURE = wp.constant(self.reference_side().measure())
287
+
288
+ @cache.dynamic_func(suffix=self.name)
289
+ def side_measure(args: self.SideArg, s: Sample):
290
+ F = self.side_deformation_gradient(args, s)
291
+ return Geometry._element_measure(F) * REF_MEASURE
292
+
293
+ return side_measure
294
+
295
+ def _make_side_measure_ratio(self):
296
+ @cache.dynamic_func(suffix=self.name)
297
+ def side_measure_ratio(args: self.SideArg, s: Sample):
298
+ inner = self.side_inner_cell_index(args, s.element_index)
299
+ outer = self.side_outer_cell_index(args, s.element_index)
300
+ inner_coords = self.side_inner_cell_coords(args, s.element_index, s.element_coords)
301
+ outer_coords = self.side_outer_cell_coords(args, s.element_index, s.element_coords)
302
+ cell_arg = self.side_to_cell_arg(args)
303
+ return self.side_measure(args, s) / wp.min(
304
+ self.cell_measure(cell_arg, make_free_sample(inner, inner_coords)),
305
+ self.cell_measure(cell_arg, make_free_sample(outer, outer_coords)),
306
+ )
307
+
308
+ return side_measure_ratio
309
+
310
+ def _make_side_normal(self):
311
+ side_dim = self.reference_side().dimension
312
+ geo_dim = self.dimension
313
+
314
+ @cache.dynamic_func(suffix=self.name)
315
+ def hyperplane_normal(args: self.SideArg, s: Sample):
316
+ F = self.side_deformation_gradient(args, s)
317
+ return Geometry._element_normal(F)
318
+
319
+ if side_dim == geo_dim - 1:
320
+ return hyperplane_normal
321
+
322
+ return None
323
+
324
+ def _make_side_inner_inverse_deformation_gradient(self):
325
+ @cache.dynamic_func(suffix=self.name)
326
+ def side_inner_inverse_deformation_gradient(args: self.SideArg, s: Sample):
327
+ cell_index = self.side_inner_cell_index(args, s.element_index)
328
+ cell_coords = self.side_inner_cell_coords(args, s.element_index, s.element_coords)
329
+ cell_arg = self.side_to_cell_arg(args)
330
+ return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
331
+
332
+ return side_inner_inverse_deformation_gradient
333
+
334
+ def _make_side_outer_inverse_deformation_gradient(self):
335
+ @cache.dynamic_func(suffix=self.name)
336
+ def side_outer_inverse_deformation_gradient(args: self.SideArg, s: Sample):
337
+ cell_index = self.side_outer_cell_index(args, s.element_index)
338
+ cell_coords = self.side_outer_cell_coords(args, s.element_index, s.element_coords)
339
+ cell_arg = self.side_to_cell_arg(args)
340
+ return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
341
+
342
+ return side_outer_inverse_deformation_gradient
@@ -29,7 +29,7 @@ class Grid2D(Geometry):
29
29
  Args:
30
30
  res: Resolution of the grid along each dimension
31
31
  bounds_lo: Position of the lower bound of the axis-aligned grid
32
- bounds_up: Position of the upper bound of the axis-aligned grid
32
+ bounds_hi: Position of the upper bound of the axis-aligned grid
33
33
  """
34
34
 
35
35
  if bounds_lo is None:
@@ -158,10 +158,7 @@ class Grid2D(Geometry):
158
158
  return Grid2D.Side(axis, origin)
159
159
 
160
160
  axis_side_index = side_index - 2 * arg.cell_count
161
- if axis_side_index < arg.axis_offsets[1]:
162
- axis = 0
163
- else:
164
- axis = 1
161
+ axis = wp.select(axis_side_index < arg.axis_offsets[1], 1, 0)
165
162
 
166
163
  altitude = arg.cell_arg.res[Grid2D.ROTATION[axis, 0]]
167
164
  longitude = axis_side_index - arg.axis_offsets[axis]
@@ -230,7 +227,7 @@ class Grid2D(Geometry):
230
227
 
231
228
  args.axis_offsets = wp.vec2i(
232
229
  0,
233
- self.res[0],
230
+ self.res[1],
234
231
  )
235
232
  args.cell_count = self.cell_count()
236
233
  args.cell_arg = self.cell_arg_value(device)
@@ -23,17 +23,13 @@ class Grid3D(Geometry):
23
23
 
24
24
  dimension = 3
25
25
 
26
- Permutation = wp.types.matrix(shape=(3, 3), dtype=int)
27
- LOC_TO_WORLD = wp.constant(Permutation(0, 1, 2, 1, 2, 0, 2, 0, 1))
28
- WORLD_TO_LOC = wp.constant(Permutation(0, 1, 2, 2, 0, 1, 1, 2, 0))
29
-
30
26
  def __init__(self, res: wp.vec3i, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
31
27
  """Constructs a dense 3D grid
32
28
 
33
29
  Args:
34
30
  res: Resolution of the grid along each dimension
35
31
  bounds_lo: Position of the lower bound of the axis-aligned grid
36
- bounds_up: Position of the upper bound of the axis-aligned grid
32
+ bounds_hi: Position of the upper bound of the axis-aligned grid
37
33
  """
38
34
 
39
35
  if bounds_lo is None:
@@ -148,28 +144,32 @@ class Grid3D(Geometry):
148
144
  @wp.func
149
145
  def _world_to_local(axis: int, vec: Any):
150
146
  return type(vec)(
151
- vec[Grid3D.LOC_TO_WORLD[axis, 0]],
152
- vec[Grid3D.LOC_TO_WORLD[axis, 1]],
153
- vec[Grid3D.LOC_TO_WORLD[axis, 2]],
147
+ vec[axis],
148
+ vec[(axis + 1) % 3],
149
+ vec[(axis + 2) % 3],
154
150
  )
155
151
 
156
152
  @wp.func
157
153
  def _local_to_world(axis: int, vec: Any):
158
154
  return type(vec)(
159
- vec[Grid3D.WORLD_TO_LOC[axis, 0]],
160
- vec[Grid3D.WORLD_TO_LOC[axis, 1]],
161
- vec[Grid3D.WORLD_TO_LOC[axis, 2]],
155
+ vec[(2 * axis) % 3],
156
+ vec[(2 * axis + 1) % 3],
157
+ vec[(2 * axis + 2) % 3],
162
158
  )
163
159
 
160
+ @wp.func
161
+ def _local_to_world_axis(axis: int, loc_index: Any):
162
+ return (axis + loc_index) % 3
163
+
164
164
  @wp.func
165
165
  def side_index(arg: SideArg, side: Side):
166
- alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
166
+ alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
167
167
  if side.origin[0] == arg.cell_arg.res[alt_axis]:
168
168
  # Upper-boundary side
169
169
  longitude = side.origin[1]
170
170
  latitude = side.origin[2]
171
171
 
172
- latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[side.axis, 2]]
172
+ latitude_res = arg.cell_arg.res[Grid3D._local_to_world_axis(side.axis, 2)]
173
173
  lat_long = latitude_res * longitude + latitude
174
174
 
175
175
  return 3 * arg.cell_count + arg.axis_offsets[side.axis] + lat_long
@@ -179,24 +179,27 @@ class Grid3D(Geometry):
179
179
 
180
180
  @wp.func
181
181
  def get_side(arg: SideArg, side_index: ElementIndex):
182
+ res = arg.cell_arg.res
183
+
182
184
  if side_index < 3 * arg.cell_count:
183
185
  axis = side_index // arg.cell_count
184
186
  cell_index = side_index - axis * arg.cell_count
185
- origin = Grid3D._world_to_local(axis, Grid3D.get_cell(arg.cell_arg.res, cell_index))
186
- return Grid3D.Side(axis, origin)
187
+ origin_loc = Grid3D._world_to_local(axis, Grid3D.get_cell(res, cell_index))
188
+ return Grid3D.Side(axis, origin_loc)
187
189
 
190
+ axis_offsets = arg.axis_offsets
188
191
  axis_side_index = side_index - 3 * arg.cell_count
189
- if axis_side_index < arg.axis_offsets[1]:
192
+ if axis_side_index < axis_offsets[1]:
190
193
  axis = 0
191
- elif axis_side_index < arg.axis_offsets[2]:
194
+ elif axis_side_index < axis_offsets[2]:
192
195
  axis = 1
193
196
  else:
194
197
  axis = 2
195
198
 
196
- altitude = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 0]]
199
+ altitude = res[Grid3D._local_to_world_axis(axis, 0)]
197
200
 
198
- lat_long = axis_side_index - arg.axis_offsets[axis]
199
- latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]
201
+ lat_long = axis_side_index - axis_offsets[axis]
202
+ latitude_res = res[Grid3D._local_to_world_axis(axis, 2)]
200
203
 
201
204
  longitude = lat_long // latitude_res
202
205
  latitude = lat_long - longitude * latitude_res
@@ -299,7 +302,7 @@ class Grid3D(Geometry):
299
302
  axis = 2
300
303
 
301
304
  lat_long = axis_side_index - args.axis_offsets[axis]
302
- latitude_res = args.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]
305
+ latitude_res = args.cell_arg.res[Grid3D._local_to_world_axis(axis, 2)]
303
306
 
304
307
  longitude = lat_long // latitude_res
305
308
  latitude = lat_long - longitude * latitude_res
@@ -347,14 +350,14 @@ class Grid3D(Geometry):
347
350
  @wp.func
348
351
  def side_measure(args: SideArg, s: Sample):
349
352
  side = Grid3D.get_side(args, s.element_index)
350
- long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
351
- lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
353
+ long_axis = Grid3D._local_to_world_axis(side.axis, 1)
354
+ lat_axis = Grid3D._local_to_world_axis(side.axis, 2)
352
355
  return args.cell_arg.cell_size[long_axis] * args.cell_arg.cell_size[lat_axis]
353
356
 
354
357
  @wp.func
355
358
  def side_measure_ratio(args: SideArg, s: Sample):
356
359
  side = Grid3D.get_side(args, s.element_index)
357
- alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
360
+ alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
358
361
  return 1.0 / args.cell_arg.cell_size[alt_axis]
359
362
 
360
363
  @wp.func
@@ -381,7 +384,7 @@ class Grid3D(Geometry):
381
384
  def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
382
385
  side = Grid3D.get_side(arg, side_index)
383
386
 
384
- alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
387
+ alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
385
388
 
386
389
  outer_alt = wp.select(
387
390
  side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
@@ -406,7 +409,7 @@ class Grid3D(Geometry):
406
409
  def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
407
410
  side = Grid3D.get_side(args, side_index)
408
411
 
409
- alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
412
+ alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
410
413
  outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
411
414
 
412
415
  side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
@@ -424,8 +427,8 @@ class Grid3D(Geometry):
424
427
  cell = Grid3D.get_cell(args.cell_arg.res, element_index)
425
428
 
426
429
  if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
427
- long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
428
- lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
430
+ long_axis = Grid3D._local_to_world_axis(side.axis, 1)
431
+ lat_axis = Grid3D._local_to_world_axis(side.axis, 2)
429
432
  long_coord = element_coords[long_axis]
430
433
  long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord)
431
434
  return Coords(long_coord, element_coords[lat_axis], 0.0)
@@ -7,7 +7,7 @@ from warp.fem.cache import (
7
7
  borrow_temporary_like,
8
8
  cached_arg_value,
9
9
  )
10
- from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
10
+ from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample
11
11
 
12
12
  from .element import Cube, Square
13
13
  from .geometry import Geometry
@@ -142,6 +142,8 @@ class Hexmesh(Geometry):
142
142
  self._edge_count = 0
143
143
  self._build_topology(temporary_store)
144
144
 
145
+ self._make_default_dependent_implementations()
146
+
145
147
  def cell_count(self):
146
148
  return self.hex_vertex_indices.shape[0]
147
149
 
@@ -246,18 +248,6 @@ class Hexmesh(Geometry):
246
248
  + wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
247
249
  )
248
250
 
249
- @wp.func
250
- def cell_inverse_deformation_gradient(cell_arg: CellArg, s: Sample):
251
- return wp.inverse(Hexmesh.cell_deformation_gradient(cell_arg, s))
252
-
253
- @wp.func
254
- def cell_measure(args: CellArg, s: Sample):
255
- return wp.abs(wp.determinant(Hexmesh.cell_deformation_gradient(args, s)))
256
-
257
- @wp.func
258
- def cell_normal(args: CellArg, s: Sample):
259
- return wp.vec3(0.0)
260
-
261
251
  @cached_arg_value
262
252
  def side_index_arg_value(self, device) -> SideIndexArg:
263
253
  args = self.SideIndexArg()
@@ -319,39 +309,6 @@ class Hexmesh(Geometry):
319
309
  v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
320
310
  return _mat32(v1, v2)
321
311
 
322
- @wp.func
323
- def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
324
- cell_index = Hexmesh.side_inner_cell_index(args, s.element_index)
325
- cell_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
326
- return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
327
-
328
- @wp.func
329
- def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
330
- cell_index = Hexmesh.side_outer_cell_index(args, s.element_index)
331
- cell_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
332
- return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
333
-
334
- @wp.func
335
- def side_measure(args: SideArg, s: Sample):
336
- v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
337
- return wp.length(wp.cross(v1, v2))
338
-
339
- @wp.func
340
- def side_measure_ratio(args: SideArg, s: Sample):
341
- inner = Hexmesh.side_inner_cell_index(args, s.element_index)
342
- outer = Hexmesh.side_outer_cell_index(args, s.element_index)
343
- inner_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
344
- outer_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
345
- return Hexmesh.side_measure(args, s) / wp.min(
346
- Hexmesh.cell_measure(args.cell_arg, make_free_sample(inner, inner_coords)),
347
- Hexmesh.cell_measure(args.cell_arg, make_free_sample(outer, outer_coords)),
348
- )
349
-
350
- @wp.func
351
- def side_normal(args: SideArg, s: Sample):
352
- v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
353
- return wp.normalize(wp.cross(v1, v2))
354
-
355
312
  @wp.func
356
313
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
357
314
  return arg.face_hex_indices[side_index][0]