warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
11
+ from warp.fem.types import NULL_NODE_INDEX
12
+ from warp.fem.utils import _iota_kernel, compress_node_indices
13
+
14
+ from .function_space import FunctionSpace
15
+ from .topology import SpaceTopology
16
+
17
+ wp.set_module_options({"enable_backward": False})
18
+
19
+
20
+ class SpacePartition:
21
+ class PartitionArg:
22
+ pass
23
+
24
+ def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
25
+ self.space_topology = space_topology
26
+ self.geo_partition = geo_partition
27
+
28
+ def node_count(self):
29
+ """Returns number of nodes in this partition"""
30
+
31
+ def owned_node_count(self) -> int:
32
+ """Returns number of nodes in this partition, excluding exterior halo"""
33
+
34
+ def interior_node_count(self) -> int:
35
+ """Returns number of interior nodes in this partition"""
36
+
37
+ def space_node_indices(self) -> wp.array:
38
+ """Return the global function space indices for nodes in this partition"""
39
+
40
+ def partition_arg_value(self, device):
41
+ pass
42
+
43
+ @staticmethod
44
+ def partition_node_index(args: "PartitionArg", space_node_index: int):
45
+ """Returns the index in the partition of a function space node, or -1 if it does not exist"""
46
+
47
+ def __str__(self) -> str:
48
+ return self.name
49
+
50
+ @property
51
+ def name(self) -> str:
52
+ return f"{self.__class__.__name__}"
53
+
54
+
55
+ class WholeSpacePartition(SpacePartition):
56
+ @wp.struct
57
+ class PartitionArg:
58
+ pass
59
+
60
+ def __init__(self, space_topology: SpaceTopology):
61
+ super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
62
+ self._node_indices = None
63
+
64
+ def node_count(self):
65
+ """Returns number of nodes in this partition"""
66
+ return self.space_topology.node_count()
67
+
68
+ def owned_node_count(self) -> int:
69
+ """Returns number of nodes in this partition, excluding exterior halo"""
70
+ return self.space_topology.node_count()
71
+
72
+ def interior_node_count(self) -> int:
73
+ """Returns number of interior nodes in this partition"""
74
+ return self.space_topology.node_count()
75
+
76
+ def space_node_indices(self):
77
+ """Return the global function space indices for nodes in this partition"""
78
+ if self._node_indices is None:
79
+ self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
80
+ wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
81
+ return self._node_indices.array
82
+
83
+ def partition_arg_value(self, device):
84
+ return WholeSpacePartition.PartitionArg()
85
+
86
+ @wp.func
87
+ def partition_node_index(args: Any, space_node_index: int):
88
+ return space_node_index
89
+
90
+ def __eq__(self, other: SpacePartition) -> bool:
91
+ return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
92
+
93
+ @property
94
+ def name(self) -> str:
95
+ return "Whole"
96
+
97
+
98
+ class NodeCategory:
99
+ OWNED_INTERIOR = wp.constant(0)
100
+ """Node is touched exclusively by this partition, not touched by frontier side"""
101
+ OWNED_FRONTIER = wp.constant(1)
102
+ """Node is touched by a frontier side, but belongs to an element of this partition"""
103
+ HALO_LOCAL_SIDE = wp.constant(2)
104
+ """Node belongs to an element of another partition, but is touched by one of our frontier side"""
105
+ HALO_OTHER_SIDE = wp.constant(3)
106
+ """Node belongs to an element of another partition, and is not touched by one of our frontier side"""
107
+ EXTERIOR = wp.constant(4)
108
+ """Node is never referenced by this partition"""
109
+
110
+ COUNT = 5
111
+
112
+
113
+ class NodePartition(SpacePartition):
114
+ @wp.struct
115
+ class PartitionArg:
116
+ space_to_partition: wp.array(dtype=int)
117
+
118
+ def __init__(
119
+ self,
120
+ space_topology: SpaceTopology,
121
+ geo_partition: GeometryPartition,
122
+ with_halo: bool = True,
123
+ device=None,
124
+ temporary_store: TemporaryStore = None,
125
+ ):
126
+ super().__init__(space_topology=space_topology, geo_partition=geo_partition)
127
+
128
+ self._compute_node_indices_from_sides(device, with_halo, temporary_store)
129
+
130
+ def node_count(self) -> int:
131
+ """Returns number of nodes referenced by this partition, including exterior halo"""
132
+ return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
133
+
134
+ def owned_node_count(self) -> int:
135
+ """Returns number of nodes in this partition, excluding exterior halo"""
136
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
137
+
138
+ def interior_node_count(self) -> int:
139
+ """Returns number of interior nodes in this partition"""
140
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
141
+
142
+ def space_node_indices(self):
143
+ """Return the global function space indices for nodes in this partition"""
144
+ return self._node_indices.array
145
+
146
+ @cached_arg_value
147
+ def partition_arg_value(self, device):
148
+ arg = NodePartition.PartitionArg()
149
+ arg.space_to_partition = self._space_to_partition.array.to(device)
150
+ return arg
151
+
152
+ @wp.func
153
+ def partition_node_index(args: PartitionArg, space_node_index: int):
154
+ return args.space_to_partition[space_node_index]
155
+
156
+ def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore):
157
+ from warp.fem import cache
158
+
159
+ trace_topology = self.space_topology.trace()
160
+ NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT
161
+ NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT
162
+
163
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
164
+ def node_category_from_cells_kernel(
165
+ geo_arg: self.geo_partition.geometry.CellArg,
166
+ geo_partition_arg: self.geo_partition.CellArg,
167
+ space_arg: self.space_topology.TopologyArg,
168
+ node_mask: wp.array(dtype=int),
169
+ ):
170
+ partition_cell_index = wp.tid()
171
+
172
+ cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
173
+
174
+ for n in range(NODES_PER_CELL):
175
+ space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
176
+ node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
177
+
178
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
179
+ def node_category_from_owned_sides_kernel(
180
+ geo_arg: self.geo_partition.geometry.SideArg,
181
+ geo_partition_arg: self.geo_partition.SideArg,
182
+ space_arg: trace_topology.TopologyArg,
183
+ node_mask: wp.array(dtype=int),
184
+ ):
185
+ partition_side_index = wp.tid()
186
+
187
+ side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
188
+
189
+ for n in range(NODES_PER_SIDE):
190
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
191
+
192
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
193
+ node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
194
+
195
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
196
+ def node_category_from_frontier_sides_kernel(
197
+ geo_arg: self.geo_partition.geometry.SideArg,
198
+ geo_partition_arg: self.geo_partition.SideArg,
199
+ space_arg: trace_topology.TopologyArg,
200
+ node_mask: wp.array(dtype=int),
201
+ ):
202
+ frontier_side_index = wp.tid()
203
+
204
+ side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
205
+
206
+ for n in range(NODES_PER_SIDE):
207
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
208
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
209
+ node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
210
+ elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
211
+ node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
212
+
213
+ node_category = borrow_temporary(
214
+ temporary_store,
215
+ shape=(self.space_topology.node_count(),),
216
+ dtype=int,
217
+ device=device,
218
+ )
219
+ node_category.array.fill_(value=NodeCategory.EXTERIOR)
220
+
221
+ wp.launch(
222
+ dim=self.geo_partition.cell_count(),
223
+ kernel=node_category_from_cells_kernel,
224
+ inputs=[
225
+ self.geo_partition.geometry.cell_arg_value(device),
226
+ self.geo_partition.cell_arg_value(device),
227
+ self.space_topology.topo_arg_value(device),
228
+ node_category.array,
229
+ ],
230
+ device=device,
231
+ )
232
+
233
+ if with_halo:
234
+ wp.launch(
235
+ dim=self.geo_partition.side_count(),
236
+ kernel=node_category_from_owned_sides_kernel,
237
+ inputs=[
238
+ self.geo_partition.geometry.side_arg_value(device),
239
+ self.geo_partition.side_arg_value(device),
240
+ self.space_topology.topo_arg_value(device),
241
+ node_category.array,
242
+ ],
243
+ device=device,
244
+ )
245
+
246
+ wp.launch(
247
+ dim=self.geo_partition.frontier_side_count(),
248
+ kernel=node_category_from_frontier_sides_kernel,
249
+ inputs=[
250
+ self.geo_partition.geometry.side_arg_value(device),
251
+ self.geo_partition.side_arg_value(device),
252
+ self.space_topology.topo_arg_value(device),
253
+ node_category.array,
254
+ ],
255
+ device=device,
256
+ )
257
+
258
+ self._finalize_node_indices(node_category.array, temporary_store)
259
+
260
+ node_category.release()
261
+
262
+ def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore):
263
+ category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category)
264
+
265
+ # Copy offsets to cpu
266
+ device = node_category.device
267
+ self._category_offsets = borrow_temporary(
268
+ temporary_store,
269
+ shape=category_offsets.array.shape,
270
+ dtype=category_offsets.array.dtype,
271
+ pinned=device.is_cuda,
272
+ device="cpu",
273
+ )
274
+ wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
275
+
276
+ if device.is_cuda:
277
+ # TODO switch to synchronize_event once available
278
+ wp.synchronize_stream(wp.get_stream(device))
279
+
280
+ category_offsets.release()
281
+
282
+ # Compute global to local indices
283
+ self._space_to_partition = borrow_temporary_like(node_indices, temporary_store)
284
+ wp.launch(
285
+ kernel=NodePartition._scatter_partition_indices,
286
+ dim=self.space_topology.node_count(),
287
+ device=device,
288
+ inputs=[self.node_count(), node_indices.array, self._space_to_partition.array],
289
+ )
290
+
291
+ # Copy to shrinked-to-fit array
292
+ self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device)
293
+ wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
294
+
295
+ node_indices.release()
296
+
297
+ @wp.kernel
298
+ def _scatter_partition_indices(
299
+ local_node_count: int,
300
+ node_indices: wp.array(dtype=int),
301
+ space_to_partition_indices: wp.array(dtype=int),
302
+ ):
303
+ local_idx = wp.tid()
304
+ space_idx = node_indices[local_idx]
305
+
306
+ if local_idx < local_node_count:
307
+ space_to_partition_indices[space_idx] = local_idx
308
+ else:
309
+ space_to_partition_indices[space_idx] = NULL_NODE_INDEX
310
+
311
+
312
+ def make_space_partition(
313
+ space: Optional[FunctionSpace] = None,
314
+ geometry_partition: Optional[GeometryPartition] = None,
315
+ space_topology: Optional[SpaceTopology] = None,
316
+ with_halo: bool = True,
317
+ device=None,
318
+ temporary_store: TemporaryStore = None,
319
+ ) -> SpacePartition:
320
+ """Computes the subset of nodes from a function space topology that touch a geometry partition
321
+
322
+ Either `space_topology` or `space` must be provided (and will be considered in that order).
323
+
324
+ Args:
325
+ space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
326
+ geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
327
+ space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
328
+ with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
329
+ device: Warp device on which to perform and store computations
330
+
331
+ Returns:
332
+ the resulting space partition
333
+ """
334
+
335
+ if space_topology is None:
336
+ space_topology = space.topology
337
+
338
+ space_topology = space_topology.full_space_topology()
339
+
340
+ if geometry_partition is not None:
341
+ if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
342
+ return NodePartition(
343
+ space_topology=space_topology,
344
+ geo_partition=geometry_partition,
345
+ with_halo=with_halo,
346
+ device=device,
347
+ temporary_store=temporary_store,
348
+ )
349
+
350
+ return WholeSpacePartition(space_topology)