warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1107 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.geometry import Grid3D
23
+ from warp._src.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
24
+ from warp._src.fem.types import Coords
25
+
26
+ from .shape_function import ShapeFunction
27
+ from .tet_shape_function import TetrahedronPolynomialShapeFunctions
28
+
29
+ _wp_module_name_ = "warp.fem.space.shape.cube_shape_function"
30
+
31
+
32
+ class CubeShapeFunction(ShapeFunction):
33
+ VERTEX = 0
34
+ EDGE = 1
35
+ FACE = 2
36
+ INTERIOR = 3
37
+
38
+ @wp.func
39
+ def _vertex_coords(vidx_in_cell: int):
40
+ x = vidx_in_cell // 4
41
+ y = (vidx_in_cell - 4 * x) // 2
42
+ z = vidx_in_cell - 4 * x - 2 * y
43
+ return wp.vec3i(x, y, z)
44
+
45
+ @wp.func
46
+ def _edge_coords(type_instance: int, index_in_side: int):
47
+ return wp.vec3i(index_in_side + 1, (type_instance & 2) >> 1, type_instance & 1)
48
+
49
+ @wp.func
50
+ def _edge_axis(type_instance: int):
51
+ return type_instance >> 2
52
+
53
+ @wp.func
54
+ def _face_offset(type_instance: int):
55
+ return type_instance & 1
56
+
57
+ @wp.func
58
+ def _face_axis(type_instance: int):
59
+ return type_instance >> 1
60
+
61
+
62
+ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
63
+ def __init__(self, degree: int, family: Polynomial):
64
+ self.family = family
65
+
66
+ self.ORDER = wp.constant(degree)
67
+ self.NODES_PER_ELEMENT = wp.constant((degree + 1) ** 3)
68
+ self.NODES_PER_SIDE = wp.constant((degree + 1) ** 2)
69
+
70
+ if is_closed(self.family):
71
+ self.VERTEX_NODE_COUNT = wp.constant(1)
72
+ self.EDGE_NODE_COUNT = wp.constant(max(0, degree - 1))
73
+ self.FACE_NODE_COUNT = wp.constant(max(0, degree - 1) ** 2)
74
+ self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) ** 3)
75
+ else:
76
+ self.VERTEX_NODE_COUNT = wp.constant(0)
77
+ self.EDGE_NODE_COUNT = wp.constant(0)
78
+ self.FACE_NODE_COUNT = wp.constant(0)
79
+ self.INTERIOR_NODE_COUNT = self.NODES_PER_ELEMENT
80
+
81
+ lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
82
+ lagrange_scale = lagrange_scales(lobatto_coords)
83
+
84
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
85
+ self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
86
+ self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
87
+ self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
88
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
89
+
90
+ self._node_ijk = self._make_node_ijk()
91
+ self.node_type_and_type_index = self._make_node_type_and_type_index()
92
+
93
+ if degree > 2:
94
+ self._face_node_ij = self._make_face_node_ij()
95
+ self._linear_face_node_index = self._make_linear_face_node_index()
96
+
97
+ @property
98
+ def name(self) -> str:
99
+ return f"Cube_Q{self.ORDER}_{self.family}"
100
+
101
+ @wp.func
102
+ def _vertex_coords_f(vidx_in_cell: int):
103
+ x = vidx_in_cell // 4
104
+ y = (vidx_in_cell - 4 * x) // 2
105
+ z = vidx_in_cell - 4 * x - 2 * y
106
+ return wp.vec3(float(x), float(y), float(z))
107
+
108
+ def _make_node_ijk(self):
109
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
110
+
111
+ def node_ijk(
112
+ node_index_in_elt: int,
113
+ ):
114
+ node_i = node_index_in_elt // (ORDER_PLUS_ONE * ORDER_PLUS_ONE)
115
+ node_jk = node_index_in_elt - ORDER_PLUS_ONE * ORDER_PLUS_ONE * node_i
116
+ node_j = node_jk // ORDER_PLUS_ONE
117
+ node_k = node_jk - ORDER_PLUS_ONE * node_j
118
+ return node_i, node_j, node_k
119
+
120
+ return cache.get_func(node_ijk, self.name)
121
+
122
+ def _make_face_node_ij(self):
123
+ ORDER_MINUS_ONE = wp.constant(self.ORDER - 1)
124
+
125
+ def face_node_ij(
126
+ face_node_index: int,
127
+ ):
128
+ node_i = face_node_index // ORDER_MINUS_ONE
129
+ node_j = face_node_index - ORDER_MINUS_ONE * node_i
130
+ return wp.vec2i(node_i, node_j)
131
+
132
+ return cache.get_func(face_node_ij, self.name)
133
+
134
+ def _make_linear_face_node_index(self):
135
+ ORDER_MINUS_ONE = wp.constant(self.ORDER - 1)
136
+
137
+ def linear_face_node_index(face_node_index: int, face_node_ij: wp.vec2i):
138
+ return face_node_ij[0] * ORDER_MINUS_ONE + face_node_ij[1]
139
+
140
+ return cache.get_func(linear_face_node_index, self.name)
141
+
142
+ def _make_node_type_and_type_index(self):
143
+ ORDER = self.ORDER
144
+
145
+ @cache.dynamic_func(suffix=self.name)
146
+ def node_type_and_type_index_open(
147
+ node_index_in_elt: int,
148
+ ):
149
+ return CubeShapeFunction.INTERIOR, 0, node_index_in_elt
150
+
151
+ @cache.dynamic_func(suffix=self.name)
152
+ def node_type_and_type_index(
153
+ node_index_in_elt: int,
154
+ ):
155
+ i, j, k = self._node_ijk(node_index_in_elt)
156
+
157
+ zi = wp.where(i == 0, 1, 0)
158
+ zj = wp.where(j == 0, 1, 0)
159
+ zk = wp.where(k == 0, 1, 0)
160
+
161
+ mi = wp.where(i == ORDER, 1, 0)
162
+ mj = wp.where(j == ORDER, 1, 0)
163
+ mk = wp.where(k == ORDER, 1, 0)
164
+
165
+ if zi + mi == 1:
166
+ if zj + mj == 1:
167
+ if zk + mk == 1:
168
+ # vertex
169
+ type_instance = mi * 4 + mj * 2 + mk
170
+ return CubeTripolynomialShapeFunctions.VERTEX, type_instance, 0
171
+
172
+ # z edge
173
+ type_instance = 8 + mi * 2 + mj
174
+ type_index = k - 1
175
+ return CubeTripolynomialShapeFunctions.EDGE, type_instance, type_index
176
+
177
+ if zk + mk == 1:
178
+ # y edge
179
+ type_instance = 4 + mk * 2 + mi
180
+ type_index = j - 1
181
+ return CubeTripolynomialShapeFunctions.EDGE, type_instance, type_index
182
+
183
+ # x face
184
+ type_instance = mi
185
+ type_index = (j - 1) * (ORDER - 1) + k - 1
186
+ return CubeTripolynomialShapeFunctions.FACE, type_instance, type_index
187
+
188
+ if zj + mj == 1:
189
+ if zk + mk == 1:
190
+ # x edge
191
+ type_instance = mj * 2 + mk
192
+ type_index = i - 1
193
+ return CubeTripolynomialShapeFunctions.EDGE, type_instance, type_index
194
+
195
+ # y face
196
+ type_instance = 2 + mj
197
+ type_index = (k - 1) * (ORDER - 1) + i - 1
198
+ return CubeTripolynomialShapeFunctions.FACE, type_instance, type_index
199
+
200
+ if zk + mk == 1:
201
+ # z face
202
+ type_instance = 4 + mk
203
+ type_index = (i - 1) * (ORDER - 1) + j - 1
204
+ return CubeTripolynomialShapeFunctions.FACE, type_instance, type_index
205
+
206
+ type_index = ((i - 1) * (ORDER - 1) + (j - 1)) * (ORDER - 1) + k - 1
207
+ return CubeTripolynomialShapeFunctions.INTERIOR, 0, type_index
208
+
209
+ return node_type_and_type_index if is_closed(self.family) else node_type_and_type_index_open
210
+
211
+ def make_node_coords_in_element(self):
212
+ LOBATTO_COORDS = self.LOBATTO_COORDS
213
+
214
+ @cache.dynamic_func(suffix=self.name)
215
+ def node_coords_in_element(
216
+ node_index_in_elt: int,
217
+ ):
218
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
219
+ return Coords(LOBATTO_COORDS[node_i], LOBATTO_COORDS[node_j], LOBATTO_COORDS[node_k])
220
+
221
+ return node_coords_in_element
222
+
223
+ def make_node_quadrature_weight(self):
224
+ ORDER = self.ORDER
225
+ LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
226
+
227
+ def node_quadrature_weight(
228
+ node_index_in_elt: int,
229
+ ):
230
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
231
+ return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j] * LOBATTO_WEIGHT[node_k]
232
+
233
+ def node_quadrature_weight_linear(
234
+ node_index_in_elt: int,
235
+ ):
236
+ return 0.125
237
+
238
+ if ORDER == 1:
239
+ return cache.get_func(node_quadrature_weight_linear, self.name)
240
+
241
+ return cache.get_func(node_quadrature_weight, self.name)
242
+
243
+ def make_trace_node_quadrature_weight(self):
244
+ ORDER = self.ORDER
245
+ LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
246
+
247
+ def trace_node_quadrature_weight(
248
+ node_index_in_elt: int,
249
+ ):
250
+ # We're either on a side interior or at a vertex
251
+ # If we find one index at extremum, pick the two other
252
+
253
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
254
+
255
+ if node_i == 0 or node_i == ORDER:
256
+ return LOBATTO_WEIGHT[node_j] * LOBATTO_WEIGHT[node_k]
257
+
258
+ if node_j == 0 or node_j == ORDER:
259
+ return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_k]
260
+
261
+ return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j]
262
+
263
+ def trace_node_quadrature_weight_linear(
264
+ node_index_in_elt: int,
265
+ ):
266
+ return 0.25
267
+
268
+ def trace_node_quadrature_weight_open(
269
+ node_index_in_elt: int,
270
+ ):
271
+ return 0.0
272
+
273
+ if not is_closed(self.family):
274
+ return cache.get_func(trace_node_quadrature_weight_open, self.name)
275
+
276
+ if ORDER == 1:
277
+ return cache.get_func(trace_node_quadrature_weight_linear, self.name)
278
+
279
+ return cache.get_func(trace_node_quadrature_weight, self.name)
280
+
281
+ def make_element_inner_weight(self):
282
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
283
+ LOBATTO_COORDS = self.LOBATTO_COORDS
284
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
285
+
286
+ def element_inner_weight(
287
+ coords: Coords,
288
+ node_index_in_elt: int,
289
+ ):
290
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
291
+
292
+ w = float(1.0)
293
+ for k in range(ORDER_PLUS_ONE):
294
+ if k != node_i:
295
+ w *= coords[0] - LOBATTO_COORDS[k]
296
+ if k != node_j:
297
+ w *= coords[1] - LOBATTO_COORDS[k]
298
+ if k != node_k:
299
+ w *= coords[2] - LOBATTO_COORDS[k]
300
+
301
+ w *= LAGRANGE_SCALE[node_i] * LAGRANGE_SCALE[node_j] * LAGRANGE_SCALE[node_k]
302
+
303
+ return w
304
+
305
+ def element_inner_weight_linear(
306
+ coords: Coords,
307
+ node_index_in_elt: int,
308
+ ):
309
+ v = CubeTripolynomialShapeFunctions._vertex_coords_f(node_index_in_elt)
310
+
311
+ wx = (1.0 - coords[0]) * (1.0 - v[0]) + v[0] * coords[0]
312
+ wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
313
+ wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
314
+ return wx * wy * wz
315
+
316
+ if self.ORDER == 1 and is_closed(self.family):
317
+ return cache.get_func(element_inner_weight_linear, self.name)
318
+
319
+ return cache.get_func(element_inner_weight, self.name)
320
+
321
+ def make_element_inner_weight_gradient(self):
322
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
323
+ LOBATTO_COORDS = self.LOBATTO_COORDS
324
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
325
+
326
+ def element_inner_weight_gradient(
327
+ coords: Coords,
328
+ node_index_in_elt: int,
329
+ ):
330
+ node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
331
+
332
+ prefix_xy = float(1.0)
333
+ prefix_yz = float(1.0)
334
+ prefix_zx = float(1.0)
335
+ for k in range(ORDER_PLUS_ONE):
336
+ if k != node_i:
337
+ prefix_yz *= coords[0] - LOBATTO_COORDS[k]
338
+ if k != node_j:
339
+ prefix_zx *= coords[1] - LOBATTO_COORDS[k]
340
+ if k != node_k:
341
+ prefix_xy *= coords[2] - LOBATTO_COORDS[k]
342
+
343
+ prefix_x = prefix_zx * prefix_xy
344
+ prefix_y = prefix_yz * prefix_xy
345
+ prefix_z = prefix_zx * prefix_yz
346
+
347
+ grad_x = float(0.0)
348
+ grad_y = float(0.0)
349
+ grad_z = float(0.0)
350
+
351
+ for k in range(ORDER_PLUS_ONE):
352
+ if k != node_i:
353
+ delta_x = coords[0] - LOBATTO_COORDS[k]
354
+ grad_x = grad_x * delta_x + prefix_x
355
+ prefix_x *= delta_x
356
+ if k != node_j:
357
+ delta_y = coords[1] - LOBATTO_COORDS[k]
358
+ grad_y = grad_y * delta_y + prefix_y
359
+ prefix_y *= delta_y
360
+ if k != node_k:
361
+ delta_z = coords[2] - LOBATTO_COORDS[k]
362
+ grad_z = grad_z * delta_z + prefix_z
363
+ prefix_z *= delta_z
364
+
365
+ grad = (
366
+ LAGRANGE_SCALE[node_i]
367
+ * LAGRANGE_SCALE[node_j]
368
+ * LAGRANGE_SCALE[node_k]
369
+ * wp.vec3(
370
+ grad_x,
371
+ grad_y,
372
+ grad_z,
373
+ )
374
+ )
375
+
376
+ return grad
377
+
378
+ def element_inner_weight_gradient_linear(
379
+ coords: Coords,
380
+ node_index_in_elt: int,
381
+ ):
382
+ v = CubeTripolynomialShapeFunctions._vertex_coords_f(node_index_in_elt)
383
+
384
+ wx = (1.0 - coords[0]) * (1.0 - v[0]) + v[0] * coords[0]
385
+ wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
386
+ wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
387
+
388
+ dx = 2.0 * v[0] - 1.0
389
+ dy = 2.0 * v[1] - 1.0
390
+ dz = 2.0 * v[2] - 1.0
391
+
392
+ return wp.vec3(dx * wy * wz, dy * wz * wx, dz * wx * wy)
393
+
394
+ if self.ORDER == 1 and is_closed(self.family):
395
+ return cache.get_func(element_inner_weight_gradient_linear, self.name)
396
+
397
+ return cache.get_func(element_inner_weight_gradient, self.name)
398
+
399
+ def element_node_hexes(self):
400
+ from warp._src.fem.utils import grid_to_hexes
401
+
402
+ return grid_to_hexes(self.ORDER, self.ORDER, self.ORDER)
403
+
404
+ def element_node_tets(self):
405
+ from warp._src.fem.utils import grid_to_tets
406
+
407
+ return grid_to_tets(self.ORDER, self.ORDER, self.ORDER)
408
+
409
+ def element_vtk_cells(self):
410
+ n = self.ORDER + 1
411
+
412
+ # vertices
413
+ cells = [
414
+ [
415
+ [0, 0, 0],
416
+ [n - 1, 0, 0],
417
+ [n - 1, n - 1, 0],
418
+ [0, n - 1, 0],
419
+ [0, 0, n - 1],
420
+ [n - 1, 0, n - 1],
421
+ [n - 1, n - 1, n - 1],
422
+ [0, n - 1, n - 1],
423
+ ]
424
+ ]
425
+
426
+ if self.ORDER == 1:
427
+ cell_type = 12 # vtk_hexahedron
428
+ else:
429
+ middle = np.arange(1, n - 1)
430
+ front = np.zeros(n - 2, dtype=int)
431
+ back = np.full(n - 2, n - 1)
432
+
433
+ # edges
434
+ cells.append(np.column_stack((middle, front, front)))
435
+ cells.append(np.column_stack((back, middle, front)))
436
+ cells.append(np.column_stack((middle, back, front)))
437
+ cells.append(np.column_stack((front, middle, front)))
438
+
439
+ cells.append(np.column_stack((middle, front, back)))
440
+ cells.append(np.column_stack((back, middle, back)))
441
+ cells.append(np.column_stack((middle, back, back)))
442
+ cells.append(np.column_stack((front, middle, back)))
443
+
444
+ cells.append(np.column_stack((front, front, middle)))
445
+ cells.append(np.column_stack((back, front, middle)))
446
+ cells.append(np.column_stack((back, back, middle)))
447
+ cells.append(np.column_stack((front, back, middle)))
448
+
449
+ # faces
450
+
451
+ face = np.meshgrid(middle, middle)
452
+ front = np.zeros((n - 2) ** 2, dtype=int)
453
+ back = np.full((n - 2) ** 2, n - 1)
454
+
455
+ # YZ
456
+ cells.append(
457
+ np.column_stack((front, face[0].flatten(), face[1].flatten())),
458
+ )
459
+ cells.append(
460
+ np.column_stack((back, face[0].flatten(), face[1].flatten())),
461
+ )
462
+ # XZ
463
+ cells.append(
464
+ np.column_stack((face[0].flatten(), front, face[1].flatten())),
465
+ )
466
+ cells.append(
467
+ np.column_stack((face[0].flatten(), back, face[1].flatten())),
468
+ )
469
+ # XY
470
+ cells.append(
471
+ np.column_stack((face[0].flatten(), face[1].flatten(), front)),
472
+ )
473
+ cells.append(
474
+ np.column_stack((face[0].flatten(), face[1].flatten(), back)),
475
+ )
476
+
477
+ # interior
478
+ interior = np.meshgrid(middle, middle, middle)
479
+ cells.append(
480
+ np.column_stack((interior[0].flatten(), interior[1].flatten(), interior[2].flatten())),
481
+ )
482
+
483
+ cell_type = 72 # vtk_lagrange_hexahedron
484
+
485
+ cells = np.concatenate(cells)
486
+ cell_indices = cells[:, 0] * n * n + cells[:, 1] * n + cells[:, 2]
487
+
488
+ return cell_indices[np.newaxis, :], np.array([cell_type], dtype=np.int8)
489
+
490
+
491
+ class CubeSerendipityShapeFunctions(CubeShapeFunction):
492
+ """
493
+ Serendipity element ~ tensor product space without interior nodes
494
+ Edge shape functions are usual Lagrange shape functions times a bilinear function in the normal directions
495
+ Corner shape functions are trilinear shape functions times a function of (x^{d-1} + y^{d-1})
496
+ """
497
+
498
+ def __init__(self, degree: int, family: Polynomial):
499
+ if not is_closed(family):
500
+ raise ValueError("A closed polynomial family is required to define serendipity elements")
501
+
502
+ if degree not in [2, 3]:
503
+ raise NotImplementedError("Serendipity element only implemented for order 2 or 3")
504
+
505
+ self.family = family
506
+
507
+ self.ORDER = wp.constant(degree)
508
+ self.NODES_PER_ELEMENT = wp.constant(8 + 12 * (degree - 1))
509
+ self.NODES_PER_SIDE = wp.constant(4 * degree)
510
+
511
+ self.VERTEX_NODE_COUNT = wp.constant(1)
512
+ self.EDGE_NODE_COUNT = wp.constant(degree - 1)
513
+ self.FACE_NODE_COUNT = wp.constant(0)
514
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
515
+
516
+ lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
517
+ lagrange_scale = lagrange_scales(lobatto_coords)
518
+
519
+ NodeVec = cache.cached_vec_type(length=degree + 1, dtype=wp.float32)
520
+ self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
521
+ self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
522
+ self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
523
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
524
+
525
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
526
+ self._node_lobatto_indices = self._get_node_lobatto_indices()
527
+
528
+ @property
529
+ def name(self) -> str:
530
+ return f"Cube_S{self.ORDER}_{self.family}"
531
+
532
+ def _get_node_type_and_type_index(self):
533
+ @cache.dynamic_func(suffix=self.name)
534
+ def node_type_and_index(
535
+ node_index_in_elt: int,
536
+ ):
537
+ if node_index_in_elt < 8:
538
+ return CubeShapeFunction.VERTEX, node_index_in_elt, 0
539
+
540
+ edge_index = (node_index_in_elt - 8) // 3
541
+ edge_axis = node_index_in_elt - 8 - 3 * edge_index
542
+
543
+ index_in_edge = edge_index // 4
544
+ edge_offset = edge_index - 4 * index_in_edge
545
+
546
+ return CubeShapeFunction.EDGE, 4 * edge_axis + edge_offset, index_in_edge
547
+
548
+ return node_type_and_index
549
+
550
+ # @wp.func
551
+ # def _cube_edge_index(node_type: int, type_index: int):
552
+ # index_in_side = type_index // 4
553
+ # side_offset = type_index - 4 * index_in_side
554
+
555
+ # return 4 * (node_type - CubeSerendipityShapeFunctions.EDGE_X) + side_offset, index_in_side
556
+
557
+ def _get_node_lobatto_indices(self):
558
+ ORDER = self.ORDER
559
+
560
+ @cache.dynamic_func(suffix=self.name)
561
+ def node_lobatto_indices(node_type: int, type_instance: int, type_index: int):
562
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
563
+ return CubeSerendipityShapeFunctions._vertex_coords(type_instance) * ORDER
564
+
565
+ axis = CubeSerendipityShapeFunctions._edge_axis(type_instance)
566
+ local_coords = CubeSerendipityShapeFunctions._edge_coords(type_instance, type_index)
567
+
568
+ local_indices = wp.vec3i(local_coords[0], local_coords[1] * ORDER, local_coords[2] * ORDER)
569
+
570
+ return Grid3D._local_to_world(axis, local_indices)
571
+
572
+ return node_lobatto_indices
573
+
574
+ def make_node_coords_in_element(self):
575
+ LOBATTO_COORDS = self.LOBATTO_COORDS
576
+
577
+ @cache.dynamic_func(suffix=self.name)
578
+ def node_coords_in_element(
579
+ node_index_in_elt: int,
580
+ ):
581
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
582
+ node_coords = self._node_lobatto_indices(node_type, type_instance, type_index)
583
+ return Coords(
584
+ LOBATTO_COORDS[node_coords[0]], LOBATTO_COORDS[node_coords[1]], LOBATTO_COORDS[node_coords[2]]
585
+ )
586
+
587
+ return node_coords_in_element
588
+
589
+ def make_node_quadrature_weight(self):
590
+ ORDER = self.ORDER
591
+
592
+ @cache.dynamic_func(suffix=self.name)
593
+ def node_quadrature_weight(
594
+ node_index_in_elt: int,
595
+ ):
596
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
597
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
598
+ return 1.0 / float(8 * ORDER * ORDER * ORDER)
599
+
600
+ return (1.0 - 1.0 / float(ORDER * ORDER * ORDER)) / float(12 * (ORDER - 1))
601
+
602
+ return node_quadrature_weight
603
+
604
+ def make_trace_node_quadrature_weight(self):
605
+ ORDER = self.ORDER
606
+
607
+ @cache.dynamic_func(suffix=self.name)
608
+ def trace_node_quadrature_weight(
609
+ node_index_in_elt: int,
610
+ ):
611
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
612
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
613
+ return 0.25 / float(ORDER * ORDER)
614
+
615
+ return (0.25 - 0.25 / float(ORDER * ORDER)) / float(ORDER - 1)
616
+
617
+ return trace_node_quadrature_weight
618
+
619
+ def make_element_inner_weight(self):
620
+ ORDER = self.ORDER
621
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
622
+
623
+ LOBATTO_COORDS = self.LOBATTO_COORDS
624
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
625
+
626
+ DEGREE_3_SPHERE_RAD = wp.constant(2 * 0.5**2 + (0.5 - LOBATTO_COORDS[1]) ** 2)
627
+ DEGREE_3_SPHERE_SCALE = 1.0 / (0.75 - DEGREE_3_SPHERE_RAD)
628
+
629
+ @cache.dynamic_func(suffix=self.name)
630
+ def element_inner_weight(
631
+ coords: Coords,
632
+ node_index_in_elt: int,
633
+ ):
634
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
635
+
636
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
637
+ node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
638
+
639
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
640
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
641
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
642
+
643
+ w = cx * cy * cz
644
+
645
+ if wp.static(ORDER == 2):
646
+ w *= cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
647
+ return w * LAGRANGE_SCALE[0]
648
+ if wp.static(ORDER == 3):
649
+ w *= (
650
+ (cx - 0.5) * (cx - 0.5)
651
+ + (cy - 0.5) * (cy - 0.5)
652
+ + (cz - 0.5) * (cz - 0.5)
653
+ - DEGREE_3_SPHERE_RAD
654
+ )
655
+ return w * DEGREE_3_SPHERE_SCALE
656
+
657
+ axis = CubeSerendipityShapeFunctions._edge_axis(type_instance)
658
+
659
+ node_all = CubeSerendipityShapeFunctions._edge_coords(type_instance, type_index)
660
+
661
+ local_coords = Grid3D._world_to_local(axis, coords)
662
+
663
+ w = float(1.0)
664
+ w *= wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
665
+ w *= wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
666
+
667
+ for k in range(ORDER_PLUS_ONE):
668
+ if k != node_all[0]:
669
+ w *= local_coords[0] - LOBATTO_COORDS[k]
670
+ w *= LAGRANGE_SCALE[node_all[0]]
671
+
672
+ return w
673
+
674
+ return element_inner_weight
675
+
676
+ def make_element_inner_weight_gradient(self):
677
+ ORDER = self.ORDER
678
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
679
+ LOBATTO_COORDS = self.LOBATTO_COORDS
680
+ LAGRANGE_SCALE = self.LAGRANGE_SCALE
681
+
682
+ DEGREE_3_SPHERE_RAD = wp.constant(2 * 0.5**2 + (0.5 - LOBATTO_COORDS[1]) ** 2)
683
+ DEGREE_3_SPHERE_SCALE = 1.0 / (0.75 - DEGREE_3_SPHERE_RAD)
684
+
685
+ @cache.dynamic_func(suffix=self.name)
686
+ def element_inner_weight_gradient(
687
+ coords: Coords,
688
+ node_index_in_elt: int,
689
+ ):
690
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
691
+
692
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
693
+ node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
694
+
695
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
696
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
697
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
698
+
699
+ gx = wp.where(node_ijk[0] == 0, -1.0, 1.0)
700
+ gy = wp.where(node_ijk[1] == 0, -1.0, 1.0)
701
+ gz = wp.where(node_ijk[2] == 0, -1.0, 1.0)
702
+
703
+ if wp.static(ORDER == 2):
704
+ w = cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
705
+ grad_x = cy * cz * gx * (w + cx)
706
+ grad_y = cz * cx * gy * (w + cy)
707
+ grad_z = cx * cy * gz * (w + cz)
708
+
709
+ return wp.vec3(grad_x, grad_y, grad_z) * LAGRANGE_SCALE[0]
710
+
711
+ if wp.static(ORDER == 3):
712
+ w = (
713
+ (cx - 0.5) * (cx - 0.5)
714
+ + (cy - 0.5) * (cy - 0.5)
715
+ + (cz - 0.5) * (cz - 0.5)
716
+ - DEGREE_3_SPHERE_RAD
717
+ )
718
+
719
+ dw_dcx = 2.0 * cx - 1.0
720
+ dw_dcy = 2.0 * cy - 1.0
721
+ dw_dcz = 2.0 * cz - 1.0
722
+ grad_x = cy * cz * gx * (w + dw_dcx * cx)
723
+ grad_y = cz * cx * gy * (w + dw_dcy * cy)
724
+ grad_z = cx * cy * gz * (w + dw_dcz * cz)
725
+
726
+ return wp.vec3(grad_x, grad_y, grad_z) * DEGREE_3_SPHERE_SCALE
727
+
728
+ axis = CubeSerendipityShapeFunctions._edge_axis(type_instance)
729
+ node_all = CubeSerendipityShapeFunctions._edge_coords(type_instance, type_index)
730
+
731
+ local_coords = Grid3D._world_to_local(axis, coords)
732
+
733
+ w_long = wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
734
+ w_lat = wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
735
+
736
+ g_long = wp.where(node_all[1] == 0, -1.0, 1.0)
737
+ g_lat = wp.where(node_all[2] == 0, -1.0, 1.0)
738
+
739
+ w_alt = LAGRANGE_SCALE[node_all[0]]
740
+ g_alt = float(0.0)
741
+ prefix_alt = LAGRANGE_SCALE[node_all[0]]
742
+ for k in range(ORDER_PLUS_ONE):
743
+ if k != node_all[0]:
744
+ delta_alt = local_coords[0] - LOBATTO_COORDS[k]
745
+ w_alt *= delta_alt
746
+ g_alt = g_alt * delta_alt + prefix_alt
747
+ prefix_alt *= delta_alt
748
+
749
+ local_grad = wp.vec3(g_alt * w_long * w_lat, w_alt * g_long * w_lat, w_alt * w_long * g_lat)
750
+
751
+ return Grid3D._local_to_world(axis, local_grad)
752
+
753
+ return element_inner_weight_gradient
754
+
755
+ def element_node_tets(self):
756
+ from warp._src.fem.utils import grid_to_tets
757
+
758
+ if self.ORDER == 2:
759
+ element_tets = np.array(
760
+ [
761
+ [0, 8, 9, 10],
762
+ [1, 11, 10, 15],
763
+ [2, 9, 14, 13],
764
+ [3, 15, 13, 17],
765
+ [4, 12, 8, 16],
766
+ [5, 18, 16, 11],
767
+ [6, 14, 12, 19],
768
+ [7, 19, 18, 17],
769
+ [16, 12, 18, 11],
770
+ [8, 16, 12, 11],
771
+ [12, 19, 18, 14],
772
+ [14, 19, 17, 18],
773
+ [10, 9, 15, 8],
774
+ [10, 8, 11, 15],
775
+ [9, 13, 15, 14],
776
+ [13, 14, 17, 15],
777
+ ]
778
+ )
779
+
780
+ middle_hex = np.array([8, 11, 9, 15, 12, 18, 14, 17])
781
+ middle_tets = middle_hex[grid_to_tets(1, 1, 1)]
782
+
783
+ return np.concatenate((element_tets, middle_tets))
784
+
785
+ raise NotImplementedError()
786
+
787
+ def element_vtk_cells(self):
788
+ tets = np.array(self.element_node_tets())
789
+ cell_type = 10 # VTK_TETRA
790
+
791
+ return tets, np.full(tets.shape[0], cell_type, dtype=np.int8)
792
+
793
+
794
+ class CubeNonConformingPolynomialShapeFunctions(ShapeFunction):
795
+ # embeds the largest regular tet centered at (0.5, 0.5, 0.5) into the reference cube
796
+
797
+ _tet_height = 2.0 / 3.0
798
+ _tet_side = math.sqrt(3.0 / 2.0) * _tet_height
799
+ _tet_face_height = math.sqrt(3.0) / 2.0 * _tet_side
800
+
801
+ _tet_to_cube = np.array(
802
+ [
803
+ [_tet_side, _tet_side / 2.0, _tet_side / 2.0],
804
+ [0.0, _tet_face_height, _tet_face_height / 3.0],
805
+ [0.0, 0.0, _tet_height],
806
+ ]
807
+ )
808
+
809
+ _TET_OFFSET = wp.constant(wp.vec3(0.5 - 0.5 * _tet_side, 0.5 - _tet_face_height / 3.0, 0.5 - 0.25 * _tet_height))
810
+
811
+ def __init__(self, degree: int):
812
+ self._tet_shape = TetrahedronPolynomialShapeFunctions(degree=degree)
813
+ self.ORDER = self._tet_shape.ORDER
814
+ self.NODES_PER_ELEMENT = self._tet_shape.NODES_PER_ELEMENT
815
+
816
+ self.element_node_tets = self._tet_shape.element_node_tets
817
+ self.element_vtk_cells = self._tet_shape.element_vtk_cells
818
+
819
+ @property
820
+ def name(self) -> str:
821
+ return f"Cube_P{self.ORDER}d"
822
+
823
+ def make_node_coords_in_element(self):
824
+ node_coords_in_tet = self._tet_shape.make_node_coords_in_element()
825
+
826
+ TET_TO_CUBE = wp.constant(wp.mat33(self._tet_to_cube))
827
+
828
+ @cache.dynamic_func(suffix=self.name)
829
+ def node_coords_in_element(
830
+ node_index_in_elt: int,
831
+ ):
832
+ tet_coords = node_coords_in_tet(node_index_in_elt)
833
+ return TET_TO_CUBE * tet_coords + CubeNonConformingPolynomialShapeFunctions._TET_OFFSET
834
+
835
+ return node_coords_in_element
836
+
837
+ def make_node_quadrature_weight(self):
838
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
839
+
840
+ @cache.dynamic_func(suffix=self.name)
841
+ def node_uniform_quadrature_weight(
842
+ node_index_in_elt: int,
843
+ ):
844
+ return 1.0 / float(NODES_PER_ELEMENT)
845
+
846
+ return node_uniform_quadrature_weight
847
+
848
+ def make_trace_node_quadrature_weight(self):
849
+ # Non-conforming, zero measure on sides
850
+
851
+ @wp.func
852
+ def zero(node_index_in_elt: int):
853
+ return 0.0
854
+
855
+ return zero
856
+
857
+ def make_element_inner_weight(self):
858
+ tet_inner_weight = self._tet_shape.make_element_inner_weight()
859
+
860
+ CUBE_TO_TET = wp.constant(wp.mat33(np.linalg.inv(self._tet_to_cube)))
861
+
862
+ @cache.dynamic_func(suffix=self.name)
863
+ def element_inner_weight(
864
+ coords: Coords,
865
+ node_index_in_elt: int,
866
+ ):
867
+ tet_coords = CUBE_TO_TET * (coords - CubeNonConformingPolynomialShapeFunctions._TET_OFFSET)
868
+
869
+ return tet_inner_weight(tet_coords, node_index_in_elt)
870
+
871
+ return element_inner_weight
872
+
873
+ def make_element_inner_weight_gradient(self):
874
+ tet_inner_weight_gradient = self._tet_shape.make_element_inner_weight_gradient()
875
+
876
+ CUBE_TO_TET = wp.constant(wp.mat33(np.linalg.inv(self._tet_to_cube)))
877
+
878
+ @cache.dynamic_func(suffix=self.name)
879
+ def element_inner_weight_gradient(
880
+ coords: Coords,
881
+ node_index_in_elt: int,
882
+ ):
883
+ tet_coords = CUBE_TO_TET * (coords - CubeNonConformingPolynomialShapeFunctions._TET_OFFSET)
884
+ grad = tet_inner_weight_gradient(tet_coords, node_index_in_elt)
885
+ return wp.transpose(CUBE_TO_TET) * grad
886
+
887
+ return element_inner_weight_gradient
888
+
889
+
890
+ class CubeNedelecFirstKindShapeFunctions(CubeShapeFunction):
891
+ value = ShapeFunction.Value.CovariantVector
892
+
893
+ def __init__(self, degree: int):
894
+ if degree != 1:
895
+ raise NotImplementedError("Only linear Nédélec implemented right now")
896
+
897
+ self.ORDER = wp.constant(degree)
898
+ self.NODES_PER_ELEMENT = wp.constant(12)
899
+ self.NODES_PER_SIDE = wp.constant(4)
900
+
901
+ self.VERTEX_NODE_COUNT = wp.constant(0)
902
+ self.EDGE_NODE_COUNT = wp.constant(1)
903
+ self.FACE_NODE_COUNT = wp.constant(0)
904
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
905
+
906
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
907
+
908
+ @property
909
+ def name(self) -> str:
910
+ return f"CubeN1_{self.ORDER}"
911
+
912
+ def _get_node_type_and_type_index(self):
913
+ @cache.dynamic_func(suffix=self.name)
914
+ def node_type_and_index(
915
+ node_index_in_elt: int,
916
+ ):
917
+ return CubeShapeFunction.EDGE, node_index_in_elt, 0
918
+
919
+ return node_type_and_index
920
+
921
+ def make_node_coords_in_element(self):
922
+ @cache.dynamic_func(suffix=self.name)
923
+ def node_coords_in_element(
924
+ node_index_in_elt: int,
925
+ ):
926
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
927
+ axis = CubeShapeFunction._edge_axis(type_instance)
928
+ local_indices = CubeShapeFunction._edge_coords(type_instance, type_index)
929
+
930
+ local_coords = wp.vec3f(0.5, float(local_indices[1]), float(local_indices[2]))
931
+ return Grid3D._local_to_world(axis, local_coords)
932
+
933
+ return node_coords_in_element
934
+
935
+ def make_node_quadrature_weight(self):
936
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
937
+
938
+ @cache.dynamic_func(suffix=self.name)
939
+ def node_quadrature_weight(node_index_in_element: int):
940
+ return 1.0 / float(NODES_PER_ELEMENT)
941
+
942
+ return node_quadrature_weight
943
+
944
+ def make_trace_node_quadrature_weight(self):
945
+ NODES_PER_SIDE = self.NODES_PER_SIDE
946
+
947
+ @cache.dynamic_func(suffix=self.name)
948
+ def trace_node_quadrature_weight(node_index_in_element: int):
949
+ return 1.0 / float(NODES_PER_SIDE)
950
+
951
+ return trace_node_quadrature_weight
952
+
953
+ def make_element_inner_weight(self):
954
+ @cache.dynamic_func(suffix=self.name)
955
+ def element_inner_weight(
956
+ coords: Coords,
957
+ node_index_in_elt: int,
958
+ ):
959
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
960
+
961
+ axis = CubeShapeFunction._edge_axis(type_instance)
962
+
963
+ local_coords = Grid3D._world_to_local(axis, coords)
964
+ edge_coords = CubeShapeFunction._edge_coords(type_instance, type_index)
965
+
966
+ a1 = float(2 * edge_coords[1] - 1)
967
+ a2 = float(2 * edge_coords[2] - 1)
968
+ b1 = float(1 - edge_coords[1])
969
+ b2 = float(1 - edge_coords[2])
970
+
971
+ local_w = wp.vec3((b1 + a1 * local_coords[1]) * (b2 + a2 * local_coords[2]), 0.0, 0.0)
972
+ return Grid3D._local_to_world(axis, local_w)
973
+
974
+ return element_inner_weight
975
+
976
+ def make_element_inner_weight_gradient(self):
977
+ @cache.dynamic_func(suffix=self.name)
978
+ def element_inner_weight_gradient(
979
+ coords: Coords,
980
+ node_index_in_elt: int,
981
+ ):
982
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
983
+
984
+ axis = CubeShapeFunction._edge_axis(type_instance)
985
+
986
+ local_coords = Grid3D._world_to_local(axis, coords)
987
+ edge_coords = CubeShapeFunction._edge_coords(type_instance, type_index)
988
+
989
+ a1 = float(2 * edge_coords[1] - 1)
990
+ a2 = float(2 * edge_coords[2] - 1)
991
+ b1 = float(1 - edge_coords[1])
992
+ b2 = float(1 - edge_coords[2])
993
+
994
+ local_gw = Grid3D._local_to_world(
995
+ axis, wp.vec3(0.0, a1 * (b2 + a2 * local_coords[2]), (b1 + a1 * local_coords[1]) * a2)
996
+ )
997
+
998
+ grad = wp.mat33(0.0)
999
+ grad[axis] = local_gw
1000
+ return grad
1001
+
1002
+ return element_inner_weight_gradient
1003
+
1004
+
1005
+ class CubeRaviartThomasShapeFunctions(CubeShapeFunction):
1006
+ value = ShapeFunction.Value.ContravariantVector
1007
+
1008
+ def __init__(self, degree: int):
1009
+ if degree != 1:
1010
+ raise NotImplementedError("Only linear Raviart Thomas implemented right now")
1011
+
1012
+ self.ORDER = wp.constant(degree)
1013
+ self.NODES_PER_ELEMENT = wp.constant(6)
1014
+ self.NODES_PER_SIDE = wp.constant(1)
1015
+
1016
+ self.VERTEX_NODE_COUNT = wp.constant(0)
1017
+ self.EDGE_NODE_COUNT = wp.constant(0)
1018
+ self.FACE_NODE_COUNT = wp.constant(1)
1019
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
1020
+
1021
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
1022
+
1023
+ @property
1024
+ def name(self) -> str:
1025
+ return f"CubeRT_{self.ORDER}"
1026
+
1027
+ def _get_node_type_and_type_index(self):
1028
+ @cache.dynamic_func(suffix=self.name)
1029
+ def node_type_and_index(
1030
+ node_index_in_elt: int,
1031
+ ):
1032
+ return CubeShapeFunction.FACE, node_index_in_elt, 0
1033
+
1034
+ return node_type_and_index
1035
+
1036
+ def make_node_coords_in_element(self):
1037
+ @cache.dynamic_func(suffix=self.name)
1038
+ def node_coords_in_element(
1039
+ node_index_in_elt: int,
1040
+ ):
1041
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
1042
+ axis = CubeShapeFunction._face_axis(type_instance)
1043
+ offset = CubeShapeFunction._face_offset(type_instance)
1044
+
1045
+ coords = Coords(0.5)
1046
+ coords[axis] = float(offset)
1047
+ return coords
1048
+
1049
+ return node_coords_in_element
1050
+
1051
+ def make_node_quadrature_weight(self):
1052
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
1053
+
1054
+ @cache.dynamic_func(suffix=self.name)
1055
+ def node_quadrature_weight(node_index_in_element: int):
1056
+ return 1.0 / float(NODES_PER_ELEMENT)
1057
+
1058
+ return node_quadrature_weight
1059
+
1060
+ def make_trace_node_quadrature_weight(self):
1061
+ NODES_PER_SIDE = self.NODES_PER_SIDE
1062
+
1063
+ @cache.dynamic_func(suffix=self.name)
1064
+ def trace_node_quadrature_weight(node_index_in_element: int):
1065
+ return 1.0 / float(NODES_PER_SIDE)
1066
+
1067
+ return trace_node_quadrature_weight
1068
+
1069
+ def make_element_inner_weight(self):
1070
+ @cache.dynamic_func(suffix=self.name)
1071
+ def element_inner_weight(
1072
+ coords: Coords,
1073
+ node_index_in_elt: int,
1074
+ ):
1075
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
1076
+
1077
+ axis = CubeShapeFunction._face_axis(type_instance)
1078
+ offset = CubeShapeFunction._face_offset(type_instance)
1079
+
1080
+ a = float(2 * offset - 1)
1081
+ b = float(1 - offset)
1082
+
1083
+ w = wp.vec3(0.0)
1084
+ w[axis] = b + a * coords[axis]
1085
+
1086
+ return w
1087
+
1088
+ return element_inner_weight
1089
+
1090
+ def make_element_inner_weight_gradient(self):
1091
+ @cache.dynamic_func(suffix=self.name)
1092
+ def element_inner_weight_gradient(
1093
+ coords: Coords,
1094
+ node_index_in_elt: int,
1095
+ ):
1096
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
1097
+
1098
+ axis = CubeShapeFunction._face_axis(type_instance)
1099
+ offset = CubeShapeFunction._face_offset(type_instance)
1100
+
1101
+ a = float(2 * offset - 1)
1102
+ grad = wp.mat33(0.0)
1103
+ grad[axis, axis] = a
1104
+
1105
+ return grad
1106
+
1107
+ return element_inner_weight_gradient