nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,257 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <macros.h>
|
11
|
+
|
12
|
+
#include <exceptions.h>
|
13
|
+
#include <type.h>
|
14
|
+
#include <visibility.h>
|
15
|
+
|
16
|
+
#include <cstring>
|
17
|
+
#include <ostream>
|
18
|
+
|
19
|
+
#include <cstdint>
|
20
|
+
|
21
|
+
namespace nvfuser {
|
22
|
+
|
23
|
+
constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] ";
|
24
|
+
|
25
|
+
//! Named descriptors of domains in matmul
|
26
|
+
enum class MatmulDimRole { M = 0, N, K, Batch };
|
27
|
+
|
28
|
+
std::string toString(MatmulDimRole role);
|
29
|
+
|
30
|
+
//! Named descriptors of TensorView roles in fusion
|
31
|
+
//! OPERAND_A - an input to the fusion that is a producer of a matmul "A" input
|
32
|
+
//! OPERAND_B - an input to the fusion that is a producer of a matmul "B" input
|
33
|
+
//! OUTPUT - fusion outputs that have the matmul as a dependency
|
34
|
+
//! EPILOGUE_INPUT - an input to the fusion that is a producer of an
|
35
|
+
//! OUTPUT, but not of an MMA input
|
36
|
+
//!
|
37
|
+
//! Note: bias vector tensors will be assigned to the EPILOGUE_INPUT role.
|
38
|
+
enum class MatmulTensorRole {
|
39
|
+
OPERAND_A = 0,
|
40
|
+
OPERAND_B,
|
41
|
+
OUTPUT,
|
42
|
+
EPILOGUE_INPUT
|
43
|
+
};
|
44
|
+
|
45
|
+
//! The expected number of occurances of core TensorView roles in fusion
|
46
|
+
static constexpr size_t MATMUL_CORE_ROLES_EXPECTED_COUNT = 1;
|
47
|
+
|
48
|
+
//! Utility data structure for recording gemm tiles
|
49
|
+
struct GemmTile {
|
50
|
+
int64_t m, n, k;
|
51
|
+
GemmTile(int64_t m_, int64_t n_, int64_t k_) : m(m_), n(n_), k(k_) {}
|
52
|
+
|
53
|
+
bool operator==(const GemmTile& other) const {
|
54
|
+
return m == other.m && n == other.n && k == other.k;
|
55
|
+
}
|
56
|
+
|
57
|
+
GemmTile operator/(const GemmTile& other) const {
|
58
|
+
return GemmTile(m / other.m, n / other.n, k / other.k);
|
59
|
+
}
|
60
|
+
|
61
|
+
std::vector<int64_t> toVector() const {
|
62
|
+
return {m, n, k};
|
63
|
+
}
|
64
|
+
};
|
65
|
+
|
66
|
+
//! Utility data structure for recording gemm tiles
|
67
|
+
struct MatMulTileOptions {
|
68
|
+
GemmTile cta_tile = GemmTile(128, 128, 32);
|
69
|
+
GemmTile warp_tile = GemmTile(64, 64, 32);
|
70
|
+
|
71
|
+
MatMulTileOptions() = default;
|
72
|
+
MatMulTileOptions(GemmTile cta_tile_, GemmTile warp_tile_)
|
73
|
+
: cta_tile(cta_tile_), warp_tile(warp_tile_) {}
|
74
|
+
|
75
|
+
bool operator==(const MatMulTileOptions& other) const {
|
76
|
+
return cta_tile == other.cta_tile && warp_tile == other.warp_tile;
|
77
|
+
}
|
78
|
+
};
|
79
|
+
|
80
|
+
enum class MmaMacro : uint64_t;
|
81
|
+
|
82
|
+
struct MmaMacroEncode {
|
83
|
+
enum class Arch : uint16_t { NoMma, Volta, Turing, Ampere, Hopper } arch;
|
84
|
+
uint16_t m;
|
85
|
+
uint16_t n;
|
86
|
+
uint16_t k;
|
87
|
+
|
88
|
+
constexpr operator uint64_t() {
|
89
|
+
return (uint64_t)arch << 48 | (uint64_t)m << 32 | (uint64_t)n << 16 |
|
90
|
+
(uint64_t)k;
|
91
|
+
}
|
92
|
+
|
93
|
+
constexpr operator MmaMacro() {
|
94
|
+
return static_cast<MmaMacro>(static_cast<uint64_t>(*this));
|
95
|
+
}
|
96
|
+
|
97
|
+
constexpr MmaMacroEncode(MmaMacro macro)
|
98
|
+
: arch(Arch(toUnderlying(macro) >> 48)),
|
99
|
+
m((toUnderlying(macro) >> 32) & 0xFFFF),
|
100
|
+
n((toUnderlying(macro) >> 16) & 0xFFFF),
|
101
|
+
k(toUnderlying(macro) & 0xFFFF) {}
|
102
|
+
|
103
|
+
constexpr MmaMacroEncode(Arch arch_, uint16_t m_, uint16_t n_, uint16_t k_)
|
104
|
+
: arch(arch_), m(m_), n(n_), k(k_) {}
|
105
|
+
};
|
106
|
+
|
107
|
+
static_assert(sizeof(MmaMacroEncode) == sizeof(uint64_t));
|
108
|
+
|
109
|
+
//! Type of mma instrinsic macro to use
|
110
|
+
//! This will translate to which mma intrinsic from runtime string
|
111
|
+
//! to be generated to implement the mma op. The current plan
|
112
|
+
//! is to have exactly one macro for each
|
113
|
+
//! (arch, datatype, operand layout) triple, though there
|
114
|
+
//! exists multiple possibilities for some cases, e.g. for Turing and fp16
|
115
|
+
//! one can use 16_8_8 or 16_8_16.
|
116
|
+
//! Will consider adding more choices that the scheduler can pick from
|
117
|
+
//! when our perf target becomes more fine grained, which is more likely in
|
118
|
+
//! latency bound kernels.
|
119
|
+
|
120
|
+
#define MACRO(arch, m, n, k) \
|
121
|
+
arch##_##m##_##n##_##k = MmaMacroEncode(MmaMacroEncode::Arch::arch, m, n, k)
|
122
|
+
|
123
|
+
enum class MmaMacro : uint64_t {
|
124
|
+
NoMMA = 0,
|
125
|
+
|
126
|
+
MACRO(Turing, 16, 8, 8),
|
127
|
+
MACRO(Turing, 16, 8, 16),
|
128
|
+
MACRO(Turing, 16, 16, 16),
|
129
|
+
|
130
|
+
MACRO(Ampere, 16, 8, 16),
|
131
|
+
MACRO(Ampere, 16, 16, 16),
|
132
|
+
|
133
|
+
MACRO(Hopper, 64, 8, 16),
|
134
|
+
MACRO(Hopper, 64, 16, 16),
|
135
|
+
MACRO(Hopper, 64, 24, 16),
|
136
|
+
MACRO(Hopper, 64, 32, 16),
|
137
|
+
MACRO(Hopper, 64, 40, 16),
|
138
|
+
MACRO(Hopper, 64, 48, 16),
|
139
|
+
MACRO(Hopper, 64, 56, 16),
|
140
|
+
MACRO(Hopper, 64, 64, 16),
|
141
|
+
MACRO(Hopper, 64, 72, 16),
|
142
|
+
MACRO(Hopper, 64, 80, 16),
|
143
|
+
MACRO(Hopper, 64, 88, 16),
|
144
|
+
MACRO(Hopper, 64, 96, 16),
|
145
|
+
MACRO(Hopper, 64, 104, 16),
|
146
|
+
MACRO(Hopper, 64, 112, 16),
|
147
|
+
MACRO(Hopper, 64, 120, 16),
|
148
|
+
MACRO(Hopper, 64, 128, 16),
|
149
|
+
MACRO(Hopper, 64, 136, 16),
|
150
|
+
MACRO(Hopper, 64, 144, 16),
|
151
|
+
MACRO(Hopper, 64, 152, 16),
|
152
|
+
MACRO(Hopper, 64, 160, 16),
|
153
|
+
MACRO(Hopper, 64, 168, 16),
|
154
|
+
MACRO(Hopper, 64, 176, 16),
|
155
|
+
MACRO(Hopper, 64, 184, 16),
|
156
|
+
MACRO(Hopper, 64, 192, 16),
|
157
|
+
MACRO(Hopper, 64, 200, 16),
|
158
|
+
MACRO(Hopper, 64, 208, 16),
|
159
|
+
MACRO(Hopper, 64, 216, 16),
|
160
|
+
MACRO(Hopper, 64, 224, 16),
|
161
|
+
MACRO(Hopper, 64, 232, 16),
|
162
|
+
MACRO(Hopper, 64, 240, 16),
|
163
|
+
MACRO(Hopper, 64, 248, 16),
|
164
|
+
MACRO(Hopper, 64, 256, 16),
|
165
|
+
};
|
166
|
+
|
167
|
+
#undef MACRO
|
168
|
+
|
169
|
+
//! [Operand Layout Convention]
|
170
|
+
//! Operand layout, T=transposed/row_major, N=normal/col_major
|
171
|
+
//! Ordered by position of K
|
172
|
+
//! NT : K,M x K,N -> M,N
|
173
|
+
//! TT : M,K X K,N -> M,N
|
174
|
+
//! TN : M,K X N,K -> M,N
|
175
|
+
//! NN : K,M X N,K -> M,N
|
176
|
+
enum class MmaLayout { NT = 0, TT, TN, NN };
|
177
|
+
|
178
|
+
//! Indicates which dimension is innermost in the allocation domain of an
|
179
|
+
//! operand
|
180
|
+
enum class UnitDim { K, M_or_N };
|
181
|
+
|
182
|
+
//! Utility to annotate which input of mma this option struct describes
|
183
|
+
enum class MmaOperand { A, B };
|
184
|
+
|
185
|
+
//! GPU arch check for macro type
|
186
|
+
inline bool isTuring(MmaMacro macro) {
|
187
|
+
return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Turing;
|
188
|
+
}
|
189
|
+
|
190
|
+
inline bool isAmpere(MmaMacro macro) {
|
191
|
+
return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Ampere;
|
192
|
+
}
|
193
|
+
|
194
|
+
inline bool isHopper(MmaMacro macro) {
|
195
|
+
return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Hopper;
|
196
|
+
}
|
197
|
+
|
198
|
+
//! Get the m size from macro type
|
199
|
+
inline int64_t getM(MmaMacro macro) {
|
200
|
+
return MmaMacroEncode(macro).m;
|
201
|
+
}
|
202
|
+
|
203
|
+
//! Get the n size from macro type
|
204
|
+
inline int64_t getN(MmaMacro macro) {
|
205
|
+
return MmaMacroEncode(macro).n;
|
206
|
+
}
|
207
|
+
|
208
|
+
//! Get the k size from macro type
|
209
|
+
inline int64_t getK(MmaMacro macro) {
|
210
|
+
return MmaMacroEncode(macro).k;
|
211
|
+
}
|
212
|
+
|
213
|
+
// Unpacked constants from macro type:
|
214
|
+
// exact numbers are defined by each individual instruction.
|
215
|
+
int getOutputRegisterSize(MmaMacro macro);
|
216
|
+
int getInputARegisterSize(MmaMacro macro);
|
217
|
+
int getInputBRegisterSize(MmaMacro macro);
|
218
|
+
|
219
|
+
// Unpack MMA op shape
|
220
|
+
GemmTile getMmaOpShape(MmaMacro macro);
|
221
|
+
|
222
|
+
// Warning: The values of the enum class must match the matrix descriptor as
|
223
|
+
// specified in:
|
224
|
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
|
225
|
+
// Do not edit the values of the enum class unless you know what you are doing.
|
226
|
+
enum class MmaInputSmemSwizzle {
|
227
|
+
None = 0,
|
228
|
+
B128 = 1,
|
229
|
+
B64 = 2,
|
230
|
+
B32 = 3,
|
231
|
+
};
|
232
|
+
|
233
|
+
constexpr int64_t core_matrix_width_bytes = 16;
|
234
|
+
|
235
|
+
int64_t getBytesFromSwizzle(MmaInputSmemSwizzle swizzle);
|
236
|
+
MmaInputSmemSwizzle getSwizzleFromBytes(int64_t bytes);
|
237
|
+
|
238
|
+
// MMA stringify utils
|
239
|
+
NVF_API std::string toString(MmaLayout input_layout);
|
240
|
+
std::string toString(const GemmTile& tile);
|
241
|
+
NVF_API std::string toString(const MatMulTileOptions& opts);
|
242
|
+
NVF_API std::string toString(MmaMacro macro);
|
243
|
+
NVF_API std::string toString(MmaInputSmemSwizzle swizzle);
|
244
|
+
inline std::ostream& operator<<(
|
245
|
+
std::ostream& os,
|
246
|
+
MmaInputSmemSwizzle input_layout) {
|
247
|
+
os << toString(input_layout);
|
248
|
+
return os;
|
249
|
+
}
|
250
|
+
|
251
|
+
// MMA hash utils
|
252
|
+
NVF_API size_t hash(MmaMacro macro);
|
253
|
+
size_t hash(MmaLayout input_layout);
|
254
|
+
size_t hash(const GemmTile& tile);
|
255
|
+
NVF_API size_t hash(const MatMulTileOptions& opts);
|
256
|
+
|
257
|
+
} // namespace nvfuser
|
@@ -0,0 +1,175 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/core/TensorBody.h>
|
11
|
+
#include <ATen/core/ivalue.h>
|
12
|
+
#include <c10/util/intrusive_ptr.h>
|
13
|
+
|
14
|
+
namespace c10d {
|
15
|
+
|
16
|
+
inline void setDebugLevelFromEnvironment() {}
|
17
|
+
|
18
|
+
class Work : public torch::CustomClassHolder {
|
19
|
+
public:
|
20
|
+
void wait() {}
|
21
|
+
};
|
22
|
+
|
23
|
+
struct ReduceOp : torch::CustomClassHolder {
|
24
|
+
enum RedOpType {
|
25
|
+
SUM,
|
26
|
+
AVG,
|
27
|
+
PRODUCT,
|
28
|
+
MIN,
|
29
|
+
MAX,
|
30
|
+
BAND,
|
31
|
+
BOR,
|
32
|
+
BXOR,
|
33
|
+
UNUSED,
|
34
|
+
};
|
35
|
+
|
36
|
+
ReduceOp() = default;
|
37
|
+
ReduceOp(RedOpType op) : op_(op) {}
|
38
|
+
|
39
|
+
RedOpType op_ = UNUSED;
|
40
|
+
};
|
41
|
+
|
42
|
+
struct ReduceScatterOptions {
|
43
|
+
ReduceOp reduceOp = ReduceOp::UNUSED;
|
44
|
+
};
|
45
|
+
|
46
|
+
struct ScatterOptions {
|
47
|
+
int64_t rootRank = 0;
|
48
|
+
};
|
49
|
+
|
50
|
+
struct AllgatherOptions {};
|
51
|
+
|
52
|
+
struct GatherOptions {
|
53
|
+
int64_t rootRank = 0;
|
54
|
+
};
|
55
|
+
|
56
|
+
struct BroadcastOptions {
|
57
|
+
int64_t rootRank = 0;
|
58
|
+
};
|
59
|
+
|
60
|
+
struct AllreduceOptions {
|
61
|
+
ReduceOp reduceOp = ReduceOp::UNUSED;
|
62
|
+
};
|
63
|
+
|
64
|
+
struct ReduceOptions {
|
65
|
+
ReduceOp reduceOp = ReduceOp::UNUSED;
|
66
|
+
int64_t rootRank = 0;
|
67
|
+
};
|
68
|
+
|
69
|
+
struct BarrierOptions {
|
70
|
+
std::vector<int64_t> device_ids;
|
71
|
+
};
|
72
|
+
|
73
|
+
class Backend : public torch::CustomClassHolder {
|
74
|
+
public:
|
75
|
+
void startCoalescing() {}
|
76
|
+
|
77
|
+
c10::intrusive_ptr<Work> endCoalescing() {
|
78
|
+
return c10::make_intrusive<Work>();
|
79
|
+
}
|
80
|
+
|
81
|
+
const std::string getBackendName() const {
|
82
|
+
return "";
|
83
|
+
};
|
84
|
+
|
85
|
+
c10::intrusive_ptr<Work> barrier(
|
86
|
+
const BarrierOptions& opts = BarrierOptions()) {
|
87
|
+
return c10::make_intrusive<Work>();
|
88
|
+
}
|
89
|
+
|
90
|
+
c10::intrusive_ptr<Work> send(
|
91
|
+
std::vector<at::Tensor>& tensors,
|
92
|
+
int dstRank,
|
93
|
+
int tag) {
|
94
|
+
return c10::make_intrusive<Work>();
|
95
|
+
}
|
96
|
+
|
97
|
+
c10::intrusive_ptr<Work> recv(
|
98
|
+
std::vector<at::Tensor>& tensors,
|
99
|
+
int srcRank,
|
100
|
+
int tag) {
|
101
|
+
return c10::make_intrusive<Work>();
|
102
|
+
}
|
103
|
+
|
104
|
+
c10::intrusive_ptr<Work> allgather(
|
105
|
+
std::vector<std::vector<at::Tensor>>& outputTensors,
|
106
|
+
std::vector<at::Tensor>& inputTensors,
|
107
|
+
const AllgatherOptions& opts = AllgatherOptions()) {
|
108
|
+
return c10::make_intrusive<Work>();
|
109
|
+
}
|
110
|
+
|
111
|
+
c10::intrusive_ptr<Work> _allgather_base(
|
112
|
+
at::Tensor& outputBuffer,
|
113
|
+
at::Tensor& inputBuffer,
|
114
|
+
const AllgatherOptions& opts = AllgatherOptions()) {
|
115
|
+
return c10::make_intrusive<Work>();
|
116
|
+
}
|
117
|
+
|
118
|
+
c10::intrusive_ptr<Work> gather(
|
119
|
+
std::vector<std::vector<at::Tensor>>& outputTensors,
|
120
|
+
std::vector<at::Tensor>& inputTensors,
|
121
|
+
const GatherOptions& opts = GatherOptions()) {
|
122
|
+
return c10::make_intrusive<Work>();
|
123
|
+
}
|
124
|
+
|
125
|
+
c10::intrusive_ptr<Work> reduce_scatter(
|
126
|
+
std::vector<at::Tensor>& outputTensors,
|
127
|
+
std::vector<std::vector<at::Tensor>>& inputTensors,
|
128
|
+
const ReduceScatterOptions& opts = ReduceScatterOptions()) {
|
129
|
+
return c10::make_intrusive<Work>();
|
130
|
+
}
|
131
|
+
|
132
|
+
c10::intrusive_ptr<Work> _reduce_scatter_base(
|
133
|
+
at::Tensor& outputBuffer,
|
134
|
+
at::Tensor& inputBuffer,
|
135
|
+
const ReduceScatterOptions& opts = ReduceScatterOptions()) {
|
136
|
+
return c10::make_intrusive<Work>();
|
137
|
+
}
|
138
|
+
|
139
|
+
c10::intrusive_ptr<Work> scatter(
|
140
|
+
std::vector<at::Tensor>& outputTensors,
|
141
|
+
std::vector<std::vector<at::Tensor>>& inputTensors,
|
142
|
+
const ScatterOptions& opts = ScatterOptions()) {
|
143
|
+
return c10::make_intrusive<Work>();
|
144
|
+
}
|
145
|
+
|
146
|
+
c10::intrusive_ptr<Work> broadcast(
|
147
|
+
std::vector<at::Tensor>& tensors,
|
148
|
+
const BroadcastOptions& opts = BroadcastOptions()) {
|
149
|
+
return c10::make_intrusive<Work>();
|
150
|
+
}
|
151
|
+
|
152
|
+
c10::intrusive_ptr<Work> allreduce(
|
153
|
+
std::vector<at::Tensor>& tensors,
|
154
|
+
const AllreduceOptions& opts = AllreduceOptions()) {
|
155
|
+
return c10::make_intrusive<Work>();
|
156
|
+
}
|
157
|
+
|
158
|
+
c10::intrusive_ptr<Work> reduce(
|
159
|
+
std::vector<at::Tensor>& tensors,
|
160
|
+
const ReduceOptions& opts = ReduceOptions()) {
|
161
|
+
return c10::make_intrusive<Work>();
|
162
|
+
}
|
163
|
+
|
164
|
+
int getSize() const {
|
165
|
+
return 0;
|
166
|
+
}
|
167
|
+
};
|
168
|
+
|
169
|
+
struct TCPStoreOptions {
|
170
|
+
static constexpr uint16_t kDefaultPort = 0;
|
171
|
+
};
|
172
|
+
|
173
|
+
class TCPStore : public torch::CustomClassHolder {};
|
174
|
+
|
175
|
+
} // namespace c10d
|
@@ -0,0 +1,232 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ir/base_nodes.h>
|
11
|
+
#include <ir/builder.h>
|
12
|
+
#include <ir/interface_nodes.h>
|
13
|
+
#include <multidevice/communicator.h>
|
14
|
+
#include <multidevice/device_mesh.h>
|
15
|
+
#include <multidevice/multidevice.h>
|
16
|
+
#ifdef NVFUSER_DISTRIBUTED
|
17
|
+
#include <torch/csrc/distributed/c10d/Types.hpp>
|
18
|
+
#else
|
19
|
+
#include <multidevice/c10d_mock.h>
|
20
|
+
#endif
|
21
|
+
#include <type.h>
|
22
|
+
#include <visibility.h>
|
23
|
+
|
24
|
+
namespace nvfuser {
|
25
|
+
|
26
|
+
enum class CommunicationType {
|
27
|
+
Gather,
|
28
|
+
Allgather,
|
29
|
+
Scatter,
|
30
|
+
Reduce,
|
31
|
+
Allreduce,
|
32
|
+
ReduceScatter,
|
33
|
+
Broadcast,
|
34
|
+
SendRecv
|
35
|
+
};
|
36
|
+
|
37
|
+
std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
|
38
|
+
|
39
|
+
using RedOpType = c10d::ReduceOp::RedOpType;
|
40
|
+
|
41
|
+
// The class "Communication" represents a MPI-style communication
|
42
|
+
// communication operation to be executed on the network. The base class
|
43
|
+
// Communication should not be used directly but through its derived classes:
|
44
|
+
// Broadcast, Gather, Scatter, Allgather, and SendRecv. Other collectives will
|
45
|
+
// be added later.
|
46
|
+
class Communication : public Expr {
|
47
|
+
public:
|
48
|
+
using Expr::Expr;
|
49
|
+
// Only specify `root` for types that have root.
|
50
|
+
// Only specify `red_op` for reduction types.
|
51
|
+
// Only specify `scattered_axis` for ReduceScatter.
|
52
|
+
Communication(
|
53
|
+
IrBuilderPasskey passkey,
|
54
|
+
CommunicationType type,
|
55
|
+
TensorView* out,
|
56
|
+
TensorView* in,
|
57
|
+
Team team, // All devices involved in this communication. It must include
|
58
|
+
// `root`. It can be a subset of `root`+`mesh` in case of 2D
|
59
|
+
// sharding.
|
60
|
+
DeviceIdxType root = -1,
|
61
|
+
RedOpType red_op = RedOpType::UNUSED,
|
62
|
+
int64_t scattered_axis = -1);
|
63
|
+
|
64
|
+
Communication(const Communication& other) = delete;
|
65
|
+
Communication& operator=(const Communication& other) = delete;
|
66
|
+
Communication(Communication&& other) = delete;
|
67
|
+
Communication& operator=(Communication&& other) = delete;
|
68
|
+
|
69
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
70
|
+
|
71
|
+
std::string toString(int indent_size = 0) const override;
|
72
|
+
std::string toInlineString(int indent_size = 0) const override;
|
73
|
+
const char* getOpString() const override {
|
74
|
+
return "Communication";
|
75
|
+
}
|
76
|
+
|
77
|
+
CommunicationType type() const {
|
78
|
+
return attribute<CommunicationType>(0);
|
79
|
+
}
|
80
|
+
|
81
|
+
TensorView* out() const {
|
82
|
+
return output(0)->as<TensorView>();
|
83
|
+
}
|
84
|
+
|
85
|
+
TensorView* in() const {
|
86
|
+
return input(0)->as<TensorView>();
|
87
|
+
}
|
88
|
+
|
89
|
+
const Team& team() const {
|
90
|
+
return attribute<Team>(1);
|
91
|
+
}
|
92
|
+
|
93
|
+
// A convenience helper so the user doesn't need to convert size_t to int64_t.
|
94
|
+
int64_t team_size() const {
|
95
|
+
return static_cast<int64_t>(team().size());
|
96
|
+
}
|
97
|
+
|
98
|
+
DeviceIdxType root() const {
|
99
|
+
return attribute<DeviceIdxType>(2);
|
100
|
+
}
|
101
|
+
|
102
|
+
RedOpType reduceOp() const {
|
103
|
+
return attribute<RedOpType>(3);
|
104
|
+
}
|
105
|
+
|
106
|
+
int64_t scatteredAxis() const {
|
107
|
+
return attribute<int64_t>(4);
|
108
|
+
}
|
109
|
+
|
110
|
+
// PyTorch's process group expects the root to be specified
|
111
|
+
// as an integer between 0 and world_size-1. We choose it to be
|
112
|
+
// the device's relative index within the team
|
113
|
+
int64_t getRootRelativeIndex();
|
114
|
+
|
115
|
+
private:
|
116
|
+
void validate();
|
117
|
+
};
|
118
|
+
|
119
|
+
enum class P2PCommunicationType { SEND, RECV };
|
120
|
+
|
121
|
+
std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type);
|
122
|
+
|
123
|
+
class P2PCommunication : public Expr {
|
124
|
+
public:
|
125
|
+
using Expr::Expr;
|
126
|
+
|
127
|
+
P2PCommunication(
|
128
|
+
IrBuilderPasskey passkey,
|
129
|
+
P2PCommunicationType type,
|
130
|
+
TensorView* buffer,
|
131
|
+
Val* peer);
|
132
|
+
|
133
|
+
P2PCommunication(const P2PCommunication& other) = delete;
|
134
|
+
P2PCommunication& operator=(const P2PCommunication& other) = delete;
|
135
|
+
P2PCommunication(P2PCommunication&& other) = delete;
|
136
|
+
P2PCommunication& operator=(P2PCommunication&& other) = delete;
|
137
|
+
|
138
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
139
|
+
|
140
|
+
std::string toString(int indent_size = 0) const override;
|
141
|
+
std::string toInlineString(int indent_size = 0) const override;
|
142
|
+
const char* getOpString() const override {
|
143
|
+
return "P2PCommunication";
|
144
|
+
}
|
145
|
+
|
146
|
+
P2PCommunicationType type() const {
|
147
|
+
return attribute<P2PCommunicationType>(0);
|
148
|
+
}
|
149
|
+
|
150
|
+
TensorView* buffer() const {
|
151
|
+
return input(0)->as<TensorView>();
|
152
|
+
}
|
153
|
+
|
154
|
+
Val* peer() const {
|
155
|
+
return attributeVal(1);
|
156
|
+
}
|
157
|
+
};
|
158
|
+
|
159
|
+
// The method "post" triggers the execution of the communication. This call is
|
160
|
+
// non-blocking. The communication can be posted multiple times.
|
161
|
+
// It is assumed that the current device_index (given by
|
162
|
+
// communicator.deviceId()) belongs to the team of the communication,
|
163
|
+
// otherwise an error is thrown.
|
164
|
+
//
|
165
|
+
// NOTE: pytorch's NCCL process group API needs <team_size> buffers on root for
|
166
|
+
// scatter/gather operation.
|
167
|
+
// (*) Broadcast
|
168
|
+
// Copies the root's src buffer to each device's dst buffer
|
169
|
+
// Requirements:
|
170
|
+
// - the root is set and belongs to the team
|
171
|
+
// - the root has one src buffer, and no or one dst buffer
|
172
|
+
// - non-roots have no src buffer and one dst buffer
|
173
|
+
// - all buffers have the same size
|
174
|
+
// (*) Gather
|
175
|
+
// Copies each device's source buffer to the root's respective src
|
176
|
+
// buffer. The order of the sender devices matches the order of the
|
177
|
+
// root's buffers.
|
178
|
+
// Requirements:
|
179
|
+
// - the root is set and belongs to the team
|
180
|
+
// - the root has one src buffer and <team_size> dst buffers
|
181
|
+
// - non-roots have one src buffer and no dst buffer
|
182
|
+
// - all buffers have the same size
|
183
|
+
// (*) Allgather
|
184
|
+
// Copies each device's src buffer to each device's respective src
|
185
|
+
// buffer. The order of the devices matches the order of the
|
186
|
+
// buffers
|
187
|
+
// Requirements:
|
188
|
+
// - all device have one src buffer and <team_size> dst buffers
|
189
|
+
// - all buffers have the same size
|
190
|
+
// (*) Scatter
|
191
|
+
// Copies each root's src buffer to each device's dst buffer.
|
192
|
+
// The order of the buffers matches the order of the receiver devices
|
193
|
+
// Requirements:
|
194
|
+
// - the root is set and belongs to the team
|
195
|
+
// - the root has <team_size> src buffers and one dst buffer
|
196
|
+
// - non-roots have no src buffer and one dst buffer
|
197
|
+
// - all buffers have the same size
|
198
|
+
// (*) Reduce
|
199
|
+
// Reduce the src buffers to the root's dst buffer.
|
200
|
+
// Requirements:
|
201
|
+
// - the root is set and belongs to the team
|
202
|
+
// - the root has one src buffers and one dst buffer
|
203
|
+
// - non-roots have one src buffer and no dst buffer
|
204
|
+
// - all buffers have the same size
|
205
|
+
// (*) Allreduce
|
206
|
+
// Reduce the src buffers to the dst buffer.
|
207
|
+
// Requirements:
|
208
|
+
// - all devices have one src buffer and one dst buffer
|
209
|
+
// - all buffers have the same size
|
210
|
+
// (*) ReduceScatter
|
211
|
+
// Reduce all the src buffers and shard the result to the dst buffers.
|
212
|
+
// Requirements:
|
213
|
+
// - all devices have <team_size> src buffer and one dst buffer
|
214
|
+
// - all buffers have the same size
|
215
|
+
// (*) SendRecv
|
216
|
+
// Copies the sender's src buffers to the receiver's dst buffer
|
217
|
+
// It is equivalent to a Broadcast with a team of size == 2
|
218
|
+
c10::intrusive_ptr<c10d::Work> postSingleCommunication(
|
219
|
+
Communication* communication,
|
220
|
+
DeviceIdxType my_device_index,
|
221
|
+
c10d::Backend* backend,
|
222
|
+
at::Tensor input_tensor,
|
223
|
+
at::Tensor output_tensor);
|
224
|
+
|
225
|
+
c10::intrusive_ptr<c10d::Work> postSingleCommunication(
|
226
|
+
P2PCommunication* communication,
|
227
|
+
DeviceIdxType my_device_index,
|
228
|
+
DeviceIdxType peer,
|
229
|
+
c10d::Backend* backend,
|
230
|
+
at::Tensor buffer);
|
231
|
+
|
232
|
+
} // namespace nvfuser
|