nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,929 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <disjoint_set.h>
11
+ #include <exceptions.h>
12
+
13
+ #include <algorithm>
14
+ #include <deque>
15
+ #include <ostream>
16
+ #include <unordered_map>
17
+ #include <unordered_set>
18
+ #include <variant>
19
+ #include <vector>
20
+
21
+ namespace nvfuser {
22
+
23
+ enum class Direction { Forward, Backward, Undefined };
24
+
25
+ template <typename ExprT>
26
+ using ExprPath = std::vector<std::pair<ExprT, Direction>>;
27
+
28
+ } // namespace nvfuser
29
+
30
+ namespace std {
31
+ template <typename ExprT>
32
+ struct hash<pair<ExprT, nvfuser::Direction>> {
33
+ std::size_t operator()(
34
+ const std::pair<ExprT, nvfuser::Direction>& directed_expr) const {
35
+ using std::hash;
36
+ return hash<ExprT>()(directed_expr.first);
37
+ }
38
+ };
39
+ } // namespace std
40
+
41
+ namespace nvfuser {
42
+
43
+ inline std::ostream& operator<<(std::ostream& os, const Direction direction) {
44
+ switch (direction) {
45
+ case Direction::Forward:
46
+ os << "Forward";
47
+ break;
48
+ case Direction::Backward:
49
+ os << "Backward";
50
+ break;
51
+ case Direction::Undefined:
52
+ os << "Undefined";
53
+ break;
54
+ }
55
+ return os;
56
+ }
57
+
58
+ template <typename ExprT>
59
+ std::ostream& operator<<(std::ostream& os, const ExprPath<ExprT>& path) {
60
+ for (const auto& [expr, direction] : path) {
61
+ os << direction << " " << toString(expr);
62
+ }
63
+ return os;
64
+ }
65
+
66
+ inline Direction reverse(Direction direction) {
67
+ if (direction == Direction::Forward) {
68
+ return Direction::Backward;
69
+ } else if (direction == Direction::Backward) {
70
+ return Direction::Forward;
71
+ } else {
72
+ return Direction::Undefined;
73
+ }
74
+ }
75
+
76
+ template <typename ExprT>
77
+ inline ExprPath<ExprT> reverse(const ExprPath<ExprT>& path) {
78
+ auto rev = path;
79
+ std::reverse(rev.begin(), rev.end());
80
+ for (auto& [e, direction] : rev) {
81
+ direction = reverse(direction);
82
+ }
83
+ return rev;
84
+ }
85
+
86
+ template <typename ExprT, typename ValT>
87
+ inline std::string toString(const std::variant<ExprT, ValT>& n) {
88
+ if (auto e = std::get_if<ExprT>(&n)) {
89
+ return toString(*e);
90
+ } else if (auto v = std::get_if<ValT>(&n)) {
91
+ return toString(*v);
92
+ } else {
93
+ NVF_THROW();
94
+ }
95
+ }
96
+
97
+ // Gives the corresponding Val type (e.g., Val* for Expr* and ValGroup for
98
+ // ExprGroup)
99
+ template <typename ExprT>
100
+ struct GetValType;
101
+
102
+ template <typename ExprT, typename InputsT, typename OutputsT>
103
+ std::vector<typename GetValType<ExprT>::type> getInputsOfExpr(
104
+ const ExprT& expr,
105
+ Direction dir,
106
+ InputsT inputs,
107
+ OutputsT outputs) {
108
+ NVF_ERROR(dir == Direction::Forward || dir == Direction::Backward);
109
+ return dir == Direction::Forward ? inputs(expr) : outputs(expr);
110
+ }
111
+
112
+ template <typename ExprT, typename InputsT, typename OutputsT>
113
+ std::vector<typename GetValType<ExprT>::type> getOutputsOfExpr(
114
+ const ExprT& expr,
115
+ Direction dir,
116
+ InputsT inputs,
117
+ OutputsT outputs) {
118
+ return dir == Direction::Forward ? outputs(expr) : inputs(expr);
119
+ }
120
+
121
+ // Traversal for finding the shortest path from given vals to another
122
+ // vals. For now, the vals are either Val* if we want to traverse IR nodes,
123
+ // or ValGroup if we want to traverse ValGraph. However, this algorithm is
124
+ // implement as a class template so in the future, we can extend it to support
125
+ // other types of vals and exprs. The algorithm is based on the standard BFS
126
+ // traversal, however, the traversal graph is treated as an undirected graph, so
127
+ // the traversal direction can be both forward and backward. The dependencies of
128
+ // vals and exprs need to be satisfied. Specifically, when visiting an expr,
129
+ // either its inputs or outputs must be visited before. Similarly, when visiting
130
+ // a val, there must be at least one defining expr or one use expr that is
131
+ // already visited.
132
+ template <
133
+ typename ExprT,
134
+ typename ValT,
135
+ typename DefinitionT,
136
+ typename UsesT,
137
+ typename InputsT,
138
+ typename OutputsT>
139
+ class BFS {
140
+ public:
141
+ using ExprType = ExprT;
142
+ using ValType = ValT;
143
+ using NodeType = std::variant<ExprT, ValT>;
144
+ using ExprPath = std::vector<std::pair<ExprT, Direction>>;
145
+ using InputsType = InputsT;
146
+ using OutputsType = OutputsT;
147
+
148
+ virtual ~BFS() = default;
149
+
150
+ public:
151
+ BFS(DefinitionT definition,
152
+ UsesT uses,
153
+ InputsT inputs,
154
+ OutputsT outputs,
155
+ std::vector<NodeType> from,
156
+ std::vector<NodeType> to,
157
+ bool require_all_to_visited = true,
158
+ Direction allowed_direction = Direction::Undefined)
159
+ : definition_(std::move(definition)),
160
+ uses_(std::move(uses)),
161
+ inputs_(std::move(inputs)),
162
+ outputs_(std::move(outputs)),
163
+ from_(std::move(from)),
164
+ to_(std::move(to)),
165
+ require_all_to_visited_(require_all_to_visited),
166
+ allowed_direction_(allowed_direction) {}
167
+
168
+ // Traverse from from_ to to_, recording each taken
169
+ // path to generate the shortest path after the travesal
170
+ virtual void traverse() {
171
+ for (const auto& n : from_) {
172
+ setVisited(n);
173
+ addNewNeighbors(n);
174
+ }
175
+
176
+ while (!allToNodesVisited()) {
177
+ bool something_was_processed = false;
178
+ std::deque<NodeType> not_ready;
179
+ while (!allToNodesVisited() && !to_visit_.empty()) {
180
+ const auto n = to_visit_.front();
181
+ to_visit_.pop_front();
182
+
183
+ if (isVisited(n)) {
184
+ continue;
185
+ }
186
+
187
+ auto ready_direction = isReady(n);
188
+ if (!ready_direction.has_value()) {
189
+ // To stop an infinite loop, the not-ready node is not moved
190
+ // back to the to_visit_ queue but kept in the separate
191
+ // queue. This way, if all nodes in to_visit_ are not ready,
192
+ // the queue would eventually become empty, which would then
193
+ // break the inner while loop. The something_was_processed
194
+ // flag is used to remember if there's any progress.
195
+ not_ready.emplace_back(n);
196
+ continue;
197
+ }
198
+
199
+ // Visit this node and add its neighbors to to_visit if not
200
+ // visited yet
201
+ setVisited(n);
202
+ setPrevGroups(n, *ready_direction);
203
+ addNewNeighbors(n);
204
+ something_was_processed = true;
205
+ }
206
+
207
+ // If nothing was processed, break out of the loop
208
+ if (!something_was_processed) {
209
+ break;
210
+ }
211
+
212
+ // Something was processed. Redo the traversal.
213
+ to_visit_.insert(to_visit_.end(), not_ready.begin(), not_ready.end());
214
+ }
215
+
216
+ if (require_all_to_visited_ && !allToNodesVisited()) {
217
+ std::stringstream ss;
218
+ for (const auto& to : to_) {
219
+ if (!isVisited(to)) {
220
+ ss << " " << toString(to);
221
+ if (const ExprT* e = std::get_if<ExprT>(&to)) {
222
+ ss << " " << toString(*e);
223
+ }
224
+ }
225
+ }
226
+ ss << " (from: ";
227
+ for (const auto& from : from_) {
228
+ ss << " " << toString(from);
229
+ if (const ExprT* e = std::get_if<ExprT>(&from)) {
230
+ ss << " " << toString(*e);
231
+ }
232
+ }
233
+ ss << ")";
234
+ ss << ", visited: (";
235
+ for (const auto& visited : visited_) {
236
+ if (const ValT* v = std::get_if<ValT>(&visited)) {
237
+ ss << " " << toString(visited);
238
+ }
239
+ }
240
+ ss << ")";
241
+ NVF_THROW("BFS traversal could not visit some nodes: ", ss.str());
242
+ }
243
+ }
244
+
245
+ // Find the shortest path from the from_ to to_. A boolean value
246
+ // indicating if all nodes are visited is also returned. This
247
+ // must be only used once traversal is completed.
248
+ virtual std::pair<ExprPath, bool> getShortestExprPath() {
249
+ NVF_ERROR(
250
+ !require_all_to_visited_ || allToNodesVisited(),
251
+ "Traveral is either not done or failed");
252
+
253
+ ExprPath path;
254
+
255
+ std::deque<std::pair<NodeType, Direction>> to_visit;
256
+ for (const NodeType& to : to_) {
257
+ to_visit.emplace_back(to, Direction::Undefined);
258
+ }
259
+
260
+ while (!to_visit.empty()) {
261
+ const auto [node, direction] = to_visit.front();
262
+ to_visit.pop_front();
263
+
264
+ if (const ExprT* e = std::get_if<ExprT>(&node)) {
265
+ path.emplace_back(*e, direction);
266
+ }
267
+
268
+ if (std::find(from_.begin(), from_.end(), node) != from_.end()) {
269
+ continue;
270
+ }
271
+
272
+ auto prev_nodes_it = prev_nodes_.find(node);
273
+ NVF_ERROR(!require_all_to_visited_ || prev_nodes_it != prev_nodes_.end());
274
+ if (prev_nodes_it != prev_nodes_.end()) {
275
+ const Direction dir = prev_nodes_it->second.first;
276
+ for (const auto& prev_node : prev_nodes_it->second.second) {
277
+ to_visit.emplace_back(prev_node, dir);
278
+ }
279
+ }
280
+ }
281
+
282
+ // At this point, we have the reverse path, but it may have multiple exprs
283
+ // that need to be filtered out. For example, if we are traversing
284
+ // IterDomain transformations, let's say there are domains 0, 1 and 2, and
285
+ // domains 1 and 2 are merged to produce domain 3, and then domains
286
+ // 0 and 3 are merged to produce domain 4.
287
+ //
288
+ // 0 1 2
289
+ //
290
+ // | | |
291
+ // | | |
292
+ // | +--> <--+
293
+ // | 3
294
+ // | |
295
+ // | |
296
+ // +----> 4 <---+
297
+ //
298
+ // Suppose we want to find the shortest path from {4} to {0, 1,
299
+ // 2}. The correct answer should be:
300
+ //
301
+ // Backward merge of 0, 3 -> 4
302
+ // Backward merge of 1, 2 -> 3
303
+ //
304
+ // However, the above traversal would produce a path of:
305
+ //
306
+ // Backward merge of 0, 3 -> 4
307
+ // Backward merge of 1, 2 -> 3
308
+ // Backward merge of 1, 2 -> 3
309
+ // Backward merge of 0, 3 -> 4
310
+ //
311
+ // This is because, since nodes 0, 1 and 2 are the starting nodes,
312
+ // we would first visit 4 from 0, and then 3 from 1 and again 3 from
313
+ // 2. Since node 3 would be visited twice, the path from 3 to 4
314
+ // would be traversed twice as well. Obviously, just reversing this
315
+ // path wouldn't give the correct path. There are two issues here:
316
+ //
317
+ // - The first visit to node 4 from node 0 should not be taken since
318
+ // node 4 must appear after node 3
319
+ // - Visiting the same node multiple times is redundant and should
320
+ // be removed
321
+ //
322
+ // Both problems could be solved by taking into considerations if
323
+ // nodes are ready to visit and also are already visited, just like
324
+ // done in the forward traversal. However, there's an additional
325
+ // complexity in this case because the following graph is also valid:
326
+ //
327
+ // 1 2
328
+ //
329
+ // | | |
330
+ // | | |
331
+ // | +--> <--+
332
+ // | 3
333
+ // | |
334
+ // | |
335
+ // +----> 4 <---+
336
+ //
337
+ // Notice that node 0 is missing, meaning the shortest path problem
338
+ // in this case is from node 4 to nodes 1 and 2, and node 0 is not
339
+ // of interest. The correct path is still the same, i.e., first
340
+ // backward merge of 0 and 3 and then another backward merge of 1
341
+ // and 2. It is just node 0 is discarded as it is not of
342
+ // interest. In this case, however, if the
343
+ // traversal was enforced to honor the dependency of each node,
344
+ // it would not be able to visit the backward merge of 0 and 3 as
345
+ // node 0 is missing.
346
+ //
347
+ // A straightforward solution here is simply first generating the
348
+ // path as above and for each node, take the last visit only. Note
349
+ // that the last visit is always guaranteed to satisfy its
350
+ // dependencies.
351
+ //
352
+ // Recall that the final path needs to be reversed, so instead of
353
+ // finding the last appearance of each node, the final path can be
354
+ // obtained by first reversing the current path and then only taking
355
+ // the first appearance of each expr. Or, more simply, we can
356
+ // just use VectorOfUniqueEntries with the reverse iterator.
357
+ //
358
+ // See the BFS2 test for a concrete example.
359
+
360
+ VectorOfUniqueEntries<std::pair<ExprT, Direction>> unique_path(
361
+ path.rbegin(), path.rend());
362
+
363
+ return std::make_pair(unique_path.vector(), allToNodesVisited());
364
+ }
365
+
366
+ // Check if a node is ready to visit. If yes, return the direction
367
+ // and the prev nodes that should be visited before the given node
368
+ // is visited.
369
+ virtual std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
370
+ const NodeType& node) const {
371
+ if (const ExprT* e = std::get_if<ExprT>(&node)) {
372
+ return isReady(*e);
373
+ } else if (const ValT* v = std::get_if<ValT>(&node)) {
374
+ return isReady(*v);
375
+ } else {
376
+ NVF_THROW();
377
+ }
378
+ }
379
+
380
+ // Check if an ExprT is ready to visit. Either all of its inputs
381
+ // or all of outputs must have their dependencies satisfied. If
382
+ // ready because the inputs are already visited, return
383
+ // Direction::Forward and all the input nodes. If ready because the
384
+ // outputs are ready, return Direction::Backward and all the output nodes.
385
+ virtual std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
386
+ const ExprT& expr) const {
387
+ // Either all inputs or all outputs must have been visited
388
+ decltype(auto) inputs = inputs_(expr);
389
+ if (!inputs.empty() && allowed_direction_ != Direction::Backward &&
390
+ std::all_of(
391
+ inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
392
+ return isDependencySatisfied(input);
393
+ })) {
394
+ std::vector<NodeType> prev_nodes;
395
+ std::copy_if(
396
+ inputs.begin(),
397
+ inputs.end(),
398
+ std::back_inserter(prev_nodes),
399
+ [&](const ValT& input) -> bool { return isVisited(input); });
400
+ return std::make_pair(Direction::Forward, prev_nodes);
401
+ }
402
+
403
+ decltype(auto) outputs = outputs_(expr);
404
+ if (!outputs.empty() && allowed_direction_ != Direction::Forward &&
405
+ std::all_of(
406
+ outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
407
+ return isDependencySatisfied(output);
408
+ })) {
409
+ std::vector<NodeType> prev_nodes;
410
+ std::copy_if(
411
+ outputs.begin(),
412
+ outputs.end(),
413
+ std::back_inserter(prev_nodes),
414
+ [&](const ValT& output) -> bool { return isVisited(output); });
415
+ return std::make_pair(Direction::Backward, prev_nodes);
416
+ }
417
+
418
+ return std::nullopt;
419
+ }
420
+
421
+ // Check if a val is ready to visit. Either its defining or use
422
+ // expr must have its dependency satisfied. If ready because
423
+ // there's a visited defining expr, return Direction::Forward and
424
+ // the defining expr. If ready because there's a visited use expr, return
425
+ // Direction::Backward and the use expr.
426
+ virtual std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
427
+ const ValT& v) const {
428
+ // In the case of Val, requires just one def or use expr.
429
+ // Check if any use is visited
430
+ decltype(auto) uses = uses_(v);
431
+ if (!uses.empty()) {
432
+ auto it = std::find_if(
433
+ uses.begin(), uses.end(), [&](const ExprT& use_e) -> bool {
434
+ return isDependencySatisfied(use_e);
435
+ });
436
+ if (it != uses.end()) {
437
+ return std::make_pair(Direction::Backward, std::vector<NodeType>{*it});
438
+ }
439
+ }
440
+ // Check if any def is visited
441
+ decltype(auto) def = definition_(v);
442
+ if (!def.empty()) {
443
+ auto it =
444
+ std::find_if(def.begin(), def.end(), [&](const ExprT& def_e) -> bool {
445
+ return isDependencySatisfied(def_e);
446
+ });
447
+ if (it != def.end()) {
448
+ return std::make_pair(Direction::Forward, std::vector<NodeType>{*it});
449
+ }
450
+ }
451
+
452
+ return std::nullopt;
453
+ }
454
+
455
+ // If another node depends on a given node, check if that
456
+ // dependency is considered satisfied. If the given node is already
457
+ // visited, that should mean the dependency is satisfied.
458
+ virtual bool isDependencySatisfied(const NodeType& dependency) const {
459
+ return isVisited(dependency);
460
+ }
461
+
462
+ // Check if a given node is already visited
463
+ virtual bool isVisited(const NodeType& node) const {
464
+ return visited_.find(node) != visited_.end();
465
+ }
466
+
467
+ // Mark a node as visited
468
+ virtual void setVisited(const NodeType& node) {
469
+ visited_.emplace(node);
470
+ }
471
+
472
+ // Add new neighbors of a given node to the to_visit list
473
+ virtual void addNewNeighbors(const NodeType& node) {
474
+ auto add_to_visit_list = [&](const NodeType& n) -> void {
475
+ if (isVisited(n) || excludeFromTraversal(n)) {
476
+ return;
477
+ }
478
+ to_visit_.emplace_back(n);
479
+ };
480
+
481
+ if (const ExprT* e = std::get_if<ExprT>(&node)) {
482
+ if (allowed_direction_ == Direction::Backward ||
483
+ allowed_direction_ == Direction::Undefined) {
484
+ for (const auto& v : inputs_(*e)) {
485
+ add_to_visit_list(v);
486
+ }
487
+ }
488
+ if (allowed_direction_ == Direction::Forward ||
489
+ allowed_direction_ == Direction::Undefined) {
490
+ for (const auto& v : outputs_(*e)) {
491
+ add_to_visit_list(v);
492
+ }
493
+ }
494
+ } else if (const ValT* v = std::get_if<ValT>(&node)) {
495
+ if (allowed_direction_ == Direction::Forward ||
496
+ allowed_direction_ == Direction::Undefined) {
497
+ for (const auto& e : uses_(*v)) {
498
+ add_to_visit_list(e);
499
+ }
500
+ }
501
+ if (allowed_direction_ == Direction::Backward ||
502
+ allowed_direction_ == Direction::Undefined) {
503
+ for (const auto& e : definition_(*v)) {
504
+ add_to_visit_list(e);
505
+ }
506
+ }
507
+ } else {
508
+ NVF_THROW();
509
+ }
510
+ }
511
+
512
+ // Check if all to_ are visited
513
+ virtual bool allToNodesVisited() const {
514
+ return std::all_of(
515
+ to_.begin(), to_.end(), [&](const NodeType& node) -> bool {
516
+ return isVisited(node);
517
+ });
518
+ };
519
+
520
+ // Set the previous nodes of a given node that is visited in a
521
+ // given direction
522
+ virtual void setPrevGroups(
523
+ const NodeType& node,
524
+ const std::pair<Direction, std::vector<NodeType>>& prev_nodes) {
525
+ NVF_ERROR(
526
+ prev_nodes_.emplace(node, prev_nodes).second,
527
+ "Previous node already set for ",
528
+ toString(node));
529
+ }
530
+
531
+ // Hook to exclude certain graph nodes. See IndexingTraversal for a
532
+ // concrete example
533
+ virtual bool excludeFromTraversal(const NodeType& node) const {
534
+ return false;
535
+ }
536
+
537
+ protected:
538
+ const DefinitionT definition_;
539
+ const UsesT uses_;
540
+ const InputsT inputs_;
541
+ const OutputsT outputs_;
542
+ const std::vector<NodeType> from_;
543
+ const std::vector<NodeType> to_;
544
+ std::deque<NodeType> to_visit_;
545
+ std::unordered_set<NodeType> visited_;
546
+ std::unordered_map<NodeType, std::pair<Direction, std::vector<NodeType>>>
547
+ prev_nodes_;
548
+ bool require_all_to_visited_ = true;
549
+ Direction allowed_direction_ = Direction::Undefined;
550
+ };
551
+
552
+ // Unlike the default BFS behavior, Expr is considered ready to
553
+ // visit as long as one of the inputs or outputs has any of its dependencies met
554
+ template <
555
+ typename ExprT,
556
+ typename ValT,
557
+ typename DefinitionT,
558
+ typename UsesT,
559
+ typename InputsT,
560
+ typename OutputsT>
561
+ class BFSWithPermissiveDependence
562
+ : public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
563
+ public:
564
+ using BFSBaseType = BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>;
565
+ using NodeType = typename BFSBaseType::NodeType;
566
+
567
+ BFSWithPermissiveDependence(
568
+ DefinitionT definition,
569
+ UsesT uses,
570
+ InputsT inputs,
571
+ OutputsT outputs,
572
+ std::vector<NodeType> from,
573
+ std::vector<NodeType> to,
574
+ bool require_all_to_visited = true,
575
+ Direction allowed_direction = Direction::Undefined)
576
+ : BFSBaseType(
577
+ definition,
578
+ uses,
579
+ inputs,
580
+ outputs,
581
+ std::move(from),
582
+ std::move(to),
583
+ require_all_to_visited,
584
+ allowed_direction) {}
585
+
586
+ std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
587
+ const ExprT& expr) const override {
588
+ // Either any inputs or any outputs must have been visited
589
+ decltype(auto) inputs = this->inputs_(expr);
590
+ if (!inputs.empty() && this->allowed_direction_ != Direction::Backward &&
591
+ std::any_of(
592
+ inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
593
+ return this->isDependencySatisfied(input);
594
+ })) {
595
+ std::vector<NodeType> prev_nodes;
596
+ std::copy_if(
597
+ inputs.begin(),
598
+ inputs.end(),
599
+ std::back_inserter(prev_nodes),
600
+ [&](const ValT& input) -> bool { return this->isVisited(input); });
601
+ return std::make_pair(Direction::Forward, prev_nodes);
602
+ }
603
+
604
+ decltype(auto) outputs = this->outputs_(expr);
605
+ if (!outputs.empty() && this->allowed_direction_ != Direction::Forward &&
606
+ std::any_of(
607
+ outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
608
+ return this->isDependencySatisfied(output);
609
+ })) {
610
+ std::vector<NodeType> prev_nodes;
611
+ std::copy_if(
612
+ outputs.begin(),
613
+ outputs.end(),
614
+ std::back_inserter(prev_nodes),
615
+ [&](const ValT& output) -> bool { return this->isVisited(output); });
616
+ return std::make_pair(Direction::Backward, prev_nodes);
617
+ }
618
+ return std::nullopt;
619
+ }
620
+
621
+ // When adding new neighbors of an expr node, if any of inputs is
622
+ // the previous node of this expr, then don't add the remaining
623
+ // inputs to the to-visit list. Similary, if any of the outputs is
624
+ // the previous node of this expr, don't add the remaining
625
+ // outputs. See BFSTest.IRBFSPermissiveTraversal2 for a concrete
626
+ // example.
627
+ void addNewNeighbors(const NodeType& node) override {
628
+ const ExprT* e = std::get_if<ExprT>(&node);
629
+ if (e == nullptr) {
630
+ BFSBaseType::addNewNeighbors(node);
631
+ return;
632
+ }
633
+
634
+ auto add_to_visit_list = [&](const NodeType& n) -> void {
635
+ if (this->isVisited(n) || this->excludeFromTraversal(n)) {
636
+ return;
637
+ }
638
+ this->to_visit_.emplace_back(n);
639
+ };
640
+
641
+ auto prev_nodes_it = this->prev_nodes_.find(node);
642
+
643
+ auto is_any_already_visited = [&](const auto& inputs_or_outputs) -> bool {
644
+ if (prev_nodes_it == this->prev_nodes_.end()) {
645
+ return false;
646
+ }
647
+
648
+ const std::vector<NodeType>& prev_nodes = prev_nodes_it->second.second;
649
+
650
+ return std::any_of(
651
+ inputs_or_outputs.begin(),
652
+ inputs_or_outputs.end(),
653
+ [&](const auto& input_or_output) {
654
+ return std::find(
655
+ prev_nodes.begin(),
656
+ prev_nodes.end(),
657
+ NodeType(input_or_output)) != prev_nodes.end();
658
+ });
659
+ };
660
+
661
+ if (this->allowed_direction_ == Direction::Backward ||
662
+ this->allowed_direction_ == Direction::Undefined) {
663
+ // There's an input node that is marked as a previous node of
664
+ // this node. Since this is permissive traversal, some of the
665
+ // other inputs may not be visited yet, but going back to
666
+ // the input nodes doesn't seem to make sense
667
+ auto input_nodes = this->inputs_(*e);
668
+ if (!is_any_already_visited(input_nodes)) {
669
+ for (const auto& v : input_nodes) {
670
+ add_to_visit_list(v);
671
+ }
672
+ }
673
+ }
674
+ if (this->allowed_direction_ == Direction::Forward ||
675
+ this->allowed_direction_ == Direction::Undefined) {
676
+ auto output_nodes = this->outputs_(*e);
677
+ if (!is_any_already_visited(output_nodes)) {
678
+ for (const auto& v : output_nodes) {
679
+ add_to_visit_list(v);
680
+ }
681
+ }
682
+ }
683
+ }
684
+ };
685
+
686
+ // Find the shortest path from the from vals to the to
687
+ // vals. Dependency between vals and exprs must be satisfied.
688
+ // It is an error if no valid path is found unless
689
+ // require_all_to_visited is false.
690
+ template <typename BFSType, typename... AdditionalArgs>
691
+ static std::pair<typename BFSType::ExprPath, bool> getExprsBetween(
692
+ const std::vector<typename BFSType::ValType>& from,
693
+ const std::vector<typename BFSType::ValType>& to,
694
+ bool require_all_to_visited = true,
695
+ Direction allowed_direction = Direction::Undefined,
696
+ const AdditionalArgs&... additional_args) {
697
+ BFSType bfs(
698
+ additional_args...,
699
+ {from.begin(), from.end()},
700
+ {to.begin(), to.end()},
701
+ require_all_to_visited,
702
+ allowed_direction);
703
+ bfs.traverse();
704
+ return bfs.getShortestExprPath();
705
+ }
706
+
707
+ template <typename ExprT, typename InputsT, typename OutputsT>
708
+ std::vector<typename GetValType<ExprT>::type> getInputsOfExprPath(
709
+ const std::vector<std::pair<ExprT, Direction>>& path,
710
+ InputsT get_inputs,
711
+ OutputsT get_outputs) {
712
+ using ValT = typename GetValType<ExprT>::type;
713
+ std::vector<ValT> inputs;
714
+ std::unordered_set<ValT> all_outputs;
715
+
716
+ for (const auto& [expr, dir] : path) {
717
+ for (const auto& inp :
718
+ getInputsOfExpr(expr, dir, get_inputs, get_outputs)) {
719
+ if (all_outputs.find(inp) == all_outputs.end()) {
720
+ inputs.push_back(inp);
721
+ }
722
+ }
723
+ for (const auto& out :
724
+ getOutputsOfExpr(expr, dir, get_inputs, get_outputs)) {
725
+ all_outputs.emplace(out);
726
+ }
727
+ }
728
+
729
+ return inputs;
730
+ }
731
+
732
+ template <typename ExprT, typename InputsT, typename OutputsT>
733
+ std::vector<typename GetValType<ExprT>::type> getOutputsOfExprPath(
734
+ const std::vector<std::pair<ExprT, Direction>>& path,
735
+ InputsT get_inputs,
736
+ OutputsT get_outputs) {
737
+ return getInputsOfExprPath(reverse(path), get_inputs, get_outputs);
738
+ }
739
+
740
+ // Given a set of exprs and vals, get all reachable ones from another set of
741
+ // nodes
742
+ template <typename BFSType, typename... AdditionalArgs>
743
+ std::vector<typename BFSType::NodeType> getReachableNodesFrom(
744
+ const std::vector<typename BFSType::NodeType>& from,
745
+ const std::vector<typename BFSType::NodeType>& nodes,
746
+ Direction allowed_direction = Direction::Undefined,
747
+ const AdditionalArgs&... additional_args) {
748
+ BFSType bfs(
749
+ additional_args...,
750
+ from,
751
+ nodes,
752
+ /*require_all_to_visited=*/false,
753
+ allowed_direction);
754
+
755
+ bfs.traverse();
756
+
757
+ std::vector<typename BFSType::NodeType> reachable_nodes;
758
+ for (const auto& node : nodes) {
759
+ if (bfs.isVisited(node) ||
760
+ std::find(from.begin(), from.end(), node) != from.end()) {
761
+ reachable_nodes.push_back(node);
762
+ }
763
+ }
764
+
765
+ return reachable_nodes;
766
+ }
767
+
768
+ // Shortcut of getReachableNodesFrom for Vals
769
+ template <typename BFSType, typename... AdditionalArgs>
770
+ std::vector<typename BFSType::ValType> getReachableValsFrom(
771
+ const std::vector<typename BFSType::ValType>& from,
772
+ const std::vector<typename BFSType::ValType>& vals,
773
+ Direction allowed_direction = Direction::Undefined,
774
+ const AdditionalArgs&... additional_args) {
775
+ auto reachable_nodes = getReachableNodesFrom<BFSType, AdditionalArgs...>(
776
+ {from.begin(), from.end()},
777
+ {vals.begin(), vals.end()},
778
+ allowed_direction,
779
+ additional_args...);
780
+
781
+ std::vector<typename BFSType::ValType> reachable_vals;
782
+ reachable_vals.reserve(reachable_nodes.size());
783
+ std::transform(
784
+ reachable_nodes.begin(),
785
+ reachable_nodes.end(),
786
+ std::back_inserter(reachable_vals),
787
+ [](const auto& node) {
788
+ return std::get<typename BFSType::ValType>(node);
789
+ });
790
+
791
+ return reachable_vals;
792
+ }
793
+
794
+ // Traverse from a given set of vals to another set of vals and
795
+ // return all vals between them. Note that if none of the Vals in the
796
+ // second set is reachable, nothing will be returned. For example,
797
+ // if a forward Merge needs to be traversed to get to the target Val
798
+ // set, both of the two inputs must be given or reachable from the
799
+ // given starting Val set.
800
+ //
801
+ // NOTE: getValsBetween(from, to) != getValsBetween(to, from). For
802
+ // example, suppose from={i0}, to={i2}, and merge(i0, i1) =
803
+ // i2. Since i1 is missing, nothing will be returned. However, if
804
+ // from={i2} and to={i0}, then the backward merge can be traversed
805
+ // as its sole input is available, so {i0} would be returned.
806
+ template <typename BFSType, typename... AdditionalArgs>
807
+ std::vector<typename BFSType::ValType> getValsBetween(
808
+ const std::vector<typename BFSType::ValType>& from,
809
+ const std::vector<typename BFSType::ValType>& to,
810
+ const AdditionalArgs&... additional_args) {
811
+ using ValType = typename BFSType::ValType;
812
+ auto path = getExprsBetween<BFSType>(
813
+ from,
814
+ to,
815
+ /*require_all_to_visited=*/false,
816
+ /*allowed_direction=*/Direction::Undefined,
817
+ additional_args...)
818
+ .first;
819
+
820
+ VectorOfUniqueEntries<ValType> unique_vals;
821
+ for (auto [expr, dir] : path) {
822
+ unique_vals.pushBack(getInputsOfExpr(
823
+ expr,
824
+ dir,
825
+ // This assumes get_inputs and get_outputs take the same
826
+ // additional arguments, which is the case with
827
+ // ValGraphBFS. Revisit if needed.
828
+ typename BFSType::InputsType(additional_args...),
829
+ typename BFSType::OutputsType(additional_args...)));
830
+ unique_vals.pushBack(getOutputsOfExpr(
831
+ expr,
832
+ dir,
833
+ typename BFSType::InputsType(additional_args...),
834
+ typename BFSType::OutputsType(additional_args...)));
835
+ }
836
+
837
+ // If a val in from is found in to, just copy it to the returned val
838
+ // set since there's no corresponding expr.
839
+ for (const auto& from_val : from) {
840
+ if (std::find(to.begin(), to.end(), from_val) != to.end()) {
841
+ unique_vals.pushBack(from_val);
842
+ }
843
+ }
844
+
845
+ return unique_vals.vector();
846
+ }
847
+
848
+ // Get all dependencies of to in from.
849
+ template <typename BFSType, typename... AdditionalArgs>
850
+ std::vector<typename BFSType::ValType> getDependenciesTo(
851
+ const std::vector<typename BFSType::ValType>& vals,
852
+ const std::vector<typename BFSType::ValType>& to) {
853
+ using ValType = typename BFSType::ValType;
854
+ auto path = getExprsBetween<BFSType>(
855
+ vals,
856
+ to,
857
+ /*require_all_to_visited=*/true,
858
+ /*allowed_direction=*/Direction::Undefined)
859
+ .first;
860
+
861
+ VectorOfUniqueEntries<ValType> unique_vals;
862
+
863
+ std::unordered_set<ValType> val_set{vals.begin(), vals.end()};
864
+
865
+ for (const auto& [expr, direction] : path) {
866
+ auto inputs =
867
+ (direction == Direction::Forward) ? expr->inputs() : expr->outputs();
868
+ for (auto val : inputs) {
869
+ if (val_set.find(val) != val_set.end()) {
870
+ unique_vals.pushBack(val);
871
+ }
872
+ }
873
+ }
874
+
875
+ return unique_vals.vector();
876
+ }
877
+
878
+ // Given `from`, project it to `to`. This function will return a subset of
879
+ // `to` that is connected to `from`.
880
+ template <typename BFSType, typename... AdditionalArgs>
881
+ std::unordered_set<typename BFSType::ValType> projectTo(
882
+ const typename BFSType::ValType& from,
883
+ const std::vector<typename BFSType::ValType>& to,
884
+ Direction allowed_direction = Direction::Undefined,
885
+ const AdditionalArgs&... additional_args) {
886
+ using ValType = typename BFSType::ValType;
887
+ std::unordered_set<ValType> projection{from};
888
+ // Reverse order
889
+ auto exprs = getExprsBetween<BFSType>(
890
+ {to},
891
+ {from},
892
+ /*require_all_to_visited=*/false,
893
+ allowed_direction,
894
+ additional_args...)
895
+ .first;
896
+ while (!exprs.empty()) {
897
+ const auto& [expr, direction] = exprs.back();
898
+ exprs.pop_back();
899
+ auto from = getOutputsOfExpr(
900
+ expr,
901
+ direction,
902
+ typename BFSType::InputsType(additional_args...),
903
+ typename BFSType::OutputsType(additional_args...));
904
+ auto to = getInputsOfExpr(
905
+ expr,
906
+ direction,
907
+ typename BFSType::InputsType(additional_args...),
908
+ typename BFSType::OutputsType(additional_args...));
909
+
910
+ for (const auto& g : from) {
911
+ if (projection.count(g)) {
912
+ projection.erase(g);
913
+ projection.insert(to.begin(), to.end());
914
+ }
915
+ }
916
+ }
917
+ // Remove items that are not in `to`. This could happen if `from` is not
918
+ // connected to `to`.
919
+ for (auto it = projection.begin(); it != projection.end();) {
920
+ if (std::find(to.begin(), to.end(), *it) == to.end()) {
921
+ it = projection.erase(it);
922
+ } else {
923
+ ++it;
924
+ }
925
+ }
926
+ return projection;
927
+ }
928
+
929
+ } // namespace nvfuser