warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ import warp as wp
2
+
3
+ from warp.fem.types import ElementIndex, Coords
4
+ from warp.fem.geometry import Tetmesh
5
+ from warp.fem import cache
6
+
7
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
8
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
9
+
10
+ from .shape import ShapeFunction, ConstantShapeFunction
11
+ from .shape import TetrahedronPolynomialShapeFunctions, TetrahedronNonConformingPolynomialShapeFunctions
12
+
13
+
14
+ @wp.struct
15
+ class TetmeshTopologyArg:
16
+ tet_edge_indices: wp.array2d(dtype=int)
17
+ tet_face_indices: wp.array2d(dtype=int)
18
+ face_vertex_indices: wp.array(dtype=wp.vec3i)
19
+
20
+ vertex_count: int
21
+ edge_count: int
22
+ face_count: int
23
+
24
+
25
+ class TetmeshSpaceTopology(SpaceTopology):
26
+ TopologyArg = TetmeshTopologyArg
27
+
28
+ def __init__(
29
+ self,
30
+ mesh: Tetmesh,
31
+ shape: ShapeFunction,
32
+ need_tet_edge_indices: bool = True,
33
+ need_tet_face_indices: bool = True,
34
+ ):
35
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
36
+ self._mesh = mesh
37
+ self._shape = shape
38
+
39
+ if need_tet_edge_indices:
40
+ self._tet_edge_indices = self._mesh.tet_edge_indices
41
+ self._edge_count = self._mesh.edge_count()
42
+ else:
43
+ self._tet_edge_indices = wp.empty(shape=(0, 0), dtype=int)
44
+ self._edge_count = 0
45
+
46
+ if need_tet_face_indices:
47
+ self._compute_tet_face_indices()
48
+ else:
49
+ self._tet_face_indices = wp.empty(shape=(0, 0), dtype=int)
50
+
51
+ @cache.cached_arg_value
52
+ def topo_arg_value(self, device):
53
+ arg = TetmeshTopologyArg()
54
+ arg.tet_face_indices = self._tet_face_indices.to(device)
55
+ arg.tet_edge_indices = self._tet_edge_indices.to(device)
56
+ arg.face_vertex_indices = self._mesh.face_vertex_indices.to(device)
57
+
58
+ arg.vertex_count = self._mesh.vertex_count()
59
+ arg.face_count = self._mesh.side_count()
60
+ arg.edge_count = self._edge_count
61
+ return arg
62
+
63
+ def _compute_tet_face_indices(self):
64
+ self._tet_face_indices = wp.empty(
65
+ dtype=int, device=self._mesh.tet_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
66
+ )
67
+
68
+ wp.launch(
69
+ kernel=TetmeshSpaceTopology._compute_tet_face_indices_kernel,
70
+ dim=self._mesh._face_tet_indices.shape,
71
+ device=self._mesh.tet_vertex_indices.device,
72
+ inputs=[
73
+ self._mesh.face_tet_indices,
74
+ self._mesh.face_vertex_indices,
75
+ self._mesh.tet_vertex_indices,
76
+ self._tet_face_indices,
77
+ ],
78
+ )
79
+
80
+ @wp.func
81
+ def _find_face_index_in_tet(
82
+ face_vtx: wp.vec3i,
83
+ tet_vtx: wp.vec4i,
84
+ ):
85
+ for k in range(3):
86
+ tvk = wp.vec3i(tet_vtx[k], tet_vtx[(k + 1) % 4], tet_vtx[(k + 2) % 4])
87
+
88
+ # Use fact that face always start with min vertex
89
+ min_t = wp.min(tvk)
90
+ max_t = wp.max(tvk)
91
+ mid_t = tvk[0] + tvk[1] + tvk[2] - min_t - max_t
92
+
93
+ if min_t == face_vtx[0] and (
94
+ (face_vtx[2] == max_t and face_vtx[1] == mid_t) or (face_vtx[1] == max_t and face_vtx[2] == mid_t)
95
+ ):
96
+ return k
97
+
98
+ return 3
99
+
100
+ @wp.kernel
101
+ def _compute_tet_face_indices_kernel(
102
+ face_tet_indices: wp.array(dtype=wp.vec2i),
103
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
104
+ tet_vertex_indices: wp.array2d(dtype=int),
105
+ tet_face_indices: wp.array2d(dtype=int),
106
+ ):
107
+ e = wp.tid()
108
+
109
+ face_vtx = face_vertex_indices[e]
110
+ face_tets = face_tet_indices[e]
111
+
112
+ t0 = face_tets[0]
113
+ t0_vtx = wp.vec4i(
114
+ tet_vertex_indices[t0, 0], tet_vertex_indices[t0, 1], tet_vertex_indices[t0, 2], tet_vertex_indices[t0, 3]
115
+ )
116
+ t0_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t0_vtx)
117
+ tet_face_indices[t0, t0_face] = e
118
+
119
+ t1 = face_tets[1]
120
+ if t1 != t0:
121
+ t1_vtx = wp.vec4i(
122
+ tet_vertex_indices[t1, 0],
123
+ tet_vertex_indices[t1, 1],
124
+ tet_vertex_indices[t1, 2],
125
+ tet_vertex_indices[t1, 3],
126
+ )
127
+ t1_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t1_vtx)
128
+ tet_face_indices[t1, t1_face] = e
129
+
130
+
131
+ class TetmeshDiscontinuousSpaceTopology(
132
+ DiscontinuousSpaceTopologyMixin,
133
+ SpaceTopology,
134
+ ):
135
+ def __init__(self, mesh: Tetmesh, shape: ShapeFunction):
136
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
137
+
138
+
139
+ class TetmeshBasisSpace(ShapeBasisSpace):
140
+ def __init__(self, topology: TetmeshSpaceTopology, shape: ShapeFunction):
141
+ super().__init__(topology, shape)
142
+
143
+ self._mesh: Tetmesh = topology.geometry
144
+
145
+
146
+ class TetmeshPiecewiseConstantBasis(TetmeshBasisSpace):
147
+ def __init__(self, mesh: Tetmesh):
148
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=3)
149
+ topology = TetmeshDiscontinuousSpaceTopology(mesh, shape)
150
+ super().__init__(shape=shape, topology=topology)
151
+
152
+ class Trace(TraceBasisSpace):
153
+ @wp.func
154
+ def _node_coords_in_element(
155
+ side_arg: Tetmesh.SideArg,
156
+ basis_arg: TetmeshBasisSpace.BasisArg,
157
+ element_index: ElementIndex,
158
+ node_index_in_element: int,
159
+ ):
160
+ return Coords(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0)
161
+
162
+ def make_node_coords_in_element(self):
163
+ return self._node_coords_in_element
164
+
165
+ def trace(self):
166
+ return TetmeshPiecewiseConstantBasis.Trace(self)
167
+
168
+
169
+ class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
170
+ def __init__(self, mesh: Tetmesh, shape: TetrahedronPolynomialShapeFunctions):
171
+ super().__init__(mesh, shape, need_tet_edge_indices=shape.ORDER >= 2, need_tet_face_indices=shape.ORDER >= 3)
172
+
173
+ self.element_node_index = self._make_element_node_index()
174
+
175
+ def node_count(self) -> int:
176
+ ORDER = self._shape.ORDER
177
+ INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1)
178
+ INTERIOR_NODES_PER_FACE = max(0, ORDER - 2) * max(0, ORDER - 1) // 2
179
+ INTERIOR_NODES_PER_CELL = max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6
180
+
181
+ return (
182
+ self._mesh.vertex_count()
183
+ + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE
184
+ + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
185
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
186
+ )
187
+
188
+ def _make_element_node_index(self):
189
+ ORDER = self._shape.ORDER
190
+ INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1))
191
+ INTERIOR_NODES_PER_FACE = wp.constant(max(0, ORDER - 2) * max(0, ORDER - 1) // 2)
192
+ INTERIOR_NODES_PER_CELL = wp.constant(max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6)
193
+
194
+ @cache.dynamic_func(suffix=self.name)
195
+ def element_node_index(
196
+ geo_arg: Tetmesh.CellArg,
197
+ topo_arg: TetmeshTopologyArg,
198
+ element_index: ElementIndex,
199
+ node_index_in_elt: int,
200
+ ):
201
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
202
+
203
+ if node_type == TetrahedronPolynomialShapeFunctions.VERTEX:
204
+ return geo_arg.tet_vertex_indices[element_index][type_index]
205
+
206
+ global_offset = topo_arg.vertex_count
207
+
208
+ if node_type == TetrahedronPolynomialShapeFunctions.EDGE:
209
+ edge = type_index // INTERIOR_NODES_PER_EDGE
210
+ edge_node = type_index - INTERIOR_NODES_PER_EDGE * edge
211
+
212
+ global_edge_index = topo_arg.tet_edge_indices[element_index][edge]
213
+
214
+ # Test if we need to swap edge direction
215
+ if INTERIOR_NODES_PER_EDGE > 1:
216
+ if edge < 3:
217
+ c1 = edge
218
+ c2 = (edge + 1) % 3
219
+ else:
220
+ c1 = edge - 3
221
+ c2 = 3
222
+
223
+ if geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2]:
224
+ edge_node = INTERIOR_NODES_PER_EDGE - 1 - edge_node
225
+
226
+ return global_offset + INTERIOR_NODES_PER_EDGE * global_edge_index + edge_node
227
+
228
+ global_offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count
229
+
230
+ if node_type == TetrahedronPolynomialShapeFunctions.FACE:
231
+ face = type_index // INTERIOR_NODES_PER_FACE
232
+ face_node = type_index - INTERIOR_NODES_PER_FACE * face
233
+
234
+ global_face_index = topo_arg.tet_face_indices[element_index][face]
235
+
236
+ if INTERIOR_NODES_PER_FACE == 3:
237
+ # Hard code for P4 case, 3 nodes per face
238
+ # Higher orders would require rotating triangle coordinates, this is not supported yet
239
+
240
+ vidx = geo_arg.tet_vertex_indices[element_index][(face + face_node) % 4]
241
+ fvi = topo_arg.face_vertex_indices[global_face_index]
242
+
243
+ if vidx == fvi[0]:
244
+ face_node = 0
245
+ elif vidx == fvi[1]:
246
+ face_node = 1
247
+ else:
248
+ face_node = 2
249
+
250
+ return global_offset + INTERIOR_NODES_PER_FACE * global_face_index + face_node
251
+
252
+ global_offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count
253
+
254
+ return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
255
+
256
+ return element_node_index
257
+
258
+
259
+ class TetmeshPolynomialBasisSpace(TetmeshBasisSpace):
260
+ def __init__(
261
+ self,
262
+ mesh: Tetmesh,
263
+ degree: int,
264
+ ):
265
+ shape = TetrahedronPolynomialShapeFunctions(degree)
266
+ topology = forward_base_topology(TetmeshPolynomialSpaceTopology, mesh, shape)
267
+
268
+ super().__init__(topology, shape)
269
+
270
+
271
+ class TetmeshDGPolynomialBasisSpace(TetmeshBasisSpace):
272
+ def __init__(
273
+ self,
274
+ mesh: Tetmesh,
275
+ degree: int,
276
+ ):
277
+ shape = TetrahedronPolynomialShapeFunctions(degree)
278
+ topology = TetmeshDiscontinuousSpaceTopology(mesh, shape)
279
+
280
+ super().__init__(topology, shape)
281
+
282
+
283
+ class TetmeshNonConformingPolynomialBasisSpace(TetmeshBasisSpace):
284
+ def __init__(
285
+ self,
286
+ mesh: Tetmesh,
287
+ degree: int,
288
+ ):
289
+ shape = TetrahedronNonConformingPolynomialShapeFunctions(degree)
290
+ topology = TetmeshDiscontinuousSpaceTopology(mesh, shape)
291
+
292
+ super().__init__(topology, shape)
@@ -0,0 +1,295 @@
1
+ from typing import Optional, Type
2
+
3
+ import warp as wp
4
+
5
+ from warp.fem.types import ElementIndex
6
+ from warp.fem.geometry import Geometry, DeformedGeometry
7
+ from warp.fem import cache
8
+
9
+
10
+ class SpaceTopology:
11
+ """
12
+ Interface class for defining the topology of a function space.
13
+
14
+ The topology only considers the indices of the nodes in each element, and as such,
15
+ the connectivity pattern of the function space.
16
+ It does not specify the actual location of the nodes within the elements, or the valuation function.
17
+ """
18
+
19
+ dimension: int
20
+ """Embedding dimension of the function space"""
21
+
22
+ NODES_PER_ELEMENT: int
23
+ """Number of interpolation nodes per element of the geometry.
24
+
25
+ .. note:: This will change to be defined per-element in future versions
26
+ """
27
+
28
+ @wp.struct
29
+ class TopologyArg:
30
+ """Structure containing arguments to be passed to device functions"""
31
+
32
+ pass
33
+
34
+ def __init__(self, geometry: Geometry, nodes_per_element: int):
35
+ self._geometry = geometry
36
+ self.dimension = geometry.dimension
37
+ self.NODES_PER_ELEMENT = wp.constant(nodes_per_element)
38
+ self.ElementArg = geometry.CellArg
39
+
40
+ @property
41
+ def geometry(self) -> Geometry:
42
+ """Underlying geometry"""
43
+ return self._geometry
44
+
45
+ def node_count(self) -> int:
46
+ """Number of nodes in the interpolation basis"""
47
+ raise NotImplementedError
48
+
49
+ def topo_arg_value(self, device) -> "TopologyArg":
50
+ """Value of the topology argument structure to be passed to device functions"""
51
+ return SpaceTopology.TopologyArg()
52
+
53
+ @property
54
+ def name(self):
55
+ return f"{self.__class__.__name__}_{self.NODES_PER_ELEMENT}"
56
+
57
+ def __str__(self):
58
+ return self.name
59
+
60
+ @staticmethod
61
+ def element_node_index(
62
+ geo_arg: "ElementArg", topo_arg: "TopologyArg", element_index: ElementIndex, node_index_in_elt: int
63
+ ):
64
+ """Global node index for a given node in a given element"""
65
+ raise NotImplementedError
66
+
67
+ def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array:
68
+ """Returns a temporary array containing the global index for each node of each element"""
69
+
70
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
71
+
72
+ @cache.dynamic_kernel(suffix=self.name)
73
+ def fill_element_node_indices(
74
+ geo_cell_arg: self.geometry.CellArg,
75
+ topo_arg: self.TopologyArg,
76
+ element_node_indices: wp.array2d(dtype=int),
77
+ ):
78
+ element_index = wp.tid()
79
+ for n in range(NODES_PER_ELEMENT):
80
+ element_node_indices[element_index, n] = self.element_node_index(
81
+ geo_cell_arg, topo_arg, element_index, n
82
+ )
83
+
84
+ shape = (self.geometry.cell_count(), NODES_PER_ELEMENT)
85
+ if out is None:
86
+ element_node_indices = wp.empty(
87
+ shape=shape,
88
+ dtype=int,
89
+ )
90
+ else:
91
+ if out.shape != shape or out.dtype != wp.int32:
92
+ raise ValueError(f"Out element node indices array must have shape {shape} and data type 'int32'")
93
+ element_node_indices = out
94
+
95
+ wp.launch(
96
+ dim=element_node_indices.shape[0],
97
+ kernel=fill_element_node_indices,
98
+ inputs=[
99
+ self.geometry.cell_arg_value(device=element_node_indices.device),
100
+ self.topo_arg_value(device=element_node_indices.device),
101
+ element_node_indices,
102
+ ],
103
+ device=element_node_indices.device,
104
+ )
105
+
106
+ return element_node_indices
107
+
108
+ # Interface generating trace space topology
109
+
110
+ def trace(self) -> "TraceSpaceTopology":
111
+ """Trace of the function space over lower-dimensional elements of the geometry"""
112
+
113
+ return TraceSpaceTopology(self)
114
+
115
+ @property
116
+ def is_trace(self) -> bool:
117
+ """Whether this topology is defined on the trace of the geometry"""
118
+ return self.dimension == self.geometry.dimension - 1
119
+
120
+ def full_space_topology(self) -> "SpaceTopology":
121
+ """Returns the full space topology from which this topology is derived"""
122
+ return self
123
+
124
+ def __eq__(self, other: "SpaceTopology") -> bool:
125
+ """Checks whether two topologies are compatible"""
126
+ return self.geometry == other.geometry and self.name == other.name
127
+
128
+ def is_derived_from(self, other: "SpaceTopology") -> bool:
129
+ """Checks whether two topologies are equal, or `self` is the trace of `other`"""
130
+ if self.dimension == other.dimension:
131
+ return self == other
132
+ if self.dimension + 1 == other.dimension:
133
+ return self.full_space_topology() == other
134
+ return False
135
+
136
+
137
+ class TraceSpaceTopology(SpaceTopology):
138
+ """Auto-generated trace topology defining the node indices associated to the geometry sides"""
139
+
140
+ def __init__(self, topo: SpaceTopology):
141
+ super().__init__(topo.geometry, 2 * topo.NODES_PER_ELEMENT)
142
+
143
+ self._topo = topo
144
+ self.dimension = topo.dimension - 1
145
+ self.ElementArg = topo.geometry.SideArg
146
+
147
+ self.TopologyArg = topo.TopologyArg
148
+ self.topo_arg_value = topo.topo_arg_value
149
+
150
+ self.inner_cell_index = self._make_inner_cell_index()
151
+ self.outer_cell_index = self._make_outer_cell_index()
152
+ self.neighbor_cell_index = self._make_neighbor_cell_index()
153
+
154
+ self.element_node_index = self._make_element_node_index()
155
+
156
+ def node_count(self) -> int:
157
+ return self._topo.node_count()
158
+
159
+ @property
160
+ def name(self):
161
+ return f"{self._topo.name}_Trace"
162
+
163
+ def _make_inner_cell_index(self):
164
+ NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
165
+
166
+ @cache.dynamic_func(suffix=self.name)
167
+ def inner_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
168
+ index_in_inner_cell = wp.select(node_index_in_elt < NODES_PER_ELEMENT, -1, node_index_in_elt)
169
+ return self.geometry.side_inner_cell_index(args, element_index), index_in_inner_cell
170
+
171
+ return inner_cell_index
172
+
173
+ def _make_outer_cell_index(self):
174
+ NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
175
+
176
+ @cache.dynamic_func(suffix=self.name)
177
+ def outer_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
178
+ return self.geometry.side_outer_cell_index(args, element_index), node_index_in_elt - NODES_PER_ELEMENT
179
+
180
+ return outer_cell_index
181
+
182
+ def _make_neighbor_cell_index(self):
183
+ NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
184
+
185
+ @cache.dynamic_func(suffix=self.name)
186
+ def neighbor_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
187
+ if node_index_in_elt < NODES_PER_ELEMENT:
188
+ return self.geometry.side_inner_cell_index(args, element_index), node_index_in_elt
189
+ else:
190
+ return (
191
+ self.geometry.side_outer_cell_index(args, element_index),
192
+ node_index_in_elt - NODES_PER_ELEMENT,
193
+ )
194
+
195
+ return neighbor_cell_index
196
+
197
+ def _make_element_node_index(self):
198
+ @cache.dynamic_func(suffix=self.name)
199
+ def trace_element_node_index(
200
+ geo_side_arg: self.geometry.SideArg,
201
+ topo_arg: self._topo.TopologyArg,
202
+ element_index: ElementIndex,
203
+ node_index_in_elt: int,
204
+ ):
205
+ cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
206
+
207
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
208
+ return self._topo.element_node_index(geo_cell_arg, topo_arg, cell_index, index_in_cell)
209
+
210
+ return trace_element_node_index
211
+
212
+ def full_space_topology(self) -> SpaceTopology:
213
+ """Returns the full space topology from which this topology is derived"""
214
+ return self._topo
215
+
216
+ def __eq__(self, other: "TraceSpaceTopology") -> bool:
217
+ return self._topo == other._topo
218
+
219
+
220
+ class DiscontinuousSpaceTopologyMixin:
221
+ """Helper for defining discontinuous topologies (per-element nodes)"""
222
+
223
+ def __init__(self, *args, **kwargs):
224
+ super().__init__(*args, **kwargs)
225
+ self.element_node_index = self._make_element_node_index()
226
+
227
+ def node_count(self):
228
+ return self.geometry.cell_count() * self.NODES_PER_ELEMENT
229
+
230
+ @property
231
+ def name(self):
232
+ return f"{self.geometry.name}_D{self.NODES_PER_ELEMENT}"
233
+
234
+ def _make_element_node_index(self):
235
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
236
+
237
+ @cache.dynamic_func(suffix=self.name)
238
+ def element_node_index(
239
+ elt_arg: self.geometry.CellArg,
240
+ topo_arg: self.TopologyArg,
241
+ element_index: ElementIndex,
242
+ node_index_in_elt: int,
243
+ ):
244
+ return NODES_PER_ELEMENT * element_index + node_index_in_elt
245
+
246
+ return element_node_index
247
+
248
+
249
+ class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology):
250
+ """Topology for generic discontinuous spaces"""
251
+
252
+ pass
253
+
254
+
255
+ class DeformedGeometrySpaceTopology(SpaceTopology):
256
+ def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
257
+ super().__init__(geometry, base_topology.NODES_PER_ELEMENT)
258
+
259
+ self.base = base_topology
260
+ self.node_count = self.base.node_count
261
+ self.topo_arg_value = self.base.topo_arg_value
262
+ self.TopologyArg = self.base.TopologyArg
263
+
264
+ self.element_node_index = self._make_element_node_index()
265
+
266
+ @property
267
+ def name(self):
268
+ return f"{self.base.name}_{self.geometry.field.name}"
269
+
270
+ def _make_element_node_index(self):
271
+ @cache.dynamic_func(suffix=self.name)
272
+ def element_node_index(
273
+ elt_arg: self.geometry.CellArg,
274
+ topo_arg: self.TopologyArg,
275
+ element_index: ElementIndex,
276
+ node_index_in_elt: int,
277
+ ):
278
+ return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
279
+
280
+ return element_node_index
281
+
282
+
283
+ def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
284
+ """
285
+ If `geometry` is *not* a :class:`DeformedGeometry`, constructs a normal instance of `topology_class` over `geometry`, forwarding additional arguments.
286
+
287
+ If `geometry` *is* a :class:`DeformedGeometry`, constructs an instance of `topology_class` over the base (undeformed) geometry of `geometry`, then warp it
288
+ in a :class:`DeformedGeometrySpaceTopology` forwarding the calls to the underlying topology.
289
+ """
290
+
291
+ if isinstance(geometry, DeformedGeometry):
292
+ base_topo = topology_class(geometry.base, *args, **kwargs)
293
+ return DeformedGeometrySpaceTopology(geometry, base_topo)
294
+
295
+ return topology_class(geometry, *args, **kwargs)