warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,953 @@
1
+ from typing import Optional
2
+
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
11
+
12
+ from .element import Cube, Square
13
+ from .geometry import Geometry
14
+
15
+
16
+ @wp.struct
17
+ class HexmeshCellArg:
18
+ hex_vertex_indices: wp.array2d(dtype=int)
19
+ positions: wp.array(dtype=wp.vec3)
20
+
21
+ # for neighbor cell lookup
22
+ vertex_hex_offsets: wp.array(dtype=int)
23
+ vertex_hex_indices: wp.array(dtype=int)
24
+
25
+
26
+ @wp.struct
27
+ class HexmeshSideArg:
28
+ cell_arg: HexmeshCellArg
29
+ face_vertex_indices: wp.array(dtype=wp.vec4i)
30
+ face_hex_indices: wp.array(dtype=wp.vec2i)
31
+ face_hex_face_orientation: wp.array(dtype=wp.vec4i)
32
+
33
+
34
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
35
+
36
+ FACE_VERTEX_INDICES = wp.constant(
37
+ wp.mat(shape=(6, 4), dtype=int)(
38
+ [
39
+ [0, 4, 7, 3], # x = 0
40
+ [1, 2, 6, 5], # x = 1
41
+ [0, 1, 5, 4], # y = 0
42
+ [3, 7, 6, 2], # y = 1
43
+ [0, 3, 2, 1], # z = 0
44
+ [4, 5, 6, 7], # z = 1
45
+ ]
46
+ )
47
+ )
48
+
49
+ EDGE_VERTEX_INDICES = wp.constant(
50
+ wp.mat(shape=(12, 2), dtype=int)(
51
+ [
52
+ [0, 1],
53
+ [1, 2],
54
+ [3, 2],
55
+ [0, 3],
56
+ [4, 5],
57
+ [5, 6],
58
+ [7, 6],
59
+ [4, 7],
60
+ [0, 4],
61
+ [1, 5],
62
+ [2, 6],
63
+ [3, 7],
64
+ ]
65
+ )
66
+ )
67
+
68
+ # orthogal transform for face coordinates given first vertex + winding
69
+ # (two rows per entry)
70
+
71
+ FACE_ORIENTATION = [
72
+ [1, 0], # FV: 0, det: +
73
+ [0, 1],
74
+ [0, 1], # FV: 0, det: -
75
+ [1, 0],
76
+ [0, -1], # FV: 1, det: +
77
+ [1, 0],
78
+ [-1, 0], # FV: 1, det: -
79
+ [0, 1],
80
+ [-1, 0], # FV: 2, det: +
81
+ [0, -1],
82
+ [0, -1], # FV: 2, det: -
83
+ [-1, 0],
84
+ [0, 1], # FV: 3, det: +
85
+ [-1, 0],
86
+ [1, 0], # FV: 3, det: -
87
+ [0, -1],
88
+ ]
89
+
90
+ FACE_TRANSLATION = [
91
+ [0, 0],
92
+ [1, 0],
93
+ [1, 1],
94
+ [0, 1],
95
+ ]
96
+
97
+ # local face coordinate system
98
+ _FACE_COORD_INDICES = wp.constant(
99
+ wp.mat(shape=(6, 4), dtype=int)(
100
+ [
101
+ [2, 1, 0, 0], # 0: z y -x
102
+ [1, 2, 0, 1], # 1: y z x-1
103
+ [0, 2, 1, 0], # 2: x z -y
104
+ [2, 0, 1, 1], # 3: z x y-1
105
+ [1, 0, 2, 0], # 4: y x -z
106
+ [0, 1, 2, 1], # 5: x y z-1
107
+ ]
108
+ )
109
+ )
110
+
111
+ _FACE_ORIENTATION_F = wp.constant(wp.mat(shape=(16, 2), dtype=float)(FACE_ORIENTATION))
112
+ _FACE_TRANSLATION_F = wp.constant(wp.mat(shape=(4, 2), dtype=float)(FACE_TRANSLATION))
113
+
114
+
115
+ class Hexmesh(Geometry):
116
+ """Hexahedral mesh geometry"""
117
+
118
+ dimension = 3
119
+
120
+ def __init__(
121
+ self, hex_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
122
+ ):
123
+ """
124
+ Constructs a tetrahedral mesh.
125
+
126
+ Args:
127
+ hex_vertex_indices: warp array of shape (num_hexes, 8) containing vertex indices for each hex
128
+ following standard ordering (bottom face vertices in counter-clockwise order, then similarly for upper face)
129
+ positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
130
+ temporary_store: shared pool from which to allocate temporary arrays
131
+ """
132
+
133
+ self.hex_vertex_indices = hex_vertex_indices
134
+ self.positions = positions
135
+
136
+ self._face_vertex_indices: wp.array = None
137
+ self._face_hex_indices: wp.array = None
138
+ self._face_hex_face_orientation: wp.array = None
139
+ self._vertex_hex_offsets: wp.array = None
140
+ self._vertex_hex_indices: wp.array = None
141
+ self._hex_edge_indices: wp.array = None
142
+ self._edge_count = 0
143
+ self._build_topology(temporary_store)
144
+
145
+ def cell_count(self):
146
+ return self.hex_vertex_indices.shape[0]
147
+
148
+ def vertex_count(self):
149
+ return self.positions.shape[0]
150
+
151
+ def side_count(self):
152
+ return self._face_vertex_indices.shape[0]
153
+
154
+ def edge_count(self):
155
+ if self._hex_edge_indices is None:
156
+ self._compute_hex_edges()
157
+ return self._edge_count
158
+
159
+ def boundary_side_count(self):
160
+ return self._boundary_face_indices.shape[0]
161
+
162
+ def reference_cell(self) -> Cube:
163
+ return Cube()
164
+
165
+ def reference_side(self) -> Square:
166
+ return Square()
167
+
168
+ @property
169
+ def hex_edge_indices(self) -> wp.array:
170
+ if self._hex_edge_indices is None:
171
+ self._compute_hex_edges()
172
+ return self._hex_edge_indices
173
+
174
+ @property
175
+ def face_hex_indices(self) -> wp.array:
176
+ return self._face_hex_indices
177
+
178
+ @property
179
+ def face_vertex_indices(self) -> wp.array:
180
+ return self._face_vertex_indices
181
+
182
+ CellArg = HexmeshCellArg
183
+ SideArg = HexmeshSideArg
184
+
185
+ @wp.struct
186
+ class SideIndexArg:
187
+ boundary_face_indices: wp.array(dtype=int)
188
+
189
+ # Geometry device interface
190
+
191
+ @cached_arg_value
192
+ def cell_arg_value(self, device) -> CellArg:
193
+ args = self.CellArg()
194
+
195
+ args.hex_vertex_indices = self.hex_vertex_indices.to(device)
196
+ args.positions = self.positions.to(device)
197
+ args.vertex_hex_offsets = self._vertex_hex_offsets.to(device)
198
+ args.vertex_hex_indices = self._vertex_hex_indices.to(device)
199
+
200
+ return args
201
+
202
+ @wp.func
203
+ def cell_position(args: CellArg, s: Sample):
204
+ hex_idx = args.hex_vertex_indices[s.element_index]
205
+
206
+ w_p = s.element_coords
207
+ w_m = Coords(1.0) - s.element_coords
208
+
209
+ # 0 : m m m
210
+ # 1 : p m m
211
+ # 2 : p p m
212
+ # 3 : m p m
213
+ # 4 : m m p
214
+ # 5 : p m p
215
+ # 6 : p p p
216
+ # 7 : m p p
217
+
218
+ return (
219
+ w_m[0] * w_m[1] * w_m[2] * args.positions[hex_idx[0]]
220
+ + w_p[0] * w_m[1] * w_m[2] * args.positions[hex_idx[1]]
221
+ + w_p[0] * w_p[1] * w_m[2] * args.positions[hex_idx[2]]
222
+ + w_m[0] * w_p[1] * w_m[2] * args.positions[hex_idx[3]]
223
+ + w_m[0] * w_m[1] * w_p[2] * args.positions[hex_idx[4]]
224
+ + w_p[0] * w_m[1] * w_p[2] * args.positions[hex_idx[5]]
225
+ + w_p[0] * w_p[1] * w_p[2] * args.positions[hex_idx[6]]
226
+ + w_m[0] * w_p[1] * w_p[2] * args.positions[hex_idx[7]]
227
+ )
228
+
229
+ @wp.func
230
+ def cell_deformation_gradient(cell_arg: CellArg, s: Sample):
231
+ """Deformation gradient at `coords`"""
232
+ """Transposed deformation gradient at `coords`"""
233
+ hex_idx = cell_arg.hex_vertex_indices[s.element_index]
234
+
235
+ w_p = s.element_coords
236
+ w_m = Coords(1.0) - s.element_coords
237
+
238
+ return (
239
+ wp.outer(cell_arg.positions[hex_idx[0]], wp.vec3(-w_m[1] * w_m[2], -w_m[0] * w_m[2], -w_m[0] * w_m[1]))
240
+ + wp.outer(cell_arg.positions[hex_idx[1]], wp.vec3(w_m[1] * w_m[2], -w_p[0] * w_m[2], -w_p[0] * w_m[1]))
241
+ + wp.outer(cell_arg.positions[hex_idx[2]], wp.vec3(w_p[1] * w_m[2], w_p[0] * w_m[2], -w_p[0] * w_p[1]))
242
+ + wp.outer(cell_arg.positions[hex_idx[3]], wp.vec3(-w_p[1] * w_m[2], w_m[0] * w_m[2], -w_m[0] * w_p[1]))
243
+ + wp.outer(cell_arg.positions[hex_idx[4]], wp.vec3(-w_m[1] * w_p[2], -w_m[0] * w_p[2], w_m[0] * w_m[1]))
244
+ + wp.outer(cell_arg.positions[hex_idx[5]], wp.vec3(w_m[1] * w_p[2], -w_p[0] * w_p[2], w_p[0] * w_m[1]))
245
+ + wp.outer(cell_arg.positions[hex_idx[6]], wp.vec3(w_p[1] * w_p[2], w_p[0] * w_p[2], w_p[0] * w_p[1]))
246
+ + wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
247
+ )
248
+
249
+ @wp.func
250
+ def cell_inverse_deformation_gradient(cell_arg: CellArg, s: Sample):
251
+ return wp.inverse(Hexmesh.cell_deformation_gradient(cell_arg, s))
252
+
253
+ @wp.func
254
+ def cell_measure(args: CellArg, s: Sample):
255
+ return wp.abs(wp.determinant(Hexmesh.cell_deformation_gradient(args, s)))
256
+
257
+ @wp.func
258
+ def cell_normal(args: CellArg, s: Sample):
259
+ return wp.vec3(0.0)
260
+
261
+ @cached_arg_value
262
+ def side_index_arg_value(self, device) -> SideIndexArg:
263
+ args = self.SideIndexArg()
264
+
265
+ args.boundary_face_indices = self._boundary_face_indices.to(device)
266
+
267
+ return args
268
+
269
+ @wp.func
270
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
271
+ """Boundary side to side index"""
272
+
273
+ return args.boundary_face_indices[boundary_side_index]
274
+
275
+ @cached_arg_value
276
+ def side_arg_value(self, device) -> CellArg:
277
+ args = self.SideArg()
278
+
279
+ args.cell_arg = self.cell_arg_value(device)
280
+ args.face_vertex_indices = self._face_vertex_indices.to(device)
281
+ args.face_hex_indices = self._face_hex_indices.to(device)
282
+ args.face_hex_face_orientation = self._face_hex_face_orientation.to(device)
283
+
284
+ return args
285
+
286
+ @wp.func
287
+ def side_position(args: SideArg, s: Sample):
288
+ face_idx = args.face_vertex_indices[s.element_index]
289
+
290
+ w_p = s.element_coords
291
+ w_m = Coords(1.0) - s.element_coords
292
+
293
+ return (
294
+ w_m[0] * w_m[1] * args.cell_arg.positions[face_idx[0]]
295
+ + w_p[0] * w_m[1] * args.cell_arg.positions[face_idx[1]]
296
+ + w_p[0] * w_p[1] * args.cell_arg.positions[face_idx[2]]
297
+ + w_m[0] * w_p[1] * args.cell_arg.positions[face_idx[3]]
298
+ )
299
+
300
+ @wp.func
301
+ def _side_deformation_vecs(args: SideArg, side_index: ElementIndex, coords: Coords):
302
+ face_idx = args.face_vertex_indices[side_index]
303
+
304
+ p0 = args.cell_arg.positions[face_idx[0]]
305
+ p1 = args.cell_arg.positions[face_idx[1]]
306
+ p2 = args.cell_arg.positions[face_idx[2]]
307
+ p3 = args.cell_arg.positions[face_idx[3]]
308
+
309
+ w_p = coords
310
+ w_m = Coords(1.0) - coords
311
+
312
+ v1 = w_m[1] * (p1 - p0) + w_p[1] * (p2 - p3)
313
+ v2 = w_p[0] * (p2 - p1) + w_m[0] * (p3 - p0)
314
+ return v1, v2
315
+
316
+ @wp.func
317
+ def side_deformation_gradient(args: SideArg, s:Sample):
318
+ """Transposed side deformation gradient at `coords`"""
319
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
320
+ return _mat32(v1, v2)
321
+
322
+ @wp.func
323
+ def side_inner_inverse_deformation_gradient(args: SideArg, s:Sample):
324
+ cell_index = Hexmesh.side_inner_cell_index(args, s.element_index)
325
+ cell_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
326
+ return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
327
+
328
+ @wp.func
329
+ def side_outer_inverse_deformation_gradient(args: SideArg, s:Sample):
330
+ cell_index = Hexmesh.side_outer_cell_index(args, s.element_index)
331
+ cell_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
332
+ return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
333
+
334
+ @wp.func
335
+ def side_measure(args: SideArg, s: Sample):
336
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
337
+ return wp.length(wp.cross(v1, v2))
338
+
339
+ @wp.func
340
+ def side_measure_ratio(args: SideArg, s: Sample):
341
+ inner = Hexmesh.side_inner_cell_index(args, s.element_index)
342
+ outer = Hexmesh.side_outer_cell_index(args, s.element_index)
343
+ inner_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
344
+ outer_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
345
+ return Hexmesh.side_measure(args, s) / wp.min(
346
+ Hexmesh.cell_measure(args.cell_arg, make_free_sample(inner, inner_coords)),
347
+ Hexmesh.cell_measure(args.cell_arg, make_free_sample(outer, outer_coords)),
348
+ )
349
+
350
+ @wp.func
351
+ def side_normal(args: SideArg, s: Sample):
352
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
353
+ return wp.normalize(wp.cross(v1, v2))
354
+
355
+ @wp.func
356
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
357
+ return arg.face_hex_indices[side_index][0]
358
+
359
+ @wp.func
360
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
361
+ return arg.face_hex_indices[side_index][1]
362
+
363
+ @wp.func
364
+ def _hex_local_face_coords(hex_coords: Coords, face_index: int):
365
+ # Coordinatex in local face coordinates system
366
+ # Sign of last coordinate (out of face)
367
+
368
+ face_coords = wp.vec2(
369
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]], hex_coords[_FACE_COORD_INDICES[face_index, 1]]
370
+ )
371
+
372
+ normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
373
+ normal_coord = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, normal_coord - 1.0, -normal_coord)
374
+
375
+ return face_coords, normal_coord
376
+
377
+ @wp.func
378
+ def _local_face_hex_coords(face_coords: wp.vec2, face_index: int):
379
+ # Coordinates in hex from local face coordinates system
380
+
381
+ hex_coords = Coords()
382
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
383
+ hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
384
+ hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, 1.0, 0.0)
385
+
386
+ return hex_coords
387
+
388
+ @wp.func
389
+ def _local_from_oriented_face_coords(ori: int, oriented_coords: Coords):
390
+ fv = ori // 2
391
+ return (oriented_coords[0] - _FACE_TRANSLATION_F[fv, 0]) * _FACE_ORIENTATION_F[2 * ori] + (
392
+ oriented_coords[1] - _FACE_TRANSLATION_F[fv, 1]
393
+ ) * _FACE_ORIENTATION_F[2 * ori + 1]
394
+
395
+ @wp.func
396
+ def _local_to_oriented_face_coords(ori: int, coords: wp.vec2):
397
+ fv = ori // 2
398
+ return Coords(
399
+ wp.dot(_FACE_ORIENTATION_F[2 * ori], coords) + _FACE_TRANSLATION_F[fv, 0],
400
+ wp.dot(_FACE_ORIENTATION_F[2 * ori + 1], coords) + _FACE_TRANSLATION_F[fv, 1],
401
+ 0.0,
402
+ )
403
+
404
+ @wp.func
405
+ def face_to_hex_coords(local_face_index: int, face_orientation: int, side_coords: Coords):
406
+ local_coords = Hexmesh._local_from_oriented_face_coords(face_orientation, side_coords)
407
+ return Hexmesh._local_face_hex_coords(local_coords, local_face_index)
408
+
409
+ @wp.func
410
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
411
+ local_face_index = args.face_hex_face_orientation[side_index][0]
412
+ face_orientation = args.face_hex_face_orientation[side_index][1]
413
+
414
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
415
+
416
+ @wp.func
417
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
418
+ local_face_index = args.face_hex_face_orientation[side_index][2]
419
+ face_orientation = args.face_hex_face_orientation[side_index][3]
420
+
421
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
422
+
423
+ @wp.func
424
+ def side_from_cell_coords(args: SideArg, side_index: ElementIndex, hex_index: ElementIndex, hex_coords: Coords):
425
+ if Hexmesh.side_inner_cell_index(args, side_index) == hex_index:
426
+ local_face_index = args.face_hex_face_orientation[side_index][0]
427
+ face_orientation = args.face_hex_face_orientation[side_index][1]
428
+ else:
429
+ local_face_index = args.face_hex_face_orientation[side_index][2]
430
+ face_orientation = args.face_hex_face_orientation[side_index][3]
431
+
432
+ face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
433
+ return wp.select(
434
+ normal_coord == 0.0, Coords(OUTSIDE), Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords)
435
+ )
436
+
437
+ @wp.func
438
+ def side_to_cell_arg(side_arg: SideArg):
439
+ return side_arg.cell_arg
440
+
441
+ def _build_topology(self, temporary_store: TemporaryStore):
442
+ from warp.fem.utils import compress_node_indices, masked_indices
443
+ from warp.utils import array_scan
444
+
445
+ device = self.hex_vertex_indices.device
446
+
447
+ vertex_hex_offsets, vertex_hex_indices, _, __ = compress_node_indices(
448
+ self.vertex_count(), self.hex_vertex_indices, temporary_store=temporary_store
449
+ )
450
+ self._vertex_hex_offsets = vertex_hex_offsets.detach()
451
+ self._vertex_hex_indices = vertex_hex_indices.detach()
452
+
453
+ vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
454
+ vertex_start_face_count.array.zero_()
455
+ vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
456
+
457
+ vertex_face_other_vs = borrow_temporary(
458
+ temporary_store, dtype=wp.vec3i, device=device, shape=(8 * self.cell_count())
459
+ )
460
+ vertex_face_hexes = borrow_temporary(
461
+ temporary_store, dtype=int, device=device, shape=(8 * self.cell_count(), 2)
462
+ )
463
+
464
+ # Count face edges starting at each vertex
465
+ wp.launch(
466
+ kernel=Hexmesh._count_starting_faces_kernel,
467
+ device=device,
468
+ dim=self.cell_count(),
469
+ inputs=[self.hex_vertex_indices, vertex_start_face_count.array],
470
+ )
471
+
472
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
473
+
474
+ # Count number of unique edges (deduplicate across faces)
475
+ vertex_unique_face_count = vertex_start_face_count
476
+ wp.launch(
477
+ kernel=Hexmesh._count_unique_starting_faces_kernel,
478
+ device=device,
479
+ dim=self.vertex_count(),
480
+ inputs=[
481
+ self._vertex_hex_offsets,
482
+ self._vertex_hex_indices,
483
+ self.hex_vertex_indices,
484
+ vertex_start_face_offsets.array,
485
+ vertex_unique_face_count.array,
486
+ vertex_face_other_vs.array,
487
+ vertex_face_hexes.array,
488
+ ],
489
+ )
490
+
491
+ vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
492
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
493
+
494
+ # Get back edge count to host
495
+ if device.is_cuda:
496
+ face_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
497
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
498
+ wp.copy(
499
+ dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
500
+ )
501
+ wp.synchronize_stream(wp.get_stream(device))
502
+ face_count = int(face_count.array.numpy()[0])
503
+ else:
504
+ face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
505
+
506
+ self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
507
+ self._face_hex_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
508
+ self._face_hex_face_orientation = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
509
+
510
+ boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
511
+
512
+ # Compress edge data
513
+ wp.launch(
514
+ kernel=Hexmesh._compress_faces_kernel,
515
+ device=device,
516
+ dim=self.vertex_count(),
517
+ inputs=[
518
+ vertex_start_face_offsets.array,
519
+ vertex_unique_face_offsets.array,
520
+ vertex_unique_face_count.array,
521
+ vertex_face_other_vs.array,
522
+ vertex_face_hexes.array,
523
+ self._face_vertex_indices,
524
+ self._face_hex_indices,
525
+ boundary_mask.array,
526
+ ],
527
+ )
528
+
529
+ vertex_start_face_offsets.release()
530
+ vertex_unique_face_offsets.release()
531
+ vertex_unique_face_count.release()
532
+ vertex_face_other_vs.release()
533
+ vertex_face_hexes.release()
534
+
535
+ # Flip normals if necessary
536
+ wp.launch(
537
+ kernel=Hexmesh._flip_face_normals,
538
+ device=device,
539
+ dim=self.side_count(),
540
+ inputs=[self._face_vertex_indices, self._face_hex_indices, self.hex_vertex_indices, self.positions],
541
+ )
542
+
543
+ # Compute and store face orientation
544
+ wp.launch(
545
+ kernel=Hexmesh._compute_face_orientation,
546
+ device=device,
547
+ dim=self.side_count(),
548
+ inputs=[
549
+ self._face_vertex_indices,
550
+ self._face_hex_indices,
551
+ self.hex_vertex_indices,
552
+ self._face_hex_face_orientation,
553
+ ],
554
+ )
555
+
556
+ boundary_face_indices, _ = masked_indices(boundary_mask.array)
557
+ self._boundary_face_indices = boundary_face_indices.detach()
558
+
559
+ def _compute_hex_edges(self, temporary_store: Optional[TemporaryStore] = None):
560
+ from warp.utils import array_scan
561
+
562
+ device = self.hex_vertex_indices.device
563
+
564
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
565
+ vertex_start_edge_count.array.zero_()
566
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
567
+
568
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(12 * self.cell_count()))
569
+
570
+ # Count face edges starting at each vertex
571
+ wp.launch(
572
+ kernel=Hexmesh._count_starting_edges_kernel,
573
+ device=device,
574
+ dim=self.cell_count(),
575
+ inputs=[self.hex_vertex_indices, vertex_start_edge_count.array],
576
+ )
577
+
578
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
579
+
580
+ # Count number of unique edges (deduplicate across faces)
581
+ vertex_unique_edge_count = vertex_start_edge_count
582
+ wp.launch(
583
+ kernel=Hexmesh._count_unique_starting_edges_kernel,
584
+ device=device,
585
+ dim=self.vertex_count(),
586
+ inputs=[
587
+ self._vertex_hex_offsets,
588
+ self._vertex_hex_indices,
589
+ self.hex_vertex_indices,
590
+ vertex_start_edge_offsets.array,
591
+ vertex_unique_edge_count.array,
592
+ vertex_edge_ends.array,
593
+ ],
594
+ )
595
+
596
+ vertex_unique_edge_offsets = borrow_temporary_like(
597
+ vertex_start_edge_offsets.array, temporary_store=temporary_store
598
+ )
599
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
600
+
601
+ # Get back edge count to host
602
+ if device.is_cuda:
603
+ edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
604
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
605
+ wp.copy(
606
+ dest=edge_count.array,
607
+ src=vertex_unique_edge_offsets.array,
608
+ src_offset=self.vertex_count() - 1,
609
+ count=1,
610
+ )
611
+ wp.synchronize_stream(wp.get_stream(device))
612
+ self._edge_count = int(edge_count.array.numpy()[0])
613
+ else:
614
+ self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
615
+
616
+ self._hex_edge_indices = wp.empty(
617
+ dtype=int, device=self.hex_vertex_indices.device, shape=(self.cell_count(), 12)
618
+ )
619
+
620
+ # Compress edge data
621
+ wp.launch(
622
+ kernel=Hexmesh._compress_edges_kernel,
623
+ device=device,
624
+ dim=self.vertex_count(),
625
+ inputs=[
626
+ self._vertex_hex_offsets,
627
+ self._vertex_hex_indices,
628
+ self.hex_vertex_indices,
629
+ vertex_start_edge_offsets.array,
630
+ vertex_unique_edge_offsets.array,
631
+ vertex_unique_edge_count.array,
632
+ vertex_edge_ends.array,
633
+ self._hex_edge_indices,
634
+ ],
635
+ )
636
+
637
+ vertex_start_edge_offsets.release()
638
+ vertex_unique_edge_offsets.release()
639
+ vertex_unique_edge_count.release()
640
+ vertex_edge_ends.release()
641
+
642
+ @wp.kernel
643
+ def _count_starting_faces_kernel(
644
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
645
+ ):
646
+ t = wp.tid()
647
+ for k in range(6):
648
+ vi = wp.vec4i(
649
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 0]],
650
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 1]],
651
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 2]],
652
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 3]],
653
+ )
654
+ vm = wp.min(vi)
655
+
656
+ for i in range(4):
657
+ if vm == vi[i]:
658
+ wp.atomic_add(vertex_start_face_count, vm, 1)
659
+
660
+ @wp.func
661
+ def _face_sort(vidx: wp.vec4i, min_k: int):
662
+ v1 = vidx[(min_k + 1) % 4]
663
+ v2 = vidx[(min_k + 2) % 4]
664
+ v3 = vidx[(min_k + 3) % 4]
665
+
666
+ if v1 < v3:
667
+ return wp.vec3i(v1, v2, v3)
668
+ return wp.vec3i(v3, v2, v1)
669
+
670
+ @wp.func
671
+ def _find_face(
672
+ needle: wp.vec3i,
673
+ values: wp.array(dtype=wp.vec3i),
674
+ beg: int,
675
+ end: int,
676
+ ):
677
+ for i in range(beg, end):
678
+ if values[i] == needle:
679
+ return i
680
+
681
+ return -1
682
+
683
+ @wp.kernel
684
+ def _count_unique_starting_faces_kernel(
685
+ vertex_hex_offsets: wp.array(dtype=int),
686
+ vertex_hex_indices: wp.array(dtype=int),
687
+ hex_vertex_indices: wp.array2d(dtype=int),
688
+ vertex_start_face_offsets: wp.array(dtype=int),
689
+ vertex_start_face_count: wp.array(dtype=int),
690
+ face_other_vs: wp.array(dtype=wp.vec3i),
691
+ face_hexes: wp.array2d(dtype=int),
692
+ ):
693
+ v = wp.tid()
694
+
695
+ face_beg = vertex_start_face_offsets[v]
696
+
697
+ hex_beg = vertex_hex_offsets[v]
698
+ hex_end = vertex_hex_offsets[v + 1]
699
+
700
+ face_cur = face_beg
701
+
702
+ for hexa in range(hex_beg, hex_end):
703
+ hx = vertex_hex_indices[hexa]
704
+
705
+ for k in range(6):
706
+ vi = wp.vec4i(
707
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 0]],
708
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 1]],
709
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 2]],
710
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 3]],
711
+ )
712
+ min_i = int(wp.argmin(vi))
713
+
714
+ if v == vi[min_i]:
715
+ other_v = Hexmesh._face_sort(vi, min_i)
716
+
717
+ # Check if other_v has been seen
718
+ seen_idx = Hexmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
719
+
720
+ if seen_idx == -1:
721
+ face_other_vs[face_cur] = other_v
722
+ face_hexes[face_cur, 0] = hx
723
+ face_hexes[face_cur, 1] = hx
724
+ face_cur += 1
725
+ else:
726
+ face_hexes[seen_idx, 1] = hx
727
+
728
+ vertex_start_face_count[v] = face_cur - face_beg
729
+
730
+ @wp.kernel
731
+ def _compress_faces_kernel(
732
+ vertex_start_face_offsets: wp.array(dtype=int),
733
+ vertex_unique_face_offsets: wp.array(dtype=int),
734
+ vertex_unique_face_count: wp.array(dtype=int),
735
+ uncompressed_face_other_vs: wp.array(dtype=wp.vec3i),
736
+ uncompressed_face_hexes: wp.array2d(dtype=int),
737
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
738
+ face_hex_indices: wp.array(dtype=wp.vec2i),
739
+ boundary_mask: wp.array(dtype=int),
740
+ ):
741
+ v = wp.tid()
742
+
743
+ start_beg = vertex_start_face_offsets[v]
744
+ unique_beg = vertex_unique_face_offsets[v]
745
+ unique_count = vertex_unique_face_count[v]
746
+
747
+ for f in range(unique_count):
748
+ src_index = start_beg + f
749
+ face_index = unique_beg + f
750
+
751
+ face_vertex_indices[face_index] = wp.vec4i(
752
+ v,
753
+ uncompressed_face_other_vs[src_index][0],
754
+ uncompressed_face_other_vs[src_index][1],
755
+ uncompressed_face_other_vs[src_index][2],
756
+ )
757
+
758
+ hx0 = uncompressed_face_hexes[src_index, 0]
759
+ hx1 = uncompressed_face_hexes[src_index, 1]
760
+ face_hex_indices[face_index] = wp.vec2i(hx0, hx1)
761
+ if hx0 == hx1:
762
+ boundary_mask[face_index] = 1
763
+ else:
764
+ boundary_mask[face_index] = 0
765
+
766
+ @wp.kernel
767
+ def _flip_face_normals(
768
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
769
+ face_hex_indices: wp.array(dtype=wp.vec2i),
770
+ hex_vertex_indices: wp.array2d(dtype=int),
771
+ positions: wp.array(dtype=wp.vec3),
772
+ ):
773
+ f = wp.tid()
774
+
775
+ hexa = face_hex_indices[f][0]
776
+
777
+ hex_vidx = hex_vertex_indices[hexa]
778
+ face_vidx = face_vertex_indices[f]
779
+
780
+ hex_centroid = (
781
+ positions[hex_vidx[0]]
782
+ + positions[hex_vidx[1]]
783
+ + positions[hex_vidx[2]]
784
+ + positions[hex_vidx[3]]
785
+ + positions[hex_vidx[4]]
786
+ + positions[hex_vidx[5]]
787
+ + positions[hex_vidx[6]]
788
+ + positions[hex_vidx[7]]
789
+ ) / 8.0
790
+
791
+ v0 = positions[face_vidx[0]]
792
+ v1 = positions[face_vidx[1]]
793
+ v2 = positions[face_vidx[2]]
794
+ v3 = positions[face_vidx[3]]
795
+
796
+ face_center = (v1 + v0 + v2 + v3) / 4.0
797
+ face_normal = wp.cross(v2 - v0, v3 - v1)
798
+
799
+ # if face normal points toward first tet centroid, flip indices
800
+ if wp.dot(hex_centroid - face_center, face_normal) > 0.0:
801
+ face_vertex_indices[f] = wp.vec4i(face_vidx[0], face_vidx[3], face_vidx[2], face_vidx[1])
802
+
803
+ @wp.func
804
+ def _find_face_orientation(face_vidx: wp.vec4i, hex_index: int, hex_vertex_indices: wp.array2d(dtype=int)):
805
+ hex_vidx = hex_vertex_indices[hex_index]
806
+
807
+ # Find local index in hex corresponding to face
808
+
809
+ face_min_i = int(wp.argmin(face_vidx))
810
+ face_other_v = Hexmesh._face_sort(face_vidx, face_min_i)
811
+
812
+ for k in range(6):
813
+ hex_face_vi = wp.vec4i(
814
+ hex_vidx[FACE_VERTEX_INDICES[k, 0]],
815
+ hex_vidx[FACE_VERTEX_INDICES[k, 1]],
816
+ hex_vidx[FACE_VERTEX_INDICES[k, 2]],
817
+ hex_vidx[FACE_VERTEX_INDICES[k, 3]],
818
+ )
819
+ hex_min_i = int(wp.argmin(hex_face_vi))
820
+ hex_other_v = Hexmesh._face_sort(hex_face_vi, hex_min_i)
821
+
822
+ if hex_other_v == face_other_v:
823
+ local_face_index = k
824
+ break
825
+
826
+ # Find starting vertex index
827
+ for k in range(4):
828
+ if face_vidx[k] == hex_face_vi[0]:
829
+ face_orientation = 2 * k
830
+ if face_vidx[(k + 1) % 4] != hex_face_vi[1]:
831
+ face_orientation += 1
832
+
833
+ return local_face_index, face_orientation
834
+
835
+ @wp.kernel
836
+ def _compute_face_orientation(
837
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
838
+ face_hex_indices: wp.array(dtype=wp.vec2i),
839
+ hex_vertex_indices: wp.array2d(dtype=int),
840
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
841
+ ):
842
+ f = wp.tid()
843
+
844
+ face_vidx = face_vertex_indices[f]
845
+
846
+ hx0 = face_hex_indices[f][0]
847
+ local_face_0, ori_0 = Hexmesh._find_face_orientation(face_vidx, hx0, hex_vertex_indices)
848
+
849
+ hx1 = face_hex_indices[f][1]
850
+ if hx0 == hx1:
851
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_0, ori_0)
852
+ else:
853
+ local_face_1, ori_1 = Hexmesh._find_face_orientation(face_vidx, hx1, hex_vertex_indices)
854
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_1, ori_1)
855
+
856
+ @wp.kernel
857
+ def _count_starting_edges_kernel(
858
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
859
+ ):
860
+ t = wp.tid()
861
+ for k in range(12):
862
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
863
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
864
+
865
+ if v0 < v1:
866
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
867
+ else:
868
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
869
+
870
+ @wp.func
871
+ def _find_edge(
872
+ needle: int,
873
+ values: wp.array(dtype=int),
874
+ beg: int,
875
+ end: int,
876
+ ):
877
+ for i in range(beg, end):
878
+ if values[i] == needle:
879
+ return i
880
+
881
+ return -1
882
+
883
+ @wp.kernel
884
+ def _count_unique_starting_edges_kernel(
885
+ vertex_hex_offsets: wp.array(dtype=int),
886
+ vertex_hex_indices: wp.array(dtype=int),
887
+ hex_vertex_indices: wp.array2d(dtype=int),
888
+ vertex_start_edge_offsets: wp.array(dtype=int),
889
+ vertex_start_edge_count: wp.array(dtype=int),
890
+ edge_ends: wp.array(dtype=int),
891
+ ):
892
+ v = wp.tid()
893
+
894
+ edge_beg = vertex_start_edge_offsets[v]
895
+
896
+ hex_beg = vertex_hex_offsets[v]
897
+ hex_end = vertex_hex_offsets[v + 1]
898
+
899
+ edge_cur = edge_beg
900
+
901
+ for tet in range(hex_beg, hex_end):
902
+ t = vertex_hex_indices[tet]
903
+
904
+ for k in range(12):
905
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
906
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
907
+
908
+ if v == wp.min(v0, v1):
909
+ other_v = wp.max(v0, v1)
910
+ if Hexmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
911
+ edge_ends[edge_cur] = other_v
912
+ edge_cur += 1
913
+
914
+ vertex_start_edge_count[v] = edge_cur - edge_beg
915
+
916
+ @wp.kernel
917
+ def _compress_edges_kernel(
918
+ vertex_hex_offsets: wp.array(dtype=int),
919
+ vertex_hex_indices: wp.array(dtype=int),
920
+ hex_vertex_indices: wp.array2d(dtype=int),
921
+ vertex_start_edge_offsets: wp.array(dtype=int),
922
+ vertex_unique_edge_offsets: wp.array(dtype=int),
923
+ vertex_unique_edge_count: wp.array(dtype=int),
924
+ uncompressed_edge_ends: wp.array(dtype=int),
925
+ hex_edge_indices: wp.array2d(dtype=int),
926
+ ):
927
+ v = wp.tid()
928
+
929
+ uncompressed_beg = vertex_start_edge_offsets[v]
930
+
931
+ unique_beg = vertex_unique_edge_offsets[v]
932
+ unique_count = vertex_unique_edge_count[v]
933
+
934
+ hex_beg = vertex_hex_offsets[v]
935
+ hex_end = vertex_hex_offsets[v + 1]
936
+
937
+ for tet in range(hex_beg, hex_end):
938
+ t = vertex_hex_indices[tet]
939
+
940
+ for k in range(12):
941
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
942
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
943
+
944
+ if v == wp.min(v0, v1):
945
+ other_v = wp.max(v0, v1)
946
+ edge_id = (
947
+ Hexmesh._find_edge(
948
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
949
+ )
950
+ - uncompressed_beg
951
+ + unique_beg
952
+ )
953
+ hex_edge_indices[t][k] = edge_id