nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,446 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <disjoint_set.h>
11
+ #include <ir/all_nodes.h>
12
+
13
+ #include <iostream>
14
+ #include <string>
15
+ #include <type_traits>
16
+ #include <unordered_map>
17
+ #include <vector>
18
+
19
+ namespace nvfuser {
20
+
21
+ // ValGraph is a DAG of Vals and Exprs connected by their input and
22
+ // output dependencies. Each graph node is a collection of
23
+ // either Vals or Exprs that are grouped together through mapVals and
24
+ // mapExprs, respectively.
25
+ //
26
+ // The primary use case of ValGraph is for representing groupings and
27
+ // dependencies of iteration domains. For example, given a fusion as
28
+ // shown below:
29
+ //
30
+ // T1 = set(T0);
31
+ // T2 = set(T1);
32
+ //
33
+ // T0: root [I0, I1], loop [I0, I1]
34
+ // T1: root [I2, I3], loop [I2*I3/4, 4]
35
+ // T2: root [I4, I5], loop [I4*I5/4, 4]
36
+ //
37
+ // The Exact ValGraph consists of ValGroups of:
38
+ //
39
+ // - {I0, I2, I4}
40
+ // - {I1, I3, I5}
41
+ // - {I2*I3, I4*I5}
42
+ // - {I2*I3/4, I4*I5/4}
43
+ // - {4, 4}
44
+ //
45
+ // and ExprGroups of:
46
+ //
47
+ // - {merge of I2 and I3, merge of I4 and I5}
48
+ // - {split of I2*I3, split of I4*I5}
49
+ //
50
+ // ValGraph can be used with any Val types, however, it's currenty
51
+ // only tested with IterDomain. Some of the routines might need to be
52
+ // extended for other Val types.
53
+
54
+ using ValGroup = std::shared_ptr<VectorOfUniqueEntries<Val*>>;
55
+ using ValGroups = VectorOfUniqueEntries<ValGroup>;
56
+ using ExprGroup = std::shared_ptr<VectorOfUniqueEntries<Expr*>>;
57
+ using ExprGroups = VectorOfUniqueEntries<ExprGroup>;
58
+
59
+ class ValGraph {
60
+ public:
61
+ ValGraph() = default;
62
+
63
+ ValGraph(const ValGraph& other);
64
+ ValGraph(ValGraph&& other) = default;
65
+
66
+ ValGraph& operator=(const ValGraph& other);
67
+ ValGraph& operator=(ValGraph&& other) = default;
68
+
69
+ ValGraph(bool propagate_through_exprs)
70
+ : propagate_through_exprs_(propagate_through_exprs) {}
71
+
72
+ // Returns the disjoint val set.
73
+ const DisjointSets<Val*>& disjointValSets() const {
74
+ return disjoint_vals_;
75
+ }
76
+
77
+ // Returns the disjoint Expr set.
78
+ const DisjointSets<Expr*>& disjointExprSets() const {
79
+ return disjoint_exprs_;
80
+ }
81
+
82
+ // Return if there's a group entry in the graph for this expr
83
+ bool hasGroup(Expr* expr) const;
84
+
85
+ // Return if there's a group entry in the graph for this val
86
+ bool hasGroup(Val* val) const;
87
+
88
+ // Convert expr to its exprGroup, assert that it exists.
89
+ const ExprGroup& toGroup(Expr* expr) const;
90
+
91
+ // Convert Val to its ValGroup, assert that it exists.
92
+ const ValGroup& toGroup(Val* val) const;
93
+
94
+ // Convert a vector-like container of Val* or Expr* to their
95
+ // ValGroups or ExprGroups. The vector-like container type must
96
+ // define the element type as value_type
97
+ template <
98
+ typename ContainerType,
99
+ typename ElementType = typename std::remove_pointer<
100
+ typename ContainerType::value_type>::type,
101
+ typename RetType = typename std::conditional<
102
+ std::is_base_of<Val, ElementType>::value,
103
+ ValGroups,
104
+ ExprGroups>::type,
105
+ typename = std::enable_if_t<
106
+ std::is_base_of<Val, ElementType>::value ||
107
+ std::is_base_of<Expr, ElementType>::value>>
108
+ RetType toGroups(const ContainerType& entries) const {
109
+ RetType groups;
110
+ for (auto entry : entries) {
111
+ groups.pushBack(toGroup(entry));
112
+ }
113
+ return groups;
114
+ }
115
+
116
+ // Return output/input Val groups of provided expr
117
+ // Note that the same ValGroup can show up multiple times, so the
118
+ // output type cannot be VectorOfUniqueEntries
119
+ std::vector<ValGroup> outputGroups(const ExprGroup& expr) const;
120
+ std::vector<ValGroup> inputGroups(const ExprGroup& expr) const;
121
+
122
+ // Return Val groups that have no definition.
123
+ ValGroups getTerminatingInputs() const;
124
+
125
+ // Recursively traverses uses of the IdGroups in 'of' and returns all
126
+ // ExprGroups that have a use in their definition of provided of IdGroups.
127
+ ExprGroups allUsesOf(const ValGroups& of) const;
128
+
129
+ // Recursively traverses definitions of the IdGroups in 'of' and returns all
130
+ // ExprGroups used in this history of defining the 'of' IdGroups.
131
+ ExprGroups allDefinitionsOf(const ValGroups& of) const;
132
+
133
+ //! Returns the expressions associated with the
134
+ //! definitions of the provided ValGroup.
135
+ //!
136
+ //! Each ExprGroup of the returned ExprGroup vector is proven to be
137
+ //! equivalent. The ExprGroup vector holds expression groups that are not
138
+ //! equivalent, but produce one of the ValGroups within the same disjoint Val
139
+ //! set.
140
+ const ExprGroups& getDefinitions(const ValGroup& val_group) const;
141
+
142
+ //! Same as getDefinitions but for uses instead of
143
+ //! definitions
144
+ const ExprGroups& getUses(const ValGroup& val_group) const;
145
+
146
+ bool hasDefinitions(const ValGroup& val_group) const;
147
+
148
+ bool hasUses(const ValGroup& val_group) const;
149
+
150
+ // Uses the Valgraph to produce mappings between from and to.
151
+ // Supports one to many mappings. If a single Val in from maps to
152
+ // multiple Vals in to, the order of the Vals in value of
153
+ // the map is preserved to be the order provided in to.
154
+ //
155
+ // Example:
156
+ // tv0: [i0, b1]
157
+ // tv1: [i2, i3]
158
+ // tv2: [i4, i5]
159
+ // tv2 = tv0 + tv1
160
+ //
161
+ // tv0: [i0*b1] CA(1)
162
+ // tv1: [i2*i3] CA(1)
163
+ // tv2: [i4*i5] CA(1)
164
+ //
165
+ // Between tv0 and tv2, the Permissive graph would map:
166
+ // {i0, i4}
167
+ // {b1, i5}
168
+ // {i0*b1, i4*i5}
169
+ //
170
+ // Here, buildMapBetween with:
171
+ // from: {i0, b1, i0*b1}
172
+ // to: {i4, i5, i4*i5}
173
+ // will return a map of:
174
+ // i0: {i4}
175
+ // b1: {i5}
176
+ // i0*b1: {i4*i5}
177
+ std::unordered_map<Val*, VectorOfUniqueEntries<Val*>> buildMapBetween(
178
+ const std::vector<Val*>& from,
179
+ const std::vector<Val*>& to) const;
180
+
181
+ // Alias of the above on unique vector entries
182
+ std::unordered_map<Val*, VectorOfUniqueEntries<Val*>> buildMapBetween(
183
+ const VectorOfUniqueEntries<Val*>& from,
184
+ const VectorOfUniqueEntries<Val*>& to) const;
185
+
186
+ std::string toString() const;
187
+
188
+ std::string toGraphvizDotGraph() const;
189
+
190
+ // Initializes entries for the provided Val with its definitions and
191
+ // uses. The provided Val will have its own new ValGroup, each item in the
192
+ // definitions and uses will become a new ExprGroup, and these new ExprGroups
193
+ // will be the definitions and uses of the new ValGroup.
194
+ void initializeVal(
195
+ Val* val,
196
+ const VectorOfUniqueEntries<Expr*>& definitions,
197
+ const VectorOfUniqueEntries<Expr*>& uses);
198
+
199
+ // Same as the above exept val->definition() and val->uses() are
200
+ // used
201
+ void initializeVal(Val* val);
202
+
203
+ // Initializes entries for the provided Val. The provided Val will be added to
204
+ // the provided existing ValGroup. There will be no changes on the definitions
205
+ // and uses of the provided ValGroup.
206
+ void initializeVal(Val* v, ValGroup vg) {
207
+ disjoint_vals_.appendToSet(v, vg);
208
+ }
209
+
210
+ // Add expr to the disjoint sets as a sole group. Used for
211
+ // registering replayed domains and exprs. Error if the expr is
212
+ // already registered.
213
+ void registerExpr(Expr* expr);
214
+
215
+ // Returns true if first and second are expressions through which
216
+ // this ValGraph has matching inputs (if forward), or outputs (if not
217
+ // forward). Returning true means the expressions are "the same", in terms
218
+ // they modify matching original inputs by the same amount.
219
+ bool exprsMap(Expr* first, Expr* second, bool forward) const;
220
+
221
+ // Check basic consistencies of val and expr groups and their
222
+ // mappings.
223
+ void validateConsistency() const;
224
+
225
+ void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) {
226
+ unique_uses_.at(id_group).pushBack(uses);
227
+ }
228
+
229
+ void addUniqueDefinitions(const ValGroup& id_group, const ExprGroup& defs) {
230
+ unique_definitions_.at(id_group).pushBack(defs);
231
+ }
232
+
233
+ // Set val0 and val1 to mapped in this graph, attempt to propagate
234
+ // new mapping through val0/val1 definitions/uses.
235
+ void mapVals(Val* val0, Val* val1);
236
+
237
+ // Checks if expr0 and expr1 should map together, maps them together, and if
238
+ // expression propagation is on, propagates mapping through
239
+ // them. The forward parameter determines the direction of the
240
+ // propagation. The expressions are mapped if the inputs are mapped
241
+ // when the forward parameter is true. This should
242
+ // be the only call in ValGraph to mapThroughExpr.
243
+ void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward);
244
+
245
+ // Can't back prop through merge without making sure one input actually
246
+ // matches. This can be done on a map or extent basis.
247
+ // TODO: Move this to val_graph.cpp once validation_utils.cpp is
248
+ // retired.
249
+ template <typename T>
250
+ static bool shouldMapMergeBackward(
251
+ Merge* merge0,
252
+ Merge* merge1,
253
+ const DisjointSets<T*>& id_sets) {
254
+ auto extent_match = [](IterDomain* id0, IterDomain* id1) -> bool {
255
+ return id0->extent()->sameAs(id1->extent()) ||
256
+ (id0->extent()->isConstInt() && id1->extent()->isConstInt() &&
257
+ id0->extent()->evaluate().as<int64_t>() ==
258
+ id1->extent()->evaluate().as<int64_t>());
259
+ };
260
+
261
+ // If one pair of the domains are mapped in the given graph, the
262
+ // backward merge is considered mapped
263
+ if (id_sets.permissiveAreMapped(merge0->outer(), merge1->outer()) ||
264
+ id_sets.permissiveAreMapped(merge0->inner(), merge1->inner())) {
265
+ return true;
266
+ }
267
+
268
+ // Considered mapped if the extents are equal
269
+ if (extent_match(merge0->outer(), merge1->outer()) ||
270
+ extent_match(merge0->inner(), merge1->inner())) {
271
+ return true;
272
+ }
273
+
274
+ // The mapped ID group may have different extents depending on the
275
+ // mapping conditions. For example, the Permissive graph may have a
276
+ // symbolic extent as well as an extent of 1 for broadcast
277
+ // domains. Those other mapped domains need to be checked as well.
278
+
279
+ // First, the outer groups
280
+ auto outer0_group = id_sets.mappingExists(merge0->outer())
281
+ ? id_sets.disjointSetMap().at(merge0->outer())
282
+ : std::make_shared<VectorOfUniqueEntries<T*>>(
283
+ VectorOfUniqueEntries<T*>{merge0->outer()});
284
+ auto outer1_group = id_sets.mappingExists(merge1->outer())
285
+ ? id_sets.disjointSetMap().at(merge1->outer())
286
+ : std::make_shared<VectorOfUniqueEntries<T*>>(
287
+ VectorOfUniqueEntries<T*>{merge1->outer()});
288
+
289
+ for (T* outer0 : *outer0_group) {
290
+ for (T* outer1 : *outer1_group) {
291
+ if (extent_match(
292
+ outer0->template as<IterDomain>(),
293
+ outer1->template as<IterDomain>())) {
294
+ return true;
295
+ }
296
+ }
297
+ }
298
+
299
+ // Check the inner groups as well if not already matched
300
+ auto inner0_group = id_sets.mappingExists(merge0->inner())
301
+ ? id_sets.disjointSetMap().at(merge0->inner())
302
+ : std::make_shared<VectorOfUniqueEntries<T*>>(
303
+ VectorOfUniqueEntries<T*>{merge0->inner()});
304
+ auto inner1_group = id_sets.mappingExists(merge1->inner())
305
+ ? id_sets.disjointSetMap().at(merge1->inner())
306
+ : std::make_shared<VectorOfUniqueEntries<T*>>(
307
+ VectorOfUniqueEntries<T*>{merge1->inner()});
308
+
309
+ for (T* inner0 : *inner0_group) {
310
+ for (T* inner1 : *inner1_group) {
311
+ if (extent_match(
312
+ inner0->template as<IterDomain>(),
313
+ inner1->template as<IterDomain>())) {
314
+ return true;
315
+ }
316
+ }
317
+ }
318
+
319
+ return false;
320
+ }
321
+
322
+ private:
323
+ // Map expr0 and expr1 with each other, update unique_definitions_
324
+ // unique_uses_
325
+ // TODO: Make this variant hidden?
326
+ void mapExprs(Expr* expr0, Expr* expr1);
327
+
328
+ // Checks if expr's are considered "the same" where sameness is
329
+ // defined as inputs and outputs in the same position across
330
+ // expressions are mapped. If the expressions are determined the
331
+ // same then
332
+ //
333
+ // if forward
334
+ // will map outputs
335
+ // else
336
+ // will map inputs
337
+ //
338
+ // Returns true if expressions were mapped through.
339
+ bool mapThroughExpr(Expr* first, Expr* second, bool forward);
340
+
341
+ private:
342
+ // If propagate_through_exprs_ = false, then mapThroughExpr will not be called
343
+ // as a consequence of calling mapVals. As well as mapThroughExpr will not be
344
+ // called (again) as a result of calling mapThroughExpr.
345
+ //
346
+ // Note: For the second sentence of above... mapThroughExpr can call mapVals
347
+ // which could in return call mapThoughExpr again, but
348
+ // propagate_through_exprs_ as mentioned above prevents that from happening.
349
+ bool propagate_through_exprs_ = true;
350
+
351
+ // Keeps a disjoint set entry for all Vals.
352
+ DisjointSets<Val*> disjoint_vals_;
353
+
354
+ // Keeps a disjoint set entry for all Exprs.
355
+ DisjointSets<Expr*> disjoint_exprs_;
356
+
357
+ // Definitions of ValGroup. There can be multiple definitions due to
358
+ // replays.
359
+ std::unordered_map<ValGroup, ExprGroups> unique_definitions_;
360
+
361
+ std::unordered_map<ValGroup, ExprGroups> unique_uses_;
362
+ };
363
+
364
+ struct ValGroupAndItsGraph {
365
+ ValGroup group;
366
+ ValGraph* graph;
367
+ bool operator==(const ValGroupAndItsGraph& other) const {
368
+ return group == other.group && graph == other.graph;
369
+ }
370
+ bool operator!=(const ValGroupAndItsGraph& other) const {
371
+ return !operator==(other);
372
+ }
373
+ operator const ValGroup&() const {
374
+ return group;
375
+ }
376
+ };
377
+
378
+ inline std::ostream& operator<<(
379
+ std::ostream& os,
380
+ const ValGroupAndItsGraph& g) {
381
+ return os << g.group;
382
+ }
383
+
384
+ // Returns the first pair of id's in ids detected to match each other on the
385
+ // given ID graph. TODO: what this is really looking for is if
386
+ // there's any overlapping between the iter domains in the provided set.
387
+ //
388
+ // i.e. if we have:
389
+ // tv0 = arange(6).reshape({3, 2})
390
+ // tv1 = tv0[3, 2].t()
391
+ // tv2 = tv0[3, 2].reshape({2, 3})
392
+ // tv3 = tv1 + tv2
393
+ //
394
+ // Then we can see this overlap in the tv3 expression as:
395
+ //
396
+ // tv0 = { {0, 1, 2},
397
+ // {3, 4, 5} }
398
+ //
399
+ // tv1 = { {0, 3},
400
+ // {1, 4},
401
+ // {2, 5} }
402
+ //
403
+ // tv2 = { {0, 1},
404
+ // {2, 3},
405
+ // {4, 5} }
406
+ //
407
+ // The elements in tv1 {3, 1, 4, 2}, map respectively to the elements in tv2
408
+ // {1, 2, 3, 4}. The reason this is so important is it means that generating
409
+ // tv3 is no longer a trivially parallelizable problem (if we include the dag
410
+ // all the way to tv0). So tv0's axes cannot be inlined across both the tv0
411
+ // and tv1 path. This breaks some assumptions we have today in schedulers that
412
+ // will assume tv2 can be trivially inlined/parallelized. Instead we'd need to
413
+ // take into consideration the effective communication going on here, so that
414
+ // we pull multiple values of tv0 to compute tv3.
415
+ //
416
+ // Note, however, that the above example is not detectable at this
417
+ // moment as the self mapping is partial through reshape. The analysis
418
+ // below would need to be extended to consider producer and consumers
419
+ // of domains as well rather than just root, logical and loop domains.
420
+ std::optional<std::pair<IterDomain*, IterDomain*>> detectSelfMapping(
421
+ const std::vector<IterDomain*>& ids,
422
+ const ValGraph& id_graph);
423
+
424
+ struct SelfMapping {
425
+ IterDomain* id1;
426
+ IterDomain* id2;
427
+ // For debugging, records which domain `id1` and `id2` belong to. This value
428
+ // is either "Root", "Logical", or "Leaf". Consider making it an enum.
429
+ std::string where;
430
+ };
431
+
432
+ // Returns if a self mapping was detected that would invalidate assumptions of
433
+ // the overall lowering system.
434
+ //
435
+ // It is assumed that for any tensor represented by a list of domains,
436
+ // those domains should never be mapped with each other. It may be
437
+ // possible to lift this assumption, but it's unclear if it could
438
+ // matter in practice.
439
+ //
440
+ // TODO: Can we make this more of an alias analysis?
441
+ // Ref: https://github.com/csarofeen/pytorch/pull/1954#discussion_r961940498
442
+ std::optional<SelfMapping> hasSelfMapping(
443
+ const TensorView* tv,
444
+ const ValGraph& id_graph);
445
+
446
+ } // namespace nvfuser
@@ -0,0 +1,259 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <bfs.h>
11
+ #include <disjoint_set.h>
12
+ #include <id_model/to_string.h>
13
+ #include <ir/all_nodes.h>
14
+ #include <val_graph.h>
15
+
16
+ namespace nvfuser {
17
+
18
+ // Iterates through a Val Graph in topological order, calling handle on
19
+ // all Val and all Expr groups in a forward topological order.
20
+ //
21
+ // Warning: A ValGraph is not guaranteed to be a DAG. In fact, the
22
+ // AlmostExact and Permissive graphs would have cycles with a ValGroup
23
+ // and an ExprGroup. For example:
24
+ //
25
+ // [i0, 1]
26
+ // merge
27
+ // [i0*1]
28
+ // Current ValGroups: {{i0}, {1}, {i0*1}}
29
+ // map i0 and i0*1 as they effectively have the same extent
30
+ // Final ValGroups: {{i0, i0*1}, {1}}
31
+ //
32
+ // Here, the merge expr is the user of i0 and the definition of
33
+ // i0*1. Since i0 and i0*1 are mapped, the dependency chain looks
34
+ // like:
35
+ //
36
+ // {i0, i0*1} ----> {merge} ----> {i0, i0*1}
37
+ // use def
38
+ //
39
+ // These ExprGroups are called trivial ExprGroups (see also
40
+ // ValGraph::isTrivialExprGroup).
41
+ //
42
+ // Strictly speaking, these cycles mean there's no valid topological
43
+ // order anymore. In our use cases for IdModel, however, it's likely
44
+ // sufficient to return an ordering such as:
45
+ //
46
+ // {i0, i0*1} -> {merge}
47
+ //
48
+ // I.e., we visit {i0, i0*1} first even though {merge} is technically
49
+ // a definition.
50
+ //
51
+ // Another alternative may be simply giving up when such a cycle is
52
+ // detected, which may be more preferrable as it would be less
53
+ // confusing. At this moment, this visitor is only used with graphs
54
+ // with no such cycle. Should be revisited when necessary.
55
+ //
56
+ // Warning: This is not a great iterator if there's a desire to minimize paths
57
+ // traveled to simply visit all ValGroups in order. See ExprsBetween to see how
58
+ // we might minimize paths.
59
+ class ValGraphVisitor {
60
+ public:
61
+ ValGraphVisitor() = delete;
62
+
63
+ ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete;
64
+
65
+ ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete;
66
+
67
+ virtual ~ValGraphVisitor() = default;
68
+
69
+ protected:
70
+ ValGraphVisitor(const ValGraph& val_graph, bool allow_cycle = true)
71
+ : val_graph_(val_graph), allow_cycle_(allow_cycle) {}
72
+
73
+ ValGraphVisitor(const ValGraphVisitor& other) = default;
74
+
75
+ ValGraphVisitor(ValGraphVisitor&& other) = default;
76
+
77
+ virtual void handle(const ValGroup& val_group) = 0;
78
+ virtual void handle(const ExprGroup& expr_group) = 0;
79
+
80
+ // Returns if the traversal was successful. If false, error_message_
81
+ // should be populated.
82
+ bool traverse();
83
+
84
+ const ValGraph& graph() {
85
+ return val_graph_;
86
+ };
87
+
88
+ const std::string& errorMessage() const {
89
+ return error_message_;
90
+ }
91
+
92
+ private:
93
+ const ValGraph& val_graph_;
94
+ bool allow_cycle_ = true;
95
+ std::string error_message_;
96
+ };
97
+
98
+ // Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor.
99
+ class ValGraphStmtSort : public ValGraphVisitor {
100
+ public:
101
+ ValGraphStmtSort(const ValGraph& val_graph, bool allow_cycle = true)
102
+ : ValGraphVisitor(val_graph, allow_cycle) {
103
+ NVF_ERROR(ValGraphVisitor::traverse(), errorMessage());
104
+ }
105
+
106
+ // Return non-reference so that code like below can work
107
+ // for (auto expr_group: ValGraphStmtSort(graph).exprs())
108
+ ExprGroups exprs() const {
109
+ return sorted_exprs_;
110
+ }
111
+
112
+ ValGroups vals() const {
113
+ return sorted_vals_;
114
+ }
115
+
116
+ ~ValGraphStmtSort() override = default;
117
+
118
+ protected:
119
+ using ValGraphVisitor::handle;
120
+
121
+ void handle(const ValGroup& val_group) override {
122
+ sorted_vals_.pushBack(val_group);
123
+ }
124
+
125
+ void handle(const ExprGroup& expr_group) override {
126
+ sorted_exprs_.pushBack(expr_group);
127
+ }
128
+
129
+ ExprGroups sorted_exprs_;
130
+ ValGroups sorted_vals_;
131
+ };
132
+
133
+ bool isCyclic(const ValGraph& graph);
134
+
135
+ class ValGraphDefinitions {
136
+ const ValGraph& graph_;
137
+
138
+ public:
139
+ ValGraphDefinitions(const ValGraph& graph) : graph_(graph) {}
140
+ decltype(auto) operator()(const ValGroup& val_group) const {
141
+ return graph_.getDefinitions(val_group);
142
+ }
143
+ };
144
+
145
+ class ValGraphUses {
146
+ const ValGraph& graph_;
147
+
148
+ public:
149
+ ValGraphUses(const ValGraph& graph) : graph_(graph) {}
150
+ decltype(auto) operator()(const ValGroup& val_group) const {
151
+ return graph_.getUses(val_group);
152
+ }
153
+ };
154
+
155
+ class ValGraphInputs {
156
+ const ValGraph& graph_;
157
+
158
+ public:
159
+ ValGraphInputs(const ValGraph& graph) : graph_(graph) {}
160
+ decltype(auto) operator()(const ExprGroup& expr_group) const {
161
+ return graph_.inputGroups(expr_group);
162
+ }
163
+ };
164
+
165
+ class ValGraphOutputs {
166
+ const ValGraph& graph_;
167
+
168
+ public:
169
+ ValGraphOutputs(const ValGraph& graph) : graph_(graph) {}
170
+ decltype(auto) operator()(const ExprGroup& expr_group) const {
171
+ return graph_.outputGroups(expr_group);
172
+ }
173
+ };
174
+
175
+ template <>
176
+ struct GetValType<ExprGroup> {
177
+ using type = ValGroup;
178
+ };
179
+
180
+ class ValGraphBFS : public BFS<
181
+ ExprGroup,
182
+ ValGroup,
183
+ ValGraphDefinitions,
184
+ ValGraphUses,
185
+ ValGraphInputs,
186
+ ValGraphOutputs> {
187
+ public:
188
+ ValGraphBFS(
189
+ const ValGraph& graph,
190
+ std::vector<NodeType> from_groups,
191
+ std::vector<NodeType> to_groups,
192
+ bool require_all_to_visited = true,
193
+ Direction allowed_direction = Direction::Undefined)
194
+ : BFS(ValGraphDefinitions(graph),
195
+ ValGraphUses(graph),
196
+ ValGraphInputs(graph),
197
+ ValGraphOutputs(graph),
198
+ std::move(from_groups),
199
+ std::move(to_groups),
200
+ require_all_to_visited,
201
+ allowed_direction) {}
202
+
203
+ // Just a shortcut to the generic getExprsBetween
204
+ static std::pair<ValGraphBFS::ExprPath, bool> getExprGroupsBetween(
205
+ const ValGraph& graph,
206
+ const ValGroups& from,
207
+ const ValGroups& to,
208
+ bool require_all_to_visited = true,
209
+ Direction allowed_direction = Direction::Undefined) {
210
+ return getExprsBetween<ValGraphBFS>(
211
+ from.vector(),
212
+ to.vector(),
213
+ require_all_to_visited,
214
+ allowed_direction,
215
+ graph);
216
+ }
217
+ };
218
+
219
+ class ValGraphPermissiveBFS : public BFSWithPermissiveDependence<
220
+ ExprGroup,
221
+ ValGroup,
222
+ ValGraphDefinitions,
223
+ ValGraphUses,
224
+ ValGraphInputs,
225
+ ValGraphOutputs> {
226
+ public:
227
+ ValGraphPermissiveBFS(
228
+ const ValGraph& graph,
229
+ std::vector<NodeType> from_groups,
230
+ std::vector<NodeType> to_groups,
231
+ bool require_all_to_visited = true,
232
+ Direction allowed_direction = Direction::Undefined)
233
+ : BFSWithPermissiveDependence(
234
+ ValGraphDefinitions(graph),
235
+ ValGraphUses(graph),
236
+ ValGraphInputs(graph),
237
+ ValGraphOutputs(graph),
238
+ std::move(from_groups),
239
+ std::move(to_groups),
240
+ require_all_to_visited,
241
+ allowed_direction) {}
242
+
243
+ // Just a shortcut to the generic getExprsBetween
244
+ static std::pair<ValGraphPermissiveBFS::ExprPath, bool> getExprGroupsBetween(
245
+ const ValGraph& graph,
246
+ const ValGroups& from,
247
+ const ValGroups& to,
248
+ bool require_all_to_visited = true,
249
+ Direction allowed_direction = Direction::Undefined) {
250
+ return getExprsBetween<ValGraphPermissiveBFS>(
251
+ from.vector(),
252
+ to.vector(),
253
+ require_all_to_visited,
254
+ allowed_direction,
255
+ graph);
256
+ }
257
+ };
258
+
259
+ } // namespace nvfuser