warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Optional
17
+
16
18
  import warp as wp
17
- from warp.fem.cache import TemporaryStore, borrow_temporary, borrow_temporary_like, cached_arg_value
18
- from warp.fem.domain import GeometryDomain
19
- from warp.fem.types import NULL_NODE_INDEX, NodeElementIndex
20
- from warp.fem.utils import compress_node_indices
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.domain import GeometryDomain
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX, NodeElementIndex
22
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index
21
23
 
22
24
  from .partition import SpacePartition
23
25
 
26
+ _wp_module_name_ = "warp.fem.space.restriction"
27
+
24
28
  wp.set_module_options({"enable_backward": False})
25
29
 
26
30
 
@@ -32,7 +36,7 @@ class SpaceRestriction:
32
36
  space_partition: SpacePartition,
33
37
  domain: GeometryDomain,
34
38
  device=None,
35
- temporary_store: TemporaryStore = None,
39
+ temporary_store: cache.TemporaryStore = None,
36
40
  ):
37
41
  space_topology = space_partition.space_topology
38
42
 
@@ -46,15 +50,25 @@ class SpaceRestriction:
46
50
  self.space_topology = space_topology
47
51
  self.domain = domain
48
52
 
49
- self._compute_node_element_indices(device=device, temporary_store=temporary_store)
53
+ self._node_count_dev: wp.array = None
54
+ """Number of unique partition node indices"""
55
+ self._dof_partition_indices: wp.array = None
56
+ """Array of unique partition node indices"""
50
57
 
51
- def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
52
- from warp.fem import cache
58
+ self._dof_partition_element_offsets: wp.array = None
59
+ """Mapping from partition node to offset in the per-node element indices array"""
60
+ self._dof_element_indices: wp.array = None
61
+ """Concatenation of neighboring elements indices for each partition node"""
62
+ self._dof_indices_in_element: wp.array = None
63
+ """Concatenation of node index in element for each partition node"""
53
64
 
54
- MAX_NODES_PER_ELEMENT = self.space_topology.MAX_NODES_PER_ELEMENT
65
+ self.rebuild(device=device, temporary_store=temporary_store)
66
+
67
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
68
+ max_nodes_per_element = self.space_topology.MAX_NODES_PER_ELEMENT
55
69
 
56
70
  @cache.dynamic_kernel(
57
- suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}",
71
+ suffix=(self.domain.name, self.space_topology.name, self.space_partition.name),
58
72
  kernel_options={"max_unroll": 8},
59
73
  )
60
74
  def fill_element_node_indices(
@@ -66,75 +80,96 @@ class SpaceRestriction:
66
80
  ):
67
81
  domain_element_index = wp.tid()
68
82
  element_index = self.domain.element_index(domain_index_arg, domain_element_index)
69
- element_node_count = self.space_topology.element_node_count(element_arg, topo_arg, element_index)
83
+
84
+ if element_index == NULL_ELEMENT_INDEX:
85
+ element_node_count = 0
86
+ else:
87
+ element_node_count = self.space_topology.element_node_count(element_arg, topo_arg, element_index)
88
+
70
89
  for n in range(element_node_count):
71
90
  space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
72
91
  partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
73
92
  element_node_indices[domain_element_index, n] = partition_nidx
74
- for n in range(element_node_count, MAX_NODES_PER_ELEMENT):
93
+ for n in range(element_node_count, element_node_indices.shape[1]):
75
94
  element_node_indices[domain_element_index, n] = NULL_NODE_INDEX
76
95
 
77
- element_node_indices = borrow_temporary(
96
+ element_node_indices = cache.borrow_temporary(
78
97
  temporary_store,
79
- shape=(self.domain.element_count(), MAX_NODES_PER_ELEMENT),
98
+ shape=(self.domain.element_count(), max_nodes_per_element),
80
99
  dtype=int,
81
100
  device=device,
82
101
  )
83
102
  wp.launch(
84
- dim=element_node_indices.array.shape[0],
103
+ dim=element_node_indices.shape[0],
85
104
  kernel=fill_element_node_indices,
86
105
  inputs=[
87
106
  self.domain.element_arg_value(device),
88
107
  self.domain.element_index_arg_value(device),
89
108
  self.space_topology.topo_arg_value(device),
90
109
  self.space_partition.partition_arg_value(device),
91
- element_node_indices.array,
110
+ element_node_indices,
92
111
  ],
93
112
  device=device,
94
113
  )
95
114
 
96
115
  # Build compressed map from node to element indices
97
- flattened_node_indices = element_node_indices.array.flatten()
116
+ flattened_node_indices = element_node_indices.flatten()
98
117
  (
99
118
  self._dof_partition_element_offsets,
100
119
  node_array_indices,
101
- self._node_count,
120
+ self._node_count_dev,
102
121
  self._dof_partition_indices,
103
122
  ) = compress_node_indices(
104
123
  self.space_partition.node_count(),
105
124
  flattened_node_indices,
125
+ node_offsets=self._dof_partition_element_offsets,
126
+ unique_node_count=self._node_count_dev,
127
+ unique_node_indices=self._dof_partition_indices,
106
128
  return_unique_nodes=True,
107
129
  temporary_store=temporary_store,
108
130
  )
109
131
 
110
132
  # Extract element index and index in element
111
- self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
112
- self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
133
+ if self._dof_element_indices is None or self._dof_element_indices.shape != flattened_node_indices.shape:
134
+ self._dof_element_indices = cache.borrow_temporary_like(flattened_node_indices, temporary_store)
135
+ self._dof_indices_in_element = cache.borrow_temporary_like(flattened_node_indices, temporary_store)
136
+
113
137
  wp.launch(
114
138
  kernel=SpaceRestriction._split_vertex_element_index,
115
139
  dim=flattened_node_indices.shape,
116
140
  inputs=[
117
- MAX_NODES_PER_ELEMENT,
118
- node_array_indices.array,
119
- self._dof_element_indices.array,
120
- self._dof_indices_in_element.array,
141
+ max_nodes_per_element,
142
+ node_array_indices,
143
+ self._dof_element_indices,
144
+ self._dof_indices_in_element,
121
145
  ],
122
146
  device=flattened_node_indices.device,
123
147
  )
124
148
 
125
149
  node_array_indices.release()
126
150
 
127
- def node_count(self):
151
+ # Upper bound on node count, use `node_count_sync` to get the actual value
152
+ self._node_count = min(self.space_partition.node_count(), self._dof_partition_indices.shape[0])
153
+
154
+ def node_count_sync(self) -> int:
155
+ """Ensures that the node count is synchronized with the device and returns it"""
156
+ if self._node_count_dev is not None:
157
+ self._node_count = int(host_read_at_index(self._node_count_dev, index=0))
158
+ self._node_count_dev = None
159
+ return self.node_count()
160
+
161
+ def node_count(self) -> int:
162
+ """Upper bound for the node count, use `node_count_sync` to get the actual value"""
128
163
  return self._node_count
129
164
 
130
165
  def partition_element_offsets(self):
131
- return self._dof_partition_element_offsets.array
166
+ return self._dof_partition_element_offsets
132
167
 
133
168
  def node_partition_indices(self):
134
- return self._dof_partition_indices.array
169
+ return self._dof_partition_indices
135
170
 
136
171
  def total_node_element_count(self):
137
- return self._dof_element_indices.array.size
172
+ return self._dof_element_indices.size
138
173
 
139
174
  @wp.struct
140
175
  class NodeArg:
@@ -143,17 +178,17 @@ class SpaceRestriction:
143
178
  dof_partition_indices: wp.array(dtype=int)
144
179
  dof_indices_in_element: wp.array(dtype=int)
145
180
 
146
- @cached_arg_value
181
+ @cache.cached_arg_value
147
182
  def node_arg_value(self, device):
148
183
  arg = SpaceRestriction.NodeArg()
149
184
  self.fill_node_arg(arg, device)
150
185
  return arg
151
186
 
152
187
  def fill_node_arg(self, arg: NodeArg, device):
153
- arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
154
- arg.dof_element_indices = self._dof_element_indices.array.to(device)
155
- arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
156
- arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
188
+ arg.dof_element_offsets = self._dof_partition_element_offsets.to(device)
189
+ arg.dof_element_indices = self._dof_element_indices.to(device)
190
+ arg.dof_partition_indices = self._dof_partition_indices.to(device)
191
+ arg.dof_indices_in_element = self._dof_indices_in_element.to(device)
157
192
 
158
193
  @wp.func
159
194
  def node_partition_index(args: NodeArg, restriction_node_index: int):
@@ -0,0 +1,152 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from enum import Enum
18
+ from typing import Optional
19
+
20
+ from warp._src.fem.geometry import Element
21
+ from warp._src.fem.polynomial import Polynomial
22
+
23
+ from .cube_shape_function import (
24
+ CubeNedelecFirstKindShapeFunctions,
25
+ CubeNonConformingPolynomialShapeFunctions,
26
+ CubeRaviartThomasShapeFunctions,
27
+ CubeSerendipityShapeFunctions,
28
+ CubeShapeFunction,
29
+ CubeTripolynomialShapeFunctions,
30
+ )
31
+ from .shape_function import ConstantShapeFunction, ShapeFunction
32
+ from .square_shape_function import (
33
+ SquareBipolynomialShapeFunctions,
34
+ SquareNedelecFirstKindShapeFunctions,
35
+ SquareNonConformingPolynomialShapeFunctions,
36
+ SquareRaviartThomasShapeFunctions,
37
+ SquareSerendipityShapeFunctions,
38
+ SquareShapeFunction,
39
+ )
40
+ from .tet_shape_function import (
41
+ TetrahedronNedelecFirstKindShapeFunctions,
42
+ TetrahedronNonConformingPolynomialShapeFunctions,
43
+ TetrahedronPolynomialShapeFunctions,
44
+ TetrahedronRaviartThomasShapeFunctions,
45
+ TetrahedronShapeFunction,
46
+ )
47
+ from .triangle_shape_function import (
48
+ TriangleNedelecFirstKindShapeFunctions,
49
+ TriangleNonConformingPolynomialShapeFunctions,
50
+ TrianglePolynomialShapeFunctions,
51
+ TriangleRaviartThomasShapeFunctions,
52
+ TriangleShapeFunction,
53
+ )
54
+
55
+
56
+ class ElementBasis(Enum):
57
+ """Choice of basis function to equip individual elements"""
58
+
59
+ LAGRANGE = "P"
60
+ """Lagrange basis functions :math:`P_k` for simplices, tensor products :math:`Q_k` for squares and cubes"""
61
+ SERENDIPITY = "S"
62
+ """Serendipity elements :math:`S_k`, corresponding to Lagrange nodes with interior points removed (for degree <= 3)"""
63
+ NONCONFORMING_POLYNOMIAL = "dP"
64
+ """Simplex Lagrange basis functions :math:`P_{kd}` embedded into non conforming reference elements (e.g. squares or cubes). Discontinuous only."""
65
+ NEDELEC_FIRST_KIND = "N1"
66
+ """Nédélec (first kind) H(curl) shape functions. Should be used with covariant function space."""
67
+ RAVIART_THOMAS = "RT"
68
+ """Raviart-Thomas H(div) shape functions. Should be used with contravariant function space."""
69
+
70
+
71
+ @functools.lru_cache(maxsize=None)
72
+ def make_element_shape_function(
73
+ element: Element,
74
+ degree: int,
75
+ element_basis: Optional[ElementBasis] = None,
76
+ family: Optional[Polynomial] = None,
77
+ ) -> ShapeFunction:
78
+ """
79
+ Equips a reference element with a shape function basis.
80
+
81
+ Args:
82
+ element: the type of reference element on which to build the shape function
83
+ degree: polynomial degree of the per-element shape functions
84
+ element_basis: type of basis function for the individual elements
85
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
86
+
87
+ Returns:
88
+ the corresponding shape function
89
+
90
+ Raises:
91
+ NotImplementedError: If the shape function is not implemented for the given element type
92
+ """
93
+
94
+ if element_basis is None:
95
+ element_basis = ElementBasis.LAGRANGE
96
+ elif element_basis == ElementBasis.SERENDIPITY and degree == 1:
97
+ # Degree-1 serendipity is always equivalent to Lagrange
98
+ element_basis = ElementBasis.LAGRANGE
99
+
100
+ if degree == 0:
101
+ return ConstantShapeFunction(element)
102
+
103
+ if family is None:
104
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
105
+
106
+ if element == Element.SQUARE:
107
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
108
+ return SquareNedelecFirstKindShapeFunctions(degree=degree)
109
+ if element_basis == ElementBasis.RAVIART_THOMAS:
110
+ return SquareRaviartThomasShapeFunctions(degree=degree)
111
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
112
+ return SquareNonConformingPolynomialShapeFunctions(degree=degree)
113
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
114
+ return SquareSerendipityShapeFunctions(degree=degree, family=family)
115
+
116
+ return SquareBipolynomialShapeFunctions(degree=degree, family=family)
117
+ if element == Element.TRIANGLE:
118
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
119
+ return TriangleNedelecFirstKindShapeFunctions(degree=degree)
120
+ if element_basis == ElementBasis.RAVIART_THOMAS:
121
+ return TriangleRaviartThomasShapeFunctions(degree=degree)
122
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
123
+ return TriangleNonConformingPolynomialShapeFunctions(degree=degree)
124
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
125
+ raise NotImplementedError("Serendipity variant not implemented yet for Triangle elements")
126
+
127
+ return TrianglePolynomialShapeFunctions(degree=degree)
128
+
129
+ if element == Element.CUBE:
130
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
131
+ return CubeNedelecFirstKindShapeFunctions(degree=degree)
132
+ if element_basis == ElementBasis.RAVIART_THOMAS:
133
+ return CubeRaviartThomasShapeFunctions(degree=degree)
134
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
135
+ return CubeNonConformingPolynomialShapeFunctions(degree=degree)
136
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
137
+ return CubeSerendipityShapeFunctions(degree=degree, family=family)
138
+
139
+ return CubeTripolynomialShapeFunctions(degree=degree, family=family)
140
+ if element == Element.TETRAHEDRON:
141
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
142
+ return TetrahedronNedelecFirstKindShapeFunctions(degree=degree)
143
+ if element_basis == ElementBasis.RAVIART_THOMAS:
144
+ return TetrahedronRaviartThomasShapeFunctions(degree=degree)
145
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
146
+ return TetrahedronNonConformingPolynomialShapeFunctions(degree=degree)
147
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
148
+ raise NotImplementedError("Serendipity variant not implemented yet for Tet elements")
149
+
150
+ return TetrahedronPolynomialShapeFunctions(degree=degree)
151
+
152
+ raise NotImplementedError(f"Unrecognized element type {element}")
@@ -18,14 +18,16 @@ import math
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
- from warp.fem import cache
22
- from warp.fem.geometry import Grid3D
23
- from warp.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
24
- from warp.fem.types import Coords
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.geometry import Grid3D
23
+ from warp._src.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
24
+ from warp._src.fem.types import Coords
25
25
 
26
26
  from .shape_function import ShapeFunction
27
27
  from .tet_shape_function import TetrahedronPolynomialShapeFunctions
28
28
 
29
+ _wp_module_name_ = "warp.fem.space.shape.cube_shape_function"
30
+
29
31
 
30
32
  class CubeShapeFunction(ShapeFunction):
31
33
  VERTEX = 0
@@ -79,7 +81,7 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
79
81
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
80
82
  lagrange_scale = lagrange_scales(lobatto_coords)
81
83
 
82
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
84
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
83
85
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
84
86
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
85
87
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -395,12 +397,12 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
395
397
  return cache.get_func(element_inner_weight_gradient, self.name)
396
398
 
397
399
  def element_node_hexes(self):
398
- from warp.fem.utils import grid_to_hexes
400
+ from warp._src.fem.utils import grid_to_hexes
399
401
 
400
402
  return grid_to_hexes(self.ORDER, self.ORDER, self.ORDER)
401
403
 
402
404
  def element_node_tets(self):
403
- from warp.fem.utils import grid_to_tets
405
+ from warp._src.fem.utils import grid_to_tets
404
406
 
405
407
  return grid_to_tets(self.ORDER, self.ORDER, self.ORDER)
406
408
 
@@ -514,7 +516,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
514
516
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
515
517
  lagrange_scale = lagrange_scales(lobatto_coords)
516
518
 
517
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
519
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
518
520
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
519
521
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
520
522
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -751,7 +753,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
751
753
  return element_inner_weight_gradient
752
754
 
753
755
  def element_node_tets(self):
754
- from warp.fem.utils import grid_to_tets
756
+ from warp._src.fem.utils import grid_to_tets
755
757
 
756
758
  if self.ORDER == 2:
757
759
  element_tets = np.array(
@@ -18,9 +18,11 @@ from enum import Enum
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
- from warp.fem import cache
22
- from warp.fem.geometry import Element
23
- from warp.fem.types import Coords
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.geometry import Element
23
+ from warp._src.fem.types import Coords
24
+
25
+ _wp_module_name_ = "warp.fem.space.shape.shape_function"
24
26
 
25
27
 
26
28
  class ShapeFunction:
@@ -68,19 +70,18 @@ class ShapeFunction:
68
70
  class ConstantShapeFunction(ShapeFunction):
69
71
  """Shape function that is constant over the element"""
70
72
 
71
- def __init__(self, element: Element, space_dimension: int):
72
- self._element = element
73
- self._dimension = space_dimension
73
+ def __init__(self, element: Element):
74
+ self._element_prototype = element.prototype
74
75
 
75
76
  self.ORDER = wp.constant(0)
76
77
  self.NODES_PER_ELEMENT = wp.constant(1)
77
78
 
78
- coords, _ = element.instantiate_quadrature(order=0, family=None)
79
+ coords, _ = self._element_prototype.instantiate_quadrature(order=0, family=None)
79
80
  self.COORDS = wp.constant(coords[0])
80
81
 
81
82
  @property
82
83
  def name(self) -> str:
83
- return f"{self._element.__class__.__name__}{self._dimension}"
84
+ return f"{self._element_prototype.__name__}"
84
85
 
85
86
  def make_node_coords_in_element(self):
86
87
  COORDS = self.COORDS
@@ -116,7 +117,7 @@ class ConstantShapeFunction(ShapeFunction):
116
117
  return ConstantShapeFunction._element_inner_weight
117
118
 
118
119
  def make_element_inner_weight_gradient(self):
119
- grad_type = wp.vec(length=self._dimension, dtype=float)
120
+ grad_type = cache.cached_vec_type(length=self._element_prototype.dimension, dtype=float)
120
121
 
121
122
  @cache.dynamic_func(suffix=self.name)
122
123
  def element_inner_weight_gradient(
@@ -18,13 +18,15 @@ import math
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
- from warp.fem import cache
22
- from warp.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
23
- from warp.fem.types import Coords
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
23
+ from warp._src.fem.types import Coords
24
24
 
25
25
  from .shape_function import ShapeFunction
26
26
  from .triangle_shape_function import TrianglePolynomialShapeFunctions
27
27
 
28
+ _wp_module_name_ = "warp.fem.space.shape.square_shape_function"
29
+
28
30
 
29
31
  class SquareShapeFunction(ShapeFunction):
30
32
  VERTEX = 0
@@ -68,7 +70,7 @@ class SquareBipolynomialShapeFunctions(SquareShapeFunction):
68
70
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
69
71
  lagrange_scale = lagrange_scales(lobatto_coords)
70
72
 
71
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
73
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
72
74
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
73
75
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
74
76
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -291,7 +293,7 @@ class SquareBipolynomialShapeFunctions(SquareShapeFunction):
291
293
  return cache.get_func(element_inner_weight_gradient, self.name)
292
294
 
293
295
  def element_node_triangulation(self):
294
- from warp.fem.utils import grid_to_tris
296
+ from warp._src.fem.utils import grid_to_tris
295
297
 
296
298
  return grid_to_tris(self.ORDER, self.ORDER)
297
299
 
@@ -348,7 +350,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
348
350
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
349
351
  lagrange_scale = lagrange_scales(lobatto_coords)
350
352
 
351
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
353
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
352
354
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
353
355
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
354
356
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
@@ -16,11 +16,13 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.types import Coords
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.types import Coords
21
21
 
22
22
  from .shape_function import ShapeFunction
23
23
 
24
+ _wp_module_name_ = "warp.fem.space.shape.tet_shape_function"
25
+
24
26
 
25
27
  def _tet_node_index(tx: int, ty: int, tz: int, degree: int):
26
28
  from .triangle_shape_function import _triangle_node_index
@@ -182,7 +184,7 @@ class TetrahedronPolynomialShapeFunctions(TetrahedronShapeFunction):
182
184
  index = _tet_node_index(tx, ty, tz, degree)
183
185
  tet_coords[index] = [tx, ty, tz]
184
186
 
185
- CoordTypeVec = wp.mat(dtype=int, shape=(self.NODES_PER_ELEMENT, 3))
187
+ CoordTypeVec = cache.cached_mat_type(dtype=int, shape=(self.NODES_PER_ELEMENT, 3))
186
188
  self.NODE_TET_COORDS = wp.constant(CoordTypeVec(tet_coords))
187
189
 
188
190
  self.node_type_and_type_index = self._get_node_type_and_type_index()
@@ -16,11 +16,13 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.types import Coords
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.types import Coords
21
21
 
22
22
  from .shape_function import ShapeFunction
23
23
 
24
+ _wp_module_name_ = "warp.fem.space.shape.triangle_shape_function"
25
+
24
26
 
25
27
  def _triangle_node_index(tx: int, ty: int, degree: int):
26
28
  VERTEX_NODE_COUNT = 3
@@ -95,7 +97,7 @@ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
95
97
  index = _triangle_node_index(tx, ty, degree)
96
98
  triangle_coords[index] = [tx, ty]
97
99
 
98
- CoordTypeVec = wp.mat(dtype=int, shape=(self.NODES_PER_ELEMENT, 2))
100
+ CoordTypeVec = cache.cached_mat_type(dtype=int, shape=(self.NODES_PER_ELEMENT, 2))
99
101
  self.NODE_TRIANGLE_COORDS = wp.constant(CoordTypeVec(triangle_coords))
100
102
 
101
103
  self.node_type_and_type_index = self._get_node_type_and_type_index()
@@ -14,9 +14,9 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import warp as wp
17
- from warp.fem import cache
18
- from warp.fem.geometry import Tetmesh
19
- from warp.fem.types import ElementIndex
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Tetmesh
19
+ from warp._src.fem.types import ElementIndex
20
20
 
21
21
  from .shape import (
22
22
  ShapeFunction,
@@ -25,6 +25,8 @@ from .shape import (
25
25
  )
26
26
  from .topology import SpaceTopology, forward_base_topology
27
27
 
28
+ _wp_module_name_ = "warp.fem.space.tetmesh_function_space"
29
+
28
30
 
29
31
  @wp.struct
30
32
  class TetmeshTopologyArg:
@@ -72,12 +74,6 @@ class TetmeshSpaceTopology(SpaceTopology):
72
74
  def name(self):
73
75
  return f"{self.geometry.name}_{self._shape.name}"
74
76
 
75
- @cache.cached_arg_value
76
- def topo_arg_value(self, device):
77
- arg = TetmeshTopologyArg()
78
- self.fill_topo_arg(arg, device)
79
- return arg
80
-
81
77
  def fill_topo_arg(self, arg: TetmeshTopologyArg, device):
82
78
  arg.tet_face_indices = self._tet_face_indices.to(device)
83
79
  arg.tet_edge_indices = self._tet_edge_indices.to(device)