nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,744 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/base_nodes.h>
12
+ #include <optional>
13
+
14
+ //! IR header hierarchy
15
+ //! 1. utils.h - PolymorphicBase and NonCopyable
16
+ //! 2. ir/base_nodes.h - Statement, Expr, and Val
17
+ //! 3. ** ir/internal_base_nodes.h ** - IterDomain and TensorDomain
18
+ //! 4. ir/interface_nodes.h - TensorView and Scalar
19
+ //! 5. ir/internal_nodes.h - Any internal-only IR nodes
20
+
21
+ namespace nvfuser {
22
+
23
+ // Friends for direct access to split
24
+ class TensorDomain;
25
+ class IterDomain;
26
+ class ReplayTransformations;
27
+ class IndexReferenceReplay;
28
+ class ViewTransform;
29
+ class Scope;
30
+ class IrCloner;
31
+ struct AnalyzeViewResult;
32
+
33
+ // Convenience utility to initialize IterDomain's without having to sort through
34
+ // all the default values. Intended to be used with
35
+ // IterDomain::IterDomain(IrBuilderPasskey, IterDomainBuilder).
36
+ class IterDomainBuilder {
37
+ public:
38
+ // Match legacy constructor
39
+ NVF_API IterDomainBuilder(Val* _start, Val* _extent);
40
+
41
+ // Grab all the parameters from id to set the IterDomainBuilder
42
+ NVF_API IterDomainBuilder(const IterDomain* id);
43
+
44
+ // Resets defaults for rfactor, is padded dim, padded to size, and is mma
45
+ // swizzle which should only be set during scheduling.
46
+ IterDomainBuilder& resetSchedulingParams();
47
+
48
+ // Resets is_rfactor_domain
49
+ IterDomainBuilder& resetRfactor();
50
+
51
+ IterDomainBuilder& start(Val* _start);
52
+ IterDomainBuilder& extent(Val* _extent);
53
+ NVF_API IterDomainBuilder& expanded_extent(Val* _expanded_extent);
54
+ IterDomainBuilder& stop_offset(Val* _stop_offset);
55
+ IterDomainBuilder& parallel_type(ParallelType _parallel_type);
56
+ NVF_API IterDomainBuilder& iter_type(IterType _iter_type);
57
+ IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
58
+ IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
59
+ IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size);
60
+
61
+ NVF_API IterDomain* build() const;
62
+
63
+ // Must have start and extent at least
64
+ IterDomainBuilder() = delete;
65
+
66
+ Val* start_ = nullptr;
67
+ Val* extent_ = nullptr;
68
+ Val* expanded_extent_ = nullptr;
69
+ Val* stop_offset_ = nullptr;
70
+ ParallelType parallel_type_ = ParallelType::Serial;
71
+ IterType iter_type_ = IterType::Iteration;
72
+
73
+ // Only relevant at scheduling time or compile time.
74
+ bool is_rfactor_domain_ = false;
75
+ bool is_padded_dimension_ = false;
76
+ std::optional<int64_t> padded_to_size_ = std::nullopt;
77
+ };
78
+
79
+ //! Simply a representation of an annotated 1D iterable from start to extent.
80
+ //! TensorDomains which represent how to iterate over a tensor is made up of
81
+ //! IterDomains to form an ND iterable. We directly set parallization strategies
82
+ //! on IterDomains.
83
+ class NVF_API IterDomain : public Val {
84
+ public:
85
+ IterDomain(IrBuilderPasskey, const IterDomainBuilder& args);
86
+
87
+ // Legacy constructor, TODO: should start moving to use the IterDomainBuilder
88
+ // constructor. Same as the above but can set the offset of the stop point.
89
+ IterDomain(
90
+ IrBuilderPasskey,
91
+ Val* start,
92
+ Val* extent,
93
+ Val* expanded_extent,
94
+ Val* stop_offset,
95
+ ParallelType parallel_type,
96
+ IterType iter_type,
97
+ bool is_rfactor_domain,
98
+ bool is_padded_dimension,
99
+ std::optional<int64_t> padded_to_size);
100
+
101
+ IterDomain(const IterDomain* src, IrCloner* ir_cloner);
102
+
103
+ NVFUSER_DECLARE_CLONE
104
+
105
+ bool sameAs(const Statement* other) const override;
106
+
107
+ std::string toString(int indent_size = 0) const override;
108
+
109
+ std::string toInlineString(int indent_size = 0) const override;
110
+
111
+ //! Returns a new IterDomain matching properties of this
112
+ //!
113
+ //! This does NOT copy the is_rfactor_domain flag.
114
+ //!
115
+ //! When map_with_original is true, the clone of the original is
116
+ //! mapped in the Exact graph.
117
+ IterDomain* cloneWithoutRFactor(bool map_with_original = false);
118
+
119
+ //! Clone a vector domains
120
+ static std::vector<IterDomain*> clone(
121
+ const std::vector<IterDomain*>& domains);
122
+
123
+ //! The optional parameters of rfactor_domain and iter_type can be
124
+ //! used to override the default behavior.
125
+ static IterDomain* merge(
126
+ IterDomain* outer,
127
+ IterDomain* inner,
128
+ std::optional<bool> rfactor_domain = std::nullopt,
129
+ std::optional<IterType> iter_type = std::nullopt);
130
+
131
+ //! The optional parameters of rfactor_domain, outer_iter_type and
132
+ //! inner_iter_type can be used to override the default behavior.
133
+ static std::pair<IterDomain*, IterDomain*> split(
134
+ IterDomain* in,
135
+ Val* factor,
136
+ bool inner_split,
137
+ std::optional<bool> rfactor_domain = std::nullopt,
138
+ std::optional<IterType> outer_iter_type = std::nullopt,
139
+ std::optional<IterType> inner_iter_type = std::nullopt);
140
+
141
+ //! Resize an IterDomain by expanding both the left and right sides
142
+ //! by given widths. The resulting IterDomain has an extent of
143
+ //! (left_expansion + in->extent() + right_expansion). Note that the
144
+ //! expansion factors can be negative, meaning the input IterDomain
145
+ //! is shrunk. This is the case when resize is used to represent
146
+ //! slice.
147
+ //!
148
+ //! When mark_as_rfactor is true, the output IterDomain
149
+ //! is marked as an rfactor domain. For example, expressions such as
150
+ //! PadOp and SliceOp resize IterDomains and generate rfactor
151
+ //! resized domains.
152
+ //!
153
+ //! Usually, the IterType of the output IterDomain will be Symbolic. This is
154
+ //! because unless the left and right expansions are known at Fusion
155
+ //! definition we cannot be sure that the output will have an extent != 1. In
156
+ //! case the output extent is in fact 1, we will set the IterType to
157
+ //! Broadcast. If the left and right expansions are constant, and sum to at
158
+ //! least two, then even an empty input will result in an Iteration IterType.
159
+ //! In these cases, we will set the output IterType to Iteration at
160
+ //! definition. Otherwise, it will be set to Symbolic and will be resolved
161
+ //! when concretization is performed by FusionExecutorCache.
162
+ //!
163
+ //! The optional iter_type argument can be used to force the output IterType,
164
+ //! but for safety its use should typically be confined to concretization.
165
+ static IterDomain* resize(
166
+ IterDomain* in,
167
+ Val* left_expansion,
168
+ Val* right_expansion,
169
+ bool mark_as_rfactor = false,
170
+ std::optional<IterType> iter_type = std::nullopt);
171
+
172
+ bool isReduction() const {
173
+ return getIterType() == IterType::Reduction;
174
+ }
175
+
176
+ bool isIteration() const {
177
+ return getIterType() == IterType::Iteration;
178
+ }
179
+
180
+ bool isRFactorProduct() const {
181
+ return is_rfactor_domain_;
182
+ }
183
+
184
+ bool isBroadcast() const {
185
+ return getIterType() == IterType::Broadcast;
186
+ }
187
+
188
+ bool isSymbolic() const {
189
+ return getIterType() == IterType::Symbolic;
190
+ }
191
+
192
+ bool isGatherScatter() const {
193
+ return getIterType() == IterType::GatherScatter;
194
+ }
195
+
196
+ bool isStride() const {
197
+ return getIterType() == IterType::Stride;
198
+ }
199
+
200
+ bool isVectorComponent() const {
201
+ return getIterType() == IterType::VectorComponent;
202
+ }
203
+
204
+ bool isParallelized() const {
205
+ return getParallelType() != ParallelType::Serial;
206
+ }
207
+
208
+ //! Return if this iter domain is mapped to a grid dimension
209
+ bool isBlockDim() const {
210
+ return isParallelTypeBlockDim(getParallelType());
211
+ }
212
+
213
+ //! Return if this iter domain is mapped to a block dimension
214
+ bool isThreadDim() const {
215
+ return isParallelTypeThreadDim(getParallelType());
216
+ }
217
+
218
+ //! Return if this iter domain is either mapped to a block or grid dimension
219
+ bool isThread() const {
220
+ return (isBlockDim() || isThreadDim());
221
+ }
222
+
223
+ bool isDeviceDim() const {
224
+ return isParallelTypeDeviceDim(getParallelType());
225
+ }
226
+
227
+ void parallelize(ParallelType t);
228
+
229
+ ParallelType getParallelType() const {
230
+ return parallel_type_;
231
+ }
232
+
233
+ IterType getIterType() const {
234
+ return iter_type_;
235
+ }
236
+
237
+ Val* start() const {
238
+ return start_;
239
+ }
240
+
241
+ Val* stop() const;
242
+
243
+ Val* stopOffset() const;
244
+
245
+ Val* extent() const {
246
+ NVF_ERROR(extent_ != nullptr);
247
+ return extent_;
248
+ }
249
+
250
+ bool hasExpandedExtent() const {
251
+ return expanded_extent_ != nullptr;
252
+ }
253
+
254
+ // Returns the expanded extent of a strided broadcast entry.
255
+ Val* expandedExtent() const {
256
+ NVF_ERROR(
257
+ hasExpandedExtent(),
258
+ "Requested expanded extent, but none found on this dimension.");
259
+ return expanded_extent_;
260
+ }
261
+
262
+ Val* getMaybeExpandedExtent() const {
263
+ if (hasExpandedExtent()) {
264
+ return expandedExtent();
265
+ }
266
+ return extent();
267
+ }
268
+
269
+ //! Dimension padding interface:
270
+ //! 2 modes are currently supported:
271
+ //!
272
+ //! - mode 1: if to_size is given as a positive number,
273
+ //! the dimension will be padded to the size so that
274
+ //! this iterdomain will be compile-time constant
275
+ //! size and it is the scheduler's responsibility
276
+ //! to ensure no input larger than the padded size
277
+ //! will be observed
278
+ //!
279
+ //! - mode 2: if no to_size is given, this dimension
280
+ //! is "dynamically" padded to next smallest multiple
281
+ //! of a warp size, i.e. 17 padded to 32, 33 padded to 64
282
+ //! based on the given input.
283
+ void padToMultipleOfWarp(std::optional<int64_t> maybe_to_size = {}) {
284
+ // Currently only restricted to TIDx to generate warp reduce
285
+ NVF_CHECK(
286
+ parallel_type_ == ParallelType::TIDx,
287
+ "padToMultipleOfWarp : warp padding only supported on TIDx parallel dimension");
288
+ is_padded_dimension_ = true;
289
+ if (maybe_to_size.has_value()) {
290
+ if (maybe_to_size.value() > 0) {
291
+ padded_to_size_ = maybe_to_size.value();
292
+ }
293
+ }
294
+ }
295
+
296
+ //! Indicates if this iterdomain had padding
297
+ //! dynamical or statical
298
+ bool hasPaddingToMultipleOfWarp() const {
299
+ return is_padded_dimension_;
300
+ }
301
+
302
+ //! Returns a concrete value if this iterdomain
303
+ //! has been padded to a statical size.
304
+ std::optional<int64_t> getMaybeSizeAfterPadding() const {
305
+ return padded_to_size_;
306
+ }
307
+
308
+ //! True if range of iteration domain isn't across the full extent
309
+ bool maybePartial() const;
310
+
311
+ //! Check if IterDomain is a broadcast axis with compile-time
312
+ //! known extent. This is the case with all size-1 IterDomains on
313
+ //! a TensorView's root domain when the TensorView is created.
314
+ bool isImplicitBroadcast() const {
315
+ return isBroadcast() && extent()->isOneInt();
316
+ }
317
+
318
+ //! Split for stride by a given factor. It effectively does an inner
319
+ //! split by the factor and sets the inner domain as a Stride
320
+ //! domain.
321
+ std::pair<IterDomain*, IterDomain*> stridedSplit(int64_t factor);
322
+
323
+ //! Marks that this id represents a
324
+ //! instruction loop, mma use only.
325
+ //!
326
+ //! An instruction loop can be considered a generalization of
327
+ //! vectorization. It also represents a loop that's implemented
328
+ //! by an instruction and should not be realized by codegen and
329
+ //! cannot be inlined with.
330
+ //! As an example, if a mma macro, call it mma_eg implements:
331
+ //! for m in M
332
+ //! for n in N
333
+ //! for k in K
334
+ //! C[m,n] += A[m,k]*B[k,n],
335
+ //! But the generated code should simply be:
336
+ //! mma_eg(C,A,B)
337
+ //! without the 3 level loopnest, i.e. they're instruction loops.
338
+ //!
339
+ //! In the actual mma macros, the loopnests it implements is a
340
+ //! transformed version of above to match the mma swizzle.
341
+ //! So it's different implicit loopnest for different macros.
342
+ //! MmaSwizzler will label the instruction loops case-by-case.
343
+ bool isMma() const {
344
+ return parallel_type_ == ParallelType::Mma;
345
+ }
346
+
347
+ //! Marks that this id represents an instruction loop, cp.async.bulk use only.
348
+ bool isBulk() const {
349
+ return parallel_type_ == ParallelType::Bulk;
350
+ }
351
+
352
+ //! Applies 2D swizzle on a rectangular tile defined by
353
+ //! a pair of iterdomains.
354
+ static std::pair<IterDomain*, IterDomain*> swizzle(
355
+ SwizzleType swizzle_type,
356
+ IterDomain* in_x,
357
+ IterDomain* in_y);
358
+ static std::pair<IterDomain*, IterDomain*> swizzle(
359
+ Swizzle2DType swizzle_type,
360
+ IterDomain* in_x,
361
+ IterDomain* in_y,
362
+ SwizzleMode swizzle_mode = SwizzleMode::Data);
363
+
364
+ protected:
365
+ friend TensorDomain;
366
+ friend ReplayTransformations;
367
+ friend IndexReferenceReplay;
368
+
369
+ private:
370
+ //! Valid range is defined as [start:-stop_offset]
371
+ Val* const start_ = nullptr;
372
+ Val* const extent_ = nullptr;
373
+
374
+ // Broadcast dimensions are assumed to be size 1 for the sake of code
375
+ // generation. If a user though calls `expand` on a tensor that dimension is
376
+ // still considered a broadcast dimension. However if we ever output that
377
+ // dimension it should be a size dictated by the `expand` operation, and have
378
+ // a stride of zero. Since this extent is important to track, but not
379
+ // necessarily generate code for (still want loops on broadcast to be of size
380
+ // 0), we simply store it separately from extent_. Having an expanded_extent_
381
+ // is only allowed with broadcasted dimsneions. Only in this instance does it
382
+ // make sense to have an expanded_extent_, because it's used when users are
383
+ // expecting return tensors to have a physical domain. If a user simply
384
+ // "broadcasts" an operation
385
+ Val* const expanded_extent_ = nullptr;
386
+
387
+ //! Distance of stop from the end
388
+ Val* const stop_offset_ = nullptr;
389
+ ParallelType parallel_type_ = ParallelType::Serial;
390
+ IterType iter_type_ = IterType::Iteration;
391
+ bool is_rfactor_domain_ = false;
392
+ bool is_padded_dimension_ = false;
393
+ std::optional<int64_t> padded_to_size_ = std::nullopt;
394
+ };
395
+
396
+ //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
397
+ //! logical axis in its associated tensor. TensorDomain does not directly hold
398
+ //! the Tensor it is associated with, and in theory could be associated with
399
+ //! multiple tensors. TensorDomain's primary responsibility is to provide a
400
+ //! mechanism to access history of transformations that were used to generate
401
+ //! it. This is done through the normal interaction of Expr/Val in Fusion. i.e.
402
+ //! if we want to know the previous operation generating a particular
403
+ //! TensorDomain we can simply call:
404
+ //!
405
+ //! FusionGuard::getCurFusion()->definition(a_tensor_domain)
406
+ //!
407
+ //! which should give us an operation in the list [split, merge] or similar
408
+ //! operations that take in a TensorDomain, applies a transformation and outputs
409
+ //! a tensor domain.
410
+ class TensorDomain : public Val {
411
+ public:
412
+ NVF_API explicit TensorDomain(
413
+ IrBuilderPasskey,
414
+ std::vector<IterDomain*> logical_domain,
415
+ std::vector<std::optional<bool>> contiguity = {});
416
+
417
+ // See notes [ Note stride order and contiguity vector ] in
418
+ // python_bindings.cpp
419
+ TensorDomain(
420
+ IrBuilderPasskey,
421
+ std::vector<IterDomain*> logical_domain,
422
+ std::vector<int64_t> stride_order,
423
+ std::vector<std::optional<bool>> contiguity = {});
424
+
425
+ TensorDomain(
426
+ IrBuilderPasskey,
427
+ std::vector<IterDomain*> logical_domain,
428
+ std::vector<IterDomain*> loop_domain,
429
+ std::vector<std::optional<bool>> contiguity = {});
430
+
431
+ TensorDomain(
432
+ IrBuilderPasskey,
433
+ std::vector<IterDomain*> root_domain,
434
+ std::vector<IterDomain*> logical_domain,
435
+ std::vector<IterDomain*> loop_domain,
436
+ std::vector<std::optional<bool>> contiguity = {});
437
+
438
+ TensorDomain(
439
+ IrBuilderPasskey,
440
+ std::vector<IterDomain*> root_domain,
441
+ std::vector<IterDomain*> logical_domain,
442
+ std::vector<IterDomain*> allocation,
443
+ std::vector<IterDomain*> loop_domain,
444
+ std::vector<std::optional<bool>> contiguity = {},
445
+ std::vector<IterDomain*> additional_ids = {});
446
+
447
+ TensorDomain(IrBuilderPasskey, const TensorDomain* src);
448
+
449
+ TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
450
+
451
+ NVFUSER_DECLARE_CLONE
452
+
453
+ bool operator==(const TensorDomain& other) const;
454
+ bool operator!=(const TensorDomain& other) const {
455
+ return !(*this == other);
456
+ }
457
+
458
+ int64_t nDims() const {
459
+ return (int64_t)loop_domain_.size();
460
+ }
461
+
462
+ bool sameAs(const Statement* other) const override;
463
+
464
+ static bool sameAs(
465
+ const std::vector<IterDomain*>& lhs,
466
+ const std::vector<IterDomain*>& rhs);
467
+
468
+ // When `loop_only` is false, prints also the root, logical and allocation
469
+ // domain if not empty.
470
+ std::string toString(int indent_size, bool loop_only) const;
471
+ std::string toString(int indent_size = 0) const override;
472
+ std::string toInlineString(int indent_size = 0) const override;
473
+
474
+ // Note: [Contiguity]
475
+ // Contiguity is a vector of optional<bool> which has the same number of
476
+ // elements as logical_domain_. The contiguity of a broadcast dimension is
477
+ // meaningless, so it has to be nullopt. The contiguity of a non-broadcasting
478
+ // dimension is true if and only if it is memory dense with the next
479
+ // non-broadcasting dimension.
480
+ // For example, if I have a tensor torch.zeros(4, 1, 3).expand(-1, 10, -1),
481
+ // the contiguity will be (true, nullopt, true), which means 4 is memory dense
482
+ // with 3.
483
+ const std::vector<std::optional<bool>>& contiguity() const {
484
+ return contiguity_;
485
+ }
486
+
487
+ // The python frontend has a stride_order argument in the define_tensor
488
+ // function. This argument allows the user to specify the allocation domain
489
+ // for the TensorView. When translating the CPP Fusion into a Python
490
+ // FusionDefinition, the stride_order argument is required if this
491
+ // TensorDomain's allocation domain is a permutation of the logical domain.
492
+ // This function generates the stride_order argument for this TensorDomain.
493
+ std::vector<int64_t> strideOrder() const;
494
+
495
+ NVF_API void setContiguity(const std::vector<std::optional<bool>>& contig);
496
+
497
+ std::string getContiguityString() const {
498
+ return toDelimitedString(contiguity(), /*delim=*/" ");
499
+ }
500
+
501
+ bool hasReduction() const {
502
+ return has_reduction_;
503
+ }
504
+
505
+ bool hasBlockReduction() const;
506
+ bool hasGridReduction() const;
507
+ bool hasBlockBroadcast() const;
508
+ bool hasGridBroadcast() const;
509
+
510
+ bool hasBroadcast() const {
511
+ return no_bcast_domain_.size() != loop_domain_.size();
512
+ }
513
+
514
+ bool hasRoot() const {
515
+ return !root_domain_.empty();
516
+ }
517
+
518
+ bool hasAllocation() const {
519
+ return !allocation_domain_.empty();
520
+ }
521
+
522
+ // Returns if rfactor domain only consists of id's of iter type.
523
+ bool hasViewLikeRFactor() const;
524
+
525
+ bool hasVectorize() const;
526
+
527
+ NVF_API bool hasSymbolicAxis() const;
528
+
529
+ std::optional<int64_t> getReductionAxis() const;
530
+
531
+ const std::vector<IterDomain*>& noReductions() const {
532
+ return no_reduction_domain_;
533
+ }
534
+
535
+ const std::vector<IterDomain*>& noBroadcasts() const {
536
+ return no_bcast_domain_;
537
+ }
538
+
539
+ // The input logical domain. The root domain of a consumer should equal the
540
+ // logical domain of its producer ignoring reduction dimensions.
541
+ const std::vector<IterDomain*>& root() const {
542
+ return root_domain_;
543
+ };
544
+
545
+ const std::vector<IterDomain*>& maybeRoot() const {
546
+ return root_domain_.empty() ? logical_domain_ : root_domain_;
547
+ };
548
+
549
+ // Check if id is a root ID. Always return false if there's no root
550
+ // domain.
551
+ bool isRoot(const IterDomain* id) const {
552
+ return hasRoot() &&
553
+ std::find(root().begin(), root().end(), id) != root().end();
554
+ }
555
+
556
+ bool isMaybeRoot(const IterDomain* id) const {
557
+ return (hasRoot() && isRoot(id)) || (!hasRoot() && isLogical(id));
558
+ }
559
+
560
+ // The output logical domain.
561
+ const std::vector<IterDomain*>& logical() const {
562
+ return logical_domain_;
563
+ };
564
+
565
+ // Check if id is a logical ID.
566
+ bool isLogical(const IterDomain* id) const {
567
+ return std::find(logical().begin(), logical().end(), id) != logical().end();
568
+ }
569
+
570
+ // The allocation domain. This describes how data is stored in memory in
571
+ // outer-to-inner order.
572
+ const std::vector<IterDomain*>& allocation() const {
573
+ return allocation_domain_;
574
+ }
575
+
576
+ // Check if id is an allocation ID. Always return false if there's
577
+ // no allocation domain.
578
+ bool isAllocation(const IterDomain* id) const {
579
+ return hasAllocation() &&
580
+ std::find(allocation().begin(), allocation().end(), id) !=
581
+ allocation().end();
582
+ }
583
+
584
+ // The loop domain after scheduling. This defines loop nests and loop indices.
585
+ const std::vector<IterDomain*>& loop() const {
586
+ return loop_domain_;
587
+ }
588
+
589
+ const std::vector<IterDomain*>& initialLoop() const {
590
+ return initial_loop_domain_;
591
+ }
592
+
593
+ // Check if id is a loop ID.
594
+ bool isLoop(const IterDomain* id) const {
595
+ return std::find(loop().begin(), loop().end(), id) != loop().end();
596
+ }
597
+
598
+ // Check if id is an intial loop ID.
599
+ bool isInitialLoop(const IterDomain* id) const {
600
+ return std::find(initialLoop().begin(), initialLoop().end(), id) !=
601
+ loop().end();
602
+ }
603
+
604
+ // Get all IDs that is on the shortest path between any of the domains
605
+ // (logical domain, root domain, loop domain, allocation domain) following
606
+ // definition and uses path. Return values are topologically ordered and
607
+ // unique.
608
+ std::vector<IterDomain*> allIDs() const;
609
+
610
+ // Similar to allIDs but returns all ID expressions.
611
+ std::vector<Expr*> allExprs() const;
612
+
613
+ // Combine allIDs and allExprs
614
+ std::vector<Statement*> allStatements() const;
615
+
616
+ const std::vector<IterDomain*>& maybeAllocation() const {
617
+ return hasAllocation() ? allocation_domain_ : logical();
618
+ };
619
+
620
+ // Additional IDs that are not on the path from one of
621
+ // root/logical/allocation/loop domain to another. We need to keep track of
622
+ // these IDs to ensure that we can find all paths/IDs of interest.
623
+ const std::vector<IterDomain*>& additionalIDs() const {
624
+ return additional_ids_;
625
+ }
626
+
627
+ // Set the loop domain of this TensorDomain.
628
+ NVF_API void setLoopDomain(std::vector<IterDomain*> new_loop_domain);
629
+
630
+ // Set the allocation domain of this TensorDomain. Because contiguity is
631
+ // always defined w.r.t. the allocation domain, the contiguity must be updated
632
+ // accordingly.
633
+ NVF_API void setAllocationDomain(
634
+ std::vector<IterDomain*> new_allocation_domain,
635
+ std::vector<std::optional<bool>> new_contiguity);
636
+
637
+ // Similar to the previous one, but with new contiguity filled with all true
638
+ // or all false.
639
+ void setAllocationDomain(
640
+ std::vector<IterDomain*> new_allocation_domain,
641
+ bool new_contiguity) {
642
+ auto contiguity_flags =
643
+ getContiguityFilledWith(new_allocation_domain, new_contiguity);
644
+ setAllocationDomain(
645
+ std::move(new_allocation_domain), std::move(contiguity_flags));
646
+ }
647
+
648
+ void resetDomains() {
649
+ no_reduction_domain_ = noReductions(loop_domain_);
650
+ no_bcast_domain_ = noBroadcasts(loop_domain_);
651
+ has_reduction_ = hasReduction(loop_domain_);
652
+ }
653
+
654
+ // i here is int, as we want to accept negative value and ::size_type can be a
655
+ // uint.
656
+ IterDomain* axis(int64_t i) const;
657
+
658
+ int64_t posOf(IterDomain* id) const;
659
+
660
+ //! Returns a position of a root domain
661
+ int64_t rootPosOf(IterDomain* id) const;
662
+
663
+ //! Create a new broadcast IterDomain with the given extent in the loop domain
664
+ void broadcast(int64_t axis, Val* extent);
665
+
666
+ // Split "axis" into 2 axes
667
+ //! inner_split dictates if the factor section of the split should be inside
668
+ //! the
669
+ //! remainer or outside.
670
+ //! e.g. split(0, 4, inner_split = true) will result in:
671
+ //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
672
+ //! e.g. split(0, 4, inner_split = false) will result in:
673
+ //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
674
+ void split(int64_t axis_, Val* factor, bool inner_split);
675
+
676
+ // Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting
677
+ // axis is by default placed at original position axis_o
678
+ void merge(int64_t axis_o, int64_t axis_i);
679
+
680
+ // Reorder axes according to map[old_pos] = new_pos
681
+ void reorder(const std::unordered_map<int64_t, int64_t>& old2new);
682
+
683
+ //! Applies 2D swizzle on a rectangular tile defined by
684
+ //! a pair of iterdomains contained in this domain.
685
+ void swizzle(SwizzleType swizzle_type, int64_t x, int64_t y);
686
+ void swizzle(
687
+ Swizzle2DType swizzle_type,
688
+ int64_t x,
689
+ int64_t y,
690
+ SwizzleMode swizzle_mode = SwizzleMode::Data);
691
+
692
+ // Transform TensorView according to merge and split transformations
693
+ TensorDomain* view(const AnalyzeViewResult& view_analysis);
694
+
695
+ TensorDomain* flatten(int64_t start_dim, int64_t end_dim);
696
+
697
+ static std::vector<IterDomain*> orderedAs(
698
+ const std::vector<IterDomain*>& td,
699
+ const std::unordered_map<int64_t, int64_t>& old2new);
700
+
701
+ NVF_API static std::vector<IterDomain*> noReductions(
702
+ const std::vector<IterDomain*>&);
703
+ NVF_API static std::vector<IterDomain*> noBroadcasts(
704
+ const std::vector<IterDomain*>&);
705
+ NVF_API static std::vector<IterDomain*> noDevices(
706
+ const std::vector<IterDomain*>&);
707
+
708
+ static bool hasBroadcast(const std::vector<IterDomain*>&);
709
+ static bool hasReduction(const std::vector<IterDomain*>&);
710
+
711
+ // Get a vector whose size is the number of IDs in the given logical_domain
712
+ // filled with fill_value or nullopt depending on whether its corresponding ID
713
+ // is broadcast.
714
+ NVF_API static std::vector<std::optional<bool>> getContiguityFilledWith(
715
+ const std::vector<IterDomain*>& logical_domain,
716
+ bool fill_value);
717
+
718
+ // pair is in order where second is the consumer of first
719
+ std::pair<TensorDomain*, TensorDomain*> rFactor(
720
+ const std::vector<int64_t>& axes);
721
+
722
+ private:
723
+ int64_t wrapDim(int64_t dim) const {
724
+ return nvfuser::wrapDim(dim, nDims());
725
+ }
726
+
727
+ private:
728
+ const std::vector<IterDomain*> root_domain_;
729
+ const std::vector<IterDomain*> logical_domain_;
730
+ std::vector<IterDomain*> allocation_domain_;
731
+ std::vector<IterDomain*> loop_domain_;
732
+ // Initial loop domain. Loop domain is updated with transformations
733
+ // such as split, but the initial loop domain can only change with
734
+ // setLoopDomain
735
+ std::vector<IterDomain*> initial_loop_domain_;
736
+ std::vector<IterDomain*> additional_ids_;
737
+
738
+ std::vector<IterDomain*> no_bcast_domain_;
739
+ std::vector<IterDomain*> no_reduction_domain_;
740
+ std::vector<std::optional<bool>> contiguity_;
741
+ bool has_reduction_ = false;
742
+ };
743
+
744
+ } // namespace nvfuser