warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1105 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.fem import cache
22
+ from warp.fem.geometry import Grid3D
23
+ from warp.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
24
+ from warp.fem.types import Coords
25
+
26
+ from .shape_function import ShapeFunction
27
+ from .tet_shape_function import TetrahedronPolynomialShapeFunctions
28
+
29
+
30
+ class CubeShapeFunction(ShapeFunction):
31
+ VERTEX = 0
32
+ EDGE = 1
33
+ FACE = 2
34
+ INTERIOR = 3
35
+
36
+ @wp.func
37
+ def _vertex_coords(vidx_in_cell: int):
38
+ x = vidx_in_cell // 4
39
+ y = (vidx_in_cell - 4 * x) // 2
40
+ z = vidx_in_cell - 4 * x - 2 * y
41
+ return wp.vec3i(x, y, z)
42
+
43
+ @wp.func
44
+ def _edge_coords(type_instance: int, index_in_side: int):
45
+ return wp.vec3i(index_in_side + 1, (type_instance & 2) >> 1, type_instance & 1)
46
+
47
+ @wp.func
48
+ def _edge_axis(type_instance: int):
49
+ return type_instance >> 2
50
+
51
+ @wp.func
52
+ def _face_offset(type_instance: int):
53
+ return type_instance & 1
54
+
55
+ @wp.func
56
+ def _face_axis(type_instance: int):
57
+ return type_instance >> 1
58
+
59
+
60
+ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
61
+ def __init__(self, degree: int, family: Polynomial):
62
+ self.family = family
63
+
64
+ self.ORDER = wp.constant(degree)
65
+ self.NODES_PER_ELEMENT = wp.constant((degree + 1) ** 3)
66
+ self.NODES_PER_SIDE = wp.constant((degree + 1) ** 2)
67
+
68
+ if is_closed(self.family):
69
+ self.VERTEX_NODE_COUNT = wp.constant(1)
70
+ self.EDGE_NODE_COUNT = wp.constant(max(0, degree - 1))
71
+ self.FACE_NODE_COUNT = wp.constant(max(0, degree - 1) ** 2)
72
+ self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) ** 3)
73
+ else:
74
+ self.VERTEX_NODE_COUNT = wp.constant(0)
75
+ self.EDGE_NODE_COUNT = wp.constant(0)
76
+ self.FACE_NODE_COUNT = wp.constant(0)
77
+ self.INTERIOR_NODE_COUNT = self.NODES_PER_ELEMENT
78
+
79
+ lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
80
+ lagrange_scale = lagrange_scales(lobatto_coords)
81
+
82
+ NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
83
+ self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
84
+ self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
85
+ self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
86
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
87
+
88
+ self._node_ijk = self._make_node_ijk()
89
+ self.node_type_and_type_index = self._make_node_type_and_type_index()
90
+
91
+ if degree > 2:
92
+ self._face_node_ij = self._make_face_node_ij()
93
+ self._linear_face_node_index = self._make_linear_face_node_index()
94
+
95
+ @property
96
+ def name(self) -> str:
97
+ return f"Cube_Q{self.ORDER}_{self.family}"
98
+
99
+ @wp.func
100
+ def _vertex_coords_f(vidx_in_cell: int):
101
+ x = vidx_in_cell // 4
102
+ y = (vidx_in_cell - 4 * x) // 2
103
+ z = vidx_in_cell - 4 * x - 2 * y
104
+ return wp.vec3(float(x), float(y), float(z))
105
+
106
+ def _make_node_ijk(self):
107
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
108
+
109
+ def node_ijk(
110
+ node_index_in_elt: int,
111
+ ):
112
+ node_i = node_index_in_elt // (ORDER_PLUS_ONE * ORDER_PLUS_ONE)
113
+ node_jk = node_index_in_elt - ORDER_PLUS_ONE * ORDER_PLUS_ONE * node_i
114
+ node_j = node_jk // ORDER_PLUS_ONE
115
+ node_k = node_jk - ORDER_PLUS_ONE * node_j
116
+ return node_i, node_j, node_k
117
+
118
+ return cache.get_func(node_ijk, self.name)
119
+
120
+ def _make_face_node_ij(self):
121
+ ORDER_MINUS_ONE = wp.constant(self.ORDER - 1)
122
+
123
+ def face_node_ij(
124
+ face_node_index: int,
125
+ ):
126
+ node_i = face_node_index // ORDER_MINUS_ONE
127
+ node_j = face_node_index - ORDER_MINUS_ONE * node_i
128
+ return wp.vec2i(node_i, node_j)
129
+
130
+ return cache.get_func(face_node_ij, self.name)
131
+
132
+ def _make_linear_face_node_index(self):
133
+ ORDER_MINUS_ONE = wp.constant(self.ORDER - 1)
134
+
135
+ def linear_face_node_index(face_node_index: int, face_node_ij: wp.vec2i):
136
+ return face_node_ij[0] * ORDER_MINUS_ONE + face_node_ij[1]
137
+
138
+ return cache.get_func(linear_face_node_index, self.name)
139
+
140
+ def _make_node_type_and_type_index(self):
141
+ ORDER = self.ORDER
142
+
143
+ @cache.dynamic_func(suffix=self.name)
144
+ def node_type_and_type_index_open(
145
+ node_index_in_elt: int,
146
+ ):
147
+ return CubeShapeFunction.INTERIOR, 0, node_index_in_elt
148
+
149
+ @cache.dynamic_func(suffix=self.name)
150
+ def node_type_and_type_index(
151
+ node_index_in_elt: int,
152
+ ):
153
+ i, j, k = self._node_ijk(node_index_in_elt)
154
+
155
+ zi = wp.where(i == 0, 1, 0)
156
+ zj = wp.where(j == 0, 1, 0)
157
+ zk = wp.where(k == 0, 1, 0)
158
+
159
+ mi = wp.where(i == ORDER, 1, 0)
160
+ mj = wp.where(j == ORDER, 1, 0)
161
+ mk = wp.where(k == ORDER, 1, 0)
162
+
163
+ if zi + mi == 1:
164
+ if zj + mj == 1:
165
+ if zk + mk == 1:
166
+ # vertex
167
+ type_instance = mi * 4 + mj * 2 + mk
168
+ return CubeTripolynomialShapeFunctions.VERTEX, type_instance, 0
169
+
170
+ # z edge
171
+ type_instance = 8 + mi * 2 + mj
172
+ type_index = k - 1
173
+ return CubeTripolynomialShapeFunctions.EDGE, type_instance, type_index
174
+
175
+ if zk + mk == 1:
176
+ # y edge
177
+ type_instance = 4 + mk * 2 + mi
178
+ type_index = j - 1
179
+ return CubeTripolynomialShapeFunctions.EDGE, type_instance, type_index
180
+
181
+ # x face
182
+ type_instance = mi
183
+ type_index = (j - 1) * (ORDER - 1) + k - 1
184
+ return CubeTripolynomialShapeFunctions.FACE, type_instance, type_index
185
+
186
+ if zj + mj == 1:
187
+ if zk + mk == 1:
188
+ # x edge
189
+ type_instance = mj * 2 + mk
190
+ type_index = i - 1
191
+ return CubeTripolynomialShapeFunctions.EDGE, type_instance, type_index
192
+
193
+ # y face
194
+ type_instance = 2 + mj
195
+ type_index = (k - 1) * (ORDER - 1) + i - 1
196
+ return CubeTripolynomialShapeFunctions.FACE, type_instance, type_index
197
+
198
+ if zk + mk == 1:
199
+ # z face
200
+ type_instance = 4 + mk
201
+ type_index = (i - 1) * (ORDER - 1) + j - 1
202
+ return CubeTripolynomialShapeFunctions.FACE, type_instance, type_index
203
+
204
+ type_index = ((i - 1) * (ORDER - 1) + (j - 1)) * (ORDER - 1) + k - 1
205
+ return CubeTripolynomialShapeFunctions.INTERIOR, 0, type_index
206
+
207
+ return node_type_and_type_index if is_closed(self.family) else node_type_and_type_index_open
208
+
209
+ def make_node_coords_in_element(self):
210
+ LOBATTO_COORDS = self.LOBATTO_COORDS
211
+
212
+ @cache.dynamic_func(suffix=self.name)
213
+ def node_coords_in_element(
214
+ node_index_in_elt: int,
215
+ ):
216
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
217
+ return Coords(LOBATTO_COORDS[node_i], LOBATTO_COORDS[node_j], LOBATTO_COORDS[node_k])
218
+
219
+ return node_coords_in_element
220
+
221
+ def make_node_quadrature_weight(self):
222
+ ORDER = self.ORDER
223
+ LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
224
+
225
+ def node_quadrature_weight(
226
+ node_index_in_elt: int,
227
+ ):
228
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
229
+ return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j] * LOBATTO_WEIGHT[node_k]
230
+
231
+ def node_quadrature_weight_linear(
232
+ node_index_in_elt: int,
233
+ ):
234
+ return 0.125
235
+
236
+ if ORDER == 1:
237
+ return cache.get_func(node_quadrature_weight_linear, self.name)
238
+
239
+ return cache.get_func(node_quadrature_weight, self.name)
240
+
241
+ def make_trace_node_quadrature_weight(self):
242
+ ORDER = self.ORDER
243
+ LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
244
+
245
+ def trace_node_quadrature_weight(
246
+ node_index_in_elt: int,
247
+ ):
248
+ # We're either on a side interior or at a vertex
249
+ # If we find one index at extremum, pick the two other
250
+
251
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
252
+
253
+ if node_i == 0 or node_i == ORDER:
254
+ return LOBATTO_WEIGHT[node_j] * LOBATTO_WEIGHT[node_k]
255
+
256
+ if node_j == 0 or node_j == ORDER:
257
+ return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_k]
258
+
259
+ return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j]
260
+
261
+ def trace_node_quadrature_weight_linear(
262
+ node_index_in_elt: int,
263
+ ):
264
+ return 0.25
265
+
266
+ def trace_node_quadrature_weight_open(
267
+ node_index_in_elt: int,
268
+ ):
269
+ return 0.0
270
+
271
+ if not is_closed(self.family):
272
+ return cache.get_func(trace_node_quadrature_weight_open, self.name)
273
+
274
+ if ORDER == 1:
275
+ return cache.get_func(trace_node_quadrature_weight_linear, self.name)
276
+
277
+ return cache.get_func(trace_node_quadrature_weight, self.name)
278
+
279
+ def make_element_inner_weight(self):
280
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
281
+ LOBATTO_COORDS = self.LOBATTO_COORDS
282
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
283
+
284
+ def element_inner_weight(
285
+ coords: Coords,
286
+ node_index_in_elt: int,
287
+ ):
288
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
289
+
290
+ w = float(1.0)
291
+ for k in range(ORDER_PLUS_ONE):
292
+ if k != node_i:
293
+ w *= coords[0] - LOBATTO_COORDS[k]
294
+ if k != node_j:
295
+ w *= coords[1] - LOBATTO_COORDS[k]
296
+ if k != node_k:
297
+ w *= coords[2] - LOBATTO_COORDS[k]
298
+
299
+ w *= LAGRANGE_SCALE[node_i] * LAGRANGE_SCALE[node_j] * LAGRANGE_SCALE[node_k]
300
+
301
+ return w
302
+
303
+ def element_inner_weight_linear(
304
+ coords: Coords,
305
+ node_index_in_elt: int,
306
+ ):
307
+ v = CubeTripolynomialShapeFunctions._vertex_coords_f(node_index_in_elt)
308
+
309
+ wx = (1.0 - coords[0]) * (1.0 - v[0]) + v[0] * coords[0]
310
+ wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
311
+ wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
312
+ return wx * wy * wz
313
+
314
+ if self.ORDER == 1 and is_closed(self.family):
315
+ return cache.get_func(element_inner_weight_linear, self.name)
316
+
317
+ return cache.get_func(element_inner_weight, self.name)
318
+
319
+ def make_element_inner_weight_gradient(self):
320
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
321
+ LOBATTO_COORDS = self.LOBATTO_COORDS
322
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
323
+
324
+ def element_inner_weight_gradient(
325
+ coords: Coords,
326
+ node_index_in_elt: int,
327
+ ):
328
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
329
+
330
+ prefix_xy = float(1.0)
331
+ prefix_yz = float(1.0)
332
+ prefix_zx = float(1.0)
333
+ for k in range(ORDER_PLUS_ONE):
334
+ if k != node_i:
335
+ prefix_yz *= coords[0] - LOBATTO_COORDS[k]
336
+ if k != node_j:
337
+ prefix_zx *= coords[1] - LOBATTO_COORDS[k]
338
+ if k != node_k:
339
+ prefix_xy *= coords[2] - LOBATTO_COORDS[k]
340
+
341
+ prefix_x = prefix_zx * prefix_xy
342
+ prefix_y = prefix_yz * prefix_xy
343
+ prefix_z = prefix_zx * prefix_yz
344
+
345
+ grad_x = float(0.0)
346
+ grad_y = float(0.0)
347
+ grad_z = float(0.0)
348
+
349
+ for k in range(ORDER_PLUS_ONE):
350
+ if k != node_i:
351
+ delta_x = coords[0] - LOBATTO_COORDS[k]
352
+ grad_x = grad_x * delta_x + prefix_x
353
+ prefix_x *= delta_x
354
+ if k != node_j:
355
+ delta_y = coords[1] - LOBATTO_COORDS[k]
356
+ grad_y = grad_y * delta_y + prefix_y
357
+ prefix_y *= delta_y
358
+ if k != node_k:
359
+ delta_z = coords[2] - LOBATTO_COORDS[k]
360
+ grad_z = grad_z * delta_z + prefix_z
361
+ prefix_z *= delta_z
362
+
363
+ grad = (
364
+ LAGRANGE_SCALE[node_i]
365
+ * LAGRANGE_SCALE[node_j]
366
+ * LAGRANGE_SCALE[node_k]
367
+ * wp.vec3(
368
+ grad_x,
369
+ grad_y,
370
+ grad_z,
371
+ )
372
+ )
373
+
374
+ return grad
375
+
376
+ def element_inner_weight_gradient_linear(
377
+ coords: Coords,
378
+ node_index_in_elt: int,
379
+ ):
380
+ v = CubeTripolynomialShapeFunctions._vertex_coords_f(node_index_in_elt)
381
+
382
+ wx = (1.0 - coords[0]) * (1.0 - v[0]) + v[0] * coords[0]
383
+ wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
384
+ wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
385
+
386
+ dx = 2.0 * v[0] - 1.0
387
+ dy = 2.0 * v[1] - 1.0
388
+ dz = 2.0 * v[2] - 1.0
389
+
390
+ return wp.vec3(dx * wy * wz, dy * wz * wx, dz * wx * wy)
391
+
392
+ if self.ORDER == 1 and is_closed(self.family):
393
+ return cache.get_func(element_inner_weight_gradient_linear, self.name)
394
+
395
+ return cache.get_func(element_inner_weight_gradient, self.name)
396
+
397
+ def element_node_hexes(self):
398
+ from warp.fem.utils import grid_to_hexes
399
+
400
+ return grid_to_hexes(self.ORDER, self.ORDER, self.ORDER)
401
+
402
+ def element_node_tets(self):
403
+ from warp.fem.utils import grid_to_tets
404
+
405
+ return grid_to_tets(self.ORDER, self.ORDER, self.ORDER)
406
+
407
+ def element_vtk_cells(self):
408
+ n = self.ORDER + 1
409
+
410
+ # vertices
411
+ cells = [
412
+ [
413
+ [0, 0, 0],
414
+ [n - 1, 0, 0],
415
+ [n - 1, n - 1, 0],
416
+ [0, n - 1, 0],
417
+ [0, 0, n - 1],
418
+ [n - 1, 0, n - 1],
419
+ [n - 1, n - 1, n - 1],
420
+ [0, n - 1, n - 1],
421
+ ]
422
+ ]
423
+
424
+ if self.ORDER == 1:
425
+ cell_type = 12 # vtk_hexahedron
426
+ else:
427
+ middle = np.arange(1, n - 1)
428
+ front = np.zeros(n - 2, dtype=int)
429
+ back = np.full(n - 2, n - 1)
430
+
431
+ # edges
432
+ cells.append(np.column_stack((middle, front, front)))
433
+ cells.append(np.column_stack((back, middle, front)))
434
+ cells.append(np.column_stack((middle, back, front)))
435
+ cells.append(np.column_stack((front, middle, front)))
436
+
437
+ cells.append(np.column_stack((middle, front, back)))
438
+ cells.append(np.column_stack((back, middle, back)))
439
+ cells.append(np.column_stack((middle, back, back)))
440
+ cells.append(np.column_stack((front, middle, back)))
441
+
442
+ cells.append(np.column_stack((front, front, middle)))
443
+ cells.append(np.column_stack((back, front, middle)))
444
+ cells.append(np.column_stack((back, back, middle)))
445
+ cells.append(np.column_stack((front, back, middle)))
446
+
447
+ # faces
448
+
449
+ face = np.meshgrid(middle, middle)
450
+ front = np.zeros((n - 2) ** 2, dtype=int)
451
+ back = np.full((n - 2) ** 2, n - 1)
452
+
453
+ # YZ
454
+ cells.append(
455
+ np.column_stack((front, face[0].flatten(), face[1].flatten())),
456
+ )
457
+ cells.append(
458
+ np.column_stack((back, face[0].flatten(), face[1].flatten())),
459
+ )
460
+ # XZ
461
+ cells.append(
462
+ np.column_stack((face[0].flatten(), front, face[1].flatten())),
463
+ )
464
+ cells.append(
465
+ np.column_stack((face[0].flatten(), back, face[1].flatten())),
466
+ )
467
+ # XY
468
+ cells.append(
469
+ np.column_stack((face[0].flatten(), face[1].flatten(), front)),
470
+ )
471
+ cells.append(
472
+ np.column_stack((face[0].flatten(), face[1].flatten(), back)),
473
+ )
474
+
475
+ # interior
476
+ interior = np.meshgrid(middle, middle, middle)
477
+ cells.append(
478
+ np.column_stack((interior[0].flatten(), interior[1].flatten(), interior[2].flatten())),
479
+ )
480
+
481
+ cell_type = 72 # vtk_lagrange_hexahedron
482
+
483
+ cells = np.concatenate(cells)
484
+ cell_indices = cells[:, 0] * n * n + cells[:, 1] * n + cells[:, 2]
485
+
486
+ return cell_indices[np.newaxis, :], np.array([cell_type], dtype=np.int8)
487
+
488
+
489
+ class CubeSerendipityShapeFunctions(CubeShapeFunction):
490
+ """
491
+ Serendipity element ~ tensor product space without interior nodes
492
+ Edge shape functions are usual Lagrange shape functions times a bilinear function in the normal directions
493
+ Corner shape functions are trilinear shape functions times a function of (x^{d-1} + y^{d-1})
494
+ """
495
+
496
+ def __init__(self, degree: int, family: Polynomial):
497
+ if not is_closed(family):
498
+ raise ValueError("A closed polynomial family is required to define serendipity elements")
499
+
500
+ if degree not in [2, 3]:
501
+ raise NotImplementedError("Serendipity element only implemented for order 2 or 3")
502
+
503
+ self.family = family
504
+
505
+ self.ORDER = wp.constant(degree)
506
+ self.NODES_PER_ELEMENT = wp.constant(8 + 12 * (degree - 1))
507
+ self.NODES_PER_SIDE = wp.constant(4 * degree)
508
+
509
+ self.VERTEX_NODE_COUNT = wp.constant(1)
510
+ self.EDGE_NODE_COUNT = wp.constant(degree - 1)
511
+ self.FACE_NODE_COUNT = wp.constant(0)
512
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
513
+
514
+ lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
515
+ lagrange_scale = lagrange_scales(lobatto_coords)
516
+
517
+ NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
518
+ self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
519
+ self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
520
+ self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
521
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
522
+
523
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
524
+ self._node_lobatto_indices = self._get_node_lobatto_indices()
525
+
526
+ @property
527
+ def name(self) -> str:
528
+ return f"Cube_S{self.ORDER}_{self.family}"
529
+
530
+ def _get_node_type_and_type_index(self):
531
+ @cache.dynamic_func(suffix=self.name)
532
+ def node_type_and_index(
533
+ node_index_in_elt: int,
534
+ ):
535
+ if node_index_in_elt < 8:
536
+ return CubeShapeFunction.VERTEX, node_index_in_elt, 0
537
+
538
+ edge_index = (node_index_in_elt - 8) // 3
539
+ edge_axis = node_index_in_elt - 8 - 3 * edge_index
540
+
541
+ index_in_edge = edge_index // 4
542
+ edge_offset = edge_index - 4 * index_in_edge
543
+
544
+ return CubeShapeFunction.EDGE, 4 * edge_axis + edge_offset, index_in_edge
545
+
546
+ return node_type_and_index
547
+
548
+ # @wp.func
549
+ # def _cube_edge_index(node_type: int, type_index: int):
550
+ # index_in_side = type_index // 4
551
+ # side_offset = type_index - 4 * index_in_side
552
+
553
+ # return 4 * (node_type - CubeSerendipityShapeFunctions.EDGE_X) + side_offset, index_in_side
554
+
555
+ def _get_node_lobatto_indices(self):
556
+ ORDER = self.ORDER
557
+
558
+ @cache.dynamic_func(suffix=self.name)
559
+ def node_lobatto_indices(node_type: int, type_instance: int, type_index: int):
560
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
561
+ return CubeSerendipityShapeFunctions._vertex_coords(type_instance) * ORDER
562
+
563
+ axis = CubeSerendipityShapeFunctions._edge_axis(type_instance)
564
+ local_coords = CubeSerendipityShapeFunctions._edge_coords(type_instance, type_index)
565
+
566
+ local_indices = wp.vec3i(local_coords[0], local_coords[1] * ORDER, local_coords[2] * ORDER)
567
+
568
+ return Grid3D._local_to_world(axis, local_indices)
569
+
570
+ return node_lobatto_indices
571
+
572
+ def make_node_coords_in_element(self):
573
+ LOBATTO_COORDS = self.LOBATTO_COORDS
574
+
575
+ @cache.dynamic_func(suffix=self.name)
576
+ def node_coords_in_element(
577
+ node_index_in_elt: int,
578
+ ):
579
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
580
+ node_coords = self._node_lobatto_indices(node_type, type_instance, type_index)
581
+ return Coords(
582
+ LOBATTO_COORDS[node_coords[0]], LOBATTO_COORDS[node_coords[1]], LOBATTO_COORDS[node_coords[2]]
583
+ )
584
+
585
+ return node_coords_in_element
586
+
587
+ def make_node_quadrature_weight(self):
588
+ ORDER = self.ORDER
589
+
590
+ @cache.dynamic_func(suffix=self.name)
591
+ def node_quadrature_weight(
592
+ node_index_in_elt: int,
593
+ ):
594
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
595
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
596
+ return 1.0 / float(8 * ORDER * ORDER * ORDER)
597
+
598
+ return (1.0 - 1.0 / float(ORDER * ORDER * ORDER)) / float(12 * (ORDER - 1))
599
+
600
+ return node_quadrature_weight
601
+
602
+ def make_trace_node_quadrature_weight(self):
603
+ ORDER = self.ORDER
604
+
605
+ @cache.dynamic_func(suffix=self.name)
606
+ def trace_node_quadrature_weight(
607
+ node_index_in_elt: int,
608
+ ):
609
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
610
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
611
+ return 0.25 / float(ORDER * ORDER)
612
+
613
+ return (0.25 - 0.25 / float(ORDER * ORDER)) / float(ORDER - 1)
614
+
615
+ return trace_node_quadrature_weight
616
+
617
+ def make_element_inner_weight(self):
618
+ ORDER = self.ORDER
619
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
620
+
621
+ LOBATTO_COORDS = self.LOBATTO_COORDS
622
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
623
+
624
+ DEGREE_3_SPHERE_RAD = wp.constant(2 * 0.5**2 + (0.5 - LOBATTO_COORDS[1]) ** 2)
625
+ DEGREE_3_SPHERE_SCALE = 1.0 / (0.75 - DEGREE_3_SPHERE_RAD)
626
+
627
+ @cache.dynamic_func(suffix=self.name)
628
+ def element_inner_weight(
629
+ coords: Coords,
630
+ node_index_in_elt: int,
631
+ ):
632
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
633
+
634
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
635
+ node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
636
+
637
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
638
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
639
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
640
+
641
+ w = cx * cy * cz
642
+
643
+ if wp.static(ORDER == 2):
644
+ w *= cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
645
+ return w * LAGRANGE_SCALE[0]
646
+ if wp.static(ORDER == 3):
647
+ w *= (
648
+ (cx - 0.5) * (cx - 0.5)
649
+ + (cy - 0.5) * (cy - 0.5)
650
+ + (cz - 0.5) * (cz - 0.5)
651
+ - DEGREE_3_SPHERE_RAD
652
+ )
653
+ return w * DEGREE_3_SPHERE_SCALE
654
+
655
+ axis = CubeSerendipityShapeFunctions._edge_axis(type_instance)
656
+
657
+ node_all = CubeSerendipityShapeFunctions._edge_coords(type_instance, type_index)
658
+
659
+ local_coords = Grid3D._world_to_local(axis, coords)
660
+
661
+ w = float(1.0)
662
+ w *= wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
663
+ w *= wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
664
+
665
+ for k in range(ORDER_PLUS_ONE):
666
+ if k != node_all[0]:
667
+ w *= local_coords[0] - LOBATTO_COORDS[k]
668
+ w *= LAGRANGE_SCALE[node_all[0]]
669
+
670
+ return w
671
+
672
+ return element_inner_weight
673
+
674
+ def make_element_inner_weight_gradient(self):
675
+ ORDER = self.ORDER
676
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
677
+ LOBATTO_COORDS = self.LOBATTO_COORDS
678
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
679
+
680
+ DEGREE_3_SPHERE_RAD = wp.constant(2 * 0.5**2 + (0.5 - LOBATTO_COORDS[1]) ** 2)
681
+ DEGREE_3_SPHERE_SCALE = 1.0 / (0.75 - DEGREE_3_SPHERE_RAD)
682
+
683
+ @cache.dynamic_func(suffix=self.name)
684
+ def element_inner_weight_gradient(
685
+ coords: Coords,
686
+ node_index_in_elt: int,
687
+ ):
688
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
689
+
690
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
691
+ node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
692
+
693
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
694
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
695
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
696
+
697
+ gx = wp.where(node_ijk[0] == 0, -1.0, 1.0)
698
+ gy = wp.where(node_ijk[1] == 0, -1.0, 1.0)
699
+ gz = wp.where(node_ijk[2] == 0, -1.0, 1.0)
700
+
701
+ if wp.static(ORDER == 2):
702
+ w = cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
703
+ grad_x = cy * cz * gx * (w + cx)
704
+ grad_y = cz * cx * gy * (w + cy)
705
+ grad_z = cx * cy * gz * (w + cz)
706
+
707
+ return wp.vec3(grad_x, grad_y, grad_z) * LAGRANGE_SCALE[0]
708
+
709
+ if wp.static(ORDER == 3):
710
+ w = (
711
+ (cx - 0.5) * (cx - 0.5)
712
+ + (cy - 0.5) * (cy - 0.5)
713
+ + (cz - 0.5) * (cz - 0.5)
714
+ - DEGREE_3_SPHERE_RAD
715
+ )
716
+
717
+ dw_dcx = 2.0 * cx - 1.0
718
+ dw_dcy = 2.0 * cy - 1.0
719
+ dw_dcz = 2.0 * cz - 1.0
720
+ grad_x = cy * cz * gx * (w + dw_dcx * cx)
721
+ grad_y = cz * cx * gy * (w + dw_dcy * cy)
722
+ grad_z = cx * cy * gz * (w + dw_dcz * cz)
723
+
724
+ return wp.vec3(grad_x, grad_y, grad_z) * DEGREE_3_SPHERE_SCALE
725
+
726
+ axis = CubeSerendipityShapeFunctions._edge_axis(type_instance)
727
+ node_all = CubeSerendipityShapeFunctions._edge_coords(type_instance, type_index)
728
+
729
+ local_coords = Grid3D._world_to_local(axis, coords)
730
+
731
+ w_long = wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
732
+ w_lat = wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
733
+
734
+ g_long = wp.where(node_all[1] == 0, -1.0, 1.0)
735
+ g_lat = wp.where(node_all[2] == 0, -1.0, 1.0)
736
+
737
+ w_alt = LAGRANGE_SCALE[node_all[0]]
738
+ g_alt = float(0.0)
739
+ prefix_alt = LAGRANGE_SCALE[node_all[0]]
740
+ for k in range(ORDER_PLUS_ONE):
741
+ if k != node_all[0]:
742
+ delta_alt = local_coords[0] - LOBATTO_COORDS[k]
743
+ w_alt *= delta_alt
744
+ g_alt = g_alt * delta_alt + prefix_alt
745
+ prefix_alt *= delta_alt
746
+
747
+ local_grad = wp.vec3(g_alt * w_long * w_lat, w_alt * g_long * w_lat, w_alt * w_long * g_lat)
748
+
749
+ return Grid3D._local_to_world(axis, local_grad)
750
+
751
+ return element_inner_weight_gradient
752
+
753
+ def element_node_tets(self):
754
+ from warp.fem.utils import grid_to_tets
755
+
756
+ if self.ORDER == 2:
757
+ element_tets = np.array(
758
+ [
759
+ [0, 8, 9, 10],
760
+ [1, 11, 10, 15],
761
+ [2, 9, 14, 13],
762
+ [3, 15, 13, 17],
763
+ [4, 12, 8, 16],
764
+ [5, 18, 16, 11],
765
+ [6, 14, 12, 19],
766
+ [7, 19, 18, 17],
767
+ [16, 12, 18, 11],
768
+ [8, 16, 12, 11],
769
+ [12, 19, 18, 14],
770
+ [14, 19, 17, 18],
771
+ [10, 9, 15, 8],
772
+ [10, 8, 11, 15],
773
+ [9, 13, 15, 14],
774
+ [13, 14, 17, 15],
775
+ ]
776
+ )
777
+
778
+ middle_hex = np.array([8, 11, 9, 15, 12, 18, 14, 17])
779
+ middle_tets = middle_hex[grid_to_tets(1, 1, 1)]
780
+
781
+ return np.concatenate((element_tets, middle_tets))
782
+
783
+ raise NotImplementedError()
784
+
785
+ def element_vtk_cells(self):
786
+ tets = np.array(self.element_node_tets())
787
+ cell_type = 10 # VTK_TETRA
788
+
789
+ return tets, np.full(tets.shape[0], cell_type, dtype=np.int8)
790
+
791
+
792
+ class CubeNonConformingPolynomialShapeFunctions(ShapeFunction):
793
+ # embeds the largest regular tet centered at (0.5, 0.5, 0.5) into the reference cube
794
+
795
+ _tet_height = 2.0 / 3.0
796
+ _tet_side = math.sqrt(3.0 / 2.0) * _tet_height
797
+ _tet_face_height = math.sqrt(3.0) / 2.0 * _tet_side
798
+
799
+ _tet_to_cube = np.array(
800
+ [
801
+ [_tet_side, _tet_side / 2.0, _tet_side / 2.0],
802
+ [0.0, _tet_face_height, _tet_face_height / 3.0],
803
+ [0.0, 0.0, _tet_height],
804
+ ]
805
+ )
806
+
807
+ _TET_OFFSET = wp.constant(wp.vec3(0.5 - 0.5 * _tet_side, 0.5 - _tet_face_height / 3.0, 0.5 - 0.25 * _tet_height))
808
+
809
+ def __init__(self, degree: int):
810
+ self._tet_shape = TetrahedronPolynomialShapeFunctions(degree=degree)
811
+ self.ORDER = self._tet_shape.ORDER
812
+ self.NODES_PER_ELEMENT = self._tet_shape.NODES_PER_ELEMENT
813
+
814
+ self.element_node_tets = self._tet_shape.element_node_tets
815
+ self.element_vtk_cells = self._tet_shape.element_vtk_cells
816
+
817
+ @property
818
+ def name(self) -> str:
819
+ return f"Cube_P{self.ORDER}d"
820
+
821
+ def make_node_coords_in_element(self):
822
+ node_coords_in_tet = self._tet_shape.make_node_coords_in_element()
823
+
824
+ TET_TO_CUBE = wp.constant(wp.mat33(self._tet_to_cube))
825
+
826
+ @cache.dynamic_func(suffix=self.name)
827
+ def node_coords_in_element(
828
+ node_index_in_elt: int,
829
+ ):
830
+ tet_coords = node_coords_in_tet(node_index_in_elt)
831
+ return TET_TO_CUBE * tet_coords + CubeNonConformingPolynomialShapeFunctions._TET_OFFSET
832
+
833
+ return node_coords_in_element
834
+
835
+ def make_node_quadrature_weight(self):
836
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
837
+
838
+ @cache.dynamic_func(suffix=self.name)
839
+ def node_uniform_quadrature_weight(
840
+ node_index_in_elt: int,
841
+ ):
842
+ return 1.0 / float(NODES_PER_ELEMENT)
843
+
844
+ return node_uniform_quadrature_weight
845
+
846
+ def make_trace_node_quadrature_weight(self):
847
+ # Non-conforming, zero measure on sides
848
+
849
+ @wp.func
850
+ def zero(node_index_in_elt: int):
851
+ return 0.0
852
+
853
+ return zero
854
+
855
+ def make_element_inner_weight(self):
856
+ tet_inner_weight = self._tet_shape.make_element_inner_weight()
857
+
858
+ CUBE_TO_TET = wp.constant(wp.mat33(np.linalg.inv(self._tet_to_cube)))
859
+
860
+ @cache.dynamic_func(suffix=self.name)
861
+ def element_inner_weight(
862
+ coords: Coords,
863
+ node_index_in_elt: int,
864
+ ):
865
+ tet_coords = CUBE_TO_TET * (coords - CubeNonConformingPolynomialShapeFunctions._TET_OFFSET)
866
+
867
+ return tet_inner_weight(tet_coords, node_index_in_elt)
868
+
869
+ return element_inner_weight
870
+
871
+ def make_element_inner_weight_gradient(self):
872
+ tet_inner_weight_gradient = self._tet_shape.make_element_inner_weight_gradient()
873
+
874
+ CUBE_TO_TET = wp.constant(wp.mat33(np.linalg.inv(self._tet_to_cube)))
875
+
876
+ @cache.dynamic_func(suffix=self.name)
877
+ def element_inner_weight_gradient(
878
+ coords: Coords,
879
+ node_index_in_elt: int,
880
+ ):
881
+ tet_coords = CUBE_TO_TET * (coords - CubeNonConformingPolynomialShapeFunctions._TET_OFFSET)
882
+ grad = tet_inner_weight_gradient(tet_coords, node_index_in_elt)
883
+ return wp.transpose(CUBE_TO_TET) * grad
884
+
885
+ return element_inner_weight_gradient
886
+
887
+
888
+ class CubeNedelecFirstKindShapeFunctions(CubeShapeFunction):
889
+ value = ShapeFunction.Value.CovariantVector
890
+
891
+ def __init__(self, degree: int):
892
+ if degree != 1:
893
+ raise NotImplementedError("Only linear Nédélec implemented right now")
894
+
895
+ self.ORDER = wp.constant(degree)
896
+ self.NODES_PER_ELEMENT = wp.constant(12)
897
+ self.NODES_PER_SIDE = wp.constant(4)
898
+
899
+ self.VERTEX_NODE_COUNT = wp.constant(0)
900
+ self.EDGE_NODE_COUNT = wp.constant(1)
901
+ self.FACE_NODE_COUNT = wp.constant(0)
902
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
903
+
904
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
905
+
906
+ @property
907
+ def name(self) -> str:
908
+ return f"CubeN1_{self.ORDER}"
909
+
910
+ def _get_node_type_and_type_index(self):
911
+ @cache.dynamic_func(suffix=self.name)
912
+ def node_type_and_index(
913
+ node_index_in_elt: int,
914
+ ):
915
+ return CubeShapeFunction.EDGE, node_index_in_elt, 0
916
+
917
+ return node_type_and_index
918
+
919
+ def make_node_coords_in_element(self):
920
+ @cache.dynamic_func(suffix=self.name)
921
+ def node_coords_in_element(
922
+ node_index_in_elt: int,
923
+ ):
924
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
925
+ axis = CubeShapeFunction._edge_axis(type_instance)
926
+ local_indices = CubeShapeFunction._edge_coords(type_instance, type_index)
927
+
928
+ local_coords = wp.vec3f(0.5, float(local_indices[1]), float(local_indices[2]))
929
+ return Grid3D._local_to_world(axis, local_coords)
930
+
931
+ return node_coords_in_element
932
+
933
+ def make_node_quadrature_weight(self):
934
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
935
+
936
+ @cache.dynamic_func(suffix=self.name)
937
+ def node_quadrature_weight(node_index_in_element: int):
938
+ return 1.0 / float(NODES_PER_ELEMENT)
939
+
940
+ return node_quadrature_weight
941
+
942
+ def make_trace_node_quadrature_weight(self):
943
+ NODES_PER_SIDE = self.NODES_PER_SIDE
944
+
945
+ @cache.dynamic_func(suffix=self.name)
946
+ def trace_node_quadrature_weight(node_index_in_element: int):
947
+ return 1.0 / float(NODES_PER_SIDE)
948
+
949
+ return trace_node_quadrature_weight
950
+
951
+ def make_element_inner_weight(self):
952
+ @cache.dynamic_func(suffix=self.name)
953
+ def element_inner_weight(
954
+ coords: Coords,
955
+ node_index_in_elt: int,
956
+ ):
957
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
958
+
959
+ axis = CubeShapeFunction._edge_axis(type_instance)
960
+
961
+ local_coords = Grid3D._world_to_local(axis, coords)
962
+ edge_coords = CubeShapeFunction._edge_coords(type_instance, type_index)
963
+
964
+ a1 = float(2 * edge_coords[1] - 1)
965
+ a2 = float(2 * edge_coords[2] - 1)
966
+ b1 = float(1 - edge_coords[1])
967
+ b2 = float(1 - edge_coords[2])
968
+
969
+ local_w = wp.vec3((b1 + a1 * local_coords[1]) * (b2 + a2 * local_coords[2]), 0.0, 0.0)
970
+ return Grid3D._local_to_world(axis, local_w)
971
+
972
+ return element_inner_weight
973
+
974
+ def make_element_inner_weight_gradient(self):
975
+ @cache.dynamic_func(suffix=self.name)
976
+ def element_inner_weight_gradient(
977
+ coords: Coords,
978
+ node_index_in_elt: int,
979
+ ):
980
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
981
+
982
+ axis = CubeShapeFunction._edge_axis(type_instance)
983
+
984
+ local_coords = Grid3D._world_to_local(axis, coords)
985
+ edge_coords = CubeShapeFunction._edge_coords(type_instance, type_index)
986
+
987
+ a1 = float(2 * edge_coords[1] - 1)
988
+ a2 = float(2 * edge_coords[2] - 1)
989
+ b1 = float(1 - edge_coords[1])
990
+ b2 = float(1 - edge_coords[2])
991
+
992
+ local_gw = Grid3D._local_to_world(
993
+ axis, wp.vec3(0.0, a1 * (b2 + a2 * local_coords[2]), (b1 + a1 * local_coords[1]) * a2)
994
+ )
995
+
996
+ grad = wp.mat33(0.0)
997
+ grad[axis] = local_gw
998
+ return grad
999
+
1000
+ return element_inner_weight_gradient
1001
+
1002
+
1003
+ class CubeRaviartThomasShapeFunctions(CubeShapeFunction):
1004
+ value = ShapeFunction.Value.ContravariantVector
1005
+
1006
+ def __init__(self, degree: int):
1007
+ if degree != 1:
1008
+ raise NotImplementedError("Only linear Raviart Thomas implemented right now")
1009
+
1010
+ self.ORDER = wp.constant(degree)
1011
+ self.NODES_PER_ELEMENT = wp.constant(6)
1012
+ self.NODES_PER_SIDE = wp.constant(1)
1013
+
1014
+ self.VERTEX_NODE_COUNT = wp.constant(0)
1015
+ self.EDGE_NODE_COUNT = wp.constant(0)
1016
+ self.FACE_NODE_COUNT = wp.constant(1)
1017
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
1018
+
1019
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
1020
+
1021
+ @property
1022
+ def name(self) -> str:
1023
+ return f"CubeRT_{self.ORDER}"
1024
+
1025
+ def _get_node_type_and_type_index(self):
1026
+ @cache.dynamic_func(suffix=self.name)
1027
+ def node_type_and_index(
1028
+ node_index_in_elt: int,
1029
+ ):
1030
+ return CubeShapeFunction.FACE, node_index_in_elt, 0
1031
+
1032
+ return node_type_and_index
1033
+
1034
+ def make_node_coords_in_element(self):
1035
+ @cache.dynamic_func(suffix=self.name)
1036
+ def node_coords_in_element(
1037
+ node_index_in_elt: int,
1038
+ ):
1039
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
1040
+ axis = CubeShapeFunction._face_axis(type_instance)
1041
+ offset = CubeShapeFunction._face_offset(type_instance)
1042
+
1043
+ coords = Coords(0.5)
1044
+ coords[axis] = float(offset)
1045
+ return coords
1046
+
1047
+ return node_coords_in_element
1048
+
1049
+ def make_node_quadrature_weight(self):
1050
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
1051
+
1052
+ @cache.dynamic_func(suffix=self.name)
1053
+ def node_quadrature_weight(node_index_in_element: int):
1054
+ return 1.0 / float(NODES_PER_ELEMENT)
1055
+
1056
+ return node_quadrature_weight
1057
+
1058
+ def make_trace_node_quadrature_weight(self):
1059
+ NODES_PER_SIDE = self.NODES_PER_SIDE
1060
+
1061
+ @cache.dynamic_func(suffix=self.name)
1062
+ def trace_node_quadrature_weight(node_index_in_element: int):
1063
+ return 1.0 / float(NODES_PER_SIDE)
1064
+
1065
+ return trace_node_quadrature_weight
1066
+
1067
+ def make_element_inner_weight(self):
1068
+ @cache.dynamic_func(suffix=self.name)
1069
+ def element_inner_weight(
1070
+ coords: Coords,
1071
+ node_index_in_elt: int,
1072
+ ):
1073
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
1074
+
1075
+ axis = CubeShapeFunction._face_axis(type_instance)
1076
+ offset = CubeShapeFunction._face_offset(type_instance)
1077
+
1078
+ a = float(2 * offset - 1)
1079
+ b = float(1 - offset)
1080
+
1081
+ w = wp.vec3(0.0)
1082
+ w[axis] = b + a * coords[axis]
1083
+
1084
+ return w
1085
+
1086
+ return element_inner_weight
1087
+
1088
+ def make_element_inner_weight_gradient(self):
1089
+ @cache.dynamic_func(suffix=self.name)
1090
+ def element_inner_weight_gradient(
1091
+ coords: Coords,
1092
+ node_index_in_elt: int,
1093
+ ):
1094
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
1095
+
1096
+ axis = CubeShapeFunction._face_axis(type_instance)
1097
+ offset = CubeShapeFunction._face_offset(type_instance)
1098
+
1099
+ a = float(2 * offset - 1)
1100
+ grad = wp.mat33(0.0)
1101
+ grad[axis, axis] = a
1102
+
1103
+ return grad
1104
+
1105
+ return element_inner_weight_gradient