warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,12 @@
16
16
  from typing import Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import (
19
+ from warp._src.fem.cache import (
20
20
  TemporaryStore,
21
21
  borrow_temporary,
22
22
  borrow_temporary_like,
23
- cached_arg_value,
24
23
  )
25
- from warp.fem.types import (
24
+ from warp._src.fem.types import (
26
25
  OUTSIDE,
27
26
  Coords,
28
27
  ElementIndex,
@@ -30,7 +29,7 @@ from warp.fem.types import (
30
29
  )
31
30
 
32
31
  from .closest_point import project_on_tet_at_origin, project_on_tri_at_origin
33
- from .element import Tetrahedron, Triangle
32
+ from .element import Element
34
33
  from .geometry import Geometry
35
34
 
36
35
 
@@ -108,11 +107,11 @@ class Tetmesh(Geometry):
108
107
  def boundary_side_count(self):
109
108
  return self._boundary_face_indices.shape[0]
110
109
 
111
- def reference_cell(self) -> Tetrahedron:
112
- return Tetrahedron()
110
+ def reference_cell(self) -> Element:
111
+ return Element.TETRAHEDRON
113
112
 
114
- def reference_side(self) -> Triangle:
115
- return Triangle()
113
+ def reference_side(self) -> Element:
114
+ return Element.TRIANGLE
116
115
 
117
116
  @property
118
117
  def tet_edge_indices(self) -> wp.array:
@@ -137,11 +136,6 @@ class Tetmesh(Geometry):
137
136
 
138
137
  # Geometry device interface
139
138
 
140
- def cell_arg_value(self, device) -> CellArg:
141
- args = self.CellArg()
142
- self.fill_cell_arg(args, device)
143
- return args
144
-
145
139
  def fill_cell_arg(self, args: CellArg, device):
146
140
  args.tet_vertex_indices = self.tet_vertex_indices.to(device)
147
141
  args.positions = self.positions.to(device)
@@ -183,12 +177,6 @@ class Tetmesh(Geometry):
183
177
  dist, coords = project_on_tet_at_origin(q, e1, e2, e3)
184
178
  return coords, dist
185
179
 
186
- @cached_arg_value
187
- def side_index_arg_value(self, device) -> SideIndexArg:
188
- args = self.SideIndexArg()
189
- self.fill_side_index_arg(args, device)
190
- return args
191
-
192
180
  def fill_side_index_arg(self, args: SideIndexArg, device):
193
181
  args.boundary_face_indices = self._boundary_face_indices.to(device)
194
182
 
@@ -198,11 +186,6 @@ class Tetmesh(Geometry):
198
186
 
199
187
  return args.boundary_face_indices[boundary_side_index]
200
188
 
201
- def side_arg_value(self, device) -> CellArg:
202
- args = self.SideArg()
203
- self.fill_side_arg(args, device)
204
- return args
205
-
206
189
  def fill_side_arg(self, args: SideArg, device):
207
190
  self.fill_cell_arg(args.cell_arg, device)
208
191
  args.face_vertex_indices = self._face_vertex_indices.to(device)
@@ -325,8 +308,8 @@ class Tetmesh(Geometry):
325
308
  return side_arg.cell_arg
326
309
 
327
310
  def _build_topology(self, temporary_store: TemporaryStore):
328
- from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
329
- from warp.utils import array_scan
311
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index, masked_indices
312
+ from warp._src.utils import array_scan
330
313
 
331
314
  device = self.tet_vertex_indices.device
332
315
 
@@ -337,7 +320,7 @@ class Tetmesh(Geometry):
337
320
  self._vertex_tet_indices = vertex_tet_indices.detach()
338
321
 
339
322
  vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
340
- vertex_start_face_count.array.zero_()
323
+ vertex_start_face_count.zero_()
341
324
  vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
342
325
 
343
326
  vertex_face_other_vs = borrow_temporary(
@@ -350,10 +333,10 @@ class Tetmesh(Geometry):
350
333
  kernel=Tetmesh._count_starting_faces_kernel,
351
334
  device=device,
352
335
  dim=self.cell_count(),
353
- inputs=[self.tet_vertex_indices, vertex_start_face_count.array],
336
+ inputs=[self.tet_vertex_indices, vertex_start_face_count],
354
337
  )
355
338
 
356
- array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
339
+ array_scan(in_array=vertex_start_face_count, out_array=vertex_start_face_offsets, inclusive=False)
357
340
 
358
341
  # Count number of unique edges (deduplicate across faces)
359
342
  vertex_unique_face_count = vertex_start_face_count
@@ -365,21 +348,19 @@ class Tetmesh(Geometry):
365
348
  self._vertex_tet_offsets,
366
349
  self._vertex_tet_indices,
367
350
  self.tet_vertex_indices,
368
- vertex_start_face_offsets.array,
369
- vertex_unique_face_count.array,
370
- vertex_face_other_vs.array,
371
- vertex_face_tets.array,
351
+ vertex_start_face_offsets,
352
+ vertex_unique_face_count,
353
+ vertex_face_other_vs,
354
+ vertex_face_tets,
372
355
  ],
373
356
  )
374
357
 
375
358
  vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
376
- array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
359
+ array_scan(in_array=vertex_start_face_count, out_array=vertex_unique_face_offsets, inclusive=False)
377
360
 
378
361
  # Get back edge count to host
379
362
  face_count = int(
380
- host_read_at_index(
381
- vertex_unique_face_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
382
- )
363
+ host_read_at_index(vertex_unique_face_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
383
364
  )
384
365
 
385
366
  self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec3i, device=device)
@@ -393,14 +374,14 @@ class Tetmesh(Geometry):
393
374
  device=device,
394
375
  dim=self.vertex_count(),
395
376
  inputs=[
396
- vertex_start_face_offsets.array,
397
- vertex_unique_face_offsets.array,
398
- vertex_unique_face_count.array,
399
- vertex_face_other_vs.array,
400
- vertex_face_tets.array,
377
+ vertex_start_face_offsets,
378
+ vertex_unique_face_offsets,
379
+ vertex_unique_face_count,
380
+ vertex_face_other_vs,
381
+ vertex_face_tets,
401
382
  self._face_vertex_indices,
402
383
  self._face_tet_indices,
403
- boundary_mask.array,
384
+ boundary_mask,
404
385
  ],
405
386
  )
406
387
 
@@ -418,17 +399,17 @@ class Tetmesh(Geometry):
418
399
  inputs=[self._face_vertex_indices, self._face_tet_indices, self.tet_vertex_indices, self.positions],
419
400
  )
420
401
 
421
- boundary_face_indices, _ = masked_indices(boundary_mask.array)
402
+ boundary_face_indices, _ = masked_indices(boundary_mask)
422
403
  self._boundary_face_indices = boundary_face_indices.detach()
423
404
 
424
405
  def _compute_tet_edges(self, temporary_store: Optional[TemporaryStore] = None):
425
- from warp.fem.utils import host_read_at_index
426
- from warp.utils import array_scan
406
+ from warp._src.fem.utils import host_read_at_index
407
+ from warp._src.utils import array_scan
427
408
 
428
409
  device = self.tet_vertex_indices.device
429
410
 
430
411
  vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
431
- vertex_start_edge_count.array.zero_()
412
+ vertex_start_edge_count.zero_()
432
413
  vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
433
414
 
434
415
  vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(6 * self.cell_count()))
@@ -438,10 +419,10 @@ class Tetmesh(Geometry):
438
419
  kernel=Tetmesh._count_starting_edges_kernel,
439
420
  device=device,
440
421
  dim=self.cell_count(),
441
- inputs=[self.tet_vertex_indices, vertex_start_edge_count.array],
422
+ inputs=[self.tet_vertex_indices, vertex_start_edge_count],
442
423
  )
443
424
 
444
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
425
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
445
426
 
446
427
  # Count number of unique edges (deduplicate across faces)
447
428
  vertex_unique_edge_count = vertex_start_edge_count
@@ -453,22 +434,18 @@ class Tetmesh(Geometry):
453
434
  self._vertex_tet_offsets,
454
435
  self._vertex_tet_indices,
455
436
  self.tet_vertex_indices,
456
- vertex_start_edge_offsets.array,
457
- vertex_unique_edge_count.array,
458
- vertex_edge_ends.array,
437
+ vertex_start_edge_offsets,
438
+ vertex_unique_edge_count,
439
+ vertex_edge_ends,
459
440
  ],
460
441
  )
461
442
 
462
- vertex_unique_edge_offsets = borrow_temporary_like(
463
- vertex_start_edge_offsets.array, temporary_store=temporary_store
464
- )
465
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
443
+ vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
444
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
466
445
 
467
446
  # Get back edge count to host
468
447
  self._edge_count = int(
469
- host_read_at_index(
470
- vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
471
- )
448
+ host_read_at_index(vertex_unique_edge_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
472
449
  )
473
450
 
474
451
  self._tet_edge_indices = wp.empty(
@@ -484,10 +461,10 @@ class Tetmesh(Geometry):
484
461
  self._vertex_tet_offsets,
485
462
  self._vertex_tet_indices,
486
463
  self.tet_vertex_indices,
487
- vertex_start_edge_offsets.array,
488
- vertex_unique_edge_offsets.array,
489
- vertex_unique_edge_count.array,
490
- vertex_edge_ends.array,
464
+ vertex_start_edge_offsets,
465
+ vertex_unique_edge_offsets,
466
+ vertex_unique_edge_count,
467
+ vertex_edge_ends,
491
468
  self._tet_edge_indices,
492
469
  ],
493
470
  )
@@ -16,13 +16,12 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import (
19
+ from warp._src.fem.cache import (
20
20
  TemporaryStore,
21
21
  borrow_temporary,
22
22
  borrow_temporary_like,
23
- cached_arg_value,
24
23
  )
25
- from warp.fem.types import (
24
+ from warp._src.fem.types import (
26
25
  OUTSIDE,
27
26
  Coords,
28
27
  ElementIndex,
@@ -30,7 +29,7 @@ from warp.fem.types import (
30
29
  )
31
30
 
32
31
  from .closest_point import project_on_seg_at_origin, project_on_tri_at_origin
33
- from .element import LinearEdge, Triangle
32
+ from .element import Element
34
33
  from .geometry import Geometry
35
34
 
36
35
 
@@ -103,11 +102,11 @@ class Trimesh(Geometry):
103
102
  def boundary_side_count(self):
104
103
  return self._boundary_edge_indices.shape[0]
105
104
 
106
- def reference_cell(self) -> Triangle:
107
- return Triangle()
105
+ def reference_cell(self) -> Element:
106
+ return Element.TRIANGLE
108
107
 
109
- def reference_side(self) -> LinearEdge:
110
- return LinearEdge()
108
+ def reference_side(self) -> Element:
109
+ return Element.LINE_SEGMENT
111
110
 
112
111
  @property
113
112
  def edge_tri_indices(self) -> wp.array:
@@ -130,30 +129,14 @@ class Trimesh(Geometry):
130
129
  args.edge_vertex_indices = self._edge_vertex_indices.to(device)
131
130
  args.edge_tri_indices = self._edge_tri_indices.to(device)
132
131
 
133
- def cell_arg_value(self, device):
134
- args = self.CellArg()
135
- self.fill_cell_arg(args, device)
136
- return args
137
-
138
132
  def fill_cell_arg(self, args: TrimeshCellArg, device):
139
133
  self._fill_cell_topo_arg(args.topology, device)
140
134
  args.positions = self.positions.to(device)
141
135
 
142
- def side_arg_value(self, device):
143
- args = self.SideArg()
144
- self.fill_side_arg(args, device)
145
- return args
146
-
147
136
  def fill_side_arg(self, args: TrimeshSideArg, device):
148
137
  self._fill_side_topo_arg(args.topology, device)
149
138
  args.positions = self.positions.to(device)
150
139
 
151
- @cached_arg_value
152
- def side_index_arg_value(self, device) -> SideIndexArg:
153
- args = self.SideIndexArg()
154
- self.fill_side_index_arg(args, device)
155
- return args
156
-
157
140
  def fill_side_index_arg(self, args: SideIndexArg, device):
158
141
  args.boundary_edge_indices = self._boundary_edge_indices.to(device)
159
142
 
@@ -216,8 +199,8 @@ class Trimesh(Geometry):
216
199
  return wp.where(tri_coords[start] + tri_coords[end] > 0.999, Coords(tri_coords[end], 0.0, 0.0), Coords(OUTSIDE))
217
200
 
218
201
  def _build_topology(self, temporary_store: TemporaryStore):
219
- from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
220
- from warp.utils import array_scan
202
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index, masked_indices
203
+ from warp._src.utils import array_scan
221
204
 
222
205
  device = self.tri_vertex_indices.device
223
206
 
@@ -228,7 +211,7 @@ class Trimesh(Geometry):
228
211
  self._vertex_tri_indices = vertex_tri_indices.detach()
229
212
 
230
213
  vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
231
- vertex_start_edge_count.array.zero_()
214
+ vertex_start_edge_count.zero_()
232
215
  vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
233
216
 
234
217
  vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count()))
@@ -239,10 +222,10 @@ class Trimesh(Geometry):
239
222
  kernel=Trimesh._count_starting_edges_kernel,
240
223
  device=device,
241
224
  dim=self.cell_count(),
242
- inputs=[self.tri_vertex_indices, vertex_start_edge_count.array],
225
+ inputs=[self.tri_vertex_indices, vertex_start_edge_count],
243
226
  )
244
227
 
245
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
228
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
246
229
 
247
230
  # Count number of unique edges (deduplicate across faces)
248
231
  vertex_unique_edge_count = vertex_start_edge_count
@@ -254,21 +237,19 @@ class Trimesh(Geometry):
254
237
  self._vertex_tri_offsets,
255
238
  self._vertex_tri_indices,
256
239
  self.tri_vertex_indices,
257
- vertex_start_edge_offsets.array,
258
- vertex_unique_edge_count.array,
259
- vertex_edge_ends.array,
260
- vertex_edge_tris.array,
240
+ vertex_start_edge_offsets,
241
+ vertex_unique_edge_count,
242
+ vertex_edge_ends,
243
+ vertex_edge_tris,
261
244
  ],
262
245
  )
263
246
 
264
247
  vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
265
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
248
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
266
249
 
267
250
  # Get back edge count to host
268
251
  edge_count = int(
269
- host_read_at_index(
270
- vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
271
- )
252
+ host_read_at_index(vertex_unique_edge_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
272
253
  )
273
254
 
274
255
  self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
@@ -282,14 +263,14 @@ class Trimesh(Geometry):
282
263
  device=device,
283
264
  dim=self.vertex_count(),
284
265
  inputs=[
285
- vertex_start_edge_offsets.array,
286
- vertex_unique_edge_offsets.array,
287
- vertex_unique_edge_count.array,
288
- vertex_edge_ends.array,
289
- vertex_edge_tris.array,
266
+ vertex_start_edge_offsets,
267
+ vertex_unique_edge_offsets,
268
+ vertex_unique_edge_count,
269
+ vertex_edge_ends,
270
+ vertex_edge_tris,
290
271
  self._edge_vertex_indices,
291
272
  self._edge_tri_indices,
292
- boundary_mask.array,
273
+ boundary_mask,
293
274
  ],
294
275
  )
295
276
 
@@ -299,7 +280,7 @@ class Trimesh(Geometry):
299
280
  vertex_edge_ends.release()
300
281
  vertex_edge_tris.release()
301
282
 
302
- boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
283
+ boundary_edge_indices, _ = masked_indices(boundary_mask, temporary_store=temporary_store)
303
284
  self._boundary_edge_indices = boundary_edge_indices.detach()
304
285
 
305
286
  boundary_mask.release()
@@ -467,7 +448,7 @@ class Trimesh(Geometry):
467
448
  q = pos - p0
468
449
  e = args.positions[edge_idx[1]] - p0
469
450
 
470
- dist, t = project_on_seg_at_origin(q, e, wp.lengh_sq(e))
451
+ dist, t = project_on_seg_at_origin(q, e, wp.length_sq(e))
471
452
  return Coords(t, 0.0, 0.0), dist
472
453
 
473
454
  @wp.func