warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Optional
17
+
16
18
  import warp as wp
17
- from warp.fem.cache import TemporaryStore, borrow_temporary, borrow_temporary_like, cached_arg_value
18
- from warp.fem.domain import GeometryDomain
19
- from warp.fem.types import NULL_NODE_INDEX, NodeElementIndex
20
- from warp.fem.utils import compress_node_indices
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.domain import GeometryDomain
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX, NodeElementIndex
22
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index
21
23
 
22
24
  from .partition import SpacePartition
23
25
 
@@ -32,7 +34,7 @@ class SpaceRestriction:
32
34
  space_partition: SpacePartition,
33
35
  domain: GeometryDomain,
34
36
  device=None,
35
- temporary_store: TemporaryStore = None,
37
+ temporary_store: cache.TemporaryStore = None,
36
38
  ):
37
39
  space_topology = space_partition.space_topology
38
40
 
@@ -46,15 +48,25 @@ class SpaceRestriction:
46
48
  self.space_topology = space_topology
47
49
  self.domain = domain
48
50
 
49
- self._compute_node_element_indices(device=device, temporary_store=temporary_store)
51
+ self._node_count_dev: wp.array = None
52
+ """Number of unique partition node indices"""
53
+ self._dof_partition_indices: wp.array = None
54
+ """Array of unique partition node indices"""
55
+
56
+ self._dof_partition_element_offsets: wp.array = None
57
+ """Mapping from partition node to offset in the per-node element indices array"""
58
+ self._dof_element_indices: wp.array = None
59
+ """Concatenation of neighboring elements indices for each partition node"""
60
+ self._dof_indices_in_element: wp.array = None
61
+ """Concatenation of node index in element for each partition node"""
50
62
 
51
- def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
52
- from warp.fem import cache
63
+ self.rebuild(device=device, temporary_store=temporary_store)
53
64
 
54
- MAX_NODES_PER_ELEMENT = self.space_topology.MAX_NODES_PER_ELEMENT
65
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
66
+ max_nodes_per_element = self.space_topology.MAX_NODES_PER_ELEMENT
55
67
 
56
68
  @cache.dynamic_kernel(
57
- suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}",
69
+ suffix=(self.domain.name, self.space_topology.name, self.space_partition.name),
58
70
  kernel_options={"max_unroll": 8},
59
71
  )
60
72
  def fill_element_node_indices(
@@ -66,75 +78,96 @@ class SpaceRestriction:
66
78
  ):
67
79
  domain_element_index = wp.tid()
68
80
  element_index = self.domain.element_index(domain_index_arg, domain_element_index)
69
- element_node_count = self.space_topology.element_node_count(element_arg, topo_arg, element_index)
81
+
82
+ if element_index == NULL_ELEMENT_INDEX:
83
+ element_node_count = 0
84
+ else:
85
+ element_node_count = self.space_topology.element_node_count(element_arg, topo_arg, element_index)
86
+
70
87
  for n in range(element_node_count):
71
88
  space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
72
89
  partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
73
90
  element_node_indices[domain_element_index, n] = partition_nidx
74
- for n in range(element_node_count, MAX_NODES_PER_ELEMENT):
91
+ for n in range(element_node_count, element_node_indices.shape[1]):
75
92
  element_node_indices[domain_element_index, n] = NULL_NODE_INDEX
76
93
 
77
- element_node_indices = borrow_temporary(
94
+ element_node_indices = cache.borrow_temporary(
78
95
  temporary_store,
79
- shape=(self.domain.element_count(), MAX_NODES_PER_ELEMENT),
96
+ shape=(self.domain.element_count(), max_nodes_per_element),
80
97
  dtype=int,
81
98
  device=device,
82
99
  )
83
100
  wp.launch(
84
- dim=element_node_indices.array.shape[0],
101
+ dim=element_node_indices.shape[0],
85
102
  kernel=fill_element_node_indices,
86
103
  inputs=[
87
104
  self.domain.element_arg_value(device),
88
105
  self.domain.element_index_arg_value(device),
89
106
  self.space_topology.topo_arg_value(device),
90
107
  self.space_partition.partition_arg_value(device),
91
- element_node_indices.array,
108
+ element_node_indices,
92
109
  ],
93
110
  device=device,
94
111
  )
95
112
 
96
113
  # Build compressed map from node to element indices
97
- flattened_node_indices = element_node_indices.array.flatten()
114
+ flattened_node_indices = element_node_indices.flatten()
98
115
  (
99
116
  self._dof_partition_element_offsets,
100
117
  node_array_indices,
101
- self._node_count,
118
+ self._node_count_dev,
102
119
  self._dof_partition_indices,
103
120
  ) = compress_node_indices(
104
121
  self.space_partition.node_count(),
105
122
  flattened_node_indices,
123
+ node_offsets=self._dof_partition_element_offsets,
124
+ unique_node_count=self._node_count_dev,
125
+ unique_node_indices=self._dof_partition_indices,
106
126
  return_unique_nodes=True,
107
127
  temporary_store=temporary_store,
108
128
  )
109
129
 
110
130
  # Extract element index and index in element
111
- self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
112
- self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
131
+ if self._dof_element_indices is None or self._dof_element_indices.shape != flattened_node_indices.shape:
132
+ self._dof_element_indices = cache.borrow_temporary_like(flattened_node_indices, temporary_store)
133
+ self._dof_indices_in_element = cache.borrow_temporary_like(flattened_node_indices, temporary_store)
134
+
113
135
  wp.launch(
114
136
  kernel=SpaceRestriction._split_vertex_element_index,
115
137
  dim=flattened_node_indices.shape,
116
138
  inputs=[
117
- MAX_NODES_PER_ELEMENT,
118
- node_array_indices.array,
119
- self._dof_element_indices.array,
120
- self._dof_indices_in_element.array,
139
+ max_nodes_per_element,
140
+ node_array_indices,
141
+ self._dof_element_indices,
142
+ self._dof_indices_in_element,
121
143
  ],
122
144
  device=flattened_node_indices.device,
123
145
  )
124
146
 
125
147
  node_array_indices.release()
126
148
 
127
- def node_count(self):
149
+ # Upper bound on node count, use `node_count_sync` to get the actual value
150
+ self._node_count = min(self.space_partition.node_count(), self._dof_partition_indices.shape[0])
151
+
152
+ def node_count_sync(self) -> int:
153
+ """Ensures that the node count is synchronized with the device and returns it"""
154
+ if self._node_count_dev is not None:
155
+ self._node_count = int(host_read_at_index(self._node_count_dev, index=0))
156
+ self._node_count_dev = None
157
+ return self.node_count()
158
+
159
+ def node_count(self) -> int:
160
+ """Upper bound for the node count, use `node_count_sync` to get the actual value"""
128
161
  return self._node_count
129
162
 
130
163
  def partition_element_offsets(self):
131
- return self._dof_partition_element_offsets.array
164
+ return self._dof_partition_element_offsets
132
165
 
133
166
  def node_partition_indices(self):
134
- return self._dof_partition_indices.array
167
+ return self._dof_partition_indices
135
168
 
136
169
  def total_node_element_count(self):
137
- return self._dof_element_indices.array.size
170
+ return self._dof_element_indices.size
138
171
 
139
172
  @wp.struct
140
173
  class NodeArg:
@@ -143,17 +176,17 @@ class SpaceRestriction:
143
176
  dof_partition_indices: wp.array(dtype=int)
144
177
  dof_indices_in_element: wp.array(dtype=int)
145
178
 
146
- @cached_arg_value
179
+ @cache.cached_arg_value
147
180
  def node_arg_value(self, device):
148
181
  arg = SpaceRestriction.NodeArg()
149
182
  self.fill_node_arg(arg, device)
150
183
  return arg
151
184
 
152
185
  def fill_node_arg(self, arg: NodeArg, device):
153
- arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
154
- arg.dof_element_indices = self._dof_element_indices.array.to(device)
155
- arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
156
- arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
186
+ arg.dof_element_offsets = self._dof_partition_element_offsets.to(device)
187
+ arg.dof_element_indices = self._dof_element_indices.to(device)
188
+ arg.dof_partition_indices = self._dof_partition_indices.to(device)
189
+ arg.dof_indices_in_element = self._dof_indices_in_element.to(device)
157
190
 
158
191
  @wp.func
159
192
  def node_partition_index(args: NodeArg, restriction_node_index: int):
@@ -0,0 +1,152 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from enum import Enum
18
+ from typing import Optional
19
+
20
+ from warp._src.fem.geometry import Element
21
+ from warp._src.fem.polynomial import Polynomial
22
+
23
+ from .cube_shape_function import (
24
+ CubeNedelecFirstKindShapeFunctions,
25
+ CubeNonConformingPolynomialShapeFunctions,
26
+ CubeRaviartThomasShapeFunctions,
27
+ CubeSerendipityShapeFunctions,
28
+ CubeShapeFunction,
29
+ CubeTripolynomialShapeFunctions,
30
+ )
31
+ from .shape_function import ConstantShapeFunction, ShapeFunction
32
+ from .square_shape_function import (
33
+ SquareBipolynomialShapeFunctions,
34
+ SquareNedelecFirstKindShapeFunctions,
35
+ SquareNonConformingPolynomialShapeFunctions,
36
+ SquareRaviartThomasShapeFunctions,
37
+ SquareSerendipityShapeFunctions,
38
+ SquareShapeFunction,
39
+ )
40
+ from .tet_shape_function import (
41
+ TetrahedronNedelecFirstKindShapeFunctions,
42
+ TetrahedronNonConformingPolynomialShapeFunctions,
43
+ TetrahedronPolynomialShapeFunctions,
44
+ TetrahedronRaviartThomasShapeFunctions,
45
+ TetrahedronShapeFunction,
46
+ )
47
+ from .triangle_shape_function import (
48
+ TriangleNedelecFirstKindShapeFunctions,
49
+ TriangleNonConformingPolynomialShapeFunctions,
50
+ TrianglePolynomialShapeFunctions,
51
+ TriangleRaviartThomasShapeFunctions,
52
+ TriangleShapeFunction,
53
+ )
54
+
55
+
56
+ class ElementBasis(Enum):
57
+ """Choice of basis function to equip individual elements"""
58
+
59
+ LAGRANGE = "P"
60
+ """Lagrange basis functions :math:`P_k` for simplices, tensor products :math:`Q_k` for squares and cubes"""
61
+ SERENDIPITY = "S"
62
+ """Serendipity elements :math:`S_k`, corresponding to Lagrange nodes with interior points removed (for degree <= 3)"""
63
+ NONCONFORMING_POLYNOMIAL = "dP"
64
+ """Simplex Lagrange basis functions :math:`P_{kd}` embedded into non conforming reference elements (e.g. squares or cubes). Discontinuous only."""
65
+ NEDELEC_FIRST_KIND = "N1"
66
+ """Nédélec (first kind) H(curl) shape functions. Should be used with covariant function space."""
67
+ RAVIART_THOMAS = "RT"
68
+ """Raviart-Thomas H(div) shape functions. Should be used with contravariant function space."""
69
+
70
+
71
+ @functools.lru_cache(maxsize=None)
72
+ def make_element_shape_function(
73
+ element: Element,
74
+ degree: int,
75
+ element_basis: Optional[ElementBasis] = None,
76
+ family: Optional[Polynomial] = None,
77
+ ) -> ShapeFunction:
78
+ """
79
+ Equips a reference element with a shape function basis.
80
+
81
+ Args:
82
+ element: the type of reference element on which to build the shape function
83
+ degree: polynomial degree of the per-element shape functions
84
+ element_basis: type of basis function for the individual elements
85
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
86
+
87
+ Returns:
88
+ the corresponding shape function
89
+
90
+ Raises:
91
+ NotImplementedError: If the shape function is not implemented for the given element type
92
+ """
93
+
94
+ if element_basis is None:
95
+ element_basis = ElementBasis.LAGRANGE
96
+ elif element_basis == ElementBasis.SERENDIPITY and degree == 1:
97
+ # Degree-1 serendipity is always equivalent to Lagrange
98
+ element_basis = ElementBasis.LAGRANGE
99
+
100
+ if degree == 0:
101
+ return ConstantShapeFunction(element)
102
+
103
+ if family is None:
104
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
105
+
106
+ if element == Element.SQUARE:
107
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
108
+ return SquareNedelecFirstKindShapeFunctions(degree=degree)
109
+ if element_basis == ElementBasis.RAVIART_THOMAS:
110
+ return SquareRaviartThomasShapeFunctions(degree=degree)
111
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
112
+ return SquareNonConformingPolynomialShapeFunctions(degree=degree)
113
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
114
+ return SquareSerendipityShapeFunctions(degree=degree, family=family)
115
+
116
+ return SquareBipolynomialShapeFunctions(degree=degree, family=family)
117
+ if element == Element.TRIANGLE:
118
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
119
+ return TriangleNedelecFirstKindShapeFunctions(degree=degree)
120
+ if element_basis == ElementBasis.RAVIART_THOMAS:
121
+ return TriangleRaviartThomasShapeFunctions(degree=degree)
122
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
123
+ return TriangleNonConformingPolynomialShapeFunctions(degree=degree)
124
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
125
+ raise NotImplementedError("Serendipity variant not implemented yet for Triangle elements")
126
+
127
+ return TrianglePolynomialShapeFunctions(degree=degree)
128
+
129
+ if element == Element.CUBE:
130
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
131
+ return CubeNedelecFirstKindShapeFunctions(degree=degree)
132
+ if element_basis == ElementBasis.RAVIART_THOMAS:
133
+ return CubeRaviartThomasShapeFunctions(degree=degree)
134
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
135
+ return CubeNonConformingPolynomialShapeFunctions(degree=degree)
136
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
137
+ return CubeSerendipityShapeFunctions(degree=degree, family=family)
138
+
139
+ return CubeTripolynomialShapeFunctions(degree=degree, family=family)
140
+ if element == Element.TETRAHEDRON:
141
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
142
+ return TetrahedronNedelecFirstKindShapeFunctions(degree=degree)
143
+ if element_basis == ElementBasis.RAVIART_THOMAS:
144
+ return TetrahedronRaviartThomasShapeFunctions(degree=degree)
145
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
146
+ return TetrahedronNonConformingPolynomialShapeFunctions(degree=degree)
147
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
148
+ raise NotImplementedError("Serendipity variant not implemented yet for Tet elements")
149
+
150
+ return TetrahedronPolynomialShapeFunctions(degree=degree)
151
+
152
+ raise NotImplementedError(f"Unrecognized element type {element}")
@@ -18,10 +18,10 @@ import math
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
- from warp.fem import cache
22
- from warp.fem.geometry import Grid3D
23
- from warp.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
24
- from warp.fem.types import Coords
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.geometry import Grid3D
23
+ from warp._src.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
24
+ from warp._src.fem.types import Coords
25
25
 
26
26
  from .shape_function import ShapeFunction
27
27
  from .tet_shape_function import TetrahedronPolynomialShapeFunctions
@@ -79,7 +79,7 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
79
79
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
80
80
  lagrange_scale = lagrange_scales(lobatto_coords)
81
81
 
82
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
82
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
83
83
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
84
84
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
85
85
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -395,12 +395,12 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
395
395
  return cache.get_func(element_inner_weight_gradient, self.name)
396
396
 
397
397
  def element_node_hexes(self):
398
- from warp.fem.utils import grid_to_hexes
398
+ from warp._src.fem.utils import grid_to_hexes
399
399
 
400
400
  return grid_to_hexes(self.ORDER, self.ORDER, self.ORDER)
401
401
 
402
402
  def element_node_tets(self):
403
- from warp.fem.utils import grid_to_tets
403
+ from warp._src.fem.utils import grid_to_tets
404
404
 
405
405
  return grid_to_tets(self.ORDER, self.ORDER, self.ORDER)
406
406
 
@@ -514,7 +514,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
514
514
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
515
515
  lagrange_scale = lagrange_scales(lobatto_coords)
516
516
 
517
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
517
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
518
518
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
519
519
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
520
520
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -751,7 +751,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
751
751
  return element_inner_weight_gradient
752
752
 
753
753
  def element_node_tets(self):
754
- from warp.fem.utils import grid_to_tets
754
+ from warp._src.fem.utils import grid_to_tets
755
755
 
756
756
  if self.ORDER == 2:
757
757
  element_tets = np.array(
@@ -18,9 +18,9 @@ from enum import Enum
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
- from warp.fem import cache
22
- from warp.fem.geometry import Element
23
- from warp.fem.types import Coords
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.geometry import Element
23
+ from warp._src.fem.types import Coords
24
24
 
25
25
 
26
26
  class ShapeFunction:
@@ -68,19 +68,18 @@ class ShapeFunction:
68
68
  class ConstantShapeFunction(ShapeFunction):
69
69
  """Shape function that is constant over the element"""
70
70
 
71
- def __init__(self, element: Element, space_dimension: int):
72
- self._element = element
73
- self._dimension = space_dimension
71
+ def __init__(self, element: Element):
72
+ self._element_prototype = element.prototype
74
73
 
75
74
  self.ORDER = wp.constant(0)
76
75
  self.NODES_PER_ELEMENT = wp.constant(1)
77
76
 
78
- coords, _ = element.instantiate_quadrature(order=0, family=None)
77
+ coords, _ = self._element_prototype.instantiate_quadrature(order=0, family=None)
79
78
  self.COORDS = wp.constant(coords[0])
80
79
 
81
80
  @property
82
81
  def name(self) -> str:
83
- return f"{self._element.__class__.__name__}{self._dimension}"
82
+ return f"{self._element_prototype.__name__}"
84
83
 
85
84
  def make_node_coords_in_element(self):
86
85
  COORDS = self.COORDS
@@ -116,7 +115,7 @@ class ConstantShapeFunction(ShapeFunction):
116
115
  return ConstantShapeFunction._element_inner_weight
117
116
 
118
117
  def make_element_inner_weight_gradient(self):
119
- grad_type = wp.vec(length=self._dimension, dtype=float)
118
+ grad_type = cache.cached_vec_type(length=self._element_prototype.dimension, dtype=float)
120
119
 
121
120
  @cache.dynamic_func(suffix=self.name)
122
121
  def element_inner_weight_gradient(
@@ -18,9 +18,9 @@ import math
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
- from warp.fem import cache
22
- from warp.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
23
- from warp.fem.types import Coords
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
23
+ from warp._src.fem.types import Coords
24
24
 
25
25
  from .shape_function import ShapeFunction
26
26
  from .triangle_shape_function import TrianglePolynomialShapeFunctions
@@ -68,7 +68,7 @@ class SquareBipolynomialShapeFunctions(SquareShapeFunction):
68
68
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
69
69
  lagrange_scale = lagrange_scales(lobatto_coords)
70
70
 
71
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
71
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
72
72
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
73
73
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
74
74
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -291,7 +291,7 @@ class SquareBipolynomialShapeFunctions(SquareShapeFunction):
291
291
  return cache.get_func(element_inner_weight_gradient, self.name)
292
292
 
293
293
  def element_node_triangulation(self):
294
- from warp.fem.utils import grid_to_tris
294
+ from warp._src.fem.utils import grid_to_tris
295
295
 
296
296
  return grid_to_tris(self.ORDER, self.ORDER)
297
297
 
@@ -348,7 +348,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
348
348
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
349
349
  lagrange_scale = lagrange_scales(lobatto_coords)
350
350
 
351
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
351
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
352
352
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
353
353
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
354
354
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -16,8 +16,8 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.types import Coords
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.types import Coords
21
21
 
22
22
  from .shape_function import ShapeFunction
23
23
 
@@ -182,7 +182,7 @@ class TetrahedronPolynomialShapeFunctions(TetrahedronShapeFunction):
182
182
  index = _tet_node_index(tx, ty, tz, degree)
183
183
  tet_coords[index] = [tx, ty, tz]
184
184
 
185
- CoordTypeVec = wp.mat(dtype=int, shape=(self.NODES_PER_ELEMENT, 3))
185
+ CoordTypeVec = cache.cached_mat_type(dtype=int, shape=(self.NODES_PER_ELEMENT, 3))
186
186
  self.NODE_TET_COORDS = wp.constant(CoordTypeVec(tet_coords))
187
187
 
188
188
  self.node_type_and_type_index = self._get_node_type_and_type_index()
@@ -16,8 +16,8 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.types import Coords
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.types import Coords
21
21
 
22
22
  from .shape_function import ShapeFunction
23
23
 
@@ -95,7 +95,7 @@ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
95
95
  index = _triangle_node_index(tx, ty, degree)
96
96
  triangle_coords[index] = [tx, ty]
97
97
 
98
- CoordTypeVec = wp.mat(dtype=int, shape=(self.NODES_PER_ELEMENT, 2))
98
+ CoordTypeVec = cache.cached_mat_type(dtype=int, shape=(self.NODES_PER_ELEMENT, 2))
99
99
  self.NODE_TRIANGLE_COORDS = wp.constant(CoordTypeVec(triangle_coords))
100
100
 
101
101
  self.node_type_and_type_index = self._get_node_type_and_type_index()
@@ -14,9 +14,9 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import warp as wp
17
- from warp.fem import cache
18
- from warp.fem.geometry import Tetmesh
19
- from warp.fem.types import ElementIndex
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Tetmesh
19
+ from warp._src.fem.types import ElementIndex
20
20
 
21
21
  from .shape import (
22
22
  ShapeFunction,
@@ -72,12 +72,6 @@ class TetmeshSpaceTopology(SpaceTopology):
72
72
  def name(self):
73
73
  return f"{self.geometry.name}_{self._shape.name}"
74
74
 
75
- @cache.cached_arg_value
76
- def topo_arg_value(self, device):
77
- arg = TetmeshTopologyArg()
78
- self.fill_topo_arg(arg, device)
79
- return arg
80
-
81
75
  def fill_topo_arg(self, arg: TetmeshTopologyArg, device):
82
76
  arg.tet_face_indices = self._tet_face_indices.to(device)
83
77
  arg.tet_edge_indices = self._tet_edge_indices.to(device)