nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,771 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <device_lower/pass/loop_rotation.h>
11
+ #include <disjoint_set.h>
12
+ #include <exceptions.h>
13
+ #include <fusion.h>
14
+ #include <ir/all_nodes.h>
15
+ #include <ir/cloner.h>
16
+ #include <scheduler/reduction_heuristic.h>
17
+ #include <scheduler/tools/maxinfo_propagator.h>
18
+ #include <visibility.h>
19
+
20
+ namespace nvfuser {
21
+
22
+ class ComputeAtMap;
23
+ class SchedulerRuntimeInfo;
24
+ class HeuristicDataCache;
25
+
26
+ namespace scheduler_utils {
27
+
28
+ // Assume any only half of the register file is available to spend on buffers,
29
+ // this is because when we allocate a buffer in register is has to be accesed
30
+ // with a compile time constant index. Unfortunately nvcc seems to be using
31
+ // many registers for indexing. This is a bad estimation of extra register use,
32
+ // but it's hard to get a better one.
33
+ constexpr int64_t register_file_size_full = (int64_t)256 * 1024;
34
+ constexpr int64_t register_file_size = register_file_size_full / 2;
35
+ constexpr int64_t register_file_size_56k = (int64_t)56 * 4 * 1024;
36
+
37
+ // Empirically observed number. Not guaranteed to be a good estimate
38
+ constexpr int64_t register_overhead = 40l;
39
+ constexpr int64_t max_registers_per_thread = 255l;
40
+ constexpr int64_t bytes_per_register = 4l;
41
+
42
+ constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
43
+ constexpr int64_t y_grid_limit = 65535;
44
+ constexpr int64_t z_grid_limit = 65535;
45
+ constexpr int64_t z_block_limit = 64;
46
+
47
+ // Find largest power of 2 that is a factor of n. If n==0, return largest power
48
+ // of 2 representable by int64_t
49
+ constexpr int64_t maxVectorizationWidth(int64_t n) {
50
+ if (n == 0) {
51
+ // Max representable int has null sign bit then all ones. Shift right then
52
+ // xor to preserve only the most significant bit.
53
+ int64_t m = std::numeric_limits<int64_t>::max();
54
+ return m ^ (m >> 1);
55
+ }
56
+ // For example
57
+ // n = b101101000
58
+ // n - 1 = b101100111
59
+ // ~ (n - 1) = b010011000
60
+ // n & (~ (n - 1)) = b000001000
61
+ // The key is that subtracting one flips all trailing 0s as well as the least
62
+ // significant 1, so all of the other bits will fail the &, leaving
63
+ // only that 1.
64
+ return n & (~(n - 1));
65
+ }
66
+
67
+ // Largest Power of 2 less-than n
68
+ constexpr int64_t lastPow2(int64_t n) {
69
+ NVF_ERROR(n >= 0);
70
+ n |= (n >> 1);
71
+ n |= (n >> 2);
72
+ n |= (n >> 4);
73
+ n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
74
+ n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
75
+ n |= (n >> 32); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
76
+ return std::max((int64_t)1, n - (n >> 1));
77
+ }
78
+
79
+ // round up to multiple of 8 or pow2 whichever smaller
80
+ constexpr int64_t roundUpPow2Or8(const int64_t x) {
81
+ auto round_up_pow2 = lastPow2(x);
82
+ if (round_up_pow2 < x) {
83
+ round_up_pow2 *= 2;
84
+ }
85
+ constexpr int64_t kEight = 8;
86
+ auto round_up_8 = x % kEight == 0 ? x : x + (kEight - x % kEight);
87
+ return std::min(round_up_8, round_up_pow2);
88
+ }
89
+
90
+ constexpr int64_t roundUpPow2(const int64_t x) {
91
+ auto round_up_pow2 = scheduler_utils::lastPow2(x);
92
+ if (round_up_pow2 < x) {
93
+ round_up_pow2 *= 2;
94
+ }
95
+ return round_up_pow2;
96
+ }
97
+
98
+ constexpr int64_t roundUpToN(const int64_t x, const int64_t n) {
99
+ return x % n == 0 ? x : x + (n - x % n);
100
+ }
101
+
102
+ // Div x by y, but min at 1
103
+ inline int64_t safeDiv(const int64_t x, const int64_t y) {
104
+ return std::max(x / y, (int64_t)1);
105
+ }
106
+
107
+ // Split the given dimensions in `to_split`. Also update the dimensions in
108
+ // `to_update` to the positions in the splitted tensor. Splitting one dimension
109
+ // multiple times is supported, and if this is the case, then the order of
110
+ // `to_split` matters. All given dimensions are numbers before any split.
111
+ void splitDims(
112
+ TensorView* tv,
113
+ std::vector<std::pair<int64_t, int64_t>> to_split, // (dim, size)
114
+ std::vector<int64_t>& to_update);
115
+
116
+ inline void splitDims(
117
+ TensorView* tv,
118
+ std::vector<std::pair<int64_t, int64_t>> to_split) { // (dim, size)
119
+ std::vector<int64_t> unused;
120
+ splitDims(tv, std::move(to_split), unused);
121
+ }
122
+
123
+ // Merge all the given dimensions in `to_merge` into a single dimension. Also
124
+ // update the dimensions in `to_update` to the positions in the merged tensor.
125
+ // Returns the merged dimension. All given dimensions are numbers before any
126
+ // merge.
127
+ // NOTE: merged is done as the entries in the order of `to_merge`, assuming an
128
+ // order from inner to outer
129
+ std::optional<int64_t> mergeDims(
130
+ TensorView* tv,
131
+ std::vector<int64_t> to_merge,
132
+ std::vector<int64_t>& to_update);
133
+
134
+ inline std::optional<int64_t> mergeDims(
135
+ TensorView* tv,
136
+ std::vector<int64_t> to_merge) {
137
+ std::vector<int64_t> unused;
138
+ return mergeDims(tv, std::move(to_merge), unused);
139
+ }
140
+
141
+ // Merge all reduction to the right side and returns total number of
142
+ // reduction axes.
143
+ int64_t mergeReduction(TensorView* tv);
144
+
145
+ // merge all non-reduction axes to the left side and returns total number of
146
+ // iteration axes.
147
+ int64_t mergeNonReduction(TensorView* tv);
148
+
149
+ // Propagate the parallelization from the selected dimensions of the reference
150
+ // tensor to their corresponding dimensions in all selected tensors in the DAG.
151
+ // Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
152
+ // -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
153
+ // DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
154
+ // `reference_tv`. `selected_parallel_types` are the selected parallel types.
155
+ // Empty `selected_parallel_types` means selecting all parallel types.
156
+ void parallelizeAllLike(
157
+ TensorView* reference_tv,
158
+ int64_t pos = -1,
159
+ std::vector<TensorView*> selected_tvs = {},
160
+ const std::unordered_set<ParallelType>& selected_parallel_types = {},
161
+ bool propagate_padding = true);
162
+
163
+ inline void parallelizeAllLike(
164
+ TensorView* reference_tv,
165
+ std::vector<TensorView*> selected_tvs,
166
+ const std::unordered_set<ParallelType>& selected_parallel_types = {},
167
+ bool propagate_padding = true) {
168
+ parallelizeAllLike(
169
+ reference_tv,
170
+ -1,
171
+ std::move(selected_tvs),
172
+ selected_parallel_types,
173
+ propagate_padding);
174
+ }
175
+
176
+ // Common hyperparameters used in heuristic scheduler. These hyperparameters
177
+ // are passed to SchedulerEntry::computeHeuristics through the
178
+ // HeuristicDataCache. These hyperparameters alter the generation of the
179
+ // HeuristicParams for the scheduler.
180
+ struct SchedulerHyperParameters {
181
+ SchedulerHyperParameters(
182
+ int64_t vectorize_factor_,
183
+ int64_t unroll_factor_,
184
+ int64_t threads_per_block_min_,
185
+ int64_t threads_per_block_max_)
186
+ : vectorize_factor(vectorize_factor_),
187
+ unroll_factor(unroll_factor_),
188
+ threads_per_block_min(threads_per_block_min_),
189
+ threads_per_block_max(threads_per_block_max_) {}
190
+
191
+ //! Number of elements to load per vectorize load.
192
+ int64_t vectorize_factor = 1;
193
+
194
+ //! Number of iterations to unroll for-loop.
195
+ int64_t unroll_factor = 1;
196
+
197
+ //! Minimum number of threads per block.
198
+ int64_t threads_per_block_min = 1;
199
+
200
+ //! Maximum number of threads per block.
201
+ int64_t threads_per_block_max = 1;
202
+ };
203
+
204
+ struct PersistentBufferInfo {
205
+ std::vector<TensorView*> persistent_buffers;
206
+ std::unordered_set<IterDomain*> unmappable_dims;
207
+
208
+ // Persistent buffers are needed until the path through the reduction -
209
+ // broadcast chain is resolved by any other chain using the persistent buffer
210
+ // that is not going through a reduction. This assumes all reduction paths
211
+ // have the same reduction pattern. Order is the same as persistent_buffers
212
+ std::vector<std::vector<TensorView*>> persistent_buffer_resolution_points;
213
+
214
+ // Not all persistent buffers can be projected to inputs, if a buffer can be
215
+ // projected to the inputs which may reduce the persistent buffer size (BN
216
+ // Backwards specifically) then keep track of it here. Persistent buffers that
217
+ // have a persistent buffer/reduction before them should not be projected
218
+ // through that.
219
+ std::vector<TensorView*> projectable_persistent_buffers;
220
+
221
+ // Track inputs of input projectable buffers
222
+ std::vector<TensorView*> projectable_buffer_inputs;
223
+
224
+ // Map unmappable dims to projectable_buffer_inputs
225
+ std::unordered_set<IterDomain*> unamppable_dims_projected_to_inputs;
226
+
227
+ // Some parameters used in
228
+ // normalization_scheduler_utils::isProjectBufferToInput
229
+ bool has_view_ops = false;
230
+ bool projection_with_exp_op = false;
231
+ bool projection_with_rng_op = false;
232
+ };
233
+
234
+ // Buffers whos roots can't map to all producer roots based on compute at. These
235
+ // are the buffers we would make persistent in a persistent kerenl or would have
236
+ // to recompute if we can't make a persistent kernel. This function will also
237
+ // return inputs as being marked persistent if they follow this pattern. It is
238
+ // important to note however inputs don't strictly have to be persistent as they
239
+ // can simply be read multiple times from GMEM in the same kernel.
240
+ PersistentBufferInfo persistentBuffers(Fusion* fusion);
241
+
242
+ // A persistent tv can be projected to its producers when all the producers are
243
+ // persistent tvs and there is no reduction op.
244
+ bool canProjectToPersistentProducer(
245
+ TensorView* buffer,
246
+ const std::vector<TensorView*>& producers,
247
+ const std::unordered_set<TensorView*>& persistent_buffer_set);
248
+
249
+ //! Evaluates if a persistent buffer can be projected to input tvs without
250
+ //! dependency on reduction tvs. Returns a std::pair with a boolean indicating
251
+ //! whether projection is feasible and a vector of projectable tvs.
252
+ //!
253
+ //! The function operates in two main steps:
254
+ //! (1) Checks if the persistent buffer has dependencies on any of the given
255
+ //! reduction tvs. If no dependencies are found, it returns true with an
256
+ //! empty vector of target broadcast tvs.
257
+ //! (2) If there are dependencies, it examines each reduction tv for an
258
+ //! associated broadcast tv that can be projected to. If all reduction tvs
259
+ //! have corresponding broadcast tvs, true is returned along with these tvs.
260
+ //! If any reduction tv lacks a corresponding broadcast tv, false is
261
+ //! returned with the current list of identified broadcast tvs.
262
+ std::pair<bool, std::vector<TensorView*>> canProjectToInputsWithoutReduction(
263
+ const std::vector<TensorView*> reduction_tvs,
264
+ TensorView* persistent_buffer);
265
+
266
+ struct ReductionTvProperties {
267
+ // How many elements in tensor view are there to reduce.
268
+ int64_t total_reduction_numel = 1;
269
+
270
+ // How many reductions do we need to perform, i.e. how many iter dimension.
271
+ // elements are there
272
+ int64_t total_iteration_numel = 1;
273
+
274
+ // Is the inner most dimension a reduction, if no reductions mark true.
275
+ bool fastest_dim_reduction = true;
276
+
277
+ // How many elements in the inner most dimension merging surrounding domains
278
+ // that match in type. This is used for 3D schedulers in
279
+ // reduction/normalization.
280
+ int64_t inner_most_dimension_numel = 1;
281
+
282
+ // Same thing as above, but the number of dimensions instead of the numel.
283
+ int64_t inner_most_dimension_ndims = 1;
284
+
285
+ // Merging neighboring iteration domains, and reduction domains, what's the
286
+ // resulting dimensionality of the problem.
287
+ int64_t dimensionality = 1;
288
+ };
289
+
290
+ // Fill ReductionTvProperties structure about tv
291
+ ReductionTvProperties getReductionProperties(
292
+ Fusion* fusion,
293
+ SchedulerRuntimeInfo& runtime_info,
294
+ TensorView* tv);
295
+
296
+ // Struct to store persistent buffer sizes. also holds the persistent buffer
297
+ // size of the buffers are projected to the inputs.
298
+ struct PersistentBufferSizeReturn {
299
+ int64_t persistent_buffer_size = 0;
300
+ int64_t projected_persistent_buffer_size = 0;
301
+ };
302
+
303
+ // Compute the amount of register space would be needed to perform this kernel
304
+ // persistently, only based on buffers that must be persistent, and based on the
305
+ // maximum of all minimum size requirement. i.e. if must be persistent, only
306
+ // hold persistent dimension.
307
+ PersistentBufferSizeReturn persistentBufferSize(
308
+ Fusion* fusion,
309
+ SchedulerRuntimeInfo& runtime_info,
310
+ const PersistentBufferInfo& persistent_buffers,
311
+ HeuristicDataCache* data_cache = nullptr);
312
+
313
+ // Merges tensor view to the form:
314
+ // [IterationDomain, ReductionDomain] Returns if <iteration dimensions,
315
+ // reduction dimensions>
316
+ std::pair<bool, bool> canonicalDimReduction(
317
+ Fusion* fusion,
318
+ TensorView* tv,
319
+ bool schedule_3D = false);
320
+
321
+ // Return a list of tensor views that are outputs of reduction operations,
322
+ // excluding resharding reduce expressions. If multiple outputs of an expression
323
+ // are found, only include one in the list
324
+ std::vector<TensorView*> getReductionTvs(Fusion* fusion);
325
+
326
+ // Returns a list of TensorViews that are the consumer tv for a view operation.
327
+ std::vector<TensorView*> getViewTVs(Fusion* fusion);
328
+
329
+ // Returns a list of non-reduction TensorViews that have a root domain
330
+ std::vector<TensorView*> getTVsWithNonReductionRFactor(Fusion* fusion);
331
+
332
+ // Reset inputs and outputs to global memory, everything else to local.
333
+ void clearMemorySpace(Fusion* fusion);
334
+
335
+ // Returns cached after tensors of the fusion inputs if unrolled. Otherwise
336
+ // return empty vector.
337
+ std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
338
+
339
+ // Returns the pairs of <cache of each fusion output, corresponding output> for
340
+ // all outputs.
341
+ std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
342
+ Fusion* fusion,
343
+ bool unroll);
344
+
345
+ // Ignores broadcast and reduction, returns iter domain in allocation domain
346
+ // that's "inner most".
347
+ IterDomain* innerMostAllocDim(TensorView* tv);
348
+
349
+ // Looks through fusion and finds all dims that match to the one provided in
350
+ // the tensorview provided. Iter domain must be a root domain. If inner_only,
351
+ // will only map dimensions if they're the inner most position. This is
352
+ // important when projecting a dimension between an rfactor position and its
353
+ // root position when mapping from consumer to producer. If inner_only=true,
354
+ // takes the rfactor/root dimensions that maps, projects it to the root/rfactor
355
+ // domain, but only following the inner most pass when encounting split/merge.
356
+ // When propagating backward, for split it will only propagate backwards if the
357
+ // mapped dimension is the inner portion of the split. For merge, inner_only
358
+ // doesn't make a dimension and will propagate through the inner portion of the
359
+ // merge. When propagating forward, the logic is symmetric with the backward
360
+ // case.
361
+ class FindAllMappedDims : public MaxInfoSpanningTree::Propagator {
362
+ std::unordered_map<TensorView*, IterDomain*> mapped_root_ids_;
363
+ std::unordered_map<TensorView*, IterDomain*> mapped_logical_ids_;
364
+ TensorView* starting_tv_ = nullptr;
365
+ IterDomain* starting_id_ = nullptr;
366
+ bool inner_only_;
367
+ bool vectorize_pass_;
368
+
369
+ public:
370
+ FindAllMappedDims(
371
+ TensorView* from,
372
+ IterDomain* starting_id,
373
+ bool inner_only,
374
+ bool vectorize_pass);
375
+ void setUp() override;
376
+ void propagateC2P(TensorView* from, TensorView* to) override;
377
+ void propagateP2C(TensorView* from, TensorView* to) override;
378
+ void propagateSibling(TensorView* from, TensorView* to) override;
379
+ std::unordered_set<IterDomain*> get() const;
380
+ };
381
+
382
+ // Checks if tensor view has an iteration domain in vector dims in its inner
383
+ // most root position (excluding broadcast and reduction), and checks if it is a
384
+ // contiguous dimension
385
+ bool hasInnerDim(
386
+ TensorView* tv,
387
+ std::unordered_set<IterDomain*> vector_dims,
388
+ bool should_vectorize);
389
+
390
+ // Returns all inputs and outputs that share the inner most dimension of the
391
+ // provided reference. If reference is an input it ignores reduction axes, will
392
+ // ignore all broadcast axes. If inner_only, will require inner->inner mapping
393
+ // in view, otherwise, it allows all inner->any mapping. If vectorize_pass, will
394
+ // check contiguity for vectorization, otherwise it just checks it has that
395
+ // inner dim.
396
+ std::vector<TensorView*> getInputsOutputsWithInnerDim(
397
+ TensorView* reference_tv,
398
+ bool inner_only,
399
+ bool vectorize_pass);
400
+
401
+ // Holder return struct for the below function.
402
+ struct DisjointLogicalSetInfo {
403
+ // const* to the disjoint set in disjoint_rfactor_set passed in to
404
+ // getDisjointLogicalSetsOf each iterdomain in the rfactor of ref is mapped
405
+ // to.
406
+ //
407
+ // WARNING: these pointers are relative to the disjoint_rfactor_set reference
408
+ // passed into getDisjointLogicalSetsOf it's the user's responsibility to
409
+ // maintain the lifetime of that reference to match this vector.
410
+ std::vector<const VectorOfUniqueEntries<IterDomain*>*> disjoint_sets_of_ref;
411
+
412
+ // Unique ID associated to the disjoint view group the logical id belongs to
413
+ // in disjoint_sets_of_ref. It's straight forward to map from
414
+ // disjoint_sets_of_ref to the vector, but not the other way around.
415
+ std::vector<int64_t> disjoint_set_ids;
416
+
417
+ // TensorView reference the above vectors are relative to.
418
+ TensorView* ref;
419
+ };
420
+
421
+ // Returns disjoint rfactor sets mapped onto the given reference. Returns a pair
422
+ // of vectors of size rfactorDomain of reference. Vector of
423
+ // VectorOfUniqueEntries returns a const* to the disjoint set in
424
+ // disjoint_rfactor_set the iterdomain is mapped to. Integer vector represents
425
+ // which disjoint rfactor group the logical id belongs to. It's straightforward
426
+ // to map from the former to the latter, but not the latter to former.
427
+ //
428
+ // Since we return a const* to entries in disjoint_rfactor_set, it must be
429
+ // passed in as a reference. Algorithm is N^2 based on number of dims in
430
+ // reference, but generating the disjoint rfactor set is likely the limiter on
431
+ // perf of this function.
432
+ //
433
+ // logical_reorder_map is provided to assume TensorView `of` will be reordered
434
+ // per the map
435
+ DisjointLogicalSetInfo getDisjointLogicalSetsOf(
436
+ Fusion* fusion,
437
+ TensorView* of,
438
+ DisjointSets<IterDomain*>& disjoint_rfactor_set,
439
+ const std::unordered_map<int64_t, int64_t>& logical_reorder_map = {});
440
+
441
+ // Structure to hold byte multiples for break points. I.e. if we have the
442
+ // tensors:
443
+ // T0[I0, I1] float
444
+ // T1[I0, I1] bool
445
+ // T2[I0] half
446
+ // T3 [I1] double
447
+ // and a break point of 1 the multiples would be:
448
+ // lhs_multiple = 4 + 1 + 2 = 7
449
+ // rhs_multiple = 4 + 1 + 8 = 13
450
+ struct BroadcastMultiple {
451
+ int64_t rhs_multiple = 0;
452
+ int64_t lhs_multiple = 0;
453
+ };
454
+
455
+ struct BroadcastMultipleInformation {
456
+ std::vector<int64_t> view_disjoint_set_ids;
457
+ std::vector<BroadcastMultiple> broadcast_multiples;
458
+ };
459
+
460
+ // Returns a vector of size reference_tv->getLogicalDomain().size() which
461
+ // is a view disjoint set id of each of those iter domains. If entries share the
462
+ // same value, they undergo view transformations in the fusion together.
463
+ // Broadcast multiples are also of size
464
+ // reference_tv->getLogicalDomain().size(), each entry [i] is the number of
465
+ // inputs/outputs that have a non-broadcast dimension mapped to the
466
+ // corresponding dimension in reference_tv. Broadcast multiples includes
467
+ // reference_tv if reference_tv is an input or output. Broadcast multiples is
468
+ // multiplied by data type size. In the case of view operations the broadcast
469
+ // multiple is the full multiple size if any domain in the group maps to a
470
+ // non-broadcast dimension in the given input/output. Otherwise if all
471
+ // dimensions are broadcast that input/output will not contribute to the
472
+ // multiple.
473
+ //
474
+ // logical_reorder_map is provided to assume reference_tv will be reordered per
475
+ // the map
476
+ BroadcastMultipleInformation getBroadcastMultiples(
477
+ TensorView* reference_tv,
478
+ DataType index_type,
479
+ const std::unordered_map<int64_t, int64_t>& logical_reorder_map = {});
480
+
481
+ //! Propagate current transformations on from_tv up to the given
482
+ //! position, to all tensorviews on the owning fusion that has
483
+ //! a connection with `from_tv` on the fusion graph.
484
+ void transformPropagateToAllFrom(TensorView* from_tv, int64_t pos);
485
+
486
+ //! A type of custom transform propagator that propagates iterdomain
487
+ //! transforms from a source tv to all tvs that are selected
488
+ //! using a "direction" and a "boundary".
489
+ //!
490
+ //! The propagation model always assumes a `from_tv`, a `direction` and a
491
+ //! `boundary`.
492
+ //!
493
+ //! This propagator will only transform producers and consumers
494
+ //! of `from_tv`, and all propagation modes **require** a boundary to be
495
+ //! specified to signify where the propagation should stop.
496
+ //!
497
+ //! There are currently three modes of propagation: forward, backward and
498
+ //! both-way, see comment on the interface functions for details.
499
+ struct BoundedDirectionalTransformPropagator {
500
+ //! Custom option container for configuring
501
+ //! the transform propagation actions.
502
+ //! All option values default to false unless
503
+ //! the corresponding setter is called.
504
+ struct Options {
505
+ //! If true, the transform propagator will
506
+ //! also propagate parallel types from
507
+ //! `from_tv` to all selected tvs.
508
+ bool propagate_parallel_type = false;
509
+
510
+ //! If true, the specified boundary tvs
511
+ //! will also be replayed as `from_tv`.
512
+ //! If false, they will not be affected
513
+ //! by the propagation pass.
514
+ bool transform_boundary = false;
515
+
516
+ //! Sets the position boundary in parallel
517
+ //! type propagation, see comment on
518
+ //! scheduler_utils::parallelizeAllLike.
519
+ //! Only used if propagate_parallel_type==true.
520
+ int64_t parallel_propagation_pos = -1;
521
+
522
+ //! Setter for enabling parallel type
523
+ //! propagation. see comment on the variable.
524
+ //!
525
+ //! \param up_to_pos, sets the parallel type
526
+ //! propagation boundary. see comment on
527
+ //! scheduler_utils::parallelizeAllLike.
528
+ Options propagateParallelType(int64_t up_to_pos = -1) {
529
+ propagate_parallel_type = true;
530
+ parallel_propagation_pos = up_to_pos;
531
+ return *this;
532
+ }
533
+
534
+ //! Setter for enabling propagation to
535
+ //! boundary tvs. see comment on the variable
536
+ Options propagateToBoundary() {
537
+ transform_boundary = true;
538
+ return *this;
539
+ }
540
+ };
541
+
542
+ //! Replay transforms from tensorview `from`
543
+ //! to the tensorviews that are consumers
544
+ //! of boundary tensorviews in `to` and producers of `from`.
545
+ static void backward(
546
+ TensorView* from,
547
+ int64_t pos,
548
+ std::vector<TensorView*> to,
549
+ std::optional<Options> options = std::nullopt);
550
+
551
+ //! Replay transforms from tensorview `from`
552
+ //! to the tensorviews that are producers
553
+ //! of boundary tensorviews in `to` and consumers of `from`.
554
+ static void forward(
555
+ TensorView* from,
556
+ int64_t pos,
557
+ std::vector<TensorView*> to,
558
+ std::optional<Options> options = std::nullopt);
559
+
560
+ //! Replay transforms from tensorview `from`
561
+ //! to all the tensorviews that are consumers
562
+ //! of tensorviews in `backward_to` and producers
563
+ //! of tensorviews in `forward_to` while being
564
+ //! either a producer or a consumer of tensorview `from`.
565
+ static void bothWays(
566
+ TensorView* from,
567
+ int64_t pos,
568
+ std::vector<TensorView*> backward_to,
569
+ std::vector<TensorView*> forward_to,
570
+ std::optional<Options> options = std::nullopt);
571
+
572
+ private:
573
+ //! Utility function:
574
+ //! Will realize the transform propagation to the
575
+ //! tensorview's in `included_tvs`.
576
+ //! Assumes that all tvs in included_tvs are either
577
+ //! a producer or a consumer of from_tv.
578
+ static void propagate(
579
+ TensorView* from_tv,
580
+ int64_t pos,
581
+ std::unordered_set<TensorView*> included_tvs,
582
+ Options options);
583
+ };
584
+
585
+ // Schedulers typically start by merging some axes together then splitting,
586
+ // and propagating those transformations through the dag. What we want to
587
+ // understand is if these merges can be supported through view operations.
588
+ // For example it could be problematic to support a reduction fusion:
589
+ //
590
+ // tv0[2, 3, 4]
591
+ // tv1 = sum(tv0, {1, 2})
592
+ // tv2 = view(tv0, {6, 4})
593
+ //
594
+ // Since the first step of the reduction scheduler would be tv1->merge(1, 2).
595
+ // If we tried to propagate this transformation through the view it would make
596
+ // the view invalid. If we tried to propagate the view through the reduction,
597
+ // it would attempt to merge a reduction and non-reduction dimension. So for
598
+ // these types of fusions we would like to understand that the view considers
599
+ // axis 1 and 2 of tv1 as "non-separable" axes.
600
+ //
601
+ // If IterDomains are disjoint in the returned set, then they are considered
602
+ // "separable".
603
+ // Warning: This pass generates the IdGraphs, not intended for use at runtime.
604
+ DisjointSets<IterDomain*> disjointLogicalSets(Fusion* fusion);
605
+
606
+ // Makes sure that there are no group id's left of pos that match right of pos.
607
+ // e.g.
608
+ // [1, 0, 0] pos 2 would return false
609
+ // [1, 0, 0] pos 1 would return true
610
+ bool breakIsDisjoint(std::vector<int64_t> group_ids, int64_t pos);
611
+
612
+ // Generates an old to new map to reorder tv's domain as the logical order.
613
+ // Priority is given to inner most dimensions for example:
614
+ // logical [i0, i1, i2]
615
+ // domain [i0*i2, i1]
616
+ // will produce the map {{0, 1}, {1, 0}}
617
+ // This is somewhat similar to orderTiledConcreteIdAsRoot
618
+ std::unordered_map<int64_t, int64_t> domainReorderAsLogicalMap(TensorView* tv);
619
+
620
+ // Generates an old to new map to reorder tv's domain as the logical order.
621
+ // This only handles the simple case where allocation is a permutation of
622
+ // logical domain, otherwise, the function returns an empty container.
623
+ std::unordered_map<int64_t, int64_t> maybeLogicalReorderAsAllocationMap(
624
+ TensorView* tv);
625
+
626
+ // Assumes view's are consistent as detected by
627
+ // registery.cpp::requiresForwardViewReplay returning false
628
+ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map);
629
+
630
+ //! Check if tv is an output of a fastest-dim reduction
631
+ bool isFastestDimReduction(TensorView* tv);
632
+
633
+ // A wrapper for Fusion::rotateLoop that provide more consistent interace
634
+ inline void rotateLoop(
635
+ TensorView* loop_tv,
636
+ int64_t axis,
637
+ std::unordered_set<Statement*> selection) {
638
+ auto fusion = loop_tv->fusion();
639
+ if (!fusion->hasManaged("loop_rotation")) {
640
+ fusion->manage("loop_rotation", LoopRotationParam{});
641
+ }
642
+ fusion->getManaged<LoopRotationParam>("loop_rotation")
643
+ .emplace_back(loop_tv, axis, std::move(selection));
644
+ }
645
+
646
+ //! Certain tensors may need to be placed on shared or global memory
647
+ //! due to data dependencies caused by resize operations. Create
648
+ //! caches of those tensors so that original operations producing
649
+ //! them should keep using the same memory. This avoids, for example,
650
+ //! reductions to global memory.
651
+ //!
652
+ //! Example:
653
+ //!
654
+ //! tv1 = sum(tv0)
655
+ //! tv2 = some_resize_op(tv1);
656
+ //! tv3 = some_other_op(tv1);
657
+ //!
658
+ //! When tv1 is promoted to Global, we want to avoid reducing to a
659
+ //! global memory tensor. After the transformation by this function,
660
+ //! the fusion should look like:
661
+ //!
662
+ //! tv1 = sum(tv0);
663
+ //! tv4 = tv1
664
+ //! tv4->setMemoryType(Global)
665
+ //! tv2 = some_resize_op(tv4)
666
+ //! tv3 = some_other_op(tv1);
667
+ //!
668
+ //! Note that the sum reduction is done using a Local buffer, i.e.,
669
+ //! tv1, but the data dependency for the resize op is still satisfied
670
+ //! by having a copy of tv1, i.e., tv4. Note that the other op using
671
+ //! tv1 still uses tv1.
672
+ void prepareForMemoryTypePromotion(Fusion* fusion);
673
+
674
+ //! If a consumer tensor induces a data dependency between threads,
675
+ //! move its producer to a shared memory that is sufficient to satisfy
676
+ //! the dependency. For example, if the domain is parallelized
677
+ //! with blockIdx, the producer memory type will be changed to
678
+ //! Global. A proper RAW sync will be automatically inserted when the
679
+ //! fusion is lowered.
680
+ void promoteProducerMemoryTypes(
681
+ Fusion* fusion,
682
+ const std::vector<TensorView*>& input_caches);
683
+
684
+ //! Get all tensors that are connected to from_tvs without going through
685
+ //! any tvs in the cutoff_tv_set.
686
+ std::unordered_set<TensorView*> getAllTvsFrom(
687
+ const std::vector<TensorView*>& from_tvs,
688
+ const std::unordered_set<TensorView*>& cutoff_tv_set);
689
+
690
+ //! Get the persistent buffer size of a tensor
691
+ int64_t getPersistentBufferSizeOfTensor(
692
+ const TensorView* buffer,
693
+ SchedulerRuntimeInfo& runtime_info,
694
+ const PersistentBufferInfo& persistent_buffer_info);
695
+
696
+ //! The required shared memory size for a block inclues two parts: (1) smem
697
+ //! for persistent buffers and (2) overhead. The overhead includes space
698
+ //! reserved by the CUDA driver and reduction workspace which depends on the
699
+ //! number of threads per block specified by the parameter threads_per_block.
700
+ //! By default, the function uses the maximum allowed number of threads per
701
+ //! block (threads_per_block = -1) to calculate the overhead. The caller can
702
+ //! specify a different value if they are sure about the max value used at
703
+ //! runtime.
704
+ int64_t getSharedMemoryOverheadPerBlock(
705
+ Fusion* fusion,
706
+ const std::vector<TensorView*>& reduction_tvs,
707
+ int64_t threads_per_block = -1);
708
+
709
+ // Returns true if any Expr in `fusion` is resharding.
710
+ bool isResharding(Fusion* fusion);
711
+
712
+ // Move non-concretized broadcast domains to innermost
713
+ // positions. Broadcast domains mapped with any domains of given tvs
714
+ // are ignored.
715
+ //
716
+ // The goal here is to find domains that are not scheduled by
717
+ // propagation from reference tensors (i.e., ignored_tvs). All
718
+ // schedulers make sure to include only schedulable domains but they
719
+ // may also allow to have non-concretized broadcast domains that have
720
+ // no mapping with any of reference tensors. Since they are
721
+ // non-concretized, they should be safe to ignore. Ideally, they
722
+ // should just be removed from the fusion. For now, they are moved to
723
+ // innermost positions to prevent them from interfering
724
+ // inlining. If they happened to be at the
725
+ // outermost position, the tensor wouldn't be inlined at all. See
726
+ // issue #2686 and PR #2799.
727
+ void moveNonConcretizedBroadcastInnermost(
728
+ Fusion* fusion,
729
+ const std::unordered_set<TensorView*>& ignored_tvs = {});
730
+
731
+ // Returns a factor represents the computation cost of the given fusion.
732
+ // Estimated using the number of MUFU operations, each weighted with a
733
+ // predefined factor.
734
+ int64_t getComputationCostFactor(Fusion* fusion);
735
+
736
+ // Returns the required bytes in flight to saturate the memory bandwidth.
737
+ int64_t getRequiredBytesInFlight();
738
+
739
+ // Returns true if the device has a high bandwidth to compute raito.
740
+ bool isHighBandwidthFlopsRatio();
741
+
742
+ // Return true if the fusion has computation requires Floating-Point
743
+ // Multi-Function (MUFU) units, e.g. cos, sin, exponent, logarithm, sine,
744
+ // cosine, square root, hyperbolic tangent. Currently, we only tested tanh, exp,
745
+ // and Reciprocal. Note that, if compiled with fast math (not supported yet) or
746
+ // directly lowered with inlined ptx, needs to revise the inner reduction
747
+ // heuristics which uses this function to set the optimal unroll factor.
748
+ bool hasExpensiveMUFUops(Fusion* fusion);
749
+ // Reorder DID parallelized axes to outermost positions. Returns
750
+ // the position of the outermost non-DID axis.
751
+ int64_t reorderDevicesToOuter(TensorView* tv);
752
+
753
+ // Returns number of non-reduction/non-broadcas/non-device dims in logical
754
+ // domain
755
+ inline int64_t nLogicalDims(const TensorView* tv) {
756
+ auto logical_dom = tv->getLogicalDomain();
757
+ int64_t tv_n_dims = 0;
758
+ for (auto dim : logical_dom) {
759
+ if (!dim->isReduction() && !dim->isBroadcast() && !dim->isDeviceDim()) {
760
+ tv_n_dims++;
761
+ }
762
+ }
763
+ return tv_n_dims;
764
+ }
765
+
766
+ // Reorer the loop domain of a given tensor to align with a given list of
767
+ // reference IDs. Non-matching loop IDs are placed outermost positions.
768
+ void reorderTensorLike(TensorView* tv, const std::vector<IterDomain*>& ref);
769
+
770
+ } // namespace scheduler_utils
771
+ } // namespace nvfuser