warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,65 @@
1
- import warp as wp
1
+ from typing import Optional
2
2
 
3
- from warp.fem.types import ElementIndex, Coords, vec2i, vec3i, Sample
4
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, NULL_DOF_INDEX, NULL_QP_INDEX
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import (
11
+ NULL_ELEMENT_INDEX,
12
+ OUTSIDE,
13
+ Coords,
14
+ ElementIndex,
15
+ Sample,
16
+ make_free_sample,
17
+ )
5
18
 
6
- from .geometry import Geometry
7
- from .element import Triangle, Tetrahedron
8
19
  from .closest_point import project_on_tet_at_origin
20
+ from .element import Tetrahedron, Triangle
21
+ from .geometry import Geometry
9
22
 
10
23
 
11
24
  @wp.struct
12
- class TetmeshArg:
25
+ class TetmeshCellArg:
13
26
  tet_vertex_indices: wp.array2d(dtype=int)
14
27
  positions: wp.array(dtype=wp.vec3)
15
28
 
29
+ # for neighbor cell lookup
16
30
  vertex_tet_offsets: wp.array(dtype=int)
17
31
  vertex_tet_indices: wp.array(dtype=int)
18
32
 
19
- face_vertex_indices: wp.array(dtype=vec3i)
20
- face_tet_indices: wp.array(dtype=vec2i)
33
+ # for transforming reference gradient
34
+ deformation_gradients: wp.array(dtype=wp.mat33f)
35
+
36
+
37
+ @wp.struct
38
+ class TetmeshSideArg:
39
+ cell_arg: TetmeshCellArg
40
+ face_vertex_indices: wp.array(dtype=wp.vec3i)
41
+ face_tet_indices: wp.array(dtype=wp.vec2i)
42
+
43
+
44
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
21
45
 
22
46
 
23
47
  class Tetmesh(Geometry):
24
48
  """Tetrahedral mesh geometry"""
25
49
 
26
- def __init__(self, tet_vertex_indices: wp.array, positions: wp.array):
50
+ dimension = 3
51
+
52
+ def __init__(
53
+ self, tet_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
54
+ ):
27
55
  """
28
56
  Constructs a tetrahedral mesh.
29
57
 
30
58
  Args:
31
59
  tet_vertex_indices: warp array of shape (num_tets, 4) containing vertex indices for each tet
32
60
  positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
33
-
61
+ temporary_store: shared pool from which to allocate temporary arrays
34
62
  """
35
- self.dimension = 3
36
63
 
37
64
  self.tet_vertex_indices = tet_vertex_indices
38
65
  self.positions = positions
@@ -41,8 +68,12 @@ class Tetmesh(Geometry):
41
68
  self._face_tet_indices: wp.array = None
42
69
  self._vertex_tet_offsets: wp.array = None
43
70
  self._vertex_tet_indices: wp.array = None
71
+ self._tet_edge_indices: wp.array = None
72
+ self._edge_count = 0
73
+ self._build_topology(temporary_store)
44
74
 
45
- self._build_topology()
75
+ self._deformation_gradients: wp.array = None
76
+ self._compute_deformation_gradients()
46
77
 
47
78
  def cell_count(self):
48
79
  return self.tet_vertex_indices.shape[0]
@@ -53,17 +84,36 @@ class Tetmesh(Geometry):
53
84
  def side_count(self):
54
85
  return self._face_vertex_indices.shape[0]
55
86
 
87
+ def edge_count(self):
88
+ if self._tet_edge_indices is None:
89
+ self._compute_tet_edges()
90
+ return self._edge_count
91
+
56
92
  def boundary_side_count(self):
57
93
  return self._boundary_face_indices.shape[0]
58
94
 
59
- def reference_cell(self) -> Triangle:
95
+ def reference_cell(self) -> Tetrahedron:
60
96
  return Tetrahedron()
61
97
 
62
98
  def reference_side(self) -> Triangle:
63
99
  return Triangle()
64
100
 
65
- CellArg = TetmeshArg
66
- SideArg = TetmeshArg
101
+ @property
102
+ def tet_edge_indices(self) -> wp.array:
103
+ if self._tet_edge_indices is None:
104
+ self._compute_tet_edges()
105
+ return self._tet_edge_indices
106
+
107
+ @property
108
+ def face_tet_indices(self) -> wp.array:
109
+ return self._face_tet_indices
110
+
111
+ @property
112
+ def face_vertex_indices(self) -> wp.array:
113
+ return self._face_vertex_indices
114
+
115
+ CellArg = TetmeshCellArg
116
+ SideArg = TetmeshSideArg
67
117
 
68
118
  @wp.struct
69
119
  class SideIndexArg:
@@ -71,15 +121,15 @@ class Tetmesh(Geometry):
71
121
 
72
122
  # Geometry device interface
73
123
 
124
+ @cached_arg_value
74
125
  def cell_arg_value(self, device) -> CellArg:
75
126
  args = self.CellArg()
76
127
 
77
128
  args.tet_vertex_indices = self.tet_vertex_indices.to(device)
78
129
  args.positions = self.positions.to(device)
79
- args.face_vertex_indices = self._face_vertex_indices.to(device)
80
- args.face_tet_indices = self._face_tet_indices.to(device)
81
130
  args.vertex_tet_offsets = self._vertex_tet_offsets.to(device)
82
131
  args.vertex_tet_indices = self._vertex_tet_indices.to(device)
132
+ args.deformation_gradients = self._deformation_gradients.to(device)
83
133
 
84
134
  return args
85
135
 
@@ -94,6 +144,14 @@ class Tetmesh(Geometry):
94
144
  + s.element_coords[2] * args.positions[tet_idx[3]]
95
145
  )
96
146
 
147
+ @wp.func
148
+ def cell_deformation_gradient(args: CellArg, s: Sample):
149
+ return args.deformation_gradients[s.element_index]
150
+
151
+ @wp.func
152
+ def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
153
+ return wp.inverse(args.deformation_gradients[s.element_index])
154
+
97
155
  @wp.func
98
156
  def _project_on_tet(args: CellArg, pos: wp.vec3, tet_index: int):
99
157
  p0 = args.positions[args.tet_vertex_indices[tet_index, 0]]
@@ -125,28 +183,11 @@ class Tetmesh(Geometry):
125
183
  closest_tet = tet
126
184
  closest_coords = coords
127
185
 
128
- return Sample(closest_tet, closest_coords, NULL_QP_INDEX, 0.0, NULL_DOF_INDEX, NULL_DOF_INDEX)
129
-
130
- @wp.func
131
- def cell_measure(args: CellArg, cell_index: ElementIndex, coords: Coords):
132
- tet_idx = args.tet_vertex_indices[cell_index]
133
-
134
- v0 = args.positions[tet_idx[0]]
135
- v1 = args.positions[tet_idx[1]]
136
- v2 = args.positions[tet_idx[2]]
137
- v3 = args.positions[tet_idx[3]]
138
-
139
- mat = wp.mat33(
140
- v1 - v0,
141
- v2 - v0,
142
- v3 - v0,
143
- )
144
-
145
- return wp.abs(wp.determinant(mat)) / 6.0
186
+ return make_free_sample(closest_tet, closest_coords)
146
187
 
147
188
  @wp.func
148
189
  def cell_measure(args: CellArg, s: Sample):
149
- return Tetmesh.cell_measure(args, s.element_index, s.element_coords)
190
+ return wp.abs(wp.determinant(args.deformation_gradients[s.element_index])) / 6.0
150
191
 
151
192
  @wp.func
152
193
  def cell_measure_ratio(args: CellArg, s: Sample):
@@ -156,9 +197,7 @@ class Tetmesh(Geometry):
156
197
  def cell_normal(args: CellArg, s: Sample):
157
198
  return wp.vec3(0.0)
158
199
 
159
- def side_arg_value(self, device) -> SideArg:
160
- return self.cell_arg_value(device)
161
-
200
+ @cached_arg_value
162
201
  def side_index_arg_value(self, device) -> SideIndexArg:
163
202
  args = self.SideIndexArg()
164
203
 
@@ -172,53 +211,83 @@ class Tetmesh(Geometry):
172
211
 
173
212
  return args.boundary_face_indices[boundary_side_index]
174
213
 
214
+ @cached_arg_value
215
+ def side_arg_value(self, device) -> CellArg:
216
+ args = self.SideArg()
217
+
218
+ args.cell_arg = self.cell_arg_value(device)
219
+ args.face_vertex_indices = self._face_vertex_indices.to(device)
220
+ args.face_tet_indices = self._face_tet_indices.to(device)
221
+
222
+ return args
223
+
175
224
  @wp.func
176
225
  def side_position(args: SideArg, s: Sample):
177
226
  face_idx = args.face_vertex_indices[s.element_index]
178
227
  return (
179
- s.element_coords[0] * args.positions[face_idx[0]]
180
- + s.element_coords[1] * args.positions[face_idx[1]]
181
- + s.element_coords[2] * args.positions[face_idx[2]]
228
+ s.element_coords[0] * args.cell_arg.positions[face_idx[0]]
229
+ + s.element_coords[1] * args.cell_arg.positions[face_idx[1]]
230
+ + s.element_coords[2] * args.cell_arg.positions[face_idx[2]]
182
231
  )
183
232
 
184
233
  @wp.func
185
- def side_measure(args: SideArg, side_index: ElementIndex, coords: Coords):
234
+ def _side_vecs(args: SideArg, side_index: ElementIndex):
186
235
  face_idx = args.face_vertex_indices[side_index]
187
- v0 = args.positions[face_idx[0]]
188
- v1 = args.positions[face_idx[1]]
189
- v2 = args.positions[face_idx[2]]
236
+ v0 = args.cell_arg.positions[face_idx[0]]
237
+ v1 = args.cell_arg.positions[face_idx[1]]
238
+ v2 = args.cell_arg.positions[face_idx[2]]
190
239
 
191
- return 0.5 * wp.length(wp.cross(v1 - v0, v2 - v0))
240
+ return v1 - v0, v2 - v0
241
+
242
+ @wp.func
243
+ def side_deformation_gradient(args: SideArg, s: Sample):
244
+ e1, e2 = Tetmesh._side_vecs(args, s.element_index)
245
+ return _mat32(e1, e2)
246
+
247
+ @wp.func
248
+ def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
249
+ cell_index = Tetmesh.side_inner_cell_index(args, s.element_index)
250
+ return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
251
+
252
+ @wp.func
253
+ def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
254
+ cell_index = Tetmesh.side_outer_cell_index(args, s.element_index)
255
+ return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
192
256
 
193
257
  @wp.func
194
258
  def side_measure(args: SideArg, s: Sample):
195
- return Tetmesh.side_measure(args, s.element_index, s.element_coords)
259
+ e1, e2 = Tetmesh._side_vecs(args, s.element_index)
260
+ return 0.5 * wp.length(wp.cross(e1, e2))
196
261
 
197
262
  @wp.func
198
263
  def side_measure_ratio(args: SideArg, s: Sample):
199
264
  inner = Tetmesh.side_inner_cell_index(args, s.element_index)
200
265
  outer = Tetmesh.side_outer_cell_index(args, s.element_index)
201
266
  return Tetmesh.side_measure(args, s) / wp.min(
202
- Tetmesh.cell_measure(args, inner, Coords()),
203
- Tetmesh.cell_measure(args, outer, Coords()),
267
+ Tetmesh.cell_measure(args.cell_arg, make_free_sample(inner, Coords())),
268
+ Tetmesh.cell_measure(args.cell_arg, make_free_sample(outer, Coords())),
204
269
  )
205
270
 
206
271
  @wp.func
207
272
  def side_normal(args: SideArg, s: Sample):
208
- face_idx = args.face_vertex_indices[s.element_index]
209
- v0 = args.positions[face_idx[0]]
210
- v1 = args.positions[face_idx[1]]
211
- v2 = args.positions[face_idx[2]]
273
+ e1, e2 = Tetmesh._side_vecs(args, s.element_index)
274
+ return wp.normalize(wp.cross(e1, e2))
275
+
276
+ @wp.func
277
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
278
+ return arg.face_tet_indices[side_index][0]
212
279
 
213
- return wp.normalize(wp.cross(v1 - v0, v2 - v0))
280
+ @wp.func
281
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
282
+ return arg.face_tet_indices[side_index][1]
214
283
 
215
284
  @wp.func
216
285
  def face_to_tet_coords(args: SideArg, side_index: ElementIndex, tet_index: ElementIndex, side_coords: Coords):
217
286
  fvi = args.face_vertex_indices[side_index]
218
287
 
219
- tv1 = args.tet_vertex_indices[tet_index, 1]
220
- tv2 = args.tet_vertex_indices[tet_index, 2]
221
- tv3 = args.tet_vertex_indices[tet_index, 3]
288
+ tv1 = args.cell_arg.tet_vertex_indices[tet_index, 1]
289
+ tv2 = args.cell_arg.tet_vertex_indices[tet_index, 2]
290
+ tv3 = args.cell_arg.tet_vertex_indices[tet_index, 3]
222
291
 
223
292
  c1 = float(0.0)
224
293
  c2 = float(0.0)
@@ -235,12 +304,22 @@ class Tetmesh(Geometry):
235
304
  return Coords(c1, c2, c3)
236
305
 
237
306
  @wp.func
238
- def tet_to_face_coords(args: SideArg, side_index: ElementIndex, tet_index: ElementIndex, tet_coords: Coords):
307
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
308
+ inner_cell_index = Tetmesh.side_inner_cell_index(args, side_index)
309
+ return Tetmesh.face_to_tet_coords(args, side_index, inner_cell_index, side_coords)
310
+
311
+ @wp.func
312
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
313
+ outer_cell_index = Tetmesh.side_outer_cell_index(args, side_index)
314
+ return Tetmesh.face_to_tet_coords(args, side_index, outer_cell_index, side_coords)
315
+
316
+ @wp.func
317
+ def side_from_cell_coords(args: SideArg, side_index: ElementIndex, tet_index: ElementIndex, tet_coords: Coords):
239
318
  fvi = args.face_vertex_indices[side_index]
240
319
 
241
- tv1 = args.tet_vertex_indices[tet_index, 1]
242
- tv2 = args.tet_vertex_indices[tet_index, 2]
243
- tv3 = args.tet_vertex_indices[tet_index, 3]
320
+ tv1 = args.cell_arg.tet_vertex_indices[tet_index, 1]
321
+ tv2 = args.cell_arg.tet_vertex_indices[tet_index, 2]
322
+ tv3 = args.cell_arg.tet_vertex_indices[tet_index, 3]
244
323
 
245
324
  if tv1 == fvi[0]:
246
325
  c0 = tet_coords[0]
@@ -272,38 +351,39 @@ class Tetmesh(Geometry):
272
351
  return wp.select(c0 + c1 + c2 > 0.999, Coords(OUTSIDE), Coords(c0, c1, c2))
273
352
 
274
353
  @wp.func
275
- def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
276
- return arg.face_tet_indices[side_index][0]
277
-
278
- @wp.func
279
- def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
280
- return arg.face_tet_indices[side_index][1]
354
+ def side_to_cell_arg(side_arg: SideArg):
355
+ return side_arg.cell_arg
281
356
 
282
- def _build_topology(self):
283
- from warp.fem.utils import compress_node_indices, masked_indices, _get_pinned_temp_count_buffer
357
+ def _build_topology(self, temporary_store: TemporaryStore):
358
+ from warp.fem.utils import compress_node_indices, masked_indices
284
359
  from warp.utils import array_scan
285
360
 
286
361
  device = self.tet_vertex_indices.device
287
362
 
288
- self._vertex_tet_offsets, self._vertex_tet_indices, _, __ = compress_node_indices(
289
- self.vertex_count(), self.tet_vertex_indices
363
+ vertex_tet_offsets, vertex_tet_indices, _, __ = compress_node_indices(
364
+ self.vertex_count(), self.tet_vertex_indices, temporary_store=temporary_store
290
365
  )
366
+ self._vertex_tet_offsets = vertex_tet_offsets.detach()
367
+ self._vertex_tet_indices = vertex_tet_indices.detach()
291
368
 
292
- vertex_start_face_count = wp.zeros(dtype=int, device=device, shape=self.vertex_count())
293
- vertex_start_face_offsets = wp.empty_like(vertex_start_face_count)
369
+ vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
370
+ vertex_start_face_count.array.zero_()
371
+ vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
294
372
 
295
- vertex_face_other_vs = wp.empty(dtype=vec2i, device=device, shape=(4 * self.cell_count()))
296
- vertex_face_tets = wp.empty(dtype=int, device=device, shape=(4 * self.cell_count(), 2))
373
+ vertex_face_other_vs = borrow_temporary(
374
+ temporary_store, dtype=wp.vec2i, device=device, shape=(4 * self.cell_count())
375
+ )
376
+ vertex_face_tets = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count(), 2))
297
377
 
298
378
  # Count face edges starting at each vertex
299
379
  wp.launch(
300
380
  kernel=Tetmesh._count_starting_faces_kernel,
301
381
  device=device,
302
382
  dim=self.cell_count(),
303
- inputs=[self.tet_vertex_indices, vertex_start_face_count],
383
+ inputs=[self.tet_vertex_indices, vertex_start_face_count.array],
304
384
  )
305
385
 
306
- array_scan(in_array=vertex_start_face_count, out_array=vertex_start_face_offsets, inclusive=False)
386
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
307
387
 
308
388
  # Count number of unique edges (deduplicate across faces)
309
389
  vertex_unique_face_count = vertex_start_face_count
@@ -315,30 +395,32 @@ class Tetmesh(Geometry):
315
395
  self._vertex_tet_offsets,
316
396
  self._vertex_tet_indices,
317
397
  self.tet_vertex_indices,
318
- vertex_start_face_offsets,
319
- vertex_unique_face_count,
320
- vertex_face_other_vs,
321
- vertex_face_tets,
398
+ vertex_start_face_offsets.array,
399
+ vertex_unique_face_count.array,
400
+ vertex_face_other_vs.array,
401
+ vertex_face_tets.array,
322
402
  ],
323
403
  )
324
404
 
325
- vertex_unique_face_offsets = wp.empty_like(vertex_start_face_offsets)
326
- array_scan(in_array=vertex_start_face_count, out_array=vertex_unique_face_offsets, inclusive=False)
405
+ vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
406
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
327
407
 
328
408
  # Get back edge count to host
329
409
  if device.is_cuda:
330
- face_count = _get_pinned_temp_count_buffer(device)
410
+ face_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
331
411
  # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
332
- wp.copy(dest=face_count, src=vertex_unique_face_offsets, src_offset=self.vertex_count() - 1, count=1)
333
- wp.synchronize_stream(wp.get_stream())
334
- face_count = int(face_count.numpy()[0])
412
+ wp.copy(
413
+ dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
414
+ )
415
+ wp.synchronize_stream(wp.get_stream(device))
416
+ face_count = int(face_count.array.numpy()[0])
335
417
  else:
336
- face_count = int(vertex_unique_face_offsets.numpy()[self.vertex_count() - 1])
418
+ face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
337
419
 
338
- self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=vec3i, device=device)
339
- self._face_tet_indices = wp.empty(shape=(face_count,), dtype=vec2i, device=device)
420
+ self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec3i, device=device)
421
+ self._face_tet_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
340
422
 
341
- boundary_mask = wp.empty(shape=(face_count,), dtype=int, device=device)
423
+ boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
342
424
 
343
425
  # Compress edge data
344
426
  wp.launch(
@@ -346,17 +428,23 @@ class Tetmesh(Geometry):
346
428
  device=device,
347
429
  dim=self.vertex_count(),
348
430
  inputs=[
349
- vertex_start_face_offsets,
350
- vertex_unique_face_offsets,
351
- vertex_unique_face_count,
352
- vertex_face_other_vs,
353
- vertex_face_tets,
431
+ vertex_start_face_offsets.array,
432
+ vertex_unique_face_offsets.array,
433
+ vertex_unique_face_count.array,
434
+ vertex_face_other_vs.array,
435
+ vertex_face_tets.array,
354
436
  self._face_vertex_indices,
355
437
  self._face_tet_indices,
356
- boundary_mask,
438
+ boundary_mask.array,
357
439
  ],
358
440
  )
359
441
 
442
+ vertex_start_face_offsets.release()
443
+ vertex_unique_face_offsets.release()
444
+ vertex_unique_face_count.release()
445
+ vertex_face_other_vs.release()
446
+ vertex_face_tets.release()
447
+
360
448
  # Flip normals if necessary
361
449
  wp.launch(
362
450
  kernel=Tetmesh._flip_face_normals,
@@ -365,7 +453,101 @@ class Tetmesh(Geometry):
365
453
  inputs=[self._face_vertex_indices, self._face_tet_indices, self.tet_vertex_indices, self.positions],
366
454
  )
367
455
 
368
- self._boundary_face_indices, _ = masked_indices(boundary_mask)
456
+ boundary_face_indices, _ = masked_indices(boundary_mask.array)
457
+ self._boundary_face_indices = boundary_face_indices.detach()
458
+
459
+ def _compute_tet_edges(self, temporary_store: Optional[TemporaryStore] = None):
460
+ from warp.utils import array_scan
461
+
462
+ device = self.tet_vertex_indices.device
463
+
464
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
465
+ vertex_start_edge_count.array.zero_()
466
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
467
+
468
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(6 * self.cell_count()))
469
+
470
+ # Count face edges starting at each vertex
471
+ wp.launch(
472
+ kernel=Tetmesh._count_starting_edges_kernel,
473
+ device=device,
474
+ dim=self.cell_count(),
475
+ inputs=[self.tet_vertex_indices, vertex_start_edge_count.array],
476
+ )
477
+
478
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
479
+
480
+ # Count number of unique edges (deduplicate across faces)
481
+ vertex_unique_edge_count = vertex_start_edge_count
482
+ wp.launch(
483
+ kernel=Tetmesh._count_unique_starting_edges_kernel,
484
+ device=device,
485
+ dim=self.vertex_count(),
486
+ inputs=[
487
+ self._vertex_tet_offsets,
488
+ self._vertex_tet_indices,
489
+ self.tet_vertex_indices,
490
+ vertex_start_edge_offsets.array,
491
+ vertex_unique_edge_count.array,
492
+ vertex_edge_ends.array,
493
+ ],
494
+ )
495
+
496
+ vertex_unique_edge_offsets = borrow_temporary_like(
497
+ vertex_start_edge_offsets.array, temporary_store=temporary_store
498
+ )
499
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
500
+
501
+ # Get back edge count to host
502
+ if device.is_cuda:
503
+ edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
504
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
505
+ wp.copy(
506
+ dest=edge_count.array,
507
+ src=vertex_unique_edge_offsets.array,
508
+ src_offset=self.vertex_count() - 1,
509
+ count=1,
510
+ )
511
+ wp.synchronize_stream(wp.get_stream(device))
512
+ self._edge_count = int(edge_count.array.numpy()[0])
513
+ else:
514
+ self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
515
+
516
+ self._tet_edge_indices = wp.empty(
517
+ dtype=int, device=self.tet_vertex_indices.device, shape=(self.cell_count(), 6)
518
+ )
519
+
520
+ # Compress edge data
521
+ wp.launch(
522
+ kernel=Tetmesh._compress_edges_kernel,
523
+ device=device,
524
+ dim=self.vertex_count(),
525
+ inputs=[
526
+ self._vertex_tet_offsets,
527
+ self._vertex_tet_indices,
528
+ self.tet_vertex_indices,
529
+ vertex_start_edge_offsets.array,
530
+ vertex_unique_edge_offsets.array,
531
+ vertex_unique_edge_count.array,
532
+ vertex_edge_ends.array,
533
+ self._tet_edge_indices,
534
+ ],
535
+ )
536
+
537
+ vertex_start_edge_offsets.release()
538
+ vertex_unique_edge_offsets.release()
539
+ vertex_unique_edge_count.release()
540
+ vertex_edge_ends.release()
541
+
542
+ def _compute_deformation_gradients(self):
543
+ self._deformation_gradients = wp.empty(dtype=wp.mat33f, device=self.positions.device, shape=(self.cell_count()))
544
+
545
+ wp.launch(
546
+ kernel=Tetmesh._compute_deformation_gradients_kernel,
547
+ dim=self._deformation_gradients.shape,
548
+ device=self._deformation_gradients.device,
549
+ inputs=[self.tet_vertex_indices, self.positions, self._deformation_gradients],
550
+ )
369
551
 
370
552
  @wp.kernel
371
553
  def _count_starting_faces_kernel(
@@ -373,7 +555,9 @@ class Tetmesh(Geometry):
373
555
  ):
374
556
  t = wp.tid()
375
557
  for k in range(4):
376
- vi = vec3i(tet_vertex_indices[t, k], tet_vertex_indices[t, (k + 1) % 4], tet_vertex_indices[t, (k + 2) % 4])
558
+ vi = wp.vec3i(
559
+ tet_vertex_indices[t, k], tet_vertex_indices[t, (k + 1) % 4], tet_vertex_indices[t, (k + 2) % 4]
560
+ )
377
561
  vm = wp.min(vi)
378
562
 
379
563
  for i in range(3):
@@ -381,9 +565,9 @@ class Tetmesh(Geometry):
381
565
  wp.atomic_add(vertex_start_face_count, vm, 1)
382
566
 
383
567
  @wp.func
384
- def _find(
385
- needle: vec2i,
386
- values: wp.array(dtype=vec2i),
568
+ def _find_face(
569
+ needle: wp.vec2i,
570
+ values: wp.array(dtype=wp.vec2i),
387
571
  beg: int,
388
572
  end: int,
389
573
  ):
@@ -400,7 +584,7 @@ class Tetmesh(Geometry):
400
584
  tet_vertex_indices: wp.array2d(dtype=int),
401
585
  vertex_start_face_offsets: wp.array(dtype=int),
402
586
  vertex_start_face_count: wp.array(dtype=int),
403
- face_other_vs: wp.array(dtype=vec2i),
587
+ face_other_vs: wp.array(dtype=wp.vec2i),
404
588
  face_tets: wp.array2d(dtype=int),
405
589
  ):
406
590
  v = wp.tid()
@@ -416,7 +600,7 @@ class Tetmesh(Geometry):
416
600
  t = vertex_tet_indices[tet]
417
601
 
418
602
  for k in range(4):
419
- vi = vec3i(
603
+ vi = wp.vec3i(
420
604
  tet_vertex_indices[t, k], tet_vertex_indices[t, (k + 1) % 4], tet_vertex_indices[t, (k + 2) % 4]
421
605
  )
422
606
  min_v = wp.min(vi)
@@ -424,10 +608,10 @@ class Tetmesh(Geometry):
424
608
  if v == min_v:
425
609
  max_v = wp.max(vi)
426
610
  mid_v = vi[0] + vi[1] + vi[2] - min_v - max_v
427
- other_v = vec2i(mid_v, max_v)
611
+ other_v = wp.vec2i(mid_v, max_v)
428
612
 
429
613
  # Check if other_v has been seen
430
- seen_idx = Tetmesh._find(other_v, face_other_vs, face_beg, face_cur)
614
+ seen_idx = Tetmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
431
615
 
432
616
  if seen_idx == -1:
433
617
  face_other_vs[face_cur] = other_v
@@ -444,10 +628,10 @@ class Tetmesh(Geometry):
444
628
  vertex_start_face_offsets: wp.array(dtype=int),
445
629
  vertex_unique_face_offsets: wp.array(dtype=int),
446
630
  vertex_unique_face_count: wp.array(dtype=int),
447
- uncompressed_face_other_vs: wp.array(dtype=vec2i),
631
+ uncompressed_face_other_vs: wp.array(dtype=wp.vec2i),
448
632
  uncompressed_face_tets: wp.array2d(dtype=int),
449
- face_vertex_indices: wp.array(dtype=vec3i),
450
- face_tet_indices: wp.array(dtype=vec2i),
633
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
634
+ face_tet_indices: wp.array(dtype=wp.vec2i),
451
635
  boundary_mask: wp.array(dtype=int),
452
636
  ):
453
637
  v = wp.tid()
@@ -460,7 +644,7 @@ class Tetmesh(Geometry):
460
644
  src_index = start_beg + f
461
645
  face_index = unique_beg + f
462
646
 
463
- face_vertex_indices[face_index] = vec3i(
647
+ face_vertex_indices[face_index] = wp.vec3i(
464
648
  v,
465
649
  uncompressed_face_other_vs[src_index][0],
466
650
  uncompressed_face_other_vs[src_index][1],
@@ -468,7 +652,7 @@ class Tetmesh(Geometry):
468
652
 
469
653
  t0 = uncompressed_face_tets[src_index, 0]
470
654
  t1 = uncompressed_face_tets[src_index, 1]
471
- face_tet_indices[face_index] = vec2i(t0, t1)
655
+ face_tet_indices[face_index] = wp.vec2i(t0, t1)
472
656
  if t0 == t1:
473
657
  boundary_mask[face_index] = 1
474
658
  else:
@@ -476,8 +660,8 @@ class Tetmesh(Geometry):
476
660
 
477
661
  @wp.kernel
478
662
  def _flip_face_normals(
479
- face_vertex_indices: wp.array(dtype=vec3i),
480
- face_tet_indices: wp.array(dtype=vec2i),
663
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
664
+ face_tet_indices: wp.array(dtype=wp.vec2i),
481
665
  tet_vertex_indices: wp.array2d(dtype=int),
482
666
  positions: wp.array(dtype=wp.vec3),
483
667
  ):
@@ -501,4 +685,156 @@ class Tetmesh(Geometry):
501
685
 
502
686
  # if face normal points toward first tet centroid, flip indices
503
687
  if wp.dot(tet_centroid - face_center, face_normal) > 0.0:
504
- face_vertex_indices[e] = vec3i(face_vidx[0], face_vidx[2], face_vidx[1])
688
+ face_vertex_indices[e] = wp.vec3i(face_vidx[0], face_vidx[2], face_vidx[1])
689
+
690
+ @wp.kernel
691
+ def _count_starting_edges_kernel(
692
+ tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
693
+ ):
694
+ t = wp.tid()
695
+ for k in range(3):
696
+ v0 = tri_vertex_indices[t, k]
697
+ v1 = tri_vertex_indices[t, (k + 1) % 3]
698
+
699
+ if v0 < v1:
700
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
701
+ else:
702
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
703
+
704
+ for k in range(3):
705
+ v0 = tri_vertex_indices[t, k]
706
+ v1 = tri_vertex_indices[t, 3]
707
+
708
+ if v0 < v1:
709
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
710
+ else:
711
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
712
+
713
+ @wp.func
714
+ def _find_edge(
715
+ needle: int,
716
+ values: wp.array(dtype=int),
717
+ beg: int,
718
+ end: int,
719
+ ):
720
+ for i in range(beg, end):
721
+ if values[i] == needle:
722
+ return i
723
+
724
+ return -1
725
+
726
+ @wp.kernel
727
+ def _count_unique_starting_edges_kernel(
728
+ vertex_tet_offsets: wp.array(dtype=int),
729
+ vertex_tet_indices: wp.array(dtype=int),
730
+ tet_vertex_indices: wp.array2d(dtype=int),
731
+ vertex_start_edge_offsets: wp.array(dtype=int),
732
+ vertex_start_edge_count: wp.array(dtype=int),
733
+ edge_ends: wp.array(dtype=int),
734
+ ):
735
+ v = wp.tid()
736
+
737
+ edge_beg = vertex_start_edge_offsets[v]
738
+
739
+ tet_beg = vertex_tet_offsets[v]
740
+ tet_end = vertex_tet_offsets[v + 1]
741
+
742
+ edge_cur = edge_beg
743
+
744
+ for tet in range(tet_beg, tet_end):
745
+ t = vertex_tet_indices[tet]
746
+
747
+ for k in range(3):
748
+ v0 = tet_vertex_indices[t, k]
749
+ v1 = tet_vertex_indices[t, (k + 1) % 3]
750
+
751
+ if v == wp.min(v0, v1):
752
+ other_v = wp.max(v0, v1)
753
+ if Tetmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
754
+ edge_ends[edge_cur] = other_v
755
+ edge_cur += 1
756
+
757
+ for k in range(3):
758
+ v0 = tet_vertex_indices[t, k]
759
+ v1 = tet_vertex_indices[t, 3]
760
+
761
+ if v == wp.min(v0, v1):
762
+ other_v = wp.max(v0, v1)
763
+ if Tetmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
764
+ edge_ends[edge_cur] = other_v
765
+ edge_cur += 1
766
+
767
+ vertex_start_edge_count[v] = edge_cur - edge_beg
768
+
769
+ @wp.kernel
770
+ def _compress_edges_kernel(
771
+ vertex_tet_offsets: wp.array(dtype=int),
772
+ vertex_tet_indices: wp.array(dtype=int),
773
+ tet_vertex_indices: wp.array2d(dtype=int),
774
+ vertex_start_edge_offsets: wp.array(dtype=int),
775
+ vertex_unique_edge_offsets: wp.array(dtype=int),
776
+ vertex_unique_edge_count: wp.array(dtype=int),
777
+ uncompressed_edge_ends: wp.array(dtype=int),
778
+ tet_edge_indices: wp.array2d(dtype=int),
779
+ ):
780
+ v = wp.tid()
781
+
782
+ uncompressed_beg = vertex_start_edge_offsets[v]
783
+
784
+ unique_beg = vertex_unique_edge_offsets[v]
785
+ unique_count = vertex_unique_edge_count[v]
786
+
787
+ tet_beg = vertex_tet_offsets[v]
788
+ tet_end = vertex_tet_offsets[v + 1]
789
+
790
+ for tet in range(tet_beg, tet_end):
791
+ t = vertex_tet_indices[tet]
792
+
793
+ for k in range(3):
794
+ v0 = tet_vertex_indices[t, k]
795
+ v1 = tet_vertex_indices[t, (k + 1) % 3]
796
+
797
+ if v == wp.min(v0, v1):
798
+ other_v = wp.max(v0, v1)
799
+ edge_id = (
800
+ Tetmesh._find_edge(
801
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
802
+ )
803
+ - uncompressed_beg
804
+ + unique_beg
805
+ )
806
+ tet_edge_indices[t][k] = edge_id
807
+
808
+ for k in range(3):
809
+ v0 = tet_vertex_indices[t, k]
810
+ v1 = tet_vertex_indices[t, 3]
811
+
812
+ if v == wp.min(v0, v1):
813
+ other_v = wp.max(v0, v1)
814
+ edge_id = (
815
+ Tetmesh._find_edge(
816
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
817
+ )
818
+ - uncompressed_beg
819
+ + unique_beg
820
+ )
821
+ tet_edge_indices[t][k + 3] = edge_id
822
+
823
+ @wp.kernel
824
+ def _compute_deformation_gradients_kernel(
825
+ tet_vertex_indices: wp.array2d(dtype=int),
826
+ positions: wp.array(dtype=wp.vec3f),
827
+ transforms: wp.array(dtype=wp.mat33f),
828
+ ):
829
+ t = wp.tid()
830
+
831
+ p0 = positions[tet_vertex_indices[t, 0]]
832
+ p1 = positions[tet_vertex_indices[t, 1]]
833
+ p2 = positions[tet_vertex_indices[t, 2]]
834
+ p3 = positions[tet_vertex_indices[t, 3]]
835
+
836
+ e1 = p1 - p0
837
+ e2 = p2 - p0
838
+ e3 = p3 - p0
839
+
840
+ transforms[t] = wp.mat33(e1, e2, e3)