warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
1
+ from typing import Any
2
+
3
+ import warp as wp
4
+
5
+ from warp.fem.types import ElementIndex, NULL_ELEMENT_INDEX
6
+ from warp.fem.utils import masked_indices
7
+ from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary
8
+
9
+ from .geometry import Geometry
10
+
11
+
12
+ wp.set_module_options({"enable_backward": False})
13
+
14
+
15
+ class GeometryPartition:
16
+ """Base class for geometry partitions, i.e. subset of cells and sides"""
17
+
18
+ class CellArg:
19
+ pass
20
+
21
+ class SideArg:
22
+ pass
23
+
24
+ def __init__(self, geometry: Geometry):
25
+ self.geometry = geometry
26
+
27
+ def cell_count(self) -> int:
28
+ """Number of cells that are 'owned' by this partition"""
29
+ raise NotImplementedError()
30
+
31
+ def side_count(self) -> int:
32
+ """Number of sides that are 'owned' by this partition"""
33
+ raise NotImplementedError()
34
+
35
+ def boundary_side_count(self) -> int:
36
+ """Number of geo-boundary sides that are 'owned' by this partition"""
37
+ raise NotImplementedError()
38
+
39
+ def frontier_side_count(self) -> int:
40
+ """Number of sides with neighbors owned by this and another partition"""
41
+ raise NotImplementedError()
42
+
43
+ @property
44
+ def name(self) -> str:
45
+ return f"{self.geometry.name}_{self.__class__.__name__}"
46
+
47
+ def __str__(self) -> str:
48
+ return self.name
49
+
50
+ def cell_arg_value(self, device):
51
+ raise NotImplementedError()
52
+
53
+ def side_arg_value(self, device):
54
+ raise NotImplementedError()
55
+
56
+ @staticmethod
57
+ def cell_index(args: CellArg, partition_cell_index: int):
58
+ """Index in the geometry of a partition cell"""
59
+ raise NotImplementedError()
60
+
61
+ @staticmethod
62
+ def partition_cell_index(args: CellArg, cell_index: int):
63
+ """Index of a geometry cell in the partition (or ``NULL_ELEMENT_INDEX``)"""
64
+ raise NotImplementedError()
65
+
66
+ @staticmethod
67
+ def side_index(args: SideArg, partition_side_index: int):
68
+ """Partition side to side index"""
69
+ raise NotImplementedError()
70
+
71
+ @staticmethod
72
+ def boundary_side_index(args: SideArg, boundary_side_index: int):
73
+ """Boundary side to side index"""
74
+ raise NotImplementedError()
75
+
76
+ @staticmethod
77
+ def frontier_side_index(args: SideArg, frontier_side_index: int):
78
+ """Frontier side to side index"""
79
+ raise NotImplementedError()
80
+
81
+
82
+ class WholeGeometryPartition(GeometryPartition):
83
+ """Trivial (NOP) partition"""
84
+
85
+ def __init__(
86
+ self,
87
+ geometry: Geometry,
88
+ ):
89
+ super().__init__(geometry)
90
+
91
+ self.SideArg = geometry.SideIndexArg
92
+ self.side_arg_value = geometry.side_index_arg_value
93
+
94
+ self.cell_index = WholeGeometryPartition._identity_element_index
95
+ self.partition_cell_index = WholeGeometryPartition._identity_element_index
96
+
97
+ self.side_index = WholeGeometryPartition._identity_element_index
98
+ self.boundary_side_index = geometry.boundary_side_index
99
+ self.frontier_side_index = WholeGeometryPartition._identity_element_index
100
+
101
+ def __eq__(self, other: GeometryPartition) -> bool:
102
+ # Ensures that two whole partition instances of the same geometry are considered equal
103
+ return isinstance(other, WholeGeometryPartition) and self.geometry == other.geometry
104
+
105
+ def cell_count(self) -> int:
106
+ return self.geometry.cell_count()
107
+
108
+ def side_count(self) -> int:
109
+ return self.geometry.side_count()
110
+
111
+ def boundary_side_count(self) -> int:
112
+ return self.geometry.boundary_side_count()
113
+
114
+ def frontier_side_count(self) -> int:
115
+ return 0
116
+
117
+ @wp.struct
118
+ class CellArg:
119
+ pass
120
+
121
+ def cell_arg_value(self, device):
122
+ arg = WholeGeometryPartition.CellArg()
123
+ return arg
124
+
125
+ @wp.func
126
+ def _identity_element_index(args: Any, idx: ElementIndex):
127
+ return idx
128
+
129
+ @property
130
+ def name(self) -> str:
131
+ return self.geometry.name
132
+
133
+
134
+ class CellBasedGeometryPartition(GeometryPartition):
135
+ """Geometry partition based on a subset of cells. Interior, boundary and frontier sides are automatically categorized."""
136
+
137
+ def __init__(
138
+ self,
139
+ geometry: Geometry,
140
+ device=None,
141
+ ):
142
+ super().__init__(geometry)
143
+
144
+ @wp.struct
145
+ class SideArg:
146
+ partition_side_indices: wp.array(dtype=int)
147
+ boundary_side_indices: wp.array(dtype=int)
148
+ frontier_side_indices: wp.array(dtype=int)
149
+
150
+ def side_count(self) -> int:
151
+ return self._partition_side_indices.array.shape[0]
152
+
153
+ def boundary_side_count(self) -> int:
154
+ return self._boundary_side_indices.array.shape[0]
155
+
156
+ def frontier_side_count(self) -> int:
157
+ return self._frontier_side_indices.array.shape[0]
158
+
159
+ @cached_arg_value
160
+ def side_arg_value(self, device):
161
+ arg = LinearGeometryPartition.SideArg()
162
+ arg.partition_side_indices = self._partition_side_indices.array.to(device)
163
+ arg.boundary_side_indices = self._boundary_side_indices.array.to(device)
164
+ arg.frontier_side_indices = self._frontier_side_indices.array.to(device)
165
+ return arg
166
+
167
+ @wp.func
168
+ def side_index(args: SideArg, partition_side_index: int):
169
+ """partition side to side index"""
170
+ return args.partition_side_indices[partition_side_index]
171
+
172
+ @wp.func
173
+ def boundary_side_index(args: SideArg, boundary_side_index: int):
174
+ """Boundary side to side index"""
175
+ return args.boundary_side_indices[boundary_side_index]
176
+
177
+ @wp.func
178
+ def frontier_side_index(args: SideArg, frontier_side_index: int):
179
+ """Frontier side to side index"""
180
+ return args.frontier_side_indices[frontier_side_index]
181
+
182
+ def compute_side_indices_from_cells(
183
+ self, cell_arg_value: Any, cell_inclusion_test_func: wp.Function, device, temporary_store: TemporaryStore = None
184
+ ):
185
+ from warp.fem import cache
186
+
187
+ cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values()))
188
+
189
+ @cache.dynamic_kernel(suffix=f"{self.geometry.name}_{cell_inclusion_test_func.key}")
190
+ def count_sides(
191
+ geo_arg: self.geometry.SideArg,
192
+ cell_arg_value: cell_arg_type,
193
+ partition_side_mask: wp.array(dtype=int),
194
+ boundary_side_mask: wp.array(dtype=int),
195
+ frontier_side_mask: wp.array(dtype=int),
196
+ ):
197
+ side_index = wp.tid()
198
+ inner_cell_index = self.geometry.side_inner_cell_index(geo_arg, side_index)
199
+ outer_cell_index = self.geometry.side_outer_cell_index(geo_arg, side_index)
200
+
201
+ inner_in = cell_inclusion_test_func(cell_arg_value, inner_cell_index)
202
+ outer_in = cell_inclusion_test_func(cell_arg_value, outer_cell_index)
203
+
204
+ if inner_in:
205
+ # Inner neighbor in partition; count as partition side
206
+ partition_side_mask[side_index] = 1
207
+
208
+ # Inner and outer element as the same -- this is a boundary side
209
+ if inner_cell_index == outer_cell_index:
210
+ boundary_side_mask[side_index] = 1
211
+
212
+ if inner_in != outer_in:
213
+ # Exactly one neighbor in partition; count as frontier side
214
+ frontier_side_mask[side_index] = 1
215
+
216
+ partition_side_mask = borrow_temporary(
217
+ temporary_store,
218
+ shape=(self.geometry.side_count(),),
219
+ dtype=int,
220
+ device=device,
221
+ )
222
+ boundary_side_mask = borrow_temporary(
223
+ temporary_store,
224
+ shape=(self.geometry.side_count(),),
225
+ dtype=int,
226
+ device=device,
227
+ )
228
+ frontier_side_mask = borrow_temporary(
229
+ temporary_store,
230
+ shape=(self.geometry.side_count(),),
231
+ dtype=int,
232
+ device=device,
233
+ )
234
+
235
+ partition_side_mask.array.zero_()
236
+ boundary_side_mask.array.zero_()
237
+ frontier_side_mask.array.zero_()
238
+
239
+ wp.launch(
240
+ dim=partition_side_mask.array.shape[0],
241
+ kernel=count_sides,
242
+ inputs=[
243
+ self.geometry.side_arg_value(device),
244
+ cell_arg_value,
245
+ partition_side_mask.array,
246
+ boundary_side_mask.array,
247
+ frontier_side_mask.array,
248
+ ],
249
+ device=device,
250
+ )
251
+
252
+ # Convert counts to indices
253
+ self._partition_side_indices, _ = masked_indices(partition_side_mask.array, temporary_store=temporary_store)
254
+ self._boundary_side_indices, _ = masked_indices(boundary_side_mask.array, temporary_store=temporary_store)
255
+ self._frontier_side_indices, _ = masked_indices(frontier_side_mask.array, temporary_store=temporary_store)
256
+
257
+ partition_side_mask.release()
258
+ boundary_side_mask.release()
259
+ frontier_side_mask.release()
260
+
261
+
262
+ class LinearGeometryPartition(CellBasedGeometryPartition):
263
+ def __init__(
264
+ self,
265
+ geometry: Geometry,
266
+ partition_rank: int,
267
+ partition_count: int,
268
+ device=None,
269
+ temporary_store: TemporaryStore = None,
270
+ ):
271
+ """Creates a geometry partition by uniformly partionning cell indices
272
+
273
+ Args:
274
+ geometry: the geometry to partition
275
+ partition_rank: the index of the partition being created
276
+ partition_count: the number of partitions that will be created over the geometry
277
+ device: Warp device on which to perform and store computations
278
+ """
279
+ super().__init__(geometry)
280
+
281
+ total_cell_count = geometry.cell_count()
282
+
283
+ cells_per_partition = (total_cell_count + partition_count - 1) // partition_count
284
+ self.cell_begin = cells_per_partition * partition_rank
285
+ self.cell_end = min(self.cell_begin + cells_per_partition, total_cell_count)
286
+
287
+ super().compute_side_indices_from_cells(
288
+ self.cell_arg_value(device),
289
+ LinearGeometryPartition._cell_inclusion_test,
290
+ device,
291
+ temporary_store=temporary_store,
292
+ )
293
+
294
+ def cell_count(self) -> int:
295
+ return self.cell_end - self.cell_begin
296
+
297
+ @wp.struct
298
+ class CellArg:
299
+ cell_begin: int
300
+ cell_end: int
301
+
302
+ def cell_arg_value(self, device):
303
+ arg = LinearGeometryPartition.CellArg()
304
+ arg.cell_begin = self.cell_begin
305
+ arg.cell_end = self.cell_end
306
+ return arg
307
+
308
+ @wp.func
309
+ def cell_index(args: CellArg, partition_cell_index: int):
310
+ """Partition cell to cell index"""
311
+ return args.cell_begin + partition_cell_index
312
+
313
+ @wp.func
314
+ def partition_cell_index(args: CellArg, cell_index: int):
315
+ """Partition cell to cell index"""
316
+ if cell_index > args.cell_end:
317
+ return NULL_ELEMENT_INDEX
318
+
319
+ partition_cell_index = cell_index - args.cell_begin
320
+ if partition_cell_index < 0:
321
+ return NULL_ELEMENT_INDEX
322
+
323
+ return partition_cell_index
324
+
325
+ @wp.func
326
+ def _cell_inclusion_test(arg: CellArg, cell_index: int):
327
+ return cell_index >= arg.cell_begin and cell_index < arg.cell_end
328
+
329
+
330
+ class ExplicitGeometryPartition(CellBasedGeometryPartition):
331
+ def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)", temporary_store: TemporaryStore = None):
332
+ """Creates a geometry partition by uniformly partionning cell indices
333
+
334
+ Args:
335
+ geometry: the geometry to partition
336
+ cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
337
+ """
338
+
339
+ super().__init__(geometry)
340
+
341
+ self._cell_mask = cell_mask
342
+ self._cells, self._partition_cells = masked_indices(self._cell_mask, temporary_store=temporary_store)
343
+
344
+ super().compute_side_indices_from_cells(
345
+ self._cell_mask,
346
+ ExplicitGeometryPartition._cell_inclusion_test,
347
+ self._cell_mask.device,
348
+ temporary_store=temporary_store,
349
+ )
350
+
351
+ def cell_count(self) -> int:
352
+ return self._cells.array.shape[0]
353
+
354
+ @wp.struct
355
+ class CellArg:
356
+ cell_index: wp.array(dtype=int)
357
+ partition_cell_index: wp.array(dtype=int)
358
+
359
+ @cached_arg_value
360
+ def cell_arg_value(self, device):
361
+ arg = ExplicitGeometryPartition.CellArg()
362
+ arg.cell_index = self._cells.array.to(device)
363
+ arg.partition_cell_index = self._partition_cells.array.to(device)
364
+ return arg
365
+
366
+ @wp.func
367
+ def cell_index(args: CellArg, partition_cell_index: int):
368
+ return args.cell_index[partition_cell_index]
369
+
370
+ @wp.func
371
+ def partition_cell_index(args: CellArg, cell_index: int):
372
+ return args.partition_cell_index[cell_index]
373
+
374
+ @wp.func
375
+ def _cell_inclusion_test(mask: wp.array(dtype=int), cell_index: int):
376
+ return mask[cell_index] > 0