warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,222 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Quadmesh2D
19
+ from warp._src.fem.polynomial import is_closed
20
+ from warp._src.fem.types import NULL_NODE_INDEX, ElementIndex
21
+
22
+ from .shape import SquareShapeFunction
23
+ from .topology import SpaceTopology, forward_base_topology
24
+
25
+ _wp_module_name_ = "warp.fem.space.quadmesh_function_space"
26
+
27
+
28
+ @wp.struct
29
+ class Quadmesh2DTopologyArg:
30
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
31
+ quad_edge_indices: wp.array2d(dtype=int)
32
+
33
+ vertex_count: int
34
+ edge_count: int
35
+ cell_count: int
36
+
37
+
38
+ class QuadmeshSpaceTopology(SpaceTopology):
39
+ TopologyArg = Quadmesh2DTopologyArg
40
+
41
+ def __init__(self, mesh: Quadmesh2D, shape: SquareShapeFunction):
42
+ if shape.value == SquareShapeFunction.Value.Scalar and not is_closed(shape.family):
43
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
44
+
45
+ self._shape = shape
46
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
47
+ self._mesh = mesh
48
+
49
+ self._compute_quad_edge_indices()
50
+ self.element_node_index = self._make_element_node_index()
51
+ self.element_node_sign = self._make_element_node_sign()
52
+
53
+ @property
54
+ def name(self):
55
+ return f"{self.geometry.name}_{self._shape.name}"
56
+
57
+ def fill_topo_arg(self, arg: Quadmesh2DTopologyArg, device):
58
+ arg.quad_edge_indices = self._quad_edge_indices.to(device)
59
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
60
+
61
+ arg.vertex_count = self._mesh.vertex_count()
62
+ arg.edge_count = self._mesh.side_count()
63
+ arg.cell_count = self._mesh.cell_count()
64
+
65
+ def _compute_quad_edge_indices(self):
66
+ self._quad_edge_indices = wp.empty(
67
+ dtype=int, device=self._mesh.quad_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
68
+ )
69
+
70
+ wp.launch(
71
+ kernel=QuadmeshSpaceTopology._compute_quad_edge_indices_kernel,
72
+ dim=self._mesh.edge_quad_indices.shape,
73
+ device=self._mesh.quad_vertex_indices.device,
74
+ inputs=[
75
+ self._mesh.edge_quad_indices,
76
+ self._mesh.edge_vertex_indices,
77
+ self._mesh.quad_vertex_indices,
78
+ self._quad_edge_indices,
79
+ ],
80
+ )
81
+
82
+ @wp.func
83
+ def _find_edge_index_in_quad(
84
+ edge_vtx: wp.vec2i,
85
+ quad_vtx: wp.vec4i,
86
+ ):
87
+ for k in range(3):
88
+ if (edge_vtx[0] == quad_vtx[k] and edge_vtx[1] == quad_vtx[k + 1]) or (
89
+ edge_vtx[1] == quad_vtx[k] and edge_vtx[0] == quad_vtx[k + 1]
90
+ ):
91
+ return k
92
+ return 3
93
+
94
+ @wp.kernel
95
+ def _compute_quad_edge_indices_kernel(
96
+ edge_quad_indices: wp.array(dtype=wp.vec2i),
97
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
98
+ quad_vertex_indices: wp.array2d(dtype=int),
99
+ quad_edge_indices: wp.array2d(dtype=int),
100
+ ):
101
+ e = wp.tid()
102
+
103
+ edge_vtx = edge_vertex_indices[e]
104
+ edge_quads = edge_quad_indices[e]
105
+
106
+ q0 = edge_quads[0]
107
+ q0_vtx = wp.vec4i(
108
+ quad_vertex_indices[q0, 0],
109
+ quad_vertex_indices[q0, 1],
110
+ quad_vertex_indices[q0, 2],
111
+ quad_vertex_indices[q0, 3],
112
+ )
113
+ q0_edge = QuadmeshSpaceTopology._find_edge_index_in_quad(edge_vtx, q0_vtx)
114
+ quad_edge_indices[q0, q0_edge] = e
115
+
116
+ q1 = edge_quads[1]
117
+ if q1 != q0:
118
+ t1_vtx = wp.vec4i(
119
+ quad_vertex_indices[q1, 0],
120
+ quad_vertex_indices[q1, 1],
121
+ quad_vertex_indices[q1, 2],
122
+ quad_vertex_indices[q1, 3],
123
+ )
124
+ t1_edge = QuadmeshSpaceTopology._find_edge_index_in_quad(edge_vtx, t1_vtx)
125
+ quad_edge_indices[q1, t1_edge] = e
126
+
127
+ def node_count(self) -> int:
128
+ return (
129
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
130
+ + self.geometry.side_count() * self._shape.EDGE_NODE_COUNT
131
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
132
+ )
133
+
134
+ def _make_element_node_index(self):
135
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
136
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
137
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
138
+
139
+ SHAPE_TO_QUAD_IDX = wp.constant(wp.vec4i([0, 3, 1, 2]))
140
+
141
+ @cache.dynamic_func(suffix=self.name)
142
+ def element_node_index(
143
+ cell_arg: self._mesh.CellArg,
144
+ topo_arg: QuadmeshSpaceTopology.TopologyArg,
145
+ element_index: ElementIndex,
146
+ node_index_in_elt: int,
147
+ ):
148
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
149
+
150
+ if wp.static(VERTEX_NODE_COUNT > 0):
151
+ if node_type == SquareShapeFunction.VERTEX:
152
+ return (
153
+ cell_arg.topology.quad_vertex_indices[element_index, SHAPE_TO_QUAD_IDX[type_instance]]
154
+ * VERTEX_NODE_COUNT
155
+ + type_index
156
+ )
157
+
158
+ global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
159
+
160
+ if wp.static(INTERIOR_NODE_COUNT > 0):
161
+ if node_type == SquareShapeFunction.INTERIOR:
162
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
163
+
164
+ global_offset += INTERIOR_NODE_COUNT * topo_arg.cell_count
165
+
166
+ if wp.static(EDGE_NODE_COUNT > 0):
167
+ # EDGE_X, EDGE_Y
168
+ side_start = wp.where(
169
+ node_type == SquareShapeFunction.EDGE_X,
170
+ wp.where(type_instance == 0, 0, 2),
171
+ wp.where(type_instance == 0, 3, 1),
172
+ )
173
+
174
+ side_index = topo_arg.quad_edge_indices[element_index, side_start]
175
+ local_vs = cell_arg.topology.quad_vertex_indices[element_index, side_start]
176
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
177
+
178
+ # Flip indexing direction
179
+ flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
180
+ index_in_side = wp.where(flipped, EDGE_NODE_COUNT - 1 - type_index, type_index)
181
+
182
+ return global_offset + EDGE_NODE_COUNT * side_index + index_in_side
183
+
184
+ return NULL_NODE_INDEX # should never happen
185
+
186
+ return element_node_index
187
+
188
+ def _make_element_node_sign(self):
189
+ @cache.dynamic_func(suffix=self.name)
190
+ def element_node_sign(
191
+ cell_arg: self._mesh.CellArg,
192
+ topo_arg: QuadmeshSpaceTopology.TopologyArg,
193
+ element_index: ElementIndex,
194
+ node_index_in_elt: int,
195
+ ):
196
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
197
+
198
+ if node_type == SquareShapeFunction.EDGE_X or node_type == SquareShapeFunction.EDGE_Y:
199
+ side_start = wp.where(
200
+ node_type == SquareShapeFunction.EDGE_X,
201
+ wp.where(type_instance == 0, 0, 2),
202
+ wp.where(type_instance == 0, 3, 1),
203
+ )
204
+
205
+ side_index = topo_arg.quad_edge_indices[element_index, side_start]
206
+ local_vs = cell_arg.topology.quad_vertex_indices[element_index, side_start]
207
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
208
+
209
+ # Flip indexing direction
210
+ flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
211
+ return wp.where(flipped, -1.0, 1.0)
212
+
213
+ return 1.0
214
+
215
+ return element_node_sign
216
+
217
+
218
+ def make_quadmesh_space_topology(mesh: Quadmesh2D, shape: SquareShapeFunction):
219
+ if isinstance(shape, SquareShapeFunction):
220
+ return forward_base_topology(QuadmeshSpaceTopology, mesh, shape)
221
+
222
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,221 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import warp as wp
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.domain import GeometryDomain
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX, NodeElementIndex
22
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index
23
+
24
+ from .partition import SpacePartition
25
+
26
+ _wp_module_name_ = "warp.fem.space.restriction"
27
+
28
+ wp.set_module_options({"enable_backward": False})
29
+
30
+
31
+ class SpaceRestriction:
32
+ """Restriction of a space partition to a given GeometryDomain"""
33
+
34
+ def __init__(
35
+ self,
36
+ space_partition: SpacePartition,
37
+ domain: GeometryDomain,
38
+ device=None,
39
+ temporary_store: cache.TemporaryStore = None,
40
+ ):
41
+ space_topology = space_partition.space_topology
42
+
43
+ if domain.dimension == space_topology.dimension - 1:
44
+ space_topology = space_topology.trace()
45
+
46
+ if domain.dimension != space_topology.dimension:
47
+ raise ValueError("Incompatible space and domain dimensions")
48
+
49
+ self.space_partition = space_partition
50
+ self.space_topology = space_topology
51
+ self.domain = domain
52
+
53
+ self._node_count_dev: wp.array = None
54
+ """Number of unique partition node indices"""
55
+ self._dof_partition_indices: wp.array = None
56
+ """Array of unique partition node indices"""
57
+
58
+ self._dof_partition_element_offsets: wp.array = None
59
+ """Mapping from partition node to offset in the per-node element indices array"""
60
+ self._dof_element_indices: wp.array = None
61
+ """Concatenation of neighboring elements indices for each partition node"""
62
+ self._dof_indices_in_element: wp.array = None
63
+ """Concatenation of node index in element for each partition node"""
64
+
65
+ self.rebuild(device=device, temporary_store=temporary_store)
66
+
67
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
68
+ max_nodes_per_element = self.space_topology.MAX_NODES_PER_ELEMENT
69
+
70
+ @cache.dynamic_kernel(
71
+ suffix=(self.domain.name, self.space_topology.name, self.space_partition.name),
72
+ kernel_options={"max_unroll": 8},
73
+ )
74
+ def fill_element_node_indices(
75
+ element_arg: self.domain.ElementArg,
76
+ domain_index_arg: self.domain.ElementIndexArg,
77
+ topo_arg: self.space_topology.TopologyArg,
78
+ partition_arg: self.space_partition.PartitionArg,
79
+ element_node_indices: wp.array2d(dtype=int),
80
+ ):
81
+ domain_element_index = wp.tid()
82
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
83
+
84
+ if element_index == NULL_ELEMENT_INDEX:
85
+ element_node_count = 0
86
+ else:
87
+ element_node_count = self.space_topology.element_node_count(element_arg, topo_arg, element_index)
88
+
89
+ for n in range(element_node_count):
90
+ space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
91
+ partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
92
+ element_node_indices[domain_element_index, n] = partition_nidx
93
+ for n in range(element_node_count, element_node_indices.shape[1]):
94
+ element_node_indices[domain_element_index, n] = NULL_NODE_INDEX
95
+
96
+ element_node_indices = cache.borrow_temporary(
97
+ temporary_store,
98
+ shape=(self.domain.element_count(), max_nodes_per_element),
99
+ dtype=int,
100
+ device=device,
101
+ )
102
+ wp.launch(
103
+ dim=element_node_indices.shape[0],
104
+ kernel=fill_element_node_indices,
105
+ inputs=[
106
+ self.domain.element_arg_value(device),
107
+ self.domain.element_index_arg_value(device),
108
+ self.space_topology.topo_arg_value(device),
109
+ self.space_partition.partition_arg_value(device),
110
+ element_node_indices,
111
+ ],
112
+ device=device,
113
+ )
114
+
115
+ # Build compressed map from node to element indices
116
+ flattened_node_indices = element_node_indices.flatten()
117
+ (
118
+ self._dof_partition_element_offsets,
119
+ node_array_indices,
120
+ self._node_count_dev,
121
+ self._dof_partition_indices,
122
+ ) = compress_node_indices(
123
+ self.space_partition.node_count(),
124
+ flattened_node_indices,
125
+ node_offsets=self._dof_partition_element_offsets,
126
+ unique_node_count=self._node_count_dev,
127
+ unique_node_indices=self._dof_partition_indices,
128
+ return_unique_nodes=True,
129
+ temporary_store=temporary_store,
130
+ )
131
+
132
+ # Extract element index and index in element
133
+ if self._dof_element_indices is None or self._dof_element_indices.shape != flattened_node_indices.shape:
134
+ self._dof_element_indices = cache.borrow_temporary_like(flattened_node_indices, temporary_store)
135
+ self._dof_indices_in_element = cache.borrow_temporary_like(flattened_node_indices, temporary_store)
136
+
137
+ wp.launch(
138
+ kernel=SpaceRestriction._split_vertex_element_index,
139
+ dim=flattened_node_indices.shape,
140
+ inputs=[
141
+ max_nodes_per_element,
142
+ node_array_indices,
143
+ self._dof_element_indices,
144
+ self._dof_indices_in_element,
145
+ ],
146
+ device=flattened_node_indices.device,
147
+ )
148
+
149
+ node_array_indices.release()
150
+
151
+ # Upper bound on node count, use `node_count_sync` to get the actual value
152
+ self._node_count = min(self.space_partition.node_count(), self._dof_partition_indices.shape[0])
153
+
154
+ def node_count_sync(self) -> int:
155
+ """Ensures that the node count is synchronized with the device and returns it"""
156
+ if self._node_count_dev is not None:
157
+ self._node_count = int(host_read_at_index(self._node_count_dev, index=0))
158
+ self._node_count_dev = None
159
+ return self.node_count()
160
+
161
+ def node_count(self) -> int:
162
+ """Upper bound for the node count, use `node_count_sync` to get the actual value"""
163
+ return self._node_count
164
+
165
+ def partition_element_offsets(self):
166
+ return self._dof_partition_element_offsets
167
+
168
+ def node_partition_indices(self):
169
+ return self._dof_partition_indices
170
+
171
+ def total_node_element_count(self):
172
+ return self._dof_element_indices.size
173
+
174
+ @wp.struct
175
+ class NodeArg:
176
+ dof_element_offsets: wp.array(dtype=int)
177
+ dof_element_indices: wp.array(dtype=int)
178
+ dof_partition_indices: wp.array(dtype=int)
179
+ dof_indices_in_element: wp.array(dtype=int)
180
+
181
+ @cache.cached_arg_value
182
+ def node_arg_value(self, device):
183
+ arg = SpaceRestriction.NodeArg()
184
+ self.fill_node_arg(arg, device)
185
+ return arg
186
+
187
+ def fill_node_arg(self, arg: NodeArg, device):
188
+ arg.dof_element_offsets = self._dof_partition_element_offsets.to(device)
189
+ arg.dof_element_indices = self._dof_element_indices.to(device)
190
+ arg.dof_partition_indices = self._dof_partition_indices.to(device)
191
+ arg.dof_indices_in_element = self._dof_indices_in_element.to(device)
192
+
193
+ @wp.func
194
+ def node_partition_index(args: NodeArg, restriction_node_index: int):
195
+ return args.dof_partition_indices[restriction_node_index]
196
+
197
+ @wp.func
198
+ def node_partition_index_from_element_offset(args: NodeArg, element_offset: int):
199
+ return wp.lower_bound(args.dof_element_offsets, element_offset + 1) - 1
200
+
201
+ @wp.func
202
+ def node_element_range(args: NodeArg, partition_node_index: int):
203
+ return args.dof_element_offsets[partition_node_index], args.dof_element_offsets[partition_node_index + 1]
204
+
205
+ @wp.func
206
+ def node_element_index(args: NodeArg, node_element_offset: int):
207
+ domain_element_index = args.dof_element_indices[node_element_offset]
208
+ index_in_element = args.dof_indices_in_element[node_element_offset]
209
+ return NodeElementIndex(domain_element_index, index_in_element)
210
+
211
+ @wp.kernel
212
+ def _split_vertex_element_index(
213
+ vertex_per_element: int,
214
+ sorted_indices: wp.array(dtype=int),
215
+ vertex_element_index: wp.array(dtype=int),
216
+ vertex_index_in_element: wp.array(dtype=int),
217
+ ):
218
+ idx = sorted_indices[wp.tid()]
219
+ element_index = idx // vertex_per_element
220
+ vertex_element_index[wp.tid()] = element_index
221
+ vertex_index_in_element[wp.tid()] = idx - vertex_per_element * element_index
@@ -0,0 +1,152 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from enum import Enum
18
+ from typing import Optional
19
+
20
+ from warp._src.fem.geometry import Element
21
+ from warp._src.fem.polynomial import Polynomial
22
+
23
+ from .cube_shape_function import (
24
+ CubeNedelecFirstKindShapeFunctions,
25
+ CubeNonConformingPolynomialShapeFunctions,
26
+ CubeRaviartThomasShapeFunctions,
27
+ CubeSerendipityShapeFunctions,
28
+ CubeShapeFunction,
29
+ CubeTripolynomialShapeFunctions,
30
+ )
31
+ from .shape_function import ConstantShapeFunction, ShapeFunction
32
+ from .square_shape_function import (
33
+ SquareBipolynomialShapeFunctions,
34
+ SquareNedelecFirstKindShapeFunctions,
35
+ SquareNonConformingPolynomialShapeFunctions,
36
+ SquareRaviartThomasShapeFunctions,
37
+ SquareSerendipityShapeFunctions,
38
+ SquareShapeFunction,
39
+ )
40
+ from .tet_shape_function import (
41
+ TetrahedronNedelecFirstKindShapeFunctions,
42
+ TetrahedronNonConformingPolynomialShapeFunctions,
43
+ TetrahedronPolynomialShapeFunctions,
44
+ TetrahedronRaviartThomasShapeFunctions,
45
+ TetrahedronShapeFunction,
46
+ )
47
+ from .triangle_shape_function import (
48
+ TriangleNedelecFirstKindShapeFunctions,
49
+ TriangleNonConformingPolynomialShapeFunctions,
50
+ TrianglePolynomialShapeFunctions,
51
+ TriangleRaviartThomasShapeFunctions,
52
+ TriangleShapeFunction,
53
+ )
54
+
55
+
56
+ class ElementBasis(Enum):
57
+ """Choice of basis function to equip individual elements"""
58
+
59
+ LAGRANGE = "P"
60
+ """Lagrange basis functions :math:`P_k` for simplices, tensor products :math:`Q_k` for squares and cubes"""
61
+ SERENDIPITY = "S"
62
+ """Serendipity elements :math:`S_k`, corresponding to Lagrange nodes with interior points removed (for degree <= 3)"""
63
+ NONCONFORMING_POLYNOMIAL = "dP"
64
+ """Simplex Lagrange basis functions :math:`P_{kd}` embedded into non conforming reference elements (e.g. squares or cubes). Discontinuous only."""
65
+ NEDELEC_FIRST_KIND = "N1"
66
+ """Nédélec (first kind) H(curl) shape functions. Should be used with covariant function space."""
67
+ RAVIART_THOMAS = "RT"
68
+ """Raviart-Thomas H(div) shape functions. Should be used with contravariant function space."""
69
+
70
+
71
+ @functools.lru_cache(maxsize=None)
72
+ def make_element_shape_function(
73
+ element: Element,
74
+ degree: int,
75
+ element_basis: Optional[ElementBasis] = None,
76
+ family: Optional[Polynomial] = None,
77
+ ) -> ShapeFunction:
78
+ """
79
+ Equips a reference element with a shape function basis.
80
+
81
+ Args:
82
+ element: the type of reference element on which to build the shape function
83
+ degree: polynomial degree of the per-element shape functions
84
+ element_basis: type of basis function for the individual elements
85
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
86
+
87
+ Returns:
88
+ the corresponding shape function
89
+
90
+ Raises:
91
+ NotImplementedError: If the shape function is not implemented for the given element type
92
+ """
93
+
94
+ if element_basis is None:
95
+ element_basis = ElementBasis.LAGRANGE
96
+ elif element_basis == ElementBasis.SERENDIPITY and degree == 1:
97
+ # Degree-1 serendipity is always equivalent to Lagrange
98
+ element_basis = ElementBasis.LAGRANGE
99
+
100
+ if degree == 0:
101
+ return ConstantShapeFunction(element)
102
+
103
+ if family is None:
104
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
105
+
106
+ if element == Element.SQUARE:
107
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
108
+ return SquareNedelecFirstKindShapeFunctions(degree=degree)
109
+ if element_basis == ElementBasis.RAVIART_THOMAS:
110
+ return SquareRaviartThomasShapeFunctions(degree=degree)
111
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
112
+ return SquareNonConformingPolynomialShapeFunctions(degree=degree)
113
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
114
+ return SquareSerendipityShapeFunctions(degree=degree, family=family)
115
+
116
+ return SquareBipolynomialShapeFunctions(degree=degree, family=family)
117
+ if element == Element.TRIANGLE:
118
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
119
+ return TriangleNedelecFirstKindShapeFunctions(degree=degree)
120
+ if element_basis == ElementBasis.RAVIART_THOMAS:
121
+ return TriangleRaviartThomasShapeFunctions(degree=degree)
122
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
123
+ return TriangleNonConformingPolynomialShapeFunctions(degree=degree)
124
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
125
+ raise NotImplementedError("Serendipity variant not implemented yet for Triangle elements")
126
+
127
+ return TrianglePolynomialShapeFunctions(degree=degree)
128
+
129
+ if element == Element.CUBE:
130
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
131
+ return CubeNedelecFirstKindShapeFunctions(degree=degree)
132
+ if element_basis == ElementBasis.RAVIART_THOMAS:
133
+ return CubeRaviartThomasShapeFunctions(degree=degree)
134
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
135
+ return CubeNonConformingPolynomialShapeFunctions(degree=degree)
136
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
137
+ return CubeSerendipityShapeFunctions(degree=degree, family=family)
138
+
139
+ return CubeTripolynomialShapeFunctions(degree=degree, family=family)
140
+ if element == Element.TETRAHEDRON:
141
+ if element_basis == ElementBasis.NEDELEC_FIRST_KIND:
142
+ return TetrahedronNedelecFirstKindShapeFunctions(degree=degree)
143
+ if element_basis == ElementBasis.RAVIART_THOMAS:
144
+ return TetrahedronRaviartThomasShapeFunctions(degree=degree)
145
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
146
+ return TetrahedronNonConformingPolynomialShapeFunctions(degree=degree)
147
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
148
+ raise NotImplementedError("Serendipity variant not implemented yet for Tet elements")
149
+
150
+ return TetrahedronPolynomialShapeFunctions(degree=degree)
151
+
152
+ raise NotImplementedError(f"Unrecognized element type {element}")