warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
11
+ from warp.fem.types import NULL_NODE_INDEX
12
+ from warp.fem.utils import _iota_kernel, compress_node_indices
13
+
14
+ from .function_space import FunctionSpace
15
+ from .topology import SpaceTopology
16
+
17
+ wp.set_module_options({"enable_backward": False})
18
+
19
+
20
+ class SpacePartition:
21
+ class PartitionArg:
22
+ pass
23
+
24
+ def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
25
+ self.space_topology = space_topology
26
+ self.geo_partition = geo_partition
27
+
28
+ def node_count(self):
29
+ """Returns number of nodes in this partition"""
30
+
31
+ def owned_node_count(self) -> int:
32
+ """Returns number of nodes in this partition, excluding exterior halo"""
33
+
34
+ def interior_node_count(self) -> int:
35
+ """Returns number of interior nodes in this partition"""
36
+
37
+ def space_node_indices(self) -> wp.array:
38
+ """Return the global function space indices for nodes in this partition"""
39
+
40
+ def partition_arg_value(self, device):
41
+ pass
42
+
43
+ @staticmethod
44
+ def partition_node_index(args: "PartitionArg", space_node_index: int):
45
+ """Returns the index in the partition of a function space node, or -1 if it does not exist"""
46
+
47
+ def __str__(self) -> str:
48
+ return self.name
49
+
50
+ @property
51
+ def name(self) -> str:
52
+ return f"{self.__class__.__name__}"
53
+
54
+
55
+ class WholeSpacePartition(SpacePartition):
56
+ @wp.struct
57
+ class PartitionArg:
58
+ pass
59
+
60
+ def __init__(self, space_topology: SpaceTopology):
61
+ super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
62
+ self._node_indices = None
63
+
64
+ def node_count(self):
65
+ """Returns number of nodes in this partition"""
66
+ return self.space_topology.node_count()
67
+
68
+ def owned_node_count(self) -> int:
69
+ """Returns number of nodes in this partition, excluding exterior halo"""
70
+ return self.space_topology.node_count()
71
+
72
+ def interior_node_count(self) -> int:
73
+ """Returns number of interior nodes in this partition"""
74
+ return self.space_topology.node_count()
75
+
76
+ def space_node_indices(self):
77
+ """Return the global function space indices for nodes in this partition"""
78
+ if self._node_indices is None:
79
+ self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
80
+ wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
81
+ return self._node_indices.array
82
+
83
+ def partition_arg_value(self, device):
84
+ return WholeSpacePartition.PartitionArg()
85
+
86
+ @wp.func
87
+ def partition_node_index(args: Any, space_node_index: int):
88
+ return space_node_index
89
+
90
+ def __eq__(self, other: SpacePartition) -> bool:
91
+ return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
92
+
93
+ @property
94
+ def name(self) -> str:
95
+ return "Whole"
96
+
97
+
98
+ class NodeCategory:
99
+ OWNED_INTERIOR = wp.constant(0)
100
+ """Node is touched exclusively by this partition, not touched by frontier side"""
101
+ OWNED_FRONTIER = wp.constant(1)
102
+ """Node is touched by a frontier side, but belongs to an element of this partition"""
103
+ HALO_LOCAL_SIDE = wp.constant(2)
104
+ """Node belongs to an element of another partition, but is touched by one of our frontier side"""
105
+ HALO_OTHER_SIDE = wp.constant(3)
106
+ """Node belongs to an element of another partition, and is not touched by one of our frontier side"""
107
+ EXTERIOR = wp.constant(4)
108
+ """Node is never referenced by this partition"""
109
+
110
+ COUNT = 5
111
+
112
+
113
+ class NodePartition(SpacePartition):
114
+ @wp.struct
115
+ class PartitionArg:
116
+ space_to_partition: wp.array(dtype=int)
117
+
118
+ def __init__(
119
+ self,
120
+ space_topology: SpaceTopology,
121
+ geo_partition: GeometryPartition,
122
+ with_halo: bool = True,
123
+ device=None,
124
+ temporary_store: TemporaryStore = None,
125
+ ):
126
+ super().__init__(space_topology=space_topology, geo_partition=geo_partition)
127
+
128
+ self._compute_node_indices_from_sides(device, with_halo, temporary_store)
129
+
130
+ def node_count(self) -> int:
131
+ """Returns number of nodes referenced by this partition, including exterior halo"""
132
+ return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
133
+
134
+ def owned_node_count(self) -> int:
135
+ """Returns number of nodes in this partition, excluding exterior halo"""
136
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
137
+
138
+ def interior_node_count(self) -> int:
139
+ """Returns number of interior nodes in this partition"""
140
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
141
+
142
+ def space_node_indices(self):
143
+ """Return the global function space indices for nodes in this partition"""
144
+ return self._node_indices.array
145
+
146
+ @cached_arg_value
147
+ def partition_arg_value(self, device):
148
+ arg = NodePartition.PartitionArg()
149
+ arg.space_to_partition = self._space_to_partition.array.to(device)
150
+ return arg
151
+
152
+ @wp.func
153
+ def partition_node_index(args: PartitionArg, space_node_index: int):
154
+ return args.space_to_partition[space_node_index]
155
+
156
+ def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
157
+ from warp.fem import cache
158
+
159
+ trace_topology = self.space_topology.trace()
160
+ NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
161
+ NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
162
+
163
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
164
+ def node_category_from_cells_kernel(
165
+ geo_arg: self.geo_partition.geometry.CellArg,
166
+ geo_partition_arg: self.geo_partition.CellArg,
167
+ space_arg: self.space_topology.TopologyArg,
168
+ node_mask: wp.array(dtype=int),
169
+ ):
170
+ partition_cell_index = wp.tid()
171
+
172
+ cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
173
+
174
+ for n in range(NODES_PER_CELL):
175
+ space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
176
+ node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
177
+
178
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
179
+ def node_category_from_owned_sides_kernel(
180
+ geo_arg: self.geo_partition.geometry.SideArg,
181
+ geo_partition_arg: self.geo_partition.SideArg,
182
+ space_arg: trace_topology.TopologyArg,
183
+ node_mask: wp.array(dtype=int),
184
+ ):
185
+ partition_side_index = wp.tid()
186
+
187
+ side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
188
+
189
+ for n in range(NODES_PER_SIDE):
190
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
191
+
192
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
193
+ node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
194
+
195
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
196
+ def node_category_from_frontier_sides_kernel(
197
+ geo_arg: self.geo_partition.geometry.SideArg,
198
+ geo_partition_arg: self.geo_partition.SideArg,
199
+ space_arg: trace_topology.TopologyArg,
200
+ node_mask: wp.array(dtype=int),
201
+ ):
202
+ frontier_side_index = wp.tid()
203
+
204
+ side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
205
+
206
+ for n in range(NODES_PER_SIDE):
207
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
208
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
209
+ node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
210
+ elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
211
+ node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
212
+
213
+ node_category = borrow_temporary(
214
+ temporary_store,
215
+ shape=(self.space_topology.node_count(),),
216
+ dtype=int,
217
+ device=device,
218
+ )
219
+ node_category.array.fill_(value=NodeCategory.EXTERIOR)
220
+
221
+ wp.launch(
222
+ dim=self.geo_partition.cell_count(),
223
+ kernel=node_category_from_cells_kernel,
224
+ inputs=[
225
+ self.geo_partition.geometry.cell_arg_value(device),
226
+ self.geo_partition.cell_arg_value(device),
227
+ self.space_topology.topo_arg_value(device),
228
+ node_category.array,
229
+ ],
230
+ device=device,
231
+ )
232
+
233
+ if with_halo:
234
+ wp.launch(
235
+ dim=self.geo_partition.side_count(),
236
+ kernel=node_category_from_owned_sides_kernel,
237
+ inputs=[
238
+ self.geo_partition.geometry.side_arg_value(device),
239
+ self.geo_partition.side_arg_value(device),
240
+ self.space_topology.topo_arg_value(device),
241
+ node_category.array,
242
+ ],
243
+ device=device,
244
+ )
245
+
246
+ wp.launch(
247
+ dim=self.geo_partition.frontier_side_count(),
248
+ kernel=node_category_from_frontier_sides_kernel,
249
+ inputs=[
250
+ self.geo_partition.geometry.side_arg_value(device),
251
+ self.geo_partition.side_arg_value(device),
252
+ self.space_topology.topo_arg_value(device),
253
+ node_category.array,
254
+ ],
255
+ device=device,
256
+ )
257
+
258
+ self._finalize_node_indices(node_category.array, temporary_store)
259
+
260
+ node_category.release()
261
+
262
+ def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
263
+ category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
264
+
265
+ # Copy offsets to cpu
266
+ device = node_category.device
267
+ self._category_offsets = borrow_temporary(
268
+ temporary_store,
269
+ shape=category_offsets.array.shape,
270
+ dtype=category_offsets.array.dtype,
271
+ pinned=device.is_cuda,
272
+ device="cpu",
273
+ )
274
+ wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
275
+
276
+ if device.is_cuda:
277
+ # TODO switch to synchronize_event once available
278
+ wp.synchronize_stream(wp.get_stream(device))
279
+
280
+ category_offsets.release()
281
+
282
+ # Compute global to local indices
283
+ self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
284
+ wp.launch(
285
+ kernel=NodePartition._scatter_partition_indices,
286
+ dim=self.space_topology.node_count(),
287
+ device=device,
288
+ inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
289
+ )
290
+
291
+ # Copy to shrinked-to-fit array
292
+ self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
293
+ wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
294
+
295
+ node_indices.release()
296
+
297
+ @wp.kernel
298
+ def _scatter_partition_indices(
299
+ local_node_count: int,
300
+ node_indices: wp.array(dtype=int),
301
+ space_to_partition_indices: wp.array(dtype=int),
302
+ ):
303
+ local_idx = wp.tid()
304
+ space_idx = node_indices[local_idx]
305
+
306
+ if local_idx < local_node_count:
307
+ space_to_partition_indices[space_idx] = local_idx
308
+ else:
309
+ space_to_partition_indices[space_idx] = NULL_NODE_INDEX
310
+
311
+
312
+ def make_space_partition(
313
+ space: Optional[FunctionSpace] = None,
314
+ geometry_partition: Optional[GeometryPartition] = None,
315
+ space_topology: Optional[SpaceTopology] = None,
316
+ with_halo: bool = True,
317
+ device=None,
318
+ temporary_store: TemporaryStore = None,
319
+ ) -> SpacePartition:
320
+ """Computes the subset of nodes from a function space topology that touch a geometry partition
321
+
322
+ Either `space_topology` or `space` must be provided (and will be considered in that order).
323
+
324
+ Args:
325
+ space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
326
+ geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
327
+ space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
328
+ with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
329
+ device: Warp device on which to perform and store computations
330
+
331
+ Returns:
332
+ the resulting space partition
333
+ """
334
+
335
+ if space_topology is None:
336
+ space_topology = space.topology
337
+
338
+ space_topology = space_topology.full_space_topology()
339
+
340
+ if geometry_partition is not None:
341
+ if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
342
+ return NodePartition(
343
+ space_topology=space_topology,
344
+ geo_partition=geometry_partition,
345
+ with_halo=with_halo,
346
+ device=device,
347
+ temporary_store=temporary_store,
348
+ )
349
+
350
+ return WholeSpacePartition(space_topology)