warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,61 @@
1
- import warp as wp
1
+ from typing import Optional
2
2
 
3
- from warp.fem.types import ElementIndex, Coords, vec2i, Sample
4
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, NULL_DOF_INDEX, NULL_QP_INDEX
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import (
11
+ NULL_ELEMENT_INDEX,
12
+ OUTSIDE,
13
+ Coords,
14
+ ElementIndex,
15
+ Sample,
16
+ make_free_sample,
17
+ )
5
18
 
6
- from .geometry import Geometry
7
- from .element import Triangle, LinearEdge
8
19
  from .closest_point import project_on_tri_at_origin
20
+ from .element import LinearEdge, Triangle
21
+ from .geometry import Geometry
9
22
 
10
23
 
11
24
  @wp.struct
12
- class Trimesh2DArg:
25
+ class Trimesh2DCellArg:
13
26
  tri_vertex_indices: wp.array2d(dtype=int)
14
27
  positions: wp.array(dtype=wp.vec2)
15
28
 
29
+ # for neighbor cell lookup
16
30
  vertex_tri_offsets: wp.array(dtype=int)
17
31
  vertex_tri_indices: wp.array(dtype=int)
18
32
 
19
- edge_vertex_indices: wp.array(dtype=vec2i)
20
- edge_tri_indices: wp.array(dtype=vec2i)
33
+ deformation_gradients: wp.array(dtype=wp.mat22f)
34
+
35
+
36
+ @wp.struct
37
+ class Trimesh2DSideArg:
38
+ cell_arg: Trimesh2DCellArg
39
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
40
+ edge_tri_indices: wp.array(dtype=wp.vec2i)
21
41
 
22
42
 
23
43
  class Trimesh2D(Geometry):
24
44
  """Two-dimensional triangular mesh geometry"""
25
45
 
26
- def __init__(self, tri_vertex_indices: wp.array, positions: wp.array):
46
+ dimension = 2
47
+
48
+ def __init__(
49
+ self, tri_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
50
+ ):
27
51
  """
28
52
  Constructs a two-dimensional triangular mesh.
29
53
 
30
54
  Args:
31
55
  tri_vertex_indices: warp array of shape (num_tris, 3) containing vertex indices for each tri
32
56
  positions: warp array of shape (num_vertices, 2) containing 2d position for each vertex
33
-
57
+ temporary_store: shared pool from which to allocate temporary arrays
34
58
  """
35
- self.dimension = 2
36
59
 
37
60
  self.tri_vertex_indices = tri_vertex_indices
38
61
  self.positions = positions
@@ -41,8 +64,10 @@ class Trimesh2D(Geometry):
41
64
  self._edge_tri_indices: wp.array = None
42
65
  self._vertex_tri_offsets: wp.array = None
43
66
  self._vertex_tri_indices: wp.array = None
67
+ self._build_topology(temporary_store)
44
68
 
45
- self._build_topology()
69
+ self._deformation_gradients: wp.array = None
70
+ self._compute_deformation_gradients()
46
71
 
47
72
  def cell_count(self):
48
73
  return self.tri_vertex_indices.shape[0]
@@ -62,8 +87,16 @@ class Trimesh2D(Geometry):
62
87
  def reference_side(self) -> LinearEdge:
63
88
  return LinearEdge()
64
89
 
65
- CellArg = Trimesh2DArg
66
- SideArg = Trimesh2DArg
90
+ @property
91
+ def edge_tri_indices(self) -> wp.array:
92
+ return self._edge_tri_indices
93
+
94
+ @property
95
+ def edge_vertex_indices(self) -> wp.array:
96
+ return self._edge_vertex_indices
97
+
98
+ CellArg = Trimesh2DCellArg
99
+ SideArg = Trimesh2DSideArg
67
100
 
68
101
  @wp.struct
69
102
  class SideIndexArg:
@@ -71,15 +104,15 @@ class Trimesh2D(Geometry):
71
104
 
72
105
  # Geometry device interface
73
106
 
107
+ @cached_arg_value
74
108
  def cell_arg_value(self, device) -> CellArg:
75
109
  args = self.CellArg()
76
110
 
77
111
  args.tri_vertex_indices = self.tri_vertex_indices.to(device)
78
112
  args.positions = self.positions.to(device)
79
- args.edge_vertex_indices = self._edge_vertex_indices.to(device)
80
- args.edge_tri_indices = self._edge_tri_indices.to(device)
81
113
  args.vertex_tri_offsets = self._vertex_tri_offsets.to(device)
82
114
  args.vertex_tri_indices = self._vertex_tri_indices.to(device)
115
+ args.deformation_gradients = self._deformation_gradients.to(device)
83
116
 
84
117
  return args
85
118
 
@@ -92,6 +125,14 @@ class Trimesh2D(Geometry):
92
125
  + s.element_coords[2] * args.positions[tri_idx[2]]
93
126
  )
94
127
 
128
+ @wp.func
129
+ def cell_deformation_gradient(args: CellArg, s: Sample):
130
+ return args.deformation_gradients[s.element_index]
131
+
132
+ @wp.func
133
+ def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
134
+ return wp.inverse(args.deformation_gradients[s.element_index])
135
+
95
136
  @wp.func
96
137
  def _project_on_tri(args: CellArg, pos: wp.vec2, tri_index: int):
97
138
  p0 = args.positions[args.tri_vertex_indices[tri_index, 0]]
@@ -122,36 +163,17 @@ class Trimesh2D(Geometry):
122
163
  closest_tri = tri
123
164
  closest_coords = coords
124
165
 
125
- return Sample(closest_tri, closest_coords, NULL_QP_INDEX, 0.0, NULL_DOF_INDEX, NULL_DOF_INDEX)
126
-
127
- @wp.func
128
- def cell_measure(args: CellArg, cell_index: ElementIndex, coords: Coords):
129
- tri_idx = args.tri_vertex_indices[cell_index]
130
-
131
- v0 = args.positions[tri_idx[0]]
132
- v1 = args.positions[tri_idx[1]]
133
- v2 = args.positions[tri_idx[2]]
134
-
135
- e1 = v1 - v0
136
- e2 = v2 - v0
137
-
138
- return 0.5 * wp.abs(e1[0] * e2[1] - e1[1] * e2[0])
166
+ return make_free_sample(closest_tri, closest_coords)
139
167
 
140
168
  @wp.func
141
169
  def cell_measure(args: CellArg, s: Sample):
142
- return Trimesh2D.cell_measure(args, s.element_index, s.element_coords)
143
-
144
- @wp.func
145
- def cell_measure_ratio(args: CellArg, s: Sample):
146
- return 1.0
170
+ return 0.5 * wp.abs(wp.determinant(args.deformation_gradients[s.element_index]))
147
171
 
148
172
  @wp.func
149
173
  def cell_normal(args: CellArg, s: Sample):
150
174
  return wp.vec2(0.0)
151
175
 
152
- def side_arg_value(self, device) -> SideArg:
153
- return self.cell_arg_value(device)
154
-
176
+ @cached_arg_value
155
177
  def side_index_arg_value(self, device) -> SideIndexArg:
156
178
  args = self.SideIndexArg()
157
179
 
@@ -165,38 +187,61 @@ class Trimesh2D(Geometry):
165
187
 
166
188
  return args.boundary_edge_indices[boundary_side_index]
167
189
 
190
+ @cached_arg_value
191
+ def side_arg_value(self, device) -> CellArg:
192
+ args = self.SideArg()
193
+
194
+ args.cell_arg = self.cell_arg_value(device)
195
+ args.edge_vertex_indices = self._edge_vertex_indices.to(device)
196
+ args.edge_tri_indices = self._edge_tri_indices.to(device)
197
+
198
+ return args
199
+
168
200
  @wp.func
169
201
  def side_position(args: SideArg, s: Sample):
170
202
  edge_idx = args.edge_vertex_indices[s.element_index]
171
- return (1.0 - s.element_coords[0]) * args.positions[edge_idx[0]] + s.element_coords[0] * args.positions[
172
- edge_idx[1]
173
- ]
203
+ return (1.0 - s.element_coords[0]) * args.cell_arg.positions[edge_idx[0]] + s.element_coords[
204
+ 0
205
+ ] * args.cell_arg.positions[edge_idx[1]]
174
206
 
175
207
  @wp.func
176
- def side_measure(args: SideArg, side_index: ElementIndex, coords: Coords):
177
- edge_idx = args.edge_vertex_indices[side_index]
178
- v0 = args.positions[edge_idx[0]]
179
- v1 = args.positions[edge_idx[1]]
180
- return wp.length(v1 - v0)
208
+ def side_deformation_gradient(args: SideArg, s: Sample):
209
+ edge_idx = args.edge_vertex_indices[s.element_index]
210
+ v0 = args.cell_arg.positions[edge_idx[0]]
211
+ v1 = args.cell_arg.positions[edge_idx[1]]
212
+ return v1 - v0
213
+
214
+ @wp.func
215
+ def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
216
+ cell_index = Trimesh2D.side_inner_cell_index(args, s.element_index)
217
+ return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
218
+
219
+ @wp.func
220
+ def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
221
+ cell_index = Trimesh2D.side_outer_cell_index(args, s.element_index)
222
+ return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
181
223
 
182
224
  @wp.func
183
225
  def side_measure(args: SideArg, s: Sample):
184
- return Trimesh2D.side_measure(args, s.element_index, s.element_coords)
226
+ edge_idx = args.edge_vertex_indices[s.element_index]
227
+ v0 = args.cell_arg.positions[edge_idx[0]]
228
+ v1 = args.cell_arg.positions[edge_idx[1]]
229
+ return wp.length(v1 - v0)
185
230
 
186
231
  @wp.func
187
232
  def side_measure_ratio(args: SideArg, s: Sample):
188
233
  inner = Trimesh2D.side_inner_cell_index(args, s.element_index)
189
234
  outer = Trimesh2D.side_outer_cell_index(args, s.element_index)
190
235
  return Trimesh2D.side_measure(args, s) / wp.min(
191
- Trimesh2D.cell_measure(args, inner, Coords()),
192
- Trimesh2D.cell_measure(args, outer, Coords()),
236
+ Trimesh2D.cell_measure(args.cell_arg, make_free_sample(inner, Coords())),
237
+ Trimesh2D.cell_measure(args.cell_arg, make_free_sample(outer, Coords())),
193
238
  )
194
239
 
195
240
  @wp.func
196
241
  def side_normal(args: SideArg, s: Sample):
197
242
  edge_idx = args.edge_vertex_indices[s.element_index]
198
- v0 = args.positions[edge_idx[0]]
199
- v1 = args.positions[edge_idx[1]]
243
+ v0 = args.cell_arg.positions[edge_idx[0]]
244
+ v1 = args.cell_arg.positions[edge_idx[1]]
200
245
  e = v1 - v0
201
246
 
202
247
  return wp.normalize(wp.vec2(-e[1], e[0]))
@@ -212,7 +257,7 @@ class Trimesh2D(Geometry):
212
257
  @wp.func
213
258
  def edge_to_tri_coords(args: SideArg, side_index: ElementIndex, tri_index: ElementIndex, side_coords: Coords):
214
259
  edge_vidx = args.edge_vertex_indices[side_index]
215
- tri_vidx = args.tri_vertex_indices[tri_index]
260
+ tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
216
261
 
217
262
  v0 = tri_vidx[0]
218
263
  v1 = tri_vidx[1]
@@ -238,9 +283,24 @@ class Trimesh2D(Geometry):
238
283
  return Coords(cx, cy, cz)
239
284
 
240
285
  @wp.func
241
- def tri_to_edge_coords(args: SideArg, side_index: ElementIndex, tri_index: ElementIndex, tri_coords: Coords):
286
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
287
+ inner_cell_index = Trimesh2D.side_inner_cell_index(args, side_index)
288
+ return Trimesh2D.edge_to_tri_coords(args, side_index, inner_cell_index, side_coords)
289
+
290
+ @wp.func
291
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
292
+ outer_cell_index = Trimesh2D.side_outer_cell_index(args, side_index)
293
+ return Trimesh2D.edge_to_tri_coords(args, side_index, outer_cell_index, side_coords)
294
+
295
+ @wp.func
296
+ def side_from_cell_coords(
297
+ args: SideArg,
298
+ side_index: ElementIndex,
299
+ tri_index: ElementIndex,
300
+ tri_coords: Coords,
301
+ ):
242
302
  edge_vidx = args.edge_vertex_indices[side_index]
243
- tri_vidx = args.tri_vertex_indices[tri_index]
303
+ tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
244
304
 
245
305
  start = int(2)
246
306
  end = int(2)
@@ -256,31 +316,38 @@ class Trimesh2D(Geometry):
256
316
  tri_coords[start] + tri_coords[end] > 0.999, Coords(OUTSIDE), Coords(tri_coords[end], 0.0, 0.0)
257
317
  )
258
318
 
259
- def _build_topology(self):
260
- from warp.fem.utils import compress_node_indices, masked_indices, _get_pinned_temp_count_buffer
319
+ @wp.func
320
+ def side_to_cell_arg(side_arg: SideArg):
321
+ return side_arg.cell_arg
322
+
323
+ def _build_topology(self, temporary_store: TemporaryStore):
324
+ from warp.fem.utils import compress_node_indices, masked_indices
261
325
  from warp.utils import array_scan
262
326
 
263
327
  device = self.tri_vertex_indices.device
264
328
 
265
- self._vertex_tri_offsets, self._vertex_tri_indices, _, __ = compress_node_indices(
266
- self.vertex_count(), self.tri_vertex_indices
329
+ vertex_tri_offsets, vertex_tri_indices, _, __ = compress_node_indices(
330
+ self.vertex_count(), self.tri_vertex_indices, temporary_store=temporary_store
267
331
  )
332
+ self._vertex_tri_offsets = vertex_tri_offsets.detach()
333
+ self._vertex_tri_indices = vertex_tri_indices.detach()
268
334
 
269
- vertex_start_edge_count = wp.zeros(dtype=int, device=device, shape=self.vertex_count())
270
- vertex_start_edge_offsets = wp.empty_like(vertex_start_edge_count)
335
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
336
+ vertex_start_edge_count.array.zero_()
337
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
271
338
 
272
- vertex_edge_ends = wp.empty(dtype=int, device=device, shape=(3 * self.cell_count()))
273
- vertex_edge_tris = wp.empty(dtype=int, device=device, shape=(3 * self.cell_count(), 2))
339
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count()))
340
+ vertex_edge_tris = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count(), 2))
274
341
 
275
342
  # Count face edges starting at each vertex
276
343
  wp.launch(
277
344
  kernel=Trimesh2D._count_starting_edges_kernel,
278
345
  device=device,
279
346
  dim=self.cell_count(),
280
- inputs=[self.tri_vertex_indices, vertex_start_edge_count],
347
+ inputs=[self.tri_vertex_indices, vertex_start_edge_count.array],
281
348
  )
282
349
 
283
- array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
350
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
284
351
 
285
352
  # Count number of unique edges (deduplicate across faces)
286
353
  vertex_unique_edge_count = vertex_start_edge_count
@@ -292,30 +359,32 @@ class Trimesh2D(Geometry):
292
359
  self._vertex_tri_offsets,
293
360
  self._vertex_tri_indices,
294
361
  self.tri_vertex_indices,
295
- vertex_start_edge_offsets,
296
- vertex_unique_edge_count,
297
- vertex_edge_ends,
298
- vertex_edge_tris,
362
+ vertex_start_edge_offsets.array,
363
+ vertex_unique_edge_count.array,
364
+ vertex_edge_ends.array,
365
+ vertex_edge_tris.array,
299
366
  ],
300
367
  )
301
368
 
302
- vertex_unique_edge_offsets = wp.empty_like(vertex_start_edge_offsets)
303
- array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
369
+ vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
370
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
304
371
 
305
372
  # Get back edge count to host
306
373
  if device.is_cuda:
307
- edge_count = _get_pinned_temp_count_buffer(device)
374
+ edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
308
375
  # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
309
- wp.copy(dest=edge_count, src=vertex_unique_edge_offsets, src_offset=self.vertex_count() - 1, count=1)
310
- wp.synchronize_stream(wp.get_stream())
311
- edge_count = int(edge_count.numpy()[0])
376
+ wp.copy(
377
+ dest=edge_count.array, src=vertex_unique_edge_offsets.array, src_offset=self.vertex_count() - 1, count=1
378
+ )
379
+ wp.synchronize_stream(wp.get_stream(device))
380
+ edge_count = int(edge_count.array.numpy()[0])
312
381
  else:
313
- edge_count = int(vertex_unique_edge_offsets.numpy()[self.vertex_count() - 1])
382
+ edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
314
383
 
315
- self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=vec2i, device=device)
316
- self._edge_tri_indices = wp.empty(shape=(edge_count,), dtype=vec2i, device=device)
384
+ self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
385
+ self._edge_tri_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
317
386
 
318
- boundary_mask = wp.empty(shape=(edge_count,), dtype=int, device=device)
387
+ boundary_mask = borrow_temporary(temporary_store=temporary_store, shape=(edge_count,), dtype=int, device=device)
319
388
 
320
389
  # Compress edge data
321
390
  wp.launch(
@@ -323,17 +392,23 @@ class Trimesh2D(Geometry):
323
392
  device=device,
324
393
  dim=self.vertex_count(),
325
394
  inputs=[
326
- vertex_start_edge_offsets,
327
- vertex_unique_edge_offsets,
328
- vertex_unique_edge_count,
329
- vertex_edge_ends,
330
- vertex_edge_tris,
395
+ vertex_start_edge_offsets.array,
396
+ vertex_unique_edge_offsets.array,
397
+ vertex_unique_edge_count.array,
398
+ vertex_edge_ends.array,
399
+ vertex_edge_tris.array,
331
400
  self._edge_vertex_indices,
332
401
  self._edge_tri_indices,
333
- boundary_mask,
402
+ boundary_mask.array,
334
403
  ],
335
404
  )
336
405
 
406
+ vertex_start_edge_offsets.release()
407
+ vertex_unique_edge_offsets.release()
408
+ vertex_unique_edge_count.release()
409
+ vertex_edge_ends.release()
410
+ vertex_edge_tris.release()
411
+
337
412
  # Flip normals if necessary
338
413
  wp.launch(
339
414
  kernel=Trimesh2D._flip_edge_normals,
@@ -342,7 +417,20 @@ class Trimesh2D(Geometry):
342
417
  inputs=[self._edge_vertex_indices, self._edge_tri_indices, self.tri_vertex_indices, self.positions],
343
418
  )
344
419
 
345
- self._boundary_edge_indices, _ = masked_indices(boundary_mask)
420
+ boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
421
+ self._boundary_edge_indices = boundary_edge_indices.detach()
422
+
423
+ boundary_mask.release()
424
+
425
+ def _compute_deformation_gradients(self):
426
+ self._deformation_gradients = wp.empty(dtype=wp.mat22f, device=self.positions.device, shape=(self.cell_count()))
427
+
428
+ wp.launch(
429
+ kernel=Trimesh2D._compute_deformation_gradients_kernel,
430
+ dim=self._deformation_gradients.shape,
431
+ device=self._deformation_gradients.device,
432
+ inputs=[self.tri_vertex_indices, self.positions, self._deformation_gradients],
433
+ )
346
434
 
347
435
  @wp.kernel
348
436
  def _count_starting_edges_kernel(
@@ -420,8 +508,8 @@ class Trimesh2D(Geometry):
420
508
  vertex_unique_edge_count: wp.array(dtype=int),
421
509
  uncompressed_edge_ends: wp.array(dtype=int),
422
510
  uncompressed_edge_tris: wp.array2d(dtype=int),
423
- edge_vertex_indices: wp.array(dtype=vec2i),
424
- edge_tri_indices: wp.array(dtype=vec2i),
511
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
512
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
425
513
  boundary_mask: wp.array(dtype=int),
426
514
  ):
427
515
  v = wp.tid()
@@ -434,11 +522,11 @@ class Trimesh2D(Geometry):
434
522
  src_index = start_beg + e
435
523
  edge_index = unique_beg + e
436
524
 
437
- edge_vertex_indices[edge_index] = vec2i(v, uncompressed_edge_ends[src_index])
525
+ edge_vertex_indices[edge_index] = wp.vec2i(v, uncompressed_edge_ends[src_index])
438
526
 
439
527
  t0 = uncompressed_edge_tris[src_index, 0]
440
528
  t1 = uncompressed_edge_tris[src_index, 1]
441
- edge_tri_indices[edge_index] = vec2i(t0, t1)
529
+ edge_tri_indices[edge_index] = wp.vec2i(t0, t1)
442
530
  if t0 == t1:
443
531
  boundary_mask[edge_index] = 1
444
532
  else:
@@ -446,8 +534,8 @@ class Trimesh2D(Geometry):
446
534
 
447
535
  @wp.kernel
448
536
  def _flip_edge_normals(
449
- edge_vertex_indices: wp.array(dtype=vec2i),
450
- edge_tri_indices: wp.array(dtype=vec2i),
537
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
538
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
451
539
  tri_vertex_indices: wp.array2d(dtype=int),
452
540
  positions: wp.array(dtype=wp.vec2),
453
541
  ):
@@ -469,4 +557,21 @@ class Trimesh2D(Geometry):
469
557
 
470
558
  # if edge normal points toward first triangle centroid, flip indices
471
559
  if wp.dot(tri_centroid - edge_center, edge_normal) > 0.0:
472
- edge_vertex_indices[e] = vec2i(edge_vidx[1], edge_vidx[0])
560
+ edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])
561
+
562
+ @wp.kernel
563
+ def _compute_deformation_gradients_kernel(
564
+ tri_vertex_indices: wp.array2d(dtype=int),
565
+ positions: wp.array(dtype=wp.vec2f),
566
+ transforms: wp.array(dtype=wp.mat22f),
567
+ ):
568
+ t = wp.tid()
569
+
570
+ p0 = positions[tri_vertex_indices[t, 0]]
571
+ p1 = positions[tri_vertex_indices[t, 1]]
572
+ p2 = positions[tri_vertex_indices[t, 2]]
573
+
574
+ e1 = p1 - p0
575
+ e2 = p2 - p0
576
+
577
+ transforms[t] = wp.mat22(e1, e2)