nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,801 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <disjoint_set.h>
11
+ #include <exceptions.h>
12
+ #include <ir/all_nodes.h>
13
+ #include <type.h>
14
+ #include <visibility.h>
15
+
16
+ #include <algorithm>
17
+ #include <iterator>
18
+ #include <unordered_map>
19
+ #include <vector>
20
+
21
+ namespace nvfuser::MmaOpUtils {
22
+
23
+ // The expected number of concrete domains for gemm
24
+ constexpr size_t expected_gemm_cdomains = 2;
25
+
26
+ void verifyMmaOpForEvaluation(MmaOp* mma_op, DataType expected_input_dtype);
27
+
28
+ struct MatmulInputs {
29
+ Val* mma_lhs = nullptr;
30
+ Val* mma_rhs = nullptr;
31
+ Val* bias = nullptr;
32
+ Val* alpha = nullptr;
33
+ Val* beta = nullptr;
34
+ // Ordering of dimensions M,N,K in MmaOp's output TensorView's root domain.
35
+ // Determined based on position of iterdomains.
36
+ // For addmm/matmul ([M,K] x [K,N]): M=0, N=2, K=1
37
+ // For linear ([M,K] x [N,K]): M=0, N=1, K=2
38
+ // mma_dims_pos = {m_pos, n_pos, k_pos}
39
+ std::tuple<int, int, int> mma_dims_pos = {};
40
+ // The elements denote if the corresponding iterdomain in the bias was a new
41
+ // broadcast dimension. This is used to broadcast the bias for matmul/addmm
42
+ // during evaluation.
43
+ std::vector<bool> bias_bcast_flags = {};
44
+ };
45
+
46
+ } // namespace nvfuser::MmaOpUtils
47
+
48
+ namespace nvfuser::ir_utils {
49
+
50
+ // Replace values in fusion using ValReplacementMutator, it also updates fusion
51
+ // output according to the replacement_map. Returns the final
52
+ // replacement map, which includes the given replacement entries as
53
+ // well as those that are the results of the replacement.
54
+ std::unordered_map<Val*, Val*> replaceValue(
55
+ Fusion*,
56
+ const std::unordered_map<Val*, Val*>& replacement_map);
57
+
58
+ //! Checks whether this is a simple Set of a TensorView. If not, then this might
59
+ //! represent a scalar set, or a segment_set.
60
+ bool isSimpleTVSet(Expr* expr);
61
+
62
+ template <typename FilterType, typename Iterator>
63
+ class FilterIterator {
64
+ public:
65
+ using iterator_category = std::forward_iterator_tag;
66
+ using difference_type = std::ptrdiff_t;
67
+ using value_type = FilterType*;
68
+ using pointer = value_type*;
69
+ using reference = value_type&;
70
+
71
+ FilterIterator(Iterator begin, Iterator end) : current_(begin), end_(end) {
72
+ advance();
73
+ }
74
+
75
+ FilterType* operator*() const {
76
+ return (*current_)->template as<FilterType>();
77
+ }
78
+
79
+ FilterType* operator->() const {
80
+ return (*this);
81
+ }
82
+
83
+ FilterIterator& operator++() {
84
+ ++current_;
85
+ advance();
86
+ return *this;
87
+ }
88
+
89
+ FilterIterator operator++(int) {
90
+ const auto before_increment = *this;
91
+ ++current_;
92
+ advance();
93
+ return before_increment;
94
+ }
95
+
96
+ bool operator==(const FilterIterator& other) const {
97
+ NVF_ERROR(
98
+ end_ == other.end_,
99
+ "Comparing two FilteredViews that originate from different containers");
100
+ return current_ == other.current_;
101
+ }
102
+
103
+ bool operator!=(const FilterIterator& other) const {
104
+ return !(*this == other);
105
+ }
106
+
107
+ private:
108
+ void advance() {
109
+ current_ = std::find_if(current_, end_, [](const auto& val) {
110
+ return dynamic_cast<const FilterType*>(val) != nullptr;
111
+ });
112
+ }
113
+
114
+ private:
115
+ Iterator current_;
116
+ Iterator end_;
117
+ };
118
+
119
+ // An iterable view to a given container of Val pointers. Only returns
120
+ // Vals of a given Val type.
121
+ // NOTE: Add a non-const iterator if needed.
122
+ template <typename FilterType, typename InputIt>
123
+ class FilteredView {
124
+ public:
125
+ using value_type = FilterType*;
126
+ using const_iterator = FilterIterator<FilterType, InputIt>;
127
+
128
+ FilteredView(InputIt first, InputIt last) : input_it_(first), last_(last) {}
129
+
130
+ const_iterator cbegin() const {
131
+ return const_iterator(input_it_, last_);
132
+ }
133
+
134
+ const_iterator begin() const {
135
+ return cbegin();
136
+ }
137
+
138
+ const_iterator cend() const {
139
+ return const_iterator(last_, last_);
140
+ }
141
+
142
+ const_iterator end() const {
143
+ return cend();
144
+ }
145
+
146
+ bool empty() const {
147
+ return begin() == end();
148
+ }
149
+
150
+ std::vector<value_type> vector() const {
151
+ return std::vector<value_type>(begin(), end());
152
+ }
153
+
154
+ size_t size() const {
155
+ size_t s = 0;
156
+ for (auto it = cbegin(); it != cend(); ++it) {
157
+ ++s;
158
+ }
159
+ return s;
160
+ }
161
+
162
+ private:
163
+ const InputIt input_it_;
164
+ const InputIt last_;
165
+ };
166
+
167
+ template <typename FilterType, typename InputIt>
168
+ auto filterByType(InputIt first, InputIt last) {
169
+ return FilteredView<FilterType, InputIt>(first, last);
170
+ }
171
+
172
+ template <typename FilterType, typename ContainerType>
173
+ auto filterByType(const ContainerType&& inputs) = delete;
174
+
175
+ template <typename FilterType, typename ContainerType>
176
+ auto filterByType(const ContainerType& inputs) {
177
+ return filterByType<FilterType>(inputs.cbegin(), inputs.cend());
178
+ }
179
+
180
+ //! Returns a list of new-to-old mappings.
181
+ //!
182
+ //! This funcion canonicalizes the dimensions and validates that multiple old
183
+ //! dimension are mapped to the same new dimension.
184
+ std::vector<int64_t> normalizeNew2Old(
185
+ const std::vector<int64_t>& new2old_in,
186
+ int64_t ndims);
187
+
188
+ //! Returns a list of new-to-old mappings.
189
+ //!
190
+ //! The input map does not need to be complete. Missing axes are
191
+ //! assumed not to be affected.
192
+ //!
193
+ //! This is used to preprocess broadcast and transpose arguments.
194
+ //!
195
+ //! Example: (N := ndims)
196
+ //! {{0, 1}} -> [1, 0, ...., N-1]
197
+ //! Transposes the first two axes with no other change.
198
+ //!
199
+ //! {{0, -1}} -> [N-1, ...., 0]
200
+ //! Swaps the first and last axes.
201
+ std::vector<int64_t> normalizeOld2New(
202
+ const std::unordered_map<int64_t, int64_t>& old2new_in,
203
+ int64_t ndims);
204
+
205
+ //! Replaces reference Val with substitute in all Expr inputs and attributes.
206
+ //! Warning: Invalidates provided Expr.
207
+ //! Warning: Removes connection of reference through provided Expr.
208
+ //! Warning: Creates new Expr defining substitute.
209
+ NVF_API Expr* replaceValInExprInputs(
210
+ Expr* expr,
211
+ Val* reference,
212
+ Val* substitute);
213
+
214
+ //! Replace old_val with new_val in all active uses as well as in fusion
215
+ //! outputs.
216
+ void replaceValInAllExprInputsAndFusionOutputs(Val* old_val, Val* new_val);
217
+
218
+ //! Removes the given expression and creates a new expression that is identical
219
+ //! to expr, but whose outputs are given by the new_outputs argument. It is an
220
+ //! error for Vals in new_outputs that are not equal to their old equivalents to
221
+ //! have a definition as these should be freshly-created Vals that are not yet
222
+ //! defined.
223
+ //!
224
+ //! Warning: Invalidates provided Expr.
225
+ //! Warning: Creates new Expr defining substitutes.
226
+ Expr* transferDefinitionToNewOutputs(
227
+ Expr* expr,
228
+ const std::vector<Val*>& new_outputs);
229
+
230
+ //! Recursively goes to the definition of the given Val and replace the Vals as
231
+ //! specified by replacement_map while cloning the given Val.
232
+ //!
233
+ //! This is similar to replaceValInExprInputs but is different as Vals are
234
+ //! cloned such that no other exprs using the same leaf Vals are not
235
+ //! modified. TODO: Consider cleaning up the multiple replacement
236
+ //! routines.
237
+ Val* replaceValRecursively(
238
+ Val* val,
239
+ const std::unordered_map<Val*, Val*>& replacement_map);
240
+
241
+ // Makes rfactor generic with reduction ops and Welford
242
+ NVF_API TensorView* rFactorHelper(
243
+ TensorView* red_tv,
244
+ const std::vector<int64_t>& axes);
245
+
246
+ // Return immediate producers of val, this function can be used on any Val and
247
+ // will return producers through Exprs.
248
+ //
249
+ // Warning: returned val's are not guaranteed to be between fusion inputs and
250
+ // outputs. This function simply uses val->definition() or val->uses() which is
251
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
252
+ // strictly between fusion inputs/outputs, it could effectively return dead
253
+ // code.
254
+ std::vector<Val*> producerValsOf(const Val* val);
255
+
256
+ // Return immediate consumers of val, this function can be used on any Val and
257
+ // will return consumers through Exprs.
258
+ //
259
+ // Warning: returned val's are not guaranteed to be between fusion inputs and
260
+ // outputs. This function simply uses val->definition() or val->uses() which is
261
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
262
+ // strictly between fusion inputs/outputs, it could effectively return dead
263
+ // code.
264
+ std::vector<Val*> consumerValsOf(const Val* val);
265
+
266
+ // Return immediate siblings of val, this function can be used on any Val and
267
+ // will return siblings through Exprs.
268
+ //
269
+ // Warning: returned val's are not guaranteed to be between fusion inputs and
270
+ // outputs. This function simply uses val->definition() or val->uses() which is
271
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
272
+ // strictly between fusion inputs/outputs, it could effectively return dead
273
+ // code.
274
+ std::vector<Val*> siblingValsOf(const Val* val);
275
+
276
+ // Return immediate producers of vals, this function can be used on any vals and
277
+ // will return producers through Exprs.
278
+ //
279
+ // Warning: returned val's are not guaranteed to be between fusion inputs and
280
+ // outputs. This function simply uses val->definition() or val->uses() which is
281
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
282
+ // strictly between fusion inputs/outputs, it could effectively return dead
283
+ // code.
284
+ std::vector<Val*> producerValsOf(const std::vector<Val*>& vals);
285
+
286
+ // Return immediate consumers of vals, this function can be used on any vals and
287
+ // will return consumers through Exprs.
288
+ //
289
+ // Warning: returned val's are not guaranteed to be between fusion inputs and
290
+ // outputs. This function simply uses val->definition() or val->uses() which is
291
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
292
+ // strictly between fusion inputs/outputs, it could effectively return dead
293
+ // code.
294
+ std::vector<Val*> consumerValsOf(const std::vector<Val*>& vals);
295
+
296
+ // Return immediate producers of tv, this function will return all immediate
297
+ // producers of tv through Exprs.
298
+ //
299
+ // Warning: returned tv's are not guaranteed to be between fusion inputs and
300
+ // outputs. This function simply uses tv->definition() or tv->uses() which is
301
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
302
+ // strictly between fusion inputs/outputs, it could effectively return dead
303
+ // code.
304
+ NVF_API std::vector<TensorView*> producerTvsOf(const TensorView* tv);
305
+
306
+ // Return immediate consumers of tv, this function will return all immediate
307
+ // consumers of tv through Exprs.
308
+ //
309
+ // Warning: returned tv's are not guaranteed to be between fusion inputs and
310
+ // outputs. This function simply uses tv->definition() or tv->uses() which is
311
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
312
+ // strictly between fusion inputs/outputs, it could effectively return dead
313
+ // code.
314
+ NVF_API std::vector<TensorView*> consumerTvsOf(const TensorView* tv);
315
+
316
+ // Return immediate siblings of tv, this function will return all immediate
317
+ // siblings of tv through Exprs.
318
+ //
319
+ // Warning: returned tv's are not guaranteed to be between fusion inputs and
320
+ // outputs. This function simply uses tv->definition() or tv->uses() which is
321
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
322
+ // strictly between fusion inputs/outputs, it could effectively return dead
323
+ // code.
324
+ std::vector<TensorView*> siblingTvsOf(const TensorView* tv);
325
+
326
+ // Return immediate producers of tvs, this function will return all immediate
327
+ // producers of tvs through Exprs.
328
+ //
329
+ // Warning: returned tv's are not guaranteed to be between fusion inputs and
330
+ // outputs. This function simply uses tv->definition() or tv->uses() which is
331
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
332
+ // strictly between fusion inputs/outputs, it could effectively return dead
333
+ // code.
334
+ std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs);
335
+
336
+ // Return immediate consumers of tvs, this function will return all immediate
337
+ // consumers of tvs through Exprs.
338
+ //
339
+ // Warning: returned tv's are not guaranteed to be between fusion inputs and
340
+ // outputs. This function simply uses tv->definition() or tv->uses() which is
341
+ // limited to not go through fusion inputs/outputs, but if on a path that isn't
342
+ // strictly between fusion inputs/outputs, it could effectively return dead
343
+ // code.
344
+ std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs);
345
+
346
+ // Returns producers of tv that are inputs of fusion
347
+ std::vector<TensorView*> inputTvsOf(TensorView* tv);
348
+
349
+ // Returns consumers of tv that are outputs of fusion
350
+ std::vector<TensorView*> outputTvsOf(TensorView* tv);
351
+
352
+ // Returns producers of tvs that are inputs of fusion
353
+ std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs);
354
+
355
+ // Returns consumers of tvs that are outputs of fusion
356
+ std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs);
357
+
358
+ // returns all tensor views used in the provided expressions
359
+ VectorOfUniqueEntries<TensorView*> allTvsOfExprs(
360
+ const std::vector<Expr*>& exprs);
361
+
362
+ // returns all tensor views in fusion that are used between outputs and inputs
363
+ // except the specified set.
364
+ NVF_API std::vector<TensorView*> allTvsExcept(
365
+ Fusion* fusion,
366
+ const std::unordered_set<TensorView*>& except);
367
+
368
+ // Returns the initialization value of tv or nullptr if not initialized.
369
+ Val* getReductionInitValOf(TensorView* tv);
370
+
371
+ // Returns if Expr is a reduction op
372
+ bool isReductionOp(const Expr*);
373
+
374
+ // Returns if Expr is a reduction op with TensorView or TensorIndex
375
+ NVF_API bool isReductionTvOp(const Expr*);
376
+
377
+ // Returns if Expr is a pointwise op op with TensorView or TensorIndex
378
+ bool isPointwiseTvOp(const Expr* expr);
379
+
380
+ bool isSegmentSet(const Expr* e);
381
+
382
+ // Returns all non-trivial view operations. We shouldn't have trivial view
383
+ // operations but this function is to simply make sure if we ever do we don't
384
+ // pull them in.
385
+ std::vector<ViewOp*> getViewOps(Fusion*);
386
+
387
+ template <typename T>
388
+ std::string toString(const T& nodes) {
389
+ std::stringstream ss;
390
+ for (auto stmt : nodes) {
391
+ if (ss.tellp() != 0) {
392
+ ss << ", ";
393
+ }
394
+ ss << stmt->toString();
395
+ }
396
+ return ss.str();
397
+ }
398
+
399
+ template <typename T>
400
+ std::string toInlineString(const T& nodes) {
401
+ std::stringstream ss;
402
+ for (auto stmt : nodes) {
403
+ if (ss.tellp() != 0) {
404
+ ss << ", ";
405
+ }
406
+ ss << stmt->toInlineString();
407
+ }
408
+ return ss.str();
409
+ }
410
+
411
+ // Test if the given tensor is an input of squeeze op
412
+ bool isSqueezeInput(const TensorView* tv);
413
+
414
+ // Test if the given ID in the given tensor is squeezed
415
+ bool isSqueezedID(const TensorView* tv, const IterDomain* id);
416
+
417
+ // Test if the given ID in the given tensor is indirectly accessed by,
418
+ // e.g., indexSelect, torchGather and scatter
419
+ bool isIndexedID(const TensorView* tv, const IterDomain* id);
420
+
421
+ // Test if the given ID in the given tensor is indirectly read by,
422
+ // e.g., indexSelect and torchGather
423
+ bool isIndexedProducerID(const TensorView* tv, const IterDomain* id);
424
+
425
+ // Test if the given ID in the given tensor is indirectly written to by,
426
+ // e.g., scatter
427
+ bool isIndexedConsumerID(const TensorView* tv, const IterDomain* id);
428
+
429
+ // Return a producer ID, if any, that is indirectly accessed by, e.g.,
430
+ // indexSelect and torchGather.
431
+ IterDomain* getIndexedProducerID(const Expr* expr);
432
+
433
+ // Return the corresponding consumer if of a producer ID that is
434
+ // indirectly accessed.
435
+ IterDomain* getConsumerOfIndexedProducerID(const Expr* expr);
436
+
437
+ // Check if the given tv is first argment of indexSelect(lookup, dim, indices)
438
+ bool isIndexSelectLookupTv(const TensorView* tv);
439
+
440
+ // Check if the given tv is third argment of indexSelect(lookup, dim, indices)
441
+ bool isIndexSelectIndicesTv(const TensorView* tv);
442
+
443
+ bool isTorchGatherLookupTv(const Val* tv);
444
+
445
+ std::string varName(const Val* val);
446
+
447
+ // Check if a tensor is resized as part of its root to logical transformations
448
+ bool hasResizedRfactor(const TensorView* tv);
449
+
450
+ // Returns tvs that have symbolic axes
451
+ std::vector<TensorView*> getTVsWithDynamicTransform(Fusion* fusion);
452
+
453
+ //! Check if dom0 and dom1 completely covers each other with no
454
+ //! redundancy. When they are equivalent, we can consider them as a different
455
+ //! view of the each other with affine transformations.
456
+ //!
457
+ //! For example, if we have
458
+ //! I0 I1 I2 I3
459
+ //! \ / \ /
460
+ //! I4 I5
461
+ //! then [I0, I1, I2, I3] is equivalent to [I4, I5], but [I1, I2, I3] is not
462
+ //! equivalent to [I4, I5].
463
+ //!
464
+ //! Another example, if we have
465
+ //! I0 I1 I2 I3
466
+ //! \ / \ /
467
+ //! I4 I5
468
+ //! / \ / \.
469
+ //! I6 I7 I8 I9
470
+ //! Then [I0, I1, I8, I9] is equivalent to [I6, I7, I2, I3]. [I0, I1, I2, I3] is
471
+ //! equivalent to [I6, I7, I8, I9]. But [I0, I1, I8, I3] is NOT equivalent to
472
+ //! [I6, I7, I2, I9]
473
+ //!
474
+ //! Broadcast IterDomains are ignored in this check, because we consider them as
475
+ //! placeholders and allow them to be created (and annihilated?) arbitrarily as
476
+ //! needed for convenience.
477
+ //!
478
+ //! Returns if each domain has unreachable IDs. It is an error if
479
+ //! redundant IDs are detected.
480
+ struct CompareDomainResult {
481
+ bool dom0_has_unreachable_ids = false;
482
+ bool dom1_has_unreachable_ids = false;
483
+ };
484
+
485
+ // TODO: Completely replace this with compareDomainWithReference
486
+ CompareDomainResult compareDomains(
487
+ std::vector<IterDomain*> dom0,
488
+ const std::vector<IterDomain*>& dom1,
489
+ const std::vector<IterDomain*>& additional_ids = {},
490
+ bool ignore_broadcast = true);
491
+
492
+ //! Validate dom0 and dom1 are equivalent
493
+ void validateDomainEquivalence(
494
+ std::vector<IterDomain*> dom0,
495
+ const std::vector<IterDomain*>& dom1,
496
+ const std::vector<IterDomain*>& additional_ids = {});
497
+
498
+ struct CompareDomainWithReferenceResult {
499
+ // Redundant IDs found in the given domain
500
+ std::vector<IterDomain*> redundant_ids;
501
+ // IDs found in the given domain that are not connected with the
502
+ // reference domain
503
+ std::vector<IterDomain*> additional_ids;
504
+ // Reference IDs that are not reachable from the given domain
505
+ std::vector<IterDomain*> unreachable_reference_ids;
506
+
507
+ bool empty() const {
508
+ return redundant_ids.empty() && additional_ids.empty() &&
509
+ unreachable_reference_ids.empty();
510
+ }
511
+
512
+ std::string toString() const {
513
+ std::stringstream ss;
514
+ ss << "{redundant_ids: " << toDelimitedString(redundant_ids)
515
+ << ", additional_ids: " << toDelimitedString(additional_ids)
516
+ << ", unreachable_reference_ids: "
517
+ << toDelimitedString(unreachable_reference_ids) << "}";
518
+ return ss.str();
519
+ }
520
+ };
521
+
522
+ // Given a reference domain that has no redundancy, check if a given
523
+ // domain completely covers the reference domain with no
524
+ // redundancy. Redundant or extra IDs will be identified if
525
+ // any.
526
+ //
527
+ // Once caveat is that if any of the IDs of the reference domain is not
528
+ // reachable, it may not identify all of the redundant or extra IDs as
529
+ // it is unclear if they should be considered redundant or extra. For example,
530
+ // suppose we have a domain of {i0}, and there's an expession
531
+ // of `merge(i0, i1) -> i2`. Comparing {i0} with {i2} as a
532
+ // i2 will be returned as an unreachable ID. In this case, i0 is
533
+ // not used sicne i1 is missing, but it doesn't seem right to cnsider
534
+ // it's an extra ID since it would have been used if i1 were not
535
+ // missing. Similarly, it should not be considered redundant. This
536
+ // should not be a concern in practice since if any of the reference
537
+ // IDs is unreachable, it should be considered an error.
538
+ CompareDomainWithReferenceResult compareDomainWithReference(
539
+ const std::vector<IterDomain*>& domain,
540
+ const std::vector<IterDomain*>& reference);
541
+
542
+ //! Check if all the inputs required to compute needed_val are known
543
+ template <
544
+ typename ValOrVectorOfVal,
545
+ typename SetOfVal = std::unordered_set<const Val*>>
546
+ inline bool dependenciesSatisfied(
547
+ // const Val*, Val*, std::vector<const Val*>, std::vector<Val*> or any other
548
+ // container that has back(), pop_back(), empty() and emplace_back()
549
+ ValOrVectorOfVal needed_vals,
550
+ // std::unordered_set<const Val*>, std::unordered_map<const Val*, T> or any
551
+ // other container that has count()
552
+ const SetOfVal& known_vals = {}) {
553
+ if constexpr (
554
+ std::is_same_v<ValOrVectorOfVal, const Val*> ||
555
+ std::is_same_v<ValOrVectorOfVal, Val*>) {
556
+ // convert a single const Val* or Val* to a vector
557
+ return dependenciesSatisfied(
558
+ std::vector<const Val*>{needed_vals}, known_vals);
559
+ } else {
560
+ while (!needed_vals.empty()) {
561
+ auto needed_val = needed_vals.back();
562
+ needed_vals.pop_back();
563
+ if (known_vals.count(needed_val) > 0 || needed_val->isConst()) {
564
+ continue;
565
+ }
566
+ auto def = needed_val->definition();
567
+ if (def == nullptr) {
568
+ return false;
569
+ }
570
+ for (auto input : def->inputs()) {
571
+ needed_vals.emplace_back(input);
572
+ }
573
+ }
574
+ }
575
+ return true;
576
+ }
577
+
578
+ //! Check if a conditional scope, i.e., ForLoop or IfThenElse, is
579
+ //! guaranteed not to cause thread divergence
580
+ bool isAlignedScopeExpr(const Expr* expr);
581
+
582
+ //! Get the only producer of a tensor view. If there are multiple producers,
583
+ //! then throw an error.
584
+ inline TensorView* getSoleProducerTv(const TensorView* tv) {
585
+ auto producers = producerTvsOf(tv);
586
+ NVF_ERROR(
587
+ producers.size() == 1,
588
+ "Expected only one producer of ",
589
+ tv->toString(),
590
+ ", but found ",
591
+ producers.size(),
592
+ " producers.");
593
+ return producers[0];
594
+ }
595
+
596
+ //! Check and return a cycle found in fusion, search starts from `to` and ends
597
+ //! at `from`
598
+ NVF_API std::vector<Statement*> checkCycle(
599
+ Fusion* fusion,
600
+ const std::unordered_set<Statement*>& from,
601
+ const std::vector<Val*>& to);
602
+
603
+ //! Check and return a cycle found in fusion
604
+ NVF_API std::vector<Statement*> checkCycle(Fusion* fusion);
605
+
606
+ //! Check if a Val is a tensor size;
607
+ NVF_API bool isTensorSize(const Val* val);
608
+
609
+ //! Check if a Val is a tensor stride;
610
+ bool isTensorStride(const Val* val);
611
+
612
+ //! Returns a vector of the given op type or exprs if multiple types are given.
613
+ template <typename... OpTypes>
614
+ auto getOpsOfType(Fusion* fusion) {
615
+ using FirstOpType = std::tuple_element_t<0, std::tuple<OpTypes...>>;
616
+ using ExprType =
617
+ std::conditional_t<sizeof...(OpTypes) == 1, FirstOpType, Expr>;
618
+ std::vector<ExprType*> ops;
619
+ for (auto expr : fusion->exprs()) {
620
+ if (expr->isOneOf<OpTypes...>()) {
621
+ ops.push_back(expr->as<ExprType>());
622
+ }
623
+ }
624
+ return ops;
625
+ }
626
+
627
+ //! Returns true if fusion has any ops of the given type.
628
+ template <typename... OpTypes>
629
+ bool hasOpsOfType(Fusion* fusion) {
630
+ for (auto expr : fusion->exprs()) {
631
+ if (expr->isOneOf<OpTypes...>()) {
632
+ return true;
633
+ }
634
+ }
635
+ return false;
636
+ }
637
+
638
+ //! Returns true if tv is used by any ops of the given type.
639
+ template <typename... OpTypes>
640
+ bool isTvUsedByOpsOfType(TensorView* tv) {
641
+ for (auto expr : tv->uses()) {
642
+ if (expr->isOneOf<OpTypes...>()) {
643
+ return true;
644
+ }
645
+ }
646
+ return false;
647
+ }
648
+
649
+ //! Returns expressions that are of type ReductionOp, GroupedReductionOp, or
650
+ //! WelfordOp.
651
+ std::vector<Expr*> getAllTypesOfReductionOps(Fusion* fusion);
652
+
653
+ //! Returns true if fusion has any reduction ops.
654
+ bool hasAnyReductionOps(Fusion* fusion);
655
+
656
+ int64_t getVectorizeSize(const TensorView* tv);
657
+
658
+ // Returns the permutation from `in` to `out`, i.e., `out[i]==in[perm[i]]`. If
659
+ // `out` is not a permutation of `in`, returns nullopt.
660
+ template <typename T>
661
+ std::optional<std::vector<int64_t>> computePermutation(
662
+ const std::vector<T>& in,
663
+ const std::vector<T>& out) {
664
+ if (!std::is_permutation(in.begin(), in.end(), out.begin())) {
665
+ return std::nullopt;
666
+ }
667
+
668
+ std::vector<int64_t> permutation;
669
+ permutation.reserve(out.size());
670
+ // O(n^2) is totally fine for the current use case of computing the
671
+ // root-to-rfactor permutation. If needed, this can be improved by making T
672
+ // hashable and/or comparable.
673
+ for (const T& out_element : out) {
674
+ permutation.push_back(std::distance(
675
+ in.begin(), std::find(in.begin(), in.end(), out_element)));
676
+ }
677
+ return permutation;
678
+ }
679
+
680
+ template <typename T>
681
+ std::vector<T> applyPermutation(
682
+ const std::vector<T>& in,
683
+ const std::vector<int64_t>& permutation) {
684
+ NVF_CHECK(in.size() == permutation.size());
685
+
686
+ std::vector<int64_t> identity(permutation.size());
687
+ std::iota(identity.begin(), identity.end(), 0);
688
+ NVF_CHECK(std::is_permutation(
689
+ permutation.begin(), permutation.end(), identity.begin()));
690
+
691
+ std::vector<T> out;
692
+ out.reserve(permutation.size());
693
+ for (auto i : permutation) {
694
+ out.push_back(in[i]);
695
+ }
696
+ return out;
697
+ }
698
+
699
+ bool hasTrivialAllocationDomain(const TensorView* tv);
700
+
701
+ // Returns true if all expr outputs should be mapped unconditionally
702
+ bool hasUniformSiblings(Expr* expr);
703
+
704
+ // Returns true if memory_type is partitioned in parallel_type. See
705
+ // also isMemorySharedAcross. Specifically, isMemorySharedAcross == true does
706
+ // not imply isMemoryPartitionedAcross == false. For example, Local with no
707
+ // parallelization is not partitioned nor shared.
708
+ inline bool isMemoryPartitionedAcross(
709
+ MemoryType memory_type,
710
+ ParallelType parallel_type) {
711
+ switch (memory_type) {
712
+ case MemoryType::Local:
713
+ return isParallelTypeThread(parallel_type) ||
714
+ isParallelTypeDeviceDim(parallel_type);
715
+ case MemoryType::Shared:
716
+ case MemoryType::Tensor:
717
+ return isParallelTypeBlockDim(parallel_type) ||
718
+ isParallelTypeDeviceDim(parallel_type);
719
+ case MemoryType::Global:
720
+ return isParallelTypeDeviceDim(parallel_type);
721
+ default:
722
+ NVF_THROW("Unknown MemoryType: ", memory_type);
723
+ }
724
+ }
725
+
726
+ // Returns true if memory_type is shared in parallel_type. See also
727
+ // isPartitionedMemory.
728
+ inline bool isMemorySharedAcross(
729
+ MemoryType memory_type,
730
+ ParallelType parallel_type) {
731
+ switch (memory_type) {
732
+ case MemoryType::Local:
733
+ // Nothing is shared if it's Local
734
+ return false;
735
+ case MemoryType::Shared:
736
+ case MemoryType::Tensor:
737
+ // Only TID parallelized domains are shared if it's Shared or Tensor
738
+ return isParallelTypeThreadDim(parallel_type);
739
+ case MemoryType::Global:
740
+ // Only TID and BID parallelized domains are shared if it's Global
741
+ return isParallelTypeThreadDim(parallel_type) ||
742
+ isParallelTypeBlockDim(parallel_type);
743
+ default:
744
+ NVF_THROW("Unknown MemoryType: ", memory_type);
745
+ }
746
+ }
747
+
748
+ //! Check if the given tv has a root domain -> loop domain linear
749
+ //! transformation. This is a temporary check used to incrementally enable
750
+ //! IdModel. Eventually, this should be removed.
751
+ bool hasRootToLoopLinearTransformations(const TensorView* tv);
752
+
753
+ //! In addition to the above hasRootToLoopLinearTransformations, it
754
+ //! also checks the loop domain has any extra domain
755
+ bool isLoopDomainFullyDerivedFromLogicalDomain(TensorView* tv);
756
+
757
+ AsyncOpType getAsyncOpType(const Expr* expr);
758
+
759
+ //! If the given statement is nullptr, return "nullptr", otherwise return its
760
+ //! toString()
761
+ std::string nullOrToString(const Statement* stmt);
762
+
763
+ //! If the given statement is nullptr, return "nullptr", otherwise return its
764
+ //! toInlineString()
765
+ std::string nullOrToInlineString(const Statement* stmt);
766
+
767
+ //! Check if the given value is functional. A functional value is one that
768
+ //! always returns the same result when called with the same inputs.
769
+ bool isFunctional(const Val* v);
770
+
771
+ // Check if the given val is recursively defined, which is invalid in
772
+ // the Fusion IR but may not be necessarily the case in other IRs
773
+ // such as the Kernel IR
774
+ bool isRecursivelyDefined(Val* val);
775
+
776
+ // Return the number of operations that are used to define val. One
777
+ // instance of Expr is counted as a single operation.
778
+ int64_t getOperationCount(Val* val);
779
+
780
+ // Create a ForLoop IR node that represents:
781
+ // for (int i = 0; i < size; i++)
782
+ ForLoop* createRangeLoop(int64_t size);
783
+
784
+ // Returns the first output of Expr that is a TensorView
785
+ TensorView* getTvOutput(const Expr*);
786
+
787
+ // Returns the first input of Expr that is a TensorView
788
+ TensorView* getTvInput(const Expr*);
789
+
790
+ // Generates the allocation domain for the given logical domain based on the
791
+ // stride order.
792
+ std::vector<IterDomain*> strideOrderToAllocation(
793
+ const std::vector<IterDomain*>& logical_domain,
794
+ const std::vector<int64_t>& stride_order);
795
+
796
+ // Returns the number of bytes of data types of the producer and
797
+ // consumer tensors of a cast unary op
798
+ std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors(
799
+ UnaryOp* cast_op);
800
+
801
+ } // namespace nvfuser::ir_utils