warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,679 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import ClassVar, Optional
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp._src.fem import cache
23
+ from warp._src.fem.geometry import Geometry
24
+ from warp._src.fem.quadrature import Quadrature
25
+ from warp._src.fem.types import (
26
+ NULL_ELEMENT_INDEX,
27
+ NULL_QP_INDEX,
28
+ Coords,
29
+ ElementIndex,
30
+ QuadraturePointIndex,
31
+ make_free_sample,
32
+ )
33
+ from warp._src.types import type_repr, types_equal
34
+
35
+ from .shape import ShapeFunction
36
+ from .topology import SpaceTopology
37
+
38
+
39
+ class BasisSpace:
40
+ """Interface class for defining a shape function space over a geometry.
41
+
42
+ A basis space makes it easy to define multiple function spaces sharing the same basis (and thus nodes) but with different valuation functions;
43
+ however, it is not a required component of a function space.
44
+
45
+ See also: :func:`make_polynomial_basis_space`, :func:`make_collocated_function_space`
46
+ """
47
+
48
+ @wp.struct
49
+ class BasisArg:
50
+ """Argument structure to be passed to device functions"""
51
+
52
+ pass
53
+
54
+ def __init__(self, topology: SpaceTopology):
55
+ self._topology = topology
56
+
57
+ @property
58
+ def topology(self) -> SpaceTopology:
59
+ """Underlying topology of the basis space"""
60
+ return self._topology
61
+
62
+ @property
63
+ def geometry(self) -> Geometry:
64
+ """Underlying geometry of the basis space"""
65
+ return self._topology.geometry
66
+
67
+ @property
68
+ def value(self) -> ShapeFunction.Value:
69
+ """Value type for the underlying shape functions"""
70
+ raise NotImplementedError()
71
+
72
+ @cache.cached_arg_value
73
+ def basis_arg_value(self, device) -> "BasisArg":
74
+ """Value for the argument structure to be passed to device functions"""
75
+ arg = self.BasisArg()
76
+ self.fill_basis_arg(arg, device)
77
+ return arg
78
+
79
+ def fill_basis_arg(self, arg, device):
80
+ pass
81
+
82
+ # Helpers for generating node positions
83
+
84
+ def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
85
+ """Returns a temporary array containing the world position for each node"""
86
+
87
+ pos_type = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
88
+
89
+ node_coords_in_element = self.make_node_coords_in_element()
90
+
91
+ @cache.dynamic_kernel(suffix=self.name, kernel_options={"max_unroll": 4, "enable_backward": False})
92
+ def fill_node_positions(
93
+ geo_cell_arg: self.geometry.CellArg,
94
+ basis_arg: self.BasisArg,
95
+ topo_arg: self.topology.TopologyArg,
96
+ node_positions: wp.array(dtype=pos_type),
97
+ ):
98
+ element_index = wp.tid()
99
+
100
+ element_node_count = self.topology.element_node_count(geo_cell_arg, topo_arg, element_index)
101
+ for n in range(element_node_count):
102
+ node_index = self.topology.element_node_index(geo_cell_arg, topo_arg, element_index, n)
103
+ coords = node_coords_in_element(geo_cell_arg, basis_arg, element_index, n)
104
+
105
+ sample = make_free_sample(element_index, coords)
106
+ pos = self.geometry.cell_position(geo_cell_arg, sample)
107
+
108
+ node_positions[node_index] = pos
109
+
110
+ shape = (self.topology.node_count(),)
111
+ if out is None:
112
+ node_positions = wp.empty(
113
+ shape=shape,
114
+ dtype=pos_type,
115
+ )
116
+ else:
117
+ if out.shape != shape or not types_equal(pos_type, out.dtype):
118
+ raise ValueError(
119
+ f"Out node positions array must have shape {shape} and data type {type_repr(pos_type)}"
120
+ )
121
+ node_positions = out
122
+
123
+ wp.launch(
124
+ dim=self.geometry.cell_count(),
125
+ kernel=fill_node_positions,
126
+ inputs=[
127
+ self.geometry.cell_arg_value(device=node_positions.device),
128
+ self.basis_arg_value(device=node_positions.device),
129
+ self.topology.topo_arg_value(device=node_positions.device),
130
+ node_positions,
131
+ ],
132
+ )
133
+
134
+ return node_positions
135
+
136
+ def make_node_coords_in_element(self):
137
+ raise NotImplementedError()
138
+
139
+ def make_node_quadrature_weight(self):
140
+ raise NotImplementedError()
141
+
142
+ def make_element_inner_weight(self):
143
+ raise NotImplementedError()
144
+
145
+ def make_element_outer_weight(self):
146
+ return self.make_element_inner_weight()
147
+
148
+ def make_element_inner_weight_gradient(self):
149
+ raise NotImplementedError()
150
+
151
+ def make_element_outer_weight_gradient(self):
152
+ return self.make_element_inner_weight_gradient()
153
+
154
+ def make_trace_node_quadrature_weight(self):
155
+ raise NotImplementedError()
156
+
157
+ def trace(self) -> "TraceBasisSpace":
158
+ return TraceBasisSpace(self)
159
+
160
+ @property
161
+ def weight_type(self):
162
+ if self.value is ShapeFunction.Value.Scalar:
163
+ return float
164
+
165
+ return cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=float)
166
+
167
+ @property
168
+ def weight_gradient_type(self):
169
+ if self.value is ShapeFunction.Value.Scalar:
170
+ return wp.vec(length=self.geometry.cell_dimension, dtype=float)
171
+
172
+ return cache.cached_mat_type(shape=(self.geometry.cell_dimension, self.geometry.cell_dimension), dtype=float)
173
+
174
+
175
+ class ShapeBasisSpace(BasisSpace):
176
+ """Base class for defining shape-function-based basis spaces."""
177
+
178
+ def __init__(self, topology: SpaceTopology, shape: ShapeFunction):
179
+ super().__init__(topology)
180
+ self._shape = shape
181
+
182
+ if self.value is not ShapeFunction.Value.Scalar:
183
+ self.BasisArg = self.topology.TopologyArg
184
+ self.basis_arg_value = self.topology.topo_arg_value
185
+ self.fill_basis_arg = self.topology.fill_topo_arg
186
+
187
+ self.ORDER = self._shape.ORDER
188
+
189
+ if hasattr(shape, "element_node_triangulation"):
190
+ self.node_triangulation = self._node_triangulation
191
+ if hasattr(shape, "element_node_tets"):
192
+ self.node_tets = self._node_tets
193
+ if hasattr(shape, "element_node_hexes"):
194
+ self.node_hexes = self._node_hexes
195
+ if hasattr(shape, "element_vtk_cells"):
196
+ self.vtk_cells = self._vtk_cells
197
+ if hasattr(topology, "node_grid"):
198
+ self.node_grid = topology.node_grid
199
+
200
+ @property
201
+ def shape(self) -> ShapeFunction:
202
+ """Shape functions used for defining individual element basis"""
203
+ return self._shape
204
+
205
+ @property
206
+ def value(self) -> ShapeFunction.Value:
207
+ return self.shape.value
208
+
209
+ @cached_property
210
+ def name(self):
211
+ return f"{self.topology.name}_{self._shape.name}"
212
+
213
+ def make_node_coords_in_element(self):
214
+ shape_node_coords_in_element = self._shape.make_node_coords_in_element()
215
+
216
+ @cache.dynamic_func(suffix=self.name)
217
+ def node_coords_in_element(
218
+ elt_arg: self.geometry.CellArg,
219
+ basis_arg: self.BasisArg,
220
+ element_index: ElementIndex,
221
+ node_index_in_elt: int,
222
+ ):
223
+ return shape_node_coords_in_element(node_index_in_elt)
224
+
225
+ return node_coords_in_element
226
+
227
+ def make_node_quadrature_weight(self):
228
+ shape_node_quadrature_weight = self._shape.make_node_quadrature_weight()
229
+
230
+ if shape_node_quadrature_weight is None:
231
+ return None
232
+
233
+ @cache.dynamic_func(suffix=self.name)
234
+ def node_quadrature_weight(
235
+ elt_arg: self.geometry.CellArg,
236
+ basis_arg: self.BasisArg,
237
+ element_index: ElementIndex,
238
+ node_index_in_elt: int,
239
+ ):
240
+ return shape_node_quadrature_weight(node_index_in_elt)
241
+
242
+ return node_quadrature_weight
243
+
244
+ def make_element_inner_weight(self):
245
+ shape_element_inner_weight = self._shape.make_element_inner_weight()
246
+
247
+ @cache.dynamic_func(suffix=self.name)
248
+ def element_inner_weight(
249
+ elt_arg: self.geometry.CellArg,
250
+ basis_arg: self.BasisArg,
251
+ element_index: ElementIndex,
252
+ coords: Coords,
253
+ node_index_in_elt: int,
254
+ qp_index: QuadraturePointIndex,
255
+ ):
256
+ if wp.static(self.value == ShapeFunction.Value.Scalar):
257
+ return shape_element_inner_weight(coords, node_index_in_elt)
258
+ else:
259
+ sign = self.topology.element_node_sign(elt_arg, basis_arg, element_index, node_index_in_elt)
260
+ return sign * shape_element_inner_weight(coords, node_index_in_elt)
261
+
262
+ return element_inner_weight
263
+
264
+ def make_element_inner_weight_gradient(self):
265
+ shape_element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
266
+
267
+ @cache.dynamic_func(suffix=self.name)
268
+ def element_inner_weight_gradient(
269
+ elt_arg: self.geometry.CellArg,
270
+ basis_arg: self.BasisArg,
271
+ element_index: ElementIndex,
272
+ coords: Coords,
273
+ node_index_in_elt: int,
274
+ qp_index: QuadraturePointIndex,
275
+ ):
276
+ if wp.static(self.value == ShapeFunction.Value.Scalar):
277
+ return shape_element_inner_weight_gradient(coords, node_index_in_elt)
278
+ else:
279
+ sign = self.topology.element_node_sign(elt_arg, basis_arg, element_index, node_index_in_elt)
280
+ return sign * shape_element_inner_weight_gradient(coords, node_index_in_elt)
281
+
282
+ return element_inner_weight_gradient
283
+
284
+ def make_trace_node_quadrature_weight(self, trace_basis):
285
+ shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight()
286
+
287
+ if shape_trace_node_quadrature_weight is None:
288
+ return None
289
+
290
+ @cache.dynamic_func(suffix=self.name)
291
+ def trace_node_quadrature_weight(
292
+ geo_side_arg: trace_basis.geometry.SideArg,
293
+ basis_arg: trace_basis.BasisArg,
294
+ element_index: ElementIndex,
295
+ node_index_in_elt: int,
296
+ ):
297
+ neighbour_elem, index_in_neighbour = trace_basis.topology.neighbor_cell_index(
298
+ geo_side_arg, element_index, node_index_in_elt
299
+ )
300
+ return shape_trace_node_quadrature_weight(index_in_neighbour)
301
+
302
+ return trace_node_quadrature_weight
303
+
304
+ def trace(self) -> "TraceBasisSpace":
305
+ if self.ORDER == 0:
306
+ return PiecewiseConstantBasisSpaceTrace(self)
307
+
308
+ return TraceBasisSpace(self)
309
+
310
+ def _node_triangulation(self):
311
+ element_node_indices = self._topology.element_node_indices().numpy()
312
+ element_triangles = self._shape.element_node_triangulation()
313
+
314
+ tri_indices = element_node_indices[:, element_triangles].reshape(-1, 3)
315
+ return tri_indices
316
+
317
+ def _node_tets(self):
318
+ element_node_indices = self._topology.element_node_indices().numpy()
319
+ element_tets = self._shape.element_node_tets()
320
+
321
+ tet_indices = element_node_indices[:, element_tets].reshape(-1, 4)
322
+ return tet_indices
323
+
324
+ def _node_hexes(self):
325
+ element_node_indices = self._topology.element_node_indices().numpy()
326
+ element_hexes = self._shape.element_node_hexes()
327
+
328
+ hex_indices = element_node_indices[:, element_hexes].reshape(-1, 8)
329
+ return hex_indices
330
+
331
+ def _vtk_cells(self):
332
+ element_node_indices = self._topology.element_node_indices().numpy()
333
+ element_vtk_cells, element_vtk_cell_types = self._shape.element_vtk_cells()
334
+
335
+ idx_per_cell = element_vtk_cells.shape[1]
336
+ cell_indices = element_node_indices[:, element_vtk_cells].reshape(-1, idx_per_cell)
337
+ cells = np.hstack((np.full((cell_indices.shape[0], 1), idx_per_cell), cell_indices))
338
+
339
+ return cells.flatten(), np.tile(element_vtk_cell_types, element_node_indices.shape[0])
340
+
341
+
342
+ class TraceBasisSpace(BasisSpace):
343
+ """Auto-generated trace space evaluating the cell-defined basis on the geometry sides"""
344
+
345
+ def __init__(self, basis: BasisSpace):
346
+ super().__init__(basis.topology.trace())
347
+
348
+ self.ORDER = basis.ORDER
349
+
350
+ self._basis = basis
351
+ self.BasisArg = self._basis.BasisArg
352
+ self.basis_arg_value = self._basis.basis_arg_value
353
+ self.fill_basis_arg = self._basis.fill_basis_arg
354
+
355
+ @property
356
+ def name(self):
357
+ return f"{self._basis.name}_Trace"
358
+
359
+ @property
360
+ def value(self) -> ShapeFunction.Value:
361
+ return self._basis.value
362
+
363
+ def make_node_coords_in_element(self):
364
+ node_coords_in_cell = self._basis.make_node_coords_in_element()
365
+
366
+ @cache.dynamic_func(suffix=self._basis.name)
367
+ def trace_node_coords_in_element(
368
+ geo_side_arg: self.geometry.SideArg,
369
+ basis_arg: self.BasisArg,
370
+ element_index: ElementIndex,
371
+ node_index_in_elt: int,
372
+ ):
373
+ neighbour_elem, index_in_neighbour = self.topology.neighbor_cell_index(
374
+ geo_side_arg, element_index, node_index_in_elt
375
+ )
376
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
377
+ neighbour_coords = node_coords_in_cell(
378
+ geo_cell_arg,
379
+ basis_arg,
380
+ neighbour_elem,
381
+ index_in_neighbour,
382
+ )
383
+
384
+ return self.geometry.side_from_cell_coords(geo_side_arg, element_index, neighbour_elem, neighbour_coords)
385
+
386
+ return trace_node_coords_in_element
387
+
388
+ def make_node_quadrature_weight(self):
389
+ return self._basis.make_trace_node_quadrature_weight(self)
390
+
391
+ def make_element_inner_weight(self):
392
+ cell_inner_weight = self._basis.make_element_inner_weight()
393
+
394
+ @cache.dynamic_func(suffix=self._basis.name)
395
+ def trace_element_inner_weight(
396
+ geo_side_arg: self.geometry.SideArg,
397
+ basis_arg: self.BasisArg,
398
+ element_index: ElementIndex,
399
+ coords: Coords,
400
+ node_index_in_elt: int,
401
+ qp_index: QuadraturePointIndex,
402
+ ):
403
+ cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
404
+ if cell_index == NULL_ELEMENT_INDEX:
405
+ return self.weight_type(0.0)
406
+
407
+ cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
408
+
409
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
410
+ return cell_inner_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
411
+
412
+ return trace_element_inner_weight
413
+
414
+ def make_element_outer_weight(self):
415
+ cell_outer_weight = self._basis.make_element_outer_weight()
416
+
417
+ @cache.dynamic_func(suffix=self._basis.name)
418
+ def trace_element_outer_weight(
419
+ geo_side_arg: self.geometry.SideArg,
420
+ basis_arg: self.BasisArg,
421
+ element_index: ElementIndex,
422
+ coords: Coords,
423
+ node_index_in_elt: int,
424
+ qp_index: QuadraturePointIndex,
425
+ ):
426
+ cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
427
+ if cell_index == NULL_ELEMENT_INDEX:
428
+ return self.weight_type(0.0)
429
+
430
+ cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
431
+
432
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
433
+ return cell_outer_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
434
+
435
+ return trace_element_outer_weight
436
+
437
+ def make_element_inner_weight_gradient(self):
438
+ cell_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
439
+
440
+ @cache.dynamic_func(suffix=self._basis.name)
441
+ def trace_element_inner_weight_gradient(
442
+ geo_side_arg: self.geometry.SideArg,
443
+ basis_arg: self.BasisArg,
444
+ element_index: ElementIndex,
445
+ coords: Coords,
446
+ node_index_in_elt: int,
447
+ qp_index: QuadraturePointIndex,
448
+ ):
449
+ cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
450
+ if cell_index == NULL_ELEMENT_INDEX:
451
+ return self.weight_gradient_type(0.0)
452
+
453
+ cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
454
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
455
+ return cell_inner_weight_gradient(
456
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
457
+ )
458
+
459
+ return trace_element_inner_weight_gradient
460
+
461
+ def make_element_outer_weight_gradient(self):
462
+ cell_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
463
+
464
+ @cache.dynamic_func(suffix=self._basis.name)
465
+ def trace_element_outer_weight_gradient(
466
+ geo_side_arg: self.geometry.SideArg,
467
+ basis_arg: self.BasisArg,
468
+ element_index: ElementIndex,
469
+ coords: Coords,
470
+ node_index_in_elt: int,
471
+ qp_index: QuadraturePointIndex,
472
+ ):
473
+ cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
474
+ if cell_index == NULL_ELEMENT_INDEX:
475
+ return self.weight_gradient_type(0.0)
476
+
477
+ cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
478
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
479
+ return cell_outer_weight_gradient(
480
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
481
+ )
482
+
483
+ return trace_element_outer_weight_gradient
484
+
485
+ def __eq__(self, other: "TraceBasisSpace") -> bool:
486
+ return self._topo == other._topo
487
+
488
+
489
+ class PiecewiseConstantBasisSpaceTrace(TraceBasisSpace):
490
+ def make_node_coords_in_element(self):
491
+ # Makes the single node visible to all sides; useful for interpolating on boundaries
492
+ # For higher-order non-conforming elements direct interpolation on boundary is not possible,
493
+ # need to do proper integration then solve with mass matrix
494
+
495
+ CENTER_COORDS = Coords(self.geometry.reference_side().prototype.center())
496
+
497
+ @cache.dynamic_func(suffix=self._basis.name)
498
+ def trace_node_coords_in_element(
499
+ geo_side_arg: self.geometry.SideArg,
500
+ basis_arg: self.BasisArg,
501
+ element_index: ElementIndex,
502
+ node_index_in_elt: int,
503
+ ):
504
+ return CENTER_COORDS
505
+
506
+ return trace_node_coords_in_element
507
+
508
+
509
+ class UnstructuredPointTopology(SpaceTopology):
510
+ """Topology for unstructured points defined from quadrature formula. See :class:`PointBasisSpace`"""
511
+
512
+ _dynamic_attribute_constructors: ClassVar = {
513
+ "element_node_index": lambda obj: obj._make_element_node_index(),
514
+ "element_node_count": lambda obj: obj._make_element_node_count(),
515
+ "side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
516
+ }
517
+
518
+ def __init__(self, quadrature: Quadrature):
519
+ if quadrature.max_points_per_element() is None:
520
+ raise ValueError("Quadrature must define a maximum number of points per element")
521
+
522
+ if quadrature.domain.element_count() != quadrature.domain.geometry_element_count():
523
+ raise ValueError("Point topology may only be defined on quadrature domains than span the whole geometry")
524
+
525
+ self._quadrature = quadrature
526
+ self.TopologyArg = quadrature.Arg
527
+ self.topo_arg_value = quadrature.arg_value
528
+ self.fill_topo_arg = quadrature.fill_arg
529
+
530
+ super().__init__(quadrature.domain.geometry, max_nodes_per_element=quadrature.max_points_per_element())
531
+
532
+ cache.setup_dynamic_attributes(self, cls=__class__)
533
+
534
+ def node_count(self):
535
+ return self._quadrature.total_point_count()
536
+
537
+ @property
538
+ def name(self):
539
+ return f"PointTopology_{self._quadrature}"
540
+
541
+ def topo_arg_value(self, device) -> SpaceTopology.TopologyArg:
542
+ """Value of the topology argument structure to be passed to device functions"""
543
+ return self._quadrature.arg_value(device)
544
+
545
+ def _make_element_node_index(self):
546
+ @cache.dynamic_func(suffix=self.name)
547
+ def element_node_index(
548
+ elt_arg: self.geometry.CellArg,
549
+ topo_arg: self.TopologyArg,
550
+ element_index: ElementIndex,
551
+ node_index_in_elt: int,
552
+ ):
553
+ return self._quadrature.point_index(elt_arg, topo_arg, element_index, element_index, node_index_in_elt)
554
+
555
+ return element_node_index
556
+
557
+ def _make_element_node_count(self):
558
+ @cache.dynamic_func(suffix=self.name)
559
+ def element_node_count(
560
+ elt_arg: self.geometry.CellArg,
561
+ topo_arg: self.TopologyArg,
562
+ element_index: ElementIndex,
563
+ ):
564
+ return self._quadrature.point_count(elt_arg, topo_arg, element_index, element_index)
565
+
566
+ return element_node_count
567
+
568
+ def _make_side_neighbor_node_counts(self):
569
+ MAX_NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
570
+
571
+ @cache.dynamic_func(suffix=self.name)
572
+ def side_neighbor_node_counts(
573
+ side_arg: self.geometry.SideArg,
574
+ element_index: ElementIndex,
575
+ ):
576
+ return MAX_NODES_PER_ELEMENT, MAX_NODES_PER_ELEMENT
577
+
578
+ return side_neighbor_node_counts
579
+
580
+
581
+ class PointBasisSpace(BasisSpace):
582
+ """An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only.
583
+
584
+ The node locations and nodal quadrature weights are defined by a :class:`Quadrature` formula.
585
+ """
586
+
587
+ def __init__(self, quadrature: Quadrature):
588
+ self._quadrature = quadrature
589
+
590
+ topology = UnstructuredPointTopology(quadrature)
591
+ super().__init__(topology)
592
+
593
+ self.BasisArg = quadrature.Arg
594
+ self.basis_arg_value = quadrature.arg_value
595
+ self.fill_basis_arg = quadrature.fill_arg
596
+
597
+ self.ORDER = 0
598
+
599
+ self.make_element_outer_weight = self.make_element_inner_weight
600
+ self.make_element_outer_weight_gradient = self.make_element_outer_weight_gradient
601
+
602
+ @property
603
+ def name(self):
604
+ return f"{self._quadrature.name}_Point"
605
+
606
+ @property
607
+ def value(self) -> ShapeFunction.Value:
608
+ return ShapeFunction.Value.Scalar
609
+
610
+ def make_node_coords_in_element(self):
611
+ @cache.dynamic_func(suffix=self.name)
612
+ def node_coords_in_element(
613
+ elt_arg: self._quadrature.domain.ElementArg,
614
+ basis_arg: self.BasisArg,
615
+ element_index: ElementIndex,
616
+ node_index_in_elt: int,
617
+ ):
618
+ return self._quadrature.point_coords(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
619
+
620
+ return node_coords_in_element
621
+
622
+ def make_node_quadrature_weight(self):
623
+ @cache.dynamic_func(suffix=self.name)
624
+ def node_quadrature_weight(
625
+ elt_arg: self._quadrature.domain.ElementArg,
626
+ basis_arg: self.BasisArg,
627
+ element_index: ElementIndex,
628
+ node_index_in_elt: int,
629
+ ):
630
+ return self._quadrature.point_weight(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
631
+
632
+ return node_quadrature_weight
633
+
634
+ def make_element_inner_weight(self):
635
+ _DIRAC_INTEGRATION_RADIUS = wp.constant(1.0e-6)
636
+
637
+ @cache.dynamic_func(suffix=self.name)
638
+ def element_inner_weight(
639
+ elt_arg: self._quadrature.domain.ElementArg,
640
+ basis_arg: self.BasisArg,
641
+ element_index: ElementIndex,
642
+ coords: Coords,
643
+ node_index_in_elt: int,
644
+ qp_index: QuadraturePointIndex,
645
+ ):
646
+ qp_coord = self._quadrature.point_coords(
647
+ elt_arg, basis_arg, element_index, element_index, node_index_in_elt
648
+ )
649
+ return wp.where(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 1.0, 0.0)
650
+
651
+ return element_inner_weight
652
+
653
+ def make_element_inner_weight_gradient(self):
654
+ gradient_vec = cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=float)
655
+
656
+ @cache.dynamic_func(suffix=self.name)
657
+ def element_inner_weight_gradient(
658
+ elt_arg: self._quadrature.domain.ElementArg,
659
+ basis_arg: self.BasisArg,
660
+ element_index: ElementIndex,
661
+ coords: Coords,
662
+ node_index_in_elt: int,
663
+ qp_index: QuadraturePointIndex,
664
+ ):
665
+ return gradient_vec(0.0)
666
+
667
+ return element_inner_weight_gradient
668
+
669
+ def make_trace_node_quadrature_weight(self, trace_basis):
670
+ @cache.dynamic_func(suffix=self.name)
671
+ def trace_node_quadrature_weight(
672
+ elt_arg: trace_basis.geometry.SideArg,
673
+ basis_arg: trace_basis.BasisArg,
674
+ element_index: ElementIndex,
675
+ node_index_in_elt: int,
676
+ ):
677
+ return 0.0
678
+
679
+ return trace_node_quadrature_weight
@@ -18,9 +18,9 @@ from enum import Enum
18
18
  from typing import Any
19
19
 
20
20
  import warp as wp
21
- import warp.types
21
+ from warp._src.types import type_size
22
22
 
23
- vec6 = wp.types.vector(length=6, dtype=wp.float32)
23
+ vec6 = wp.vec(length=6, dtype=wp.float32)
24
24
 
25
25
  _SQRT_2 = wp.constant(math.sqrt(2.0))
26
26
  _SQRT_3 = wp.constant(math.sqrt(3.0))
@@ -57,7 +57,7 @@ class IdentityMapper(DofMapper):
57
57
  self.value_dtype = dtype
58
58
  self.dof_dtype = dtype
59
59
 
60
- size = warp.types.type_size(dtype)
60
+ size = type_size(dtype)
61
61
  self.DOF_SIZE = wp.constant(size)
62
62
 
63
63
  @wp.func