warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,223 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp.fem import cache
18
+ from warp.fem.geometry import Quadmesh2D
19
+ from warp.fem.polynomial import is_closed
20
+ from warp.fem.types import NULL_NODE_INDEX, ElementIndex
21
+
22
+ from .shape import SquareShapeFunction
23
+ from .topology import SpaceTopology, forward_base_topology
24
+
25
+
26
+ @wp.struct
27
+ class Quadmesh2DTopologyArg:
28
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
29
+ quad_edge_indices: wp.array2d(dtype=int)
30
+
31
+ vertex_count: int
32
+ edge_count: int
33
+ cell_count: int
34
+
35
+
36
+ class QuadmeshSpaceTopology(SpaceTopology):
37
+ TopologyArg = Quadmesh2DTopologyArg
38
+
39
+ def __init__(self, mesh: Quadmesh2D, shape: SquareShapeFunction):
40
+ if shape.value == SquareShapeFunction.Value.Scalar and not is_closed(shape.family):
41
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
42
+
43
+ self._shape = shape
44
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
45
+ self._mesh = mesh
46
+
47
+ self._compute_quad_edge_indices()
48
+ self.element_node_index = self._make_element_node_index()
49
+ self.element_node_sign = self._make_element_node_sign()
50
+
51
+ @property
52
+ def name(self):
53
+ return f"{self.geometry.name}_{self._shape.name}"
54
+
55
+ @cache.cached_arg_value
56
+ def topo_arg_value(self, device):
57
+ arg = Quadmesh2DTopologyArg()
58
+ arg.quad_edge_indices = self._quad_edge_indices.to(device)
59
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
60
+
61
+ arg.vertex_count = self._mesh.vertex_count()
62
+ arg.edge_count = self._mesh.side_count()
63
+ arg.cell_count = self._mesh.cell_count()
64
+ return arg
65
+
66
+ def _compute_quad_edge_indices(self):
67
+ self._quad_edge_indices = wp.empty(
68
+ dtype=int, device=self._mesh.quad_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
69
+ )
70
+
71
+ wp.launch(
72
+ kernel=QuadmeshSpaceTopology._compute_quad_edge_indices_kernel,
73
+ dim=self._mesh.edge_quad_indices.shape,
74
+ device=self._mesh.quad_vertex_indices.device,
75
+ inputs=[
76
+ self._mesh.edge_quad_indices,
77
+ self._mesh.edge_vertex_indices,
78
+ self._mesh.quad_vertex_indices,
79
+ self._quad_edge_indices,
80
+ ],
81
+ )
82
+
83
+ @wp.func
84
+ def _find_edge_index_in_quad(
85
+ edge_vtx: wp.vec2i,
86
+ quad_vtx: wp.vec4i,
87
+ ):
88
+ for k in range(3):
89
+ if (edge_vtx[0] == quad_vtx[k] and edge_vtx[1] == quad_vtx[k + 1]) or (
90
+ edge_vtx[1] == quad_vtx[k] and edge_vtx[0] == quad_vtx[k + 1]
91
+ ):
92
+ return k
93
+ return 3
94
+
95
+ @wp.kernel
96
+ def _compute_quad_edge_indices_kernel(
97
+ edge_quad_indices: wp.array(dtype=wp.vec2i),
98
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
99
+ quad_vertex_indices: wp.array2d(dtype=int),
100
+ quad_edge_indices: wp.array2d(dtype=int),
101
+ ):
102
+ e = wp.tid()
103
+
104
+ edge_vtx = edge_vertex_indices[e]
105
+ edge_quads = edge_quad_indices[e]
106
+
107
+ q0 = edge_quads[0]
108
+ q0_vtx = wp.vec4i(
109
+ quad_vertex_indices[q0, 0],
110
+ quad_vertex_indices[q0, 1],
111
+ quad_vertex_indices[q0, 2],
112
+ quad_vertex_indices[q0, 3],
113
+ )
114
+ q0_edge = QuadmeshSpaceTopology._find_edge_index_in_quad(edge_vtx, q0_vtx)
115
+ quad_edge_indices[q0, q0_edge] = e
116
+
117
+ q1 = edge_quads[1]
118
+ if q1 != q0:
119
+ t1_vtx = wp.vec4i(
120
+ quad_vertex_indices[q1, 0],
121
+ quad_vertex_indices[q1, 1],
122
+ quad_vertex_indices[q1, 2],
123
+ quad_vertex_indices[q1, 3],
124
+ )
125
+ t1_edge = QuadmeshSpaceTopology._find_edge_index_in_quad(edge_vtx, t1_vtx)
126
+ quad_edge_indices[q1, t1_edge] = e
127
+
128
+ def node_count(self) -> int:
129
+ return (
130
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
131
+ + self.geometry.side_count() * self._shape.EDGE_NODE_COUNT
132
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
133
+ )
134
+
135
+ def _make_element_node_index(self):
136
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
137
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
138
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
139
+
140
+ SHAPE_TO_QUAD_IDX = wp.constant(wp.vec4i([0, 3, 1, 2]))
141
+
142
+ @cache.dynamic_func(suffix=self.name)
143
+ def element_node_index(
144
+ cell_arg: self._mesh.CellArg,
145
+ topo_arg: QuadmeshSpaceTopology.TopologyArg,
146
+ element_index: ElementIndex,
147
+ node_index_in_elt: int,
148
+ ):
149
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
150
+
151
+ if wp.static(VERTEX_NODE_COUNT > 0):
152
+ if node_type == SquareShapeFunction.VERTEX:
153
+ return (
154
+ cell_arg.topology.quad_vertex_indices[element_index, SHAPE_TO_QUAD_IDX[type_instance]]
155
+ * VERTEX_NODE_COUNT
156
+ + type_index
157
+ )
158
+
159
+ global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
160
+
161
+ if wp.static(INTERIOR_NODE_COUNT > 0):
162
+ if node_type == SquareShapeFunction.INTERIOR:
163
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
164
+
165
+ global_offset += INTERIOR_NODE_COUNT * topo_arg.cell_count
166
+
167
+ if wp.static(EDGE_NODE_COUNT > 0):
168
+ # EDGE_X, EDGE_Y
169
+ side_start = wp.where(
170
+ node_type == SquareShapeFunction.EDGE_X,
171
+ wp.where(type_instance == 0, 0, 2),
172
+ wp.where(type_instance == 0, 3, 1),
173
+ )
174
+
175
+ side_index = topo_arg.quad_edge_indices[element_index, side_start]
176
+ local_vs = cell_arg.topology.quad_vertex_indices[element_index, side_start]
177
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
178
+
179
+ # Flip indexing direction
180
+ flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
181
+ index_in_side = wp.where(flipped, EDGE_NODE_COUNT - 1 - type_index, type_index)
182
+
183
+ return global_offset + EDGE_NODE_COUNT * side_index + index_in_side
184
+
185
+ return NULL_NODE_INDEX # should never happen
186
+
187
+ return element_node_index
188
+
189
+ def _make_element_node_sign(self):
190
+ @cache.dynamic_func(suffix=self.name)
191
+ def element_node_sign(
192
+ cell_arg: self._mesh.CellArg,
193
+ topo_arg: QuadmeshSpaceTopology.TopologyArg,
194
+ element_index: ElementIndex,
195
+ node_index_in_elt: int,
196
+ ):
197
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
198
+
199
+ if node_type == SquareShapeFunction.EDGE_X or node_type == SquareShapeFunction.EDGE_Y:
200
+ side_start = wp.where(
201
+ node_type == SquareShapeFunction.EDGE_X,
202
+ wp.where(type_instance == 0, 0, 2),
203
+ wp.where(type_instance == 0, 3, 1),
204
+ )
205
+
206
+ side_index = topo_arg.quad_edge_indices[element_index, side_start]
207
+ local_vs = cell_arg.topology.quad_vertex_indices[element_index, side_start]
208
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
209
+
210
+ # Flip indexing direction
211
+ flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
212
+ return wp.where(flipped, -1.0, 1.0)
213
+
214
+ return 1.0
215
+
216
+ return element_node_sign
217
+
218
+
219
+ def make_quadmesh_space_topology(mesh: Quadmesh2D, shape: SquareShapeFunction):
220
+ if isinstance(shape, SquareShapeFunction):
221
+ return forward_base_topology(QuadmeshSpaceTopology, mesh, shape)
222
+
223
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,179 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp.fem.cache import TemporaryStore, borrow_temporary, borrow_temporary_like, cached_arg_value
18
+ from warp.fem.domain import GeometryDomain
19
+ from warp.fem.types import NULL_NODE_INDEX, NodeElementIndex
20
+ from warp.fem.utils import compress_node_indices
21
+
22
+ from .partition import SpacePartition
23
+
24
+ wp.set_module_options({"enable_backward": False})
25
+
26
+
27
+ class SpaceRestriction:
28
+ """Restriction of a space partition to a given GeometryDomain"""
29
+
30
+ def __init__(
31
+ self,
32
+ space_partition: SpacePartition,
33
+ domain: GeometryDomain,
34
+ device=None,
35
+ temporary_store: TemporaryStore = None,
36
+ ):
37
+ space_topology = space_partition.space_topology
38
+
39
+ if domain.dimension == space_topology.dimension - 1:
40
+ space_topology = space_topology.trace()
41
+
42
+ if domain.dimension != space_topology.dimension:
43
+ raise ValueError("Incompatible space and domain dimensions")
44
+
45
+ self.space_partition = space_partition
46
+ self.space_topology = space_topology
47
+ self.domain = domain
48
+
49
+ self._compute_node_element_indices(device=device, temporary_store=temporary_store)
50
+
51
+ def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
52
+ from warp.fem import cache
53
+
54
+ MAX_NODES_PER_ELEMENT = self.space_topology.MAX_NODES_PER_ELEMENT
55
+
56
+ @cache.dynamic_kernel(
57
+ suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}",
58
+ kernel_options={"max_unroll": 8},
59
+ )
60
+ def fill_element_node_indices(
61
+ element_arg: self.domain.ElementArg,
62
+ domain_index_arg: self.domain.ElementIndexArg,
63
+ topo_arg: self.space_topology.TopologyArg,
64
+ partition_arg: self.space_partition.PartitionArg,
65
+ element_node_indices: wp.array2d(dtype=int),
66
+ ):
67
+ domain_element_index = wp.tid()
68
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
69
+ element_node_count = self.space_topology.element_node_count(element_arg, topo_arg, element_index)
70
+ for n in range(element_node_count):
71
+ space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
72
+ partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
73
+ element_node_indices[domain_element_index, n] = partition_nidx
74
+ for n in range(element_node_count, MAX_NODES_PER_ELEMENT):
75
+ element_node_indices[domain_element_index, n] = NULL_NODE_INDEX
76
+
77
+ element_node_indices = borrow_temporary(
78
+ temporary_store,
79
+ shape=(self.domain.element_count(), MAX_NODES_PER_ELEMENT),
80
+ dtype=int,
81
+ device=device,
82
+ )
83
+ wp.launch(
84
+ dim=element_node_indices.array.shape[0],
85
+ kernel=fill_element_node_indices,
86
+ inputs=[
87
+ self.domain.element_arg_value(device),
88
+ self.domain.element_index_arg_value(device),
89
+ self.space_topology.topo_arg_value(device),
90
+ self.space_partition.partition_arg_value(device),
91
+ element_node_indices.array,
92
+ ],
93
+ device=device,
94
+ )
95
+
96
+ # Build compressed map from node to element indices
97
+ flattened_node_indices = element_node_indices.array.flatten()
98
+ (
99
+ self._dof_partition_element_offsets,
100
+ node_array_indices,
101
+ self._node_count,
102
+ self._dof_partition_indices,
103
+ ) = compress_node_indices(
104
+ self.space_partition.node_count(),
105
+ flattened_node_indices,
106
+ return_unique_nodes=True,
107
+ temporary_store=temporary_store,
108
+ )
109
+
110
+ # Extract element index and index in element
111
+ self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
112
+ self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
113
+ wp.launch(
114
+ kernel=SpaceRestriction._split_vertex_element_index,
115
+ dim=flattened_node_indices.shape,
116
+ inputs=[
117
+ MAX_NODES_PER_ELEMENT,
118
+ node_array_indices.array,
119
+ self._dof_element_indices.array,
120
+ self._dof_indices_in_element.array,
121
+ ],
122
+ device=flattened_node_indices.device,
123
+ )
124
+
125
+ node_array_indices.release()
126
+
127
+ def node_count(self):
128
+ return self._node_count
129
+
130
+ def partition_element_offsets(self):
131
+ return self._dof_partition_element_offsets.array
132
+
133
+ def node_partition_indices(self):
134
+ return self._dof_partition_indices.array
135
+
136
+ def total_node_element_count(self):
137
+ return self._dof_element_indices.array.size
138
+
139
+ @wp.struct
140
+ class NodeArg:
141
+ dof_element_offsets: wp.array(dtype=int)
142
+ dof_element_indices: wp.array(dtype=int)
143
+ dof_partition_indices: wp.array(dtype=int)
144
+ dof_indices_in_element: wp.array(dtype=int)
145
+
146
+ @cached_arg_value
147
+ def node_arg(self, device):
148
+ arg = SpaceRestriction.NodeArg()
149
+ arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
150
+ arg.dof_element_indices = self._dof_element_indices.array.to(device)
151
+ arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
152
+ arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
153
+ return arg
154
+
155
+ @wp.func
156
+ def node_partition_index(args: NodeArg, restriction_node_index: int):
157
+ return args.dof_partition_indices[restriction_node_index]
158
+
159
+ @wp.func
160
+ def node_element_range(args: NodeArg, partition_node_index: int):
161
+ return args.dof_element_offsets[partition_node_index], args.dof_element_offsets[partition_node_index + 1]
162
+
163
+ @wp.func
164
+ def node_element_index(args: NodeArg, node_element_offset: int):
165
+ domain_element_index = args.dof_element_indices[node_element_offset]
166
+ index_in_element = args.dof_indices_in_element[node_element_offset]
167
+ return NodeElementIndex(domain_element_index, index_in_element)
168
+
169
+ @wp.kernel
170
+ def _split_vertex_element_index(
171
+ vertex_per_element: int,
172
+ sorted_indices: wp.array(dtype=int),
173
+ vertex_element_index: wp.array(dtype=int),
174
+ vertex_index_in_element: wp.array(dtype=int),
175
+ ):
176
+ idx = sorted_indices[wp.tid()]
177
+ element_index = idx // vertex_per_element
178
+ vertex_element_index[wp.tid()] = element_index
179
+ vertex_index_in_element[wp.tid()] = idx - vertex_per_element * element_index
@@ -0,0 +1,143 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+ from typing import Optional
18
+
19
+ from warp.fem.geometry import element as _element
20
+ from warp.fem.polynomial import Polynomial
21
+
22
+ from .cube_shape_function import (
23
+ CubeNedelecFirstKindShapeFunctions,
24
+ CubeNonConformingPolynomialShapeFunctions,
25
+ CubeRaviartThomasShapeFunctions,
26
+ CubeSerendipityShapeFunctions,
27
+ CubeShapeFunction,
28
+ CubeTripolynomialShapeFunctions,
29
+ )
30
+ from .shape_function import ConstantShapeFunction, ShapeFunction
31
+ from .square_shape_function import (
32
+ SquareBipolynomialShapeFunctions,
33
+ SquareNedelecFirstKindShapeFunctions,
34
+ SquareNonConformingPolynomialShapeFunctions,
35
+ SquareRaviartThomasShapeFunctions,
36
+ SquareSerendipityShapeFunctions,
37
+ SquareShapeFunction,
38
+ )
39
+ from .tet_shape_function import (
40
+ TetrahedronNedelecFirstKindShapeFunctions,
41
+ TetrahedronNonConformingPolynomialShapeFunctions,
42
+ TetrahedronPolynomialShapeFunctions,
43
+ TetrahedronRaviartThomasShapeFunctions,
44
+ TetrahedronShapeFunction,
45
+ )
46
+ from .triangle_shape_function import (
47
+ TriangleNedelecFirstKindShapeFunctions,
48
+ TriangleNonConformingPolynomialShapeFunctions,
49
+ TrianglePolynomialShapeFunctions,
50
+ TriangleRaviartThomasShapeFunctions,
51
+ TriangleShapeFunction,
52
+ )
53
+
54
+
55
+ class ElementBasis(Enum):
56
+ """Choice of basis function to equip individual elements"""
57
+
58
+ LAGRANGE = "P"
59
+ """Lagrange basis functions :math:`P_k` for simplices, tensor products :math:`Q_k` for squares and cubes"""
60
+ SERENDIPITY = "S"
61
+ """Serendipity elements :math:`S_k`, corresponding to Lagrange nodes with interior points removed (for degree <= 3)"""
62
+ NONCONFORMING_POLYNOMIAL = "dP"
63
+ """Simplex Lagrange basis functions :math:`P_{kd}` embedded into non conforming reference elements (e.g. squares or cubes). Discontinuous only."""
64
+ NEDELEC_FIRST_KIND = "N1"
65
+ """Nédélec (first kind) H(curl) shape functions. Should be used with covariant function space."""
66
+ RAVIART_THOMAS = "RT"
67
+ """Raviart-Thomas H(div) shape functions. Should be used with contravariant function space."""
68
+
69
+
70
+ def get_shape_function(
71
+ element: _element.Element,
72
+ space_dimension: int,
73
+ degree: int,
74
+ element_basis: ElementBasis,
75
+ family: Optional[Polynomial] = None,
76
+ ):
77
+ """
78
+ Equips a reference element with a shape function basis.
79
+
80
+ Args:
81
+ element: the reference element on which to build the shape function
82
+ space_dimension: the dimension of the embedding space
83
+ degree: polynomial degree of the per-element shape functions
84
+ element_basis: type of basis function for the individual elements
85
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
86
+
87
+ Returns:
88
+ the corresponding shape function
89
+ """
90
+
91
+ if degree == 0:
92
+ return ConstantShapeFunction(element, space_dimension)
93
+
94
+ if family is None:
95
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
96
+
97
+ if isinstance(element, _element.Square):
98
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
99
+ return SquareNedelecFirstKindShapeFunctions(degree=degree)
100
+ if element_basis == ElementBasis.RAVIART_THOMAS:
101
+ return SquareRaviartThomasShapeFunctions(degree=degree)
102
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
103
+ return SquareNonConformingPolynomialShapeFunctions(degree=degree)
104
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
105
+ return SquareSerendipityShapeFunctions(degree=degree, family=family)
106
+
107
+ return SquareBipolynomialShapeFunctions(degree=degree, family=family)
108
+ if isinstance(element, _element.Triangle):
109
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
110
+ return TriangleNedelecFirstKindShapeFunctions(degree=degree)
111
+ if element_basis == ElementBasis.RAVIART_THOMAS:
112
+ return TriangleRaviartThomasShapeFunctions(degree=degree)
113
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
114
+ return TriangleNonConformingPolynomialShapeFunctions(degree=degree)
115
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
116
+ raise NotImplementedError("Serendipity variant not implemented yet for Triangle elements")
117
+
118
+ return TrianglePolynomialShapeFunctions(degree=degree)
119
+
120
+ if isinstance(element, _element.Cube):
121
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
122
+ return CubeNedelecFirstKindShapeFunctions(degree=degree)
123
+ if element_basis == ElementBasis.RAVIART_THOMAS:
124
+ return CubeRaviartThomasShapeFunctions(degree=degree)
125
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
126
+ return CubeNonConformingPolynomialShapeFunctions(degree=degree)
127
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
128
+ return CubeSerendipityShapeFunctions(degree=degree, family=family)
129
+
130
+ return CubeTripolynomialShapeFunctions(degree=degree, family=family)
131
+ if isinstance(element, _element.Tetrahedron):
132
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
133
+ return TetrahedronNedelecFirstKindShapeFunctions(degree=degree)
134
+ if element_basis == ElementBasis.RAVIART_THOMAS:
135
+ return TetrahedronRaviartThomasShapeFunctions(degree=degree)
136
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
137
+ return TetrahedronNonConformingPolynomialShapeFunctions(degree=degree)
138
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
139
+ raise NotImplementedError("Serendipity variant not implemented yet for Tet elements")
140
+
141
+ return TetrahedronPolynomialShapeFunctions(degree=degree)
142
+
143
+ return NotImplementedError("Unrecognized element type")