warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,667 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.fem import cache
22
+ from warp.fem.geometry import Geometry
23
+ from warp.fem.quadrature import Quadrature
24
+ from warp.fem.types import (
25
+ NULL_ELEMENT_INDEX,
26
+ NULL_QP_INDEX,
27
+ Coords,
28
+ ElementIndex,
29
+ QuadraturePointIndex,
30
+ make_free_sample,
31
+ )
32
+
33
+ from .shape import ShapeFunction
34
+ from .topology import RegularDiscontinuousSpaceTopology, SpaceTopology
35
+
36
+
37
+ class BasisSpace:
38
+ """Interface class for defining a shape function space over a geometry.
39
+
40
+ A basis space makes it easy to define multiple function spaces sharing the same basis (and thus nodes) but with different valuation functions;
41
+ however, it is not a required component of a function space.
42
+
43
+ See also: :func:`make_polynomial_basis_space`, :func:`make_collocated_function_space`
44
+ """
45
+
46
+ @wp.struct
47
+ class BasisArg:
48
+ """Argument structure to be passed to device functions"""
49
+
50
+ pass
51
+
52
+ def __init__(self, topology: SpaceTopology):
53
+ self._topology = topology
54
+
55
+ @property
56
+ def topology(self) -> SpaceTopology:
57
+ """Underlying topology of the basis space"""
58
+ return self._topology
59
+
60
+ @property
61
+ def geometry(self) -> Geometry:
62
+ """Underlying geometry of the basis space"""
63
+ return self._topology.geometry
64
+
65
+ @property
66
+ def value(self) -> ShapeFunction.Value:
67
+ """Value type for the underlying shape functions"""
68
+ raise NotImplementedError()
69
+
70
+ def basis_arg_value(self, device) -> "BasisArg":
71
+ """Value for the argument structure to be passed to device functions"""
72
+ return BasisSpace.BasisArg()
73
+
74
+ # Helpers for generating node positions
75
+
76
+ def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
77
+ """Returns a temporary array containing the world position for each node"""
78
+
79
+ pos_type = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
80
+
81
+ node_coords_in_element = self.make_node_coords_in_element()
82
+
83
+ @cache.dynamic_kernel(suffix=self.name, kernel_options={"max_unroll": 4, "enable_backward": False})
84
+ def fill_node_positions(
85
+ geo_cell_arg: self.geometry.CellArg,
86
+ basis_arg: self.BasisArg,
87
+ topo_arg: self.topology.TopologyArg,
88
+ node_positions: wp.array(dtype=pos_type),
89
+ ):
90
+ element_index = wp.tid()
91
+
92
+ element_node_count = self.topology.element_node_count(geo_cell_arg, topo_arg, element_index)
93
+ for n in range(element_node_count):
94
+ node_index = self.topology.element_node_index(geo_cell_arg, topo_arg, element_index, n)
95
+ coords = node_coords_in_element(geo_cell_arg, basis_arg, element_index, n)
96
+
97
+ sample = make_free_sample(element_index, coords)
98
+ pos = self.geometry.cell_position(geo_cell_arg, sample)
99
+
100
+ node_positions[node_index] = pos
101
+
102
+ shape = (self.topology.node_count(),)
103
+ if out is None:
104
+ node_positions = wp.empty(
105
+ shape=shape,
106
+ dtype=pos_type,
107
+ )
108
+ else:
109
+ if out.shape != shape or not wp.types.types_equal(pos_type, out.dtype):
110
+ raise ValueError(
111
+ f"Out node positions array must have shape {shape} and data type {wp.types.type_repr(pos_type)}"
112
+ )
113
+ node_positions = out
114
+
115
+ wp.launch(
116
+ dim=self.geometry.cell_count(),
117
+ kernel=fill_node_positions,
118
+ inputs=[
119
+ self.geometry.cell_arg_value(device=node_positions.device),
120
+ self.basis_arg_value(device=node_positions.device),
121
+ self.topology.topo_arg_value(device=node_positions.device),
122
+ node_positions,
123
+ ],
124
+ )
125
+
126
+ return node_positions
127
+
128
+ def make_node_coords_in_element(self):
129
+ raise NotImplementedError()
130
+
131
+ def make_node_quadrature_weight(self):
132
+ raise NotImplementedError()
133
+
134
+ def make_element_inner_weight(self):
135
+ raise NotImplementedError()
136
+
137
+ def make_element_outer_weight(self):
138
+ return self.make_element_inner_weight()
139
+
140
+ def make_element_inner_weight_gradient(self):
141
+ raise NotImplementedError()
142
+
143
+ def make_element_outer_weight_gradient(self):
144
+ return self.make_element_inner_weight_gradient()
145
+
146
+ def make_trace_node_quadrature_weight(self):
147
+ raise NotImplementedError()
148
+
149
+ def trace(self) -> "TraceBasisSpace":
150
+ return TraceBasisSpace(self)
151
+
152
+ @property
153
+ def weight_type(self):
154
+ if self.value is ShapeFunction.Value.Scalar:
155
+ return float
156
+
157
+ return cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=float)
158
+
159
+ @property
160
+ def weight_gradient_type(self):
161
+ if self.value is ShapeFunction.Value.Scalar:
162
+ return wp.vec(length=self.geometry.cell_dimension, dtype=float)
163
+
164
+ return cache.cached_mat_type(shape=(self.geometry.cell_dimension, self.geometry.cell_dimension), dtype=float)
165
+
166
+
167
+ class ShapeBasisSpace(BasisSpace):
168
+ """Base class for defining shape-function-based basis spaces."""
169
+
170
+ def __init__(self, topology: SpaceTopology, shape: ShapeFunction):
171
+ super().__init__(topology)
172
+ self._shape = shape
173
+
174
+ if self.value is not ShapeFunction.Value.Scalar:
175
+ self.BasisArg = self.topology.TopologyArg
176
+ self.basis_arg_value = self.topology.topo_arg_value
177
+
178
+ self.ORDER = self._shape.ORDER
179
+
180
+ if hasattr(shape, "element_node_triangulation"):
181
+ self.node_triangulation = self._node_triangulation
182
+ if hasattr(shape, "element_node_tets"):
183
+ self.node_tets = self._node_tets
184
+ if hasattr(shape, "element_node_hexes"):
185
+ self.node_hexes = self._node_hexes
186
+ if hasattr(shape, "element_vtk_cells"):
187
+ self.vtk_cells = self._vtk_cells
188
+ if hasattr(topology, "node_grid"):
189
+ self.node_grid = topology.node_grid
190
+
191
+ @property
192
+ def shape(self) -> ShapeFunction:
193
+ """Shape functions used for defining individual element basis"""
194
+ return self._shape
195
+
196
+ @property
197
+ def value(self) -> ShapeFunction.Value:
198
+ return self.shape.value
199
+
200
+ @property
201
+ def name(self):
202
+ return f"{self.topology.name}_{self._shape.name}"
203
+
204
+ def make_node_coords_in_element(self):
205
+ shape_node_coords_in_element = self._shape.make_node_coords_in_element()
206
+
207
+ @cache.dynamic_func(suffix=self.name)
208
+ def node_coords_in_element(
209
+ elt_arg: self.geometry.CellArg,
210
+ basis_arg: self.BasisArg,
211
+ element_index: ElementIndex,
212
+ node_index_in_elt: int,
213
+ ):
214
+ return shape_node_coords_in_element(node_index_in_elt)
215
+
216
+ return node_coords_in_element
217
+
218
+ def make_node_quadrature_weight(self):
219
+ shape_node_quadrature_weight = self._shape.make_node_quadrature_weight()
220
+
221
+ if shape_node_quadrature_weight is None:
222
+ return None
223
+
224
+ @cache.dynamic_func(suffix=self.name)
225
+ def node_quadrature_weight(
226
+ elt_arg: self.geometry.CellArg,
227
+ basis_arg: self.BasisArg,
228
+ element_index: ElementIndex,
229
+ node_index_in_elt: int,
230
+ ):
231
+ return shape_node_quadrature_weight(node_index_in_elt)
232
+
233
+ return node_quadrature_weight
234
+
235
+ def make_element_inner_weight(self):
236
+ shape_element_inner_weight = self._shape.make_element_inner_weight()
237
+
238
+ @cache.dynamic_func(suffix=self.name)
239
+ def element_inner_weight(
240
+ elt_arg: self.geometry.CellArg,
241
+ basis_arg: self.BasisArg,
242
+ element_index: ElementIndex,
243
+ coords: Coords,
244
+ node_index_in_elt: int,
245
+ qp_index: QuadraturePointIndex,
246
+ ):
247
+ if wp.static(self.value == ShapeFunction.Value.Scalar):
248
+ return shape_element_inner_weight(coords, node_index_in_elt)
249
+ else:
250
+ sign = self.topology.element_node_sign(elt_arg, basis_arg, element_index, node_index_in_elt)
251
+ return sign * shape_element_inner_weight(coords, node_index_in_elt)
252
+
253
+ return element_inner_weight
254
+
255
+ def make_element_inner_weight_gradient(self):
256
+ shape_element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
257
+
258
+ @cache.dynamic_func(suffix=self.name)
259
+ def element_inner_weight_gradient(
260
+ elt_arg: self.geometry.CellArg,
261
+ basis_arg: self.BasisArg,
262
+ element_index: ElementIndex,
263
+ coords: Coords,
264
+ node_index_in_elt: int,
265
+ qp_index: QuadraturePointIndex,
266
+ ):
267
+ if wp.static(self.value == ShapeFunction.Value.Scalar):
268
+ return shape_element_inner_weight_gradient(coords, node_index_in_elt)
269
+ else:
270
+ sign = self.topology.element_node_sign(elt_arg, basis_arg, element_index, node_index_in_elt)
271
+ return sign * shape_element_inner_weight_gradient(coords, node_index_in_elt)
272
+
273
+ return element_inner_weight_gradient
274
+
275
+ def make_trace_node_quadrature_weight(self, trace_basis):
276
+ shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight()
277
+
278
+ if shape_trace_node_quadrature_weight is None:
279
+ return None
280
+
281
+ @cache.dynamic_func(suffix=self.name)
282
+ def trace_node_quadrature_weight(
283
+ geo_side_arg: trace_basis.geometry.SideArg,
284
+ basis_arg: trace_basis.BasisArg,
285
+ element_index: ElementIndex,
286
+ node_index_in_elt: int,
287
+ ):
288
+ neighbour_elem, index_in_neighbour = trace_basis.topology.neighbor_cell_index(
289
+ geo_side_arg, element_index, node_index_in_elt
290
+ )
291
+ return shape_trace_node_quadrature_weight(index_in_neighbour)
292
+
293
+ return trace_node_quadrature_weight
294
+
295
+ def _node_triangulation(self):
296
+ element_node_indices = self._topology.element_node_indices().numpy()
297
+ element_triangles = self._shape.element_node_triangulation()
298
+
299
+ tri_indices = element_node_indices[:, element_triangles].reshape(-1, 3)
300
+ return tri_indices
301
+
302
+ def _node_tets(self):
303
+ element_node_indices = self._topology.element_node_indices().numpy()
304
+ element_tets = self._shape.element_node_tets()
305
+
306
+ tet_indices = element_node_indices[:, element_tets].reshape(-1, 4)
307
+ return tet_indices
308
+
309
+ def _node_hexes(self):
310
+ element_node_indices = self._topology.element_node_indices().numpy()
311
+ element_hexes = self._shape.element_node_hexes()
312
+
313
+ hex_indices = element_node_indices[:, element_hexes].reshape(-1, 8)
314
+ return hex_indices
315
+
316
+ def _vtk_cells(self):
317
+ element_node_indices = self._topology.element_node_indices().numpy()
318
+ element_vtk_cells, element_vtk_cell_types = self._shape.element_vtk_cells()
319
+
320
+ idx_per_cell = element_vtk_cells.shape[1]
321
+ cell_indices = element_node_indices[:, element_vtk_cells].reshape(-1, idx_per_cell)
322
+ cells = np.hstack((np.full((cell_indices.shape[0], 1), idx_per_cell), cell_indices))
323
+
324
+ return cells.flatten(), np.tile(element_vtk_cell_types, element_node_indices.shape[0])
325
+
326
+
327
+ class TraceBasisSpace(BasisSpace):
328
+ """Auto-generated trace space evaluating the cell-defined basis on the geometry sides"""
329
+
330
+ def __init__(self, basis: BasisSpace):
331
+ super().__init__(basis.topology.trace())
332
+
333
+ self.ORDER = basis.ORDER
334
+
335
+ self._basis = basis
336
+ self.BasisArg = self._basis.BasisArg
337
+ self.basis_arg_value = self._basis.basis_arg_value
338
+
339
+ @property
340
+ def name(self):
341
+ return f"{self._basis.name}_Trace"
342
+
343
+ @property
344
+ def value(self) -> ShapeFunction.Value:
345
+ return self._basis.value
346
+
347
+ def make_node_coords_in_element(self):
348
+ node_coords_in_cell = self._basis.make_node_coords_in_element()
349
+
350
+ @cache.dynamic_func(suffix=self._basis.name)
351
+ def trace_node_coords_in_element(
352
+ geo_side_arg: self.geometry.SideArg,
353
+ basis_arg: self.BasisArg,
354
+ element_index: ElementIndex,
355
+ node_index_in_elt: int,
356
+ ):
357
+ neighbour_elem, index_in_neighbour = self.topology.neighbor_cell_index(
358
+ geo_side_arg, element_index, node_index_in_elt
359
+ )
360
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
361
+ neighbour_coords = node_coords_in_cell(
362
+ geo_cell_arg,
363
+ basis_arg,
364
+ neighbour_elem,
365
+ index_in_neighbour,
366
+ )
367
+
368
+ return self.geometry.side_from_cell_coords(geo_side_arg, element_index, neighbour_elem, neighbour_coords)
369
+
370
+ return trace_node_coords_in_element
371
+
372
+ def make_node_quadrature_weight(self):
373
+ return self._basis.make_trace_node_quadrature_weight(self)
374
+
375
+ def make_element_inner_weight(self):
376
+ cell_inner_weight = self._basis.make_element_inner_weight()
377
+
378
+ @cache.dynamic_func(suffix=self._basis.name)
379
+ def trace_element_inner_weight(
380
+ geo_side_arg: self.geometry.SideArg,
381
+ basis_arg: self.BasisArg,
382
+ element_index: ElementIndex,
383
+ coords: Coords,
384
+ node_index_in_elt: int,
385
+ qp_index: QuadraturePointIndex,
386
+ ):
387
+ cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
388
+ if cell_index == NULL_ELEMENT_INDEX:
389
+ return self.weight_type(0.0)
390
+
391
+ cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
392
+
393
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
394
+ return cell_inner_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
395
+
396
+ return trace_element_inner_weight
397
+
398
+ def make_element_outer_weight(self):
399
+ cell_outer_weight = self._basis.make_element_outer_weight()
400
+
401
+ @cache.dynamic_func(suffix=self._basis.name)
402
+ def trace_element_outer_weight(
403
+ geo_side_arg: self.geometry.SideArg,
404
+ basis_arg: self.BasisArg,
405
+ element_index: ElementIndex,
406
+ coords: Coords,
407
+ node_index_in_elt: int,
408
+ qp_index: QuadraturePointIndex,
409
+ ):
410
+ cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
411
+ if cell_index == NULL_ELEMENT_INDEX:
412
+ return self.weight_type(0.0)
413
+
414
+ cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
415
+
416
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
417
+ return cell_outer_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
418
+
419
+ return trace_element_outer_weight
420
+
421
+ def make_element_inner_weight_gradient(self):
422
+ cell_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
423
+
424
+ @cache.dynamic_func(suffix=self._basis.name)
425
+ def trace_element_inner_weight_gradient(
426
+ geo_side_arg: self.geometry.SideArg,
427
+ basis_arg: self.BasisArg,
428
+ element_index: ElementIndex,
429
+ coords: Coords,
430
+ node_index_in_elt: int,
431
+ qp_index: QuadraturePointIndex,
432
+ ):
433
+ cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
434
+ if cell_index == NULL_ELEMENT_INDEX:
435
+ return self.weight_gradient_type(0.0)
436
+
437
+ cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
438
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
439
+ return cell_inner_weight_gradient(
440
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
441
+ )
442
+
443
+ return trace_element_inner_weight_gradient
444
+
445
+ def make_element_outer_weight_gradient(self):
446
+ cell_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
447
+
448
+ @cache.dynamic_func(suffix=self._basis.name)
449
+ def trace_element_outer_weight_gradient(
450
+ geo_side_arg: self.geometry.SideArg,
451
+ basis_arg: self.BasisArg,
452
+ element_index: ElementIndex,
453
+ coords: Coords,
454
+ node_index_in_elt: int,
455
+ qp_index: QuadraturePointIndex,
456
+ ):
457
+ cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
458
+ if cell_index == NULL_ELEMENT_INDEX:
459
+ return self.weight_gradient_type(0.0)
460
+
461
+ cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
462
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
463
+ return cell_outer_weight_gradient(
464
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
465
+ )
466
+
467
+ return trace_element_outer_weight_gradient
468
+
469
+ def __eq__(self, other: "TraceBasisSpace") -> bool:
470
+ return self._topo == other._topo
471
+
472
+
473
+ class PiecewiseConstantBasisSpace(ShapeBasisSpace):
474
+ class Trace(TraceBasisSpace):
475
+ def make_node_coords_in_element(self):
476
+ # Makes the single node visible to all sides; useful for interpolating on boundaries
477
+ # For higher-order non-conforming elements direct interpolation on boundary is not possible,
478
+ # need to do proper integration then solve with mass matrix
479
+
480
+ CENTER_COORDS = Coords(self.geometry.reference_side().center())
481
+
482
+ @cache.dynamic_func(suffix=self._basis.name)
483
+ def trace_node_coords_in_element(
484
+ geo_side_arg: self.geometry.SideArg,
485
+ basis_arg: self.BasisArg,
486
+ element_index: ElementIndex,
487
+ node_index_in_elt: int,
488
+ ):
489
+ return CENTER_COORDS
490
+
491
+ return trace_node_coords_in_element
492
+
493
+ def trace(self):
494
+ return PiecewiseConstantBasisSpace.Trace(self)
495
+
496
+
497
+ def make_discontinuous_basis_space(geometry: Geometry, shape: ShapeFunction):
498
+ topology = RegularDiscontinuousSpaceTopology(geometry, shape.NODES_PER_ELEMENT)
499
+
500
+ if shape.NODES_PER_ELEMENT == 1:
501
+ # piecewise-constant space
502
+ return PiecewiseConstantBasisSpace(topology=topology, shape=shape)
503
+
504
+ return ShapeBasisSpace(topology=topology, shape=shape)
505
+
506
+
507
+ class UnstructuredPointTopology(SpaceTopology):
508
+ """Topology for unstructured points defined from quadrature formula. See :class:`PointBasisSpace`"""
509
+
510
+ def __init__(self, quadrature: Quadrature):
511
+ if quadrature.max_points_per_element() is None:
512
+ raise ValueError("Quadrature must define a maximum number of points per element")
513
+
514
+ if quadrature.domain.element_count() != quadrature.domain.geometry_element_count():
515
+ raise ValueError("Point topology may only be defined on quadrature domains than span the whole geometry")
516
+
517
+ self._quadrature = quadrature
518
+ self.TopologyArg = quadrature.Arg
519
+
520
+ super().__init__(quadrature.domain.geometry, max_nodes_per_element=quadrature.max_points_per_element())
521
+
522
+ self.element_node_index = self._make_element_node_index()
523
+ self.element_node_count = self._make_element_node_count()
524
+ self.side_neighbor_node_counts = self._make_side_neighbor_node_counts()
525
+
526
+ def node_count(self):
527
+ return self._quadrature.total_point_count()
528
+
529
+ @property
530
+ def name(self):
531
+ return f"PointTopology_{self._quadrature}"
532
+
533
+ def topo_arg_value(self, device) -> SpaceTopology.TopologyArg:
534
+ """Value of the topology argument structure to be passed to device functions"""
535
+ return self._quadrature.arg_value(device)
536
+
537
+ def _make_element_node_index(self):
538
+ @cache.dynamic_func(suffix=self.name)
539
+ def element_node_index(
540
+ elt_arg: self.geometry.CellArg,
541
+ topo_arg: self.TopologyArg,
542
+ element_index: ElementIndex,
543
+ node_index_in_elt: int,
544
+ ):
545
+ return self._quadrature.point_index(elt_arg, topo_arg, element_index, element_index, node_index_in_elt)
546
+
547
+ return element_node_index
548
+
549
+ def _make_element_node_count(self):
550
+ @cache.dynamic_func(suffix=self.name)
551
+ def element_node_count(
552
+ elt_arg: self.geometry.CellArg,
553
+ topo_arg: self.TopologyArg,
554
+ element_index: ElementIndex,
555
+ ):
556
+ return self._quadrature.point_count(elt_arg, topo_arg, element_index, element_index)
557
+
558
+ return element_node_count
559
+
560
+ def _make_side_neighbor_node_counts(self):
561
+ @cache.dynamic_func(suffix=self.name)
562
+ def side_neighbor_node_counts(
563
+ side_arg: self.geometry.SideArg,
564
+ element_index: ElementIndex,
565
+ ):
566
+ return 0, 0
567
+
568
+ return side_neighbor_node_counts
569
+
570
+
571
+ class PointBasisSpace(BasisSpace):
572
+ """An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only.
573
+
574
+ The node locations and nodal quadrature weights are defined by a :class:`Quadrature` formula.
575
+ """
576
+
577
+ def __init__(self, quadrature: Quadrature):
578
+ self._quadrature = quadrature
579
+
580
+ topology = UnstructuredPointTopology(quadrature)
581
+ super().__init__(topology)
582
+
583
+ self.BasisArg = quadrature.Arg
584
+ self.basis_arg_value = quadrature.arg_value
585
+ self.ORDER = 0
586
+
587
+ self.make_element_outer_weight = self.make_element_inner_weight
588
+ self.make_element_outer_weight_gradient = self.make_element_outer_weight_gradient
589
+
590
+ @property
591
+ def name(self):
592
+ return f"{self._quadrature.name}_Point"
593
+
594
+ @property
595
+ def value(self) -> ShapeFunction.Value:
596
+ return ShapeFunction.Value.Scalar
597
+
598
+ def make_node_coords_in_element(self):
599
+ @cache.dynamic_func(suffix=self.name)
600
+ def node_coords_in_element(
601
+ elt_arg: self._quadrature.domain.ElementArg,
602
+ basis_arg: self.BasisArg,
603
+ element_index: ElementIndex,
604
+ node_index_in_elt: int,
605
+ ):
606
+ return self._quadrature.point_coords(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
607
+
608
+ return node_coords_in_element
609
+
610
+ def make_node_quadrature_weight(self):
611
+ @cache.dynamic_func(suffix=self.name)
612
+ def node_quadrature_weight(
613
+ elt_arg: self._quadrature.domain.ElementArg,
614
+ basis_arg: self.BasisArg,
615
+ element_index: ElementIndex,
616
+ node_index_in_elt: int,
617
+ ):
618
+ return self._quadrature.point_weight(elt_arg, basis_arg, element_index, element_index, node_index_in_elt)
619
+
620
+ return node_quadrature_weight
621
+
622
+ def make_element_inner_weight(self):
623
+ _DIRAC_INTEGRATION_RADIUS = wp.constant(1.0e-6)
624
+
625
+ @cache.dynamic_func(suffix=self.name)
626
+ def element_inner_weight(
627
+ elt_arg: self._quadrature.domain.ElementArg,
628
+ basis_arg: self.BasisArg,
629
+ element_index: ElementIndex,
630
+ coords: Coords,
631
+ node_index_in_elt: int,
632
+ qp_index: QuadraturePointIndex,
633
+ ):
634
+ qp_coord = self._quadrature.point_coords(
635
+ elt_arg, basis_arg, element_index, element_index, node_index_in_elt
636
+ )
637
+ return wp.where(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 1.0, 0.0)
638
+
639
+ return element_inner_weight
640
+
641
+ def make_element_inner_weight_gradient(self):
642
+ gradient_vec = cache.cached_vec_type(length=self.geometry.cell_dimension, dtype=float)
643
+
644
+ @cache.dynamic_func(suffix=self.name)
645
+ def element_inner_weight_gradient(
646
+ elt_arg: self._quadrature.domain.ElementArg,
647
+ basis_arg: self.BasisArg,
648
+ element_index: ElementIndex,
649
+ coords: Coords,
650
+ node_index_in_elt: int,
651
+ qp_index: QuadraturePointIndex,
652
+ ):
653
+ return gradient_vec(0.0)
654
+
655
+ return element_inner_weight_gradient
656
+
657
+ def make_trace_node_quadrature_weight(self, trace_basis):
658
+ @cache.dynamic_func(suffix=self.name)
659
+ def trace_node_quadrature_weight(
660
+ elt_arg: trace_basis.geometry.SideArg,
661
+ basis_arg: trace_basis.BasisArg,
662
+ element_index: ElementIndex,
663
+ node_index_in_elt: int,
664
+ ):
665
+ return 0.0
666
+
667
+ return trace_node_quadrature_weight