nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,1125 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <macros.h>
12
+ #include <visibility.h>
13
+
14
+ #include <c10/core/ScalarType.h>
15
+
16
+ #include <polymorphic_value.h>
17
+
18
+ #include <array>
19
+ #include <complex>
20
+ #include <cstdint>
21
+ #include <iostream>
22
+ #include <optional>
23
+ #include <string>
24
+ #include <type_traits>
25
+ #include <typeinfo>
26
+ #include <unordered_set>
27
+ #include <variant>
28
+
29
+ namespace nvfuser {
30
+
31
+ // Order of strength
32
+ enum class ValType {
33
+ TensorDomain,
34
+ IterDomain,
35
+ TensorView,
36
+ NamedScalar,
37
+ Predicate,
38
+ TensorIndex,
39
+ Stream,
40
+ Others
41
+ };
42
+
43
+ // Manual - The user provides the Bool value. Predicate generation is bypassed.
44
+ // Inline corresponds with PredicateCompute::getInlinePredicate
45
+ // Unswitch corresponds with UnswitchPredicate::get
46
+ // Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag
47
+ // ReductionWrite - Same as Inline but without reduction axes
48
+ // LoopRotation - Predicate added by loop rotation, currently always true.
49
+ // ElectSync - Select a single thread to launch asynchronous operations.
50
+ enum class PredicateType {
51
+ Manual,
52
+ Inline,
53
+ Unswitch,
54
+ Vectorize,
55
+ Misaligned,
56
+ ReductionWrite,
57
+ LoopRotation,
58
+ ElectSync
59
+ };
60
+
61
+ // Index type is a convenience type that may be a 64 or 32 signed integer.
62
+ // This is helpful for math on indexing/size when we don't know what the index
63
+ // type might be. This allows us to prevent assuming the welford count must be
64
+ // int64_t which is relatively heavy to carry around. Index will be resolved
65
+ // at compile time with KernelIndexMode.
66
+ enum class PrimDataType {
67
+ // Floating point types
68
+ Double,
69
+ Float,
70
+ Half,
71
+ BFloat16,
72
+ Float8_e4m3fn,
73
+ Float8_e5m2,
74
+ // Integral types
75
+ Char,
76
+ Short,
77
+ Int32,
78
+ Int,
79
+ Byte, // Following ATen convention
80
+ UInt16, // Following ATen convention
81
+ UInt32,
82
+ UInt64,
83
+ Index,
84
+ // Boolean types
85
+ Bool,
86
+ // Complex types
87
+ ComplexDouble,
88
+ ComplexFloat,
89
+ // Pointers
90
+ SMemAddress,
91
+ TMemAddress,
92
+ // Null
93
+ Null
94
+ };
95
+
96
+ #if defined(__GNUC__) && !defined(__clang__)
97
+ #pragma GCC diagnostic push
98
+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
99
+ #endif
100
+
101
+ struct DataType;
102
+
103
+ struct ArrayType {
104
+ std::shared_ptr<DataType> type;
105
+ size_t size;
106
+ inline bool operator==(const ArrayType& other) const;
107
+ };
108
+
109
+ struct PointerType {
110
+ std::shared_ptr<DataType> type;
111
+ inline bool operator==(const PointerType& other) const;
112
+ };
113
+
114
+ struct StructType {
115
+ std::string name;
116
+ std::function<std::shared_ptr<Struct>()> create;
117
+
118
+ struct FieldInfo {
119
+ std::string name;
120
+ std::shared_ptr<DataType> type;
121
+ bool used_in_kernel = true;
122
+ };
123
+
124
+ std::vector<FieldInfo> fields;
125
+
126
+ template <typename T>
127
+ static StructType make(std::vector<FieldInfo> fields, std::string name = "") {
128
+ static_assert(
129
+ std::is_base_of<Struct, T>::value,
130
+ "StructType::make only accepts Struct types");
131
+ return StructType{
132
+ .name = std::move(name),
133
+ .create =
134
+ []() {
135
+ return std::static_pointer_cast<Struct>(std::make_shared<T>());
136
+ },
137
+ .fields = std::move(fields)};
138
+ }
139
+
140
+ inline const DataType& fieldDataType(const std::string& name) const {
141
+ for (const auto& field : fields) {
142
+ if (field.name == name) {
143
+ return *field.type;
144
+ }
145
+ }
146
+ NVF_THROW("Field ", name, " not found in struct ", this->name);
147
+ }
148
+
149
+ inline bool operator==(const StructType& other) const;
150
+ };
151
+
152
+ struct OpaqueType {
153
+ std::string name;
154
+ std::reference_wrapper<const std::type_info> type_info;
155
+ size_t size;
156
+
157
+ template <typename T>
158
+ static OpaqueType make(std::string name = "") {
159
+ return OpaqueType{
160
+ .name = std::move(name), .type_info = typeid(T), .size = sizeof(T)};
161
+ }
162
+
163
+ inline bool operator==(const OpaqueType& other) const {
164
+ return type_info.get() == other.type_info.get();
165
+ }
166
+ };
167
+
168
+ struct DataType {
169
+ using VariantOfSupportedTypes = std::
170
+ variant<PrimDataType, ArrayType, PointerType, StructType, OpaqueType>;
171
+ VariantOfSupportedTypes type = PrimDataType::Null;
172
+
173
+ DataType() = default;
174
+ DataType(VariantOfSupportedTypes type) : type(std::move(type)) {}
175
+ DataType(PrimDataType type) : type(type) {}
176
+ DataType(ArrayType type) : type(std::move(type)) {}
177
+ DataType(PointerType type) : type(std::move(type)) {}
178
+ DataType(StructType type) : type(std::move(type)) {}
179
+ DataType(OpaqueType type) : type(std::move(type)) {}
180
+
181
+ static constexpr PrimDataType Double = PrimDataType::Double;
182
+ static constexpr PrimDataType Float = PrimDataType::Float;
183
+ static constexpr PrimDataType Half = PrimDataType::Half;
184
+ static constexpr PrimDataType Float8_e4m3fn = PrimDataType::Float8_e4m3fn;
185
+ static constexpr PrimDataType Float8_e5m2 = PrimDataType::Float8_e5m2;
186
+ static constexpr PrimDataType Index = PrimDataType::Index;
187
+ static constexpr PrimDataType Char = PrimDataType::Char;
188
+ static constexpr PrimDataType Short = PrimDataType::Short;
189
+ static constexpr PrimDataType Int32 = PrimDataType::Int32;
190
+ static constexpr PrimDataType Int = PrimDataType::Int;
191
+ static constexpr PrimDataType Byte = PrimDataType::Byte;
192
+ static constexpr PrimDataType UInt16 = PrimDataType::UInt16;
193
+ static constexpr PrimDataType UInt32 = PrimDataType::UInt32;
194
+ static constexpr PrimDataType UInt64 = PrimDataType::UInt64;
195
+ static constexpr PrimDataType Bool = PrimDataType::Bool;
196
+ static constexpr PrimDataType BFloat16 = PrimDataType::BFloat16;
197
+ static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat;
198
+ static constexpr PrimDataType ComplexDouble = PrimDataType::ComplexDouble;
199
+ static constexpr PrimDataType SMemAddress = PrimDataType::SMemAddress;
200
+ static constexpr PrimDataType TMemAddress = PrimDataType::TMemAddress;
201
+ static constexpr PrimDataType Null = PrimDataType::Null;
202
+ };
203
+
204
+ inline bool operator==(const DataType& lhs, const DataType& rhs) {
205
+ return lhs.type == rhs.type;
206
+ }
207
+
208
+ inline bool operator!=(const DataType& lhs, const DataType& rhs) {
209
+ return !operator==(lhs, rhs);
210
+ }
211
+
212
+ bool ArrayType::operator==(const ArrayType& other) const {
213
+ return *type == *other.type && size == other.size;
214
+ }
215
+
216
+ bool PointerType::operator==(const PointerType& other) const {
217
+ return *type == *other.type;
218
+ }
219
+
220
+ bool StructType::operator==(const StructType& other) const {
221
+ if (fields.size() != other.fields.size()) {
222
+ return false;
223
+ }
224
+ for (auto i : c10::irange(fields.size())) {
225
+ if (fields[i].name != other.fields[i].name ||
226
+ *fields[i].type != *other.fields[i].type ||
227
+ fields[i].used_in_kernel != other.fields[i].used_in_kernel) {
228
+ return false;
229
+ }
230
+ }
231
+ return true;
232
+ }
233
+
234
+ inline StructType StructHandle::type() const {
235
+ return struct_ptr_->type();
236
+ }
237
+
238
+ StructType globalTensorMetaData(
239
+ const PrimDataType& dtype,
240
+ size_t dim,
241
+ size_t alloc_dim);
242
+
243
+ inline StructType globalTensorMetaData(const PrimDataType& dtype, size_t dim) {
244
+ return globalTensorMetaData(dtype, dim, dim);
245
+ }
246
+
247
+ class Val;
248
+ //! Get the type of a Val's metadata, currently only supporting tensors
249
+ NVF_API DataType metaDataTypeOf(const Val* tv);
250
+
251
+ enum class KernelIndexMode { INT32, INT64 };
252
+
253
+ PrimDataType indexModeToDtype(KernelIndexMode index_mode);
254
+ KernelIndexMode indexTypeToMode(DataType index_type);
255
+
256
+ // check if type preserves all information from base_type. Which indicates a
257
+ // cast from base_type -> type -> base_type should be bit-wise identical
258
+ bool isInclusiveType(const DataType& base_type, const DataType& type);
259
+
260
+ // Returns if the datatype is a floating point type
261
+ inline bool isFloatingPointType(DataType dtype) {
262
+ return dtype == DataType::Double || dtype == DataType::Float ||
263
+ dtype == DataType::Half || dtype == DataType::BFloat16 ||
264
+ dtype == DataType::Float8_e4m3fn || dtype == DataType::Float8_e5m2;
265
+ }
266
+
267
+ // Returns if the datatype is an integer type
268
+ inline bool isIntegralType(DataType dtype) {
269
+ return std::visit(
270
+ [](auto&& dtype) {
271
+ using T = std::decay_t<decltype(dtype)>;
272
+ if constexpr (std::is_same_v<T, PrimDataType>) {
273
+ switch (dtype) {
274
+ case DataType::Index:
275
+ case DataType::Char:
276
+ case DataType::Short:
277
+ case DataType::Int:
278
+ case DataType::Int32:
279
+ case DataType::Byte:
280
+ case DataType::UInt16:
281
+ case DataType::UInt32:
282
+ case DataType::UInt64:
283
+ return true;
284
+ default:
285
+ return false;
286
+ }
287
+ }
288
+ return false;
289
+ },
290
+ dtype.type);
291
+ }
292
+
293
+ // Returns if the datatype is an unsigned integer type
294
+ inline bool isUnsignedIntegralType(DataType dtype) {
295
+ return dtype == DataType::Byte || dtype == DataType::UInt16 ||
296
+ dtype == DataType::UInt32 || dtype == DataType::UInt64;
297
+ }
298
+
299
+ // Returns if the datatype is a pointer type
300
+ inline bool isPointerType(DataType dtype) {
301
+ return std::holds_alternative<PointerType>(dtype.type) ||
302
+ dtype == DataType::SMemAddress || dtype == DataType::TMemAddress;
303
+ }
304
+
305
+ // Returns if the datatype is an integer or pointer type
306
+ inline bool isIntegralOrPointerType(DataType dtype) {
307
+ return isIntegralType(dtype) || isPointerType(dtype);
308
+ }
309
+
310
+ // Returns if the datatype is a boolean type
311
+ inline bool isBooleanType(DataType dtype) {
312
+ return dtype == DataType::Bool;
313
+ }
314
+
315
+ // Returns if the datatype is a complex type
316
+ inline bool isComplexType(DataType dtype) {
317
+ return dtype == DataType::ComplexFloat || dtype == DataType::ComplexDouble;
318
+ }
319
+
320
+ // Returns if the datatype is a complex type
321
+ inline bool isStructType(DataType dtype) {
322
+ return std::holds_alternative<StructType>(dtype.type);
323
+ }
324
+
325
+ // Return the corresponding scalar of a complex type
326
+ DataType getTypeFromComplexType(DataType dtype);
327
+ // Return the corresponding complex type of a scalar
328
+ DataType getComplexTypeFromType(DataType dtype);
329
+ // Return if the datatype is supported on the current device
330
+ NVF_API bool isSupportedTypeByDevice(DataType dtype);
331
+
332
+ NVF_API int64_t dataTypeSize(DataType type);
333
+
334
+ // If the index type is known it will be automatically used here
335
+ int64_t dataTypeSize(DataType type, DataType index_type);
336
+
337
+ template <PrimDataType DT>
338
+ struct DataTypeToNativeType;
339
+
340
+ template <PrimDataType DT>
341
+ struct DataTypeToAtenType;
342
+
343
+ template <typename NativeType>
344
+ struct NativeTypeToDataType;
345
+
346
+ template <at::ScalarType aten_type>
347
+ struct AtenTypeToDataType;
348
+
349
+ template <at::ScalarType aten_type>
350
+ struct AtenTypeToNativeType;
351
+
352
+ template <typename NativeType>
353
+ struct IsPrimitiveNativeType : std::false_type {};
354
+
355
+ #define DEFINE_DATATYPE_TO_NATIVE_TYPE(data_type, native_type) \
356
+ template <> \
357
+ struct DataTypeToNativeType<data_type> { \
358
+ using type = native_type; \
359
+ }; \
360
+ template <> \
361
+ struct NativeTypeToDataType<native_type> { \
362
+ static constexpr PrimDataType type = data_type; \
363
+ }; \
364
+ template <> \
365
+ struct IsPrimitiveNativeType<native_type> : std::true_type {}
366
+
367
+ #define DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( \
368
+ data_type, at_type, native_type) \
369
+ DEFINE_DATATYPE_TO_NATIVE_TYPE(data_type, native_type); \
370
+ template <> \
371
+ struct AtenTypeToDataType<at_type> { \
372
+ static constexpr PrimDataType type = data_type; \
373
+ }; \
374
+ template <> \
375
+ struct AtenTypeToNativeType<at_type> { \
376
+ using type = native_type; \
377
+ }
378
+
379
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
380
+ DataType::Float,
381
+ at::ScalarType::Float,
382
+ float);
383
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
384
+ DataType::Double,
385
+ at::ScalarType::Double,
386
+ double);
387
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
388
+ DataType::Half,
389
+ at::ScalarType::Half,
390
+ at::Half);
391
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
392
+ DataType::BFloat16,
393
+ at::ScalarType::BFloat16,
394
+ at::BFloat16);
395
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
396
+ DataType::Float8_e4m3fn,
397
+ at::ScalarType::Float8_e4m3fn,
398
+ at::Float8_e4m3fn);
399
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
400
+ DataType::Float8_e5m2,
401
+ at::ScalarType::Float8_e5m2,
402
+ at::Float8_e5m2);
403
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
404
+ DataType::Char,
405
+ at::ScalarType::Char,
406
+ int8_t);
407
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
408
+ DataType::Short,
409
+ at::ScalarType::Short,
410
+ int16_t);
411
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
412
+ DataType::Int32,
413
+ at::ScalarType::Int,
414
+ int);
415
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
416
+ DataType::Int,
417
+ at::ScalarType::Long,
418
+ int64_t);
419
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
420
+ DataType::Byte,
421
+ at::ScalarType::Byte,
422
+ uint8_t);
423
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
424
+ DataType::UInt16,
425
+ at::ScalarType::UInt16,
426
+ uint16_t);
427
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
428
+ DataType::UInt32,
429
+ at::ScalarType::UInt32,
430
+ uint32_t);
431
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
432
+ DataType::UInt64,
433
+ at::ScalarType::UInt64,
434
+ uint64_t);
435
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
436
+ DataType::Bool,
437
+ at::ScalarType::Bool,
438
+ bool);
439
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
440
+ DataType::ComplexFloat,
441
+ at::ScalarType::ComplexFloat,
442
+ std::complex<float>);
443
+ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
444
+ DataType::ComplexDouble,
445
+ at::ScalarType::ComplexDouble,
446
+ std::complex<double>);
447
+
448
+ #undef DEFINE_DATATYPE_TO_NATIVE_TYPE
449
+ #undef DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE
450
+
451
+ inline DataType getDataType(const PolymorphicValue& value) {
452
+ std::optional<DataType> dtype = std::nullopt;
453
+ PolymorphicValue::for_all_types([&value, &dtype](auto _) {
454
+ using T = typename decltype(_)::type;
455
+ if constexpr (IsPrimitiveNativeType<T>::value) {
456
+ if (value.is<T>()) {
457
+ dtype = NativeTypeToDataType<T>::type;
458
+ }
459
+ } else if constexpr (std::is_same_v<T, std::vector<PolymorphicValue>>) {
460
+ if (value.is<T>()) {
461
+ const auto& vec = value.as<T>();
462
+ size_t size = vec.size();
463
+ NVF_CHECK(size > 0, "Empty array is not supported");
464
+ dtype =
465
+ ArrayType{std::make_shared<DataType>(getDataType(vec[0])), size};
466
+ }
467
+ } else if constexpr (std::is_same_v<T, Pointer>) {
468
+ // For pointers in polymorphic value, we only store the data size of the
469
+ // pointee, so it is impossible to infer the pointer type.
470
+ NVF_CHECK(!value.is<T>(), "Can not infer pointer type.");
471
+ } else if constexpr (std::is_same_v<T, StructHandle>) {
472
+ if (value.is<T>()) {
473
+ dtype = value.as<T>().type();
474
+ }
475
+ } else if constexpr (std::is_same_v<T, Opaque>) {
476
+ if (value.is<T>()) {
477
+ const auto& opaque = value.as<T>();
478
+ dtype = DataType(OpaqueType{
479
+ .type_info = opaque.any().type(), .size = opaque.size()});
480
+ }
481
+ }
482
+ });
483
+ NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name());
484
+ return dtype.value();
485
+ }
486
+
487
+ inline bool isCompatibleDataType(DataType dtype, DataType dtype2) {
488
+ if (dtype == dtype2) {
489
+ return true;
490
+ }
491
+ if (isIntegralType(dtype) && isIntegralType(dtype2)) {
492
+ return true;
493
+ }
494
+ if (isFloatingPointType(dtype) && isFloatingPointType(dtype2)) {
495
+ return true;
496
+ }
497
+ if (isComplexType(dtype) && isComplexType(dtype2)) {
498
+ return true;
499
+ }
500
+ if (std::holds_alternative<ArrayType>(dtype.type) &&
501
+ std::holds_alternative<ArrayType>(dtype2.type)) {
502
+ const auto& array_type = std::get<ArrayType>(dtype.type);
503
+ const auto& array_type2 = std::get<ArrayType>(dtype2.type);
504
+ return array_type.size == array_type2.size &&
505
+ isCompatibleDataType(*array_type.type, *array_type2.type);
506
+ }
507
+ if (std::holds_alternative<StructType>(dtype.type) &&
508
+ std::holds_alternative<StructType>(dtype2.type)) {
509
+ const auto& struct_type = std::get<StructType>(dtype.type);
510
+ const auto& struct_type2 = std::get<StructType>(dtype2.type);
511
+ if (struct_type.fields.size() != struct_type2.fields.size()) {
512
+ return false;
513
+ }
514
+ for (auto i : c10::irange(struct_type.fields.size())) {
515
+ if (struct_type.fields[i].name != struct_type2.fields[i].name ||
516
+ !isCompatibleDataType(
517
+ *struct_type.fields[i].type, *struct_type2.fields[i].type)) {
518
+ return false;
519
+ }
520
+ }
521
+ return true;
522
+ }
523
+ if (std::holds_alternative<OpaqueType>(dtype.type) &&
524
+ std::holds_alternative<OpaqueType>(dtype2.type)) {
525
+ const auto& opaque_type = std::get<OpaqueType>(dtype.type);
526
+ const auto& opaque_type2 = std::get<OpaqueType>(dtype2.type);
527
+ return opaque_type.type_info.get() == opaque_type2.type_info.get();
528
+ }
529
+ return false;
530
+ }
531
+
532
+ inline bool hasCompatibleDataType(
533
+ const PolymorphicValue& value,
534
+ DataType dtype) {
535
+ // We can not always completely infer data type from value, so we need some
536
+ // special handling here.
537
+ if (std::holds_alternative<PointerType>(dtype.type)) {
538
+ if (!value.is<Pointer>()) {
539
+ return false;
540
+ }
541
+ auto ptr = std::get<PointerType>(dtype.type);
542
+ return dataTypeSize(*ptr.type) == value.as<Pointer>().size();
543
+ } else if (std::holds_alternative<ArrayType>(dtype.type)) {
544
+ if (!value.is<std::vector>()) {
545
+ return false;
546
+ }
547
+ const auto& array_type = std::get<ArrayType>(dtype.type);
548
+ if (array_type.size != value.as<std::vector>().size()) {
549
+ return false;
550
+ }
551
+ if (array_type.size == 0) {
552
+ return true;
553
+ }
554
+ }
555
+ return isCompatibleDataType(getDataType(value), dtype);
556
+ }
557
+
558
+ #if defined(__GNUC__) && !defined(__clang__)
559
+ #pragma GCC diagnostic pop
560
+ #endif
561
+
562
+ //! Returns the number of base-10 digits required to guarantee a lossless
563
+ //! binary->text->binary round-trip. For exact types, this function returns 0.
564
+ int max_digits10(DataType dtype);
565
+
566
+ enum class UnaryOpType {
567
+ Cast,
568
+ BitCast,
569
+ RefCast,
570
+
571
+ Abs,
572
+ Acos,
573
+ Acosh,
574
+ Address,
575
+ Asin,
576
+ Asinh,
577
+ Atan,
578
+ Atanh,
579
+ Ceil,
580
+ Cos,
581
+ Cosh,
582
+ Dereference,
583
+ Exp,
584
+ Exp2,
585
+ Expm1,
586
+ Erf,
587
+ Erfc,
588
+ Erfinv,
589
+ Erfcinv,
590
+ Floor,
591
+ Frac,
592
+ Gelu,
593
+ Imag,
594
+ Silu,
595
+ Lgamma,
596
+ Log,
597
+ Log10,
598
+ Log1p,
599
+ Log2,
600
+ Neg,
601
+ Real,
602
+ Reciprocal,
603
+ Relu,
604
+ Rsqrt,
605
+ Round,
606
+ Sigmoid,
607
+ Signbit,
608
+ Sin,
609
+ Sinh,
610
+ Sqrt,
611
+ Tan,
612
+ Tanh,
613
+ Trunc,
614
+
615
+ // Tools to help debugging
616
+ Print,
617
+
618
+ // Logical and bitwise negation
619
+ LogicalNot,
620
+ BitwiseNot,
621
+
622
+ // Operators returning boolean values
623
+ IsFinite,
624
+ IsInf,
625
+ IsNan,
626
+ IsNegInf,
627
+ IsPosInf,
628
+ IsReal,
629
+
630
+ // Special unary ops
631
+ ElectSync,
632
+ ToUnsignedSmemAddr,
633
+ AdjustPartialLdMatrixAddrInTuring8,
634
+ AdjustPartialLdMatrixAddrInTuring16
635
+ };
636
+
637
+ // TODO: Order of this list is important as it affects type promotion. it's not
638
+ // in the right order now.
639
+ enum class BinaryOpType {
640
+ // Math Ops
641
+ Add,
642
+ Atan2,
643
+ Div,
644
+ Fmod,
645
+ Max,
646
+ Min,
647
+ Mul,
648
+ Nextafter,
649
+ Pow,
650
+ Remainder,
651
+ Sub,
652
+ // TypeAs,
653
+
654
+ // Integer output ops.
655
+ Mod,
656
+ CeilDiv,
657
+ Lshift,
658
+ Rshift,
659
+ Gcd,
660
+
661
+ // Bitwise Ops
662
+ // These always return integers, as if each arg is first cast to int
663
+ // If changing modify isIntegerOp.
664
+ BitwiseAnd,
665
+ BitwiseOr,
666
+ BitwiseXor,
667
+
668
+ // Logical Ops
669
+ // Int operations, leave position of Mod as first logical op see
670
+ // isLogicalOp(BinaryOpType bopt)
671
+ Eq,
672
+ GE,
673
+ GT,
674
+ LE,
675
+ LT,
676
+ NE,
677
+
678
+ // These ops compare as if each arg is first cast to bool
679
+ LogicalAnd,
680
+ LogicalOr,
681
+
682
+ // generate complex from real and imaginary parts
683
+ Complex
684
+ };
685
+
686
+ enum class ScatterOpType { Set };
687
+
688
+ enum class RNGOpType {
689
+ Uniform, // Uniform in [0, 1)
690
+ UniformRange, // Uniform in [low, high]
691
+ NormalStandard, // Normal with mean 0, std 1
692
+ NormalGeneral, // Normal with given mean and std
693
+ Undefined,
694
+ };
695
+
696
+ // Return if output of operator should be a boolean
697
+ bool isIntegerOp(const BinaryOpType bopt);
698
+
699
+ // Return if output of operator should be a boolean
700
+ bool isLogicalOp(const BinaryOpType bopt);
701
+
702
+ enum class TernaryOpType { Clamp, Lerp, Threshold, Where };
703
+
704
+ enum class ParallelType {
705
+ DIDx,
706
+ BIDz,
707
+ BIDy,
708
+ BIDx,
709
+ TIDz,
710
+ TIDy,
711
+ TIDx,
712
+ Stream,
713
+ Vectorize,
714
+ MisalignedVectorize,
715
+ Unroll,
716
+ Unswitch,
717
+ Mma,
718
+ Group,
719
+ Bulk,
720
+ Serial
721
+ };
722
+
723
+ std::unordered_set<ParallelType> allParallelTypesExcept(
724
+ const std::unordered_set<ParallelType>& except);
725
+
726
+ static constexpr std::array<ParallelType, 6> kParallelTypeThreads = {
727
+ ParallelType::BIDx,
728
+ ParallelType::BIDy,
729
+ ParallelType::BIDz,
730
+ ParallelType::TIDx,
731
+ ParallelType::TIDy,
732
+ ParallelType::TIDz};
733
+
734
+ static constexpr std::array<ParallelType, 3> kParallelTypeBIDs = {
735
+ ParallelType::BIDx,
736
+ ParallelType::BIDy,
737
+ ParallelType::BIDz};
738
+
739
+ static constexpr std::array<ParallelType, 3> kParallelTypeTIDs = {
740
+ ParallelType::TIDx,
741
+ ParallelType::TIDy,
742
+ ParallelType::TIDz};
743
+
744
+ static constexpr std::array<ParallelType, 1> kParallelTypeDIDs = {
745
+ ParallelType::DIDx};
746
+
747
+ enum class MemoryType { Local, Shared, Global, Tensor };
748
+
749
+ // Symbolic: Undetermined between Iteration or Broadcast
750
+ enum class IterType {
751
+ Iteration,
752
+ Reduction,
753
+ Broadcast,
754
+ Stride,
755
+ GatherScatter,
756
+ VectorComponent,
757
+ Symbolic
758
+ };
759
+
760
+ // Used for Iteration Domain mapping modes in ComputeAtMap
761
+ enum class IdMappingMode {
762
+ EXACT,
763
+ ALMOSTEXACT,
764
+ BROADCAST,
765
+ PERMISSIVE,
766
+ LOOP,
767
+ // TODO: Reconsider if this graph is really necessary
768
+ PERMISSIVE_RESIZE,
769
+ // TODO: Reconsider if this graph is really necessary
770
+ INNERMOST
771
+ };
772
+
773
+ static constexpr std::array<IdMappingMode, 7> kIdMappingModes = {
774
+ IdMappingMode::EXACT,
775
+ IdMappingMode::ALMOSTEXACT,
776
+ IdMappingMode::BROADCAST,
777
+ IdMappingMode::PERMISSIVE,
778
+ IdMappingMode::LOOP,
779
+ IdMappingMode::PERMISSIVE_RESIZE,
780
+ IdMappingMode::INNERMOST};
781
+
782
+ // See
783
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
784
+ // for what each option means. Will also consider .L1::no_allocate because .cs
785
+ // still pollutes cache to some extent.
786
+ enum class CacheOp {
787
+ Unspecified, // Opt in for the default cache operator or when the LoadStoreOp
788
+ // doesn't take a cache operator.
789
+ AllLevels,
790
+ Streaming,
791
+ Global,
792
+ };
793
+
794
+ //! Used to annotate the special memory intrinsics that a loadstore op will be
795
+ //! lowered to.
796
+ //!
797
+ //! SegmenterSet here is used to hint segmenter to break kernel on the output
798
+ //! of the node
799
+ enum class LoadStoreOpType {
800
+ Set,
801
+ SegmenterSet,
802
+ LdMatrix,
803
+ CpAsync,
804
+ CpAsyncBulk,
805
+ CpAsyncBulkTensorTile,
806
+ StMatrix,
807
+ LdTMem,
808
+ StTMem
809
+ };
810
+
811
+ // Used to label what part of the circular buffered iterdomain
812
+ // a for loop is materializing.
813
+ enum class CircularBufferLoopStage {
814
+ Prolog = 0,
815
+ Main,
816
+ Epilog,
817
+ LoadWarp,
818
+ ComputeWarp,
819
+ EndOfStages, // A special placeholder used to iterate over all stages
820
+ NotApplicable
821
+ };
822
+
823
+ // The circular buffer load expressions are cloned for these circular buffer
824
+ // loop types.
825
+ // e.g., No additional loads are required for the Epilogue stage.
826
+ inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) {
827
+ return stage == CircularBufferLoopStage::Prolog ||
828
+ stage == CircularBufferLoopStage::Main ||
829
+ stage == CircularBufferLoopStage::LoadWarp;
830
+ }
831
+
832
+ // The consuming expressions of circular buffer are cloned for these circular
833
+ // buffer loop types.
834
+ // e.g., No actual computation occurs in the Prologue stage.
835
+ inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) {
836
+ return stage == CircularBufferLoopStage::Main ||
837
+ stage == CircularBufferLoopStage::Epilog ||
838
+ stage == CircularBufferLoopStage::ComputeWarp;
839
+ }
840
+
841
+ // A loop type may have WAR hazard if any of the following is true:
842
+ // - The load *in this loop type* may overwrite a buffer being read by a
843
+ // compute somewhere (*may or may not be in this loop*)
844
+ // - The compute *in this loop type* reads circular buffer TVs that, if not
845
+ // properly handled, could be overwriten by a circular buffer loading
846
+ // somewhere (*may or may not be in this loop*)
847
+ inline bool mayHaveWarHazard(CircularBufferLoopStage stage) {
848
+ return stage == CircularBufferLoopStage::Main ||
849
+ stage == CircularBufferLoopStage::LoadWarp ||
850
+ stage == CircularBufferLoopStage::ComputeWarp;
851
+ }
852
+
853
+ //! Supported swizzle types,
854
+ //! corresponds to swizzles functions on the runtime cuda
855
+ //! naming it swizzle_2d to reserve the options to have a swizzle_1d.
856
+ //!
857
+ //! TODO: unify with existing swizzle logic, currently
858
+ //! doesn't have the same type.
859
+ enum class SwizzleType { NoSwizzle = 0, XOR, CyclicShift };
860
+ enum class Swizzle2DType { NoSwizzle = 0, ZShape, XOR, CyclicShift };
861
+
862
+ //! Modes of swizzle, see [Note on swizzle mode].
863
+ enum class SwizzleMode { NoSwizzle = 0, Data, Loop };
864
+
865
+ // Returns if function needs an f suffix on the operator when operating on a
866
+ // float value i.e. sin->sinf
867
+ bool needFloatSuffix(UnaryOpType t);
868
+ bool needFloatSuffix(BinaryOpType t);
869
+ bool needFloatSuffix(RNGOpType t);
870
+
871
+ ValType promoteType(const ValType& t1, const ValType& t2);
872
+
873
+ #define HANDLE_TYPE_PROMOTION(Type1, Type2) \
874
+ if (t1 == NativeTypeToDataType<Type1>::type && \
875
+ t2 == NativeTypeToDataType<Type2>::type) { \
876
+ return NativeTypeToDataType<std::common_type_t<Type1, Type2>>::type; \
877
+ }
878
+
879
+ #define HANDLE_TYPE_PROMOTION1(Type1) \
880
+ HANDLE_TYPE_PROMOTION(Type1, float); \
881
+ HANDLE_TYPE_PROMOTION(Type1, double); \
882
+ HANDLE_TYPE_PROMOTION(Type1, int64_t); \
883
+ HANDLE_TYPE_PROMOTION(Type1, int); \
884
+ HANDLE_TYPE_PROMOTION(Type1, bool); \
885
+ HANDLE_TYPE_PROMOTION(Type1, std::complex<float>); \
886
+ HANDLE_TYPE_PROMOTION(Type1, std::complex<double>)
887
+
888
+ inline DataType promoteType(const DataType& t1, const DataType& t2) {
889
+ if (t1 == t2) {
890
+ return t1;
891
+ }
892
+ // pointer +- integer = pointer
893
+ if (isPointerType(t1) && isIntegralType(t2)) {
894
+ return t1;
895
+ }
896
+ if (isPointerType(t2) && isIntegralType(t1)) {
897
+ return t2;
898
+ }
899
+ // When seeing DataType::Index, assuming we are computing index, so propagate
900
+ // DataType::Index
901
+ if ((t1 == DataType::Index && isIntegralType(t2)) ||
902
+ (t2 == DataType::Index && isIntegralType(t1))) {
903
+ return DataType::Index;
904
+ }
905
+ // Workaround a case where C++ and ATen have different type promotion rules
906
+ if ((t1 == DataType::Double && t2 == DataType::ComplexFloat) ||
907
+ (t2 == DataType::Double && t1 == DataType::ComplexFloat)) {
908
+ // WARNING: ATen and C++ behave differently for this case. ATen returns
909
+ // DataType::ComplexDouble but C++ returns DataType::ComplexFloat. Right now
910
+ // we choose to be consistent with ATen.
911
+ // TODO: I am pretty sure that for some cases we would need C++'s promotion
912
+ // rule, for example, when we are simplifying scalar expressions, and for
913
+ // other cases, we need ATen's promotion rule, for example, when we define
914
+ // fusion from ATen graph. Fortunately, right now this is the only case to
915
+ // worry about, and I don't think in practice, using ATen's rule would cause
916
+ // any trouble.
917
+ return DataType::ComplexDouble;
918
+ }
919
+ // Use C++ promotion rule when dtype has a native C++ type
920
+ HANDLE_TYPE_PROMOTION1(float);
921
+ HANDLE_TYPE_PROMOTION1(double);
922
+ HANDLE_TYPE_PROMOTION1(int64_t);
923
+ HANDLE_TYPE_PROMOTION1(int);
924
+ HANDLE_TYPE_PROMOTION1(bool);
925
+ HANDLE_TYPE_PROMOTION1(std::complex<float>);
926
+ HANDLE_TYPE_PROMOTION1(std::complex<double>);
927
+ // double + half/bfloat16 = double
928
+ if ((t1 == DataType::Double && isFloatingPointType(t2)) ||
929
+ (t2 == DataType::Double && isFloatingPointType(t1))) {
930
+ return DataType::Double;
931
+ }
932
+ // float + half/bfloat16 = float
933
+ // half + bfloat16 = float
934
+ if (isFloatingPointType(t1) && isFloatingPointType(t2)) {
935
+ return DataType::Float;
936
+ }
937
+ // complex + half/bfloat16 = complex
938
+ if (isComplexType(t1)) {
939
+ return t1;
940
+ }
941
+ if (isComplexType(t2)) {
942
+ return t2;
943
+ }
944
+ // half + integers/bool = half
945
+ // bfloat16 + integers/bool = bfloat16
946
+ if (isFloatingPointType(t1)) {
947
+ return t1;
948
+ }
949
+ if (isFloatingPointType(t2)) {
950
+ return t2;
951
+ }
952
+ NVF_CHECK(false, "Expected promotable DataTypes but got: ", t1, " and ", t2);
953
+ }
954
+
955
+ #undef HANDLE_TYPE_PROMOTION
956
+ #undef HANDLE_TYPE_PROMOTION1
957
+
958
+ template <typename... Args>
959
+ inline DataType promoteType(
960
+ const DataType& t1,
961
+ const DataType& t2,
962
+ const Args&... args) {
963
+ return promoteType(t1, promoteType(t2, promoteType(args...)));
964
+ }
965
+
966
+ inline DataType promoteType(const std::vector<DataType>& types) {
967
+ NVF_CHECK(!types.empty(), "Can not promote empty type vector")
968
+ DataType result = types.at(0);
969
+ for (const auto& t : types) {
970
+ result = promoteType(result, t);
971
+ }
972
+ return result;
973
+ }
974
+
975
+ // If type cannot be found (i.e. codegen does not support provided type) returns
976
+ // DataType::Null
977
+ NVF_API DataType aten_to_data_type(const at::ScalarType& scalar_type);
978
+ NVF_API at::ScalarType data_type_to_aten(const DataType& data_type);
979
+
980
+ NVF_API std::ostream& operator<<(std::ostream&, const ValType);
981
+ std::ostream& operator<<(std::ostream&, const PredicateType);
982
+ NVF_API std::ostream& operator<<(std::ostream&, const DataType);
983
+ std::ostream& operator<<(std::ostream&, const UnaryOpType);
984
+ NVF_API std::ostream& operator<<(std::ostream&, const BinaryOpType);
985
+ std::ostream& operator<<(std::ostream&, const TernaryOpType);
986
+ std::ostream& operator<<(std::ostream&, const ScatterOpType);
987
+ std::ostream& operator<<(std::ostream&, const RNGOpType);
988
+ NVF_API std::ostream& operator<<(std::ostream&, const ParallelType);
989
+ NVF_API std::ostream& operator<<(std::ostream&, const MemoryType);
990
+ NVF_API std::ostream& operator<<(std::ostream&, const IterType);
991
+ std::ostream& operator<<(std::ostream&, const IdMappingMode);
992
+ NVF_API std::ostream& operator<<(std::ostream&, const LoadStoreOpType);
993
+ std::ostream& operator<<(std::ostream&, const CircularBufferLoopStage);
994
+ std::ostream& operator<<(std::ostream&, const SwizzleType&);
995
+ std::ostream& operator<<(std::ostream&, const Swizzle2DType&);
996
+ std::ostream& operator<<(std::ostream&, const SwizzleMode&);
997
+ std::ostream& operator<<(std::ostream&, const KernelIndexMode&);
998
+ NVF_API std::ostream& operator<<(std::ostream&, const CacheOp&);
999
+ std::ostream& operator<<(std::ostream& os, const std::optional<bool>&);
1000
+
1001
+ std::string stringifyThreadSize(const ParallelType);
1002
+ std::string stringifyThread(const ParallelType);
1003
+ std::string typePrefix(const DataType);
1004
+
1005
+ // TODO: ThreadDim should be BlockDim and BlockDim should be GridDim
1006
+ // Returns if parallel type is TID[x, y, z]
1007
+ NVF_API bool isParallelTypeThreadDim(ParallelType);
1008
+ // Returns if parallel type is BID[x, y, z]
1009
+ NVF_API bool isParallelTypeBlockDim(ParallelType);
1010
+ // Returns if parallel type is a grid or block parallelization dimension
1011
+ NVF_API bool isParallelTypeThread(ParallelType);
1012
+ // Returns if parallel type is DIDx
1013
+ NVF_API bool isParallelTypeDeviceDim(ParallelType);
1014
+
1015
+ NVF_API bool isParallelTypeVectorize(ParallelType);
1016
+
1017
+ std::optional<std::string> inline_op_str(const UnaryOpType);
1018
+ std::optional<std::string> inline_op_str(const BinaryOpType);
1019
+ std::optional<std::string> inline_op_str(const RNGOpType);
1020
+ std::optional<std::string> integer_op_str(const BinaryOpType);
1021
+ std::optional<std::string> bool_op_str(const BinaryOpType);
1022
+ const char* predicate_type2string(PredicateType t);
1023
+ const char* load_store_type2string(LoadStoreOpType t);
1024
+
1025
+ std::optional<std::string> cast_func_str(const std::pair<DataType, DataType>&);
1026
+
1027
+ constexpr inline size_t primDataTypeSize(PrimDataType type) {
1028
+ switch (type) {
1029
+ case DataType::Bool:
1030
+ return sizeof(bool);
1031
+ case DataType::ComplexDouble:
1032
+ return sizeof(std::complex<double>);
1033
+ case DataType::ComplexFloat:
1034
+ return sizeof(std::complex<float>);
1035
+ case DataType::Double:
1036
+ return sizeof(double);
1037
+ case DataType::Float:
1038
+ return sizeof(float);
1039
+ case DataType::Half:
1040
+ return sizeof(at::Half);
1041
+ case DataType::BFloat16:
1042
+ return sizeof(at::BFloat16);
1043
+ case DataType::Float8_e4m3fn:
1044
+ return sizeof(at::Float8_e4m3fn);
1045
+ case DataType::Float8_e5m2:
1046
+ return sizeof(at::Float8_e5m2);
1047
+ case DataType::Index:
1048
+ NVF_THROW("The actual type of Index is only known at compile time.");
1049
+ case DataType::Char:
1050
+ return sizeof(int8_t);
1051
+ case DataType::Short:
1052
+ return sizeof(int16_t);
1053
+ case DataType::Int32:
1054
+ return sizeof(int32_t);
1055
+ case DataType::Int:
1056
+ return sizeof(int64_t);
1057
+ case DataType::Byte:
1058
+ return sizeof(uint8_t);
1059
+ case DataType::UInt16:
1060
+ return sizeof(uint16_t);
1061
+ case DataType::UInt32:
1062
+ case DataType::SMemAddress:
1063
+ case DataType::TMemAddress:
1064
+ return sizeof(uint32_t);
1065
+ case DataType::UInt64:
1066
+ return sizeof(uint64_t);
1067
+ default:
1068
+ NVF_THROW("Size undefined for data type.");
1069
+ }
1070
+ }
1071
+
1072
+ enum class LaunchConfigType {
1073
+ Compatible,
1074
+ SharedMemory,
1075
+ BIDz,
1076
+ BIDy,
1077
+ BIDx,
1078
+ TIDz,
1079
+ TIDy,
1080
+ TIDx
1081
+ };
1082
+
1083
+ const char* const kMagicZeroName = "nvfuser_zero";
1084
+
1085
+ //! Maximum number of reductions that can be grouped together. The
1086
+ //! limit can be increased by extending struct Tuple define in tuple.cu.
1087
+ static constexpr int kMaxNumGroupedReductions = 16;
1088
+
1089
+ Pointer::Pointer(void* ptr, DataType dtype)
1090
+ : ptr_(reinterpret_cast<std::byte*>(ptr)), size_(dataTypeSize(dtype)) {}
1091
+
1092
+ inline PolymorphicValue castToDtype(
1093
+ PolymorphicValue value,
1094
+ const DataType& dtype) {
1095
+ if (!value.hasValue()) {
1096
+ return value;
1097
+ }
1098
+ // Cast the given value to the given data type. This enables interface
1099
+ // like: IrBuilder::create<Val>(0, DataType::Double) where value is
1100
+ // an integer but the desired data type is double.
1101
+ if (!hasCompatibleDataType(value, dtype)) {
1102
+ PolymorphicValue::for_all_types([&](auto _) {
1103
+ using T = typename decltype(_)::type;
1104
+ if constexpr (IsPrimitiveNativeType<T>::value) {
1105
+ if (isCompatibleDataType(NativeTypeToDataType<T>::type, dtype)) {
1106
+ value = PolymorphicValue(static_cast<T>(value));
1107
+ }
1108
+ }
1109
+ // TODO: support arrays and pointers
1110
+ });
1111
+ }
1112
+ return value;
1113
+ }
1114
+
1115
+ // Converts an enum to its underlying type.
1116
+ // It corresponds with std::to_underlying introduced in c++23
1117
+ // https://en.cppreference.com/w/cpp/utility/to_underlying
1118
+ template <typename E>
1119
+ constexpr auto toUnderlying(E e) noexcept {
1120
+ return static_cast<std::underlying_type_t<E>>(e);
1121
+ }
1122
+
1123
+ enum class AsyncOpType { NotAsync, CpAsync, CpAsyncBulk, WgMma };
1124
+
1125
+ } // namespace nvfuser