warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,179 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Grid2D
21
+ from warp._src.fem.polynomial import is_closed
22
+ from warp._src.fem.types import NULL_NODE_INDEX, ElementIndex
23
+
24
+ from .shape import SquareBipolynomialShapeFunctions, SquareShapeFunction
25
+ from .topology import SpaceTopology, forward_base_topology
26
+
27
+ _wp_module_name_ = "warp.fem.space.grid_2d_function_space"
28
+
29
+
30
+ class Grid2DSpaceTopology(SpaceTopology):
31
+ def __init__(self, grid: Grid2D, shape: SquareShapeFunction):
32
+ self._shape = shape
33
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
34
+
35
+ self.element_node_index = self._make_element_node_index()
36
+
37
+ TopologyArg = Grid2D.SideArg
38
+
39
+ @property
40
+ def name(self):
41
+ return f"{self.geometry.name}_{self._shape.name}"
42
+
43
+ def fill_topo_arg(self, arg: Grid2D.SideArg, device):
44
+ self.geometry.fill_side_arg(arg, device)
45
+
46
+ def node_count(self) -> int:
47
+ return (
48
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
49
+ + self.geometry.side_count() * self._shape.EDGE_NODE_COUNT
50
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
51
+ )
52
+
53
+ def _make_element_node_index(self):
54
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
55
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
56
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
57
+
58
+ @cache.dynamic_func(suffix=self.name)
59
+ def element_node_index(
60
+ cell_arg: Grid2D.CellArg,
61
+ topo_arg: Grid2D.SideArg,
62
+ element_index: ElementIndex,
63
+ node_index_in_elt: int,
64
+ ):
65
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
66
+
67
+ if wp.static(VERTEX_NODE_COUNT > 0):
68
+ if node_type == SquareShapeFunction.VERTEX:
69
+ return (
70
+ Grid2DSpaceTopology._vertex_index(cell_arg, element_index, type_instance) * VERTEX_NODE_COUNT
71
+ + type_index
72
+ )
73
+
74
+ res = cell_arg.res
75
+ vertex_count = (res[0] + 1) * (res[1] + 1)
76
+ global_offset = vertex_count
77
+
78
+ if wp.static(INTERIOR_NODE_COUNT > 0):
79
+ if node_type == SquareShapeFunction.INTERIOR:
80
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
81
+
82
+ cell_count = res[0] * res[1]
83
+ global_offset += INTERIOR_NODE_COUNT * cell_count
84
+
85
+ if wp.static(EDGE_NODE_COUNT > 0):
86
+ axis = 1 - (node_type - SquareShapeFunction.EDGE_X)
87
+
88
+ cell = Grid2D.get_cell(cell_arg.res, element_index)
89
+ origin = Grid2D.orient(axis, cell) + wp.vec2i(type_instance, 0)
90
+
91
+ side = Grid2D.Side(axis, origin)
92
+ side_index = Grid2D.side_index(topo_arg, side)
93
+
94
+ vertex_count = (res[0] + 1) * (res[1] + 1)
95
+
96
+ return global_offset + EDGE_NODE_COUNT * side_index + type_index
97
+
98
+ return NULL_NODE_INDEX # unreachable
99
+
100
+ return element_node_index
101
+
102
+ @wp.func
103
+ def _vertex_coords(vidx_in_cell: int):
104
+ x = vidx_in_cell // 2
105
+ y = vidx_in_cell - 2 * x
106
+ return wp.vec2i(x, y)
107
+
108
+ @wp.func
109
+ def _vertex_index(cell_arg: Grid2D.CellArg, cell_index: ElementIndex, vidx_in_cell: int):
110
+ res = cell_arg.res
111
+ x_stride = res[1] + 1
112
+
113
+ corner = Grid2D.get_cell(res, cell_index) + Grid2DSpaceTopology._vertex_coords(vidx_in_cell)
114
+ return Grid2D._from_2d_index(x_stride, corner)
115
+
116
+
117
+ class GridBipolynomialSpaceTopology(SpaceTopology):
118
+ def __init__(self, grid: Grid2D, shape: SquareBipolynomialShapeFunctions):
119
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
120
+ self._shape = shape
121
+ self.element_node_index = self._make_element_node_index()
122
+
123
+ def node_count(self) -> int:
124
+ return (self.geometry.res[0] * self._shape.ORDER + 1) * (self.geometry.res[1] * self._shape.ORDER + 1)
125
+
126
+ def _make_element_node_index(self):
127
+ ORDER = self._shape.ORDER
128
+
129
+ @cache.dynamic_func(suffix=self.name)
130
+ def element_node_index(
131
+ cell_arg: Grid2D.CellArg,
132
+ topo_arg: self.TopologyArg,
133
+ element_index: ElementIndex,
134
+ node_index_in_elt: int,
135
+ ):
136
+ res = cell_arg.res
137
+ cell = Grid2D.get_cell(res, element_index)
138
+
139
+ node_i = node_index_in_elt // (ORDER + 1)
140
+ node_j = node_index_in_elt - (ORDER + 1) * node_i
141
+
142
+ node_x = ORDER * cell[0] + node_i
143
+ node_y = ORDER * cell[1] + node_j
144
+
145
+ node_pitch = (res[1] * ORDER) + 1
146
+ node_index = node_pitch * node_x + node_y
147
+
148
+ return node_index
149
+
150
+ return element_node_index
151
+
152
+ def node_grid(self):
153
+ res = self.geometry.res
154
+
155
+ cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
156
+
157
+ grid_coords_x = np.repeat(np.arange(0, res[0], dtype=float), len(cell_coords)) + np.tile(
158
+ cell_coords, reps=res[0]
159
+ )
160
+ grid_coords_x = np.append(grid_coords_x, res[0])
161
+ X = grid_coords_x * self.geometry.cell_size[0] + self.geometry.origin[0]
162
+
163
+ grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
164
+ cell_coords, reps=res[1]
165
+ )
166
+ grid_coords_y = np.append(grid_coords_y, res[1])
167
+ Y = grid_coords_y * self.geometry.cell_size[1] + self.geometry.origin[1]
168
+
169
+ return np.meshgrid(X, Y, indexing="ij")
170
+
171
+
172
+ def make_grid_2d_space_topology(grid: Grid2D, shape: SquareShapeFunction):
173
+ if isinstance(shape, SquareBipolynomialShapeFunctions) and is_closed(shape.family):
174
+ return forward_base_topology(GridBipolynomialSpaceTopology, grid, shape)
175
+
176
+ if isinstance(shape, SquareShapeFunction):
177
+ return forward_base_topology(Grid2DSpaceTopology, grid, shape)
178
+
179
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,229 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Grid3D
21
+ from warp._src.fem.polynomial import is_closed
22
+ from warp._src.fem.types import ElementIndex
23
+
24
+ from .shape import (
25
+ CubeShapeFunction,
26
+ CubeTripolynomialShapeFunctions,
27
+ )
28
+ from .topology import SpaceTopology, forward_base_topology
29
+
30
+ _wp_module_name_ = "warp.fem.space.grid_3d_function_space"
31
+
32
+
33
+ class Grid3DSpaceTopology(SpaceTopology):
34
+ def __init__(self, grid: Grid3D, shape: CubeShapeFunction):
35
+ self._shape = shape
36
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
37
+ self.element_node_index = self._make_element_node_index()
38
+
39
+ @property
40
+ def name(self):
41
+ return f"{self.geometry.name}_{self._shape.name}"
42
+
43
+ @wp.func
44
+ def _vertex_coords(vidx_in_cell: int):
45
+ x = vidx_in_cell // 4
46
+ y = (vidx_in_cell - 4 * x) // 2
47
+ z = vidx_in_cell - 4 * x - 2 * y
48
+ return wp.vec3i(x, y, z)
49
+
50
+ @wp.func
51
+ def _vertex_index(cell_arg: Grid3D.CellArg, cell_index: ElementIndex, vidx_in_cell: int):
52
+ res = cell_arg.res
53
+ strides = wp.vec2i((res[1] + 1) * (res[2] + 1), res[2] + 1)
54
+
55
+ corner = Grid3D.get_cell(res, cell_index) + Grid3DSpaceTopology._vertex_coords(vidx_in_cell)
56
+ return Grid3D._from_3d_index(strides, corner)
57
+
58
+ def node_count(self) -> int:
59
+ return (
60
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
61
+ + self.geometry.edge_count() * self._shape.EDGE_NODE_COUNT
62
+ + self.geometry.side_count() * self._shape.FACE_NODE_COUNT
63
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
64
+ )
65
+
66
+ def _make_element_node_index(self):
67
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
68
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
69
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
70
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
71
+
72
+ @cache.dynamic_func(suffix=self.name)
73
+ def element_node_index(
74
+ cell_arg: Grid3D.CellArg,
75
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
76
+ element_index: ElementIndex,
77
+ node_index_in_elt: int,
78
+ ):
79
+ res = cell_arg.res
80
+ cell = Grid3D.get_cell(res, element_index)
81
+
82
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
83
+
84
+ if wp.static(VERTEX_NODE_COUNT > 0):
85
+ if node_type == CubeShapeFunction.VERTEX:
86
+ return (
87
+ Grid3DSpaceTopology._vertex_index(cell_arg, element_index, type_instance) * VERTEX_NODE_COUNT
88
+ + type_index
89
+ )
90
+
91
+ res = cell_arg.res
92
+ vertex_count = (res[0] + 1) * (res[1] + 1) * (res[2] + 1)
93
+ global_offset = vertex_count * VERTEX_NODE_COUNT
94
+
95
+ if wp.static(EDGE_NODE_COUNT > 0):
96
+ if node_type == CubeShapeFunction.EDGE:
97
+ axis = CubeShapeFunction._edge_axis(type_instance)
98
+ node_all = CubeShapeFunction._edge_coords(type_instance, type_index)
99
+
100
+ res = cell_arg.res
101
+
102
+ edge_index = 0
103
+ if axis > 0:
104
+ edge_index += (res[1] + 1) * (res[2] + 1) * res[0]
105
+ if axis > 1:
106
+ edge_index += (res[0] + 1) * (res[2] + 1) * res[1]
107
+
108
+ res_loc = Grid3D._world_to_local(axis, res)
109
+ cell_loc = Grid3D._world_to_local(axis, cell)
110
+
111
+ edge_index += (res_loc[1] + 1) * (res_loc[2] + 1) * cell_loc[0]
112
+ edge_index += (res_loc[2] + 1) * (cell_loc[1] + node_all[1])
113
+ edge_index += cell_loc[2] + node_all[2]
114
+
115
+ return global_offset + EDGE_NODE_COUNT * edge_index + type_index
116
+
117
+ edge_count = (
118
+ (res[0] + 1) * (res[1] + 1) * (res[2])
119
+ + (res[0]) * (res[1] + 1) * (res[2] + 1)
120
+ + (res[0] + 1) * (res[1]) * (res[2] + 1)
121
+ )
122
+ global_offset += edge_count * EDGE_NODE_COUNT
123
+
124
+ if wp.static(FACE_NODE_COUNT > 0):
125
+ if node_type == CubeShapeFunction.FACE:
126
+ axis = CubeShapeFunction._face_axis(type_instance)
127
+ face_offset = CubeShapeFunction._face_offset(type_instance)
128
+
129
+ face_index = 0
130
+ if axis > 0:
131
+ face_index += (res[0] + 1) * res[1] * res[2]
132
+ if axis > 1:
133
+ face_index += (res[1] + 1) * res[2] * res[0]
134
+
135
+ res_loc = Grid3D._world_to_local(axis, res)
136
+ cell_loc = Grid3D._world_to_local(axis, cell)
137
+
138
+ face_index += res_loc[1] * res_loc[2] * (cell_loc[0] + face_offset)
139
+ face_index += res_loc[2] * cell_loc[1]
140
+ face_index += cell_loc[2]
141
+
142
+ return global_offset + FACE_NODE_COUNT * face_index + type_index
143
+
144
+ face_count = (
145
+ (res[0] + 1) * res[1] * res[2] + res[0] * (res[1] + 1) * res[2] + res[0] * res[1] * (res[2] + 1)
146
+ )
147
+ global_offset += face_count * FACE_NODE_COUNT
148
+
149
+ # interior
150
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
151
+
152
+ return element_node_index
153
+
154
+
155
+ class GridTripolynomialSpaceTopology(SpaceTopology):
156
+ def __init__(self, grid: Grid3D, shape: CubeTripolynomialShapeFunctions):
157
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
158
+ self._shape = shape
159
+
160
+ self.element_node_index = self._make_element_node_index()
161
+
162
+ def node_count(self) -> int:
163
+ return (
164
+ (self.geometry.res[0] * self._shape.ORDER + 1)
165
+ * (self.geometry.res[1] * self._shape.ORDER + 1)
166
+ * (self.geometry.res[2] * self._shape.ORDER + 1)
167
+ )
168
+
169
+ def _make_element_node_index(self):
170
+ ORDER = self._shape.ORDER
171
+
172
+ @cache.dynamic_func(suffix=self.name)
173
+ def element_node_index(
174
+ cell_arg: Grid3D.CellArg,
175
+ topo_arg: self.TopologyArg,
176
+ element_index: ElementIndex,
177
+ node_index_in_elt: int,
178
+ ):
179
+ res = cell_arg.res
180
+ cell = Grid3D.get_cell(res, element_index)
181
+
182
+ node_i, node_j, node_k = self._shape._node_ijk(node_index_in_elt)
183
+
184
+ node_x = ORDER * cell[0] + node_i
185
+ node_y = ORDER * cell[1] + node_j
186
+ node_z = ORDER * cell[2] + node_k
187
+
188
+ node_pitch_y = (res[2] * ORDER) + 1
189
+ node_pitch_x = node_pitch_y * ((res[1] * ORDER) + 1)
190
+ node_index = node_pitch_x * node_x + node_pitch_y * node_y + node_z
191
+
192
+ return node_index
193
+
194
+ return element_node_index
195
+
196
+ def node_grid(self):
197
+ res = self.geometry.res
198
+
199
+ cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
200
+
201
+ grid_coords_x = np.repeat(np.arange(0, res[0], dtype=float), len(cell_coords)) + np.tile(
202
+ cell_coords, reps=res[0]
203
+ )
204
+ grid_coords_x = np.append(grid_coords_x, res[0])
205
+ X = grid_coords_x * self.geometry.cell_size[0] + self.geometry.origin[0]
206
+
207
+ grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
208
+ cell_coords, reps=res[1]
209
+ )
210
+ grid_coords_y = np.append(grid_coords_y, res[1])
211
+ Y = grid_coords_y * self.geometry.cell_size[1] + self.geometry.origin[1]
212
+
213
+ grid_coords_z = np.repeat(np.arange(0, res[2], dtype=float), len(cell_coords)) + np.tile(
214
+ cell_coords, reps=res[2]
215
+ )
216
+ grid_coords_z = np.append(grid_coords_z, res[2])
217
+ Z = grid_coords_z * self.geometry.cell_size[2] + self.geometry.origin[2]
218
+
219
+ return np.meshgrid(X, Y, Z, indexing="ij")
220
+
221
+
222
+ def make_grid_3d_space_topology(grid: Grid3D, shape: CubeShapeFunction):
223
+ if isinstance(shape, CubeTripolynomialShapeFunctions) and is_closed(shape.family):
224
+ return forward_base_topology(GridTripolynomialSpaceTopology, grid, shape)
225
+
226
+ if isinstance(shape, CubeShapeFunction):
227
+ return forward_base_topology(Grid3DSpaceTopology, grid, shape)
228
+
229
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,255 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Hexmesh
19
+ from warp._src.fem.geometry.hexmesh import (
20
+ EDGE_VERTEX_INDICES,
21
+ FACE_ORIENTATION,
22
+ FACE_TRANSLATION,
23
+ )
24
+ from warp._src.fem.types import ElementIndex
25
+
26
+ from .shape import CubeShapeFunction
27
+ from .topology import SpaceTopology, forward_base_topology
28
+
29
+ _wp_module_name_ = "warp.fem.space.hexmesh_function_space"
30
+
31
+ _FACE_ORIENTATION_I = wp.constant(wp.mat(shape=(16, 2), dtype=int)(FACE_ORIENTATION))
32
+ _FACE_TRANSLATION_I = wp.constant(wp.mat(shape=(4, 2), dtype=int)(FACE_TRANSLATION))
33
+
34
+ # map from shape function vertex indexing to hexmesh vertex indexing
35
+ _CUBE_TO_HEX_VERTEX = wp.constant(wp.vec(length=8, dtype=int)([0, 4, 3, 7, 1, 5, 2, 6]))
36
+
37
+ # map from shape function edge indexing to hexmesh edge indexing
38
+ _CUBE_TO_HEX_EDGE = wp.constant(wp.vec(length=12, dtype=int)([0, 4, 2, 6, 3, 1, 7, 5, 8, 11, 9, 10]))
39
+
40
+
41
+ @wp.struct
42
+ class HexmeshTopologyArg:
43
+ hex_edge_indices: wp.array2d(dtype=int)
44
+ hex_face_indices: wp.array2d(dtype=wp.vec2i)
45
+
46
+ vertex_count: int
47
+ edge_count: int
48
+ face_count: int
49
+
50
+
51
+ class HexmeshSpaceTopology(SpaceTopology):
52
+ TopologyArg = HexmeshTopologyArg
53
+
54
+ def __init__(
55
+ self,
56
+ mesh: Hexmesh,
57
+ shape: CubeShapeFunction,
58
+ ):
59
+ self._shape = shape
60
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
61
+ self._mesh = mesh
62
+
63
+ need_edge_indices = shape.EDGE_NODE_COUNT > 0
64
+ need_face_indices = shape.FACE_NODE_COUNT > 0
65
+
66
+ if need_edge_indices:
67
+ self._hex_edge_indices = self._mesh.hex_edge_indices
68
+ self._edge_count = self._mesh.edge_count()
69
+ else:
70
+ self._hex_edge_indices = wp.empty(shape=(0, 0), dtype=int)
71
+ self._edge_count = 0
72
+
73
+ if need_face_indices:
74
+ self._compute_hex_face_indices()
75
+ else:
76
+ self._hex_face_indices = wp.empty(shape=(0, 0), dtype=wp.vec2i)
77
+
78
+ self._compute_hex_face_indices()
79
+
80
+ self.element_node_index = self._make_element_node_index()
81
+ self.element_node_sign = self._make_element_node_sign()
82
+
83
+ @property
84
+ def name(self):
85
+ return f"{self.geometry.name}_{self._shape.name}"
86
+
87
+ def fill_topo_arg(self, arg: HexmeshTopologyArg, device):
88
+ arg.hex_edge_indices = self._hex_edge_indices.to(device)
89
+ arg.hex_face_indices = self._hex_face_indices.to(device)
90
+ arg.vertex_count = self._mesh.vertex_count()
91
+ arg.face_count = self._mesh.side_count()
92
+ arg.edge_count = self._edge_count
93
+
94
+ def _compute_hex_face_indices(self):
95
+ self._hex_face_indices = wp.empty(
96
+ dtype=wp.vec2i, device=self._mesh.hex_vertex_indices.device, shape=(self._mesh.cell_count(), 6)
97
+ )
98
+
99
+ wp.launch(
100
+ kernel=HexmeshSpaceTopology._compute_hex_face_indices_kernel,
101
+ dim=self._mesh.side_count(),
102
+ device=self._mesh.hex_vertex_indices.device,
103
+ inputs=[
104
+ self._mesh.face_hex_indices,
105
+ self._mesh._face_hex_face_orientation,
106
+ self._hex_face_indices,
107
+ ],
108
+ )
109
+
110
+ @wp.kernel
111
+ def _compute_hex_face_indices_kernel(
112
+ face_hex_indices: wp.array(dtype=wp.vec2i),
113
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
114
+ hex_face_indices: wp.array2d(dtype=wp.vec2i),
115
+ ):
116
+ f = wp.tid()
117
+
118
+ # face indices from CubeShapeFunction always have positive orientation,
119
+ # while Hexmesh faces are oriented to point "outside"
120
+ # We need to flip orientation for faces at offset 0
121
+
122
+ hx0 = face_hex_indices[f][0]
123
+ local_face_0 = face_hex_face_ori[f][0]
124
+ ori_0 = face_hex_face_ori[f][1]
125
+
126
+ local_face_offset_0 = CubeShapeFunction._face_offset(local_face_0)
127
+ flip_0 = ori_0 ^ (1 - local_face_offset_0)
128
+
129
+ hex_face_indices[hx0, local_face_0] = wp.vec2i(f, flip_0)
130
+
131
+ hx1 = face_hex_indices[f][1]
132
+ local_face_1 = face_hex_face_ori[f][2]
133
+ ori_1 = face_hex_face_ori[f][3]
134
+
135
+ local_face_offset_1 = CubeShapeFunction._face_offset(local_face_1)
136
+ flip_1 = ori_1 ^ (1 - local_face_offset_1)
137
+
138
+ hex_face_indices[hx1, local_face_1] = wp.vec2i(f, flip_1)
139
+
140
+ def node_count(self) -> int:
141
+ return (
142
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
143
+ + self._mesh.edge_count() * self._shape.EDGE_NODE_COUNT
144
+ + self._mesh.side_count() * self._shape.FACE_NODE_COUNT
145
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
146
+ )
147
+
148
+ @wp.func
149
+ def _rotate_face_coordinates(ori: int, offset: int, coords: wp.vec2i):
150
+ fv = ori // 2
151
+
152
+ rot_i = wp.dot(_FACE_ORIENTATION_I[2 * ori], coords)
153
+ rot_j = wp.dot(_FACE_ORIENTATION_I[2 * ori + 1], coords)
154
+
155
+ return wp.vec2i(rot_i, rot_j) + _FACE_TRANSLATION_I[fv]
156
+
157
+ def _make_element_node_index(self):
158
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
159
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
160
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
161
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
162
+
163
+ @cache.dynamic_func(suffix=self.name)
164
+ def element_node_index(
165
+ geo_arg: Hexmesh.CellArg,
166
+ topo_arg: HexmeshTopologyArg,
167
+ element_index: ElementIndex,
168
+ node_index_in_elt: int,
169
+ ):
170
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
171
+
172
+ if wp.static(VERTEX_NODE_COUNT > 0):
173
+ if node_type == CubeShapeFunction.VERTEX:
174
+ return (
175
+ geo_arg.hex_vertex_indices[element_index, _CUBE_TO_HEX_VERTEX[type_instance]]
176
+ * VERTEX_NODE_COUNT
177
+ + type_index
178
+ )
179
+
180
+ offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
181
+
182
+ if wp.static(EDGE_NODE_COUNT > 0):
183
+ if node_type == CubeShapeFunction.EDGE:
184
+ hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
185
+ edge_index = topo_arg.hex_edge_indices[element_index, hex_edge]
186
+
187
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
188
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
189
+
190
+ if v0 > v1:
191
+ type_index = EDGE_NODE_COUNT - 1 - type_index
192
+
193
+ return offset + EDGE_NODE_COUNT * edge_index + type_index
194
+
195
+ offset += EDGE_NODE_COUNT * topo_arg.edge_count
196
+
197
+ if wp.static(FACE_NODE_COUNT > 0):
198
+ if node_type == CubeShapeFunction.FACE:
199
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
200
+ face_index = face_index_and_ori[0]
201
+ face_orientation = face_index_and_ori[1]
202
+
203
+ face_offset = CubeShapeFunction._face_offset(type_instance)
204
+
205
+ if wp.static(FACE_NODE_COUNT > 1):
206
+ face_coords = self._shape._face_node_ij(type_index)
207
+ rot_face_coords = HexmeshSpaceTopology._rotate_face_coordinates(
208
+ face_orientation, face_offset, face_coords
209
+ )
210
+ type_index = self._shape._linear_face_node_index(type_index, rot_face_coords)
211
+
212
+ return offset + FACE_NODE_COUNT * face_index + type_index
213
+
214
+ offset += FACE_NODE_COUNT * topo_arg.face_count
215
+
216
+ return offset + INTERIOR_NODE_COUNT * element_index + type_index
217
+
218
+ return element_node_index
219
+
220
+ def _make_element_node_sign(self):
221
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
222
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
223
+
224
+ @cache.dynamic_func(suffix=self.name)
225
+ def element_node_sign(
226
+ geo_arg: self.geometry.CellArg,
227
+ topo_arg: HexmeshTopologyArg,
228
+ element_index: ElementIndex,
229
+ node_index_in_elt: int,
230
+ ):
231
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
232
+
233
+ if wp.static(EDGE_NODE_COUNT > 0):
234
+ if node_type == CubeShapeFunction.EDGE:
235
+ hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
236
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
237
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
238
+ return wp.where(v0 > v1, -1.0, 1.0)
239
+
240
+ if wp.static(FACE_NODE_COUNT > 0):
241
+ if node_type == CubeShapeFunction.FACE:
242
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
243
+ flip = face_index_and_ori[1] & 1
244
+ return wp.where(flip == 0, 1.0, -1.0)
245
+
246
+ return 1.0
247
+
248
+ return element_node_sign
249
+
250
+
251
+ def make_hexmesh_space_topology(mesh: Hexmesh, shape: CubeShapeFunction):
252
+ if isinstance(shape, CubeShapeFunction):
253
+ return forward_base_topology(HexmeshSpaceTopology, mesh, shape)
254
+
255
+ raise ValueError(f"Unsupported shape function {shape.name}")