warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,681 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import ClassVar, Optional
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp._src.fem import cache
23
+ from warp._src.fem.geometry import Geometry
24
+ from warp._src.fem.quadrature import Quadrature
25
+ from warp._src.fem.types import (
26
+ NULL_ELEMENT_INDEX,
27
+ NULL_QP_INDEX,
28
+ Coords,
29
+ ElementIndex,
30
+ QuadraturePointIndex,
31
+ make_free_sample,
32
+ )
33
+ from warp._src.types import type_repr, types_equal
34
+
35
+ from .shape import ShapeFunction
36
+ from .topology import SpaceTopology
37
+
38
+ _wp_module_name_ = "warp.fem.space.basis_space"
39
+
40
+
41
+ class BasisSpace:
42
+ """Interface class for defining a shape function space over a geometry.
43
+
44
+ A basis space makes it easy to define multiple function spaces sharing the same basis (and thus nodes) but with different valuation functions;
45
+ however, it is not a required component of a function space.
46
+
47
+ See also: :func:`make_polynomial_basis_space`, :func:`make_collocated_function_space`
48
+ """
49
+
50
+ @wp.struct
51
+ class BasisArg:
52
+ """Argument structure to be passed to device functions"""
53
+
54
+ pass
55
+
56
+ def __init__(self, topology: SpaceTopology):
57
+ self._topology = topology
58
+
59
+ @property
60
+ def topology(self) -> SpaceTopology:
61
+ """Underlying topology of the basis space"""
62
+ return self._topology
63
+
64
+ @property
65
+ def geometry(self) -> Geometry:
66
+ """Underlying geometry of the basis space"""
67
+ return self._topology.geometry
68
+
69
+ @property
70
+ def value(self) -> ShapeFunction.Value:
71
+ """Value type for the underlying shape functions"""
72
+ raise NotImplementedError()
73
+
74
+ @cache.cached_arg_value
75
+ def basis_arg_value(self, device) -> "BasisArg":
76
+ """Value for the argument structure to be passed to device functions"""
77
+ arg = self.BasisArg()
78
+ self.fill_basis_arg(arg, device)
79
+ return arg
80
+
81
+ def fill_basis_arg(self, arg, device):
82
+ pass
83
+
84
+ # Helpers for generating node positions
85
+
86
+ def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
87
+ """Returns a temporary array containing the world position for each node"""
88
+
89
+ pos_type = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
90
+
91
+ node_coords_in_element = self.make_node_coords_in_element()
92
+
93
+ @cache.dynamic_kernel(suffix=self.name, kernel_options={"max_unroll": 4, "enable_backward": False})
94
+ def fill_node_positions(
95
+ geo_cell_arg: self.geometry.CellArg,
96
+ basis_arg: self.BasisArg,
97
+ topo_arg: self.topology.TopologyArg,
98
+ node_positions: wp.array(dtype=pos_type),
99
+ ):
100
+ element_index = wp.tid()
101
+
102
+ element_node_count = self.topology.element_node_count(geo_cell_arg, topo_arg, element_index)
103
+ for n in range(element_node_count):
104
+ node_index = self.topology.element_node_index(geo_cell_arg, topo_arg, element_index, n)
105
+ coords = node_coords_in_element(geo_cell_arg, basis_arg, element_index, n)
106
+
107
+ sample = make_free_sample(element_index, coords)
108
+ pos = self.geometry.cell_position(geo_cell_arg, sample)
109
+
110
+ node_positions[node_index] = pos
111
+
112
+ shape = (self.topology.node_count(),)
113
+ if out is None:
114
+ node_positions = wp.empty(
115
+ shape=shape,
116
+ dtype=pos_type,
117
+ )
118
+ else:
119
+ if out.shape != shape or not types_equal(pos_type, out.dtype):
120
+ raise ValueError(
121
+ f"Out node positions array must have shape {shape} and data type {type_repr(pos_type)}"
122
+ )
123
+ node_positions = out
124
+
125
+ wp.launch(
126
+ dim=self.geometry.cell_count(),
127
+ kernel=fill_node_positions,
128
+ inputs=[
129
+ self.geometry.cell_arg_value(device=node_positions.device),
130
+ self.basis_arg_value(device=node_positions.device),
131
+ self.topology.topo_arg_value(device=node_positions.device),
132
+ node_positions,
133
+ ],
134
+ )
135
+
136
+ return node_positions
137
+
138
+ def make_node_coords_in_element(self):
139
+ raise NotImplementedError()
140
+
141
+ def make_node_quadrature_weight(self):
142
+ raise NotImplementedError()
143
+
144
+ def make_element_inner_weight(self):
145
+ raise NotImplementedError()
146
+
147
+ def make_element_outer_weight(self):
148
+ return self.make_element_inner_weight()
149
+
150
+ def make_element_inner_weight_gradient(self):
151
+ raise NotImplementedError()
152
+
153
+ def make_element_outer_weight_gradient(self):
154
+ return self.make_element_inner_weight_gradient()
155
+
156
+ def make_trace_node_quadrature_weight(self):
157
+ raise NotImplementedError()
158
+
159
+ def trace(self) -> "TraceBasisSpace":
160
+ return TraceBasisSpace(self)
161
+
162
+ @property
163
+ def weight_type(self):
164
+ if self.value is ShapeFunction.Value.Scalar:
165
+ return float
166
+
167
+ return cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=float)
168
+
169
+ @property
170
+ def weight_gradient_type(self):
171
+ if self.value is ShapeFunction.Value.Scalar:
172
+ return wp.vec(length=self.geometry.cell_dimension, dtype=float)
173
+
174
+ return cache.cached_mat_type(shape=(self.geometry.cell_dimension, self.geometry.cell_dimension), dtype=float)
175
+
176
+
177
+ class ShapeBasisSpace(BasisSpace):
178
+ """Base class for defining shape-function-based basis spaces."""
179
+
180
+ def __init__(self, topology: SpaceTopology, shape: ShapeFunction):
181
+ super().__init__(topology)
182
+ self._shape = shape
183
+
184
+ if self.value is not ShapeFunction.Value.Scalar:
185
+ self.BasisArg = self.topology.TopologyArg
186
+ self.basis_arg_value = self.topology.topo_arg_value
187
+ self.fill_basis_arg = self.topology.fill_topo_arg
188
+
189
+ self.ORDER = self._shape.ORDER
190
+
191
+ if hasattr(shape, "element_node_triangulation"):
192
+ self.node_triangulation = self._node_triangulation
193
+ if hasattr(shape, "element_node_tets"):
194
+ self.node_tets = self._node_tets
195
+ if hasattr(shape, "element_node_hexes"):
196
+ self.node_hexes = self._node_hexes
197
+ if hasattr(shape, "element_vtk_cells"):
198
+ self.vtk_cells = self._vtk_cells
199
+ if hasattr(topology, "node_grid"):
200
+ self.node_grid = topology.node_grid
201
+
202
+ @property
203
+ def shape(self) -> ShapeFunction:
204
+ """Shape functions used for defining individual element basis"""
205
+ return self._shape
206
+
207
+ @property
208
+ def value(self) -> ShapeFunction.Value:
209
+ return self.shape.value
210
+
211
+ @cached_property
212
+ def name(self):
213
+ return f"{self.topology.name}_{self._shape.name}"
214
+
215
+ def make_node_coords_in_element(self):
216
+ shape_node_coords_in_element = self._shape.make_node_coords_in_element()
217
+
218
+ @cache.dynamic_func(suffix=self.name)
219
+ def node_coords_in_element(
220
+ elt_arg: self.geometry.CellArg,
221
+ basis_arg: self.BasisArg,
222
+ element_index: ElementIndex,
223
+ node_index_in_elt: int,
224
+ ):
225
+ return shape_node_coords_in_element(node_index_in_elt)
226
+
227
+ return node_coords_in_element
228
+
229
+ def make_node_quadrature_weight(self):
230
+ shape_node_quadrature_weight = self._shape.make_node_quadrature_weight()
231
+
232
+ if shape_node_quadrature_weight is None:
233
+ return None
234
+
235
+ @cache.dynamic_func(suffix=self.name)
236
+ def node_quadrature_weight(
237
+ elt_arg: self.geometry.CellArg,
238
+ basis_arg: self.BasisArg,
239
+ element_index: ElementIndex,
240
+ node_index_in_elt: int,
241
+ ):
242
+ return shape_node_quadrature_weight(node_index_in_elt)
243
+
244
+ return node_quadrature_weight
245
+
246
+ def make_element_inner_weight(self):
247
+ shape_element_inner_weight = self._shape.make_element_inner_weight()
248
+
249
+ @cache.dynamic_func(suffix=self.name)
250
+ def element_inner_weight(
251
+ elt_arg: self.geometry.CellArg,
252
+ basis_arg: self.BasisArg,
253
+ element_index: ElementIndex,
254
+ coords: Coords,
255
+ node_index_in_elt: int,
256
+ qp_index: QuadraturePointIndex,
257
+ ):
258
+ if wp.static(self.value == ShapeFunction.Value.Scalar):
259
+ return shape_element_inner_weight(coords, node_index_in_elt)
260
+ else:
261
+ sign = self.topology.element_node_sign(elt_arg, basis_arg, element_index, node_index_in_elt)
262
+ return sign * shape_element_inner_weight(coords, node_index_in_elt)
263
+
264
+ return element_inner_weight
265
+
266
+ def make_element_inner_weight_gradient(self):
267
+ shape_element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
268
+
269
+ @cache.dynamic_func(suffix=self.name)
270
+ def element_inner_weight_gradient(
271
+ elt_arg: self.geometry.CellArg,
272
+ basis_arg: self.BasisArg,
273
+ element_index: ElementIndex,
274
+ coords: Coords,
275
+ node_index_in_elt: int,
276
+ qp_index: QuadraturePointIndex,
277
+ ):
278
+ if wp.static(self.value == ShapeFunction.Value.Scalar):
279
+ return shape_element_inner_weight_gradient(coords, node_index_in_elt)
280
+ else:
281
+ sign = self.topology.element_node_sign(elt_arg, basis_arg, element_index, node_index_in_elt)
282
+ return sign * shape_element_inner_weight_gradient(coords, node_index_in_elt)
283
+
284
+ return element_inner_weight_gradient
285
+
286
+ def make_trace_node_quadrature_weight(self, trace_basis):
287
+ shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight()
288
+
289
+ if shape_trace_node_quadrature_weight is None:
290
+ return None
291
+
292
+ @cache.dynamic_func(suffix=self.name)
293
+ def trace_node_quadrature_weight(
294
+ geo_side_arg: trace_basis.geometry.SideArg,
295
+ basis_arg: trace_basis.BasisArg,
296
+ element_index: ElementIndex,
297
+ node_index_in_elt: int,
298
+ ):
299
+ neighbour_elem, index_in_neighbour = trace_basis.topology.neighbor_cell_index(
300
+ geo_side_arg, element_index, node_index_in_elt
301
+ )
302
+ return shape_trace_node_quadrature_weight(index_in_neighbour)
303
+
304
+ return trace_node_quadrature_weight
305
+
306
+ def trace(self) -> "TraceBasisSpace":
307
+ if self.ORDER == 0:
308
+ return PiecewiseConstantBasisSpaceTrace(self)
309
+
310
+ return TraceBasisSpace(self)
311
+
312
+ def _node_triangulation(self):
313
+ element_node_indices = self._topology.element_node_indices().numpy()
314
+ element_triangles = self._shape.element_node_triangulation()
315
+
316
+ tri_indices = element_node_indices[:, element_triangles].reshape(-1, 3)
317
+ return tri_indices
318
+
319
+ def _node_tets(self):
320
+ element_node_indices = self._topology.element_node_indices().numpy()
321
+ element_tets = self._shape.element_node_tets()
322
+
323
+ tet_indices = element_node_indices[:, element_tets].reshape(-1, 4)
324
+ return tet_indices
325
+
326
+ def _node_hexes(self):
327
+ element_node_indices = self._topology.element_node_indices().numpy()
328
+ element_hexes = self._shape.element_node_hexes()
329
+
330
+ hex_indices = element_node_indices[:, element_hexes].reshape(-1, 8)
331
+ return hex_indices
332
+
333
+ def _vtk_cells(self):
334
+ element_node_indices = self._topology.element_node_indices().numpy()
335
+ element_vtk_cells, element_vtk_cell_types = self._shape.element_vtk_cells()
336
+
337
+ idx_per_cell = element_vtk_cells.shape[1]
338
+ cell_indices = element_node_indices[:, element_vtk_cells].reshape(-1, idx_per_cell)
339
+ cells = np.hstack((np.full((cell_indices.shape[0], 1), idx_per_cell), cell_indices))
340
+
341
+ return cells.flatten(), np.tile(element_vtk_cell_types, element_node_indices.shape[0])
342
+
343
+
344
+ class TraceBasisSpace(BasisSpace):
345
+ """Auto-generated trace space evaluating the cell-defined basis on the geometry sides"""
346
+
347
+ def __init__(self, basis: BasisSpace):
348
+ super().__init__(basis.topology.trace())
349
+
350
+ self.ORDER = basis.ORDER
351
+
352
+ self._basis = basis
353
+ self.BasisArg = self._basis.BasisArg
354
+ self.basis_arg_value = self._basis.basis_arg_value
355
+ self.fill_basis_arg = self._basis.fill_basis_arg
356
+
357
+ @property
358
+ def name(self):
359
+ return f"{self._basis.name}_Trace"
360
+
361
+ @property
362
+ def value(self) -> ShapeFunction.Value:
363
+ return self._basis.value
364
+
365
+ def make_node_coords_in_element(self):
366
+ node_coords_in_cell = self._basis.make_node_coords_in_element()
367
+
368
+ @cache.dynamic_func(suffix=self._basis.name)
369
+ def trace_node_coords_in_element(
370
+ geo_side_arg: self.geometry.SideArg,
371
+ basis_arg: self.BasisArg,
372
+ element_index: ElementIndex,
373
+ node_index_in_elt: int,
374
+ ):
375
+ neighbour_elem, index_in_neighbour = self.topology.neighbor_cell_index(
376
+ geo_side_arg, element_index, node_index_in_elt
377
+ )
378
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
379
+ neighbour_coords = node_coords_in_cell(
380
+ geo_cell_arg,
381
+ basis_arg,
382
+ neighbour_elem,
383
+ index_in_neighbour,
384
+ )
385
+
386
+ return self.geometry.side_from_cell_coords(geo_side_arg, element_index, neighbour_elem, neighbour_coords)
387
+
388
+ return trace_node_coords_in_element
389
+
390
+ def make_node_quadrature_weight(self):
391
+ return self._basis.make_trace_node_quadrature_weight(self)
392
+
393
+ def make_element_inner_weight(self):
394
+ cell_inner_weight = self._basis.make_element_inner_weight()
395
+
396
+ @cache.dynamic_func(suffix=self._basis.name)
397
+ def trace_element_inner_weight(
398
+ geo_side_arg: self.geometry.SideArg,
399
+ basis_arg: self.BasisArg,
400
+ element_index: ElementIndex,
401
+ coords: Coords,
402
+ node_index_in_elt: int,
403
+ qp_index: QuadraturePointIndex,
404
+ ):
405
+ cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
406
+ if cell_index == NULL_ELEMENT_INDEX:
407
+ return self.weight_type(0.0)
408
+
409
+ cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
410
+
411
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
412
+ return cell_inner_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
413
+
414
+ return trace_element_inner_weight
415
+
416
+ def make_element_outer_weight(self):
417
+ cell_outer_weight = self._basis.make_element_outer_weight()
418
+
419
+ @cache.dynamic_func(suffix=self._basis.name)
420
+ def trace_element_outer_weight(
421
+ geo_side_arg: self.geometry.SideArg,
422
+ basis_arg: self.BasisArg,
423
+ element_index: ElementIndex,
424
+ coords: Coords,
425
+ node_index_in_elt: int,
426
+ qp_index: QuadraturePointIndex,
427
+ ):
428
+ cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
429
+ if cell_index == NULL_ELEMENT_INDEX:
430
+ return self.weight_type(0.0)
431
+
432
+ cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
433
+
434
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
435
+ return cell_outer_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
436
+
437
+ return trace_element_outer_weight
438
+
439
+ def make_element_inner_weight_gradient(self):
440
+ cell_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
441
+
442
+ @cache.dynamic_func(suffix=self._basis.name)
443
+ def trace_element_inner_weight_gradient(
444
+ geo_side_arg: self.geometry.SideArg,
445
+ basis_arg: self.BasisArg,
446
+ element_index: ElementIndex,
447
+ coords: Coords,
448
+ node_index_in_elt: int,
449
+ qp_index: QuadraturePointIndex,
450
+ ):
451
+ cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
452
+ if cell_index == NULL_ELEMENT_INDEX:
453
+ return self.weight_gradient_type(0.0)
454
+
455
+ cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
456
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
457
+ return cell_inner_weight_gradient(
458
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
459
+ )
460
+
461
+ return trace_element_inner_weight_gradient
462
+
463
+ def make_element_outer_weight_gradient(self):
464
+ cell_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
465
+
466
+ @cache.dynamic_func(suffix=self._basis.name)
467
+ def trace_element_outer_weight_gradient(
468
+ geo_side_arg: self.geometry.SideArg,
469
+ basis_arg: self.BasisArg,
470
+ element_index: ElementIndex,
471
+ coords: Coords,
472
+ node_index_in_elt: int,
473
+ qp_index: QuadraturePointIndex,
474
+ ):
475
+ cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
476
+ if cell_index == NULL_ELEMENT_INDEX:
477
+ return self.weight_gradient_type(0.0)
478
+
479
+ cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
480
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
481
+ return cell_outer_weight_gradient(
482
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
483
+ )
484
+
485
+ return trace_element_outer_weight_gradient
486
+
487
+ def __eq__(self, other: "TraceBasisSpace") -> bool:
488
+ return self._topo == other._topo
489
+
490
+
491
+ class PiecewiseConstantBasisSpaceTrace(TraceBasisSpace):
492
+ def make_node_coords_in_element(self):
493
+ # Makes the single node visible to all sides; useful for interpolating on boundaries
494
+ # For higher-order non-conforming elements direct interpolation on boundary is not possible,
495
+ # need to do proper integration then solve with mass matrix
496
+
497
+ CENTER_COORDS = Coords(self.geometry.reference_side().prototype.center())
498
+
499
+ @cache.dynamic_func(suffix=self._basis.name)
500
+ def trace_node_coords_in_element(
501
+ geo_side_arg: self.geometry.SideArg,
502
+ basis_arg: self.BasisArg,
503
+ element_index: ElementIndex,
504
+ node_index_in_elt: int,
505
+ ):
506
+ return CENTER_COORDS
507
+
508
+ return trace_node_coords_in_element
509
+
510
+
511
+ class UnstructuredPointTopology(SpaceTopology):
512
+ """Topology for unstructured points defined from quadrature formula. See :class:`PointBasisSpace`"""
513
+
514
+ _dynamic_attribute_constructors: ClassVar = {
515
+ "element_node_index": lambda obj: obj._make_element_node_index(),
516
+ "element_node_count": lambda obj: obj._make_element_node_count(),
517
+ "side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
518
+ }
519
+
520
+ def __init__(self, quadrature: Quadrature):
521
+ if quadrature.max_points_per_element() is None:
522
+ raise ValueError("Quadrature must define a maximum number of points per element")
523
+
524
+ if quadrature.domain.element_count() != quadrature.domain.geometry_element_count():
525
+ raise ValueError("Point topology may only be defined on quadrature domains than span the whole geometry")
526
+
527
+ self._quadrature = quadrature
528
+ self.TopologyArg = quadrature.Arg
529
+ self.topo_arg_value = quadrature.arg_value
530
+ self.fill_topo_arg = quadrature.fill_arg
531
+
532
+ super().__init__(quadrature.domain.geometry, max_nodes_per_element=quadrature.max_points_per_element())
533
+
534
+ cache.setup_dynamic_attributes(self, cls=__class__)
535
+
536
+ def node_count(self):
537
+ return self._quadrature.total_point_count()
538
+
539
+ @property
540
+ def name(self):
541
+ return f"PointTopology_{self._quadrature}"
542
+
543
+ def topo_arg_value(self, device) -> SpaceTopology.TopologyArg:
544
+ """Value of the topology argument structure to be passed to device functions"""
545
+ return self._quadrature.arg_value(device)
546
+
547
+ def _make_element_node_index(self):
548
+ @cache.dynamic_func(suffix=self.name)
549
+ def element_node_index(
550
+ elt_arg: self.geometry.CellArg,
551
+ topo_arg: self.TopologyArg,
552
+ element_index: ElementIndex,
553
+ node_index_in_elt: int,
554
+ ):
555
+ return self._quadrature.point_index(elt_arg, topo_arg, element_index, element_index, node_index_in_elt)
556
+
557
+ return element_node_index
558
+
559
+ def _make_element_node_count(self):
560
+ @cache.dynamic_func(suffix=self.name)
561
+ def element_node_count(
562
+ elt_arg: self.geometry.CellArg,
563
+ topo_arg: self.TopologyArg,
564
+ element_index: ElementIndex,
565
+ ):
566
+ return self._quadrature.point_count(elt_arg, topo_arg, element_index, element_index)
567
+
568
+ return element_node_count
569
+
570
+ def _make_side_neighbor_node_counts(self):
571
+ MAX_NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
572
+
573
+ @cache.dynamic_func(suffix=self.name)
574
+ def side_neighbor_node_counts(
575
+ side_arg: self.geometry.SideArg,
576
+ element_index: ElementIndex,
577
+ ):
578
+ return MAX_NODES_PER_ELEMENT, MAX_NODES_PER_ELEMENT
579
+
580
+ return side_neighbor_node_counts
581
+
582
+
583
+ class PointBasisSpace(BasisSpace):
584
+ """An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only.
585
+
586
+ The node locations and nodal quadrature weights are defined by a :class:`Quadrature` formula.
587
+ """
588
+
589
+ def __init__(self, quadrature: Quadrature):
590
+ self._quadrature = quadrature
591
+
592
+ topology = UnstructuredPointTopology(quadrature)
593
+ super().__init__(topology)
594
+
595
+ self.BasisArg = quadrature.Arg
596
+ self.basis_arg_value = quadrature.arg_value
597
+ self.fill_basis_arg = quadrature.fill_arg
598
+
599
+ self.ORDER = 0
600
+
601
+ self.make_element_outer_weight = self.make_element_inner_weight
602
+ self.make_element_outer_weight_gradient = self.make_element_outer_weight_gradient
603
+
604
+ @property
605
+ def name(self):
606
+ return f"{self._quadrature.name}_Point"
607
+
608
+ @property
609
+ def value(self) -> ShapeFunction.Value:
610
+ return ShapeFunction.Value.Scalar
611
+
612
+ def make_node_coords_in_element(self):
613
+ @cache.dynamic_func(suffix=self.name)
614
+ def node_coords_in_element(
615
+ elt_arg: self._quadrature.domain.ElementArg,
616
+ basis_arg: self.BasisArg,
617
+ element_index: ElementIndex,
618
+ node_index_in_elt: int,
619
+ ):
620
+ return self._quadrature.point_coords(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
621
+
622
+ return node_coords_in_element
623
+
624
+ def make_node_quadrature_weight(self):
625
+ @cache.dynamic_func(suffix=self.name)
626
+ def node_quadrature_weight(
627
+ elt_arg: self._quadrature.domain.ElementArg,
628
+ basis_arg: self.BasisArg,
629
+ element_index: ElementIndex,
630
+ node_index_in_elt: int,
631
+ ):
632
+ return self._quadrature.point_weight(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
633
+
634
+ return node_quadrature_weight
635
+
636
+ def make_element_inner_weight(self):
637
+ _DIRAC_INTEGRATION_RADIUS = wp.constant(1.0e-6)
638
+
639
+ @cache.dynamic_func(suffix=self.name)
640
+ def element_inner_weight(
641
+ elt_arg: self._quadrature.domain.ElementArg,
642
+ basis_arg: self.BasisArg,
643
+ element_index: ElementIndex,
644
+ coords: Coords,
645
+ node_index_in_elt: int,
646
+ qp_index: QuadraturePointIndex,
647
+ ):
648
+ qp_coord = self._quadrature.point_coords(
649
+ elt_arg, basis_arg, element_index, element_index, node_index_in_elt
650
+ )
651
+ return wp.where(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 1.0, 0.0)
652
+
653
+ return element_inner_weight
654
+
655
+ def make_element_inner_weight_gradient(self):
656
+ gradient_vec = cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=float)
657
+
658
+ @cache.dynamic_func(suffix=self.name)
659
+ def element_inner_weight_gradient(
660
+ elt_arg: self._quadrature.domain.ElementArg,
661
+ basis_arg: self.BasisArg,
662
+ element_index: ElementIndex,
663
+ coords: Coords,
664
+ node_index_in_elt: int,
665
+ qp_index: QuadraturePointIndex,
666
+ ):
667
+ return gradient_vec(0.0)
668
+
669
+ return element_inner_weight_gradient
670
+
671
+ def make_trace_node_quadrature_weight(self, trace_basis):
672
+ @cache.dynamic_func(suffix=self.name)
673
+ def trace_node_quadrature_weight(
674
+ elt_arg: trace_basis.geometry.SideArg,
675
+ basis_arg: trace_basis.BasisArg,
676
+ element_index: ElementIndex,
677
+ node_index_in_elt: int,
678
+ ):
679
+ return 0.0
680
+
681
+ return trace_node_quadrature_weight
@@ -18,9 +18,11 @@ from enum import Enum
18
18
  from typing import Any
19
19
 
20
20
  import warp as wp
21
- import warp.types
21
+ from warp._src.types import type_size
22
22
 
23
- vec6 = wp.types.vector(length=6, dtype=wp.float32)
23
+ _wp_module_name_ = "warp.fem.space.dof_mapper"
24
+
25
+ vec6 = wp.vec(length=6, dtype=wp.float32)
24
26
 
25
27
  _SQRT_2 = wp.constant(math.sqrt(2.0))
26
28
  _SQRT_3 = wp.constant(math.sqrt(3.0))
@@ -57,7 +59,7 @@ class IdentityMapper(DofMapper):
57
59
  self.value_dtype = dtype
58
60
  self.dof_dtype = dtype
59
61
 
60
- size = warp.types.type_size(dtype)
62
+ size = type_size(dtype)
61
63
  self.DOF_SIZE = wp.constant(size)
62
64
 
63
65
  @wp.func