nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,751 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <debug.h>
11
+ #include <exceptions.h>
12
+ #include <fusion.h>
13
+ #include <ir/base_nodes.h>
14
+ #include <options.h>
15
+ #include <scheduler/all_schedulers.h>
16
+ #include <scheduler/registry.h>
17
+ #include <scheduler/runtime_info.h>
18
+ #include <utils.h>
19
+ #include <visibility.h>
20
+
21
+ #include <deque>
22
+ #include <list>
23
+ #include <unordered_set>
24
+ #include <vector>
25
+
26
+ namespace nvfuser {
27
+
28
+ class SegmentedGroup;
29
+ class SegmentCandidateFinder;
30
+
31
+ // A directed edge on DAG,
32
+ // Wrapper for values, edges between segmented groups which are made up
33
+ // of Exprs. Multiple edges can exist between segmented groups.
34
+ struct SegmentedEdge {
35
+ SegmentedEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val)
36
+ : from(from), to(to), val(val) {}
37
+
38
+ SegmentedGroup* from;
39
+ SegmentedGroup* to;
40
+ Val* val;
41
+
42
+ void print() const;
43
+ };
44
+
45
+ std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge);
46
+
47
+ //! Groups together expressions which create a segmented group
48
+ //! Can be used to produce fusions
49
+ class SegmentedGroup {
50
+ public:
51
+ //! Utility struct to represent a group connection
52
+ //! both the group to connect with and the edge
53
+ //! to connect through
54
+ struct NeighborGroup {
55
+ NeighborGroup(SegmentedGroup* g, SegmentedEdge* e) : group(g), edge(e) {}
56
+ SegmentedGroup* group;
57
+ SegmentedEdge* edge;
58
+ };
59
+
60
+ SegmentedGroup(SegmentedFusion* segmented_fusion)
61
+ : segmented_fusion_(segmented_fusion) {}
62
+
63
+ SegmentedGroup(Expr* expr, SegmentedFusion* segmented_fusion)
64
+ : segmented_fusion_(segmented_fusion) {
65
+ exprs_.push_back(expr);
66
+ }
67
+
68
+ //! Create a temporary group to signify a fusion input, which can be
69
+ //! an original fusion input or a forwarded input with unary-only
70
+ //! use chains
71
+ SegmentedGroup(SegmentedFusion* segmented_fusion, bool is_fusion_input)
72
+ : is_fusion_input_(is_fusion_input),
73
+ segmented_fusion_(segmented_fusion) {}
74
+
75
+ //! Serialize SegmentedGroup using flatbuffers
76
+ flatbuffers::Offset<serde::SegmentedGroup> serialize(
77
+ flatbuffers::FlatBufferBuilder& builder,
78
+ const std::unordered_map<Val*, int64_t>& vals_map,
79
+ const std::unordered_map<Expr*, int64_t>& exprs_map,
80
+ const std::unordered_map<SegmentedGroup*, int64_t>& groups_map,
81
+ const std::unordered_map<SegmentedEdge*, int64_t>& edges_map) const;
82
+
83
+ //! Deserialize SegmentedGroup using flatbuffers
84
+ void deserialize(
85
+ const serde::SegmentedGroup* buffer,
86
+ const std::deque<Val*>& vals,
87
+ const std::deque<Expr*>& exprs,
88
+ const std::vector<SegmentedGroup*>& groups,
89
+ const std::vector<SegmentedEdge*>& edges);
90
+
91
+ //! Checks if this group takes original fusion's input
92
+ bool isInputGroup() {
93
+ return !input_vals.empty();
94
+ };
95
+
96
+ //! Checks if this group is used any where in the segmented fusion
97
+ bool isConnected() const {
98
+ return !producer_edges.empty() || !consumer_edges.empty() ||
99
+ !output_vals.empty();
100
+ }
101
+
102
+ //! returns the id assigned by segment pass
103
+ int groupId() const {
104
+ return group_id_;
105
+ }
106
+
107
+ //! Returns inputs that this group shares with the original fusion
108
+ const auto& inputs() const {
109
+ return input_vals;
110
+ }
111
+
112
+ //! Returns outputs that this group shares with the original fusion
113
+ const auto& outputs() const {
114
+ return output_vals;
115
+ }
116
+
117
+ //! Returns the schedule heuristic associated with this group
118
+ SchedulerType schedulerType() const {
119
+ return scheduler_type_;
120
+ }
121
+
122
+ //! Returns the exprs that make up this group
123
+ const std::vector<Expr*>& exprs() const {
124
+ return exprs_;
125
+ }
126
+
127
+ //! Returns the complete fusion inputs mapped to this segmented group's fusion
128
+ const auto& getCompleteFusionInputs() const {
129
+ return original_inputs_in_cloned_fusion_;
130
+ }
131
+
132
+ //! Returns cloned fusion for this segmented group.
133
+ //! TODO Replace read-only uses of makeFusion with cached getFusion
134
+ Fusion* getFusion() {
135
+ // Build cloned fusion for this segmented group
136
+ if (cloned_fusion_ == nullptr) {
137
+ makeClonedFusion();
138
+ }
139
+ return cloned_fusion_.get();
140
+ }
141
+
142
+ //! Debug print function
143
+ void print() const;
144
+
145
+ //! Utility to re-collect the operators included in this
146
+ //! segmented group after updating the group boundary.
147
+ void resetExprList();
148
+
149
+ //! Try to get a scheduler entry for this group with
150
+ //! the given runtime info.
151
+ //! Returns a new scheduler with the same heuristics
152
+ //! for this group if possible.
153
+ //! Note that the schedule params can be different.
154
+ //! Returns a nullopt if this group cannot be scheduled
155
+ //! with the same heuristics.
156
+ std::optional<std::unique_ptr<HeuristicParams>> getMaybeHeuristicParams(
157
+ SchedulerRuntimeInfo& runtime_info);
158
+
159
+ //! Query if this is a group for a fusion input
160
+ bool isFusionInputGroup() const;
161
+
162
+ public:
163
+ //! "Ancestor nodes", towards inputs of segmentedDAG
164
+ std::vector<SegmentedEdge*> producer_edges;
165
+
166
+ //! "Descendent nodes", towards outputs of segmentedDAG
167
+ std::vector<SegmentedEdge*> consumer_edges;
168
+
169
+ //! Composite Fusion inputs in this group
170
+ std::vector<Val*> input_vals;
171
+
172
+ //! Composite Fusion outputs in this group
173
+ std::vector<Val*> output_vals;
174
+
175
+ bool isMerged() const {
176
+ return merged_;
177
+ }
178
+
179
+ private:
180
+ friend class SegmentCandidateFinder;
181
+ friend class SegmentedFusion;
182
+ friend class FusionKernelRuntime;
183
+ friend class TranslateApplicableWelford;
184
+
185
+ //! unique identifier of group in the segmented fusion
186
+ int group_id_ = -1;
187
+
188
+ //! The scheduler to use for compiling this group
189
+ SchedulerType scheduler_type_ = SchedulerType::None;
190
+
191
+ //! Exprs that make up the group
192
+ std::vector<Expr*> exprs_;
193
+
194
+ //! Maximum path distance from an input segmented group required for
195
+ //! Theorem 4.2
196
+ int level_ = -1;
197
+
198
+ //! traversal marker, has this node already been processed
199
+ bool visited_ = false;
200
+
201
+ //! Did we select another group to merge with
202
+ SegmentedGroup* merge_with_ = nullptr;
203
+
204
+ //! if we selected another group to merge, which edge is to be contracted
205
+ SegmentedEdge* merge_through_ = nullptr;
206
+
207
+ //! Has this node been merged?
208
+ bool merged_ = false;
209
+
210
+ //! Is a group for a fusion input?
211
+ bool is_fusion_input_ = false;
212
+
213
+ private:
214
+ //! Utility to convert edge vector to value vector
215
+ std::vector<Val*> edgesToVals(const std::vector<SegmentedEdge*>& se_v);
216
+
217
+ //! Reset method to call at begining of each
218
+ //! merge node iteration
219
+ void clearTraversalInfo();
220
+
221
+ //! To be called at the very end of segment fusion
222
+ //! no more segment merging should be done beyond
223
+ void finalize();
224
+
225
+ //! Make the cloned fusion for this segmented group
226
+ void makeClonedFusion();
227
+
228
+ //! Return all segmented groups connected with *this
229
+ std::vector<SegmentedGroup*> getNeighbors();
230
+
231
+ //! TODO: May want to sort this based on size of connections between this and
232
+ //! neighbors as well as if the connection is an output of the fusion (has to
233
+ //! be saved to gmem anyways)
234
+ std::vector<NeighborGroup> getNeighborGroups();
235
+
236
+ //! Look at all neighbors of this and return who this could merge with based
237
+ //! on level values of this, neighbors, and merged neighbors of neighbors
238
+ std::vector<NeighborGroup> getMergeCandidates();
239
+
240
+ //! Assign scheduler type to this group
241
+ void setSchedulerType(SchedulerType scheduler_type) {
242
+ scheduler_type_ = scheduler_type;
243
+ }
244
+
245
+ //! Assign Id for this group
246
+ void setID(int id) {
247
+ NVF_ERROR(group_id_ == -1);
248
+ group_id_ = id;
249
+ }
250
+
251
+ //! SegmentedFusion this group belongs to
252
+ SegmentedFusion* segmented_fusion_;
253
+
254
+ //! The cloned segmented fusion
255
+ std::unique_ptr<Fusion> cloned_fusion_;
256
+
257
+ //! These are the complete fusion's inputs mapped to the cloned fusion
258
+ std::vector<Val*> original_inputs_in_cloned_fusion_;
259
+ };
260
+
261
+ std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group);
262
+
263
+ //! Exported Interface for representing segmented fusion graph
264
+ //! this class owns the segmented groups
265
+ class SegmentedFusion {
266
+ public:
267
+ explicit SegmentedFusion(std::unique_ptr<Fusion> fusion);
268
+
269
+ //! Factory function for the un-segmented case, directly
270
+ //! constructs a "SegmentedFusion", with the given Fusion
271
+ //! as the only group.
272
+ static std::unique_ptr<SegmentedFusion> fromCompleteFusion(
273
+ std::unique_ptr<Fusion> fusion,
274
+ SchedulerType scheduler_type,
275
+ const KernelArgumentHolder& runtime_inputs);
276
+
277
+ //! Is the fusion segmented?
278
+ bool isSegmented() const {
279
+ return !groups_.empty();
280
+ }
281
+
282
+ std::vector<SegmentedGroup*>& groups() {
283
+ return groups_;
284
+ }
285
+
286
+ const std::vector<SegmentedGroup*>& groups() const {
287
+ return groups_;
288
+ }
289
+
290
+ std::vector<SegmentedEdge*>& edges() {
291
+ return edges_;
292
+ }
293
+
294
+ const std::vector<SegmentedGroup*>& cgroups() const {
295
+ return groups_;
296
+ }
297
+
298
+ const std::vector<SegmentedEdge*>& cedges() const {
299
+ return edges_;
300
+ }
301
+
302
+ //! Returns the original un-segmented fusion
303
+ Fusion* completeFusion() const {
304
+ return complete_fusion_.get();
305
+ }
306
+
307
+ const auto& inputs() const {
308
+ return complete_fusion_->inputs();
309
+ }
310
+
311
+ const auto& outputs() const {
312
+ return complete_fusion_->outputs();
313
+ }
314
+
315
+ //! Get the fusion for the segmented group and return the IrCloner used to
316
+ //! clone the complete fusion
317
+ std::pair<IrCloner, std::unique_ptr<Fusion>> makeFusion(SegmentedGroup* sg);
318
+
319
+ //! Make a heuristics entry for a group and parameters
320
+ std::unique_ptr<HeuristicParams> makeInitialHeuristicParams(
321
+ SegmentedGroup* sg,
322
+ SchedulerRuntimeInfo& runtime_info);
323
+
324
+ //! Debug drawing for graphviz
325
+ void draw();
326
+
327
+ //! Debug print for segmented fusions
328
+ void print() const;
329
+
330
+ //! API for adding groups
331
+ SegmentedGroup* newGroup();
332
+
333
+ //! API shortcut for adding a singleton group
334
+ SegmentedGroup* newGroup(Expr* expr);
335
+
336
+ //! API shortcut for adding a new group for a fusion input
337
+ SegmentedGroup* newFusionInputGroup();
338
+
339
+ //! API for adding edges
340
+ SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
341
+
342
+ HeuristicDataCache* getCachedHeuristicDataFor(SegmentedGroup* group);
343
+
344
+ //! Lower FP precision of inputs and outputs specified by the given
345
+ //! edges.
346
+ //!
347
+ //! This function is used in two scenarios. One is when testing a
348
+ //! merge of groups during the segmentation time. At that time,
349
+ //! those groups are not yet merged, but we want to consider them as
350
+ //! merged and see if there's a valid scheduler. So, we treat the
351
+ //! groups given by groups_to_merge as a single group and insert
352
+ //! cast ops into the group. No other group is modified unless it
353
+ //! has an edge to any of the merged groups.
354
+ //!
355
+ //! The second scenario is when inserting cast ops to a whole
356
+ //! segmented fusion. All groups are considered separate groups with
357
+ //! no (temporary) merging. Each edge is considered a potential
358
+ //! place to insert cast. In this case, groups_to_merge should be
359
+ //! empty.
360
+ std::vector<SegmentedEdge*> castInputOutputToLowerPrecision(
361
+ const std::vector<SegmentedEdge*>& edges,
362
+ const std::vector<SegmentedGroup*>& groups_to_merge = {});
363
+
364
+ //! Revert the changes made by castInputOutputToLowerPrecision to the given
365
+ //! edges
366
+ void revertInputOutputPrecisionChanges(
367
+ const std::vector<SegmentedEdge*>& edges);
368
+
369
+ //! Grab edges with val
370
+ std::vector<SegmentedEdge*> getEdgesByVal(Val* val) const;
371
+
372
+ //! Make sure it's a DAG and optionally disjoint
373
+ void validate(bool require_disjoint = true) const;
374
+
375
+ //! Same as validate but only enabled when NDEBUG is undefined
376
+ void validateIfDebug(bool require_disjoint = true) const;
377
+
378
+ //! Serialize SegmentedFusion using flatbuffers
379
+ flatbuffers::Offset<serde::SegmentedFusion> serialize(
380
+ flatbuffers::FlatBufferBuilder& builder) const;
381
+
382
+ //! Deserialize SegmentedFusion using flatbuffers
383
+ void deserialize(const serde::SegmentedFusion* buffer);
384
+
385
+ private:
386
+ void validateDAG() const;
387
+ void validateDisjoint() const;
388
+
389
+ //! Serialize SegmentedEdge using flatbuffers
390
+ flatbuffers::Offset<serde::SegmentedEdge> serialize(
391
+ flatbuffers::FlatBufferBuilder& builder,
392
+ const nvfuser::SegmentedEdge* edge,
393
+ const std::unordered_map<Val*, int64_t>& vals_map,
394
+ const std::unordered_map<SegmentedGroup*, int64_t>& groups_map) const;
395
+
396
+ //! Deserialize SegmentedEdge using flatbuffers
397
+ nvfuser::SegmentedEdge deserialize(
398
+ const serde::SegmentedEdge* buffer,
399
+ const std::deque<Val*>& vals);
400
+
401
+ private:
402
+ //! Unique name for segmented fusion
403
+ size_t segmented_fusion_name_;
404
+
405
+ //! States representing segmentation
406
+ std::vector<SegmentedEdge*> edges_;
407
+ std::vector<SegmentedGroup*> groups_;
408
+
409
+ //! Owning object to explicitly manage groups and edges
410
+ class Impl {
411
+ public:
412
+ explicit Impl(SegmentedFusion* sf) : owning_fusion_(sf) {}
413
+
414
+ SegmentedGroup* makeGroup();
415
+ SegmentedGroup* makeGroup(Expr*);
416
+ SegmentedGroup* makeFusionInputGroup();
417
+ SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
418
+ void cleanUnused();
419
+ std::unordered_map<SegmentedGroup*, int64_t> groups_map() const;
420
+ std::unordered_map<SegmentedEdge*, int64_t> edges_map() const;
421
+
422
+ private:
423
+ using GroupPtr = std::unique_ptr<SegmentedGroup>;
424
+ using EdgePtr = std::unique_ptr<SegmentedEdge>;
425
+ std::vector<GroupPtr> groups_;
426
+ std::vector<EdgePtr> edges_;
427
+ SegmentedFusion* owning_fusion_;
428
+ };
429
+ Impl impl_;
430
+
431
+ //! A Copy of original full fusion
432
+ std::unique_ptr<Fusion> complete_fusion_;
433
+
434
+ //! A set of intermediate tensors that need to be cast to fp16
435
+ std::unordered_set<TensorView*> force_fp16_tv_set_;
436
+
437
+ DataType force_half_precision_type_;
438
+
439
+ //! Static traversal information to be used for fast heuristics lookup
440
+ std::unordered_map<SegmentedGroup*, std::unique_ptr<HeuristicDataCache>>
441
+ heuristic_data_cache_;
442
+
443
+ //! The number of values in fusion after constructing segmented fusion.
444
+ //! Used for checking state during deserialization.
445
+ size_t initial_vals_size_;
446
+
447
+ //! The number of expressions in fusion after constructing segmented fusion.
448
+ //! Used for checking state during deserialization.
449
+ size_t initial_exprs_size_;
450
+
451
+ // TODO: this class needs cleanup
452
+ protected:
453
+ friend class SegmentCandidateFinder;
454
+
455
+ //! Cleanup function to be call at the end of fusion
456
+ //! segment pass
457
+ void finalize();
458
+
459
+ //! Collect all the intermediate tensors between segmented
460
+ //! groups that will cast to fp16
461
+ void annotateFP16IntermediateTensors();
462
+
463
+ //! Keep heuristic checking intermediate data
464
+ void setCachedHeuristicDataFor(
465
+ SegmentedGroup* group,
466
+ std::unique_ptr<HeuristicDataCache> data);
467
+
468
+ //! Utility to give unique name for each segmented fusion
469
+ static size_t segmentedFusionName() {
470
+ static size_t counter = 0;
471
+ return counter++;
472
+ }
473
+ };
474
+
475
+ std::ostream& operator<<(
476
+ std::ostream& os,
477
+ const SegmentedFusion* segmented_fusion);
478
+
479
+ //! This is a base class for segmenter analysis
480
+ //! provides the minimal implementation on header so that
481
+ //! a unique_ptr can use this base class
482
+ //! actual implementations of analyses are in the .cpp files
483
+ //! TODO: In the next refactor PR, should put segment candidate
484
+ //! finder in .cpp file completely since API doesn't require these
485
+ //! details
486
+ class SegmenterAnalysis : public PolymorphicBase {};
487
+ class GroupDependencyAnalysis;
488
+
489
+ // Manual node merging passes
490
+ class CombineReductions;
491
+ class MergeUpAndDownCast;
492
+
493
+ //! Options to configure/debug candidate finder
494
+ struct SegmentCandidateFinderOptions {
495
+ bool run_translate_welford = true;
496
+ bool run_combine_reductions = true;
497
+ bool run_herrmann_merge = true;
498
+ bool run_final_merge = true;
499
+ bool only_segment_resharding_exprs = false;
500
+ };
501
+
502
+ //! SegmentCandidateFinder
503
+ //! Responsible for going through DAG and proposing things we could try to
504
+ //! fuse together, calls "canGenerateCode" on these proposed segments to see
505
+ //! if they are valid and we can generate code for them.
506
+ //! FusionSegment
507
+ //! A group of exprs that are segmented together
508
+ //! FusionSegmentConnections
509
+ //! Holds vals and what they connect. In other words it's a val that is an
510
+ //! output of a FusionSegment "from" and an input of FusionSegment "to".
511
+ //! There's nothing preventing from a val being between segments twice.
512
+ //! TODO: make sure there's nothing wrong with segmentation on nodes that
513
+ //! have the same value input twice. i.e. (B = A*A)
514
+ //! Selecting segments to propose is based on the theorem 4.2 in the paper which
515
+ //! makes sure when segment the segmented graph will be a DAG (assumes Fusion is
516
+ //! already a DAG). The segmentation code relies on assumptions of DAG-ness
517
+ //! during segmentation, meaning proposed merging of groups must maintain the
518
+ //! DAG property of the graph.
519
+ //!
520
+ //! Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
521
+ //! Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
522
+ //! SIAM Journal on Scientific Computing, Society for Industrial and Applied
523
+ //! Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
524
+ //! ffhal02306566f
525
+ class SegmentCandidateFinder {
526
+ public:
527
+ // Perform segmentation on a copy of the given fusion
528
+ static std::unique_ptr<SegmentedFusion> segment(
529
+ const Fusion* fusion,
530
+ const KernelArgumentHolder* inputs,
531
+ SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions());
532
+
533
+ // Perform segmentation on and take ownership of the given fusion
534
+ static std::unique_ptr<SegmentedFusion> segment(
535
+ std::unique_ptr<Fusion> fusion,
536
+ const KernelArgumentHolder* inputs,
537
+ SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions());
538
+
539
+ static std::unique_ptr<SegmentedFusion> segment(
540
+ std::unique_ptr<Fusion> fusion,
541
+ const KernelArgumentHolder* inputs,
542
+ SchedulerRuntimeInfo& runtime_info);
543
+
544
+ static bool hasSegmentHints(Fusion* fusion);
545
+
546
+ NVF_API static bool translateWelfordInFusion(
547
+ Fusion* fusion,
548
+ const KernelArgumentHolder& runtime_inputs);
549
+
550
+ private:
551
+ // Perform segmentation on and take ownership of the given fusion
552
+ NVF_API SegmentCandidateFinder(
553
+ std::unique_ptr<Fusion> fusion,
554
+ const KernelArgumentHolder* inputs,
555
+ SegmentCandidateFinderOptions options);
556
+
557
+ void resetTraversal();
558
+
559
+ void resetLevels();
560
+
561
+ SegmentedGroup* mergeNodes();
562
+
563
+ bool codeGenSupportedMerge(SegmentedGroup* group1, SegmentedGroup* group2);
564
+
565
+ void buildInitialSegments();
566
+
567
+ void findSegments();
568
+
569
+ //! Find a group found in candidates that can be merged with the
570
+ //! given group and set them to be merged if found. When no
571
+ //! candidate is given, SegmentedGroup::getMergeCandidates is used
572
+ //! to get candidates.
573
+ void trySetUpMerge(
574
+ SegmentedGroup* group,
575
+ std::vector<SegmentedGroup::NeighborGroup> candidates = {});
576
+
577
+ std::unordered_set<SegmentedEdge*> disconnectGroup(SegmentedGroup* group);
578
+
579
+ std::vector<SegmentedGroup*>& groups() {
580
+ NVF_ERROR(
581
+ segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
582
+ return segmented_fusion_->groups();
583
+ }
584
+
585
+ std::vector<SegmentedEdge*>& edges() {
586
+ NVF_ERROR(
587
+ segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
588
+ return segmented_fusion_->edges();
589
+ }
590
+
591
+ Fusion* completeFusion() {
592
+ NVF_ERROR(
593
+ segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
594
+ return segmented_fusion_->completeFusion();
595
+ }
596
+
597
+ SchedulerRuntimeInfo& runtimeInfo() {
598
+ NVF_ERROR(runtime_info_.has_value(), "needs runtime info");
599
+ return runtime_info_.value();
600
+ }
601
+
602
+ ExpressionEvaluator& expressionEvaluator() {
603
+ return runtimeInfo().expressionEvaluator();
604
+ }
605
+
606
+ //! Additional merging iteration, clean up the rest of
607
+ //! the merging opportunities
608
+ //! Herrmann et al. is a fast and safe algorithm for finding merge candidates
609
+ //! but can become too conservative in our use cases because we place
610
+ //! additional qualifiers on valid merges other than having to generate DAGs,
611
+ //! i.e. canSchedule. So we need a bruteforce final merging iteration as a
612
+ //! clean up pass. Cost isn't expected to be high since the graph at this
613
+ //! stage is already quite merged. Example cf. test_gpu.cpp:
614
+ //! FusionDAGMerging_CUDA
615
+ //!
616
+ //! This merging algorithm is based on Theorem 4.1 of Herrmann et al.,
617
+ //! to check if a producer-consumer pair can be merged into one group,
618
+ //! it's enough to check if any other consumer of the producer also
619
+ //! produces the consumer.
620
+ void finalMerge();
621
+
622
+ //! Duplicate and add all exprs producing the used
623
+ //! scalar values in group
624
+ void resolveScalarsInGroup(SegmentedGroup* group);
625
+
626
+ //! Duplicate and add all exprs from fusion inputs to `forwarded_input` into
627
+ //! the group, to complete inputs. These expressions are simply unary ops of
628
+ //! inputs that we want to recompute for each segment, instead of computing
629
+ //! and producing a segmented val. For example if we have:
630
+ //!
631
+ //! tv1 = tv0 * 2;
632
+ //! tv3 = tv1 + tv2;
633
+ //! tv4 = tv1 + tv4
634
+ //!
635
+ //! If we segmented on tv1, we would be producing an output for tv1 for 2
636
+ //! groups that have tv3 or tv4, instead we could easily recompute tv1 from
637
+ //! tv0.
638
+ void resolveNonscalarForwardedInput(Val* forwarded_input);
639
+
640
+ void resolveForwardedInputs();
641
+
642
+ // Creates the input group that ends at `forwarded_input`, i.e., the region
643
+ // between fusion inputs and `forwarded_input`.
644
+ SegmentedGroup* createInputGroup(Val* forwarded_input);
645
+
646
+ //! Remove all scalar edges in group
647
+ //! (TODO: need structure better so we don't have to do this)
648
+ void removeScalarEdges();
649
+
650
+ //! Utility function to merge a vector of groups in one step,
651
+ //! need to check for DAG condition before using this method
652
+ SegmentedGroup* mergeAllGivenGroups(
653
+ const std::vector<SegmentedGroup*>& groups);
654
+
655
+ //! Utility to remove a group and corresponding edges
656
+ //! TODO: remove inline versions of this as much as possible
657
+ void eraseGroups(std::unordered_set<SegmentedGroup*>& groups_to_erase);
658
+
659
+ void finalize();
660
+
661
+ //! Return the resulting SchedulerType corresponding to the merged
662
+ //! group built by merging the two groups connected by edge
663
+ SchedulerType deriveSchedulerType(SegmentedGroup* edge);
664
+
665
+ GroupDependencyAnalysis* getGroupDependency();
666
+
667
+ //! Find all expresions that are simply unary ops from
668
+ //! inputs. Don't segment
669
+ //! these as they're easy targets for recomputation. Only go until the first
670
+ //! expression that has multiple uses.
671
+ //!
672
+ //! The ending tensors, or the forwarded tensors, are considered
673
+ //! fusion inputs for the sake of segmentation, and the expressions
674
+ //! between the real inputs and the forwarded tensors are excluded
675
+ //! from the segmentation steps until the finalization, at which
676
+ //! point they are simply prepended to each final segment using the
677
+ //! forwarded inputs.
678
+ void forwardInputs();
679
+
680
+ void cleanupForwardedInputs();
681
+
682
+ //! Query if a val is a fusion input or a forwarded input
683
+ bool isFusionInput(Val* val) const {
684
+ return std::find(
685
+ forwarded_fusion_inputs_.begin(),
686
+ forwarded_fusion_inputs_.end(),
687
+ val) != forwarded_fusion_inputs_.end();
688
+ };
689
+
690
+ protected:
691
+ //! These are the merge node heuristic passes, should
692
+ //! eventually should have a dedicated interface
693
+ //! instead of keeping adding friends
694
+ friend class CombineReductions;
695
+ friend class MergeUpAndDownCast;
696
+
697
+ //! options to configure and debug the segment process
698
+ SegmentCandidateFinderOptions options_;
699
+
700
+ std::deque<SegmentedGroup*> to_visit_;
701
+ std::vector<SegmentedGroup*> next_to_visit_;
702
+
703
+ std::unordered_set<SegmentedGroup*> clean_up_groups_;
704
+ std::unordered_set<SegmentedEdge*> clean_up_edges_;
705
+
706
+ std::vector<SegmentedGroup*> to_merge_;
707
+
708
+ std::unique_ptr<SegmentedFusion> segmented_fusion_;
709
+
710
+ std::unique_ptr<SegmenterAnalysis> group_dependency_;
711
+
712
+ //! List of vals to treat as complete fusion inputs for segmentation
713
+ std::vector<Val*> forwarded_fusion_inputs_;
714
+
715
+ //! Keep track of complete fusion input use
716
+ std::unordered_map<Val*, SegmentedGroup*> input2group_;
717
+
718
+ // Expressions to exclude from segmentation because they're just derived from
719
+ // unary ops on inputs to the complete fusion
720
+ VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs_;
721
+
722
+ // This is allowed to be null in the multidevice case where the segmenter is
723
+ // used for breaking the fusion into compute and communication segments
724
+ std::optional<SchedulerRuntimeInfo> runtime_info_;
725
+
726
+ //! Note:
727
+ //! Segmenter should eventually rely only on runtime_info_ for
728
+ //! safe caching. runtime_inputs_ is only used in translateWelford
729
+ //! to initialize expression evaluators on copies of the original
730
+ //! fusion, which doesn't use any un-cached info and is safe.
731
+ //!
732
+ //! Directly using runtime_inputs_ in other cases is in general
733
+ //! risky.
734
+ //!
735
+ //! To get rid of runtime_inputs_ we need mechanisms
736
+ //! to copy expression evaluator values from fusion
737
+ //! to a copy, or even better to a copy of a
738
+ //! sub-graph of original fusion.
739
+ //! TODO:
740
+ //! implement the expression evaluator transfer and
741
+ //! remove runtime_inputs_ in a follow up.
742
+ const KernelArgumentHolder* runtime_inputs_;
743
+ };
744
+
745
+ // TODO: Make as member functions on classes instead of global scope
746
+ std::string toString(const SegmentedGroup* group);
747
+ std::string toString(const SegmentedEdge* edge);
748
+ std::string toString(const SegmentedFusion* segmented_fusion);
749
+ std::string toString(const SegmentCandidateFinderOptions& segment_options);
750
+
751
+ } // namespace nvfuser