nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,661 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <visibility.h>
12
+
13
+ #include <bfs.h>
14
+ #include <dispatch.h>
15
+ #include <ir/base_nodes.h>
16
+ #include <type.h>
17
+
18
+ #include <deque>
19
+ #include <unordered_set>
20
+ #include <vector>
21
+
22
+ namespace nvfuser {
23
+
24
+ class Fusion;
25
+
26
+ /*
27
+ * IterVisitor starts from leaf nodes, fusion outputs, or the provided values.
28
+ * It walks the DAG bacwkards from the starting nodes, to roots. Each node in
29
+ * the dag will be called with handle(Statement*) in topolgical order inputs of
30
+ * the fusion to outputs of the fusion.
31
+ *
32
+ * TODO: We may want a BFS version of this code to extract ILP, not implemented
33
+ * yet.
34
+ *
35
+ * TODO: We may want to have ordering of outputs to inputs. I'm not sure why we
36
+ * would want this, but seems like it would be a reasonable request.
37
+ */
38
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
39
+ class NVF_API IterVisitor : public OptOutDispatch {
40
+ public:
41
+ ~IterVisitor() override = default;
42
+
43
+ IterVisitor() = default;
44
+
45
+ IterVisitor(const IterVisitor& other) = default;
46
+ IterVisitor& operator=(const IterVisitor& other) = default;
47
+
48
+ IterVisitor(IterVisitor&& other) = default;
49
+ IterVisitor& operator=(IterVisitor&& other) = default;
50
+
51
+ protected:
52
+ // Functions return nodes in reverse order to be added to the to_visit queue
53
+ // These functions will start at outputs and propagate up through the DAG
54
+ // to inputs based on depth first traversal. Next could be called on a node
55
+ // multiple times.
56
+ virtual std::vector<Statement*> next(Statement* stmt);
57
+
58
+ virtual std::vector<Statement*> next(Val* v);
59
+
60
+ virtual std::vector<Statement*> next(Expr* expr);
61
+
62
+ using OptOutDispatch::handle;
63
+
64
+ // This dispatch functions is called on every Statement* in topological order,
65
+ // starting from outputs to inputs.
66
+ void dispatch(Statement* s) override;
67
+
68
+ // This dispatch functions is called on every Expr* in topological order,
69
+ // starting from outputs to inputs.
70
+ void dispatch(Expr* e) override;
71
+
72
+ // This dispatch functions is called on every Val* in topological order,
73
+ // starting from outputs to inputs.
74
+ void dispatch(Val* v) override;
75
+
76
+ // The entire stack during traversal. stmt_stack.back().back() is the node
77
+ // that is being called in handle(). stmt_stack.back() contains siblings (not
78
+ // guarenteed to be all siblings throughout traversal). stmt_stack.front()
79
+ // contains the outputs we started with (not guarenteed to be all outputs
80
+ // throughout traversal).
81
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
82
+ std::vector<std::vector<Statement*>> stmt_stack;
83
+
84
+ void traverseHelper(Fusion* fusion, bool traverse_all_paths = false);
85
+
86
+ public:
87
+ //! Traverses nodes in Fusion from inputs in topological order to "to". i.e.
88
+ //! from inputs towards outputs.
89
+ //! \param traverseAllPaths = false only call handle on each Statement* once
90
+ //! traverseAllPaths = true traverses all paths between expressions/values.
91
+ //! Calls handle on a Statement* for every path from inputs to "to".
92
+ //! \param traverseIntoMembers = When hitting nodes like TensorView,
93
+ //! TensorDomain, or IterDomain where there are members of the nodes that are
94
+ //! Val's a value of "true" will also traverse into those member Val's, a
95
+ //! value of "false" will not traverse into the members.
96
+ //! \param traverse_attributes When true, traverse into expr
97
+ //! attributes. Note that attributes of template type Attribute are
98
+ //! not traversed as there's no dispatch support.
99
+ //! \param traverse_siblings When true, traverse all outputs of
100
+ //! active multi-output expressions, even if those Expr outputs are not used
101
+ //! in paths to Fusion outputs.
102
+ void traverseTo(
103
+ const std::vector<Val*>& to,
104
+ bool traverse_all_paths = false,
105
+ bool traverse_into_members = false,
106
+ bool traverse_attributes = false,
107
+ bool traverse_siblings = false);
108
+
109
+ //! Traverses nodes in Fusion from inputs in topological order to "to". i.e.
110
+ //! from inputs towards outputs.
111
+ //! \param traverseAllPaths = false only call handle on each Statement* once
112
+ //! traverseAllPaths = true traverses all paths between expressions/values.
113
+ //! Calls handle on a Statement* for every path from inputs to "to".
114
+ //! \param traverseIntoMembers = When hitting nodes like TensorView,
115
+ //! TensorDomain, or IterDomain where there are members of the nodes that are
116
+ //! Val's a value of "true" will also traverse into those member Val's, a
117
+ //! value of "false" will not traverse into the members.
118
+ //! \param from: Specified values to start traversing. If a "from" Val is not
119
+ //! on path from inputs to "to" node it will not be visited. If there's a path
120
+ //! from inputs to "to" that doesn't go through "from" that input and the path
121
+ //! from it will also be traversed.
122
+ //! \param traverse_attributes When true, traverse into expr
123
+ //! attributes. Note that attributes of template type Attribute are
124
+ //! not traversed as there's no dispatch support.
125
+ //! \param traverse_siblings When true, traverse all outputs of
126
+ //! active multi-output expressions, even if those Expr outputs are not used
127
+ //! in paths to Fusion outputs.
128
+ void traverseBetween(
129
+ const std::unordered_set<Val*>& from,
130
+ const std::vector<Val*>& to,
131
+ bool traverse_all_paths = false,
132
+ bool traverse_into_members = false,
133
+ bool traverse_attributes = false,
134
+ bool traverse_siblings = false);
135
+
136
+ // Iterates from terminating outputs registered with the fusion. Terminating
137
+ // means value is not used to generate any other value used in producing
138
+ // registered outputs.
139
+ void traverse(Fusion* fusion);
140
+
141
+ // Same as traverse but it traverses every edge, meaning it will traverse
142
+ // values more than once.
143
+ void traverseAllPaths(Fusion* fusion);
144
+
145
+ //! Get inputs to vals. Possible input vals can be optionally
146
+ //! given. If not, vals with no producers are returned.
147
+ //
148
+ // TODO: This doesn't seem to fit with IterVisitor. Should probably be moved
149
+ // out of the class.
150
+ static std::vector<Val*> getInputsTo(
151
+ const std::vector<Val*>& vals,
152
+ const std::vector<Val*>& inputs = {});
153
+ };
154
+
155
+ /*
156
+ * Backward visitor calls handle in reverse order from outputs to inputs.
157
+ * It would be really nice to unify this with IterVisitor, however,
158
+ * the challenge there is that we specify traversal from outputs towards inputs
159
+ * because it implicitly provides DCE. However, if users are not careful, they
160
+ * could miss necessary outputs to do a backward traversal.
161
+ *
162
+ * BackwardVisitor checks that all outputs of an Expr is visited before visiting
163
+ * the Expr. If we don't provide nodes to start from on all backward paths of
164
+ * those outputs we will never visit the Expr.
165
+ *
166
+ * The first step of BackwardVisitor is to make sure we've specified enough
167
+ * outputs to guarentee that we will traverse all outputs of all exprs during
168
+ * the backward traversal. In the case where we don't require visiting all
169
+ * outputs of some exprs, example being the `N` output of welford ops.
170
+ * `must_cover_all_expr_outputs` is added to disable the check, and in
171
+ * this case the visitor pass need be aware
172
+ * 1. Exprs in the `from` list with any output that has a use chain that
173
+ * ends with a final consumer `will be` visited.
174
+ * 2. Vals in the `from` list that doesn't have a use chain that ends with
175
+ * a final consumer `will not be` visited, even though its
176
+ * definition expr might be visited. An example is if the `N` output
177
+ * of an welford op is unused, but other outputs are, the welford op
178
+ * will be visited but the `N` output will not.
179
+ *
180
+ */
181
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
182
+ class BackwardVisitor : public OptOutDispatch {
183
+ public:
184
+ // clang-tidy: cppcoreguidelines-virtual-class-destructor
185
+ ~BackwardVisitor() override = default;
186
+
187
+ protected:
188
+ BackwardVisitor(bool must_cover_all_expr_outputs = true)
189
+ : must_cover_all_expr_outputs_(must_cover_all_expr_outputs) {}
190
+
191
+ BackwardVisitor(const BackwardVisitor& other) = default;
192
+ BackwardVisitor& operator=(const BackwardVisitor& other) = default;
193
+
194
+ BackwardVisitor(BackwardVisitor&& other) = default;
195
+ BackwardVisitor& operator=(BackwardVisitor&& other) = default;
196
+
197
+ // Functions return nodes in reverse order to be added to the to_visit queue
198
+ // These functions will start at outputs and propagate up through the DAG
199
+ // to inputs based on depth first traversal. Next could be called on a node
200
+ // multiple times.
201
+ virtual std::vector<Statement*> next(Statement* stmt);
202
+
203
+ virtual std::vector<Statement*> next(Expr* expr);
204
+
205
+ virtual std::vector<Statement*> next(Val* val);
206
+
207
+ using OptOutDispatch::handle;
208
+
209
+ // This handle functions is called on every Statement* in topological order,
210
+ // starting from outputs to inputs.
211
+ // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
212
+ virtual void dispatch(Statement* stmt) override;
213
+
214
+ // This handle functions is called on every Expr* in topological order,
215
+ // starting from outputs to inputs.
216
+ // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
217
+ virtual void dispatch(Expr* expr) override;
218
+
219
+ // This handle functions is called on every Val* in topological order,
220
+ // starting from outputs to inputs.
221
+ // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
222
+ virtual void dispatch(Val* val) override;
223
+
224
+ // All exprs that need to be visited in this traversal. Labeled in topological
225
+ // order (size_t).
226
+ std::unordered_map<Expr*, size_t> traversal_exprs_;
227
+
228
+ // The entire stack during traversal. stmt_stack.back().back() is the node
229
+ // that is being called in handle(). stmt_stack.back() contains siblings (not
230
+ // guarenteed to be all siblings throughout traversal). stmt_stack.front()
231
+ // contains the inputs we started with (not guarenteed to be all outputs
232
+ // throughout traversal).
233
+ std::deque<std::deque<Statement*>> stmt_stack_;
234
+
235
+ // Starts at nodes provided in from, traverses from these nodes to inputs.
236
+ // Calls handle on all Statement*s in topological sorted order.
237
+ // traverseAllPaths = false only call handle on each Statement* once
238
+ // traverseAllPaths = true traverses all paths from nodes in from to inputs.
239
+ // Handle on a Statement* for every path from "from" nodes, to inputs.
240
+ void traverseTo(const std::vector<Val*>& from, bool traverseAllPaths = false);
241
+
242
+ bool must_cover_all_expr_outputs_ = true;
243
+ };
244
+
245
+ class DependencyCheck {
246
+ public:
247
+ // Returns if "dependency" is a dependency of "of".
248
+ NVF_API static bool isDependencyOf(Val* dependency, Val* of);
249
+
250
+ // Finds a Val* path from "of" to "dependency". Returns that path.
251
+ // deque.back() is "of", deque[0] is dependency if a chain exists.
252
+ NVF_API static std::deque<Val*> getSingleDependencyChain(
253
+ Val* dependency,
254
+ Val* of);
255
+
256
+ // Finds all Val* paths from "of" to "dependency". Returns those paths.
257
+ // deque[i].back() is "of", and deque[i][0] is "dependency". Returns an
258
+ // empty deque if no dependency found.
259
+ static std::deque<std::deque<Val*>> getAllDependencyChains(
260
+ Val* dependency,
261
+ Val* of);
262
+
263
+ // Finds all Val* paths from all leaf nodes to "dependency". Returns those
264
+ // paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency".
265
+ // Returns an empty deque if there are no uses of dependency found.
266
+ static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
267
+
268
+ // Grab all values that exist between and including provided
269
+ // vals. Returned values are topologicaly ordered, and unique.
270
+ NVF_API static std::vector<Val*> getAllValsBetween(
271
+ const std::unordered_set<Val*>& dependencies,
272
+ const std::vector<Val*>& of);
273
+
274
+ // Returns all dependent exprs that exist between
275
+ // the provided vals
276
+ static std::vector<Expr*> getAllExprsBetween(
277
+ const std::unordered_set<Val*>& dependencies,
278
+ const std::vector<Val*>& of);
279
+
280
+ // Return registered outputs of the fusion that are a dependency of any val of
281
+ static std::unordered_set<Val*> getAllOutputsOf(
282
+ const std::unordered_set<Val*>& of);
283
+
284
+ // Return all Vals that depend on the given Vals
285
+ static std::unordered_set<Val*> getAllDependentVals(
286
+ const std::unordered_set<Val*>& of);
287
+ };
288
+
289
+ // Expr sort will take a fusion and return a topologically sorted list of
290
+ // expressions.
291
+ class StmtSort : public IterVisitor {
292
+ protected:
293
+ StmtSort() = default;
294
+
295
+ std::vector<Statement*> stmts;
296
+
297
+ using IterVisitor::handle;
298
+
299
+ void dispatch(Statement* stmt) override;
300
+
301
+ public:
302
+ // If traverse_members it will also extract all member nodes in the sorted
303
+ // statement list in the fusion. i.e. all IterDomains, extents, and associated
304
+ // expressions of them. Similarly, if traverse_attributes it will
305
+ // grab all nodes associated as Expr attributes.
306
+ NVF_API static std::vector<Statement*> getStmts(
307
+ Fusion* fusion,
308
+ bool traverse_members = false,
309
+ bool traverse_attributes = false,
310
+ bool traverse_siblings = false);
311
+
312
+ // Returns ordered Statements required to produce 'to', including 'to'.
313
+ NVF_API static std::vector<Statement*> getStmtsTo(
314
+ const std::vector<Val*>& to,
315
+ bool traverse_members = false,
316
+ bool traverse_attributes = false,
317
+ bool traverse_siblings = false);
318
+
319
+ // Returns all ordered Statements of a given fusion. Unlike
320
+ // getStmts, for TensorDomain, all of its iter domains and exprs are
321
+ // grabbed and returned in a topological order.
322
+ NVF_API static std::vector<Statement*> getAllStmts(
323
+ Fusion* fusion,
324
+ bool traverse_members = false,
325
+ bool traverse_attributes = false,
326
+ bool traverse_siblings = false);
327
+
328
+ // Returns ordered Statements required to produce 'to', including
329
+ // 'to'. Unlike getStmtsTo, for TensorDomain, all of its iter domains and
330
+ // exprs are grabbed and returned in a topological order, regardless of
331
+ // `traverse_members`.
332
+ //
333
+ // The to vals are assumed to be either TensorView or scalar
334
+ // Val. This assumption could be removed if desired.
335
+ NVF_API static std::vector<Statement*> getAllStmtsTo(
336
+ const std::vector<Val*>& to,
337
+ bool traverse_members = false,
338
+ bool traverse_attributes = false,
339
+ bool traverse_siblings = false);
340
+
341
+ // Returns ordered Statements required to produce from, including from.
342
+ // Stops traversal once hiting any Statements in to. Includes Statements in
343
+ // to.
344
+ //
345
+ // Warning: this doesn't necessarily prevent statements before `to` from being
346
+ // returned. e.g.
347
+ // i1 = i0
348
+ // i2 = i1
349
+ // i3 = i2
350
+ // i4 = i3 + i1
351
+ // getExprs(fusion, {i4}, {i3})
352
+ // will return the definition and values {i0, i1, i4}
353
+ // i3 is dependent on i1, but since i4 also is then the traversal will go down
354
+ // the i4->i1->i0 path, even though the i4->i3-//>i2->i1 path is blocked.
355
+ //
356
+ // If traverse_members it will also extract all member nodes in the sorted
357
+ // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc
358
+ NVF_API static std::vector<Statement*> getStmtsBetween(
359
+ const std::vector<Val*>& from,
360
+ const std::vector<Val*>& to,
361
+ bool traverse_members = false,
362
+ bool traverse_attributes = false,
363
+ bool traverse_siblings = false);
364
+
365
+ // Same as getStmts version but filters to only return the Expr*s
366
+ static std::vector<Expr*> getExprs(
367
+ const Fusion* fusion,
368
+ bool traverse_members = false,
369
+ bool traverse_attributes = false,
370
+ bool traverse_siblings = false);
371
+
372
+ // Same as getStmts version but filters to only return the Expr*s
373
+ NVF_API static std::vector<Expr*> getExprsTo(
374
+ const std::vector<Val*>& to,
375
+ bool traverse_members = false,
376
+ bool traverse_attributes = false,
377
+ bool traverse_siblings = false);
378
+
379
+ // Same as getStmts version but filters to only return the Expr*s
380
+ NVF_API static std::vector<Expr*> getExprsBetween(
381
+ const std::vector<Val*>& from,
382
+ const std::vector<Val*>& to,
383
+ bool traverse_members = false,
384
+ bool traverse_attributes = false,
385
+ bool traverse_siblings = false);
386
+ };
387
+
388
+ class InputsOf : public IterVisitor {
389
+ private:
390
+ std::unordered_set<Val*> grabbed_inputs;
391
+ std::vector<Val*> ordered_inputs;
392
+
393
+ using IterVisitor::handle;
394
+
395
+ void dispatch(Val* v) final;
396
+
397
+ public:
398
+ NVF_API static std::vector<Val*> output(Val* output_);
399
+ static std::vector<Val*> outputs(const std::vector<Val*>& outputs_);
400
+ };
401
+
402
+ //! This is a generic traversal class that is used to modify a Fusion graph by
403
+ //! replacing Vals to simplify computation or remove dead code. This differs
404
+ //! from OptOutMutator, which is built for mutating TensorViews in-place in a
405
+ //! graph by altering the associated IterDomains, and which does not easily
406
+ //! handle modifying TensorView definitions and Fusion outputs during traversal.
407
+ //!
408
+ //! Derived classes should override handle() for relevant Exprs and they should
409
+ //! make use of registerReplacement() to change the definitions of Vals in the
410
+ //! graph. Note that if replacements are made using registerReplacement(old_val,
411
+ //! new_val), then neither new_val nor any new Statements produced in creating
412
+ //! it will be traversed by this class. Also note that any Vals or Exprs that
413
+ //! are previously marked dead will not be processed by handle().
414
+ class DeadCodeRemover : BackwardVisitor {
415
+ public:
416
+ DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {}
417
+
418
+ DeadCodeRemover(const DeadCodeRemover& other) = default;
419
+ DeadCodeRemover& operator=(const DeadCodeRemover& other) = default;
420
+
421
+ DeadCodeRemover(DeadCodeRemover&& other) = default;
422
+ DeadCodeRemover& operator=(DeadCodeRemover&& other) = default;
423
+
424
+ //! Instead of traverseTo, run() is the entry point for this class, and we
425
+ //! always traverse from outputs backward to their inputs.
426
+ //!
427
+ //! Returns a bool indicating whether the Fusion was modified or not.
428
+ bool run();
429
+
430
+ inline Fusion* fusion() const {
431
+ return fusion_;
432
+ }
433
+
434
+ protected:
435
+ using BackwardVisitor::handle;
436
+
437
+ void dispatch(Statement* stmt) override;
438
+ void dispatch(Expr* expr) override;
439
+
440
+ //! We implement this in order to remove dangling TensorViews whose uses are
441
+ //! all dead. Note that we do not remove other ValTypes like Scalars since
442
+ //! they might still be used as attributes or members of other objects, which
443
+ //! is not reflected by Val::uses().
444
+ void handle(TensorView* tv) override;
445
+
446
+ //! Registers a Val for replacement in outputs and in all its uses.
447
+ //!
448
+ //! Note that replacement does not occur immediately, but will be done after
449
+ //! the traversal is completed. This is so that any Val* and Expr* pointers
450
+ //! may be safely dereferenced during traversal.
451
+ //!
452
+ //! The argument old_val is always marked Dead by this method. If old_val is a
453
+ //! Fusion input, we do not replace it. If old_val's definition is non-null
454
+ //! and has other outputs which are not dead, we do not remove old_val.
455
+ //!
456
+ //! Returns whether old_val was registered for removal from the Fusion.
457
+ bool registerReplacement(Val* old_val, Val* new_val);
458
+
459
+ //! Find whether a statement is not marked as live code.
460
+ inline bool isDead(Statement* stmt) const {
461
+ return live_statements_.find(stmt) == live_statements_.end();
462
+ }
463
+
464
+ //! Find whether a statement is marked as live code.
465
+ inline bool isLive(Statement* stmt) const {
466
+ return !isDead(stmt);
467
+ }
468
+
469
+ //! Check whether all outputs of an expression have been marked dead
470
+ inline bool allOutputsDead(Expr* expr) const {
471
+ return std::all_of(
472
+ expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) {
473
+ return isDead(outp);
474
+ });
475
+ }
476
+
477
+ //! Check whether all uses have been marked dead
478
+ inline bool allUsesDead(Val* val) const {
479
+ auto fu_it = future_uses_.find(val);
480
+ if (fu_it != future_uses_.end() && !fu_it->second.empty()) {
481
+ // Regardless of whether current uses are marked dead, this appears in a
482
+ // replacement expression, so it has a future live use and we should keep
483
+ // it.
484
+ return false;
485
+ }
486
+
487
+ return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) {
488
+ return isDead(use);
489
+ });
490
+ }
491
+
492
+ private:
493
+ //! Removes an Expr* from the Fusion, if possible.
494
+ //!
495
+ //! The Expr will _only_ be marked dead and removed if all of its outputs are
496
+ //! already marked dead. In this case all the outputs will also be removed
497
+ //! from the Fusion.
498
+ //!
499
+ //! Returns whether the Expr was marked dead and removed from the Fusion.
500
+ bool maybeRemoveExpr(Expr* expr);
501
+
502
+ //! Mark a single Statement as being alive.
503
+ inline void markLive(Statement* stmt) {
504
+ live_statements_.insert(stmt);
505
+ if (auto e = dynamic_cast<Expr*>(stmt)) {
506
+ // Check if this expression is already in uses() for each of its inputs
507
+ // and if not, record it in future_uses_
508
+ for (Val* inp : e->inputs()) {
509
+ if (std::find(inp->uses().begin(), inp->uses().end(), e) ==
510
+ inp->uses().end()) {
511
+ auto fu_it = future_uses_.find(inp);
512
+ if (fu_it == future_uses_.end()) {
513
+ future_uses_.emplace(inp, std::unordered_set<Expr*>({e}));
514
+ } else {
515
+ fu_it->second.insert(e);
516
+ }
517
+ }
518
+ }
519
+ }
520
+ }
521
+
522
+ //! Ensure that a Statement and its upstream Statements are alive. If it is an
523
+ //! Expr, ensure all its inputs are alive. If it's a Val with a definition,
524
+ //! recursive to the definition. Newly-created Statements default to being
525
+ //! dead, so this method is called when adding a Statement to the active path
526
+ //! of the Fusion inside registerReplacement.
527
+ void markLiveRecursive(Statement* stmt);
528
+
529
+ //! Mark a single Statement as being dead. This does not remove stmt from the
530
+ //! Fusion. It is an error to call this on a Fusion output.
531
+ //!
532
+ //! Returns true if the statement was previously live, and false otherwise.
533
+ bool markDead(Statement* stmt);
534
+
535
+ //! Register a Val for later removal.
536
+ void registerRemoval(Val* val);
537
+
538
+ //! Register an Expr for later removal.
539
+ //!
540
+ //! Note that if any of its outputs are removed, expr will be removed even if
541
+ //! it is not marked for removal, and all its outputs will have their
542
+ //! definitions set to nullptr.
543
+ inline void registerRemoval(Expr* expr) {
544
+ exprs_to_remove_.push_back(expr);
545
+ }
546
+
547
+ //! All modifications to the Fusion are registered during traversal then
548
+ //! later they are committed by this method. For safety, this should only be
549
+ //! run after traversing the graph.
550
+ //!
551
+ //! Returns a bool indicating whether any modifications were performed.
552
+ bool modifyFusion() const;
553
+
554
+ private:
555
+ //! The Fusion associated with live_statements_
556
+ Fusion* fusion_;
557
+
558
+ //! Statements are marked dead by removing them from this set
559
+ std::unordered_set<Statement*> live_statements_;
560
+
561
+ //! Vals to be replaced in outputs and with replaceValInExprInputs in all
562
+ //! uses.
563
+ std::vector<std::pair<Val*, Val*>> vals_to_replace_;
564
+
565
+ //! Statements that will be removed. We remove Vals before Exprs, so we track
566
+ //! them separately here.
567
+ std::vector<Val*> vals_to_remove_;
568
+ std::vector<Expr*> exprs_to_remove_;
569
+
570
+ //! This holds additional _future_ uses of each val. val->uses() only returns
571
+ //! currently live uses, so until we have finalized all replacements, new uses
572
+ //! will not appear there. The mapping below gets populated whenever we mark
573
+ //! an expression as live, if that expression is not already in inp->uses()
574
+ //! for any of its inputs.
575
+ std::unordered_map<Val*, std::unordered_set<Expr*>> future_uses_;
576
+ };
577
+
578
+ struct IRDefinitions {
579
+ decltype(auto) operator()(Val* val) const {
580
+ auto def = val->definition();
581
+ if (def == nullptr) {
582
+ return std::vector<Expr*>{};
583
+ }
584
+ return std::vector<Expr*>{val->definition()};
585
+ }
586
+ };
587
+
588
+ struct IRUses {
589
+ decltype(auto) operator()(Val* val) const {
590
+ return val->uses();
591
+ }
592
+ };
593
+
594
+ struct IRInputs {
595
+ decltype(auto) operator()(Expr* expr) const {
596
+ return expr->inputs();
597
+ }
598
+ };
599
+
600
+ struct IROutputs {
601
+ decltype(auto) operator()(Expr* expr) const {
602
+ return expr->outputs();
603
+ }
604
+ };
605
+
606
+ template <>
607
+ struct GetValType<Expr*> {
608
+ using type = Val*;
609
+ };
610
+
611
+ class IRBFS
612
+ : public BFS<Expr*, Val*, IRDefinitions, IRUses, IRInputs, IROutputs> {
613
+ public:
614
+ IRBFS(
615
+ std::vector<NodeType> from_groups,
616
+ std::vector<NodeType> to_groups,
617
+ bool require_all_to_visited,
618
+ Direction allowed_direction = Direction::Undefined)
619
+ : BFS(IRDefinitions{},
620
+ IRUses{},
621
+ IRInputs{},
622
+ IROutputs{},
623
+ std::move(from_groups),
624
+ std::move(to_groups),
625
+ require_all_to_visited,
626
+ allowed_direction) {}
627
+ };
628
+
629
+ inline std::vector<Val*> getInputsOfExpr(Expr* expr, Direction dir) {
630
+ return getInputsOfExpr<Expr*>(expr, dir, IRInputs(), IROutputs());
631
+ }
632
+
633
+ inline std::vector<Val*> getOutputsOfExpr(Expr* expr, Direction dir) {
634
+ return getOutputsOfExpr<Expr*>(expr, dir, IRInputs(), IROutputs());
635
+ }
636
+
637
+ class IRPermissiveBFS : public BFSWithPermissiveDependence<
638
+ Expr*,
639
+ Val*,
640
+ IRDefinitions,
641
+ IRUses,
642
+ IRInputs,
643
+ IROutputs> {
644
+ public:
645
+ IRPermissiveBFS(
646
+ std::vector<NodeType> from_groups,
647
+ std::vector<NodeType> to_groups,
648
+ bool require_all_to_visited,
649
+ Direction allowed_direction = Direction::Undefined)
650
+ : BFSWithPermissiveDependence(
651
+ IRDefinitions{},
652
+ IRUses{},
653
+ IRInputs{},
654
+ IROutputs{},
655
+ std::move(from_groups),
656
+ std::move(to_groups),
657
+ require_all_to_visited,
658
+ allowed_direction) {}
659
+ };
660
+
661
+ } // namespace nvfuser