nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,166 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <contiguity.h>
11
+ #include <id_model/id_model.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ // Minimal adaptation of OrderedIdInformation for IdModel. Note that
16
+ // the analysis only propagates forward for now.
17
+ class OrderedIdGroupInformation : public OrderedIdInformation {
18
+ public:
19
+ // Run the ordering analysis from given allocation domains through
20
+ // a given traversal path
21
+ static OrderedIdGroupInformation get(
22
+ const std::vector<IterDomain*>& alloc_domain,
23
+ const ExprPath<ExprGroup>& path_from_alloc,
24
+ const ValGraph& graph) {
25
+ OrderedIdGroupInformation info(alloc_domain, graph);
26
+ info.traverse(path_from_alloc);
27
+ return info;
28
+ }
29
+
30
+ // Traversal is based on the AlmostExact graph, so matching of iter
31
+ // domains also needs to be done with the same graph
32
+ bool isConsistentlyOrdered(IterDomain* id) const override {
33
+ return std::find_if(
34
+ consistently_ordered_ids_.begin(),
35
+ consistently_ordered_ids_.end(),
36
+ [&](IterDomain* consistent_id) -> bool {
37
+ return graph_.disjointValSets().strictAreMapped(
38
+ consistent_id, id);
39
+ }) != consistently_ordered_ids_.end();
40
+ }
41
+
42
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>::
43
+ const_iterator
44
+ findAllocIDs(IterDomain* id) const override {
45
+ // This is a little ugly workaround. id_to_alloc_ids_ is a map of
46
+ // iter domains. If it were a map from ValGroup, this lookup
47
+ // should have been O(1)
48
+ return std::find_if(
49
+ id_to_alloc_ids_.begin(),
50
+ id_to_alloc_ids_.end(),
51
+ [&](const auto& kv) -> bool {
52
+ return graph_.disjointValSets().strictAreMapped(kv.first, id);
53
+ });
54
+ }
55
+
56
+ protected:
57
+ OrderedIdGroupInformation(
58
+ const std::vector<IterDomain*>& alloc_domain,
59
+ const ValGraph& graph)
60
+ : OrderedIdInformation(alloc_domain), graph_(graph) {
61
+ using_id_graph_ = true;
62
+ }
63
+
64
+ // Currently only forward propagation is supported
65
+ void traverse(const ExprPath<ExprGroup>& path_from_alloc) {
66
+ for (const auto& [eg, direction] : path_from_alloc) {
67
+ if (direction == Direction::Backward) {
68
+ // TODO: support Backward prop
69
+ continue;
70
+ }
71
+ dispatch(eg->front());
72
+ }
73
+ }
74
+
75
+ std::vector<IterDomain*>::const_iterator findActiveId(
76
+ IterDomain* id) const override {
77
+ NVF_ERROR(id != nullptr);
78
+ auto it = std::find_if(
79
+ active_ids_.begin(),
80
+ active_ids_.end(),
81
+ [&](IterDomain* active_id) -> bool {
82
+ return active_id != nullptr &&
83
+ graph_.disjointValSets().strictAreMapped(active_id, id);
84
+ });
85
+ return it;
86
+ }
87
+
88
+ private:
89
+ const ValGraph& graph_;
90
+ };
91
+
92
+ // Adapted from ContigIDs
93
+ class ContigIDGroups {
94
+ public:
95
+ ContigIDGroups(
96
+ const std::vector<IterDomain*>& alloc_domains,
97
+ std::vector<bool> contiguity,
98
+ const ExprPath<ExprGroup>& path_from_alloc,
99
+ const ValGraph& graph,
100
+ bool is_predicate_pass);
101
+
102
+ void dispatch(const ExprGroup& eg, Direction direction) {
103
+ NVF_ERROR(!eg->empty());
104
+ Expr* expr = eg->front();
105
+
106
+ // Currently not propagating any contiguity information with
107
+ // swizzles as contiguity is generally not preserved after swizzles.
108
+ // But in follow ups we could gradually add back a few special
109
+ // cases, depending on specific swizzle type and axes.
110
+
111
+ if (auto merge = dynamic_cast<Merge*>(expr)) {
112
+ handle(merge, direction);
113
+ } else if (auto split = dynamic_cast<Split*>(expr)) {
114
+ handle(split, direction);
115
+ } else if (auto resize = dynamic_cast<Resize*>(expr)) {
116
+ handle(resize, direction);
117
+ }
118
+ }
119
+
120
+ void handle(Merge* merge, Direction direction);
121
+
122
+ void handle(Split* split, Direction direction);
123
+
124
+ void handle(Resize* resize, Direction direction);
125
+
126
+ const std::unordered_set<ValGroup>& contigIDs() const {
127
+ return contig_ids_;
128
+ }
129
+
130
+ const std::unordered_map<IterDomain*, ValGroup>& allocToContigIDs() const {
131
+ return alloc_to_contig_ids_;
132
+ }
133
+
134
+ private:
135
+ // Indexing traversal graph.
136
+ const ValGraph& graph_;
137
+ // Domains to analyze contiguity. They are typically allocation
138
+ // domains but if this is a predicate indexing pass, they are
139
+ // likely logical domains.
140
+ const std::vector<IterDomain*> alloc_domains_;
141
+ // Contiguity of alloc_domains_
142
+ const std::vector<bool> alloc_contiguity_;
143
+ const bool is_predicate_pass_;
144
+ std::unique_ptr<const OrderedIdGroupInformation> consistent_transform_info_;
145
+
146
+ // Contig domain groups
147
+ std::unordered_set<ValGroup> contig_ids_;
148
+ // Mapping of allocation domains to contig groups
149
+ std::unordered_map<IterDomain*, ValGroup> alloc_to_contig_ids_;
150
+ // All domains that have dependencies with resize ops
151
+ std::unordered_set<ValGroup> resize_deps_;
152
+ // All domains that have dependencies with non-divisible split ops
153
+ std::unordered_set<ValGroup> non_divisible_deps_;
154
+ };
155
+
156
+ // Get a contiguous indexing domain for a given allocation domain. If
157
+ // no such domain is found, just the allocation domain itself is
158
+ // returned.
159
+ std::unordered_map<IterDomain*, ValGroup> getContigDomains(
160
+ const std::vector<IterDomain*>& alloc_domains,
161
+ const std::vector<bool>& alloc_contiguity,
162
+ const ExprPath<ExprGroup>& path_from_alloc,
163
+ const ValGraph& graph,
164
+ bool is_predicate_pass);
165
+
166
+ } // namespace nvfuser
@@ -0,0 +1,359 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <disjoint_set.h>
11
+ #include <fusion.h>
12
+ #include <ir/all_nodes.h>
13
+ #include <val_graph.h>
14
+
15
+ #include <string>
16
+ #include <unordered_map>
17
+ #include <unordered_set>
18
+ #include <vector>
19
+
20
+ namespace nvfuser {
21
+
22
+ class ValGraph;
23
+ class LoopPromotionMapBuilderCallback;
24
+
25
+ struct StatefulInliningInfo {
26
+ // All producer ids within (including dependencies of) inlined loop domains,
27
+ // used for deterministic order
28
+ VectorOfUniqueEntries<IterDomain*> ordered_p_ca_ids;
29
+
30
+ // p2c mappings through the fusion within (including dependencies of) inlined
31
+ // loop domains.
32
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<Val*>>
33
+ p2c_ca_permissive_maps;
34
+
35
+ // Broadcast resolution map for root domains, including non-inlined
36
+ // root domains
37
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
38
+ p2c_root_broadcast_resolution_map;
39
+
40
+ // All IDs of all first siblings
41
+ VectorOfUniqueEntries<IterDomain*> ordered_sibling_ids;
42
+
43
+ // Mappings to other sibling IDs from ordered_sibling_ids
44
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<Val*>> sibling_maps;
45
+ };
46
+
47
+ StatefulInliningInfo buildStatefulInliningInfo(
48
+ const std::vector<Expr*>& exprs,
49
+ const ValGraph& exact_graph,
50
+ const ValGraph& permissive_graph);
51
+
52
+ // A collection of ValGraphs that are built from a fusion or series of
53
+ // expressions. These graphs are related, but have some distinct features based
54
+ // on the IdMappingMode.
55
+ //
56
+ // EXACT/PERMISSIVE mode:
57
+ //
58
+ // consumer[i0, b1] = producer[i0]
59
+ // consumer->merge(0) (consumer will now be [i0 * b1])
60
+ //
61
+ // When producer is replayed as consumer (the direction we use for mapping)
62
+ // with forwarding from ForwardingInfo the producer to consumer map will have
63
+ // both a mapping of consumer(i0) to producer(i0) as well as consumer(i0*b1) to
64
+ // producer(i0). This latter mapping is important for loop nest mappings as the
65
+ // consumer will generate a loop based on i0*b1 and the producer may be
66
+ // computeAt inside this loop nest. However, for indexing we do not want these
67
+ // two iter domains mapped as producer may be indexed as i0*i1 depending on the
68
+ // loop nest structure and how it was built.
69
+ //
70
+ // Exact mode is if the iter domain relationships from producer to consumer are
71
+ // considered the exact same size operating on matching dimensions from the root
72
+ // domain mapping.
73
+ //
74
+ // LOOP mode is important to resolve inlined broadcassts. If we have something
75
+ // like: consumer[i0o, threadIdx.x{i0i}] = producer[i0o,
76
+ // threadIdx.y{i0i}](computeAt = 1) which can easily happen when using shared
77
+ // memory. Loop is actually defined for all iteration domains, and resembles
78
+ // groups of iter domains that are effectively inlined with each other.
79
+ // Therefore iter domain's that are a common dependency of inlined loop domains
80
+ // may be loop mapped together.
81
+ //
82
+ // Loop promotion is a mechanism by which to capture inlined resolved
83
+ // broadcasts. If a consumer resolves a broadcast of a producer, and the
84
+ // producer's broadcast is inlined (in total or partially). Then the producer's
85
+ // iter domain will be "promoted" to the size of the consumers iter domain.
86
+ //
87
+ // IdMappingMode::EXACT
88
+ // Don't map any broadcast axes to non-broadcast axes
89
+ // Do not forward through any broadcast IDs
90
+ // IdMappingMode::BROADCAST
91
+ // Map any broadcast axes to non-broadcast axes
92
+ // Do not forward through any broadcast IDs
93
+ // IdMappingMode::PERMISSIVE
94
+ // Forward broadcast axes in replay
95
+ // Map all iteration domains
96
+ // Always contain root mappings (otherwise they could have been forwarded in
97
+ // broadcast)
98
+ // IdMappingMode::ALMOSTEXACT
99
+ // Forward through broadcast axes, but not through to a non-broadcast axis
100
+ // i.e. id{b1*i0}, id{i0} are mapped
101
+ // id{i1*i0}, id{i0} are not mapped (this part is the difference from
102
+ // PERMISSIVE)
103
+ // Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped
104
+ // IdMappingMode::LOOP
105
+ // Subgraph of the permissive graph. Maps only CA and their
106
+ // dependent domains.
107
+ class IdModel : public PolymorphicBase {
108
+ public:
109
+ // Sometimes fusion inputs or outputs are disconnected from expressions, in
110
+ // those cases we still may want to send in some additional tensor views from
111
+ // the Fusion that don't have expressions associated with them.
112
+ //
113
+ // All graphs are built by default. It can be disabled with
114
+ // build_graphs=false.
115
+ IdModel(
116
+ const std::vector<Expr*>& exprs,
117
+ const std::vector<TensorView*>& additional_tvs = {},
118
+ bool build_graphs = true,
119
+ bool allow_self_mapping = false,
120
+ LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback =
121
+ nullptr);
122
+
123
+ // Same as the above constructor with fusion->exprs() excpet fusion may have
124
+ // some dangling inputs/outputs that are expected to have IterDomain entries
125
+ // even though there's no possible connections from them.
126
+ //
127
+ // The validate parameter is a temporary option during the
128
+ // transition from the current ComputeAtMap.
129
+ IdModel(
130
+ Fusion* fusion,
131
+ bool build_graphs = true,
132
+ bool allow_self_mapping = false,
133
+ bool validate = false,
134
+ LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback =
135
+ nullptr);
136
+
137
+ bool hasIdGraph(IdMappingMode mode) const {
138
+ return id_graphs_.find(mode) != id_graphs_.end();
139
+ }
140
+
141
+ // Returns iter domain graph of provided mode. The graph must have
142
+ // been already built.
143
+ const ValGraph& idGraph(IdMappingMode mode) const;
144
+ ValGraph& idGraph(IdMappingMode mode);
145
+
146
+ const std::unordered_map<IterDomain*, VectorOfUniqueEntries<Expr*>>& idUses()
147
+ const {
148
+ return id_uses_;
149
+ }
150
+
151
+ const std::unordered_map<IterDomain*, VectorOfUniqueEntries<Expr*>>&
152
+ idDefinitions() const {
153
+ return id_definitions_;
154
+ }
155
+
156
+ // TODO: Seems a bit unfortunate that this isn't IterDomain local information.
157
+ const std::unordered_set<IterDomain*>& viewRfactorIds() const {
158
+ return view_rfactor_ids_;
159
+ }
160
+
161
+ std::string toString() const;
162
+
163
+ bool empty() const {
164
+ return tvs_.empty();
165
+ }
166
+
167
+ const std::vector<TensorView*>& tvs() const {
168
+ return tvs_;
169
+ }
170
+
171
+ Fusion* fusion() const {
172
+ return fusion_;
173
+ }
174
+
175
+ // Build all graphs, i.e., Exact, AlmostExact, Permissive and
176
+ // LOOP. This is by default called from the constructor
177
+ void buildAllGraphs();
178
+
179
+ // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs
180
+ // and first output of expr
181
+ ValGraph& buildExactGraph();
182
+
183
+ // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as
184
+ // Exact entries, then map anything that's either merged with a size-1 or
185
+ // split by a size-1 dimension.
186
+ ValGraph& buildAlmostExactGraph();
187
+
188
+ // Fills disjoint_ids_[IdMappingMode::BROADCAST]. Initialize it as
189
+ // Exact entries, then map through broadcasts. Build the Exact graph
190
+ // as well if not yet done.
191
+ ValGraph& buildBroadcastGraph();
192
+
193
+ // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize it as
194
+ // BROADCAST entries, then map through forwarded domains. Build the
195
+ // BROADCAST graph as well if not yet done.
196
+ ValGraph& buildPermissiveGraph();
197
+
198
+ // Fills disjoint_ids_[IdMappingMode::LOOP]. Map only inlined
199
+ // domains that are mapped in the permissive graph. Build the Exact
200
+ // and Permissive graphs as well if not yet done.
201
+ //
202
+ // (For debugging only) When force_full_loop_promotion_analysis is
203
+ // true, it always performs the full loop promotion analysis even
204
+ // when it's possible to take a quicker shortcut.
205
+ ValGraph& buildLoopGraph(bool force_full_loop_promotion_analysis = false);
206
+
207
+ // Build a graph. Dependent graphs are also built if not yet done.
208
+ void buildGraph(IdMappingMode mode);
209
+
210
+ // Build a graph if not already built
211
+ void maybeBuildGraph(IdMappingMode mode);
212
+
213
+ // Iterates over all IterDomains in id_definitions_ and calls initializeVal on
214
+ // a new ValGraph and returns it.
215
+ ValGraph initializeIdGraph(bool propagate_through_exprs = true) const;
216
+
217
+ // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and
218
+ // graph1.
219
+ ValGraph buildIntersection(
220
+ const ValGraph& graph0,
221
+ const ValGraph& graph1,
222
+ bool propagate_exprs = true) const;
223
+
224
+ const std::unordered_map<ValGroup, IterDomain*>& loopPromotionMap() const {
225
+ return loop_promotion_map_;
226
+ }
227
+
228
+ // Replay Expr but with the inputs provided. ValGraphs will be updated
229
+ // for all maps that have entries, adding the output iter domains of the
230
+ // replayed expression and adding potential mappings through the expression.
231
+ Expr* addReplayAs(std::vector<IterDomain*> new_inputs, Expr* expr);
232
+
233
+ //! Run through disjoint sets in the LOOP graph, make sure there's only one
234
+ //! non-serial parallel type in each disjoint set, set the parallel type of
235
+ //! all IterDomains in the disjoint set to that PType.
236
+ void validateAndPropagatePType();
237
+
238
+ //! (Copied from ComputeAtMap::allocateIndexVariables)
239
+ //! Run through disjoint sets in the LOOP map and allocate the index
240
+ //! variable for the associated for loop that will be generated
241
+ //! for each disjoint sets in the loop map. This pre-allocation makes
242
+ //! 2 key assumptions about computeAt map that would very likely be
243
+ //! long term invariant:
244
+ //! 1. All kir::forloop created in the lowering pass should belong
245
+ //! to one of the disjoint sets in loop map.
246
+ //! 2. The lowering pass will *never* create a loop nest with 2
247
+ //! different nesting levels mapped together, i.e. the case below
248
+ //! never occurs:
249
+ //! for i in IterDomain1
250
+ //! for j in IterDomain2
251
+ //! ...
252
+ //! With loop_map.areMapped(IterDomain1, IterDomain2) == true.
253
+ //! Under this condition, we can pre-allocate all required index
254
+ //! variable integers before creating any kir::forloop, and this
255
+ //! would help optimizing the generated integer math for indexing.
256
+ void allocateLoopIndexVariables();
257
+
258
+ // Get the index variable assigned for a given loop ID
259
+ Val* getLoopIndexVariable(
260
+ IterDomain* id,
261
+ CircularBufferLoopStage circular_buffer_loop_stage =
262
+ CircularBufferLoopStage::NotApplicable) const;
263
+
264
+ // Get the index variable assigned for a given loop group
265
+ Val* getLoopIndexVariable(
266
+ const ValGroup& loop_group,
267
+ CircularBufferLoopStage circular_buffer_loop_stage =
268
+ CircularBufferLoopStage::NotApplicable) const;
269
+
270
+ protected:
271
+ // Fills id_uses_ and id_definitions_ for all IterDomains active in the
272
+ // fusion.
273
+ void buildIterDomainDefinitionsAndUses();
274
+
275
+ // Start loop map by grouping inlined iter domains
276
+ void initializeLoopGraph(const StatefulInliningInfo& info);
277
+
278
+ // Build a map of loop groups to IterDomains that represent actual
279
+ // loops. The map is built based on the broadcast resolution with
280
+ // root domains between inlined producer and consumer tensors.
281
+ std::unordered_map<ValGroup, IterDomain*> buildLoopPromotionMap(
282
+ const StatefulInliningInfo& info);
283
+
284
+ // Errors if self mapping occurs
285
+ void assertNoSelfMapping();
286
+
287
+ // Loop graph represents the loop structure of the given fusion, so
288
+ // there must not be any mapping between the loop domains of each
289
+ // tensor.
290
+ void validateLoopGraphHasNoSelfMappedLeafDomains() const;
291
+
292
+ protected:
293
+ // Fusion where iter domains belong
294
+ Fusion* fusion_ = nullptr;
295
+
296
+ // All tensor expressions that this model analyzes
297
+ std::vector<Expr*> tv_exprs_;
298
+
299
+ // All tensors that this model analyzes
300
+ std::vector<TensorView*> tvs_;
301
+
302
+ // Tensors should not have domains that are mapped with another
303
+ // domains of the same tensor. This flag disables the check
304
+ bool allow_self_mapping_ = false;
305
+
306
+ // If true, validate graphs by comparing them with ComputeAtMap
307
+ bool validate_ = false;
308
+
309
+ // Optional callback for the loop promotion map builder for
310
+ // debugging and testing
311
+ LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback_ =
312
+ nullptr;
313
+
314
+ // By default, the permissive graph should map compliment domains as
315
+ // well. See the design doc for more details
316
+ bool permissive_graph_map_compliment_ids_ = true;
317
+
318
+ // Keeps ValGraphs containing all IterDomains for all mapping mode types.
319
+ //
320
+ // Using an array here might be nice, but it seems hard to use an enum as an
321
+ // array key
322
+ // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum
323
+ std::unordered_map<IdMappingMode, ValGraph> id_graphs_;
324
+
325
+ // If multiple transformations occur IterDomains could have multiple uses,
326
+ // however only one should be active in the given Fusion. When we resolve loop
327
+ // promotions during lowering, we can generate new iter domains from existing
328
+ // ones, so there can be multiple uses generated. Tracks all the active iter
329
+ // domain uses.
330
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<Expr*>> id_uses_;
331
+
332
+ // Make sure we don't blindly use definitions as we don't want to grab
333
+ // transformations before a tensor view's root domain. There can be
334
+ // multiple definitions due to replays.
335
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<Expr*>> id_definitions_;
336
+
337
+ std::unordered_set<IterDomain*> view_rfactor_ids_;
338
+
339
+ // Promotion domain for each loop group
340
+ std::unordered_map<ValGroup, IterDomain*> loop_promotion_map_;
341
+
342
+ // Allocated Loop index variable through the LOOP graph
343
+ std::unordered_map<ValGroup, Val*> loop_index_variable_map_;
344
+
345
+ // Allocated loop indices for circular buffer loops
346
+ std::unordered_map<
347
+ ValGroup,
348
+ std::unique_ptr<std::unordered_map<CircularBufferLoopStage, Val*>>>
349
+ circular_buffered_loop_index_variable_map_;
350
+ };
351
+
352
+ // A utility function to update a map of ValGroups to ID from an old
353
+ // Valgraph to a new ValGraph. The new graph must be a superset of the
354
+ // old graph.
355
+ std::unordered_map<ValGroup, IterDomain*> updateValGroupIdMap(
356
+ const std::unordered_map<ValGroup, IterDomain*>& stale_map,
357
+ ValGraph& new_graph);
358
+
359
+ } // namespace nvfuser
@@ -0,0 +1,81 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <dispatch.h>
11
+ #include <id_model/id_model.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ // Similar to IndexCompute but adapted for the graph-based indexing
16
+ class IdGraphIndexCompute : public OptOutDispatch {
17
+ public:
18
+ IdGraphIndexCompute(
19
+ const ValGraph& traversal_graph,
20
+ std::unordered_map<ValGroup, Val*> initial_index_map)
21
+ : traversal_graph_(traversal_graph),
22
+ index_map_(std::move(initial_index_map)) {}
23
+
24
+ // Propagate the index map through a given expr of a specified
25
+ // direction.
26
+ void propagate(const ExprGroup& expr_group, Direction direction) {
27
+ NVF_ERROR(!expr_group->empty());
28
+ // This looks a little ugly but the dispatch interface doesn't
29
+ // have a way to pass arguments
30
+ current_direction_ = direction;
31
+ dispatch(expr_group->front());
32
+ current_direction_ = Direction::Undefined;
33
+ }
34
+
35
+ const std::unordered_map<ValGroup, Val*> indexMap() const {
36
+ return index_map_;
37
+ }
38
+
39
+ private:
40
+ using OptOutDispatch::handle;
41
+
42
+ void handle(Split* split) override;
43
+
44
+ void handle(Merge* merge) override;
45
+
46
+ void handle(Swizzle* swizzle) override;
47
+
48
+ void handle(Resize* resize) override;
49
+
50
+ bool isForward(Expr* expr) const {
51
+ return current_direction_ == Direction::Forward;
52
+ }
53
+
54
+ bool hasIndex(IterDomain* id) const {
55
+ return indexMap().find(toGroup(id)) != indexMap().end();
56
+ }
57
+
58
+ Val* getIndex(IterDomain* id) const {
59
+ auto it = index_map_.find(toGroup(id));
60
+ NVF_ERROR(it != index_map_.end(), "Index not found: ", id->toString());
61
+ return it->second;
62
+ }
63
+
64
+ void setIndex(IterDomain* id, Val* idx) {
65
+ // May overwrite index. When the graph is cyclic due to, e.g.,
66
+ // resize, the index obtained by traversing most through the
67
+ // indexing path should be used (see also PR #3454)
68
+ index_map_[toGroup(id)] = idx;
69
+ }
70
+
71
+ const ValGroup& toGroup(IterDomain* id) const {
72
+ return traversal_graph_.toGroup(id);
73
+ }
74
+
75
+ private:
76
+ const ValGraph& traversal_graph_;
77
+ std::unordered_map<ValGroup, Val*> index_map_;
78
+ Direction current_direction_ = Direction::Undefined;
79
+ };
80
+
81
+ } // namespace nvfuser