nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,957 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <fusion.h>
13
+ #include <ir/builder_passkey.h>
14
+ #include <ir/internal_base_nodes.h>
15
+ #include <ir/internal_nodes.h>
16
+ #include <mma_type.h>
17
+ #include <multidevice/device_mesh.h>
18
+ #include <type.h>
19
+ #include <visibility.h>
20
+
21
+ #include <torch/csrc/jit/ir/ir.h>
22
+
23
+ #include <complex>
24
+ #include <limits>
25
+ #include <sstream>
26
+
27
+ //! Nodes in here are intended to be "user facing" users in this sense being
28
+ //! those that want to be able to generate CUDA code.
29
+
30
+ //! IR header hierarchy
31
+ //! 1. utils.h - PolymorphicBase and NonCopyable
32
+ //! 2. ir/base_nodes.h - Statement, Expr, and Val
33
+ //! 3. ir/internal_base_nodes.h - IterDomain and TensorDomain
34
+ //! 4. ** ir/interface_nodes.h ** - TensorView and Scalar
35
+ //! 5. ir/internal_nodes.h - Any internal-only IR nodes
36
+
37
+ namespace nvfuser {
38
+
39
+ class WelfordResult;
40
+ class ViewTransform;
41
+
42
+ class IrCloner;
43
+
44
+ namespace ir_utils {
45
+ std::string varName(const Val* val);
46
+ }
47
+
48
+ template <typename T>
49
+ T& Expr::attribute(size_t index) const {
50
+ if constexpr (PolymorphicValue::is_candidate_type<T>) {
51
+ return attributeVal(index)->value().as<T>();
52
+ } else {
53
+ return attributeVal(index)->value().as<Opaque>().as<T>();
54
+ }
55
+ }
56
+
57
+ //! Mode during propagation of computeAt, standard will throw an error if
58
+ //! computeAt position provided can't be satisfied, best effort will lower the
59
+ //! computeAt position as needed during traversal, most inlined will increase
60
+ //! the compute at position to maximum possible through traversal.
61
+ enum class ComputeAtMode { Standard, BestEffort, MostInlined };
62
+
63
+ class TransformPropagator;
64
+ struct MostInlinedTransformPropagator;
65
+ class TransformIter;
66
+ class TransformReplay;
67
+ class OptOutMutator;
68
+ class TensorDomain;
69
+
70
+ class MaxPosCalculator;
71
+
72
+ namespace ir_utils {
73
+ class TVDomainGuard;
74
+ }
75
+
76
+ // [Circular buffering]
77
+ //
78
+ // A non-circle-buffered loop looks like below (assuming both the load and the
79
+ // compute are async ops):
80
+ // for i in range(data.size):
81
+ // load data[i] to buffer
82
+ // wait buffer to be ready (RAW sync)
83
+ // compute buffer
84
+ // wait compute to be done (WAR sync)
85
+ //
86
+ // Circular buffering allows removing RAW and WAR hazards to maximize
87
+ // overlapping of memory load and compute. Both the load and compute operations
88
+ // are pipelined. In order to pipeline the load operations, the RAW hazards need
89
+ // to be removed, so that at every iteration, the data needed for computation is
90
+ // already prefetched a few iterations ago. In order to pipeline the compute,
91
+ // WAR hazards need to be removed, so that each iterations's compute is not
92
+ // required to be completed immediately in this iteration to avoid next
93
+ // iteration's load overwriting the current iteration's operand for the compute
94
+ // operation.
95
+ //
96
+ // With circular buffering, we want to prefetch a few iterations ahead of the
97
+ // compute, and defer the load to the just-used buffer a few iterations, so that
98
+ // both the load and the compute can be pipelined, minimizing the idle time.
99
+ //
100
+ // Circular buffering is controlled by two parameters: stage and prefetch. The
101
+ // stage parameter determines the size of the circular buffer, and the prefetch
102
+ // parameter determines how many iterations ahead of the compute the data is
103
+ // prefetched. Note that prefetch must be < stage. Both the removal of RAW and
104
+ // WAR hazards require additional storage space. The prefetch parameter
105
+ // determines how buffers are partitioned between RAW and WAR hazards. If we are
106
+ // not interested in pipelining the compute, then use prefetch = stage - 1, so
107
+ // that all buffers are used for RAW removal.
108
+ //
109
+ // The figure below illustrates the timeline of a circular buffered loop, where
110
+ // each row represents an iteration:
111
+
112
+ // clang-format off
113
+
114
+ //
115
+ // /load 0;\ \.
116
+ // / load 1; [prefetch = 3] | [prefetching]
117
+ // [stage] load 2;/ /'
118
+ // [ = 6 ] load 3; wait load 0; compute 0; \.
119
+ // \ load 4; wait load 1; compute 1; |
120
+ // \load 5; wait load 2; compute 2; wait compute 0; |
121
+ // load 0; wait load 3; compute 3; wait compute 1; |
122
+ // load 1; wait load 4; compute 4; wait compute 2; |
123
+ // load 2; wait load 5; compute 5; wait compute 3; |
124
+ // load 3; wait load 0; compute 0; wait compute 4; |
125
+ // load 4; wait load 1; compute 1; wait compute 5; | [main]
126
+ // load 5; wait load 2; compute 2; wait compute 0; |
127
+ // .................................................. |
128
+ // .................................................. |
129
+ // .................................................. |
130
+ // load ; wait load ; compute ; wait compute ; |
131
+ // load ; wait load ; compute ; wait compute ; |
132
+ // load ; wait load ; compute ; wait compute ; |
133
+ // load ; wait load ; compute ; /'
134
+ // /wait load ; compute ; \.
135
+ // [same number as prefetch] wait load ; compute ; | [draining]
136
+ // \wait load ; compute ; wait all computes; /'
137
+
138
+ // clang-format on
139
+
140
+ // In the above figure, we have:
141
+ // storage required = stage * tile_size
142
+ // load pipeline depth = prefetch + 1
143
+ // compute pipeline depth = stage - prefetch
144
+ //
145
+ // There are two ways to implement the above timeline: pipelined, and
146
+ // warp-specialization.
147
+ //
148
+ // In the pipelined way, the prefetching stage is implemented as a prologue
149
+ // loop, and main stage is implemented as a main loop, and the draining stage is
150
+ // implemented as an epilogue loop. That is, we will have the following loop
151
+ // structure:
152
+ //
153
+ // Prologue loop:
154
+ // for i in range(prefetch):
155
+ // load data[i] to buffer[i]
156
+ //
157
+ // Main loop (using syncthreads to avoid WAR harzard):
158
+ // for i in range(data.size - prefetch):
159
+ // load data[i + prefetch] to buffer[(i + prefetch) % stage]
160
+ // wait buffer[i % stage] to be loaded
161
+ // compute buffer[i % stage]
162
+ // wait until the first compute in the queue is done
163
+ // (i.e. stage - prefetch - 1 in flight computes remaining)
164
+ // __syncthreads();
165
+ //
166
+ // Main loop (using mbarrier to avoid WAR harzard):
167
+ // for i in range(data.size - prefetch):
168
+ // wait buffer[(i + prefetch) % stage] to be empty
169
+ // load data[i + prefetch] to buffer[(i + prefetch) % stage]
170
+ // wait buffer[i % stage] to be loaded
171
+ // compute buffer[i % stage]
172
+ // wait until the first compute in the queue is done
173
+ // (i.e. stage - prefetch - 1 in flight computes remaining)
174
+ // signal that buffer (i + prefetch + 1) % stage is empty and ready to be
175
+ // loaded again
176
+ //
177
+ // Epilogue loop:
178
+ // for i in range(data.size - prefetch, data.size):
179
+ // wait buffer[i % stage] to be ready
180
+ // compute buffer[i % stage]
181
+ // wait until all computes are done
182
+ //
183
+ // Note that in the above loop structure, the "wait compute" in the first
184
+ // stage - prefetch - 1 iterations and last iteration of the main loop is
185
+ // redundant. We can remove them to further optimize the performance, but
186
+ // we decide to keep them for simplicity.
187
+ //
188
+ // In the warp-specialized approach, we will use different warp/warp-group
189
+ // for loading and computing. We will generate code like below (assuming warp
190
+ // specialized on TIDy):
191
+ //
192
+ // if (threadIdx.y == blockDim.y - 1) {
193
+ // // If we use warp specialization on TIDy, then the blockDim.y of the
194
+ // // kernel will be (whatever_value_inferred_from_schedule + 1), and the
195
+ // // last threadIdx.y will be used as load warp
196
+ // for i in range(data.size):
197
+ // wait buffer[i % stage] to be empty
198
+ // load data[i] to buffer[i % stage]
199
+ // } else {
200
+ // // Every threadIdx.y other than the last will be used for compute
201
+ // for i in range(prefetch + 1):
202
+ // signal that buffer i % stage is empty and ready to load
203
+ // for i in range(data.size):
204
+ // wait buffer[i % stage] to be loaded
205
+ // compute buffer[i % stage]
206
+ // wait until the first compute in the queue is done
207
+ // (i.e. stage - prefetch - 1 in flight computes remaining)
208
+ // signal that buffer (i + prefetch + 1) % stage is empty and ready to be
209
+ // loaded again
210
+ // }
211
+
212
+ struct Pipelined {
213
+ bool uses_mbarrier_for_war = false;
214
+ explicit Pipelined(bool uses_mbarrier_for_war)
215
+ : uses_mbarrier_for_war(uses_mbarrier_for_war) {}
216
+ Pipelined() = default;
217
+ bool operator==(const Pipelined& other) const {
218
+ return uses_mbarrier_for_war == other.uses_mbarrier_for_war;
219
+ }
220
+ };
221
+
222
+ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) {
223
+ if (pipelined.uses_mbarrier_for_war) {
224
+ return os << "PipelinedMBarrierForWAR";
225
+ }
226
+ return os << "Pipelined";
227
+ }
228
+
229
+ struct WarpSpecialized {
230
+ ParallelType on = ParallelType::Serial;
231
+ // The number of registers for load and compute warps respectively.
232
+ std::optional<std::pair<int64_t, int64_t>> num_registers = std::nullopt;
233
+
234
+ explicit WarpSpecialized(
235
+ ParallelType on,
236
+ std::pair<int64_t, int64_t> num_registers)
237
+ : on(on), num_registers(num_registers) {
238
+ validateRegisterSharing();
239
+ }
240
+ explicit WarpSpecialized(ParallelType on)
241
+ : on(on), num_registers(std::nullopt) {}
242
+ WarpSpecialized() = default;
243
+
244
+ void validateRegisterSharing() {
245
+ // short-circuit: register sharing is not used.
246
+ if (!num_registers.has_value()) {
247
+ return;
248
+ }
249
+ auto validate_num_registers = [](int64_t a) {
250
+ NVF_ERROR(
251
+ a >= 24 && a <= 256 && a % 8 == 0,
252
+ "The number of registers for setmaxnreg must be between 24 and",
253
+ " 256 (inclusive) and be a multiple of 8.");
254
+ };
255
+ validate_num_registers(num_registers.value().first);
256
+ validate_num_registers(num_registers.value().second);
257
+ NVF_ERROR(
258
+ num_registers.value().first <= num_registers.value().second,
259
+ "The number of registers for load warp group must be <= to the number",
260
+ " of registers for the compute warp groups.");
261
+ }
262
+
263
+ bool operator==(const WarpSpecialized& other) const {
264
+ return on == other.on && num_registers == other.num_registers;
265
+ }
266
+ };
267
+
268
+ inline std::ostream& operator<<(
269
+ std::ostream& os,
270
+ const WarpSpecialized& warp_specialized) {
271
+ std::string parallel_type_str = "";
272
+ switch (warp_specialized.on) {
273
+ case ParallelType::TIDx:
274
+ parallel_type_str = "TIDx";
275
+ break;
276
+ case ParallelType::TIDy:
277
+ parallel_type_str = "TIDy";
278
+ break;
279
+ case ParallelType::TIDz:
280
+ parallel_type_str = "TIDz";
281
+ break;
282
+ default:
283
+ NVF_THROW("Invalid parallel type");
284
+ }
285
+ std::string num_registers = "RegisterSharing_None";
286
+ if (warp_specialized.num_registers.has_value()) {
287
+ auto&& [decrease_num_reg, increase_num_reg] =
288
+ warp_specialized.num_registers.value();
289
+ std::stringstream s;
290
+ s << "RegisterSharing_" << decrease_num_reg << "_" << increase_num_reg;
291
+ num_registers = s.str();
292
+ }
293
+ return os << "WarpSpecializedOn" << parallel_type_str << num_registers;
294
+ }
295
+
296
+ using CircularBufferType = std::variant<Pipelined, WarpSpecialized>;
297
+
298
+ inline std::ostream& operator<<(
299
+ std::ostream& os,
300
+ const CircularBufferType& type) {
301
+ return std::visit(
302
+ [&os](const auto& t) -> std::ostream& { return os << t; }, type);
303
+ }
304
+
305
+ struct CircularBufferOptions {
306
+ CircularBufferType type =
307
+ Pipelined(false); // Type of circular buffer. Currently supports:
308
+ // - pipelined using syncthreads for WAR hazards
309
+ // - pipelined using mbarrier for WAR hazards.
310
+ int64_t stage = 0; // Size of the circular buffer (number of buffers)
311
+ int64_t prefetch = 0; // Number of iterations ahead of the compute to
312
+ // prefetch, can only be < stage.
313
+
314
+ bool isEnable() const {
315
+ return stage > 1;
316
+ }
317
+
318
+ bool usesMBarrierForWAR() const {
319
+ return (std::holds_alternative<Pipelined>(type) &&
320
+ std::get<Pipelined>(type).uses_mbarrier_for_war) ||
321
+ std::holds_alternative<WarpSpecialized>(type);
322
+ return false;
323
+ }
324
+
325
+ bool operator==(const CircularBufferOptions& other) const {
326
+ return type == other.type && stage == other.stage &&
327
+ prefetch == other.prefetch;
328
+ }
329
+ };
330
+
331
+ inline std::ostream& operator<<(
332
+ std::ostream& os,
333
+ const CircularBufferOptions& options) {
334
+ return os << "CircularBufferOptions{ stage=" << options.stage
335
+ << ", prefetch=" << options.prefetch << ", type=" << options.type
336
+ << " }";
337
+ }
338
+
339
+ //! TensorView is our primitive Tensor Type used in code generation. It can be
340
+ //! thought of as representing physical memory, however, its dimensionality is
341
+ //! modifed as split/merge/computeAt functions are called. The history of
342
+ //! these transformations are kept and used for generating actual code
343
+ //! referencing physical memory. Generally when users are thinking of code
344
+ //! generation in reference to a Tensor, this is the class they should be
345
+ //! interacting with.
346
+ //!
347
+ //! The reason we need both TensorView and TensorDomain is that we need to have
348
+ //! a record of both what is being computed and how it is being computed. For
349
+ //! example we may have the operation:
350
+ //!
351
+ //! TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
352
+ //!
353
+ //! The mathematical operations here are on the tensor views TV1, TV2, and
354
+ //! TV3. This operation is a pointwise operation. To compute this pointwise
355
+ //! operation we iterate over the 3D TensorDomain [I, J, K], where K is the
356
+ //! fastest changing dimension.
357
+ //!
358
+ //! \todo Need to work on the const model for TensorView, making all functions
359
+ //! that should be const, const. Gave this a try but expanded really quickly.
360
+ //! getComputeAtAxis not being const because it can return a TV that some expect
361
+ //! to be non-const is the biggest headache.
362
+ //!
363
+ class NVF_API TensorView : public Val {
364
+ public:
365
+ TensorView(
366
+ IrBuilderPasskey passkey,
367
+ TensorDomain* domain,
368
+ DataType dtype,
369
+ MemoryType mtype = MemoryType::Local);
370
+
371
+ TensorView(const TensorView* src, IrCloner* ir_cloner);
372
+
373
+ NVFUSER_DECLARE_CLONE
374
+
375
+ std::string toString(int indent_size = 0) const override;
376
+
377
+ std::string toInlineString(int indent_size = 0) const override;
378
+
379
+ void printTransforms() const;
380
+
381
+ TensorDomain* domain() const {
382
+ return domain_;
383
+ }
384
+
385
+ void setContiguity(const std::vector<std::optional<bool>>& contig) {
386
+ domain()->setContiguity(contig);
387
+ }
388
+
389
+ void setContiguity(bool contig) {
390
+ setContiguity(TensorDomain::getContiguityFilledWith(
391
+ getMaybeAllocationDomain(), contig));
392
+ }
393
+
394
+ const std::vector<std::optional<bool>>& getContiguity() const {
395
+ return domain()->contiguity();
396
+ }
397
+
398
+ bool hasReduction() const {
399
+ return domain()->hasReduction();
400
+ }
401
+
402
+ bool hasBlockReduction() const {
403
+ return domain()->hasBlockReduction();
404
+ }
405
+
406
+ bool hasGridReduction() const {
407
+ return domain()->hasGridReduction();
408
+ }
409
+
410
+ bool hasBroadcast() const {
411
+ return domain()->hasBroadcast();
412
+ }
413
+
414
+ bool hasRoot() const {
415
+ return domain()->hasRoot();
416
+ }
417
+
418
+ bool hasAllocation() const {
419
+ return domain()->hasAllocation();
420
+ }
421
+
422
+ //! Returns true if this tensor is zero dimensional,
423
+ //! i.e. a wrapped scalar or an empty placeholder.
424
+ bool isZeroDim() const {
425
+ return nDims() == 0;
426
+ }
427
+
428
+ //! Returns true if this tensor does not contain
429
+ //! any value.
430
+ bool isEmptyTensor() const;
431
+
432
+ std::optional<int64_t> getReductionAxis() const {
433
+ return domain()->getReductionAxis();
434
+ }
435
+
436
+ const std::vector<IterDomain*>& getRootDomain() const {
437
+ return domain()->root();
438
+ };
439
+
440
+ const std::vector<IterDomain*>& getMaybeRootDomain() const {
441
+ return domain()->maybeRoot();
442
+ };
443
+
444
+ const std::vector<IterDomain*>& getLogicalDomain() const {
445
+ return domain()->logical();
446
+ };
447
+
448
+ const std::vector<IterDomain*>& getAllocationDomain() const {
449
+ return domain()->allocation();
450
+ };
451
+
452
+ const std::vector<IterDomain*>& getLoopDomain() const {
453
+ return domain()->loop();
454
+ };
455
+
456
+ const std::vector<IterDomain*>& getInitialLoopDomain() const {
457
+ return domain()->initialLoop();
458
+ };
459
+
460
+ // If allocation domain exists in domain() return it, otherwise return
461
+ // logical domain
462
+ const std::vector<IterDomain*>& getMaybeAllocationDomain() const {
463
+ return domain()->maybeAllocation();
464
+ };
465
+
466
+ void setLoopDomain(std::vector<IterDomain*> new_loop_domain) {
467
+ domain()->setLoopDomain(std::move(new_loop_domain));
468
+ }
469
+
470
+ void setAllocationDomain(
471
+ std::vector<IterDomain*> new_allocation_domain,
472
+ std::vector<std::optional<bool>> new_contiguity) {
473
+ domain()->setAllocationDomain(
474
+ std::move(new_allocation_domain), std::move(new_contiguity));
475
+ }
476
+
477
+ void setAllocationDomain(
478
+ std::vector<IterDomain*> new_allocation_domain,
479
+ bool new_contiguity) {
480
+ domain()->setAllocationDomain(
481
+ std::move(new_allocation_domain), new_contiguity);
482
+ }
483
+
484
+ IterDomain* axis(int64_t pos) const;
485
+
486
+ // Does it share outer axes with other tensors?
487
+ bool hasComputeAt() const {
488
+ return compute_at_pos_ > 0;
489
+ }
490
+
491
+ bool hasMaxProducerPosition() const {
492
+ return max_producer_pos_ > 0;
493
+ }
494
+
495
+ int64_t nDims() const {
496
+ return (int64_t)domain()->nDims();
497
+ }
498
+
499
+ // sets cpu_scalar_ value, which is special handling for CPU based zero-dim
500
+ // tensors (i.e. CPU Tensors that only have one value). This is only used if
501
+ // on an input value, otherwise ignored. This is important as special handling
502
+ // because these "scalars" should be type promoted as a tensor, but we want to
503
+ // avoid explicit copying of the data, so we want to pass the data value as a
504
+ // standard kernel argument value.
505
+ void setCpuScalar(bool is_cpu_scalar);
506
+
507
+ // returns cpu_scalar_ value, which is special handling for CPU based zero-dim
508
+ // tensors (i.e. CPU Tensors that only have one value). This is only used if
509
+ // on an input value, otherwise ignored. This is important as special handling
510
+ // because these "scalars" should be type promoted as a tensor, but we want to
511
+ // avoid explicit copying of the data, so we want to pass the data value as a
512
+ // standard kernel argument value.
513
+ bool isCpuScalar() const {
514
+ return cpu_scalar_;
515
+ }
516
+
517
+ // Returns the position that this tensor is produced at relative to its axes.
518
+ int64_t getComputeAtPosition() const {
519
+ return compute_at_pos_;
520
+ }
521
+
522
+ // Returns the maximum position of producers are being computed at relative to
523
+ // this tensor. This position dictates the clear expectations of producers.
524
+ int64_t getMaxProducerPosition() const {
525
+ return max_producer_pos_;
526
+ }
527
+
528
+ int64_t getMaybeMaxProducerPosition() const {
529
+ return maybe_max_producer_pos_;
530
+ }
531
+
532
+ //! This is used when we disconnect a tensorview from a reduction
533
+ //! operation and connect it to a non-reduction operator. We need
534
+ //! to remove the reduction ids on the tv in this case.
535
+ //! Currently only used in translate welford, and this function may
536
+ //! be refactored or extended if any more use cases appear.
537
+ void clearReductionIterDomains();
538
+
539
+ //! Compute this TensorView relative to a consumer position, -1 will
540
+ //! compute tensors inline with each other, 0 doesn't share
541
+ //! any loop nests between the tensors. It's an error when the given
542
+ //! position is not legally viable. Alternatively, when the mode
543
+ //! parameter is ComputeAtMode::BestEffort, the position is lowered
544
+ //! one by one until a valid position is found. When
545
+ //! ComputeAtMode::MostInlined is given, the position parameter is
546
+ //! ignored, and the deepest possible position is searched.
547
+ TensorView* computeAt(
548
+ TensorView* consumer,
549
+ int64_t position,
550
+ ComputeAtMode mode = ComputeAtMode::Standard);
551
+
552
+ //! Create a new broadcast IterDomain with the given extent in the loop domain
553
+ TensorView* broadcast(int64_t axis, int64_t extent = 1);
554
+ TensorView* broadcast(int64_t axis, Val* extent);
555
+
556
+ // Split "axis" into 2 axes
557
+ //! inner_split dictates if the factor section of the split should be inside
558
+ //! the
559
+ //! remainer or outside.
560
+ //! e.g. split(0, 4, inner_split = true) will result in:
561
+ //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
562
+ //! e.g. split(0, 4, inner_split = false) will result in:
563
+ //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
564
+ TensorView* split(int64_t axis, int64_t factor, bool inner_split = true);
565
+
566
+ // Split "axis" into 2 axes where the inner axes is size of "factor"
567
+ // and outer axis is size axis.size() / factor. Factor can be a symbolic
568
+ // value instead of constant. This requires setting the symbolic value as an
569
+ // input, or using a parallel dim from NamedScalar::getParallelDim
570
+ TensorView* split(int64_t axis, Val* factor, bool inner_split = true);
571
+
572
+ // Merge axis_o and axis_i into 1 IterDomain
573
+ TensorView* merge(int64_t axis_o, int64_t axis_i);
574
+
575
+ // Merge axis and axis+1 into 1 IterDomain
576
+ TensorView* merge(int64_t axis) {
577
+ return merge(axis, axis + 1);
578
+ }
579
+
580
+ // Flatten the axis from `from` to `to` into a single axis.
581
+ // Both `from` and `to` are inclusive.
582
+ TensorView* flatten(int64_t from = 0, int64_t to = -1);
583
+
584
+ // Reorder axes according to old2new[old_pos] = new_pos
585
+ TensorView* reorder(const std::unordered_map<int64_t, int64_t>& old2new);
586
+ TensorView* reorder(
587
+ const std::initializer_list<std::pair<const int64_t, int64_t>>& old2new);
588
+
589
+ // Reorder axes based on the vector permutation.
590
+ // In terms of the function above, this can be seen as old2new[index] =
591
+ // permutation[index]
592
+ TensorView* reorder(const std::vector<int64_t>& permutation);
593
+ TensorView* reorder(const std::initializer_list<int64_t>& permutation);
594
+
595
+ //! Swizzle the rectangular tile defined by the iterdomains corresponding
596
+ //! to the 2 given indices.
597
+ TensorView* swizzle(SwizzleType swizzle_type, int64_t x, int64_t y);
598
+ TensorView* swizzle(
599
+ Swizzle2DType swizzle_type,
600
+ int64_t x,
601
+ int64_t y,
602
+ SwizzleMode swizzle_mode = SwizzleMode::Data);
603
+
604
+ // WARNING: rFactor does not return this TensorView, ir returns a new
605
+ // tensorview consumed by this!
606
+ //
607
+ // Take reduction axes out of this domain, and create a new
608
+ // domain. New domain will be used to create this domain.
609
+ //
610
+ // For example:
611
+ // TV1[I0, R1, R2, I3] = TV0[I0, I1, I2, I3]
612
+ //
613
+ // After:
614
+ // TV1->rfactor({1}), TV1 is transformed to -> TV1[I0, R2, I3]
615
+ //
616
+ // The TensorView returned is: TV2[I0, R1, I2, I3]
617
+ //
618
+ // The reduction will now beset as:
619
+ // TV2[I0, R1, I2, I3] = TV0[I0, I1, I2, I3]
620
+ // TV1[I0, R2, I3] = TV2[I0, R1, I2, I3]
621
+ //
622
+ TensorView* rFactor(const std::vector<int64_t>& axes);
623
+
624
+ //! Multi-output version of rFactor, semantically similar with
625
+ //! the reduction version except that the rfactor is done
626
+ //! for all outputs in a consistent way
627
+ std::vector<TensorView*> rFactor(
628
+ const std::vector<int64_t>& axes,
629
+ const std::vector<TensorView*>& tvs);
630
+
631
+ //! Create a TensorView before the original tensor. A common use case is to
632
+ //! write results into shared memory or registers before moving to global
633
+ //! memory. Analogous to TVM Cache_Write
634
+ //!
635
+ //! @param op_type: memory operator to use for the inserted op between
636
+ //! the the data tensor and the cache tensor
637
+ TensorView* cacheBefore(LoadStoreOpType op_type = LoadStoreOpType::Set);
638
+
639
+ //! Create a TensorView after the original tensor. A common use case is to
640
+ //! read tensor into shared memory or registers. Analogous to TVM Cache_Read
641
+ //!
642
+ //! @param op_type: memory operator to use for the inserted op between
643
+ //! the the data tensor and the cache tensor
644
+ //! @param cache_op: cache operator, see enum class CacheOp
645
+ //! @param propagate_allocation_domain: replay allocation domain on cached
646
+ //! load
647
+ //! @param cached_uses: if empty, cache all uses; otherwise, only try to cache
648
+ //! uses in cached_uses.
649
+ TensorView* cacheAfter(
650
+ LoadStoreOpType op_type = LoadStoreOpType::Set,
651
+ CacheOp cache_op = CacheOp::Unspecified,
652
+ bool propagate_allocation_domain = true,
653
+ std::vector<Expr*> cached_uses = {});
654
+
655
+ // For a fusion output with other uses, we want to avoid writing to global
656
+ // memory and then reading the output again. We write to global memory
657
+ // separately after an operation. We replace this fusion output with the
658
+ // direct write TensorView.
659
+ TensorView* cacheFork();
660
+
661
+ MemoryType getMemoryType() const {
662
+ return memory_type_;
663
+ }
664
+
665
+ void setMemoryType(MemoryType mt);
666
+
667
+ // Apply circular buffering transformation. Negative prefetch_distance
668
+ // means "all but", for example, -1 means number_of_stages - 1.
669
+ void circularBuffer(
670
+ int64_t number_of_stages,
671
+ int64_t prefetch_distance = -1,
672
+ CircularBufferType type = Pipelined(false));
673
+
674
+ // Returns true if this tensor is circular buffered.
675
+ bool isCircularBuffered() const {
676
+ return circular_buffer_options_.isEnable();
677
+ }
678
+
679
+ const CircularBufferOptions& circularBufferOptions() const {
680
+ return circular_buffer_options_;
681
+ }
682
+
683
+ //! Transforms the innermost iterdomains according to the given mma swizzle,
684
+ //! this should be used on the tvs that are either inputs/outputs of an
685
+ //! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
686
+ //! have a matching thread swizzle with the mma operand/result.
687
+ //! More detail on usage see [MmaSwizzler] in scheduler/mma_utils.h .
688
+ void applyMmaSwizzle(MmaOperand operand);
689
+ void applyMmaSwizzle(MmaInputSmemSwizzle swizzle);
690
+
691
+ //! Function to schedule the swizzled TMA box.
692
+ //! This functions works on the assumption that the TMA box is 2D
693
+ //! and the inner-dimension is less or equal to the swizzle size.
694
+ //! This doesn't work for the swizzle none mode. For more details
695
+ //! refer to the figure doc/dev/tma/swizzle.svg
696
+ void swizzleTMABox(MmaInputSmemSwizzle swizzle);
697
+
698
+ //! Transforms the innermost iterdomains according to the given mma swizzle,
699
+ //! this should be used on the tvs that are inputs of a MmaOp or are loaded
700
+ //! using TMA.
701
+ void applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle);
702
+
703
+ //! Returns if this tensor view has swizzle operator on its tensor domain.
704
+ //! This is the temporary flag for indicating that the new swizzle
705
+ //! implementation is used and will be removed in follow ups.
706
+ bool hasSwizzleOp() const {
707
+ return has_swizzle_op_;
708
+ }
709
+
710
+ //! A temporary helper function for the transition from Swizzle2D to Swizzle
711
+ void setHasSwizzleOp() {
712
+ has_swizzle_op_ = true;
713
+ }
714
+
715
+ friend TransformPropagator;
716
+ friend MostInlinedTransformPropagator;
717
+ friend TransformReplay;
718
+ friend OptOutMutator;
719
+ friend class InlineBatchingGuard;
720
+ friend class ir_utils::TVDomainGuard;
721
+
722
+ // Inline the computation of this tensor into its consumer at the given
723
+ // position. If this tensor is already inlined in a higher position, then this
724
+ // call is a no-op. If the right most dimensions before `pos` are
725
+ // broadcasting, then will not inline into these broadcastings. If
726
+ // best_effort, then will inline into the highest allowed position that is <=
727
+ // `pos`.
728
+ void inlineAt(
729
+ int64_t pos,
730
+ bool best_effort = false,
731
+ MaxPosCalculator* calc = nullptr);
732
+
733
+ //! Inline the computation of this tensor into a consumer at the given
734
+ //! position. The consumer to compute with is determined when the
735
+ //! fusion is lowered. Specifically, it is the first consumer tensor
736
+ //! in the topologically ordered dependency graph. Before the
737
+ //! lowering, its compute-with consumer is considered unresolved,
738
+ //! which is then resolved by resolveComputeWith below.
739
+ //!
740
+ //! The position is relative to its own domain. It is an
741
+ //! error if the position is smaller than the compute-at position. If this
742
+ //! tensor is already inlined in a higher position with the same
743
+ //! consumer, then this call is a no-op. The actual position is
744
+ //! computed in the same way as inlineAt, except that computeWith
745
+ //! does not have the constraint of the persistent data-dependency pattern.
746
+ void computeWith(int64_t pos, bool best_effort = false);
747
+
748
+ //! Set the actual consumer tensors that this tensor is
749
+ //! computed with. Requires a topologically sorted list expressions,
750
+ //! which can be obtained reorderExprsForComputeAt. Return true if
751
+ //! resolution is actually done. This should only be done in the
752
+ //! Kernel container.
753
+ bool resolveComputeWith(const std::vector<Expr*>& sorted_exprs);
754
+
755
+ bool hasComputeWith() const {
756
+ return getComputeWithPosition() > getComputeAtPosition();
757
+ }
758
+
759
+ bool hasResolvedComputeWith() const {
760
+ return !compute_with_consumers_.empty();
761
+ }
762
+
763
+ //! Query if this tensor is computed with a given consumer.
764
+ bool isComputedWith(const TensorView* consumer) const;
765
+
766
+ //! Return the tensors with which this tensor is computed. It is an
767
+ //! error to use this function without first resolving computeWith.
768
+ const std::vector<TensorView*>& getComputeWithConsumers() const;
769
+
770
+ int64_t getComputeWithPosition() const {
771
+ return compute_with_pos_;
772
+ }
773
+
774
+ int64_t getMaxComputePosition() const {
775
+ return std::max(getComputeWithPosition(), getComputeAtPosition());
776
+ }
777
+
778
+ //! Returns the position that this tensor is produced at for a given
779
+ //! consumer. If this tensor is computed with the given consumer,
780
+ //! which also means its computeWith needs to have been resolved, the
781
+ //! computeWith position is returned. Otherwise, the default computeAt
782
+ //! position is retured.
783
+ int64_t getComputePosition(const TensorView* consumer) const;
784
+
785
+ // Update the max producer position of the current tensor. This is required
786
+ // when we modify producer-consumer relationship of a scheduled tensor, for
787
+ // example, grouping multiple reductions.
788
+ void updateMaxProducerPosition();
789
+
790
+ // Commit the current changes in loop domain into rFactor domain. This
791
+ // function can be used to do implicit transpose and view, but today, only
792
+ // implicit transpose is being tested. This function can be dangerous: it
793
+ // changes the the semantics of the current tensor without updating its
794
+ // consumers consistently, and there is no reliable way to detect this
795
+ // inconsistency. It is the responsibility of the caller of this function to
796
+ // ensure consistency.
797
+ void commitLeafToLogical();
798
+
799
+ //! Request that we reclaim the memory of this tv before any subsequent
800
+ //! tensors are allocated.
801
+ //!
802
+ //! This method influences the shared memory allocator that assigns shared
803
+ //! memory addresses at lowering. It ensures that the proper synchronization
804
+ //! is present in the kernel to reuse memory and inserts new block
805
+ //! synchronizations if necessary.
806
+ void promoteReuse(bool b = true) {
807
+ NVF_CHECK(
808
+ memory_type_ == MemoryType::Shared,
809
+ "promoteReuse should only be called on shared memory tensors");
810
+ promote_reuse_ = b;
811
+ }
812
+
813
+ //! Returns whether we should insert syncs if needed in order to reuse the
814
+ //! memory of this tensor.
815
+ bool shouldPromoteReuse() const {
816
+ return promote_reuse_;
817
+ }
818
+
819
+ void setDeviceMesh(const DeviceMesh& mesh) {
820
+ mesh_ = mesh;
821
+ }
822
+
823
+ const DeviceMesh& getDeviceMesh() const {
824
+ return mesh_;
825
+ }
826
+
827
+ bool hasDeviceMesh() const {
828
+ return !mesh_.vector().empty();
829
+ }
830
+
831
+ protected:
832
+ void setDomain(TensorDomain* td) {
833
+ domain_ = td;
834
+ }
835
+
836
+ private:
837
+ int64_t wrapDim(int64_t dim) const {
838
+ return nvfuser::wrapDim(dim, nDims());
839
+ }
840
+
841
+ //! A helper function to maintain the consistency of schedules of
842
+ //! multiple outputs wheen doing rfactor on multi-output reduction ops.
843
+ TensorView* multiOutputRFactorHelper(
844
+ TensorView* tv,
845
+ const std::vector<int64_t>& axes);
846
+
847
+ void clearComputeWith();
848
+
849
+ private:
850
+ TensorDomain* domain_ = nullptr;
851
+ int64_t compute_at_pos_ = 0;
852
+ int64_t max_producer_pos_ = 0;
853
+ MemoryType memory_type_ = MemoryType::Local;
854
+
855
+ //! Indicates the circular buffering options if applicable.
856
+ CircularBufferOptions circular_buffer_options_;
857
+
858
+ // special handling for CPU based zero-dim tensors (i.e. CPU Tensors that
859
+ // only have one value). This is only used if on an input value, otherwise
860
+ // ignored. This is important as special handling because these "scalars"
861
+ // should be type promoted as a tensor, but we want to avoid explicit
862
+ // copying of the data, so we want to pass the data value as a standard
863
+ // kernel argument value.
864
+ bool cpu_scalar_ = false;
865
+
866
+ //! Indicates if this tensor view has swizzle operator on its tensor domain.
867
+ //! This is the temporary flag for indicating that the new swizzle
868
+ //! implementation is used and will be removed in follow ups.
869
+ bool has_swizzle_op_ = false;
870
+
871
+ //! Direct consumer tensors that this tensor is computed with
872
+ std::vector<TensorView*> compute_with_consumers_;
873
+
874
+ //! Position where this tensor is computed with the compute-with
875
+ //! consumer tensors. It should be always be equal or greater than
876
+ //! the computeAt position
877
+ int64_t compute_with_pos_ = 0;
878
+
879
+ //! Maximum position where producers may be computed at, including
880
+ //! unresolved computeWith. This is equal to max_producer_pos_ when
881
+ //! no producer has unresolved computeWith. It is only used before
882
+ //! resolving computeWith so that no IterDomain should never be
883
+ //! transformed when there may actually be a producer tensor that
884
+ //! may be computed at.
885
+ int64_t maybe_max_producer_pos_ = 0;
886
+
887
+ //! When this is true, it indicates, if this is a shared memory tensor and
888
+ //! there other shared memory tensors whose lifetimes do not overlap and come
889
+ //! later than this tensor's lifetime, that we should ensure that thread
890
+ //! blocks are synchronized such that all threads have performed their last
891
+ //! read of this tensor (or any tensors aliasing in) before writing to the
892
+ //! current tensor. This will then allow us to safely reuse the memory
893
+ //! allocated to this tensor.
894
+ bool promote_reuse_ = false;
895
+
896
+ // Device Mesh on which the Tensor is sharded
897
+ DeviceMesh mesh_;
898
+ };
899
+
900
+ //! A simple TensorView builder
901
+ //!
902
+ //! Example usage:
903
+ //!
904
+ //! auto tv = TensorViewBuilder()
905
+ //! .ndims(ndims)
906
+ //! .dtype(dtype)
907
+ //! .contiguity(contiguity)
908
+ //! .build();
909
+ //!
910
+ class NVF_API TensorViewBuilder {
911
+ public:
912
+ //! Set the number of dimensions of the tensor (default 0, meaning scalar)
913
+ TensorViewBuilder& ndims(int64_t ndims);
914
+
915
+ //! Set the data type of the tensor (default DataType::Float)
916
+ TensorViewBuilder& dtype(DataType dtype);
917
+
918
+ //! Set the contiguity information (default non-contiguous)
919
+ TensorViewBuilder& contiguity(std::vector<std::optional<bool>> contiguity);
920
+ TensorViewBuilder& contiguity(bool contiguity);
921
+
922
+ //! Set the shape (default 0 dimensional, ie. scalar)
923
+ TensorViewBuilder& shape(std::vector<Val*> shape);
924
+ TensorViewBuilder& shape(const std::vector<int64_t>& shape);
925
+
926
+ //! Set if a dimension is expanded
927
+ TensorViewBuilder& expanded(std::vector<bool> expanded);
928
+
929
+ //! Set the permutation from allocation domain on root domain
930
+ TensorViewBuilder& strideOrder(std::vector<int64_t> stride_order);
931
+
932
+ //! Creates a new TensorView with the specified options
933
+ TensorView* build() const;
934
+
935
+ private:
936
+ int64_t ndims_ = 0;
937
+ DataType dtype_ = DataType::Float;
938
+
939
+ // contiguity_ is the vector that you will pass to the constructor of
940
+ // TensorDomain. However, constructing this vector can be non-trivial, because
941
+ // it is required to be nullopt for broadcast dimensions. We often want to
942
+ // create contiguity vector that represents all contiguous or all
943
+ // discontiguous. uniform_contiguity_ is there to make this use case more
944
+ // convenient. If set, then TensorViewBuilder will automatically fill the
945
+ // contiguity with the value of uniform_contiguity_ where it is not required
946
+ // to be nullopt. Note that you can only set one of contiguity_ or
947
+ // uniform_contiguity_.
948
+ std::vector<std::optional<bool>> contiguity_;
949
+ std::optional<bool> uniform_contiguity_ = std::nullopt;
950
+
951
+ std::vector<Val*> shape_;
952
+
953
+ std::vector<int64_t> stride_order_;
954
+ std::vector<bool> expanded_;
955
+ };
956
+
957
+ } // namespace nvfuser