warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,177 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+ from warp.fem import cache
20
+ from warp.fem.geometry import Grid2D
21
+ from warp.fem.polynomial import is_closed
22
+ from warp.fem.types import NULL_NODE_INDEX, ElementIndex
23
+
24
+ from .shape import SquareBipolynomialShapeFunctions, SquareShapeFunction
25
+ from .topology import SpaceTopology, forward_base_topology
26
+
27
+
28
+ class Grid2DSpaceTopology(SpaceTopology):
29
+ def __init__(self, grid: Grid2D, shape: SquareShapeFunction):
30
+ self._shape = shape
31
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
32
+
33
+ self.element_node_index = self._make_element_node_index()
34
+
35
+ TopologyArg = Grid2D.SideArg
36
+
37
+ @property
38
+ def name(self):
39
+ return f"{self.geometry.name}_{self._shape.name}"
40
+
41
+ def topo_arg_value(self, device):
42
+ return self.geometry.side_arg_value(device)
43
+
44
+ def node_count(self) -> int:
45
+ return (
46
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
47
+ + self.geometry.side_count() * self._shape.EDGE_NODE_COUNT
48
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
49
+ )
50
+
51
+ def _make_element_node_index(self):
52
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
53
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
54
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
55
+
56
+ @cache.dynamic_func(suffix=self.name)
57
+ def element_node_index(
58
+ cell_arg: Grid2D.CellArg,
59
+ topo_arg: Grid2D.SideArg,
60
+ element_index: ElementIndex,
61
+ node_index_in_elt: int,
62
+ ):
63
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
64
+
65
+ if wp.static(VERTEX_NODE_COUNT > 0):
66
+ if node_type == SquareShapeFunction.VERTEX:
67
+ return (
68
+ Grid2DSpaceTopology._vertex_index(cell_arg, element_index, type_instance) * VERTEX_NODE_COUNT
69
+ + type_index
70
+ )
71
+
72
+ res = cell_arg.res
73
+ vertex_count = (res[0] + 1) * (res[1] + 1)
74
+ global_offset = vertex_count
75
+
76
+ if wp.static(INTERIOR_NODE_COUNT > 0):
77
+ if node_type == SquareShapeFunction.INTERIOR:
78
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
79
+
80
+ cell_count = res[0] * res[1]
81
+ global_offset += INTERIOR_NODE_COUNT * cell_count
82
+
83
+ if wp.static(EDGE_NODE_COUNT > 0):
84
+ axis = 1 - (node_type - SquareShapeFunction.EDGE_X)
85
+
86
+ cell = Grid2D.get_cell(cell_arg.res, element_index)
87
+ origin = wp.vec2i(cell[Grid2D.ROTATION[axis, 0]] + type_instance, cell[Grid2D.ROTATION[axis, 1]])
88
+
89
+ side = Grid2D.Side(axis, origin)
90
+ side_index = Grid2D.side_index(topo_arg, side)
91
+
92
+ vertex_count = (res[0] + 1) * (res[1] + 1)
93
+
94
+ return global_offset + EDGE_NODE_COUNT * side_index + type_index
95
+
96
+ return NULL_NODE_INDEX # unreachable
97
+
98
+ return element_node_index
99
+
100
+ @wp.func
101
+ def _vertex_coords(vidx_in_cell: int):
102
+ x = vidx_in_cell // 2
103
+ y = vidx_in_cell - 2 * x
104
+ return wp.vec2i(x, y)
105
+
106
+ @wp.func
107
+ def _vertex_index(cell_arg: Grid2D.CellArg, cell_index: ElementIndex, vidx_in_cell: int):
108
+ res = cell_arg.res
109
+ x_stride = res[1] + 1
110
+
111
+ corner = Grid2D.get_cell(res, cell_index) + Grid2DSpaceTopology._vertex_coords(vidx_in_cell)
112
+ return Grid2D._from_2d_index(x_stride, corner)
113
+
114
+
115
+ class GridBipolynomialSpaceTopology(SpaceTopology):
116
+ def __init__(self, grid: Grid2D, shape: SquareBipolynomialShapeFunctions):
117
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
118
+ self._shape = shape
119
+ self.element_node_index = self._make_element_node_index()
120
+
121
+ def node_count(self) -> int:
122
+ return (self.geometry.res[0] * self._shape.ORDER + 1) * (self.geometry.res[1] * self._shape.ORDER + 1)
123
+
124
+ def _make_element_node_index(self):
125
+ ORDER = self._shape.ORDER
126
+
127
+ @cache.dynamic_func(suffix=self.name)
128
+ def element_node_index(
129
+ cell_arg: Grid2D.CellArg,
130
+ topo_arg: self.TopologyArg,
131
+ element_index: ElementIndex,
132
+ node_index_in_elt: int,
133
+ ):
134
+ res = cell_arg.res
135
+ cell = Grid2D.get_cell(res, element_index)
136
+
137
+ node_i = node_index_in_elt // (ORDER + 1)
138
+ node_j = node_index_in_elt - (ORDER + 1) * node_i
139
+
140
+ node_x = ORDER * cell[0] + node_i
141
+ node_y = ORDER * cell[1] + node_j
142
+
143
+ node_pitch = (res[1] * ORDER) + 1
144
+ node_index = node_pitch * node_x + node_y
145
+
146
+ return node_index
147
+
148
+ return element_node_index
149
+
150
+ def node_grid(self):
151
+ res = self.geometry.res
152
+
153
+ cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
154
+
155
+ grid_coords_x = np.repeat(np.arange(0, res[0], dtype=float), len(cell_coords)) + np.tile(
156
+ cell_coords, reps=res[0]
157
+ )
158
+ grid_coords_x = np.append(grid_coords_x, res[0])
159
+ X = grid_coords_x * self.geometry.cell_size[0] + self.geometry.origin[0]
160
+
161
+ grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
162
+ cell_coords, reps=res[1]
163
+ )
164
+ grid_coords_y = np.append(grid_coords_y, res[1])
165
+ Y = grid_coords_y * self.geometry.cell_size[1] + self.geometry.origin[1]
166
+
167
+ return np.meshgrid(X, Y, indexing="ij")
168
+
169
+
170
+ def make_grid_2d_space_topology(grid: Grid2D, shape: SquareShapeFunction):
171
+ if isinstance(shape, SquareBipolynomialShapeFunctions) and is_closed(shape.family):
172
+ return forward_base_topology(GridBipolynomialSpaceTopology, grid, shape)
173
+
174
+ if isinstance(shape, SquareShapeFunction):
175
+ return forward_base_topology(Grid2DSpaceTopology, grid, shape)
176
+
177
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,227 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+ from warp.fem import cache
20
+ from warp.fem.geometry import Grid3D
21
+ from warp.fem.polynomial import is_closed
22
+ from warp.fem.types import ElementIndex
23
+
24
+ from .shape import (
25
+ CubeShapeFunction,
26
+ CubeTripolynomialShapeFunctions,
27
+ )
28
+ from .topology import SpaceTopology, forward_base_topology
29
+
30
+
31
+ class Grid3DSpaceTopology(SpaceTopology):
32
+ def __init__(self, grid: Grid3D, shape: CubeShapeFunction):
33
+ self._shape = shape
34
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
35
+ self.element_node_index = self._make_element_node_index()
36
+
37
+ @property
38
+ def name(self):
39
+ return f"{self.geometry.name}_{self._shape.name}"
40
+
41
+ @wp.func
42
+ def _vertex_coords(vidx_in_cell: int):
43
+ x = vidx_in_cell // 4
44
+ y = (vidx_in_cell - 4 * x) // 2
45
+ z = vidx_in_cell - 4 * x - 2 * y
46
+ return wp.vec3i(x, y, z)
47
+
48
+ @wp.func
49
+ def _vertex_index(cell_arg: Grid3D.CellArg, cell_index: ElementIndex, vidx_in_cell: int):
50
+ res = cell_arg.res
51
+ strides = wp.vec2i((res[1] + 1) * (res[2] + 1), res[2] + 1)
52
+
53
+ corner = Grid3D.get_cell(res, cell_index) + Grid3DSpaceTopology._vertex_coords(vidx_in_cell)
54
+ return Grid3D._from_3d_index(strides, corner)
55
+
56
+ def node_count(self) -> int:
57
+ return (
58
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
59
+ + self.geometry.edge_count() * self._shape.EDGE_NODE_COUNT
60
+ + self.geometry.side_count() * self._shape.FACE_NODE_COUNT
61
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
62
+ )
63
+
64
+ def _make_element_node_index(self):
65
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
66
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
67
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
68
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
69
+
70
+ @cache.dynamic_func(suffix=self.name)
71
+ def element_node_index(
72
+ cell_arg: Grid3D.CellArg,
73
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
74
+ element_index: ElementIndex,
75
+ node_index_in_elt: int,
76
+ ):
77
+ res = cell_arg.res
78
+ cell = Grid3D.get_cell(res, element_index)
79
+
80
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
81
+
82
+ if wp.static(VERTEX_NODE_COUNT > 0):
83
+ if node_type == CubeShapeFunction.VERTEX:
84
+ return (
85
+ Grid3DSpaceTopology._vertex_index(cell_arg, element_index, type_instance) * VERTEX_NODE_COUNT
86
+ + type_index
87
+ )
88
+
89
+ res = cell_arg.res
90
+ vertex_count = (res[0] + 1) * (res[1] + 1) * (res[2] + 1)
91
+ global_offset = vertex_count * VERTEX_NODE_COUNT
92
+
93
+ if wp.static(EDGE_NODE_COUNT > 0):
94
+ if node_type == CubeShapeFunction.EDGE:
95
+ axis = CubeShapeFunction._edge_axis(type_instance)
96
+ node_all = CubeShapeFunction._edge_coords(type_instance, type_index)
97
+
98
+ res = cell_arg.res
99
+
100
+ edge_index = 0
101
+ if axis > 0:
102
+ edge_index += (res[1] + 1) * (res[2] + 1) * res[0]
103
+ if axis > 1:
104
+ edge_index += (res[0] + 1) * (res[2] + 1) * res[1]
105
+
106
+ res_loc = Grid3D._world_to_local(axis, res)
107
+ cell_loc = Grid3D._world_to_local(axis, cell)
108
+
109
+ edge_index += (res_loc[1] + 1) * (res_loc[2] + 1) * cell_loc[0]
110
+ edge_index += (res_loc[2] + 1) * (cell_loc[1] + node_all[1])
111
+ edge_index += cell_loc[2] + node_all[2]
112
+
113
+ return global_offset + EDGE_NODE_COUNT * edge_index + type_index
114
+
115
+ edge_count = (
116
+ (res[0] + 1) * (res[1] + 1) * (res[2])
117
+ + (res[0]) * (res[1] + 1) * (res[2] + 1)
118
+ + (res[0] + 1) * (res[1]) * (res[2] + 1)
119
+ )
120
+ global_offset += edge_count * EDGE_NODE_COUNT
121
+
122
+ if wp.static(FACE_NODE_COUNT > 0):
123
+ if node_type == CubeShapeFunction.FACE:
124
+ axis = CubeShapeFunction._face_axis(type_instance)
125
+ face_offset = CubeShapeFunction._face_offset(type_instance)
126
+
127
+ face_index = 0
128
+ if axis > 0:
129
+ face_index += (res[0] + 1) * res[1] * res[2]
130
+ if axis > 1:
131
+ face_index += (res[1] + 1) * res[2] * res[0]
132
+
133
+ res_loc = Grid3D._world_to_local(axis, res)
134
+ cell_loc = Grid3D._world_to_local(axis, cell)
135
+
136
+ face_index += res_loc[1] * res_loc[2] * (cell_loc[0] + face_offset)
137
+ face_index += res_loc[2] * cell_loc[1]
138
+ face_index += cell_loc[2]
139
+
140
+ return global_offset + FACE_NODE_COUNT * face_index + type_index
141
+
142
+ face_count = (
143
+ (res[0] + 1) * res[1] * res[2] + res[0] * (res[1] + 1) * res[2] + res[0] * res[1] * (res[2] + 1)
144
+ )
145
+ global_offset += face_count * FACE_NODE_COUNT
146
+
147
+ # interior
148
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
149
+
150
+ return element_node_index
151
+
152
+
153
+ class GridTripolynomialSpaceTopology(SpaceTopology):
154
+ def __init__(self, grid: Grid3D, shape: CubeTripolynomialShapeFunctions):
155
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
156
+ self._shape = shape
157
+
158
+ self.element_node_index = self._make_element_node_index()
159
+
160
+ def node_count(self) -> int:
161
+ return (
162
+ (self.geometry.res[0] * self._shape.ORDER + 1)
163
+ * (self.geometry.res[1] * self._shape.ORDER + 1)
164
+ * (self.geometry.res[2] * self._shape.ORDER + 1)
165
+ )
166
+
167
+ def _make_element_node_index(self):
168
+ ORDER = self._shape.ORDER
169
+
170
+ @cache.dynamic_func(suffix=self.name)
171
+ def element_node_index(
172
+ cell_arg: Grid3D.CellArg,
173
+ topo_arg: self.TopologyArg,
174
+ element_index: ElementIndex,
175
+ node_index_in_elt: int,
176
+ ):
177
+ res = cell_arg.res
178
+ cell = Grid3D.get_cell(res, element_index)
179
+
180
+ node_i, node_j, node_k = self._shape._node_ijk(node_index_in_elt)
181
+
182
+ node_x = ORDER * cell[0] + node_i
183
+ node_y = ORDER * cell[1] + node_j
184
+ node_z = ORDER * cell[2] + node_k
185
+
186
+ node_pitch_y = (res[2] * ORDER) + 1
187
+ node_pitch_x = node_pitch_y * ((res[1] * ORDER) + 1)
188
+ node_index = node_pitch_x * node_x + node_pitch_y * node_y + node_z
189
+
190
+ return node_index
191
+
192
+ return element_node_index
193
+
194
+ def node_grid(self):
195
+ res = self.geometry.res
196
+
197
+ cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
198
+
199
+ grid_coords_x = np.repeat(np.arange(0, res[0], dtype=float), len(cell_coords)) + np.tile(
200
+ cell_coords, reps=res[0]
201
+ )
202
+ grid_coords_x = np.append(grid_coords_x, res[0])
203
+ X = grid_coords_x * self.geometry.cell_size[0] + self.geometry.origin[0]
204
+
205
+ grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
206
+ cell_coords, reps=res[1]
207
+ )
208
+ grid_coords_y = np.append(grid_coords_y, res[1])
209
+ Y = grid_coords_y * self.geometry.cell_size[1] + self.geometry.origin[1]
210
+
211
+ grid_coords_z = np.repeat(np.arange(0, res[2], dtype=float), len(cell_coords)) + np.tile(
212
+ cell_coords, reps=res[2]
213
+ )
214
+ grid_coords_z = np.append(grid_coords_z, res[2])
215
+ Z = grid_coords_z * self.geometry.cell_size[2] + self.geometry.origin[2]
216
+
217
+ return np.meshgrid(X, Y, Z, indexing="ij")
218
+
219
+
220
+ def make_grid_3d_space_topology(grid: Grid3D, shape: CubeShapeFunction):
221
+ if isinstance(shape, CubeTripolynomialShapeFunctions) and is_closed(shape.family):
222
+ return forward_base_topology(GridTripolynomialSpaceTopology, grid, shape)
223
+
224
+ if isinstance(shape, CubeShapeFunction):
225
+ return forward_base_topology(Grid3DSpaceTopology, grid, shape)
226
+
227
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,257 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp.fem import cache
18
+ from warp.fem.geometry import Hexmesh
19
+ from warp.fem.geometry.hexmesh import (
20
+ EDGE_VERTEX_INDICES,
21
+ FACE_ORIENTATION,
22
+ FACE_TRANSLATION,
23
+ )
24
+ from warp.fem.types import ElementIndex
25
+
26
+ from .shape import CubeShapeFunction
27
+ from .topology import SpaceTopology, forward_base_topology
28
+
29
+ _FACE_ORIENTATION_I = wp.constant(wp.mat(shape=(16, 2), dtype=int)(FACE_ORIENTATION))
30
+ _FACE_TRANSLATION_I = wp.constant(wp.mat(shape=(4, 2), dtype=int)(FACE_TRANSLATION))
31
+
32
+ # map from shape function vertex indexing to hexmesh vertex indexing
33
+ _CUBE_TO_HEX_VERTEX = wp.constant(wp.vec(length=8, dtype=int)([0, 4, 3, 7, 1, 5, 2, 6]))
34
+
35
+ # map from shape function edge indexing to hexmesh edge indexing
36
+ _CUBE_TO_HEX_EDGE = wp.constant(wp.vec(length=12, dtype=int)([0, 4, 2, 6, 3, 1, 7, 5, 8, 11, 9, 10]))
37
+
38
+
39
+ @wp.struct
40
+ class HexmeshTopologyArg:
41
+ hex_edge_indices: wp.array2d(dtype=int)
42
+ hex_face_indices: wp.array2d(dtype=wp.vec2i)
43
+
44
+ vertex_count: int
45
+ edge_count: int
46
+ face_count: int
47
+
48
+
49
+ class HexmeshSpaceTopology(SpaceTopology):
50
+ TopologyArg = HexmeshTopologyArg
51
+
52
+ def __init__(
53
+ self,
54
+ mesh: Hexmesh,
55
+ shape: CubeShapeFunction,
56
+ ):
57
+ self._shape = shape
58
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
59
+ self._mesh = mesh
60
+
61
+ need_edge_indices = shape.EDGE_NODE_COUNT > 0
62
+ need_face_indices = shape.FACE_NODE_COUNT > 0
63
+
64
+ if need_edge_indices:
65
+ self._hex_edge_indices = self._mesh.hex_edge_indices
66
+ self._edge_count = self._mesh.edge_count()
67
+ else:
68
+ self._hex_edge_indices = wp.empty(shape=(0, 0), dtype=int)
69
+ self._edge_count = 0
70
+
71
+ if need_face_indices:
72
+ self._compute_hex_face_indices()
73
+ else:
74
+ self._hex_face_indices = wp.empty(shape=(0, 0), dtype=wp.vec2i)
75
+
76
+ self._compute_hex_face_indices()
77
+
78
+ self.element_node_index = self._make_element_node_index()
79
+ self.element_node_sign = self._make_element_node_sign()
80
+
81
+ @property
82
+ def name(self):
83
+ return f"{self.geometry.name}_{self._shape.name}"
84
+
85
+ @cache.cached_arg_value
86
+ def topo_arg_value(self, device):
87
+ arg = HexmeshTopologyArg()
88
+ arg.hex_edge_indices = self._hex_edge_indices.to(device)
89
+ arg.hex_face_indices = self._hex_face_indices.to(device)
90
+
91
+ arg.vertex_count = self._mesh.vertex_count()
92
+ arg.face_count = self._mesh.side_count()
93
+ arg.edge_count = self._edge_count
94
+ return arg
95
+
96
+ def _compute_hex_face_indices(self):
97
+ self._hex_face_indices = wp.empty(
98
+ dtype=wp.vec2i, device=self._mesh.hex_vertex_indices.device, shape=(self._mesh.cell_count(), 6)
99
+ )
100
+
101
+ wp.launch(
102
+ kernel=HexmeshSpaceTopology._compute_hex_face_indices_kernel,
103
+ dim=self._mesh.side_count(),
104
+ device=self._mesh.hex_vertex_indices.device,
105
+ inputs=[
106
+ self._mesh.face_hex_indices,
107
+ self._mesh._face_hex_face_orientation,
108
+ self._hex_face_indices,
109
+ ],
110
+ )
111
+
112
+ @wp.kernel
113
+ def _compute_hex_face_indices_kernel(
114
+ face_hex_indices: wp.array(dtype=wp.vec2i),
115
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
116
+ hex_face_indices: wp.array2d(dtype=wp.vec2i),
117
+ ):
118
+ f = wp.tid()
119
+
120
+ # face indices from CubeShapeFunction always have positive orientation,
121
+ # while Hexmesh faces are oriented to point "outside"
122
+ # We need to flip orientation for faces at offset 0
123
+
124
+ hx0 = face_hex_indices[f][0]
125
+ local_face_0 = face_hex_face_ori[f][0]
126
+ ori_0 = face_hex_face_ori[f][1]
127
+
128
+ local_face_offset_0 = CubeShapeFunction._face_offset(local_face_0)
129
+ flip_0 = ori_0 ^ (1 - local_face_offset_0)
130
+
131
+ hex_face_indices[hx0, local_face_0] = wp.vec2i(f, flip_0)
132
+
133
+ hx1 = face_hex_indices[f][1]
134
+ local_face_1 = face_hex_face_ori[f][2]
135
+ ori_1 = face_hex_face_ori[f][3]
136
+
137
+ local_face_offset_1 = CubeShapeFunction._face_offset(local_face_1)
138
+ flip_1 = ori_1 ^ (1 - local_face_offset_1)
139
+
140
+ hex_face_indices[hx1, local_face_1] = wp.vec2i(f, flip_1)
141
+
142
+ def node_count(self) -> int:
143
+ return (
144
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
145
+ + self._mesh.edge_count() * self._shape.EDGE_NODE_COUNT
146
+ + self._mesh.side_count() * self._shape.FACE_NODE_COUNT
147
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
148
+ )
149
+
150
+ @wp.func
151
+ def _rotate_face_coordinates(ori: int, offset: int, coords: wp.vec2i):
152
+ fv = ori // 2
153
+
154
+ rot_i = wp.dot(_FACE_ORIENTATION_I[2 * ori], coords)
155
+ rot_j = wp.dot(_FACE_ORIENTATION_I[2 * ori + 1], coords)
156
+
157
+ return wp.vec2i(rot_i, rot_j) + _FACE_TRANSLATION_I[fv]
158
+
159
+ def _make_element_node_index(self):
160
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
161
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
162
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
163
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
164
+
165
+ @cache.dynamic_func(suffix=self.name)
166
+ def element_node_index(
167
+ geo_arg: Hexmesh.CellArg,
168
+ topo_arg: HexmeshTopologyArg,
169
+ element_index: ElementIndex,
170
+ node_index_in_elt: int,
171
+ ):
172
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
173
+
174
+ if wp.static(VERTEX_NODE_COUNT > 0):
175
+ if node_type == CubeShapeFunction.VERTEX:
176
+ return (
177
+ geo_arg.hex_vertex_indices[element_index, _CUBE_TO_HEX_VERTEX[type_instance]]
178
+ * VERTEX_NODE_COUNT
179
+ + type_index
180
+ )
181
+
182
+ offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
183
+
184
+ if wp.static(EDGE_NODE_COUNT > 0):
185
+ if node_type == CubeShapeFunction.EDGE:
186
+ hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
187
+ edge_index = topo_arg.hex_edge_indices[element_index, hex_edge]
188
+
189
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
190
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
191
+
192
+ if v0 > v1:
193
+ type_index = EDGE_NODE_COUNT - 1 - type_index
194
+
195
+ return offset + EDGE_NODE_COUNT * edge_index + type_index
196
+
197
+ offset += EDGE_NODE_COUNT * topo_arg.edge_count
198
+
199
+ if wp.static(FACE_NODE_COUNT > 0):
200
+ if node_type == CubeShapeFunction.FACE:
201
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
202
+ face_index = face_index_and_ori[0]
203
+ face_orientation = face_index_and_ori[1]
204
+
205
+ face_offset = CubeShapeFunction._face_offset(type_instance)
206
+
207
+ if wp.static(FACE_NODE_COUNT > 1):
208
+ face_coords = self._shape._face_node_ij(type_index)
209
+ rot_face_coords = HexmeshSpaceTopology._rotate_face_coordinates(
210
+ face_orientation, face_offset, face_coords
211
+ )
212
+ type_index = self._shape._linear_face_node_index(type_index, rot_face_coords)
213
+
214
+ return offset + FACE_NODE_COUNT * face_index + type_index
215
+
216
+ offset += FACE_NODE_COUNT * topo_arg.face_count
217
+
218
+ return offset + INTERIOR_NODE_COUNT * element_index + type_index
219
+
220
+ return element_node_index
221
+
222
+ def _make_element_node_sign(self):
223
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
224
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
225
+
226
+ @cache.dynamic_func(suffix=self.name)
227
+ def element_node_sign(
228
+ geo_arg: self.geometry.CellArg,
229
+ topo_arg: HexmeshTopologyArg,
230
+ element_index: ElementIndex,
231
+ node_index_in_elt: int,
232
+ ):
233
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
234
+
235
+ if wp.static(EDGE_NODE_COUNT > 0):
236
+ if node_type == CubeShapeFunction.EDGE:
237
+ hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
238
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
239
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
240
+ return wp.where(v0 > v1, -1.0, 1.0)
241
+
242
+ if wp.static(FACE_NODE_COUNT > 0):
243
+ if node_type == CubeShapeFunction.FACE:
244
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
245
+ flip = face_index_and_ori[1] & 1
246
+ return wp.where(flip == 0, 1.0, -1.0)
247
+
248
+ return 1.0
249
+
250
+ return element_node_sign
251
+
252
+
253
+ def make_hexmesh_space_topology(mesh: Hexmesh, shape: CubeShapeFunction):
254
+ if isinstance(shape, CubeShapeFunction):
255
+ return forward_base_topology(HexmeshSpaceTopology, mesh, shape)
256
+
257
+ raise ValueError(f"Unsupported shape function {shape.name}")