nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,257 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <macros.h>
11
+
12
+ #include <exceptions.h>
13
+ #include <type.h>
14
+ #include <visibility.h>
15
+
16
+ #include <cstring>
17
+ #include <ostream>
18
+
19
+ #include <cstdint>
20
+
21
+ namespace nvfuser {
22
+
23
+ constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] ";
24
+
25
+ //! Named descriptors of domains in matmul
26
+ enum class MatmulDimRole { M = 0, N, K, Batch };
27
+
28
+ std::string toString(MatmulDimRole role);
29
+
30
+ //! Named descriptors of TensorView roles in fusion
31
+ //! OPERAND_A - an input to the fusion that is a producer of a matmul "A" input
32
+ //! OPERAND_B - an input to the fusion that is a producer of a matmul "B" input
33
+ //! OUTPUT - fusion outputs that have the matmul as a dependency
34
+ //! EPILOGUE_INPUT - an input to the fusion that is a producer of an
35
+ //! OUTPUT, but not of an MMA input
36
+ //!
37
+ //! Note: bias vector tensors will be assigned to the EPILOGUE_INPUT role.
38
+ enum class MatmulTensorRole {
39
+ OPERAND_A = 0,
40
+ OPERAND_B,
41
+ OUTPUT,
42
+ EPILOGUE_INPUT
43
+ };
44
+
45
+ //! The expected number of occurances of core TensorView roles in fusion
46
+ static constexpr size_t MATMUL_CORE_ROLES_EXPECTED_COUNT = 1;
47
+
48
+ //! Utility data structure for recording gemm tiles
49
+ struct GemmTile {
50
+ int64_t m, n, k;
51
+ GemmTile(int64_t m_, int64_t n_, int64_t k_) : m(m_), n(n_), k(k_) {}
52
+
53
+ bool operator==(const GemmTile& other) const {
54
+ return m == other.m && n == other.n && k == other.k;
55
+ }
56
+
57
+ GemmTile operator/(const GemmTile& other) const {
58
+ return GemmTile(m / other.m, n / other.n, k / other.k);
59
+ }
60
+
61
+ std::vector<int64_t> toVector() const {
62
+ return {m, n, k};
63
+ }
64
+ };
65
+
66
+ //! Utility data structure for recording gemm tiles
67
+ struct MatMulTileOptions {
68
+ GemmTile cta_tile = GemmTile(128, 128, 32);
69
+ GemmTile warp_tile = GemmTile(64, 64, 32);
70
+
71
+ MatMulTileOptions() = default;
72
+ MatMulTileOptions(GemmTile cta_tile_, GemmTile warp_tile_)
73
+ : cta_tile(cta_tile_), warp_tile(warp_tile_) {}
74
+
75
+ bool operator==(const MatMulTileOptions& other) const {
76
+ return cta_tile == other.cta_tile && warp_tile == other.warp_tile;
77
+ }
78
+ };
79
+
80
+ enum class MmaMacro : uint64_t;
81
+
82
+ struct MmaMacroEncode {
83
+ enum class Arch : uint16_t { NoMma, Volta, Turing, Ampere, Hopper } arch;
84
+ uint16_t m;
85
+ uint16_t n;
86
+ uint16_t k;
87
+
88
+ constexpr operator uint64_t() {
89
+ return (uint64_t)arch << 48 | (uint64_t)m << 32 | (uint64_t)n << 16 |
90
+ (uint64_t)k;
91
+ }
92
+
93
+ constexpr operator MmaMacro() {
94
+ return static_cast<MmaMacro>(static_cast<uint64_t>(*this));
95
+ }
96
+
97
+ constexpr MmaMacroEncode(MmaMacro macro)
98
+ : arch(Arch(toUnderlying(macro) >> 48)),
99
+ m((toUnderlying(macro) >> 32) & 0xFFFF),
100
+ n((toUnderlying(macro) >> 16) & 0xFFFF),
101
+ k(toUnderlying(macro) & 0xFFFF) {}
102
+
103
+ constexpr MmaMacroEncode(Arch arch_, uint16_t m_, uint16_t n_, uint16_t k_)
104
+ : arch(arch_), m(m_), n(n_), k(k_) {}
105
+ };
106
+
107
+ static_assert(sizeof(MmaMacroEncode) == sizeof(uint64_t));
108
+
109
+ //! Type of mma instrinsic macro to use
110
+ //! This will translate to which mma intrinsic from runtime string
111
+ //! to be generated to implement the mma op. The current plan
112
+ //! is to have exactly one macro for each
113
+ //! (arch, datatype, operand layout) triple, though there
114
+ //! exists multiple possibilities for some cases, e.g. for Turing and fp16
115
+ //! one can use 16_8_8 or 16_8_16.
116
+ //! Will consider adding more choices that the scheduler can pick from
117
+ //! when our perf target becomes more fine grained, which is more likely in
118
+ //! latency bound kernels.
119
+
120
+ #define MACRO(arch, m, n, k) \
121
+ arch##_##m##_##n##_##k = MmaMacroEncode(MmaMacroEncode::Arch::arch, m, n, k)
122
+
123
+ enum class MmaMacro : uint64_t {
124
+ NoMMA = 0,
125
+
126
+ MACRO(Turing, 16, 8, 8),
127
+ MACRO(Turing, 16, 8, 16),
128
+ MACRO(Turing, 16, 16, 16),
129
+
130
+ MACRO(Ampere, 16, 8, 16),
131
+ MACRO(Ampere, 16, 16, 16),
132
+
133
+ MACRO(Hopper, 64, 8, 16),
134
+ MACRO(Hopper, 64, 16, 16),
135
+ MACRO(Hopper, 64, 24, 16),
136
+ MACRO(Hopper, 64, 32, 16),
137
+ MACRO(Hopper, 64, 40, 16),
138
+ MACRO(Hopper, 64, 48, 16),
139
+ MACRO(Hopper, 64, 56, 16),
140
+ MACRO(Hopper, 64, 64, 16),
141
+ MACRO(Hopper, 64, 72, 16),
142
+ MACRO(Hopper, 64, 80, 16),
143
+ MACRO(Hopper, 64, 88, 16),
144
+ MACRO(Hopper, 64, 96, 16),
145
+ MACRO(Hopper, 64, 104, 16),
146
+ MACRO(Hopper, 64, 112, 16),
147
+ MACRO(Hopper, 64, 120, 16),
148
+ MACRO(Hopper, 64, 128, 16),
149
+ MACRO(Hopper, 64, 136, 16),
150
+ MACRO(Hopper, 64, 144, 16),
151
+ MACRO(Hopper, 64, 152, 16),
152
+ MACRO(Hopper, 64, 160, 16),
153
+ MACRO(Hopper, 64, 168, 16),
154
+ MACRO(Hopper, 64, 176, 16),
155
+ MACRO(Hopper, 64, 184, 16),
156
+ MACRO(Hopper, 64, 192, 16),
157
+ MACRO(Hopper, 64, 200, 16),
158
+ MACRO(Hopper, 64, 208, 16),
159
+ MACRO(Hopper, 64, 216, 16),
160
+ MACRO(Hopper, 64, 224, 16),
161
+ MACRO(Hopper, 64, 232, 16),
162
+ MACRO(Hopper, 64, 240, 16),
163
+ MACRO(Hopper, 64, 248, 16),
164
+ MACRO(Hopper, 64, 256, 16),
165
+ };
166
+
167
+ #undef MACRO
168
+
169
+ //! [Operand Layout Convention]
170
+ //! Operand layout, T=transposed/row_major, N=normal/col_major
171
+ //! Ordered by position of K
172
+ //! NT : K,M x K,N -> M,N
173
+ //! TT : M,K X K,N -> M,N
174
+ //! TN : M,K X N,K -> M,N
175
+ //! NN : K,M X N,K -> M,N
176
+ enum class MmaLayout { NT = 0, TT, TN, NN };
177
+
178
+ //! Indicates which dimension is innermost in the allocation domain of an
179
+ //! operand
180
+ enum class UnitDim { K, M_or_N };
181
+
182
+ //! Utility to annotate which input of mma this option struct describes
183
+ enum class MmaOperand { A, B };
184
+
185
+ //! GPU arch check for macro type
186
+ inline bool isTuring(MmaMacro macro) {
187
+ return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Turing;
188
+ }
189
+
190
+ inline bool isAmpere(MmaMacro macro) {
191
+ return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Ampere;
192
+ }
193
+
194
+ inline bool isHopper(MmaMacro macro) {
195
+ return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Hopper;
196
+ }
197
+
198
+ //! Get the m size from macro type
199
+ inline int64_t getM(MmaMacro macro) {
200
+ return MmaMacroEncode(macro).m;
201
+ }
202
+
203
+ //! Get the n size from macro type
204
+ inline int64_t getN(MmaMacro macro) {
205
+ return MmaMacroEncode(macro).n;
206
+ }
207
+
208
+ //! Get the k size from macro type
209
+ inline int64_t getK(MmaMacro macro) {
210
+ return MmaMacroEncode(macro).k;
211
+ }
212
+
213
+ // Unpacked constants from macro type:
214
+ // exact numbers are defined by each individual instruction.
215
+ int getOutputRegisterSize(MmaMacro macro);
216
+ int getInputARegisterSize(MmaMacro macro);
217
+ int getInputBRegisterSize(MmaMacro macro);
218
+
219
+ // Unpack MMA op shape
220
+ GemmTile getMmaOpShape(MmaMacro macro);
221
+
222
+ // Warning: The values of the enum class must match the matrix descriptor as
223
+ // specified in:
224
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
225
+ // Do not edit the values of the enum class unless you know what you are doing.
226
+ enum class MmaInputSmemSwizzle {
227
+ None = 0,
228
+ B128 = 1,
229
+ B64 = 2,
230
+ B32 = 3,
231
+ };
232
+
233
+ constexpr int64_t core_matrix_width_bytes = 16;
234
+
235
+ int64_t getBytesFromSwizzle(MmaInputSmemSwizzle swizzle);
236
+ MmaInputSmemSwizzle getSwizzleFromBytes(int64_t bytes);
237
+
238
+ // MMA stringify utils
239
+ NVF_API std::string toString(MmaLayout input_layout);
240
+ std::string toString(const GemmTile& tile);
241
+ NVF_API std::string toString(const MatMulTileOptions& opts);
242
+ NVF_API std::string toString(MmaMacro macro);
243
+ NVF_API std::string toString(MmaInputSmemSwizzle swizzle);
244
+ inline std::ostream& operator<<(
245
+ std::ostream& os,
246
+ MmaInputSmemSwizzle input_layout) {
247
+ os << toString(input_layout);
248
+ return os;
249
+ }
250
+
251
+ // MMA hash utils
252
+ NVF_API size_t hash(MmaMacro macro);
253
+ size_t hash(MmaLayout input_layout);
254
+ size_t hash(const GemmTile& tile);
255
+ NVF_API size_t hash(const MatMulTileOptions& opts);
256
+
257
+ } // namespace nvfuser
@@ -0,0 +1,175 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ATen/core/TensorBody.h>
11
+ #include <ATen/core/ivalue.h>
12
+ #include <c10/util/intrusive_ptr.h>
13
+
14
+ namespace c10d {
15
+
16
+ inline void setDebugLevelFromEnvironment() {}
17
+
18
+ class Work : public torch::CustomClassHolder {
19
+ public:
20
+ void wait() {}
21
+ };
22
+
23
+ struct ReduceOp : torch::CustomClassHolder {
24
+ enum RedOpType {
25
+ SUM,
26
+ AVG,
27
+ PRODUCT,
28
+ MIN,
29
+ MAX,
30
+ BAND,
31
+ BOR,
32
+ BXOR,
33
+ UNUSED,
34
+ };
35
+
36
+ ReduceOp() = default;
37
+ ReduceOp(RedOpType op) : op_(op) {}
38
+
39
+ RedOpType op_ = UNUSED;
40
+ };
41
+
42
+ struct ReduceScatterOptions {
43
+ ReduceOp reduceOp = ReduceOp::UNUSED;
44
+ };
45
+
46
+ struct ScatterOptions {
47
+ int64_t rootRank = 0;
48
+ };
49
+
50
+ struct AllgatherOptions {};
51
+
52
+ struct GatherOptions {
53
+ int64_t rootRank = 0;
54
+ };
55
+
56
+ struct BroadcastOptions {
57
+ int64_t rootRank = 0;
58
+ };
59
+
60
+ struct AllreduceOptions {
61
+ ReduceOp reduceOp = ReduceOp::UNUSED;
62
+ };
63
+
64
+ struct ReduceOptions {
65
+ ReduceOp reduceOp = ReduceOp::UNUSED;
66
+ int64_t rootRank = 0;
67
+ };
68
+
69
+ struct BarrierOptions {
70
+ std::vector<int64_t> device_ids;
71
+ };
72
+
73
+ class Backend : public torch::CustomClassHolder {
74
+ public:
75
+ void startCoalescing() {}
76
+
77
+ c10::intrusive_ptr<Work> endCoalescing() {
78
+ return c10::make_intrusive<Work>();
79
+ }
80
+
81
+ const std::string getBackendName() const {
82
+ return "";
83
+ };
84
+
85
+ c10::intrusive_ptr<Work> barrier(
86
+ const BarrierOptions& opts = BarrierOptions()) {
87
+ return c10::make_intrusive<Work>();
88
+ }
89
+
90
+ c10::intrusive_ptr<Work> send(
91
+ std::vector<at::Tensor>& tensors,
92
+ int dstRank,
93
+ int tag) {
94
+ return c10::make_intrusive<Work>();
95
+ }
96
+
97
+ c10::intrusive_ptr<Work> recv(
98
+ std::vector<at::Tensor>& tensors,
99
+ int srcRank,
100
+ int tag) {
101
+ return c10::make_intrusive<Work>();
102
+ }
103
+
104
+ c10::intrusive_ptr<Work> allgather(
105
+ std::vector<std::vector<at::Tensor>>& outputTensors,
106
+ std::vector<at::Tensor>& inputTensors,
107
+ const AllgatherOptions& opts = AllgatherOptions()) {
108
+ return c10::make_intrusive<Work>();
109
+ }
110
+
111
+ c10::intrusive_ptr<Work> _allgather_base(
112
+ at::Tensor& outputBuffer,
113
+ at::Tensor& inputBuffer,
114
+ const AllgatherOptions& opts = AllgatherOptions()) {
115
+ return c10::make_intrusive<Work>();
116
+ }
117
+
118
+ c10::intrusive_ptr<Work> gather(
119
+ std::vector<std::vector<at::Tensor>>& outputTensors,
120
+ std::vector<at::Tensor>& inputTensors,
121
+ const GatherOptions& opts = GatherOptions()) {
122
+ return c10::make_intrusive<Work>();
123
+ }
124
+
125
+ c10::intrusive_ptr<Work> reduce_scatter(
126
+ std::vector<at::Tensor>& outputTensors,
127
+ std::vector<std::vector<at::Tensor>>& inputTensors,
128
+ const ReduceScatterOptions& opts = ReduceScatterOptions()) {
129
+ return c10::make_intrusive<Work>();
130
+ }
131
+
132
+ c10::intrusive_ptr<Work> _reduce_scatter_base(
133
+ at::Tensor& outputBuffer,
134
+ at::Tensor& inputBuffer,
135
+ const ReduceScatterOptions& opts = ReduceScatterOptions()) {
136
+ return c10::make_intrusive<Work>();
137
+ }
138
+
139
+ c10::intrusive_ptr<Work> scatter(
140
+ std::vector<at::Tensor>& outputTensors,
141
+ std::vector<std::vector<at::Tensor>>& inputTensors,
142
+ const ScatterOptions& opts = ScatterOptions()) {
143
+ return c10::make_intrusive<Work>();
144
+ }
145
+
146
+ c10::intrusive_ptr<Work> broadcast(
147
+ std::vector<at::Tensor>& tensors,
148
+ const BroadcastOptions& opts = BroadcastOptions()) {
149
+ return c10::make_intrusive<Work>();
150
+ }
151
+
152
+ c10::intrusive_ptr<Work> allreduce(
153
+ std::vector<at::Tensor>& tensors,
154
+ const AllreduceOptions& opts = AllreduceOptions()) {
155
+ return c10::make_intrusive<Work>();
156
+ }
157
+
158
+ c10::intrusive_ptr<Work> reduce(
159
+ std::vector<at::Tensor>& tensors,
160
+ const ReduceOptions& opts = ReduceOptions()) {
161
+ return c10::make_intrusive<Work>();
162
+ }
163
+
164
+ int getSize() const {
165
+ return 0;
166
+ }
167
+ };
168
+
169
+ struct TCPStoreOptions {
170
+ static constexpr uint16_t kDefaultPort = 0;
171
+ };
172
+
173
+ class TCPStore : public torch::CustomClassHolder {};
174
+
175
+ } // namespace c10d
@@ -0,0 +1,232 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ir/base_nodes.h>
11
+ #include <ir/builder.h>
12
+ #include <ir/interface_nodes.h>
13
+ #include <multidevice/communicator.h>
14
+ #include <multidevice/device_mesh.h>
15
+ #include <multidevice/multidevice.h>
16
+ #ifdef NVFUSER_DISTRIBUTED
17
+ #include <torch/csrc/distributed/c10d/Types.hpp>
18
+ #else
19
+ #include <multidevice/c10d_mock.h>
20
+ #endif
21
+ #include <type.h>
22
+ #include <visibility.h>
23
+
24
+ namespace nvfuser {
25
+
26
+ enum class CommunicationType {
27
+ Gather,
28
+ Allgather,
29
+ Scatter,
30
+ Reduce,
31
+ Allreduce,
32
+ ReduceScatter,
33
+ Broadcast,
34
+ SendRecv
35
+ };
36
+
37
+ std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
38
+
39
+ using RedOpType = c10d::ReduceOp::RedOpType;
40
+
41
+ // The class "Communication" represents a MPI-style communication
42
+ // communication operation to be executed on the network. The base class
43
+ // Communication should not be used directly but through its derived classes:
44
+ // Broadcast, Gather, Scatter, Allgather, and SendRecv. Other collectives will
45
+ // be added later.
46
+ class Communication : public Expr {
47
+ public:
48
+ using Expr::Expr;
49
+ // Only specify `root` for types that have root.
50
+ // Only specify `red_op` for reduction types.
51
+ // Only specify `scattered_axis` for ReduceScatter.
52
+ Communication(
53
+ IrBuilderPasskey passkey,
54
+ CommunicationType type,
55
+ TensorView* out,
56
+ TensorView* in,
57
+ Team team, // All devices involved in this communication. It must include
58
+ // `root`. It can be a subset of `root`+`mesh` in case of 2D
59
+ // sharding.
60
+ DeviceIdxType root = -1,
61
+ RedOpType red_op = RedOpType::UNUSED,
62
+ int64_t scattered_axis = -1);
63
+
64
+ Communication(const Communication& other) = delete;
65
+ Communication& operator=(const Communication& other) = delete;
66
+ Communication(Communication&& other) = delete;
67
+ Communication& operator=(Communication&& other) = delete;
68
+
69
+ NVFUSER_DECLARE_CLONE_AND_CREATE
70
+
71
+ std::string toString(int indent_size = 0) const override;
72
+ std::string toInlineString(int indent_size = 0) const override;
73
+ const char* getOpString() const override {
74
+ return "Communication";
75
+ }
76
+
77
+ CommunicationType type() const {
78
+ return attribute<CommunicationType>(0);
79
+ }
80
+
81
+ TensorView* out() const {
82
+ return output(0)->as<TensorView>();
83
+ }
84
+
85
+ TensorView* in() const {
86
+ return input(0)->as<TensorView>();
87
+ }
88
+
89
+ const Team& team() const {
90
+ return attribute<Team>(1);
91
+ }
92
+
93
+ // A convenience helper so the user doesn't need to convert size_t to int64_t.
94
+ int64_t team_size() const {
95
+ return static_cast<int64_t>(team().size());
96
+ }
97
+
98
+ DeviceIdxType root() const {
99
+ return attribute<DeviceIdxType>(2);
100
+ }
101
+
102
+ RedOpType reduceOp() const {
103
+ return attribute<RedOpType>(3);
104
+ }
105
+
106
+ int64_t scatteredAxis() const {
107
+ return attribute<int64_t>(4);
108
+ }
109
+
110
+ // PyTorch's process group expects the root to be specified
111
+ // as an integer between 0 and world_size-1. We choose it to be
112
+ // the device's relative index within the team
113
+ int64_t getRootRelativeIndex();
114
+
115
+ private:
116
+ void validate();
117
+ };
118
+
119
+ enum class P2PCommunicationType { SEND, RECV };
120
+
121
+ std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type);
122
+
123
+ class P2PCommunication : public Expr {
124
+ public:
125
+ using Expr::Expr;
126
+
127
+ P2PCommunication(
128
+ IrBuilderPasskey passkey,
129
+ P2PCommunicationType type,
130
+ TensorView* buffer,
131
+ Val* peer);
132
+
133
+ P2PCommunication(const P2PCommunication& other) = delete;
134
+ P2PCommunication& operator=(const P2PCommunication& other) = delete;
135
+ P2PCommunication(P2PCommunication&& other) = delete;
136
+ P2PCommunication& operator=(P2PCommunication&& other) = delete;
137
+
138
+ NVFUSER_DECLARE_CLONE_AND_CREATE
139
+
140
+ std::string toString(int indent_size = 0) const override;
141
+ std::string toInlineString(int indent_size = 0) const override;
142
+ const char* getOpString() const override {
143
+ return "P2PCommunication";
144
+ }
145
+
146
+ P2PCommunicationType type() const {
147
+ return attribute<P2PCommunicationType>(0);
148
+ }
149
+
150
+ TensorView* buffer() const {
151
+ return input(0)->as<TensorView>();
152
+ }
153
+
154
+ Val* peer() const {
155
+ return attributeVal(1);
156
+ }
157
+ };
158
+
159
+ // The method "post" triggers the execution of the communication. This call is
160
+ // non-blocking. The communication can be posted multiple times.
161
+ // It is assumed that the current device_index (given by
162
+ // communicator.deviceId()) belongs to the team of the communication,
163
+ // otherwise an error is thrown.
164
+ //
165
+ // NOTE: pytorch's NCCL process group API needs <team_size> buffers on root for
166
+ // scatter/gather operation.
167
+ // (*) Broadcast
168
+ // Copies the root's src buffer to each device's dst buffer
169
+ // Requirements:
170
+ // - the root is set and belongs to the team
171
+ // - the root has one src buffer, and no or one dst buffer
172
+ // - non-roots have no src buffer and one dst buffer
173
+ // - all buffers have the same size
174
+ // (*) Gather
175
+ // Copies each device's source buffer to the root's respective src
176
+ // buffer. The order of the sender devices matches the order of the
177
+ // root's buffers.
178
+ // Requirements:
179
+ // - the root is set and belongs to the team
180
+ // - the root has one src buffer and <team_size> dst buffers
181
+ // - non-roots have one src buffer and no dst buffer
182
+ // - all buffers have the same size
183
+ // (*) Allgather
184
+ // Copies each device's src buffer to each device's respective src
185
+ // buffer. The order of the devices matches the order of the
186
+ // buffers
187
+ // Requirements:
188
+ // - all device have one src buffer and <team_size> dst buffers
189
+ // - all buffers have the same size
190
+ // (*) Scatter
191
+ // Copies each root's src buffer to each device's dst buffer.
192
+ // The order of the buffers matches the order of the receiver devices
193
+ // Requirements:
194
+ // - the root is set and belongs to the team
195
+ // - the root has <team_size> src buffers and one dst buffer
196
+ // - non-roots have no src buffer and one dst buffer
197
+ // - all buffers have the same size
198
+ // (*) Reduce
199
+ // Reduce the src buffers to the root's dst buffer.
200
+ // Requirements:
201
+ // - the root is set and belongs to the team
202
+ // - the root has one src buffers and one dst buffer
203
+ // - non-roots have one src buffer and no dst buffer
204
+ // - all buffers have the same size
205
+ // (*) Allreduce
206
+ // Reduce the src buffers to the dst buffer.
207
+ // Requirements:
208
+ // - all devices have one src buffer and one dst buffer
209
+ // - all buffers have the same size
210
+ // (*) ReduceScatter
211
+ // Reduce all the src buffers and shard the result to the dst buffers.
212
+ // Requirements:
213
+ // - all devices have <team_size> src buffer and one dst buffer
214
+ // - all buffers have the same size
215
+ // (*) SendRecv
216
+ // Copies the sender's src buffers to the receiver's dst buffer
217
+ // It is equivalent to a Broadcast with a team of size == 2
218
+ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
219
+ Communication* communication,
220
+ DeviceIdxType my_device_index,
221
+ c10d::Backend* backend,
222
+ at::Tensor input_tensor,
223
+ at::Tensor output_tensor);
224
+
225
+ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
226
+ P2PCommunication* communication,
227
+ DeviceIdxType my_device_index,
228
+ DeviceIdxType peer,
229
+ c10d::Backend* backend,
230
+ at::Tensor buffer);
231
+
232
+ } // namespace nvfuser