nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,2792 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/interface_nodes.h>
12
+
13
+ #include <fusion.h>
14
+ #include <ir/base_nodes.h>
15
+ #include <mma_type.h>
16
+ #include <parallel_type_bitmap.h>
17
+ #include <visibility.h>
18
+
19
+ //! Nodes in here should generally not be used by users. They should be behind
20
+ //! the scenes and users shouldn't have to be aware of what they do to use the
21
+ //! code generator
22
+ //!
23
+ //! \todo improve implementation bool IterDomain::sameAs(const IterDomain*)
24
+ //! \todo Add testing of sameAs functions for these nodes
25
+ //!
26
+
27
+ //! IR header hierarchy
28
+ //! 1. utils.h - PolymorphicBase and NonCopyable
29
+ //! 2. ir/base_nodes.h - Statement, Expr, and Val
30
+ //! 3. ir/internal_base_nodes.h - IterDomain and TensorDomain
31
+ //! 4. ir/interface_nodes.h - TensorView and Scalar
32
+ //! 5. ** ir/internal_nodes.h ** - Any internal-only IR nodes
33
+
34
+ namespace nvfuser {
35
+
36
+ class ViewTransform;
37
+ class Scope;
38
+ class IrCloner;
39
+ struct AnalyzeViewResult;
40
+
41
+ class NVF_API FullOp : public Expr {
42
+ public:
43
+ using Expr::Expr;
44
+
45
+ FullOp(IrBuilderPasskey, Val* out, Val* fill_value);
46
+
47
+ NVFUSER_DECLARE_CLONE_AND_CREATE
48
+
49
+ const char* getOpString() const override {
50
+ return "FullOp";
51
+ }
52
+
53
+ std::string toString(int indent_size = 0) const override;
54
+ std::string toInlineString(int indent_size = 0) const override;
55
+ std::vector<PolymorphicValue> evaluate(
56
+ const ExpressionEvaluator& ee,
57
+ const std::vector<PolymorphicValue>& inputs) const override;
58
+
59
+ Val* getFillValue() const {
60
+ return inputs().back();
61
+ }
62
+ };
63
+
64
+ class SelectOp : public Expr {
65
+ public:
66
+ using Expr::Expr;
67
+
68
+ SelectOp(IrBuilderPasskey, Val* out, Val* in, int64_t dim, Val* index);
69
+
70
+ NVFUSER_DECLARE_CLONE_AND_CREATE
71
+
72
+ const char* getOpString() const override {
73
+ return "SelectOp";
74
+ }
75
+
76
+ std::string toString(int indent_size = 0) const override;
77
+ std::string toInlineString(int indent_size = 0) const override;
78
+ std::vector<PolymorphicValue> evaluate(
79
+ const ExpressionEvaluator& ee,
80
+ const std::vector<PolymorphicValue>& inputs) const override;
81
+
82
+ TensorView* lookupTv() const {
83
+ return input(0)->as<TensorView>();
84
+ }
85
+
86
+ int64_t dim() const {
87
+ return attribute<int64_t>(0);
88
+ }
89
+
90
+ IterDomain* getIndexedID() const;
91
+
92
+ std::unordered_map<IterDomain*, Val*> getIndexOverridingMap() const {
93
+ return {{getIndexedID(), input(1)}};
94
+ }
95
+ };
96
+
97
+ class IndexSelectOp : public Expr {
98
+ public:
99
+ using Expr::Expr;
100
+
101
+ IndexSelectOp(IrBuilderPasskey, Val* out, Val* in1, int64_t dim, Val* in3);
102
+
103
+ NVFUSER_DECLARE_CLONE_AND_CREATE
104
+
105
+ const char* getOpString() const override {
106
+ return "IndexSelectOp";
107
+ }
108
+
109
+ std::string toString(int indent_size = 0) const override;
110
+ std::string toInlineString(int indent_size = 0) const override;
111
+ std::vector<PolymorphicValue> evaluate(
112
+ const ExpressionEvaluator& ee,
113
+ const std::vector<PolymorphicValue>& inputs) const override;
114
+
115
+ TensorView* lookupTv() const {
116
+ return input(0)->as<TensorView>();
117
+ }
118
+
119
+ TensorView* indexTv() const {
120
+ return input(1)->as<TensorView>();
121
+ }
122
+
123
+ IterDomain* getIndexedID() const;
124
+
125
+ IterDomain* getConsumerOfIndexedID() const;
126
+
127
+ int64_t dim() const {
128
+ return attribute<int64_t>(0);
129
+ }
130
+ };
131
+
132
+ class NVF_API TorchGatherOp : public Expr {
133
+ public:
134
+ using Expr::Expr;
135
+
136
+ //! Parameter exact_sizes indicates whether the non-indexed domains
137
+ //! of the index tensor have the same extents of those of the input
138
+ //! tensor. It's true in the case of torch.take_along_dim and
139
+ //! numpy_take_along_axis. torch.take_along_axis does not guarantee
140
+ //! they are the same.
141
+ TorchGatherOp(
142
+ IrBuilderPasskey,
143
+ Val* out,
144
+ Val* in,
145
+ int64_t dim,
146
+ Val* index,
147
+ bool exact_sizes);
148
+
149
+ NVFUSER_DECLARE_CLONE_AND_CREATE
150
+
151
+ const char* getOpString() const override {
152
+ return "TorchGatherOp";
153
+ }
154
+
155
+ std::string toString(int indent_size = 0) const override;
156
+ std::string toInlineString(int indent_size = 0) const override;
157
+ std::vector<PolymorphicValue> evaluate(
158
+ const ExpressionEvaluator& ee,
159
+ const std::vector<PolymorphicValue>& inputs) const override;
160
+
161
+ TensorView* lookupTv() const {
162
+ return input(0)->as<TensorView>();
163
+ }
164
+
165
+ TensorView* indexTv() const {
166
+ return input(1)->as<TensorView>();
167
+ }
168
+
169
+ int64_t dim() const {
170
+ return attribute<int64_t>(0);
171
+ }
172
+
173
+ IterDomain* getIndexedID() const;
174
+
175
+ IterDomain* getConsumerOfIndexedID() const;
176
+
177
+ bool exactSizes() const {
178
+ return attribute<bool>(1);
179
+ }
180
+ };
181
+
182
+ class ScatterOp : public Expr {
183
+ public:
184
+ using Expr::Expr;
185
+ ScatterOp(
186
+ IrBuilderPasskey,
187
+ ScatterOpType type,
188
+ Val* out,
189
+ Val* self,
190
+ int64_t dim,
191
+ Val* index,
192
+ Val* src);
193
+
194
+ NVFUSER_DECLARE_CLONE_AND_CREATE
195
+
196
+ const char* getOpString() const override {
197
+ return "ScatterOp";
198
+ }
199
+
200
+ std::string toString(int indent_size = 0) const override;
201
+ std::string toInlineString(int indent_size = 0) const override;
202
+ std::vector<PolymorphicValue> evaluate(
203
+ const ExpressionEvaluator& ee,
204
+ const std::vector<PolymorphicValue>& inputs) const override;
205
+
206
+ TensorView* selfTv() const {
207
+ return input(0)->as<TensorView>();
208
+ }
209
+
210
+ TensorView* indexTv() const {
211
+ return input(1)->as<TensorView>();
212
+ }
213
+
214
+ TensorView* srcTv() const {
215
+ return input(2)->as<TensorView>();
216
+ }
217
+
218
+ int64_t dim() const {
219
+ return attribute<int64_t>(0);
220
+ }
221
+
222
+ IterDomain* getIndexedID() const;
223
+
224
+ ScatterOpType getScatterOpType() const {
225
+ return attribute<ScatterOpType>(1);
226
+ }
227
+ };
228
+
229
+ class IotaOp : public Expr {
230
+ public:
231
+ using Expr::Expr;
232
+
233
+ IotaOp(IrBuilderPasskey, Val* out, Val* length, Val* start, Val* step);
234
+
235
+ NVFUSER_DECLARE_CLONE_AND_CREATE
236
+
237
+ const char* getOpString() const override {
238
+ return "IotaOp";
239
+ }
240
+
241
+ std::string toString(int indent_size = 0) const override;
242
+ std::string toInlineString(int indent_size = 0) const override;
243
+ std::vector<PolymorphicValue> evaluate(
244
+ const ExpressionEvaluator& ee,
245
+ const std::vector<PolymorphicValue>& inputs) const override;
246
+
247
+ DataType dtype() const {
248
+ return *start()->getDataType();
249
+ }
250
+
251
+ Val* length() const {
252
+ return input(0);
253
+ }
254
+
255
+ Val* start() const {
256
+ return input(1);
257
+ }
258
+
259
+ Val* step() const {
260
+ return input(2);
261
+ }
262
+ };
263
+
264
+ // Tensor factory for generating identity matrices like
265
+ //
266
+ // [[1, 0, 0],
267
+ // [0, 1, 0],
268
+ // [0, 0, 1]]
269
+ //
270
+ // or
271
+ //
272
+ // [[1, 0, 0],
273
+ // [0, 1, 0],
274
+ // [0, 0, 1],
275
+ // [0, 0, 0]]
276
+ //
277
+ // or
278
+ //
279
+ // [[1, 0, 0, 0],
280
+ // [0, 1, 0, 0],
281
+ // [0, 0, 1, 0]]
282
+ class EyeOp : public Expr {
283
+ public:
284
+ using Expr::Expr;
285
+
286
+ EyeOp(IrBuilderPasskey, Val* out, DataType dtype);
287
+
288
+ NVFUSER_DECLARE_CLONE_AND_CREATE
289
+
290
+ const char* getOpString() const override {
291
+ return "EyeOp";
292
+ }
293
+
294
+ std::string toString(int indent_size = 0) const override;
295
+ std::string toInlineString(int indent_size = 0) const override;
296
+ std::vector<PolymorphicValue> evaluate(
297
+ const ExpressionEvaluator& ee,
298
+ const std::vector<PolymorphicValue>& inputs) const override;
299
+
300
+ DataType dtype() const {
301
+ return attribute<DataType>(0);
302
+ }
303
+ };
304
+
305
+ //! A specialization for Unary operations. Unary operations take in a single
306
+ //! input and produce a single output. Examples include:
307
+ //! 1) Casting operation i.e. float(a_val)
308
+ //! 2) Negation i.e. val * -1
309
+ //! 3) Reduction across a dimension i.e. val.sum(axis=2)
310
+ //! 4) split/merge
311
+ class NVF_API UnaryOp : public Expr {
312
+ public:
313
+ using Expr::Expr;
314
+
315
+ UnaryOp(IrBuilderPasskey, UnaryOpType type, Val* out, Val* in);
316
+
317
+ NVFUSER_DECLARE_CLONE_AND_CREATE
318
+
319
+ const char* getOpString() const override {
320
+ return "UnaryOp";
321
+ }
322
+
323
+ std::string getGraphvizLabel() const override;
324
+
325
+ std::vector<PolymorphicValue> evaluate(
326
+ const ExpressionEvaluator& ee,
327
+ const std::vector<PolymorphicValue>& inputs) const override;
328
+
329
+ std::string toString(int indent_size = 0) const override;
330
+ std::string toInlineString(int indent_size = 0) const override;
331
+
332
+ Val* out() const {
333
+ return output(0);
334
+ }
335
+ Val* in() const {
336
+ return input(0);
337
+ }
338
+
339
+ UnaryOpType getUnaryOpType() const {
340
+ return attribute<UnaryOpType>(0);
341
+ }
342
+
343
+ private:
344
+ void printHelper(std::stringstream& ss, std::string input) const;
345
+ };
346
+
347
+ //! A specialization for Binary operations. Binary operations take in two inputs
348
+ //! and produce a single output. Examples include:
349
+ //! 1) Add/mul/div/mod/sub (A * B)
350
+ //! 2) LT (A < B)
351
+ class NVF_API BinaryOp : public Expr {
352
+ public:
353
+ using Expr::Expr;
354
+
355
+ BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs);
356
+
357
+ NVFUSER_DECLARE_CLONE_AND_CREATE
358
+
359
+ const char* getOpString() const override {
360
+ return "BinaryOp";
361
+ }
362
+
363
+ std::string getGraphvizLabel() const override;
364
+
365
+ std::vector<PolymorphicValue> evaluate(
366
+ const ExpressionEvaluator& ee,
367
+ const std::vector<PolymorphicValue>& inputs) const override;
368
+
369
+ std::string toString(int indent_size = 0) const override;
370
+ std::string toInlineString(int indent_size = 0) const override;
371
+
372
+ Val* out() const {
373
+ return output(0);
374
+ }
375
+ Val* lhs() const {
376
+ return input(0);
377
+ }
378
+ Val* rhs() const {
379
+ return input(1);
380
+ }
381
+
382
+ BinaryOpType getBinaryOpType() const {
383
+ return attribute<BinaryOpType>(0);
384
+ }
385
+
386
+ private:
387
+ void printHelper(
388
+ std::stringstream& ss,
389
+ int indent_size,
390
+ std::string lhs,
391
+ std::string rhs) const;
392
+ };
393
+
394
+ class TernaryOp : public Expr {
395
+ public:
396
+ using Expr::Expr;
397
+
398
+ TernaryOp(
399
+ IrBuilderPasskey,
400
+ TernaryOpType type,
401
+ Val* out,
402
+ Val* in1,
403
+ Val* in2,
404
+ Val* in3);
405
+
406
+ NVFUSER_DECLARE_CLONE_AND_CREATE
407
+
408
+ const char* getOpString() const override {
409
+ return "TernaryOp";
410
+ }
411
+
412
+ std::string getGraphvizLabel() const override;
413
+
414
+ std::vector<PolymorphicValue> evaluate(
415
+ const ExpressionEvaluator& ee,
416
+ const std::vector<PolymorphicValue>& inputs) const override;
417
+
418
+ std::string toString(int indent_size = 0) const override;
419
+ std::string toInlineString(int indent_size = 0) const override;
420
+
421
+ Val* out() const {
422
+ return output(0);
423
+ }
424
+
425
+ Val* in1() const {
426
+ return input(0);
427
+ }
428
+ Val* in2() const {
429
+ return input(1);
430
+ }
431
+ Val* in3() const {
432
+ return input(2);
433
+ }
434
+
435
+ TernaryOpType getTernaryOpType() const {
436
+ return attribute<TernaryOpType>(0);
437
+ }
438
+
439
+ private:
440
+ void printHelper(
441
+ std::stringstream& ss,
442
+ int indent_size,
443
+ std::string in1,
444
+ std::string in2,
445
+ std::string in3) const;
446
+ };
447
+
448
+ // construct an array from a list of values
449
+ class ArrayConstruct : public Expr {
450
+ public:
451
+ using Expr::Expr;
452
+
453
+ NVF_API ArrayConstruct(
454
+ IrBuilderPasskey,
455
+ Val* output,
456
+ std::vector<Val*> inputs);
457
+
458
+ NVFUSER_DECLARE_CLONE_AND_CREATE
459
+
460
+ const char* getOpString() const override {
461
+ return "ArrayConstruct";
462
+ }
463
+
464
+ std::string toString(int indent_size = 0) const override;
465
+ std::string toInlineString(int indent_size = 0) const override;
466
+
467
+ std::vector<PolymorphicValue> evaluate(
468
+ const ExpressionEvaluator& ee,
469
+ const std::vector<PolymorphicValue>& inputs) const override;
470
+
471
+ Val* out() const {
472
+ return output(0);
473
+ }
474
+ };
475
+
476
+ class ReverseArray : public Expr {
477
+ public:
478
+ using Expr::Expr;
479
+
480
+ ReverseArray(IrBuilderPasskey, Val* output, Val* input);
481
+
482
+ NVFUSER_DECLARE_CLONE_AND_CREATE
483
+
484
+ const char* getOpString() const override {
485
+ return "ReverseArray";
486
+ }
487
+
488
+ std::string toString(int indent_size = 0) const override;
489
+ std::string toInlineString(int indent_size = 0) const override;
490
+
491
+ std::vector<PolymorphicValue> evaluate(
492
+ const ExpressionEvaluator& ee,
493
+ const std::vector<PolymorphicValue>& inputs) const override;
494
+
495
+ Val* out() const {
496
+ return output(0);
497
+ }
498
+
499
+ Val* in() const {
500
+ return input(0);
501
+ }
502
+ };
503
+
504
+ // Get an item from an array, array[index]
505
+ class GetItem : public Expr {
506
+ public:
507
+ using Expr::Expr;
508
+
509
+ GetItem(IrBuilderPasskey, Val* output, Val* array, Val* index);
510
+
511
+ NVFUSER_DECLARE_CLONE_AND_CREATE
512
+
513
+ const char* getOpString() const override {
514
+ return "GetItem";
515
+ }
516
+
517
+ std::string toString(int indent_size = 0) const override;
518
+ std::string toInlineString(int indent_size = 0) const override;
519
+
520
+ std::vector<PolymorphicValue> evaluate(
521
+ const ExpressionEvaluator& ee,
522
+ const std::vector<PolymorphicValue>& inputs) const override;
523
+
524
+ Val* out() const {
525
+ return output(0);
526
+ }
527
+
528
+ Val* array() const {
529
+ return input(0);
530
+ }
531
+
532
+ Val* index() const {
533
+ return input(1);
534
+ }
535
+ };
536
+
537
+ // construct a struct from a list of values
538
+ class StructConstruct : public Expr {
539
+ public:
540
+ using Expr::Expr;
541
+
542
+ NVF_API StructConstruct(
543
+ IrBuilderPasskey,
544
+ Val* output,
545
+ const std::vector<std::pair<std::string, Val*>>& fields);
546
+
547
+ NVFUSER_DECLARE_CLONE_AND_CREATE
548
+
549
+ const char* getOpString() const override {
550
+ return "StructConstruct";
551
+ }
552
+
553
+ std::string toString(int indent_size = 0) const override;
554
+ std::string toInlineString(int indent_size = 0) const override;
555
+
556
+ std::vector<PolymorphicValue> evaluate(
557
+ const ExpressionEvaluator& ee,
558
+ const std::vector<PolymorphicValue>& inputs) const override;
559
+
560
+ std::string fieldName(size_t i) const {
561
+ return attribute<std::string>(i);
562
+ }
563
+
564
+ Val* out() const {
565
+ return output(0);
566
+ }
567
+ };
568
+
569
+ // Get an attribute from a struct, struct.attr
570
+ class GetAttr : public Expr {
571
+ public:
572
+ using Expr::Expr;
573
+
574
+ GetAttr(IrBuilderPasskey, Val* output, Val* struct_, std::string attr);
575
+
576
+ NVFUSER_DECLARE_CLONE_AND_CREATE
577
+
578
+ const char* getOpString() const override {
579
+ return "GetAttr";
580
+ }
581
+
582
+ std::string toString(int indent_size = 0) const override;
583
+ std::string toInlineString(int indent_size = 0) const override;
584
+
585
+ std::vector<PolymorphicValue> evaluate(
586
+ const ExpressionEvaluator& ee,
587
+ const std::vector<PolymorphicValue>& inputs) const override;
588
+
589
+ Val* out() const {
590
+ return output(0);
591
+ }
592
+
593
+ Val* struct_() const {
594
+ return input(0);
595
+ }
596
+
597
+ std::string attr() const {
598
+ return attribute<std::string>(0);
599
+ }
600
+ };
601
+
602
+ // Get an attribute from a struct, struct.attr
603
+ class GetMetaData : public Expr {
604
+ public:
605
+ using Expr::Expr;
606
+
607
+ GetMetaData(IrBuilderPasskey, Val* output, Val* input);
608
+
609
+ NVFUSER_DECLARE_CLONE_AND_CREATE
610
+
611
+ const char* getOpString() const override {
612
+ return "GetMetaData";
613
+ }
614
+
615
+ std::string toString(int indent_size = 0) const override;
616
+ std::string toInlineString(int indent_size = 0) const override;
617
+
618
+ bool sameAs(const Statement* other) const override {
619
+ auto other_meta = dynamic_cast<const GetMetaData*>(other);
620
+ if (other_meta == nullptr) {
621
+ return false;
622
+ }
623
+ // Do not recursively check input, because if we have
624
+ // T1 = set(T0)
625
+ // T2 = set(T0)
626
+ // Then even if T1->sameAs(T2), they should not have the same metadata.
627
+ // For example, T1 and T2 may be different fusion outputs, so their data
628
+ // pointers are different.
629
+ return other_meta->in() == in();
630
+ }
631
+
632
+ std::vector<PolymorphicValue> evaluate(
633
+ const ExpressionEvaluator& ee,
634
+ const std::vector<PolymorphicValue>& inputs) const override;
635
+
636
+ Val* out() const {
637
+ return output(0);
638
+ }
639
+
640
+ Val* in() const {
641
+ return input(0);
642
+ }
643
+ };
644
+
645
+ // Construct a tensor from an array
646
+ class TensorConstruct : public Expr {
647
+ public:
648
+ using Expr::Expr;
649
+
650
+ TensorConstruct(IrBuilderPasskey, TensorView* output, Val* input);
651
+
652
+ NVFUSER_DECLARE_CLONE_AND_CREATE
653
+
654
+ const char* getOpString() const override {
655
+ return "TensorConstruct";
656
+ }
657
+
658
+ std::string toString(int indent_size = 0) const override;
659
+ std::string toInlineString(int indent_size = 0) const override;
660
+
661
+ std::vector<PolymorphicValue> evaluate(
662
+ const ExpressionEvaluator& ee,
663
+ const std::vector<PolymorphicValue>& inputs) const override;
664
+
665
+ TensorView* out() const {
666
+ return output(0)->as<TensorView>();
667
+ }
668
+
669
+ Val* in() const {
670
+ return input(0);
671
+ }
672
+ };
673
+
674
+ //! A specialization for random number generator (RNG) operations. RNG
675
+ //! operations take in no tensor input and produce a single output.
676
+ class RNGOp : public Expr {
677
+ int64_t getOutputDims() const;
678
+
679
+ public:
680
+ struct Attributes {
681
+ // default initialization for clang-tidy
682
+ // cppcoreguidelines-pro-type-member-init
683
+ RNGOpType rtype = RNGOpType::Undefined;
684
+ DataType dtype;
685
+ size_t num_parameters = 0;
686
+
687
+ // TODO: Enable the following in C++20:
688
+ // bool operator==(const Attributes &other) const = default;
689
+ bool operator==(const Attributes& other) const {
690
+ // Note: we do not need to explicitly compare num_parameters since it is
691
+ // tied to rtype
692
+ return rtype == other.rtype && dtype == other.dtype;
693
+ }
694
+ };
695
+
696
+ using Expr::Expr;
697
+
698
+ //! Note that if philox_offset is provided, then rng_offset will be ignored.
699
+ RNGOp(
700
+ IrBuilderPasskey,
701
+ RNGOpType type,
702
+ Val* out,
703
+ DataType dtype,
704
+ std::vector<Val*> parameters = {},
705
+ Val* philox_seed = nullptr,
706
+ Val* philox_offset = nullptr,
707
+ Val* philox_index = nullptr);
708
+
709
+ NVFUSER_DECLARE_CLONE_AND_CREATE
710
+
711
+ const char* getOpString() const override {
712
+ return "RNGOp";
713
+ }
714
+
715
+ std::string toString(int indent_size = 0) const override;
716
+ std::string toInlineString(int indent_size = 0) const override;
717
+
718
+ RNGOpType getRNGOpType() const {
719
+ return attribute<Attributes>(0).rtype;
720
+ }
721
+
722
+ DataType dtype() const {
723
+ return attribute<Attributes>(0).dtype;
724
+ }
725
+
726
+ size_t getNumParameters() const {
727
+ return attribute<Attributes>(0).num_parameters;
728
+ }
729
+
730
+ std::vector<Val*> getParameters() const {
731
+ return {
732
+ inputs().begin() + getOutputDims(),
733
+ inputs().begin() + (int64_t)(getOutputDims() + getNumParameters())};
734
+ }
735
+
736
+ std::vector<Val*> getShape() const {
737
+ return {inputs().begin(), inputs().begin() + getOutputDims()};
738
+ }
739
+
740
+ Val* getRNGSeedVal() const {
741
+ // Note that inputs() consists of:
742
+ // output dims | parameters | philox seed | philox_offset
743
+ auto seed_index = getOutputDims() + getNumParameters();
744
+ return (inputs().size() > seed_index) ? inputs().at(seed_index) : nullptr;
745
+ }
746
+
747
+ Val* getRNGOffsetVal() const {
748
+ // Note that inputs() consists of:
749
+ // output dims | parameters | philox seed | philox_offset
750
+ auto offset_index = getOutputDims() + getNumParameters() + 1;
751
+ return (inputs().size() > offset_index) ? inputs().at(offset_index)
752
+ : nullptr;
753
+ }
754
+
755
+ bool isDeterministic() const {
756
+ return inputs().size() == getOutputDims() + getNumParameters() + 2;
757
+ }
758
+
759
+ void setSeedAndOffset(Val* seed, Val* offset) {
760
+ NVF_ERROR(!isDeterministic());
761
+ addInput(seed);
762
+ addInput(offset);
763
+ }
764
+
765
+ Val* getPhiloxIndex() const {
766
+ return attributeVal(1);
767
+ }
768
+
769
+ int getPhiloxMultiple() const {
770
+ return dtype() == DataType::Double ? 2 : 4;
771
+ }
772
+ };
773
+
774
+ //! Broadcast in to match out. The semantics are identical to torch.unsqueeze.
775
+ //! is_broadcast_dims are relative to out. Where
776
+ //! is_broadcast_dims.size() == out->nDims().
777
+ class NVF_API BroadcastOp : public Expr {
778
+ public:
779
+ using Expr::Expr;
780
+
781
+ //! \param out The output tensor
782
+ //! \param in The input tensor
783
+ //! \param is_broadcast_dims True when output dim is a new broadcast domain
784
+ BroadcastOp(
785
+ IrBuilderPasskey,
786
+ Val* out,
787
+ Val* in,
788
+ std::vector<bool> is_broadcast_dims);
789
+
790
+ NVFUSER_DECLARE_CLONE_AND_CREATE
791
+
792
+ const char* getOpString() const override {
793
+ return "BroadcastOp";
794
+ }
795
+
796
+ std::string toString(int indent_size = 0) const override;
797
+ std::string toInlineString(int indent_size = 0) const override;
798
+
799
+ std::vector<PolymorphicValue> evaluate(
800
+ const ExpressionEvaluator& ee,
801
+ const std::vector<PolymorphicValue>& inputs) const override;
802
+
803
+ Val* out() const {
804
+ return output(0);
805
+ }
806
+ Val* in() const {
807
+ return input(0);
808
+ }
809
+
810
+ bool isBroadcastDim(size_t dim) const {
811
+ return getBroadcastDimFlags().at(dim);
812
+ }
813
+
814
+ //! The same list passed to the broadcast arithmetic op. Each
815
+ //! element corresponds to an IterDomain of the output tensor and is
816
+ //! true when the IterDomain is a new broadcast domain. Note
817
+ //! that the output tensor may have other broadcast domains whose
818
+ //! flags are false because the input tensor may already have
819
+ //! broadcast domains.
820
+ const std::vector<bool>& getBroadcastDimFlags() const {
821
+ return attribute<std::vector<bool>>(0);
822
+ }
823
+ };
824
+
825
+ //! Squeeze in to match out. is_squeeze_dims are relative to in. Where
826
+ //! is_squeeze_dims.size() == in->nDims(). Squeeze is the opposite of
827
+ //! broadcast.
828
+ class NVF_API SqueezeOp : public Expr {
829
+ public:
830
+ using Expr::Expr;
831
+
832
+ //! \param out The output tensor
833
+ //! \param in The input tensor
834
+ //! \param is_squeeze_dims True when input dim is a removed broadcast domain
835
+ SqueezeOp(
836
+ IrBuilderPasskey,
837
+ Val* out,
838
+ Val* in,
839
+ std::vector<bool> is_broadcast_dims);
840
+
841
+ NVFUSER_DECLARE_CLONE_AND_CREATE
842
+
843
+ const char* getOpString() const override {
844
+ return "SqueezeOp";
845
+ }
846
+
847
+ std::string toString(int indent_size = 0) const override;
848
+ std::string toInlineString(int indent_size = 0) const override;
849
+
850
+ std::vector<PolymorphicValue> evaluate(
851
+ const ExpressionEvaluator& ee,
852
+ const std::vector<PolymorphicValue>& inputs) const override;
853
+
854
+ Val* out() const {
855
+ return output(0);
856
+ }
857
+ Val* in() const {
858
+ return input(0);
859
+ }
860
+
861
+ bool isSqueezeDim(size_t dim) const {
862
+ return getSqueezeDimFlags().at(dim);
863
+ }
864
+
865
+ //! The same list passed to the squeeze arithmetic op. Each
866
+ //! element corresponds to an IterDomain of the input tensor and is
867
+ //! true when the IterDomain is a broadcast domain that is removed in the
868
+ //! output. Note that the output tensor may still contain broadcast domains
869
+ //! because the input tensor may have broadcast domains that we don't want to
870
+ //! remove (false flag).
871
+ const std::vector<bool>& getSqueezeDimFlags() const {
872
+ return attribute<std::vector<bool>>(0);
873
+ }
874
+
875
+ //! Check that squeezed IDs in old_tv concretize to Broadcast IterType
876
+ void checkConcretization(Val* old_tv, Val* new_tv) const override;
877
+ };
878
+
879
+ //! Reduction operation. Out is first initialized to _init. Then
880
+ //! reduction_op_type is used to update out as out = reductionOp(out, in).
881
+ //! Output's axes marked as reduction will be reduced to produce an output
882
+ //! tensor. The output tensors size will be the size of all
883
+ //! non-reduction/non-broadcast dimensions.
884
+ class NVF_API ReductionOp : public Expr {
885
+ public:
886
+ using Expr::Expr;
887
+
888
+ ReductionOp(
889
+ IrBuilderPasskey,
890
+ BinaryOpType reduction_op_type,
891
+ Val* init,
892
+ Val* out,
893
+ Val* in,
894
+ bool is_allreduce = false);
895
+
896
+ NVFUSER_DECLARE_CLONE_AND_CREATE
897
+
898
+ const char* getOpString() const override {
899
+ return "ReductionOp";
900
+ }
901
+
902
+ std::string toString(int indent_size = 0) const override;
903
+ std::string toInlineString(int indent_size = 0) const override;
904
+ std::vector<PolymorphicValue> evaluate(
905
+ const ExpressionEvaluator& ee,
906
+ const std::vector<PolymorphicValue>& inputs) const override;
907
+
908
+ Val* out() const {
909
+ return output(0);
910
+ }
911
+ Val* in() const {
912
+ return input(0);
913
+ }
914
+ Val* init() const {
915
+ return attributeVal(0);
916
+ }
917
+
918
+ BinaryOpType getReductionOpType() const {
919
+ return attribute<BinaryOpType>(1);
920
+ }
921
+
922
+ bool isAllreduce() const {
923
+ return attribute<bool>(2);
924
+ }
925
+
926
+ //! Scheduling method to request that this reduction be performed as a
927
+ //! serial grid reduction. Note that it is an error to use this method on a
928
+ //! reduction whose output has any of its reduction axes parallelized with a
929
+ //! threadIdx, even if that parallelization occurs after this method call.
930
+ //!
931
+ //! Also note that this operation should not be inlined with other reductions
932
+ //! unless they use the same parallelization pattern and they are also serial
933
+ //! gridreductions.
934
+ void requestSerialGridReduction(bool value = true) {
935
+ attribute<bool>(3) = value;
936
+ }
937
+
938
+ bool serialGridReductionRequested() const {
939
+ return attribute<bool>(3);
940
+ }
941
+ };
942
+
943
+ //! Grouped reduction operation for horizontal fusions. It works like
944
+ //! batched GEMMs in the sense that multiple independent reductions are
945
+ //! performed together. The main benefit is when reducing tensors across thread
946
+ //! blocks, a single grid sync can be done for all individual
947
+ //! reductions. As grid sync is very expensive, this can be a
948
+ //! significant performance impact.
949
+ class GroupedReductionOp : public Expr {
950
+ public:
951
+ using Expr::Expr;
952
+
953
+ GroupedReductionOp(
954
+ IrBuilderPasskey,
955
+ std::vector<BinaryOpType> reduction_op_types,
956
+ std::vector<Val*> init,
957
+ std::vector<Val*> out,
958
+ std::vector<Val*> in,
959
+ bool is_allreduce = false);
960
+
961
+ NVFUSER_DECLARE_CLONE_AND_CREATE
962
+
963
+ const char* getOpString() const override {
964
+ return "GroupedReductionOp";
965
+ }
966
+
967
+ std::string toString(int indent_size = 0) const override;
968
+ std::string toInlineString(int indent_size = 0) const override;
969
+ std::vector<PolymorphicValue> evaluate(
970
+ const ExpressionEvaluator& ee,
971
+ const std::vector<PolymorphicValue>& inputs) const override;
972
+
973
+ //! Number of expressions grouped horizontally. It does not reflect
974
+ //! iteration grouping.
975
+ size_t numHorizontallyGroupedExprs() const {
976
+ return getReductionOpTypes().size();
977
+ }
978
+
979
+ std::vector<Val*> initVals() const {
980
+ auto size = numHorizontallyGroupedExprs();
981
+ std::vector<Val*> result;
982
+ result.reserve(size);
983
+ for (auto i : c10::irange(2, 2 + size)) {
984
+ result.emplace_back(attribute(i)->as<Val>());
985
+ }
986
+ return result;
987
+ }
988
+
989
+ Val* initVal(size_t index) const {
990
+ return attributeVal(2 + index);
991
+ }
992
+
993
+ const std::vector<BinaryOpType>& getReductionOpTypes() const {
994
+ return attribute<std::vector<BinaryOpType>>(0);
995
+ }
996
+
997
+ BinaryOpType getReductionOpType(size_t index) const {
998
+ return getReductionOpTypes().at(index);
999
+ }
1000
+
1001
+ bool isAllreduce() const {
1002
+ return attribute<bool>(1);
1003
+ }
1004
+
1005
+ //! Return the index of the corresponding reduction expression for
1006
+ //! a given output val.
1007
+ int getExprIndexOfOutput(Val* output_val) const;
1008
+ };
1009
+
1010
+ //! Average, variance and N (count) vals for Welford
1011
+ class WelfordTriplet {
1012
+ public:
1013
+ //! Names of the Welford triplet vals
1014
+ enum class ValName { Avg, Var, N };
1015
+
1016
+ WelfordTriplet() = default;
1017
+
1018
+ WelfordTriplet(Val* avg, Val* var, Val* N) : vals_({avg, var, N}) {}
1019
+
1020
+ Val* const& avg() const {
1021
+ return get(ValName::Avg);
1022
+ }
1023
+
1024
+ Val*& avg() {
1025
+ return get(ValName::Avg);
1026
+ }
1027
+
1028
+ TensorView* avgTv() const {
1029
+ NVF_ERROR(avg()->isA<TensorView>());
1030
+ return avg()->as<TensorView>();
1031
+ }
1032
+
1033
+ Val* const& var() const {
1034
+ return get(ValName::Var);
1035
+ }
1036
+
1037
+ Val*& var() {
1038
+ return get(ValName::Var);
1039
+ }
1040
+
1041
+ TensorView* varTv() const {
1042
+ NVF_ERROR(var()->isA<TensorView>());
1043
+ return var()->as<TensorView>();
1044
+ }
1045
+
1046
+ Val* const& N() const {
1047
+ return get(ValName::N);
1048
+ }
1049
+
1050
+ Val*& N() {
1051
+ return get(ValName::N);
1052
+ }
1053
+
1054
+ TensorView* NTv() const {
1055
+ NVF_ERROR(N()->isA<TensorView>());
1056
+ return N()->as<TensorView>();
1057
+ }
1058
+
1059
+ //! Get the i-th val. Ordering is defined by ValName.
1060
+ Val* const& get(int i) const {
1061
+ return vals_.at(i);
1062
+ }
1063
+
1064
+ //! Get the i-th val. Ordering is defined by ValName.
1065
+ Val*& get(int i) {
1066
+ return vals_.at(i);
1067
+ }
1068
+
1069
+ Val* const& get(ValName name) const {
1070
+ return get(valNameToIndex(name));
1071
+ }
1072
+
1073
+ Val*& get(ValName name) {
1074
+ return get(valNameToIndex(name));
1075
+ }
1076
+
1077
+ //! Get the name of a given val in this triplet. None is returned if
1078
+ //! not found.
1079
+ std::optional<ValName> getNameOf(Val* val) const;
1080
+
1081
+ //! Return a new triplet with outputs produced by a function applied
1082
+ //! to each of this triplet
1083
+ template <typename Func>
1084
+ WelfordTriplet transform(Func func) const {
1085
+ return WelfordTriplet(func(avg()), func(var()), func(N()));
1086
+ }
1087
+
1088
+ bool sameAs(const WelfordTriplet& other) const;
1089
+
1090
+ WelfordTriplet clone(IrCloner* ir_cloner) const;
1091
+
1092
+ //! Clone a vector of triplets
1093
+ static std::vector<WelfordTriplet> clone(
1094
+ const std::vector<WelfordTriplet>& src,
1095
+ IrCloner* ir_cloner);
1096
+
1097
+ auto begin() {
1098
+ return vals_.begin();
1099
+ }
1100
+
1101
+ auto begin() const {
1102
+ return vals_.begin();
1103
+ }
1104
+
1105
+ auto end() {
1106
+ return vals_.end();
1107
+ }
1108
+
1109
+ auto end() const {
1110
+ return vals_.end();
1111
+ }
1112
+
1113
+ private:
1114
+ //! Convert a given val name to an index
1115
+ static int valNameToIndex(ValName name) {
1116
+ return static_cast<int>(name);
1117
+ }
1118
+
1119
+ //! Convert a given index to a name
1120
+ static ValName indexToValName(int index) {
1121
+ NVF_ERROR(index >= 0 && index < 3, "Invalid index: ", index);
1122
+ return static_cast<ValName>(index);
1123
+ }
1124
+
1125
+ private:
1126
+ //! Holds avg, var and N in this order
1127
+ std::array<Val*, 3> vals_ = {{nullptr, nullptr, nullptr}};
1128
+ };
1129
+
1130
+ //! Welford Scan operation.
1131
+ class NVF_API WelfordOp : public Expr {
1132
+ public:
1133
+ using Expr::Expr;
1134
+ static constexpr int kNumAttrs = 4;
1135
+
1136
+ WelfordOp(
1137
+ IrBuilderPasskey,
1138
+ const WelfordTriplet& output,
1139
+ const WelfordTriplet& input,
1140
+ const WelfordTriplet& init,
1141
+ bool is_fused = false);
1142
+
1143
+ WelfordOp(
1144
+ IrBuilderPasskey,
1145
+ Val* out_avg,
1146
+ Val* out_var,
1147
+ Val* out_N,
1148
+ Val* in_avg,
1149
+ Val* in_var,
1150
+ Val* in_N,
1151
+ Val* init_avg,
1152
+ Val* init_var,
1153
+ Val* init_N,
1154
+ bool is_fused = false);
1155
+
1156
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1157
+
1158
+ const char* getOpString() const override {
1159
+ return "WelfordOp";
1160
+ }
1161
+
1162
+ std::string toString(int indent_size = 0) const override;
1163
+ std::string toInlineString(int indent_size = 0) const override;
1164
+ std::vector<PolymorphicValue> evaluate(
1165
+ const ExpressionEvaluator& ee,
1166
+ const std::vector<PolymorphicValue>& inputs) const override;
1167
+
1168
+ Val* out() const {
1169
+ return outputTriplet().avg();
1170
+ }
1171
+
1172
+ Val* in() const {
1173
+ return inputTriplet().avg();
1174
+ }
1175
+
1176
+ WelfordTriplet outputTriplet() const {
1177
+ return WelfordTriplet(outAvg(), outVar(), outN());
1178
+ }
1179
+
1180
+ Val* outAvg() const {
1181
+ return output(0);
1182
+ }
1183
+
1184
+ Val* outVar() const {
1185
+ return output(1);
1186
+ }
1187
+
1188
+ Val* outN() const {
1189
+ return output(2);
1190
+ }
1191
+
1192
+ WelfordTriplet inputTriplet() const {
1193
+ return WelfordTriplet(inAvg(), inVar(), inN());
1194
+ }
1195
+
1196
+ Val* inAvg() const {
1197
+ return input(0);
1198
+ }
1199
+
1200
+ Val* inVar() const {
1201
+ return input(1);
1202
+ }
1203
+
1204
+ Val* inN() const {
1205
+ return input(2);
1206
+ }
1207
+
1208
+ WelfordTriplet initTriplet() const {
1209
+ return WelfordTriplet(initAvg(), initVar(), initN());
1210
+ }
1211
+
1212
+ Val* initAvg() const {
1213
+ return attributeVal(0);
1214
+ }
1215
+
1216
+ Val* initVar() const {
1217
+ return attributeVal(1);
1218
+ }
1219
+
1220
+ Val* initN() const {
1221
+ return attributeVal(2);
1222
+ }
1223
+
1224
+ bool singleValue() const {
1225
+ return inN()->isOneInt();
1226
+ }
1227
+
1228
+ bool hasInit() const {
1229
+ return !initN()->isZeroInt();
1230
+ }
1231
+
1232
+ //! True if using the fused reduction kernel (not implemented yet)
1233
+ bool isAllreduce() const {
1234
+ return attribute<bool>(3);
1235
+ }
1236
+
1237
+ std::vector<Val*> getInitVals() const;
1238
+
1239
+ //! Return the init val for an output val
1240
+ Val* getInitValOfOutput(Val* output_val) const;
1241
+ };
1242
+
1243
+ class GroupedWelfordOp : public Expr {
1244
+ public:
1245
+ using Expr::Expr;
1246
+
1247
+ GroupedWelfordOp(
1248
+ IrBuilderPasskey,
1249
+ std::vector<WelfordTriplet> output_vals,
1250
+ std::vector<WelfordTriplet> input_vals,
1251
+ std::vector<WelfordTriplet> init_vals,
1252
+ bool is_allreduce = false);
1253
+
1254
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1255
+
1256
+ const char* getOpString() const override {
1257
+ return "GroupedWelfordOp";
1258
+ }
1259
+
1260
+ std::string toString(int indent_size = 0) const override;
1261
+ std::string toInlineString(int indent_size = 0) const override;
1262
+
1263
+ //! Number of expressions grouped horizontally. It does not reflect
1264
+ //! iteration grouping. As horizontal grouping is not supported,
1265
+ //! this always returns 1.
1266
+ size_t numHorizontallyGroupedExprs() const {
1267
+ return 1;
1268
+ }
1269
+
1270
+ Val* out(size_t index) const {
1271
+ return outAvg(index);
1272
+ }
1273
+
1274
+ Val* in(size_t index) const {
1275
+ return inAvg(index);
1276
+ }
1277
+
1278
+ std::vector<WelfordTriplet> outputVals() const {
1279
+ std::vector<WelfordTriplet> result;
1280
+ auto size = outputs().size() / 3;
1281
+ result.reserve(size);
1282
+ for (auto i : c10::irange(size)) {
1283
+ result.emplace_back(outAvg(i), outVar(i), outN(i));
1284
+ }
1285
+ return result;
1286
+ }
1287
+
1288
+ std::vector<WelfordTriplet> inputVals() const {
1289
+ std::vector<WelfordTriplet> result;
1290
+ auto size = inputs().size() / 3;
1291
+ result.reserve(size);
1292
+ for (auto i : c10::irange(size)) {
1293
+ result.emplace_back(inAvg(i), inVar(i), inN(i));
1294
+ }
1295
+ return result;
1296
+ }
1297
+
1298
+ std::vector<WelfordTriplet> initVals() const {
1299
+ std::vector<WelfordTriplet> result;
1300
+ auto size = inputs().size() / 3;
1301
+ result.reserve(size);
1302
+ for (auto i : c10::irange(size)) {
1303
+ result.emplace_back(initAvg(i), initVar(i), initN(i));
1304
+ }
1305
+ return result;
1306
+ }
1307
+
1308
+ Val* outAvg(size_t index) const {
1309
+ return output(index * 3);
1310
+ }
1311
+
1312
+ Val* outVar(size_t index) const {
1313
+ return output(index * 3 + 1);
1314
+ }
1315
+
1316
+ Val* outN(size_t index) const {
1317
+ return output(index * 3 + 2);
1318
+ }
1319
+
1320
+ Val* inAvg(size_t index) const {
1321
+ return input(index * 3);
1322
+ }
1323
+
1324
+ Val* inVar(size_t index) const {
1325
+ return input(index * 3 + 1);
1326
+ }
1327
+
1328
+ Val* inN(size_t index) const {
1329
+ return input(index * 3 + 2);
1330
+ }
1331
+
1332
+ Val* initAvg(size_t index) const {
1333
+ return attributeVal(1 + index * 3);
1334
+ }
1335
+
1336
+ Val* initVar(size_t index) const {
1337
+ return attributeVal(2 + index * 3);
1338
+ }
1339
+
1340
+ Val* initN(size_t index) const {
1341
+ return attributeVal(3 + index * 3);
1342
+ }
1343
+
1344
+ //! Return the index of the corresponding welford expression for
1345
+ //! a given output val
1346
+ int getExprIndexOfOutput(Val* output_val) const;
1347
+
1348
+ //! Return the init val for an output val
1349
+ Val* getInitValOfOutput(Val* output_val) const;
1350
+
1351
+ bool singleValue(size_t index) const {
1352
+ return inN(index)->isOneInt();
1353
+ }
1354
+
1355
+ bool hasInit(size_t index) const {
1356
+ return !initN(index)->isZeroInt();
1357
+ }
1358
+
1359
+ bool isAllreduce() const {
1360
+ return attribute<bool>(0);
1361
+ }
1362
+ };
1363
+
1364
+ //! Fused Matmul operation
1365
+ class NVF_API MmaOp : public Expr {
1366
+ public:
1367
+ using AxesData = std::vector<int64_t>;
1368
+ // AxisMapping denotes the pairing of two input dimensions to produce an
1369
+ // output dimension. It holds two vectors of integers indicating the
1370
+ // corresponding position of each output axis in either the A or B input.
1371
+ // Positions refer to the noReductions logical domain of each input.
1372
+ // NOTE: Axis positions are absolute, meaning you cannot specify them
1373
+ // relative to the last dimension since -1 has special meaning.
1374
+ // NOTE: -1 indicates that the axis does not exist, so Broadcast input
1375
+ // domains should be listed with their actual position and not -1.
1376
+ //
1377
+ // Example 1:
1378
+ // a [ K, 1, M ]
1379
+ // b [ 1, N, K ]
1380
+ // out [ M, N, rK ]
1381
+ // axisMapping:
1382
+ // a_axes = [ 2, 1, 0 ]
1383
+ // b_axes = [ 0, 1, 2 ]
1384
+ // This results in the following groups of mapped axes:
1385
+ // { tv_a->axis(2), tv_b->axis(0), out->axis(0) }
1386
+ // { tv_a->axis(1), tv_b->axis(1), out->axis(1) }
1387
+ // { tv_a->axis(0), tv_b->axis(2), out->axis(2) }
1388
+ //
1389
+ // Example 1:
1390
+ // a [ K, M ]
1391
+ // b [ 1, N, K ]
1392
+ // out [ M, N, rK ]
1393
+ // axisMapping:
1394
+ // a_axes = [ 1, -1, 0 ]
1395
+ // b_axes = [ 0, 1, 2 ]
1396
+ // This results in the following groups of mapped axes:
1397
+ // { tv_a->axis(1), tv_b->axis(0), out->axis(0) }
1398
+ // { tv_b->axis(1), out->axis(1) }
1399
+ // { tv_a->axis(0), tv_b->axis(2), out->axis(2) }
1400
+ struct AxisMapping {
1401
+ AxesData a_axes;
1402
+ AxesData b_axes;
1403
+
1404
+ static AxisMapping trivialMapping(size_t dimension);
1405
+ };
1406
+ using Expr::Expr;
1407
+
1408
+ MmaOp(
1409
+ IrBuilderPasskey,
1410
+ Val* out,
1411
+ Val* in_a,
1412
+ Val* in_b,
1413
+ Val* init,
1414
+ const AxisMapping& axis_mapping);
1415
+
1416
+ MmaOp(
1417
+ IrBuilderPasskey,
1418
+ Val* out,
1419
+ Val* in_a,
1420
+ Val* in_b,
1421
+ Val* init,
1422
+ const AxisMapping& axis_mapping,
1423
+ const MmaMacro& options);
1424
+
1425
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1426
+
1427
+ const char* getOpString() const override {
1428
+ return "MmaOp";
1429
+ }
1430
+
1431
+ std::string toString(int indent_size = 0) const override;
1432
+ std::string toInlineString(int indent_size = 0) const override;
1433
+
1434
+ Val* out() const {
1435
+ return output(0);
1436
+ }
1437
+
1438
+ Val* inA() const {
1439
+ return input(0);
1440
+ }
1441
+
1442
+ Val* inB() const {
1443
+ return input(1);
1444
+ }
1445
+
1446
+ Val* init() const {
1447
+ return attributeVal(0);
1448
+ }
1449
+
1450
+ const auto& macro() const {
1451
+ return attribute<MmaMacro>(ATTR_POS_MACRO);
1452
+ }
1453
+
1454
+ int64_t m() const {
1455
+ return getM(macro());
1456
+ }
1457
+
1458
+ int64_t n() const {
1459
+ return getN(macro());
1460
+ }
1461
+
1462
+ int64_t k() const {
1463
+ return getK(macro());
1464
+ }
1465
+
1466
+ bool isTuring() const {
1467
+ return nvfuser::isTuring(macro());
1468
+ }
1469
+
1470
+ bool isAmpere() const {
1471
+ return nvfuser::isAmpere(macro());
1472
+ }
1473
+
1474
+ bool isHopper() const {
1475
+ return nvfuser::isHopper(macro());
1476
+ }
1477
+
1478
+ void setMacro(MmaMacro options);
1479
+
1480
+ const AxisMapping& axisMapping() const {
1481
+ return attribute<AxisMapping>(ATTR_POS_AXIS_MAPPING);
1482
+ }
1483
+
1484
+ private:
1485
+ // Predefined indices of attributes stored for this IR node, to avoid
1486
+ // magic numbers, based on order in which attributes are initialized
1487
+ // in constructor
1488
+ static constexpr size_t ATTR_POS_INIT = 0;
1489
+ static constexpr size_t ATTR_POS_MACRO = 1;
1490
+ static constexpr size_t ATTR_POS_AXIS_MAPPING = 2;
1491
+ };
1492
+
1493
+ //! The semantics are identical to torch.broadcast_to.
1494
+ class ExpandOp : public Expr {
1495
+ public:
1496
+ using Expr::Expr;
1497
+
1498
+ ExpandOp(
1499
+ IrBuilderPasskey,
1500
+ TensorView* out,
1501
+ TensorView* in,
1502
+ std::vector<Val*> _expanded_extents);
1503
+
1504
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1505
+
1506
+ const char* getOpString() const override {
1507
+ return "ExpandOp";
1508
+ }
1509
+
1510
+ std::string toString(int indent_size = 0) const override;
1511
+ std::string toInlineString(int indent_size = 0) const override;
1512
+
1513
+ TensorView* out() const {
1514
+ return output(0)->as<TensorView>();
1515
+ }
1516
+
1517
+ TensorView* in() const {
1518
+ return input(0)->as<TensorView>();
1519
+ }
1520
+
1521
+ std::vector<Val*> expanded_extents() const {
1522
+ return {inputs().begin() + 1, inputs().end()};
1523
+ }
1524
+
1525
+ std::vector<PolymorphicValue> evaluate(
1526
+ const ExpressionEvaluator& ee,
1527
+ const std::vector<PolymorphicValue>& inputs) const override;
1528
+ };
1529
+
1530
+ // Represents a repetition of broadcast IDs. Repetitions of
1531
+ // non-broadcast IDs are represented using the broadcast, expand and
1532
+ // reshape pattern. See the repeat op implementation in ops/alias.cpp
1533
+ // as well as the TranslateRepeatToExpand preseg pass.
1534
+ class RepeatOp : public Expr {
1535
+ public:
1536
+ using Expr::Expr;
1537
+
1538
+ // in: Input tensor that have broadcast logical IDs.
1539
+ // out: Output tensor where some of the input broadcast logical IDs
1540
+ // are converted to concrete IDs. Their extents represent the
1541
+ // repetition factor of each ID.
1542
+ RepeatOp(IrBuilderPasskey, TensorView* out, TensorView* in);
1543
+
1544
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1545
+
1546
+ const char* getOpString() const override {
1547
+ return "RepeatOp";
1548
+ }
1549
+
1550
+ std::string toString(int indent_size = 0) const override;
1551
+ std::string toInlineString(int indent_size = 0) const override;
1552
+
1553
+ TensorView* out() const {
1554
+ return output(0)->as<TensorView>();
1555
+ }
1556
+
1557
+ TensorView* in() const {
1558
+ return input(0)->as<TensorView>();
1559
+ }
1560
+
1561
+ std::vector<PolymorphicValue> evaluate(
1562
+ const ExpressionEvaluator& ee,
1563
+ const std::vector<PolymorphicValue>& inputs) const override;
1564
+ };
1565
+
1566
+ class ViewAsScalar : public Expr {
1567
+ public:
1568
+ using Expr::Expr;
1569
+
1570
+ ViewAsScalar(IrBuilderPasskey, Val* out, Val* in, IterDomain* vector_id);
1571
+
1572
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1573
+
1574
+ const char* getOpString() const override {
1575
+ return "ViewAsScalar";
1576
+ }
1577
+
1578
+ std::string toString(int indent_size = 0) const override;
1579
+ std::string toInlineString(int indent_size = 0) const override;
1580
+ std::vector<PolymorphicValue> evaluate(
1581
+ const ExpressionEvaluator& ee,
1582
+ const std::vector<PolymorphicValue>& inputs) const override;
1583
+
1584
+ Val* out() const {
1585
+ return output(0);
1586
+ }
1587
+
1588
+ Val* in() const {
1589
+ return input(0);
1590
+ }
1591
+
1592
+ // The IterDomain of type VectorComponent newly appended to the output
1593
+ IterDomain* vector_id() const {
1594
+ return attribute(0)->as<IterDomain>();
1595
+ }
1596
+ };
1597
+
1598
+ class NVF_API ViewOp : public Expr {
1599
+ public:
1600
+ using Expr::Expr;
1601
+
1602
+ ViewOp(IrBuilderPasskey, Val* out, Val* in);
1603
+
1604
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1605
+
1606
+ const char* getOpString() const override {
1607
+ return "ViewOp";
1608
+ }
1609
+
1610
+ std::string toString(int indent_size = 0) const override;
1611
+ std::string toInlineString(int indent_size = 0) const override;
1612
+
1613
+ TensorView* out() const {
1614
+ return output(0)->as<TensorView>();
1615
+ }
1616
+
1617
+ TensorView* in() const {
1618
+ return input(0)->as<TensorView>();
1619
+ }
1620
+
1621
+ std::vector<PolymorphicValue> evaluate(
1622
+ const ExpressionEvaluator& ee,
1623
+ const std::vector<PolymorphicValue>& inputs) const override;
1624
+ };
1625
+
1626
+ //! This operator explicitly models data movement between
1627
+ //! state spaces on GPU. Currently the modeled state spaces include
1628
+ //! global memory, shared memory and register.
1629
+ //!
1630
+ //! The main usage of this op is to facilitate generation of hardware
1631
+ //! accelerated memory ops, i.e. ldmatrix, cp.async and more to come.
1632
+ class NVF_API LoadStoreOp : public Expr {
1633
+ public:
1634
+ using Expr::Expr;
1635
+
1636
+ LoadStoreOp(
1637
+ IrBuilderPasskey,
1638
+ LoadStoreOpType op_type,
1639
+ Val* out,
1640
+ Val* in,
1641
+ CacheOp cache_op = CacheOp::Unspecified);
1642
+
1643
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1644
+
1645
+ const char* getOpString() const override {
1646
+ return "LoadStoreOp";
1647
+ }
1648
+
1649
+ std::vector<PolymorphicValue> evaluate(
1650
+ const ExpressionEvaluator& ee,
1651
+ const std::vector<PolymorphicValue>& inputs) const override;
1652
+
1653
+ std::string toString(int indent_size = 0) const override;
1654
+ std::string toInlineString(int indent_size = 0) const override;
1655
+
1656
+ Val* out() const {
1657
+ return output(0);
1658
+ }
1659
+
1660
+ Val* in() const {
1661
+ return input(0);
1662
+ }
1663
+
1664
+ LoadStoreOpType opType() const {
1665
+ return attribute<LoadStoreOpType>(0);
1666
+ }
1667
+
1668
+ CacheOp cacheOp() const {
1669
+ return attribute<CacheOp>(1);
1670
+ }
1671
+
1672
+ void setOpType(LoadStoreOpType op) {
1673
+ attribute<LoadStoreOpType>(0) = op;
1674
+ if (op != LoadStoreOpType::Set && op != LoadStoreOpType::CpAsync) {
1675
+ attribute<CacheOp>(1) = CacheOp::Unspecified;
1676
+ }
1677
+ }
1678
+
1679
+ void setCacheOp(CacheOp cache_op) {
1680
+ attribute<CacheOp>(1) = cache_op;
1681
+ }
1682
+ };
1683
+
1684
+ //! Representation a split on an IterDomain by "factor"
1685
+ //! inner_split dictates if the factor section of the split should be inside the
1686
+ //! remainer or outside.
1687
+ class NVF_API Split : public Expr {
1688
+ public:
1689
+ using Expr::Expr;
1690
+
1691
+ Split(
1692
+ IrBuilderPasskey,
1693
+ IterDomain* outer,
1694
+ IterDomain* inner,
1695
+ IterDomain* in,
1696
+ Val* factor,
1697
+ bool inner_split = true);
1698
+
1699
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1700
+
1701
+ const char* getOpString() const override {
1702
+ return "Split";
1703
+ }
1704
+
1705
+ std::string toString(int indent_size = 0) const override;
1706
+ std::string toInlineString(int indent_size = 0) const override;
1707
+
1708
+ IterDomain* outer() const {
1709
+ return output(0)->as<IterDomain>();
1710
+ }
1711
+ IterDomain* inner() const {
1712
+ return output(1)->as<IterDomain>();
1713
+ }
1714
+ IterDomain* in() const {
1715
+ return input(0)->as<IterDomain>();
1716
+ }
1717
+ Val* factor() const {
1718
+ return attributeVal(0);
1719
+ }
1720
+ Val* isDivisible() const;
1721
+
1722
+ bool innerSplit() const {
1723
+ return attribute<bool>(1);
1724
+ }
1725
+ };
1726
+
1727
+ //! Merge the IterDomains outer and inner into one domain, outer and inner
1728
+ //! dictate which will be traversed first (inner). Both IterDomains must be of
1729
+ //! the same iter or reduction type, as well as the same parallelization
1730
+ //! strategy if there is one
1731
+ class NVF_API Merge : public Expr {
1732
+ public:
1733
+ using Expr::Expr;
1734
+
1735
+ Merge(
1736
+ IrBuilderPasskey,
1737
+ IterDomain* out,
1738
+ IterDomain* outer,
1739
+ IterDomain* inner);
1740
+
1741
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1742
+
1743
+ const char* getOpString() const override {
1744
+ return "Merge";
1745
+ }
1746
+
1747
+ std::string toString(int indent_size = 0) const override;
1748
+ std::string toInlineString(int indent_size = 0) const override;
1749
+
1750
+ IterDomain* out() const {
1751
+ return output(0)->as<IterDomain>();
1752
+ }
1753
+ IterDomain* outer() const {
1754
+ return input(0)->as<IterDomain>();
1755
+ }
1756
+ IterDomain* inner() const {
1757
+ return input(1)->as<IterDomain>();
1758
+ }
1759
+ };
1760
+
1761
+ class Swizzle : public Expr {
1762
+ public:
1763
+ using Expr::Expr;
1764
+
1765
+ Swizzle(
1766
+ IrBuilderPasskey,
1767
+ IterDomain* out_x,
1768
+ IterDomain* out_y,
1769
+ IterDomain* in_x,
1770
+ IterDomain* in_y,
1771
+ SwizzleType swizzle_type = SwizzleType::NoSwizzle);
1772
+
1773
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1774
+
1775
+ const char* getOpString() const override {
1776
+ return "Swizzle";
1777
+ }
1778
+
1779
+ std::string toString(int indent_size = 0) const override;
1780
+ std::string toInlineString(int indent_size = 0) const override;
1781
+
1782
+ // Output iterdomain pair corresponding
1783
+ // to the original input iterdomain pair.
1784
+ IterDomain* outX() const {
1785
+ return output(0)->as<IterDomain>();
1786
+ }
1787
+
1788
+ IterDomain* outY() const {
1789
+ return output(1)->as<IterDomain>();
1790
+ }
1791
+
1792
+ // Input iterdomain pair.
1793
+ IterDomain* inX() const {
1794
+ return input(0)->as<IterDomain>();
1795
+ }
1796
+
1797
+ IterDomain* inY() const {
1798
+ return input(1)->as<IterDomain>();
1799
+ }
1800
+
1801
+ // The type of predefined 1-to-1 functions
1802
+ // used for swizzling math.
1803
+ auto swizzleType() const {
1804
+ return attribute<SwizzleType>(0);
1805
+ }
1806
+ };
1807
+
1808
+ //! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains.
1809
+ class NVF_API Swizzle2D : public Expr {
1810
+ public:
1811
+ using Expr::Expr;
1812
+
1813
+ Swizzle2D(
1814
+ IrBuilderPasskey,
1815
+ IterDomain* out_x,
1816
+ IterDomain* out_y,
1817
+ IterDomain* in_x,
1818
+ IterDomain* in_y,
1819
+ Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle,
1820
+ SwizzleMode swizzle_mode = SwizzleMode::Data);
1821
+
1822
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1823
+
1824
+ const char* getOpString() const override {
1825
+ return "Swizzle2D";
1826
+ }
1827
+
1828
+ std::string toString(int indent_size = 0) const override;
1829
+ std::string toInlineString(int indent_size = 0) const override;
1830
+
1831
+ // Output iterdomain pair corresponding
1832
+ // to the original input iterdomain pair.
1833
+ IterDomain* outX() const {
1834
+ return output(0)->as<IterDomain>();
1835
+ }
1836
+
1837
+ IterDomain* outY() const {
1838
+ return output(1)->as<IterDomain>();
1839
+ }
1840
+
1841
+ // Input iterdomain pair.
1842
+ IterDomain* inX() const {
1843
+ return input(0)->as<IterDomain>();
1844
+ }
1845
+
1846
+ IterDomain* inY() const {
1847
+ return input(1)->as<IterDomain>();
1848
+ }
1849
+
1850
+ // The type of predefined 1-to-1 functions
1851
+ // used for swizzling math.
1852
+ auto swizzleType() const {
1853
+ return attribute<Swizzle2DType>(0);
1854
+ }
1855
+
1856
+ // Swizzle mode of this swizzle instance.
1857
+ // [Note on swizzle mode]
1858
+ // On the current implementations we support two modes of
1859
+ // swizzle math, namely, data mode and loop mode.
1860
+ // `Data` mode swizzling is a swizzle that will change the
1861
+ // data layout in shared memory, likely in global memory buffers
1862
+ // as well in the future. see also IndexSwizzle in index_compute.cpp.
1863
+ //
1864
+ // Most important use cases are transpose bank conflict removal, and mma
1865
+ // swizzled shared memory layout. Example illustrated in 1D case:
1866
+ //
1867
+ // for (int i = 0; i<I; i++){
1868
+ // # This is a `Data` mode swizzle.
1869
+ // Tshared [swizzled(i)] = Tin[i];
1870
+ // }
1871
+ // # Now Tshared holds swizzled data, i.e. the data layout of
1872
+ // Tshared does not map to Tin with affine relationships.
1873
+ //
1874
+ // for(int i=0;i<I;i++){
1875
+ // Tout = Tshared[swizzled(i)];
1876
+ // }
1877
+ //
1878
+ // `Loop` mode swizzling does not affect the data layout of any buffer
1879
+ // but only permutes the iteration order of serial or parallel loop.
1880
+ // This is useful when we want to designate non-affine mapping of thread
1881
+ // to data or we want to generate non-affine loops.
1882
+ // Exampe illustrated in 1D case:
1883
+ // for (int i = 0; i<I; i++){
1884
+ // # This is a `Loop` mode swizzle
1885
+ // Tshared [swizzled(i)] = Tin[swizzled(i)];
1886
+ // }
1887
+ // # Now Tshared holds normal data, i.e. it still has
1888
+ // the same data layout as if the swizzle wasn't there.
1889
+ //
1890
+ // # Consumers of Tshared does not need to know about the
1891
+ // loop swizzle at previous op if not inlined.
1892
+ // for(int i=0;i<I;i++){
1893
+ // Tout = Tshared[i];
1894
+ // }
1895
+ // TODO: Loop swizzles eventually will be piped through in all mappings
1896
+ // and replay of the fusion IR infrastructure.
1897
+ auto swizzleMode() const {
1898
+ return attribute<SwizzleMode>(1);
1899
+ }
1900
+ };
1901
+
1902
+ //! IterDomain expression to resize
1903
+ class NVF_API Resize : public Expr {
1904
+ public:
1905
+ using Expr::Expr;
1906
+
1907
+ // Expand the input domain by left_expand and right_expand for each
1908
+ // of the start and end sides, respectively
1909
+ Resize(
1910
+ IrBuilderPasskey,
1911
+ IterDomain* out,
1912
+ IterDomain* in,
1913
+ Val* left_expand,
1914
+ Val* right_expand);
1915
+
1916
+ NVFUSER_DECLARE_CLONE_AND_CREATE
1917
+
1918
+ const char* getOpString() const override {
1919
+ return "Resize";
1920
+ }
1921
+
1922
+ std::string toString(int indent_size = 0) const override;
1923
+ std::string toInlineString(int indent_size = 0) const override;
1924
+
1925
+ IterDomain* out() const {
1926
+ return output(0)->as<IterDomain>();
1927
+ }
1928
+
1929
+ IterDomain* in() const {
1930
+ return input(0)->as<IterDomain>();
1931
+ }
1932
+
1933
+ Val* leftExpand() const {
1934
+ return attributeVal(0);
1935
+ }
1936
+
1937
+ Val* rightExpand() const {
1938
+ return attributeVal(1);
1939
+ }
1940
+ };
1941
+
1942
+ //! Integer value which has a special name
1943
+ //!
1944
+ //! These could be:
1945
+ //! - threadIdx.x
1946
+ //! - blockIdx.y
1947
+ //! - blockDim.z
1948
+ //! - T3.stride[2]
1949
+ //!
1950
+ class NVF_API NamedScalar : public Val {
1951
+ public:
1952
+ NamedScalar(IrBuilderPasskey passkey, std::string name, DataType dtype);
1953
+
1954
+ NamedScalar(const NamedScalar* src, IrCloner* ir_cloner);
1955
+
1956
+ NVFUSER_DECLARE_CLONE
1957
+
1958
+ const std::string& name() const {
1959
+ return name_;
1960
+ }
1961
+
1962
+ bool sameAs(const Statement* other) const override;
1963
+
1964
+ std::string toString(int indent_size = 0) const override {
1965
+ return name_;
1966
+ }
1967
+
1968
+ std::string toInlineString(int indent_size = 0) const override {
1969
+ return name_;
1970
+ }
1971
+
1972
+ //! Check if this is threadIdx.{x,y,z}
1973
+ bool isThreadIdx() const {
1974
+ auto p = getParallelIndex();
1975
+ return (
1976
+ p == ParallelType::TIDx || p == ParallelType::TIDy ||
1977
+ p == ParallelType::TIDz);
1978
+ }
1979
+
1980
+ //! Check if this is blockIdx.{x,y,z}
1981
+ bool isBlockIdx() const {
1982
+ auto p = getParallelIndex();
1983
+ return (
1984
+ p == ParallelType::BIDx || p == ParallelType::BIDy ||
1985
+ p == ParallelType::BIDz);
1986
+ }
1987
+
1988
+ //! Check if this is blockDim.{x,y,z}
1989
+ bool isBlockDim() const {
1990
+ auto p = getParallelDim();
1991
+ return (
1992
+ p == ParallelType::TIDx || p == ParallelType::TIDy ||
1993
+ p == ParallelType::TIDz);
1994
+ }
1995
+
1996
+ //! Check if this is gridDim.{x,y,z}
1997
+ bool isGridDim() const {
1998
+ auto p = getParallelDim();
1999
+ return (
2000
+ p == ParallelType::BIDx || p == ParallelType::BIDy ||
2001
+ p == ParallelType::BIDz);
2002
+ }
2003
+
2004
+ //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
2005
+ //! WARNING: Only works with Fusion container at the moment
2006
+ static NamedScalar* getParallelDim(ParallelType p_type);
2007
+
2008
+ //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x)
2009
+ //! WARNING: Only works with Fusion container at the moment
2010
+ static NamedScalar* getParallelIndex(ParallelType p_type);
2011
+
2012
+ //! Return the parallel type of this NamedScalar if it is an extent of a
2013
+ //! parallel dimension
2014
+ std::optional<ParallelType> getParallelDim() const;
2015
+
2016
+ //! Return the parallel type of this NamedScalar if it is an index of a
2017
+ //! parallel dimension
2018
+ std::optional<ParallelType> getParallelIndex() const;
2019
+
2020
+ private:
2021
+ std::string name_;
2022
+ };
2023
+
2024
+ class PadOp : public Expr {
2025
+ public:
2026
+ using Expr::Expr;
2027
+
2028
+ //! Pad a tensor as specified by a vector of integer scalars. For
2029
+ //! the actual semantics, see the torch.pad documentation. Note that
2030
+ //! unlike torch.pad, the pad_widths vector parameter must contain
2031
+ //! width vals for all dimensions. For non-padded dimensions, width
2032
+ //! vals should be integer zero.
2033
+ PadOp(
2034
+ IrBuilderPasskey passkey,
2035
+ TensorView* out,
2036
+ TensorView* inp,
2037
+ const std::vector<Val*>& pad_widths,
2038
+ Val* value);
2039
+
2040
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2041
+
2042
+ const char* getOpString() const override {
2043
+ return "PadOp";
2044
+ }
2045
+
2046
+ std::string toString(int indent_size = 0) const override;
2047
+ std::string toInlineString(int indent_size = 0) const override;
2048
+
2049
+ std::vector<PolymorphicValue> evaluate(
2050
+ const ExpressionEvaluator& ee,
2051
+ const std::vector<PolymorphicValue>& inputs) const override;
2052
+
2053
+ Val* out() const {
2054
+ return output(0);
2055
+ }
2056
+
2057
+ Val* in() const {
2058
+ return input(0);
2059
+ }
2060
+
2061
+ Val* value() const {
2062
+ return input(1);
2063
+ }
2064
+
2065
+ //! Return axes that are actually paded, i.e., those that have
2066
+ //! non-zero pad widths
2067
+ std::vector<int64_t> getPaddedAxes() const;
2068
+
2069
+ //! Return pad widths of the given axis, which are just zero for non padded
2070
+ //! dimensions
2071
+ std::pair<Val*, Val*> getPadWidths(int64_t axis) const;
2072
+
2073
+ //! Return the pad widths of all dimensions, including non-padded ones
2074
+ std::vector<Val*> getPadWidths() const;
2075
+
2076
+ private:
2077
+ //! Offset of pad_width inputs in the input vector
2078
+ int64_t getPadWidthInputOffset() const {
2079
+ return 2;
2080
+ }
2081
+
2082
+ //! Iterator to the first pad_width input
2083
+ auto getPadWidthInputBegin() const {
2084
+ return inputs().cbegin() + getPadWidthInputOffset();
2085
+ }
2086
+
2087
+ //! Iterator to the end of the pad_width inputs
2088
+ auto getPadWidthInputEnd() const {
2089
+ return inputs().cend();
2090
+ }
2091
+ };
2092
+
2093
+ // Similar to at::indexing::Slice
2094
+ struct Slice {
2095
+ Val* start = nullptr;
2096
+ Val* stop = nullptr;
2097
+ Val* step = nullptr;
2098
+ };
2099
+
2100
+ class SliceOp : public Expr {
2101
+ public:
2102
+ using Expr::Expr;
2103
+
2104
+ SliceOp(
2105
+ IrBuilderPasskey passkey,
2106
+ TensorView* out,
2107
+ TensorView* inp,
2108
+ const std::vector<Slice>& ranges);
2109
+
2110
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2111
+
2112
+ const char* getOpString() const override {
2113
+ return "SliceOp";
2114
+ }
2115
+
2116
+ std::string toString(int indent_size = 0) const override;
2117
+ std::string toInlineString(int indent_size = 0) const override;
2118
+ std::vector<PolymorphicValue> evaluate(
2119
+ const ExpressionEvaluator& ee,
2120
+ const std::vector<PolymorphicValue>& inputs) const override;
2121
+
2122
+ TensorView* out() const {
2123
+ return output(0)->as<TensorView>();
2124
+ }
2125
+
2126
+ TensorView* in() const {
2127
+ return input(0)->as<TensorView>();
2128
+ }
2129
+
2130
+ //! Get normalized ranges for SliceOp.
2131
+ std::vector<Slice> getRanges() const;
2132
+
2133
+ private:
2134
+ //! Offset of ranges input in the input vector
2135
+ int getRangeInputOffset() const {
2136
+ return 1;
2137
+ }
2138
+
2139
+ //! Iterator to the first range inputs
2140
+ auto getRangeInputBegin() const {
2141
+ return inputs().cbegin() + getRangeInputOffset();
2142
+ }
2143
+
2144
+ //! Iterator to the end of the range inputs
2145
+ auto getRangeInputEnd() const {
2146
+ return inputs().cend();
2147
+ }
2148
+ };
2149
+
2150
+ class NVF_API CatOp : public Expr {
2151
+ public:
2152
+ using Expr::Expr;
2153
+
2154
+ CatOp(
2155
+ IrBuilderPasskey passkey,
2156
+ Val* out,
2157
+ const std::vector<Val*>& inputs,
2158
+ int64_t concatenated_dim);
2159
+
2160
+ //! Create a cat op with the index and predicates for codegen. Only
2161
+ //! used for the Kernel container
2162
+ CatOp(
2163
+ IrBuilderPasskey passkey,
2164
+ Val* out,
2165
+ const std::vector<Val*>& inputs,
2166
+ int64_t concatenated_dim,
2167
+ Val* concatenated_domain_index,
2168
+ const std::vector<Val*>& preds);
2169
+
2170
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2171
+
2172
+ const char* getOpString() const override {
2173
+ return "CatOp";
2174
+ }
2175
+
2176
+ std::string toString(int indent_size = 0) const override;
2177
+ std::string toInlineString(int indent_size = 0) const override;
2178
+ std::vector<PolymorphicValue> evaluate(
2179
+ const ExpressionEvaluator& ee,
2180
+ std::unordered_map<const Val*, PolymorphicValue>& known_values)
2181
+ const override;
2182
+
2183
+ int64_t concatenatedDim() const {
2184
+ return attribute<int64_t>(0);
2185
+ }
2186
+
2187
+ //! The index val that determines which input tensor should be used
2188
+ //! to fill the particular output position of this expression. Only
2189
+ //! valid after indexing
2190
+ Val* getConcatenatedDomainIndex() const;
2191
+
2192
+ //! Gets a Bool indicating if the input tensor specified by
2193
+ //! tensor_idx should be used to fill the output tensor. Only valid
2194
+ //! with the Kernel container
2195
+ Val* getPred(int input_idx) const;
2196
+ };
2197
+
2198
+ //! Matmul Operator to be expression evaluated without decomposition.
2199
+ class MatmulOp : public Expr {
2200
+ public:
2201
+ using Expr::Expr;
2202
+
2203
+ MatmulOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b);
2204
+
2205
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2206
+
2207
+ const char* getOpString() const override {
2208
+ return "MatmulOp";
2209
+ }
2210
+
2211
+ std::string toString(int indent_size = 0) const override;
2212
+ std::string toInlineString(int indent_size = 0) const override;
2213
+
2214
+ TensorView* out() const {
2215
+ return output(0)->as<TensorView>();
2216
+ }
2217
+
2218
+ TensorView* inA() const {
2219
+ return input(0)->as<TensorView>();
2220
+ }
2221
+
2222
+ TensorView* inB() const {
2223
+ return input(1)->as<TensorView>();
2224
+ }
2225
+
2226
+ std::vector<PolymorphicValue> evaluate(
2227
+ const ExpressionEvaluator& ee,
2228
+ const std::vector<PolymorphicValue>& inputs) const override;
2229
+ };
2230
+
2231
+ // Linear node with same functionality as F.linear
2232
+ // (https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html#torch.nn.functional.linear)
2233
+ class LinearOp : public Expr {
2234
+ public:
2235
+ using Expr::Expr;
2236
+
2237
+ LinearOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* bias);
2238
+
2239
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2240
+
2241
+ const char* getOpString() const override {
2242
+ return "LinearOp";
2243
+ }
2244
+
2245
+ std::string toString(int indent_size = 0) const override;
2246
+ std::string toInlineString(int indent_size = 0) const override;
2247
+
2248
+ TensorView* out() const {
2249
+ return output(0)->as<TensorView>();
2250
+ }
2251
+
2252
+ TensorView* inA() const {
2253
+ return input(0)->as<TensorView>();
2254
+ }
2255
+
2256
+ TensorView* inB() const {
2257
+ return input(1)->as<TensorView>();
2258
+ }
2259
+
2260
+ TensorView* bias() const {
2261
+ if (has_bias()) {
2262
+ return input(2)->as<TensorView>();
2263
+ } else {
2264
+ return nullptr;
2265
+ }
2266
+ }
2267
+
2268
+ std::vector<PolymorphicValue> evaluate(
2269
+ const ExpressionEvaluator& ee,
2270
+ const std::vector<PolymorphicValue>& inputs) const override;
2271
+
2272
+ bool has_bias() const {
2273
+ return inputs().size() == 3;
2274
+ }
2275
+ };
2276
+
2277
+ /*
2278
+ SDPA node with same functionality at::_scaled_dot_product_flash_attention
2279
+ output = [N, H, L, Ev]
2280
+ logsumexp = [N, H, L]
2281
+ query_seq_len = scalar(int)
2282
+ key_seq_len = scalar(int)
2283
+ philox_seed = scalar tensor
2284
+ philox_offset = scalar tensor
2285
+ debug_attn_mask = scalar tensor (Thunder does not return a debug attn mask by
2286
+ setting `return_debug_mask=False` when invoking flash attention)
2287
+
2288
+ query = [N, H, L, E]
2289
+ key = [N, H, S, E]
2290
+ value = [N, H, S, Ev]
2291
+ dropout_p = scalar(double)
2292
+ is_causal = scalar(bool)
2293
+ scale = scalar(double)
2294
+
2295
+ N = number of sequences / batch size
2296
+ H = num of heads
2297
+ L = query sequence length / target sequence length
2298
+ S = key/value sequence length / src sequence length
2299
+ E = query/key embd dimension
2300
+ Ev = value embd dimension
2301
+
2302
+ For flash attention, E = Ev
2303
+ */
2304
+
2305
+ class SdpaFwdOp : public Expr {
2306
+ public:
2307
+ using Expr::Expr;
2308
+
2309
+ SdpaFwdOp(
2310
+ IrBuilderPasskey,
2311
+ TensorView* output,
2312
+ TensorView* log_sumexp,
2313
+ TensorView* philox_seed,
2314
+ TensorView* philox_offset,
2315
+ Val* query,
2316
+ Val* key,
2317
+ Val* value,
2318
+ Val* dropout_p,
2319
+ Val* is_causal,
2320
+ Val* scale);
2321
+
2322
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2323
+
2324
+ const char* getOpString() const override {
2325
+ return "SdpaFwdOp";
2326
+ }
2327
+
2328
+ std::string toString(int indent_size = 0) const override;
2329
+ std::string toInlineString(int indent_size = 0) const override;
2330
+
2331
+ TensorView* attn_out() const {
2332
+ return output(0)->as<TensorView>();
2333
+ }
2334
+
2335
+ TensorView* logsumexp() const {
2336
+ return output(1)->as<TensorView>();
2337
+ }
2338
+
2339
+ TensorView* philox_seed() const {
2340
+ return output(2)->as<TensorView>();
2341
+ }
2342
+
2343
+ TensorView* philox_offset() const {
2344
+ return output(3)->as<TensorView>();
2345
+ }
2346
+
2347
+ TensorView* query() const {
2348
+ return input(0)->as<TensorView>();
2349
+ }
2350
+
2351
+ TensorView* key() const {
2352
+ return input(1)->as<TensorView>();
2353
+ }
2354
+
2355
+ TensorView* value() const {
2356
+ return input(2)->as<TensorView>();
2357
+ }
2358
+
2359
+ Val* dropout_p() const {
2360
+ return input(3);
2361
+ }
2362
+
2363
+ Val* is_causal() const {
2364
+ return input(4);
2365
+ }
2366
+
2367
+ Val* scale() const {
2368
+ if (inputs().size() > 5) {
2369
+ return input(5);
2370
+ }
2371
+ return nullptr;
2372
+ }
2373
+
2374
+ std::vector<PolymorphicValue> evaluate(
2375
+ const ExpressionEvaluator& ee,
2376
+ const std::vector<PolymorphicValue>& inputs) const override;
2377
+ };
2378
+
2379
+ class Scope {
2380
+ public:
2381
+ explicit Scope(Expr* owner) : owner_(owner) {}
2382
+
2383
+ std::string toString(int indent_size = 0) const;
2384
+
2385
+ const std::vector<Expr*>& exprs() const {
2386
+ return exprs_;
2387
+ }
2388
+
2389
+ bool empty() const {
2390
+ return exprs_.empty();
2391
+ }
2392
+
2393
+ auto size() const {
2394
+ return exprs_.size();
2395
+ }
2396
+
2397
+ auto& at(size_t i) {
2398
+ return exprs_.at(i);
2399
+ }
2400
+
2401
+ auto& at(size_t i) const {
2402
+ return exprs_.at(i);
2403
+ }
2404
+
2405
+ auto& operator[](size_t i) {
2406
+ return at(i);
2407
+ }
2408
+
2409
+ auto& operator[](size_t i) const {
2410
+ return at(i);
2411
+ }
2412
+
2413
+ // Insert expr before expression at pos
2414
+ std::vector<Expr*>::iterator insert(size_t pos, Expr* expr);
2415
+
2416
+ // Insert expr before ref
2417
+ std::vector<Expr*>::iterator insert_before(Expr* ref, Expr* expr);
2418
+
2419
+ // Insert expr after ref
2420
+ std::vector<Expr*>::iterator insert_after(Expr* ref, Expr* expr);
2421
+
2422
+ void push_back(Expr* e) {
2423
+ exprs_.push_back(e);
2424
+ }
2425
+
2426
+ // Erase expr at pos
2427
+ void erase(size_t pos);
2428
+
2429
+ // Erase expr ref
2430
+ void erase(Expr* ref);
2431
+
2432
+ bool contains(Expr* expr) const;
2433
+
2434
+ void clear();
2435
+
2436
+ Expr* owner() const {
2437
+ return owner_;
2438
+ }
2439
+
2440
+ bool operator==(const Scope&) const {
2441
+ NVF_THROW("Should not reach here");
2442
+ }
2443
+
2444
+ // Insert expr before pos
2445
+ std::vector<Expr*>::iterator insert(
2446
+ std::vector<Expr*>::const_iterator pos,
2447
+ Expr* expr);
2448
+
2449
+ private:
2450
+ // Erase expr at pos
2451
+ void erase(std::vector<Expr*>::const_iterator pos);
2452
+
2453
+ private:
2454
+ std::vector<Expr*> exprs_;
2455
+
2456
+ //! Owner exprssion of this scope, e.g., IfThenElse
2457
+ Expr* owner_ = nullptr;
2458
+ };
2459
+
2460
+ //! ForLoop provides scoping around an int iterator from 0 to range. Exprs
2461
+ //! placed in its body are considered inside the scope of the for loop. In the
2462
+ //! future the implementation should look quite different so that we can do
2463
+ //! proper dependency annalysis like in Fusion.
2464
+ //!
2465
+ //! TODO(kir): this is not a real expression
2466
+ //!
2467
+ //! ForLoop may represent a part of an iteration domain representend
2468
+ //! by iter_domain_. In that case, the loop extent field, extent_, may
2469
+ //! be smaller than the extent of iter_domain_.
2470
+ class ForLoop final : public Expr {
2471
+ public:
2472
+ using Expr::Expr;
2473
+
2474
+ //! By default, start and stop are the same as those of iter_domain.
2475
+ //! Step is one by default.
2476
+ //!
2477
+ //! TODO: cleaner way to set options?
2478
+ ForLoop(
2479
+ IrBuilderPasskey passkey,
2480
+ IterDomain* iter_domain,
2481
+ Val* index,
2482
+ Val* start,
2483
+ Val* stop,
2484
+ Val* step,
2485
+ bool vectorize,
2486
+ Val* vectorize_shift,
2487
+ bool unroll_required,
2488
+ CircularBufferLoopStage circular_buffer_loop_stage,
2489
+ int64_t circular_buffer_loop_stage_depth);
2490
+
2491
+ ForLoop(
2492
+ IrBuilderPasskey passkey,
2493
+ IterDomain* iter_domain,
2494
+ Val* index,
2495
+ CircularBufferLoopStage circular_buffer_loop_stage,
2496
+ int64_t circular_buffer_loop_stage_depth);
2497
+
2498
+ ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain);
2499
+
2500
+ ForLoop(IrBuilderPasskey passkey, const ForLoop* other);
2501
+
2502
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2503
+
2504
+ const char* getOpString() const override {
2505
+ return "ForLoop";
2506
+ }
2507
+
2508
+ std::string toString(int indent_size = 0) const override;
2509
+ std::string toInlineString(int indent_size = 0) const override;
2510
+
2511
+ Val* index() const {
2512
+ return input(0);
2513
+ }
2514
+
2515
+ Val* indexOrStartIfTrivial() const {
2516
+ return isTrivial() ? start() : index();
2517
+ }
2518
+
2519
+ Val* start() const;
2520
+
2521
+ Val* stop() const;
2522
+
2523
+ Val* step() const;
2524
+
2525
+ Val* simplifiedStop() const;
2526
+
2527
+ // [pre | vectorize | post] <= inner-most, merged root domain
2528
+ // shift_ is applied to vectorize and post sections.
2529
+ Val* vectorize_shift() const {
2530
+ return attributeVal(4);
2531
+ }
2532
+
2533
+ IterDomain* iter_domain() const {
2534
+ return input(1)->as<IterDomain>();
2535
+ }
2536
+
2537
+ // TODO: Return pointer instead of reference to be more consistent
2538
+ Scope& body() {
2539
+ return attribute<Scope>(8);
2540
+ }
2541
+
2542
+ const Scope& body() const {
2543
+ return attribute<Scope>(8);
2544
+ }
2545
+
2546
+ bool empty() const {
2547
+ return body().empty();
2548
+ }
2549
+
2550
+ // vectorize is true when the for-loop contains a vectorize set
2551
+ // the flag is used to omit the for-loop from the kernel
2552
+ bool vectorize() const {
2553
+ return attribute<bool>(3);
2554
+ }
2555
+
2556
+ //! True if unrolled (i.e., "#pragma unroll" is attached)
2557
+ bool isUnrolled() const;
2558
+
2559
+ //! True if unroll is required for avoiding stack allocation
2560
+ bool isUnrollRequired() const {
2561
+ return attribute<bool>(5);
2562
+ }
2563
+
2564
+ //! Set unrolling required
2565
+ void requireUnroll() {
2566
+ attribute<bool>(5) = true;
2567
+ }
2568
+
2569
+ //! True if no actual for-loop is materialized
2570
+ bool isTrivial() const;
2571
+
2572
+ //! True if loop is grouped reduction/welford
2573
+ bool isGroup() const;
2574
+
2575
+ //! Returns the stage of a circular buffered iterdomain
2576
+ //! that this for loop materializes.
2577
+ auto circularBufferLoopStage() const {
2578
+ return attribute<CircularBufferLoopStage>(6);
2579
+ }
2580
+ auto circularBufferLoopStageDepth() const {
2581
+ return attribute<int64_t>(7);
2582
+ }
2583
+
2584
+ private:
2585
+ //! Returns if a loop could be unrolled.
2586
+ bool isUnrollable() const;
2587
+
2588
+ //! Not storing this as an attribute because this is only a cache for
2589
+ //! simplifiedStop. We are not interested in keeping this across clone/serde,
2590
+ //! etc.
2591
+ mutable Val* simplified_stop_ = nullptr;
2592
+ };
2593
+
2594
+ /*
2595
+ SDPA bwd node with same functionality
2596
+ at::_scaled_dot_product_flash_attention_backward
2597
+ grad_query = [N, H, L, E]
2598
+ grad_key = [N, H, S, E]
2599
+ grad_value = [N, H, S, Ev]
2600
+
2601
+ grad_output = [N, H, L, Ev]
2602
+ query = [N, H, L, E]
2603
+ key = [N, H, S, E]
2604
+ value = [N, H, S, Ev]
2605
+ output = [N, H, L, Ev]
2606
+ logsumexp = [N, H, L]
2607
+ dropout_p = scalar(double)
2608
+ is_causal = scalar(bool)
2609
+ philox_seed = scalar CPU tensor
2610
+ philox_offset = scalar CPU tensor
2611
+ scale = scalar(double)
2612
+
2613
+ N = number of sequences / batch size
2614
+ H = num of heads
2615
+ L = query sequence length / target sequence length
2616
+ S = key/value sequence length / src sequence length
2617
+ E = query/key embd dimension
2618
+ Ev = value embd dimension
2619
+
2620
+ For flash attention, E = Ev
2621
+ */
2622
+
2623
+ class SdpaBwdOp : public Expr {
2624
+ public:
2625
+ using Expr::Expr;
2626
+
2627
+ SdpaBwdOp(
2628
+ IrBuilderPasskey,
2629
+ TensorView* grad_query,
2630
+ TensorView* grad_key,
2631
+ TensorView* grad_value,
2632
+ TensorView* grad_output,
2633
+ TensorView* query,
2634
+ TensorView* key,
2635
+ TensorView* value,
2636
+ TensorView* output,
2637
+ TensorView* log_sumexp,
2638
+ Val* dropout_p,
2639
+ Val* is_causal,
2640
+ TensorView* philox_seed,
2641
+ TensorView* philox_offset,
2642
+ Val* scale);
2643
+
2644
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2645
+
2646
+ const char* getOpString() const override {
2647
+ return "SdpaBwdOp";
2648
+ }
2649
+
2650
+ std::string toString(int indent_size = 0) const override;
2651
+ std::string toInlineString(int indent_size = 0) const override;
2652
+
2653
+ TensorView* grad_query() const {
2654
+ return output(0)->as<TensorView>();
2655
+ }
2656
+
2657
+ TensorView* grad_key() const {
2658
+ return output(1)->as<TensorView>();
2659
+ }
2660
+
2661
+ TensorView* grad_value() const {
2662
+ return output(2)->as<TensorView>();
2663
+ }
2664
+
2665
+ TensorView* grad_attn() const {
2666
+ return input(0)->as<TensorView>();
2667
+ }
2668
+
2669
+ TensorView* query() const {
2670
+ return input(1)->as<TensorView>();
2671
+ }
2672
+
2673
+ TensorView* key() const {
2674
+ return input(2)->as<TensorView>();
2675
+ }
2676
+
2677
+ TensorView* value() const {
2678
+ return input(3)->as<TensorView>();
2679
+ }
2680
+
2681
+ TensorView* attn_out() const {
2682
+ return input(4)->as<TensorView>();
2683
+ }
2684
+
2685
+ TensorView* logsumexp() const {
2686
+ return input(5)->as<TensorView>();
2687
+ }
2688
+
2689
+ Val* dropout_p() const {
2690
+ return input(6);
2691
+ }
2692
+
2693
+ Val* is_causal() const {
2694
+ return input(7);
2695
+ }
2696
+
2697
+ Val* philox_seed() const {
2698
+ return input(8);
2699
+ }
2700
+
2701
+ Val* philox_offset() const {
2702
+ return input(9);
2703
+ }
2704
+
2705
+ Val* scale() const {
2706
+ if (inputs().size() > 10) {
2707
+ return input(10);
2708
+ }
2709
+ return nullptr;
2710
+ }
2711
+
2712
+ std::vector<PolymorphicValue> evaluate(
2713
+ const ExpressionEvaluator& ee,
2714
+ const std::vector<PolymorphicValue>& inputs) const override;
2715
+ };
2716
+
2717
+ class EmbeddingFwdOp : public Expr {
2718
+ public:
2719
+ using Expr::Expr;
2720
+
2721
+ EmbeddingFwdOp(
2722
+ IrBuilderPasskey,
2723
+ TensorView* output,
2724
+ TensorView* input,
2725
+ TensorView* weight,
2726
+ Val* padding_idx,
2727
+ Val* max_norm,
2728
+ Val* norm_type,
2729
+ Val* scale_grad_by_freq,
2730
+ Val* sparse);
2731
+
2732
+ NVFUSER_DECLARE_CLONE_AND_CREATE
2733
+
2734
+ const char* getOpString() const override {
2735
+ return "EmbeddingFwdOp";
2736
+ }
2737
+
2738
+ std::string toString(int indent_size = 0) const override;
2739
+ std::string toInlineString(int indent_size = 0) const override;
2740
+
2741
+ TensorView* out() const {
2742
+ return output(0)->as<TensorView>();
2743
+ }
2744
+
2745
+ TensorView* in() const {
2746
+ return input(0)->as<TensorView>();
2747
+ }
2748
+
2749
+ TensorView* weight() const {
2750
+ return input(1)->as<TensorView>();
2751
+ }
2752
+
2753
+ Val* norm_type() const {
2754
+ return input(2);
2755
+ }
2756
+
2757
+ Val* scale_grad_by_freq() const {
2758
+ return input(3);
2759
+ }
2760
+
2761
+ Val* sparse() const {
2762
+ return input(4);
2763
+ }
2764
+
2765
+ Val* padding_idx() const {
2766
+ if (has_padding_idx()) {
2767
+ return input(5);
2768
+ }
2769
+ return nullptr;
2770
+ }
2771
+
2772
+ Val* max_norm() const {
2773
+ if (has_max_norm()) {
2774
+ return input(5 + has_padding_idx());
2775
+ }
2776
+ return nullptr;
2777
+ }
2778
+
2779
+ bool has_padding_idx() const {
2780
+ return attribute<bool>(0);
2781
+ }
2782
+
2783
+ bool has_max_norm() const {
2784
+ return attribute<bool>(1);
2785
+ }
2786
+
2787
+ std::vector<PolymorphicValue> evaluate(
2788
+ const ExpressionEvaluator& ee,
2789
+ const std::vector<PolymorphicValue>& inputs) const override;
2790
+ };
2791
+
2792
+ } // namespace nvfuser