warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,271 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp.fem import cache
18
+ from warp.fem.geometry import Tetmesh
19
+ from warp.fem.types import ElementIndex
20
+
21
+ from .shape import (
22
+ ShapeFunction,
23
+ TetrahedronPolynomialShapeFunctions,
24
+ TetrahedronShapeFunction,
25
+ )
26
+ from .topology import SpaceTopology, forward_base_topology
27
+
28
+
29
+ @wp.struct
30
+ class TetmeshTopologyArg:
31
+ tet_edge_indices: wp.array2d(dtype=int)
32
+ tet_face_indices: wp.array2d(dtype=int)
33
+ face_vertex_indices: wp.array(dtype=wp.vec3i)
34
+ face_tet_indices: wp.array(dtype=wp.vec2i)
35
+
36
+ vertex_count: int
37
+ edge_count: int
38
+ face_count: int
39
+
40
+
41
+ class TetmeshSpaceTopology(SpaceTopology):
42
+ TopologyArg = TetmeshTopologyArg
43
+
44
+ def __init__(
45
+ self,
46
+ mesh: Tetmesh,
47
+ shape: TetrahedronShapeFunction,
48
+ ):
49
+ self._shape = shape
50
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
51
+ self._mesh = mesh
52
+
53
+ need_tet_edge_indices = self._shape.EDGE_NODE_COUNT > 0
54
+ need_tet_face_indices = self._shape.FACE_NODE_COUNT > 0
55
+
56
+ if need_tet_edge_indices:
57
+ self._tet_edge_indices = self._mesh.tet_edge_indices
58
+ self._edge_count = self._mesh.edge_count()
59
+ else:
60
+ self._tet_edge_indices = wp.empty(shape=(0, 0), dtype=int)
61
+ self._edge_count = 0
62
+
63
+ if need_tet_face_indices:
64
+ self._compute_tet_face_indices()
65
+ else:
66
+ self._tet_face_indices = wp.empty(shape=(0, 0), dtype=int)
67
+
68
+ self.element_node_index = self._make_element_node_index()
69
+ self.element_node_sign = self._make_element_node_sign()
70
+
71
+ @property
72
+ def name(self):
73
+ return f"{self.geometry.name}_{self._shape.name}"
74
+
75
+ @cache.cached_arg_value
76
+ def topo_arg_value(self, device):
77
+ arg = TetmeshTopologyArg()
78
+ arg.tet_face_indices = self._tet_face_indices.to(device)
79
+ arg.tet_edge_indices = self._tet_edge_indices.to(device)
80
+ arg.face_vertex_indices = self._mesh.face_vertex_indices.to(device)
81
+ arg.face_tet_indices = self._mesh.face_tet_indices.to(device)
82
+
83
+ arg.vertex_count = self._mesh.vertex_count()
84
+ arg.face_count = self._mesh.side_count()
85
+ arg.edge_count = self._edge_count
86
+ return arg
87
+
88
+ def _compute_tet_face_indices(self):
89
+ self._tet_face_indices = wp.empty(
90
+ dtype=int, device=self._mesh.tet_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
91
+ )
92
+
93
+ wp.launch(
94
+ kernel=TetmeshSpaceTopology._compute_tet_face_indices_kernel,
95
+ dim=self._mesh._face_tet_indices.shape,
96
+ device=self._mesh.tet_vertex_indices.device,
97
+ inputs=[
98
+ self._mesh.face_tet_indices,
99
+ self._mesh.face_vertex_indices,
100
+ self._mesh.tet_vertex_indices,
101
+ self._tet_face_indices,
102
+ ],
103
+ )
104
+
105
+ @wp.func
106
+ def _find_face_index_in_tet(
107
+ face_vtx: wp.vec3i,
108
+ tet_vtx: wp.vec4i,
109
+ ):
110
+ for k in range(3):
111
+ tvk = wp.vec3i(tet_vtx[k], tet_vtx[(k + 1) % 4], tet_vtx[(k + 2) % 4])
112
+
113
+ # Use fact that face always start with min vertex
114
+ min_t = wp.min(tvk)
115
+ max_t = wp.max(tvk)
116
+ mid_t = tvk[0] + tvk[1] + tvk[2] - min_t - max_t
117
+
118
+ if min_t == face_vtx[0] and (
119
+ (face_vtx[2] == max_t and face_vtx[1] == mid_t) or (face_vtx[1] == max_t and face_vtx[2] == mid_t)
120
+ ):
121
+ return k
122
+
123
+ return 3
124
+
125
+ @wp.kernel
126
+ def _compute_tet_face_indices_kernel(
127
+ face_tet_indices: wp.array(dtype=wp.vec2i),
128
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
129
+ tet_vertex_indices: wp.array2d(dtype=int),
130
+ tet_face_indices: wp.array2d(dtype=int),
131
+ ):
132
+ e = wp.tid()
133
+
134
+ face_vtx = face_vertex_indices[e]
135
+ face_tets = face_tet_indices[e]
136
+
137
+ t0 = face_tets[0]
138
+ t0_vtx = wp.vec4i(
139
+ tet_vertex_indices[t0, 0], tet_vertex_indices[t0, 1], tet_vertex_indices[t0, 2], tet_vertex_indices[t0, 3]
140
+ )
141
+ t0_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t0_vtx)
142
+ tet_face_indices[t0, t0_face] = e
143
+
144
+ t1 = face_tets[1]
145
+ if t1 != t0:
146
+ t1_vtx = wp.vec4i(
147
+ tet_vertex_indices[t1, 0],
148
+ tet_vertex_indices[t1, 1],
149
+ tet_vertex_indices[t1, 2],
150
+ tet_vertex_indices[t1, 3],
151
+ )
152
+ t1_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t1_vtx)
153
+ tet_face_indices[t1, t1_face] = e
154
+
155
+ def node_count(self) -> int:
156
+ return (
157
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
158
+ + self._mesh.edge_count() * self._shape.EDGE_NODE_COUNT
159
+ + self._mesh.side_count() * self._shape.FACE_NODE_COUNT
160
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
161
+ )
162
+
163
+ def _make_element_node_index(self):
164
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
165
+ INTERIOR_NODES_PER_EDGE = self._shape.EDGE_NODE_COUNT
166
+ INTERIOR_NODES_PER_FACE = self._shape.FACE_NODE_COUNT
167
+ INTERIOR_NODES_PER_CELL = self._shape.INTERIOR_NODE_COUNT
168
+
169
+ @cache.dynamic_func(suffix=self.name)
170
+ def element_node_index(
171
+ geo_arg: Tetmesh.CellArg,
172
+ topo_arg: TetmeshTopologyArg,
173
+ element_index: ElementIndex,
174
+ node_index_in_elt: int,
175
+ ):
176
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
177
+
178
+ if node_type == TetrahedronPolynomialShapeFunctions.VERTEX:
179
+ vertex = type_index // VERTEX_NODE_COUNT
180
+ vertex_node = type_index - VERTEX_NODE_COUNT * vertex
181
+ return geo_arg.tet_vertex_indices[element_index][vertex] * VERTEX_NODE_COUNT + vertex_node
182
+
183
+ global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
184
+
185
+ if node_type == TetrahedronPolynomialShapeFunctions.EDGE:
186
+ edge = type_index // INTERIOR_NODES_PER_EDGE
187
+ edge_node = type_index - INTERIOR_NODES_PER_EDGE * edge
188
+
189
+ global_edge_index = topo_arg.tet_edge_indices[element_index][edge]
190
+
191
+ # Test if we need to swap edge direction
192
+ if wp.static(INTERIOR_NODES_PER_EDGE > 1):
193
+ c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
194
+ if geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2]:
195
+ edge_node = INTERIOR_NODES_PER_EDGE - 1 - edge_node
196
+
197
+ return global_offset + INTERIOR_NODES_PER_EDGE * global_edge_index + edge_node
198
+
199
+ global_offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count
200
+
201
+ if node_type == TetrahedronPolynomialShapeFunctions.FACE:
202
+ face = type_index // INTERIOR_NODES_PER_FACE
203
+ face_node = type_index - INTERIOR_NODES_PER_FACE * face
204
+
205
+ global_face_index = topo_arg.tet_face_indices[element_index][face]
206
+
207
+ if wp.static(INTERIOR_NODES_PER_FACE == 3):
208
+ # Hard code for P4 case, 3 nodes per face
209
+ # Higher orders would require rotating triangle coordinates, this is not supported yet
210
+
211
+ vidx = geo_arg.tet_vertex_indices[element_index][(face + face_node) % 4]
212
+ fvi = topo_arg.face_vertex_indices[global_face_index]
213
+
214
+ if vidx == fvi[0]:
215
+ face_node = 0
216
+ elif vidx == fvi[1]:
217
+ face_node = 1
218
+ else:
219
+ face_node = 2
220
+
221
+ return global_offset + INTERIOR_NODES_PER_FACE * global_face_index + face_node
222
+
223
+ global_offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count
224
+
225
+ return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
226
+
227
+ return element_node_index
228
+
229
+ def _make_element_node_sign(self):
230
+ INTERIOR_NODES_PER_EDGE = self._shape.EDGE_NODE_COUNT
231
+ INTERIOR_NODES_PER_FACE = self._shape.FACE_NODE_COUNT
232
+
233
+ @cache.dynamic_func(suffix=self.name)
234
+ def element_node_sign(
235
+ geo_arg: self.geometry.CellArg,
236
+ topo_arg: TetmeshTopologyArg,
237
+ element_index: ElementIndex,
238
+ node_index_in_elt: int,
239
+ ):
240
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
241
+
242
+ if wp.static(INTERIOR_NODES_PER_EDGE > 0):
243
+ if node_type == TetrahedronShapeFunction.EDGE:
244
+ edge = type_index // INTERIOR_NODES_PER_EDGE
245
+ c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
246
+
247
+ return wp.where(
248
+ geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2],
249
+ -1.0,
250
+ 1.0,
251
+ )
252
+
253
+ if wp.static(INTERIOR_NODES_PER_FACE > 0):
254
+ if node_type == TetrahedronShapeFunction.FACE:
255
+ face = type_index // INTERIOR_NODES_PER_FACE
256
+
257
+ global_face_index = topo_arg.tet_face_indices[element_index][face]
258
+ inner = topo_arg.face_tet_indices[global_face_index][0]
259
+
260
+ return wp.where(inner == element_index, 1.0, -1.0)
261
+
262
+ return 1.0
263
+
264
+ return element_node_sign
265
+
266
+
267
+ def make_tetmesh_space_topology(mesh: Tetmesh, shape: ShapeFunction):
268
+ if isinstance(shape, TetrahedronShapeFunction):
269
+ return forward_base_topology(TetmeshSpaceTopology, mesh, shape)
270
+
271
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,424 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple, Type
17
+
18
+ import warp as wp
19
+ from warp.fem import cache
20
+ from warp.fem.geometry import DeformedGeometry, Geometry
21
+ from warp.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX, ElementIndex
22
+
23
+
24
+ class SpaceTopology:
25
+ """
26
+ Interface class for defining the topology of a function space.
27
+
28
+ The topology only considers the indices of the nodes in each element, and as such,
29
+ the connectivity pattern of the function space.
30
+ It does not specify the actual location of the nodes within the elements, or the valuation function.
31
+ """
32
+
33
+ dimension: int
34
+ """Embedding dimension of the function space"""
35
+
36
+ MAX_NODES_PER_ELEMENT: int
37
+ """maximum number of interpolation nodes per element of the geometry.
38
+
39
+ .. note:: This will change to be defined per-element in future versions
40
+ """
41
+
42
+ @wp.struct
43
+ class TopologyArg:
44
+ """Structure containing arguments to be passed to device functions"""
45
+
46
+ pass
47
+
48
+ def __init__(self, geometry: Geometry, max_nodes_per_element: int):
49
+ self._geometry = geometry
50
+ self.dimension = geometry.dimension
51
+ self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
52
+ self.ElementArg = geometry.CellArg
53
+
54
+ self._make_constant_element_node_count()
55
+ self._make_constant_element_node_sign()
56
+
57
+ @property
58
+ def geometry(self) -> Geometry:
59
+ """Underlying geometry"""
60
+ return self._geometry
61
+
62
+ def node_count(self) -> int:
63
+ """Number of nodes in the interpolation basis"""
64
+ raise NotImplementedError
65
+
66
+ def topo_arg_value(self, device) -> "TopologyArg":
67
+ """Value of the topology argument structure to be passed to device functions"""
68
+ return SpaceTopology.TopologyArg()
69
+
70
+ @property
71
+ def name(self):
72
+ return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
73
+
74
+ def __str__(self):
75
+ return self.name
76
+
77
+ @staticmethod
78
+ def element_node_count(
79
+ geo_arg: "ElementArg", # noqa: F821
80
+ topo_arg: "TopologyArg",
81
+ element_index: ElementIndex,
82
+ ) -> int:
83
+ """Returns the actual number of nodes in a given element"""
84
+ raise NotImplementedError
85
+
86
+ @staticmethod
87
+ def element_node_index(
88
+ geo_arg: "ElementArg", # noqa: F821
89
+ topo_arg: "TopologyArg",
90
+ element_index: ElementIndex,
91
+ node_index_in_elt: int,
92
+ ) -> int:
93
+ """Global node index for a given node in a given element"""
94
+ raise NotImplementedError
95
+
96
+ @staticmethod
97
+ def side_neighbor_node_counts(
98
+ side_arg: "ElementArg", # noqa: F821
99
+ side_index: ElementIndex,
100
+ ) -> Tuple[int, int]:
101
+ """Returns the number of nodes for both the inner and outer cells of a given sides"""
102
+ raise NotImplementedError
103
+
104
+ def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array:
105
+ """Returns a temporary array containing the global index for each node of each element"""
106
+
107
+ MAX_NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
108
+
109
+ @cache.dynamic_kernel(suffix=self.name)
110
+ def fill_element_node_indices(
111
+ geo_cell_arg: self.geometry.CellArg,
112
+ topo_arg: self.TopologyArg,
113
+ element_node_indices: wp.array2d(dtype=int),
114
+ ):
115
+ element_index = wp.tid()
116
+ element_node_count = self.element_node_count(geo_cell_arg, topo_arg, element_index)
117
+ for n in range(element_node_count):
118
+ element_node_indices[element_index, n] = self.element_node_index(
119
+ geo_cell_arg, topo_arg, element_index, n
120
+ )
121
+
122
+ shape = (self.geometry.cell_count(), MAX_NODES_PER_ELEMENT)
123
+ if out is None:
124
+ element_node_indices = wp.empty(
125
+ shape=shape,
126
+ dtype=int,
127
+ )
128
+ else:
129
+ if out.shape != shape or out.dtype != wp.int32:
130
+ raise ValueError(f"Out element node indices array must have shape {shape} and data type 'int32'")
131
+ element_node_indices = out
132
+
133
+ wp.launch(
134
+ dim=element_node_indices.shape[0],
135
+ kernel=fill_element_node_indices,
136
+ inputs=[
137
+ self.geometry.cell_arg_value(device=element_node_indices.device),
138
+ self.topo_arg_value(device=element_node_indices.device),
139
+ element_node_indices,
140
+ ],
141
+ device=element_node_indices.device,
142
+ )
143
+
144
+ return element_node_indices
145
+
146
+ # Interface generating trace space topology
147
+
148
+ def trace(self) -> "TraceSpaceTopology":
149
+ """Trace of the function space over lower-dimensional elements of the geometry"""
150
+
151
+ return TraceSpaceTopology(self)
152
+
153
+ @property
154
+ def is_trace(self) -> bool:
155
+ """Whether this topology is defined on the trace of the geometry"""
156
+ return self.dimension == self.geometry.dimension - 1
157
+
158
+ def full_space_topology(self) -> "SpaceTopology":
159
+ """Returns the full space topology from which this topology is derived"""
160
+ return self
161
+
162
+ def __eq__(self, other: "SpaceTopology") -> bool:
163
+ """Checks whether two topologies are compatible"""
164
+ return self.geometry == other.geometry and self.name == other.name
165
+
166
+ def is_derived_from(self, other: "SpaceTopology") -> bool:
167
+ """Checks whether two topologies are equal, or `self` is the trace of `other`"""
168
+ if self.dimension == other.dimension:
169
+ return self == other
170
+ if self.dimension + 1 == other.dimension:
171
+ return self.full_space_topology() == other
172
+ return False
173
+
174
+ def _make_constant_element_node_count(self):
175
+ NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
176
+
177
+ @cache.dynamic_func(suffix=self.name)
178
+ def constant_element_node_count(
179
+ geo_arg: self.geometry.CellArg,
180
+ topo_arg: self.TopologyArg,
181
+ element_index: ElementIndex,
182
+ ):
183
+ return NODES_PER_ELEMENT
184
+
185
+ @cache.dynamic_func(suffix=self.name)
186
+ def constant_side_neighbor_node_counts(
187
+ side_arg: self.geometry.SideArg,
188
+ element_index: ElementIndex,
189
+ ):
190
+ return NODES_PER_ELEMENT, NODES_PER_ELEMENT
191
+
192
+ self.element_node_count = constant_element_node_count
193
+ self.side_neighbor_node_counts = constant_side_neighbor_node_counts
194
+
195
+ def _make_constant_element_node_sign(self):
196
+ @cache.dynamic_func(suffix=self.name)
197
+ def constant_element_node_sign(
198
+ geo_arg: self.geometry.CellArg,
199
+ topo_arg: self.TopologyArg,
200
+ element_index: ElementIndex,
201
+ node_index_in_element: int,
202
+ ):
203
+ return 1.0
204
+
205
+ self.element_node_sign = constant_element_node_sign
206
+
207
+
208
+ class TraceSpaceTopology(SpaceTopology):
209
+ """Auto-generated trace topology defining the node indices associated to the geometry sides"""
210
+
211
+ def __init__(self, topo: SpaceTopology):
212
+ self._topo = topo
213
+
214
+ super().__init__(topo.geometry, 2 * topo.MAX_NODES_PER_ELEMENT)
215
+
216
+ self.dimension = topo.dimension - 1
217
+ self.ElementArg = topo.geometry.SideArg
218
+
219
+ self.TopologyArg = topo.TopologyArg
220
+ self.topo_arg_value = topo.topo_arg_value
221
+
222
+ self.inner_cell_index = self._make_inner_cell_index()
223
+ self.outer_cell_index = self._make_outer_cell_index()
224
+ self.neighbor_cell_index = self._make_neighbor_cell_index()
225
+
226
+ self.element_node_index = self._make_element_node_index()
227
+ self.element_node_count = self._make_element_node_count()
228
+ self.side_neighbor_node_counts = None
229
+
230
+ def node_count(self) -> int:
231
+ return self._topo.node_count()
232
+
233
+ @property
234
+ def name(self):
235
+ return f"{self._topo.name}_Trace"
236
+
237
+ def _make_inner_cell_index(self):
238
+ @cache.dynamic_func(suffix=self.name)
239
+ def inner_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
240
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
241
+ if node_index_in_elt >= inner_count:
242
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
243
+ return self.geometry.side_inner_cell_index(side_arg, element_index), node_index_in_elt
244
+
245
+ return inner_cell_index
246
+
247
+ def _make_outer_cell_index(self):
248
+ @cache.dynamic_func(suffix=self.name)
249
+ def outer_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
250
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
251
+ if node_index_in_elt < inner_count:
252
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
253
+ return self.geometry.side_outer_cell_index(side_arg, element_index), node_index_in_elt - inner_count
254
+
255
+ return outer_cell_index
256
+
257
+ def _make_neighbor_cell_index(self):
258
+ @cache.dynamic_func(suffix=self.name)
259
+ def neighbor_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
260
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
261
+ if node_index_in_elt < inner_count:
262
+ return self.geometry.side_inner_cell_index(side_arg, element_index), node_index_in_elt
263
+
264
+ return (
265
+ self.geometry.side_outer_cell_index(side_arg, element_index),
266
+ node_index_in_elt - inner_count,
267
+ )
268
+
269
+ return neighbor_cell_index
270
+
271
+ def _make_element_node_count(self):
272
+ @cache.dynamic_func(suffix=self.name)
273
+ def trace_element_node_count(
274
+ geo_side_arg: self.geometry.SideArg,
275
+ topo_arg: self._topo.TopologyArg,
276
+ element_index: ElementIndex,
277
+ ):
278
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(geo_side_arg, element_index)
279
+ return inner_count + outer_count
280
+
281
+ return trace_element_node_count
282
+
283
+ def _make_element_node_index(self):
284
+ @cache.dynamic_func(suffix=self.name)
285
+ def trace_element_node_index(
286
+ geo_side_arg: self.geometry.SideArg,
287
+ topo_arg: self._topo.TopologyArg,
288
+ element_index: ElementIndex,
289
+ node_index_in_elt: int,
290
+ ):
291
+ cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
292
+
293
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
294
+ return self._topo.element_node_index(geo_cell_arg, topo_arg, cell_index, index_in_cell)
295
+
296
+ return trace_element_node_index
297
+
298
+ def _make_element_node_sign(self):
299
+ @cache.dynamic_func(suffix=self.name)
300
+ def trace_element_node_sign(
301
+ geo_side_arg: self.geometry.SideArg,
302
+ topo_arg: self._topo.TopologyArg,
303
+ element_index: ElementIndex,
304
+ node_index_in_elt: int,
305
+ ):
306
+ cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
307
+
308
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
309
+ return self._topo.element_node_sign(geo_cell_arg, topo_arg, cell_index, index_in_cell)
310
+
311
+ return trace_element_node_sign
312
+
313
+ def full_space_topology(self) -> SpaceTopology:
314
+ """Returns the full space topology from which this topology is derived"""
315
+ return self._topo
316
+
317
+ def __eq__(self, other: "TraceSpaceTopology") -> bool:
318
+ return self._topo == other._topo
319
+
320
+
321
+ class RegularDiscontinuousSpaceTopologyMixin:
322
+ """Helper for defining discontinuous topologies (per-element nodes)"""
323
+
324
+ def __init__(self, *args, **kwargs):
325
+ super().__init__(*args, **kwargs)
326
+ self.element_node_index = self._make_element_node_index()
327
+
328
+ def node_count(self):
329
+ return self.geometry.cell_count() * self.MAX_NODES_PER_ELEMENT
330
+
331
+ @property
332
+ def name(self):
333
+ return f"{self.geometry.name}_D{self.MAX_NODES_PER_ELEMENT}"
334
+
335
+ def _make_element_node_index(self):
336
+ NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
337
+
338
+ @cache.dynamic_func(suffix=self.name)
339
+ def element_node_index(
340
+ elt_arg: self.geometry.CellArg,
341
+ topo_arg: self.TopologyArg,
342
+ element_index: ElementIndex,
343
+ node_index_in_elt: int,
344
+ ):
345
+ return NODES_PER_ELEMENT * element_index + node_index_in_elt
346
+
347
+ return element_node_index
348
+
349
+
350
+ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin, SpaceTopology):
351
+ """Topology for generic discontinuous spaces"""
352
+
353
+ pass
354
+
355
+
356
+ class DeformedGeometrySpaceTopology(SpaceTopology):
357
+ def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
358
+ self.base = base_topology
359
+ super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
360
+
361
+ self.node_count = self.base.node_count
362
+ self.topo_arg_value = self.base.topo_arg_value
363
+ self.TopologyArg = self.base.TopologyArg
364
+
365
+ self._make_passthrough_functions()
366
+
367
+ @property
368
+ def name(self):
369
+ return f"{self.base.name}_{self.geometry.field.name}"
370
+
371
+ def _make_passthrough_functions(self):
372
+ @cache.dynamic_func(suffix=self.name)
373
+ def element_node_index(
374
+ elt_arg: self.geometry.CellArg,
375
+ topo_arg: self.TopologyArg,
376
+ element_index: ElementIndex,
377
+ node_index_in_elt: int,
378
+ ):
379
+ return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
380
+
381
+ @cache.dynamic_func(suffix=self.name)
382
+ def element_node_count(
383
+ elt_arg: self.geometry.CellArg,
384
+ topo_arg: self.TopologyArg,
385
+ element_count: ElementIndex,
386
+ ):
387
+ return self.base.element_node_count(elt_arg.elt_arg, topo_arg, element_count)
388
+
389
+ @cache.dynamic_func(suffix=self.name)
390
+ def side_neighbor_node_counts(
391
+ side_arg: self.geometry.SideArg,
392
+ element_index: ElementIndex,
393
+ ):
394
+ inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
395
+ return inner_count, outer_count
396
+
397
+ @cache.dynamic_func(suffix=self.name)
398
+ def element_node_sign(
399
+ elt_arg: self.geometry.CellArg,
400
+ topo_arg: self.TopologyArg,
401
+ element_index: ElementIndex,
402
+ node_index_in_elt: int,
403
+ ):
404
+ return self.base.element_node_sign(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
405
+
406
+ self.element_node_index = element_node_index
407
+ self.element_node_count = element_node_count
408
+ self.element_node_sign = element_node_sign
409
+ self.side_neighbor_node_counts = side_neighbor_node_counts
410
+
411
+
412
+ def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
413
+ """
414
+ If `geometry` is *not* a :class:`DeformedGeometry`, constructs a normal instance of `topology_class` over `geometry`, forwarding additional arguments.
415
+
416
+ If `geometry` *is* a :class:`DeformedGeometry`, constructs an instance of `topology_class` over the base (undeformed) geometry of `geometry`, then warp it
417
+ in a :class:`DeformedGeometrySpaceTopology` forwarding the calls to the underlying topology.
418
+ """
419
+
420
+ if isinstance(geometry, DeformedGeometry):
421
+ base_topo = topology_class(geometry.base, *args, **kwargs)
422
+ return DeformedGeometrySpaceTopology(geometry, base_topo)
423
+
424
+ return topology_class(geometry, *args, **kwargs)