nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,687 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/core/ScalarType.h>
11
+ #include <exceptions.h>
12
+
13
+ #include <ir/builder_passkey.h>
14
+ #include <polymorphic_value.h>
15
+ #include <type.h>
16
+ #include <utils.h>
17
+ #include <visibility.h>
18
+
19
+ #include <cstdint>
20
+ #include <iostream>
21
+ #include <limits>
22
+ #include <memory>
23
+ #include <stdexcept>
24
+ #include <unordered_map>
25
+ #include <vector>
26
+
27
+ // TODO: Add more types (int32, int64)
28
+ // TODO: sameAs should have better logic to check against any type and return
29
+ // gracefully
30
+
31
+ /*
32
+ * This file defines the base IR structure. Any IR node in this system will
33
+ * inherit from one of the following classes: Statement, Expr, Val,
34
+ * IrInputOutput IR is any information that the code generation stack may need
35
+ * for analysis. By analysis we're refering to anything done in response to a
36
+ * user facing call of this stack. This could be careful tracking of user calls,
37
+ * and any transformation including optimizing transformations, user declared
38
+ * transformations, and lowering the IR.
39
+ */
40
+
41
+ //! IR header hierarchy
42
+ //! 1. utils.h - PolymorphicBase and NonCopyable
43
+ //! 2. ** ir/base_nodes.h ** - Statement, Expr, and Val
44
+ //! 3. ir/internal_base_nodes.h - IterDomain and TensorDomain
45
+ //! 4. ir/interface_nodes.h - TensorView and Scalar
46
+ //! 5. ir/internal_nodes.h - Any internal-only IR nodes
47
+
48
+ namespace nvfuser {
49
+
50
+ using ValueId = int32_t;
51
+
52
+ using StmtNameType = unsigned int;
53
+
54
+ constexpr StmtNameType kInvalidStmName =
55
+ std::numeric_limits<unsigned int>::max();
56
+
57
+ class Fusion;
58
+ class Expr;
59
+ class Val;
60
+ class IrCloner;
61
+ class IrContainer;
62
+ class IrBuilderPasskey;
63
+ class IrContainerPasskey;
64
+ class ExpressionEvaluator;
65
+
66
+ namespace kir {
67
+ class Kernel;
68
+ class Predicate;
69
+ } // namespace kir
70
+
71
+ // Passkey for container to register names with statements
72
+ class ExprPasskey {
73
+ friend class Expr;
74
+
75
+ private:
76
+ explicit ExprPasskey() = default;
77
+ };
78
+
79
+ #define NVFUSER_DECLARE_CLONE \
80
+ virtual Statement* clone(IrCloner* ir_cloner) const override;
81
+
82
+ #define NVFUSER_DEFINE_CLONE(ClassName) \
83
+ Statement* ClassName::clone(IrCloner* ir_cloner) const { \
84
+ return IrBuilder::clone(this, ir_cloner); \
85
+ }
86
+
87
+ //! Statement is the highest level node representation. Everything that is
88
+ //! considered "IR" will be derived from this class at some point. Both Values
89
+ //! and Expr's are a Statement. If there will ever be any more fundamental
90
+ //! types, they will also derive from Statement.
91
+ //!
92
+ //! We use Statements to pass around nodes of unknown compile type. Therefore it
93
+ //! is also important for the design to have a dispatch system for a Statment.
94
+ //! Basically beinng able to succienctly traverse down the inhereitance stack of
95
+ //! a Statment at runtime. This is currently implemented in dispatch.h
96
+ class NVF_API Statement : public NonCopyable, public PolymorphicBase {
97
+ friend void swap(Fusion&, Fusion&) noexcept;
98
+ friend void swap(IrContainer& a, IrContainer& b) noexcept;
99
+
100
+ public:
101
+ Statement() = delete;
102
+
103
+ // Cloning constructor
104
+ Statement(const Statement* src, IrCloner* ir_cloner);
105
+
106
+ // Dispatch functions, definitions in dispatch.cpp
107
+ template <typename T>
108
+ static void dispatch(T handler, Statement*);
109
+
110
+ template <typename T>
111
+ static void constDispatch(T handler, const Statement* const);
112
+
113
+ template <typename T>
114
+ static void mutatorDispatch(T mutator, Statement*);
115
+
116
+ // Accessor functions to types. Vals always have a DataType, Exprs never do
117
+ virtual std::optional<ValType> getValType() const {
118
+ return std::nullopt;
119
+ }
120
+ virtual std::optional<DataType> getDataType() const {
121
+ return std::nullopt;
122
+ }
123
+
124
+ // Short cut to figure out if it is a value/expression
125
+ bool isVal() const {
126
+ return getValType() != std::nullopt;
127
+ }
128
+ bool isExpr() const {
129
+ return isA<Expr>();
130
+ }
131
+
132
+ // Make sure this is a Val and return it as a Val*
133
+ Val* asVal();
134
+
135
+ // Make sure this is an Expr and return it as an Expr*
136
+ Expr* asExpr();
137
+
138
+ // Return the fusion this statement belongs to
139
+ Fusion* fusion() const;
140
+
141
+ // Return the kernel this statement belongs to
142
+ kir::Kernel* kernel() const;
143
+
144
+ // Return the container this statement belongs to
145
+ IrContainer* container() const {
146
+ return ir_container_;
147
+ }
148
+
149
+ // Return the int that represents its name
150
+ StmtNameType name() const {
151
+ return name_;
152
+ }
153
+
154
+ // Set the statements' name. Typically the container will set the name,
155
+ // however if we're dealing with cloning, IrBuilder will set the name, this
156
+ // maybe should be from IrCloner, however I didn't want to add another
157
+ // passkey.
158
+ void setName(IrContainerPasskey, StmtNameType name);
159
+ void setName(IrBuilderPasskey, StmtNameType name);
160
+
161
+ virtual bool sameType(const Statement* const other) {
162
+ return typeid(*this) == typeid(*other);
163
+ }
164
+
165
+ // Return if this statement is the same as another statement
166
+ // TODO: should this run through dispatch on this and other?
167
+ virtual bool sameAs(const Statement* other) const {
168
+ return this == other;
169
+ }
170
+
171
+ static bool lessThan(const Statement* stmt1, const Statement* stmt2);
172
+
173
+ virtual std::string toString(int indent_size = 0) const;
174
+
175
+ virtual std::string toInlineString(int indent_size = 0) const;
176
+
177
+ virtual Statement* clone(IrCloner* ir_cloner) const;
178
+
179
+ protected:
180
+ Statement(IrBuilderPasskey);
181
+
182
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
183
+ StmtNameType name_ = kInvalidStmName;
184
+
185
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
186
+ IrContainer* ir_container_ = nullptr;
187
+ };
188
+
189
+ inline std::string toString(Statement* stmt) {
190
+ return stmt->toString();
191
+ }
192
+
193
+ //! A Val represents a "value." These are objects, like tensors, scalars, and
194
+ //! memory locations, that are inputs and outputs of computations (represented
195
+ //! by Exprs, below)
196
+ //!
197
+ //! Vals are constant and unique and should always be passed
198
+ //! around as a pointer. Val can generally be thought of as representing any
199
+ //! type of data. Some examples:
200
+ //! * a constant size like convolution filter width
201
+ //! * a runtime constant like batch normalizations momentum
202
+ //! * a "symbolic" tensor like one passed down from the JIT
203
+ //! * a memory buffer used in device code
204
+ //!
205
+ //! Adding a Val:
206
+ //! Right now adding a Val is quite involved. Val's can be defined in ir.h or in
207
+ //! their own header file. The following is what is currently needed to add a
208
+ //! new Val:
209
+ //!
210
+ //! 1) Definition inheriting from Val
211
+ //! - Members must be private or protected
212
+ //! - Accessor functions for members
213
+ //! - Must call Val constructor, Val constructor registers with fusion
214
+ //! - Implementation of bool sameAs(...)
215
+ //! - Must implement a "cloning" constructor, ex.
216
+ //! Scalar::Scalar(const Val* src, IrCloner* ir_cloner)
217
+ //! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
218
+ //! 3) Default mutator function should be added to mutator.cpp
219
+ //! 4a) Printing functions should be added to ir/iostream.h/.cpp
220
+ //! 4b) Graphviz generation must be added to ir/graphviz.h/.cpp
221
+ //! 5) An enum value must be added to ValType in type.h
222
+ //! 6) A string entry must be added in val_type_string_map
223
+ //!
224
+ class NVF_API Val : public Statement {
225
+ public:
226
+ // When we create a Val we immediately register them with the active fusion.
227
+ explicit Val(
228
+ IrBuilderPasskey passkey,
229
+ ValType _vtype,
230
+ DataType _dtype = DataType::Null,
231
+ PolymorphicValue _value = std::monostate{})
232
+ : Statement(passkey),
233
+ vtype_(_vtype),
234
+ dtype_(std::move(_dtype)),
235
+ value_(std::move(_value)) {
236
+ if (value_.hasValue()) {
237
+ NVF_CHECK(
238
+ hasCompatibleDataType(value_, dtype_),
239
+ "Scalar value is not compatible with the given data type ",
240
+ dtype_,
241
+ " for value ",
242
+ PolymorphicValue_functions::toString(value_));
243
+ }
244
+ }
245
+ explicit Val(IrBuilderPasskey passkey, DataType dtype)
246
+ : Val(passkey, ValType::Others, std::move(dtype)) {}
247
+ explicit Val(IrBuilderPasskey passkey, PrimDataType dtype)
248
+ : Val(passkey, ValType::Others, DataType(dtype)) {}
249
+ explicit Val(IrBuilderPasskey passkey, PolymorphicValue value)
250
+ : Val(passkey, ValType::Others, nvfuser::getDataType(value), value) {}
251
+ explicit Val(IrBuilderPasskey passkey, PolymorphicValue value, DataType dtype)
252
+ : Val(passkey,
253
+ ValType::Others,
254
+ dtype,
255
+ castToDtype(std::move(value), dtype)) {}
256
+
257
+ // NOTE: we don't clone the definition_ and uses_ here
258
+ // since they may introduce cloning cycles. Instead, we copy
259
+ // the original pointers and we'll fix them up later part of the
260
+ // Fusion copy. Neither definition_ nor uses_ are copied through
261
+ // this constructor now leaving them to be resolved by later stages
262
+ //
263
+ Val(const Val* src, IrCloner* ir_cloner)
264
+ : Statement(src, ir_cloner),
265
+ vtype_(src->vtype_),
266
+ dtype_(src->dtype_),
267
+ value_(src->value_) {}
268
+
269
+ std::string toString(int indent_size = 0) const override;
270
+
271
+ std::string toInlineString(int indent_size = 0) const override;
272
+
273
+ // Dispatch functions, definitions in dispatch.cpp
274
+ template <typename T>
275
+ static void dispatch(T handler, Val*);
276
+
277
+ template <typename T>
278
+ static void constDispatch(T handler, const Val* const);
279
+
280
+ template <typename T>
281
+ static void mutatorDispatch(T mutator, Val*);
282
+
283
+ std::optional<ValType> getValType() const override {
284
+ return vtype_;
285
+ }
286
+
287
+ ValType vtype() const {
288
+ return vtype_;
289
+ }
290
+
291
+ DataType dtype() const {
292
+ return dtype_;
293
+ }
294
+
295
+ const PolymorphicValue& value() const {
296
+ return value_;
297
+ }
298
+
299
+ PolymorphicValue& value() {
300
+ return value_;
301
+ }
302
+
303
+ bool isSymbolic() const {
304
+ return !value_.hasValue();
305
+ }
306
+
307
+ // Throws if no DataType is found. Vals must have a DataType
308
+ std::optional<DataType> getDataType() const override;
309
+
310
+ bool isScalar() const {
311
+ return vtype_ == ValType::Others || vtype_ == ValType::NamedScalar;
312
+ }
313
+
314
+ // Returns if all dependencies are constant scalars
315
+ bool isConstScalar() const;
316
+
317
+ // Returns if all dependencies are constant integers
318
+ bool isConstInt() const;
319
+
320
+ bool isIntegralScalar() const {
321
+ return isScalar() && isIntegralType(dtype_);
322
+ }
323
+
324
+ bool isFloatingPointScalar() const {
325
+ return isScalar() && isFloatingPointType(dtype_);
326
+ }
327
+
328
+ bool isABool() const {
329
+ return isScalar() && dtype_ == DataType::Bool;
330
+ }
331
+
332
+ // If this Val's history is comprised only of constant values, will return a
333
+ // PolymorphicValue. Cannot make constant as expression evaluator takes
334
+ // non-constant Vals.
335
+ PolymorphicValue evaluate();
336
+
337
+ // Returns if no dependencies and is a constant scalar.
338
+ virtual bool isConst() const {
339
+ return value_.hasValue() && definition() == nullptr;
340
+ }
341
+
342
+ bool isZero() const;
343
+ bool isZeroInt() const;
344
+ bool isOne() const;
345
+ bool isOneInt() const;
346
+ bool isTrue() const;
347
+ bool isFalse() const;
348
+
349
+ // Returns the Expr that this value is an output of, returns nullptr if none
350
+ // was found
351
+ Expr* definition() const {
352
+ if (is_fusion_input_) {
353
+ return nullptr;
354
+ }
355
+ return definition_;
356
+ }
357
+
358
+ // Determine if value definition matches given expression type
359
+ template <typename T>
360
+ inline bool isDefinitionType() const;
361
+
362
+ //! Returns the Exprs for which this is an input.
363
+ //! Note that uses() will occasionally trigger a deferred call to
364
+ //! resetTvUses() which can be expensive as it requires traversing the graph
365
+ //! using Val definitions.
366
+ const std::vector<Expr*>& uses() const;
367
+
368
+ bool isFusionInput() const {
369
+ return is_fusion_input_;
370
+ }
371
+
372
+ bool isFusionOutput() const {
373
+ return is_fusion_output_;
374
+ }
375
+
376
+ bool sameType(const Statement* other) override {
377
+ return Statement::sameType(other) &&
378
+ getDataType() == other->as<Val>()->getDataType();
379
+ }
380
+
381
+ bool sameAs(const Statement* other) const override;
382
+
383
+ void setEvaluatorIndex(int to) {
384
+ // Only allow resetting evaluator_index to -1 OR
385
+ // setting evaluator_index if it isn't in-use
386
+ NVF_ERROR(evaluator_index_ == -1 || to == -1);
387
+ evaluator_index_ = to;
388
+ }
389
+
390
+ int evaluatorIndex() const {
391
+ return evaluator_index_;
392
+ }
393
+
394
+ // Following is managed by Fusion (or kirIrBuilder) and can change.
395
+ // TODO: Protect with a passkey.
396
+ void setDefinition(Expr* expr) {
397
+ definition_ = expr;
398
+ }
399
+
400
+ NVFUSER_DECLARE_CLONE
401
+
402
+ protected:
403
+ friend Fusion;
404
+
405
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
406
+ const ValType vtype_;
407
+
408
+ // TODO: Add fusion passkey for this
409
+ void setIsFusionInput(bool is_fusion_input) {
410
+ is_fusion_input_ = is_fusion_input;
411
+ }
412
+
413
+ // TODO: Add fusion passkey for this
414
+ void setIsFusionOutput(bool is_fusion_output) {
415
+ is_fusion_output_ = is_fusion_output;
416
+ }
417
+
418
+ // TODO: Add fusion or container passkey for this
419
+ void setUses(const std::vector<Expr*>& uses) {
420
+ uses_ = uses;
421
+ }
422
+
423
+ //! Insert a new expression into uses() if it is not already present and
424
+ //! return whether an insertion occurred.
425
+ bool addUse(Expr*);
426
+
427
+ //! Remove an expression from uses() if it is already present and return
428
+ //! whether a removal occurred.
429
+ bool removeUse(Expr*);
430
+
431
+ private:
432
+ // There's only one instance where dtype can change, and that's through
433
+ // resolving the index data type from nvfuser to either Int or Int32 for
434
+ // welford operations.
435
+ DataType dtype_;
436
+
437
+ // Following is managed by Fusion and can change.
438
+ bool is_fusion_input_ = false;
439
+ bool is_fusion_output_ = false;
440
+
441
+ Expr* definition_ = nullptr;
442
+ std::vector<Expr*> uses_;
443
+
444
+ // Expr evaluator idx;
445
+ int evaluator_index_ = -1;
446
+
447
+ // The concrete value of this Val. This is only used for constant Vals.
448
+ // Depending on the actual type of the Val, the allowed types of the
449
+ // value_ can be different. For example, for a TensorView, the value_ must be
450
+ // a at::Tensor, while for IterDomain, the value_ must be std::monostate{}.
451
+ PolymorphicValue value_;
452
+ };
453
+
454
+ using newObjectFuncType = Expr*(
455
+ IrContainer*,
456
+ std::vector<Val*>,
457
+ std::vector<Val*>,
458
+ std::vector<Statement*>);
459
+
460
+ //! A Expr represents a "computation." These are functions that takes inputs
461
+ //! and produce outputs, inputs and outputs all being Vals. There are
462
+ //! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
463
+ //! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
464
+ //! immutable. Conceptually, Exprs could always be manipulated using unique
465
+ //! pointers, and we could add this later. However, for now Exprs can be
466
+ //! replaced in a fusion, but they cannot be modified in place.
467
+ //!
468
+ //! The IR is static single assignment (SSA). Values can only be defined as an
469
+ //! output of an Expr once. If they are re-defined the original definition is
470
+ //! deleted from the program, as opposed to an ordered redefinition of the
471
+ //! value in the program.
472
+ //!
473
+ //! Note: Registering an Expr with a Fusion is actually 2 parts, one part is
474
+ //! done in the Expr constructor, so that should be called on anything that
475
+ //! inherits Expr. The issue with having registration in Expr's constructor, is
476
+ //! that the constructor of an Expr will set ouputs and inputs. This
477
+ //! information is important for registration with Fuser, so it can track the
478
+ //! dependency chain.
479
+ //!
480
+ //! Adding an Expr:
481
+ //! Right now adding an Expr is quite involved. Expr's can be defined in ir.h
482
+ //! or in their own header file. The following is what is currently needed for
483
+ //! Expr definitions:
484
+ //!
485
+ //! 1) Definition inheriting from Expr.
486
+ //! - Members must be private or protected
487
+ //! - Accessor functions for members
488
+ //! - Constructors need to register with the Fusion after inputs/outputs
489
+ //! are defined
490
+ //! - Implementation of bool sameAs(...)
491
+ //! 2) dispatch.h/.cpp must be updated to include dispatch of the new Expr
492
+ //! 3) Default mutator function should be added to mutator.h/.cpp
493
+ //! 4) Printing functions should be added to ir/iostream.h/.cpp
494
+ //! 5) Lower case convenience functions should be added to arith.h/.cpp (If
495
+ //! user facing)
496
+ //! 7) A string entry must be added in expr_type_string_map
497
+ //! 8) Entry added to ir_graphviz .cpp/.h
498
+ //!
499
+ class NVF_API Expr : public Statement {
500
+ public:
501
+ explicit Expr(IrBuilderPasskey);
502
+
503
+ Expr(const Expr* src, IrCloner* ir_cloner);
504
+
505
+ Expr(
506
+ IrBuilderPasskey,
507
+ std::vector<Val*> inputs,
508
+ std::vector<Val*> outputs,
509
+ std::vector<Statement*> attributes);
510
+
511
+ virtual newObjectFuncType* newObjectFunc() const = 0;
512
+
513
+ // Creates a new instance of the expression with all its field copied.
514
+ // Note that unlike IrCloner, this function only do a shallow copy
515
+ Expr* shallowCopy() const;
516
+
517
+ // Check that if this and other are the same operator. This main difference
518
+ // from sameAs is that sameOp does not check the inputs.
519
+ virtual bool sameOp(const Expr* other) const;
520
+
521
+ bool sameAs(const Statement* other) const override;
522
+
523
+ virtual std::vector<PolymorphicValue> evaluate(
524
+ const ExpressionEvaluator& ee,
525
+ const std::vector<PolymorphicValue>& inputs) const;
526
+
527
+ // This version allows evaluation of multiple ops together instead of one op
528
+ // at a time by overriding and skipping computation of intermediate inputs
529
+ // that are not required. For example:
530
+ // 1. CatOp is internally preceded by PadOp but the ATen evaluation uses only
531
+ // the unpadded inputs and the evaluation of padded inputs can be skipped.
532
+ // 2. Evaluating patterns in matmul fallback such as MmaOp + Cast/ MmaOp +
533
+ // Bias + Cast
534
+ virtual std::vector<PolymorphicValue> evaluate(
535
+ const ExpressionEvaluator& ee,
536
+ std::unordered_map<const Val*, PolymorphicValue>& known_values) const;
537
+
538
+ // Input/output accessors
539
+ const auto& inputs() const {
540
+ return inputs_;
541
+ }
542
+
543
+ const auto& outputs() const {
544
+ return outputs_;
545
+ }
546
+
547
+ const auto& attributes() const {
548
+ return attributes_;
549
+ }
550
+
551
+ auto input(size_t index) const {
552
+ return inputs_.at(index);
553
+ }
554
+
555
+ auto output(size_t index) const {
556
+ return outputs_.at(index);
557
+ }
558
+
559
+ auto attribute(size_t index) const {
560
+ return attributes_.at(index);
561
+ }
562
+
563
+ auto attributeVal(size_t index) const {
564
+ return dynamic_cast<Val*>(attributes_.at(index));
565
+ }
566
+
567
+ template <typename T>
568
+ T& attribute(size_t index) const;
569
+
570
+ // Dispatch functions, definitions in dispatch.cpp
571
+ template <typename T>
572
+ static void dispatch(T handler, Expr*);
573
+
574
+ template <typename T>
575
+ static void constDispatch(T handler, const Expr* const);
576
+
577
+ // TODO: Protect based on being in kernel container
578
+ kir::Predicate* predicate() const;
579
+
580
+ // Creates a shallow copy the expression with the given predicate attached.
581
+ // TODO: Protect based on being in kernel container
582
+ Expr* withPredicate(kir::Predicate* predicate);
583
+
584
+ // TODO: Protect based on being in kernel container
585
+ kir::Predicate* writePredicate() const;
586
+
587
+ // Creates a shallow copy the expression with the given write-predicate
588
+ // attached.
589
+ // TODO: Protect based on being in kernel container
590
+ Expr* withWritePredicate(kir::Predicate* write_predicate);
591
+
592
+ // Get the name of an expression
593
+ virtual const char* getOpString() const = 0;
594
+
595
+ // Get the label for Graphviz
596
+ virtual std::string getGraphvizLabel() const;
597
+
598
+ //! Perform assertions on new_val to ensure that it is valid for this
599
+ //! particular expression. This ensures that invalid values are not propagated
600
+ //! through the graph during concretization.
601
+ virtual void checkConcretization(Val* old_val, Val* new_val) const;
602
+
603
+ protected:
604
+ // TODO: Protect based on being in kernel container
605
+ void setPredicate(kir::Predicate* predicate);
606
+
607
+ // TODO: Protect based on being in kernel container
608
+ void setWritePredicate(kir::Predicate* write_predicate);
609
+
610
+ // TODO: Add Fusion passkey
611
+ void addInput(Val* input) {
612
+ NVF_ERROR(input != nullptr);
613
+ inputs_.push_back(input);
614
+ }
615
+
616
+ // TODO: Add Fusion passkey
617
+ void addOutput(Val* output) {
618
+ NVF_ERROR(output != nullptr);
619
+ outputs_.push_back(output);
620
+ }
621
+
622
+ // TODO: Add Fusion passkey
623
+ void addAttribute(Statement* attr) {
624
+ attributes_.push_back(attr);
625
+ }
626
+
627
+ // TODO: Add Fusion passkey
628
+ void addDataAttribute(PolymorphicValue attr);
629
+
630
+ // TODO: Add Fusion passkey
631
+ template <typename T>
632
+ void addDataAttribute(T attr) {
633
+ if constexpr (PolymorphicValue::is_candidate_type<T>) {
634
+ addDataAttribute(PolymorphicValue(std::move(attr)));
635
+ } else {
636
+ addDataAttribute(Opaque(std::move(attr)));
637
+ }
638
+ }
639
+
640
+ ExprPasskey exprPasskey() {
641
+ return ExprPasskey();
642
+ }
643
+
644
+ std::vector<Statement*> attributes_;
645
+
646
+ private:
647
+ std::vector<Val*> inputs_;
648
+ std::vector<Val*> outputs_;
649
+ kir::Predicate* predicate_ = nullptr;
650
+
651
+ // Only used for reduction-related expressions
652
+ kir::Predicate* write_predicate_ = nullptr;
653
+ };
654
+
655
+ template <typename T>
656
+ bool Val::isDefinitionType() const {
657
+ if (definition() != nullptr) {
658
+ return definition()->isA<T>();
659
+ }
660
+ return false;
661
+ }
662
+
663
+ #define NVFUSER_DECLARE_CLONE_AND_CREATE \
664
+ virtual Statement* clone(IrCloner* ir_cloner) const override; \
665
+ static Expr* newObject( \
666
+ IrContainer* container, \
667
+ std::vector<Val*> inputs, \
668
+ std::vector<Val*> outputs, \
669
+ std::vector<Statement*> attributes); \
670
+ virtual newObjectFuncType* newObjectFunc() const override { \
671
+ return newObject; \
672
+ }
673
+
674
+ #define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \
675
+ Statement* ClassName::clone(IrCloner* ir_cloner) const { \
676
+ return IrBuilder::clone(this, ir_cloner); \
677
+ } \
678
+ Expr* ClassName::newObject( \
679
+ IrContainer* container, \
680
+ std::vector<Val*> inputs, \
681
+ std::vector<Val*> outputs, \
682
+ std::vector<Statement*> attributes) { \
683
+ return IrBuilder::createInContainer<ClassName>( \
684
+ container, inputs, outputs, attributes); \
685
+ }
686
+
687
+ } // namespace nvfuser