nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,99 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <array>
|
11
|
+
#include <cstdint>
|
12
|
+
#include <memory>
|
13
|
+
|
14
|
+
namespace nvfuser {
|
15
|
+
|
16
|
+
namespace matmul_heuristic_plugin {
|
17
|
+
|
18
|
+
//! This is intended as a minimal interface for enabling matmul heuristics.
|
19
|
+
//! In order to plug in your own custom heuristic, create a dynamic library
|
20
|
+
//! defining a subclass of KernelConfig, overriding the `configure` method. This
|
21
|
+
//! class does not need to be exported from the dll, but you should export a
|
22
|
+
//! std::unique_ptr<KernelConfig> makeConfig() function that returns a
|
23
|
+
//! unique_ptr to an object of that type. The `configure` method will be called
|
24
|
+
//! on that object by nvfuser in order to fill the correct values in the class.
|
25
|
+
//!
|
26
|
+
//! If that library is located at /path/to/libfoo.so you can set
|
27
|
+
//! NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libfoo.so to use the plugin to
|
28
|
+
//! determine matmul parameters automatically.
|
29
|
+
|
30
|
+
struct KernelConfig {
|
31
|
+
//! This is the information available to the plugin to determine the kernel
|
32
|
+
//! configuration.
|
33
|
+
struct ProblemDescription {
|
34
|
+
uint32_t m = -1;
|
35
|
+
uint32_t n = -1;
|
36
|
+
uint32_t k = -1;
|
37
|
+
uint32_t batch_size = -1;
|
38
|
+
//! Explicit integer mapping for layout
|
39
|
+
enum class Layout {
|
40
|
+
NN = 0,
|
41
|
+
NT = 1,
|
42
|
+
TN = 2,
|
43
|
+
TT = 3,
|
44
|
+
};
|
45
|
+
Layout layout = Layout::TN;
|
46
|
+
//! Precision is a string like HSH or TSS indicating input, compute, and
|
47
|
+
//! accumulate precision where the letters are mapped to types using the
|
48
|
+
//! following mapping:
|
49
|
+
//! B = Int8
|
50
|
+
//! I = Int32
|
51
|
+
//! Q = FP8 (E4M3)
|
52
|
+
//! R = FP8 (E5M2)
|
53
|
+
//! T = BFloat16
|
54
|
+
//! H = Float16
|
55
|
+
//! F = TensorFloat32
|
56
|
+
//! S = Float32
|
57
|
+
//! D = Float64
|
58
|
+
//! C = complex<float>
|
59
|
+
//! Z = complex<double>
|
60
|
+
//! Note that some of these are not currently supported by nvFuser.
|
61
|
+
const char* precision = "SSS";
|
62
|
+
|
63
|
+
//! Supported vectorization of operands and epilogue inputs (bias)
|
64
|
+
struct SupportedVectorization {
|
65
|
+
uint8_t a = 16;
|
66
|
+
uint8_t b = 16;
|
67
|
+
uint8_t epilogue = 16;
|
68
|
+
} supported_vec_size;
|
69
|
+
} problem;
|
70
|
+
|
71
|
+
using Tile = std::array<uint16_t, 3>;
|
72
|
+
Tile cta_tile = {128, 128, 32};
|
73
|
+
Tile warp_tile = {64, 64, 32};
|
74
|
+
Tile instruction_tile = {16, 16, 16};
|
75
|
+
Tile cluster_dims = {1, 1, 1};
|
76
|
+
uint16_t splitk_factor = 1;
|
77
|
+
uint8_t load_stages = 2;
|
78
|
+
// The circular buffering prefetch distance will be set to
|
79
|
+
// load_stages - prefetch_gap
|
80
|
+
uint8_t prefetch_gap = 1;
|
81
|
+
uint8_t grid_swizzle_factor = 0;
|
82
|
+
uint8_t cta_order = 0;
|
83
|
+
bool circular_buffer_smem_read = true;
|
84
|
+
bool async_gmem_load_operands = true;
|
85
|
+
|
86
|
+
public:
|
87
|
+
// This should be overridden to implement the actual heuristic logic
|
88
|
+
virtual void configure() = 0;
|
89
|
+
|
90
|
+
// This allows us to use a std::unique_ptr<KernelConfig> and call derived
|
91
|
+
// classes' destructors on deletion.
|
92
|
+
// See
|
93
|
+
// https://clang.llvm.org/extra/clang-tidy/checks/cppcoreguidelines/virtual-class-destructor.html
|
94
|
+
virtual ~KernelConfig() = default;
|
95
|
+
};
|
96
|
+
|
97
|
+
} // namespace matmul_heuristic_plugin
|
98
|
+
|
99
|
+
} // namespace nvfuser
|
@@ -0,0 +1,54 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <fusion.h>
|
12
|
+
|
13
|
+
namespace nvfuser {
|
14
|
+
|
15
|
+
class SchedulerRuntimeInfo;
|
16
|
+
class HeuristicDataCache;
|
17
|
+
class MatmulParams;
|
18
|
+
|
19
|
+
namespace matmul_utils {
|
20
|
+
//! An implementation of functionality that will prepare heuristics for fusion
|
21
|
+
//! that represents matmul. May return empty object if any of conditions are
|
22
|
+
//! not met.
|
23
|
+
//! TODO: Remove and only have public facing APIs through SchedulerEntry like
|
24
|
+
//! the other schedulers
|
25
|
+
std::unique_ptr<MatmulParams> getMatmulHeuristics(
|
26
|
+
Fusion* fusion,
|
27
|
+
SchedulerRuntimeInfo& runtime_info,
|
28
|
+
HeuristicDataCache* data_cache = nullptr);
|
29
|
+
|
30
|
+
//! An implementation of compile time checks. Returns messasge if given fusion
|
31
|
+
//! does not represent matmul, otherwise an empty string is returned.
|
32
|
+
std::string getMatmulCompileTimeRejectReason(Fusion* fusion);
|
33
|
+
|
34
|
+
//! An implementation of runtime time checks. Returns messasge if given fusion
|
35
|
+
//! does not represent matmul, otherwise an empty string is returned.
|
36
|
+
std::string getMatmulRunTimeRejectReason(
|
37
|
+
Fusion* fusion,
|
38
|
+
HeuristicDataCache* data_cache,
|
39
|
+
SchedulerRuntimeInfo& runtime_info);
|
40
|
+
|
41
|
+
//! This is a utility to determine whether we can use cp.async to load the
|
42
|
+
//! operands A and B. Heuristic plugins can use this to help them set
|
43
|
+
//! async_gmem_load_operands.
|
44
|
+
bool NVF_API isCpAsyncOperandLoadSupported(
|
45
|
+
const MatmulParams* params,
|
46
|
+
int64_t min_dtype_size);
|
47
|
+
|
48
|
+
// Move the broadcast axes to the left on the specified number of inner
|
49
|
+
// dimensions e.g. (when number_of_inner_pos == 3):
|
50
|
+
// [... I0, B, I1] -> [... B, I0, I1]
|
51
|
+
// should probably be only used to order innermost mnk axes.
|
52
|
+
void moveInnerBroadcastLeft(TensorView* tv, int64_t number_of_inner_pos = 3);
|
53
|
+
} // namespace matmul_utils
|
54
|
+
} // namespace nvfuser
|
@@ -0,0 +1,500 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <fusion.h>
|
12
|
+
#include <id_model/id_model.h>
|
13
|
+
#include <mma_type.h>
|
14
|
+
#include <scheduler/matmul_heuristic.h>
|
15
|
+
#include <scheduler/tools/abstract_tensor.h>
|
16
|
+
#include <val_graph.h>
|
17
|
+
#include <visibility.h>
|
18
|
+
|
19
|
+
#include <array>
|
20
|
+
#include <variant>
|
21
|
+
#include <vector>
|
22
|
+
|
23
|
+
namespace nvfuser {
|
24
|
+
|
25
|
+
namespace mma_utils {
|
26
|
+
|
27
|
+
//! Utilities in this namespace facilitates scheduling matmul kernels with
|
28
|
+
//! hierarchichal tiling specified in MatMulTileOptions.
|
29
|
+
|
30
|
+
//! A mapping from ValGroup pointers to MatmulDimRole. The ValGroups should
|
31
|
+
//! correspond to IterDomain groups from an IdModel's exact graph. This
|
32
|
+
using DimRolesMap = std::unordered_map<ValGroup, MatmulDimRole>;
|
33
|
+
|
34
|
+
//! Schedule utility for matmul prolog:
|
35
|
+
//! Use all the threads on a CTA tile to load matmul operands
|
36
|
+
//! into shared memory with the given vectorization word.
|
37
|
+
//! TODO:
|
38
|
+
//! will need to add bank conflict removal swizzle in a follow up.
|
39
|
+
NVF_API void scheduleContiguousVectorLoad(
|
40
|
+
TensorView* tv,
|
41
|
+
MatMulTileOptions tile,
|
42
|
+
int64_t vector_word,
|
43
|
+
bool vectorize = true);
|
44
|
+
|
45
|
+
//! Schedule utility for mma output in matmul main loop:
|
46
|
+
//! Realize the hierarchical tiling based on the given tiling options.
|
47
|
+
//! TODO: rewrite this one with makeTile
|
48
|
+
NVF_API void scheduleWarpTileWithReduction(
|
49
|
+
TensorView* tv,
|
50
|
+
MatMulTileOptions tile,
|
51
|
+
MmaMacro macro);
|
52
|
+
|
53
|
+
//! Schedule utility for mma output in matmul main loop:
|
54
|
+
//! Realize the hierarchical tiling based on the given tiling options
|
55
|
+
//! on consumers of mma ops in epilog.
|
56
|
+
//! TODO: remove this one eventually.
|
57
|
+
NVF_API void scheduleWarpTileWithNoReduction(
|
58
|
+
TensorView* tv,
|
59
|
+
MatMulTileOptions tile,
|
60
|
+
MmaMacro macro);
|
61
|
+
|
62
|
+
//! Lower level primitive spliting inner iterdomains into tiles:
|
63
|
+
//! Eg.
|
64
|
+
//! A[B,I0,I1,I2] -> makeTile({1,2,3})
|
65
|
+
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
|
66
|
+
void makeTile(TensorView* tv, const std::vector<int64_t>& tile_sizes);
|
67
|
+
|
68
|
+
//! The above call assumes the axes are [(B), M, N, K]. In this version, we
|
69
|
+
//! provide the dimension roles that are present for this tensor.
|
70
|
+
void makeTile(
|
71
|
+
TensorView* tv,
|
72
|
+
const GemmTile& tile_sizes,
|
73
|
+
const std::vector<MatmulDimRole>& axis_roles);
|
74
|
+
|
75
|
+
//! We model each dimension of every tensor in the Fusion with ID roles
|
76
|
+
//! described by MatmulDimRole.
|
77
|
+
using AbstractMatmulTensor = TaggedAbstractTensor<MatmulDimRole>;
|
78
|
+
|
79
|
+
//! Abstract version of the above utility. Schedules the provided
|
80
|
+
//! AbstractMatmulTensor instead of a concrete TensorView.
|
81
|
+
void makeTile(
|
82
|
+
AbstractMatmulTensor& canonicalized_abstract_tensor,
|
83
|
+
const std::vector<int64_t>& tile_sizes);
|
84
|
+
|
85
|
+
//! Order the inner tile dimensions as the original order in
|
86
|
+
//! (maybe allocation) domain. Also putting broadcast domains on the left.
|
87
|
+
//! Eg. A[I0o,I1o,B2o,I0i,I1i,B2i] (maybe allocation domain: I1,B,I0)
|
88
|
+
//! -> A[I0o, I1o, B2o, B2i, I1i, I0i]
|
89
|
+
//! This is used to facilitate data layout swizzling and
|
90
|
+
//! defining vectorized loads.
|
91
|
+
void orderTiledConcreteIdAsMaybeAllocationDomain(TensorView* tv);
|
92
|
+
|
93
|
+
//! Orders the leaf ID canonically, and merges dims of the same role
|
94
|
+
//! The return value gives the role of each loop IterDomain in tv.
|
95
|
+
std::vector<MatmulDimRole> canonicalizeMmaTvOrdering(
|
96
|
+
TensorView* tv,
|
97
|
+
const ValGraph& broadcast_graph,
|
98
|
+
const DimRolesMap& dim_roles,
|
99
|
+
const std::vector<ValGroup>& ordering);
|
100
|
+
|
101
|
+
//! Given a TensorView matching the canonicalDimOrdering, schedule it by
|
102
|
+
//! merging dimensions with matching roles.
|
103
|
+
void mergeConsecutiveAxesWithSameRole(
|
104
|
+
TensorView* tv,
|
105
|
+
const DimRolesMap& dim_roles,
|
106
|
+
const ValGraph* graph);
|
107
|
+
|
108
|
+
//! [MmaSwizzler]:
|
109
|
+
//! This class is used to implement the thread swizzle format
|
110
|
+
//! required for the mma macros, cf. PTX ISA 9.7.13.4.
|
111
|
+
//!
|
112
|
+
//! The mma instructions (Volta and later arch) require specific
|
113
|
+
//! thread mapping within a warp for both the mma inputs and
|
114
|
+
//! mma outputs. All mma swizzle patterns seen so far turned out
|
115
|
+
//! to be affine, so we could use the normal scheduler interface
|
116
|
+
//! to fulfill the mma thread swizzle pattern. And fusion with
|
117
|
+
//! other non-mma ops and validations can just natually rely on the current
|
118
|
+
//! iterdomain infrastructure.
|
119
|
+
//!
|
120
|
+
//! This is different from a normal scheduler utility though,
|
121
|
+
//! as the thread mapping within a warp are *required* to be
|
122
|
+
//! a specific pattern which currently translates to an enforced
|
123
|
+
//! requirement that all the loop domains produced by MmaSwizzler
|
124
|
+
//! cannot be further transformed (split/merge/reorder etc.).
|
125
|
+
//!
|
126
|
+
//! Currently MmaSwizzler can be accessed by schedulers through
|
127
|
+
//! TensorView::applyMmaSwizzle, and the current scheduling procedure is
|
128
|
+
//! as follows:
|
129
|
+
//!
|
130
|
+
//! Step 1. Before scheduling, the mma op needs to be configured with a macro
|
131
|
+
//! type, either manually or inferred (eg. Ampere_16_8_8).
|
132
|
+
//!
|
133
|
+
//! Step 2. Scheduler can tile the outer dimensions based on any heuristics,
|
134
|
+
//! i.e. the CTA tiling, warp tiling, splitK etc.
|
135
|
+
//!
|
136
|
+
//! Step 3. The scheduler will need to split the innermost part of the 3
|
137
|
+
//! involved
|
138
|
+
//! root dimensions, they need to be ordered as M,N,K on the rightmost of
|
139
|
+
//! tensordomain (see [Operand Layout Convention] for exact definition).
|
140
|
+
//!
|
141
|
+
//! For example before calling MmaSwizzler, the domain could look like:
|
142
|
+
//! [TileM, TileN, TileK, Im(16), In(8), Rk(8)], to use Ampere_16_8_8.
|
143
|
+
//! The rightmost 3 iterdomains need to be the innermost component of their
|
144
|
+
//! corresponding root id, similar to vectorization except this requirement
|
145
|
+
//! applies to all 3 rightmost dims.
|
146
|
+
//!
|
147
|
+
//! Before applying swizzle, MmaSwizzler will try to validate:
|
148
|
+
//! 1. The "innermost-ness" of the rightmost 3 iterdomains. E.g:
|
149
|
+
//! Xo, Xi = split(X, 16),
|
150
|
+
//! Xo doesn't check, Xi would check.
|
151
|
+
//! 2. The rightmost three are constant sized, and they are ordered as
|
152
|
+
//! M,N,K.
|
153
|
+
//! In the case of operand schedule before the broadcast, only 2 of
|
154
|
+
//! the axis are see, and they still need to follow the same order,
|
155
|
+
//! i.e. need to be M,K or N,K.
|
156
|
+
//! 3. The rightmost three axes have matching size with the selected
|
157
|
+
//! mma macro.
|
158
|
+
//!
|
159
|
+
//! Step 4. MmaSwizzler will transform the rightmost 3 domains to the
|
160
|
+
//! correct swizzle
|
161
|
+
//! format and will parallelize the TIDx, which is reserved for lane id. The
|
162
|
+
//! transformed inner iterdomains will be locked with WarpMapped tag so that
|
163
|
+
//! they cannot be further transformed. Currently the only change that
|
164
|
+
//! scheduler can still do after this step is to vectorize the innermost
|
165
|
+
//! iterdomain.
|
166
|
+
//!
|
167
|
+
//! Notes:
|
168
|
+
//! This version of implementation is trying to balance the composition
|
169
|
+
//! flexibility and validation complexity. Currently the validation protocol
|
170
|
+
//! is that if the rightmost 3 dimensions given to MmaSwizzler are indeed
|
171
|
+
//! innermost components of the 3 root id's and their dimensions match the mma
|
172
|
+
//! macro, the swizzle format produced by MmaSwizzler will be correct for
|
173
|
+
//! the macro and we just lock the innermost iterdomains from further
|
174
|
+
//! transformations.
|
175
|
+
//!
|
176
|
+
//! Ninja users/schedulers might go for 2 cases that we currently don't
|
177
|
+
//! support:
|
178
|
+
//!
|
179
|
+
//! 1. Equivalent affine transforms:
|
180
|
+
//! Even though the mma swizzles are affine, there are still infinitely many
|
181
|
+
//! equivalent ways to implement
|
182
|
+
//! the same affine transform. E.g. io,ii = split(i,8); ioii =
|
183
|
+
//! merge(io,ii); would make ioii equiv to i if it's a divisible split. One
|
184
|
+
//! can use this to construct infinite many equivalent affine swizzles.
|
185
|
+
//!
|
186
|
+
//! Users/schedulers might want to have a different but equivalent affine
|
187
|
+
//! representation from the one provided
|
188
|
+
//! by MmaSwizzler, but validating them needs some extra work
|
189
|
+
//! canonicalizing the affine transforms. So short term wouldn't support
|
190
|
+
//! this flexibility.
|
191
|
+
//!
|
192
|
+
//! 2. Swizzled data input:
|
193
|
+
//! It is also possible that the data input has other swizzles before
|
194
|
+
//! entering the fusion already and some might be natively compatible
|
195
|
+
//! with mma format. This is a very broad category of use cases
|
196
|
+
//! and we'd have to consider enabling any use like this case-by-case.
|
197
|
+
class MmaSwizzler {
|
198
|
+
public:
|
199
|
+
//! Applies the output mma swizzling to the given tv, should be used
|
200
|
+
//! on mma output or tv's involved in epilog fusion, i.e. bias.
|
201
|
+
//! The rightmost iterdomains must follow the m,n,k convention before calling.
|
202
|
+
static AbstractTensor scheduleMmaOutputAllocation(AbstractTensor t);
|
203
|
+
|
204
|
+
//! Applies the input mma swizzling to the given tv as its allocation domain,
|
205
|
+
//! should be used on mma input or tv's involved in any fusion before mma, but
|
206
|
+
//! after smem read.
|
207
|
+
//! The rightmost iterdomains must follow the m,n,k convention before calling.
|
208
|
+
static void scheduleOperandRead(TensorView* tv, MmaOperand operand);
|
209
|
+
static void scheduleOperandRead(TensorView* tv, MmaInputSmemSwizzle swizzle);
|
210
|
+
|
211
|
+
//! Note [schedule of ldmatrix]
|
212
|
+
//! If you look at the doc of ldmatrix and mma for Turing and Ampere:
|
213
|
+
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-float
|
214
|
+
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
|
215
|
+
//! you will find that, the memory layout of the output of ldmatrix, which
|
216
|
+
//! matches with the input layout of MMA instruction, mismatch with the index
|
217
|
+
//! that each thread uses to call ldmatrix. In nvFuser, we schedule the
|
218
|
+
//! allocation domain of the ldmatrix output and mma inputs to be consistent
|
219
|
+
//! with the memory layout of the output of ldmatrix, and we schedule the
|
220
|
+
//! loop domain of the ldmatrix output to be consistent with the index that
|
221
|
+
//! each thread uses to call ldmatrix. This function is used to schedule the
|
222
|
+
//! loop domain of the ldmatrix output. The allocation domain of the ldmatrix
|
223
|
+
//! output and mma inputs are scheduled in scheduleOperandRead, which must be
|
224
|
+
//! called before this function.
|
225
|
+
static void scheduleLdMatrix(TensorView* tv, MmaOperand operand);
|
226
|
+
|
227
|
+
//! Function to schedule the load of the input operands of a
|
228
|
+
//! Mma op. This internally calls swizzleTMABox. This function
|
229
|
+
//! splits/tiles the inputs to correct 2D TMA boxes and calls the function
|
230
|
+
//! above. Please note that we currently do not fully support not splitting
|
231
|
+
//! the outer dimension. This only works when the inner-dimension is not
|
232
|
+
//! split, that is the inner dim is less or equal to the swizzle size (in
|
233
|
+
//! bytes). The outer dim here refers to the second ID from the end, so for
|
234
|
+
//! the input [B, N, K], N would be outer. Broadcast is always moved
|
235
|
+
//! outermost.
|
236
|
+
static void scheduleTMALoadForMma(
|
237
|
+
TensorView* tv,
|
238
|
+
MmaInputSmemSwizzle swizzle);
|
239
|
+
|
240
|
+
//! Parallelize all dims as bulk expect the first dims mentioned in the second
|
241
|
+
//! param.
|
242
|
+
static void parallelizeAsBulkSkippingFirstIDs(
|
243
|
+
TensorView* tv,
|
244
|
+
int64_t first_ids_to_skip);
|
245
|
+
};
|
246
|
+
|
247
|
+
//! Schedules the copy operation of output of a Mma op which resided in the
|
248
|
+
//! shared memory to global memory.
|
249
|
+
void scheduleTMAStoreForMmaOutput(TensorView* tv, MmaInputSmemSwizzle swizzle);
|
250
|
+
|
251
|
+
//! Schedules the copy operation of output of a Mma op which resided in the
|
252
|
+
//! registers to shared memory.
|
253
|
+
void scheduleStMatrixForMmaOutput(
|
254
|
+
TensorView* tv,
|
255
|
+
MmaInputSmemSwizzle swizzle,
|
256
|
+
int64_t tile_m,
|
257
|
+
int64_t tile_n);
|
258
|
+
|
259
|
+
void checkDimSize(
|
260
|
+
TensorView* tv,
|
261
|
+
std::vector<int64_t> axis,
|
262
|
+
std::vector<int64_t> expect);
|
263
|
+
|
264
|
+
//! A constant with minimum number of fusion inputs that could be MMA inputs.
|
265
|
+
//! TODO: update for square matmuls where both inputs are the same tensor
|
266
|
+
constexpr size_t MIN_MATMUL_INPUTS_NUMBER = 2;
|
267
|
+
|
268
|
+
//! An alias for data structure for passing IterDomains representing problem
|
269
|
+
//! shape dimensions
|
270
|
+
//! TODO: extend definition for handling batch matmuls
|
271
|
+
using ProblemIterDomains = std::array<IterDomain*, 3>;
|
272
|
+
|
273
|
+
//! An alias for mapping between TensorView instance and its role in
|
274
|
+
//! matmul fusion definition, some roles can be assigned to more than
|
275
|
+
//! a single tv, for example input for beta scaling in epilogue
|
276
|
+
using TensorRolesMap =
|
277
|
+
std::unordered_map<MatmulTensorRole, std::vector<TensorView*>>;
|
278
|
+
|
279
|
+
//! An alias for storing data types of the tensors in the mma op
|
280
|
+
//! the order is A, B, OUTPUT
|
281
|
+
using MmaDataTypes = std::array<DataType, 3>;
|
282
|
+
|
283
|
+
//! A wrapper for data containers with optional error message stored if
|
284
|
+
//! initialization of the data fails.
|
285
|
+
template <typename DataType>
|
286
|
+
class DataWrapperOpt {
|
287
|
+
private:
|
288
|
+
std::variant<std::string, DataType> data;
|
289
|
+
|
290
|
+
public:
|
291
|
+
DataWrapperOpt(std::string&& v) : data(std::move(v)) {}
|
292
|
+
DataWrapperOpt(DataType&& v) : data(std::move(v)) {}
|
293
|
+
|
294
|
+
bool isValid() const {
|
295
|
+
return std::holds_alternative<DataType>(data);
|
296
|
+
}
|
297
|
+
DataType getData() const {
|
298
|
+
return std::get<DataType>(data);
|
299
|
+
}
|
300
|
+
std::string getErrorMsg() const {
|
301
|
+
if (data.valueless_by_exception() ||
|
302
|
+
std::holds_alternative<std::string>(data)) {
|
303
|
+
return "Uninitialized data in data holder object";
|
304
|
+
} else {
|
305
|
+
return std::get<std::string>(data);
|
306
|
+
}
|
307
|
+
}
|
308
|
+
};
|
309
|
+
|
310
|
+
//! This represents a single matmul operation, without a prologue or epilogue.
|
311
|
+
//! Each matmul has two inputs which might not be fusion inputs: A and B. It
|
312
|
+
//! also has one output, which can be Float or reduced precision. For MatmulOp
|
313
|
+
//! and LinearOp, the output is the same dtype as the inputs; so output does not
|
314
|
+
//! necessarily correspond to the output of a translated MmaOp and it might not
|
315
|
+
//! be a fusion output.
|
316
|
+
struct MatmulPattern {
|
317
|
+
TensorView* A;
|
318
|
+
TensorView* B;
|
319
|
+
// This is not necessarily a Fusion output, but rather is the immediate output
|
320
|
+
// representing a matmul in the current Fusion. The definition of this tensor
|
321
|
+
// determines what kind of translation is needed, if any. Possible definition
|
322
|
+
// Expr types are: MmaOp, ReductionOp (for mul-sum patterns), MatmulOp, and
|
323
|
+
// LinearOp.
|
324
|
+
TensorView* output;
|
325
|
+
|
326
|
+
//! If the pattern is not already represented by an MmaOp, for example if
|
327
|
+
//! there is a MatmulOp instead, this function modifies the fusion to insert
|
328
|
+
//! an MmaOp. TensorViews A and B are unchanged, but this->output might be
|
329
|
+
//! updated to reflect the replacement tensor.
|
330
|
+
//!
|
331
|
+
//! If avoid_intermediates is true, this function will use an
|
332
|
+
//! MmaOp::AxisMapping instead of broadcasting and permuting axes, in order to
|
333
|
+
//! avoid introducing unnecessary copies on Hopper and above.
|
334
|
+
MmaOp* translateToMmaOp(bool avoid_intermediates = false);
|
335
|
+
|
336
|
+
//! Given an IdModel, map groups of IterDomains to dimension roles
|
337
|
+
//! (MatmulDimRole). Note that ValGroup is a shared_ptr to a
|
338
|
+
//! VectorOfUniqueEntries<Val*>. We copy these as keys so that the returned
|
339
|
+
//! object can safely outlive id_model.
|
340
|
+
DimRolesMap getDimRoles(IdModel& id_model) const;
|
341
|
+
|
342
|
+
std::string toString() const;
|
343
|
+
};
|
344
|
+
|
345
|
+
//! Traverse the fusion to find supported matmul patterns
|
346
|
+
std::vector<MatmulPattern> findMatmulPatterns(Fusion* fusion);
|
347
|
+
|
348
|
+
//! This is a vector of roles describing the inner dimension of each operand
|
349
|
+
using MatmulOperandInnerDims = std::vector<MatmulDimRole>;
|
350
|
+
|
351
|
+
using MatmulOperandInnerDimsOpt = DataWrapperOpt<MatmulOperandInnerDims>;
|
352
|
+
using ProblemIterDomainsOpt = DataWrapperOpt<ProblemIterDomains>;
|
353
|
+
using DimRolesMapOpt = DataWrapperOpt<DimRolesMap>;
|
354
|
+
using TensorRolesMapOpt = DataWrapperOpt<TensorRolesMap>;
|
355
|
+
|
356
|
+
using DomainsDesc = std::vector<MatmulDimRole>;
|
357
|
+
using DependenciesMap = std::map<TensorView*, DomainsDesc>;
|
358
|
+
|
359
|
+
//! Returns wrapped matmul input memory layout data, if supported, otherwise
|
360
|
+
//! returned object contains a message with failure root cause.
|
361
|
+
//!
|
362
|
+
//! Matmul layout depends only on fusion definition while mma layout relies on
|
363
|
+
//! HW implementation to handle input layout from fusion definition. Detailed
|
364
|
+
//! explanation:
|
365
|
+
//! - matmul layout which contains information about transposition of matmul
|
366
|
+
//! inputs, it is based on the order of key domains (M,N K) in fusion input
|
367
|
+
//! tensors,
|
368
|
+
//! - mma layout, some architectures (e.g. Hopper) support all combination of
|
369
|
+
//! transposition of inputs in mma instructions, while other (e.g. Turing,
|
370
|
+
//! Ampere) the only supported transposition is TN which means that mma
|
371
|
+
//! instruction first input is transposed, the second input is non-transposed.
|
372
|
+
NVF_API MatmulOperandInnerDimsOpt getOperandInnerDims(
|
373
|
+
const IdModel& id_model,
|
374
|
+
const DimRolesMap& dim_roles,
|
375
|
+
const TensorRolesMap& tensor_roles);
|
376
|
+
|
377
|
+
//! This version assumes the Fusion contains a single MatmulPattern, then builds
|
378
|
+
//! an IdModel and infers dim roles then calls the above function.
|
379
|
+
NVF_API MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion);
|
380
|
+
|
381
|
+
//! Returns wrapped collection of TensorView roles in fusion.
|
382
|
+
//! An error message is stored in retruned object if valid data cannot
|
383
|
+
//! be gathered.
|
384
|
+
TensorRolesMapOpt getTensorRoles(
|
385
|
+
Fusion* fusion,
|
386
|
+
const IdModel& id_model,
|
387
|
+
const DimRolesMap& dim_roles);
|
388
|
+
|
389
|
+
//! Return pair of whether use shared memory epilogue or not and whether to
|
390
|
+
//! reuse shared memory for the prologue at the expense of an additional block
|
391
|
+
//! sync.
|
392
|
+
//!
|
393
|
+
//! Returns true in first position if using shared memory epilogue won't cause
|
394
|
+
//! the decrease of occupancy ratio. The occupancy ratio is estimated using
|
395
|
+
//! register and shared memory usage. If ignore_occupancy_drop is set to true,
|
396
|
+
//! returns true if there is enough shared memory to launch the kernel without
|
397
|
+
//! considering the occupancy, useful for debug and validate shared memory
|
398
|
+
//! epilogue implementation.
|
399
|
+
//!
|
400
|
+
//! Returns true in the second position if reusing shared memory for the
|
401
|
+
//! epilogue does not increase occupancy.
|
402
|
+
std::pair<bool, bool> generateSharedMemoryEpilogueHeuristics(
|
403
|
+
const MatMulTileOptions& gemm_tile,
|
404
|
+
const int smem_circular_buffer_stage,
|
405
|
+
const TensorRolesMap& tensor_roles,
|
406
|
+
bool ignore_occupancy_drop = false);
|
407
|
+
|
408
|
+
//! This version assumes roles_map has been analyzed to determine smem datatypes
|
409
|
+
//! as well as guarantees about prologue smem reuse.
|
410
|
+
NVF_API std::pair<bool, bool> generateSharedMemoryEpilogueHeuristics(
|
411
|
+
const MatMulTileOptions& gemm_tile,
|
412
|
+
const int smem_circular_buffer_stage,
|
413
|
+
const MmaDataTypes& data_types,
|
414
|
+
bool smem_a_reuse_guaranteed = false,
|
415
|
+
bool smem_b_reuse_guaranteed = false,
|
416
|
+
bool ignore_occupancy_drop = false);
|
417
|
+
|
418
|
+
//! Compute the amount of shared memory we expect to need. The actual amount
|
419
|
+
//! allocated will be determined by aliasing (see alias_memory.cpp). This
|
420
|
+
//! function is useful for testing that we provide accurate information to our
|
421
|
+
//! heuristics.
|
422
|
+
int64_t computeExpectedSharedMemoryUsage(
|
423
|
+
const MatmulParams* mparams,
|
424
|
+
const MmaDataTypes& data_types,
|
425
|
+
bool smem_a_reuse_guaranteed = false,
|
426
|
+
bool smem_b_reuse_guaranteed = false);
|
427
|
+
|
428
|
+
//! Encode DataType as character using the following mapping (not all are
|
429
|
+
//! supported yet in nvFuser):
|
430
|
+
//! B = Int8
|
431
|
+
//! I = Int32
|
432
|
+
//! Q = FP8 (E4M3)
|
433
|
+
//! R = FP8 (E5M2)
|
434
|
+
//! T = BFloat16
|
435
|
+
//! H = Float16
|
436
|
+
//! F = TensorFloat32
|
437
|
+
//! S = Float32
|
438
|
+
//! D = Float64
|
439
|
+
//! C = complex<float>
|
440
|
+
//! Z = complex<double>
|
441
|
+
char dtypeToChar(const DataType& dtype);
|
442
|
+
|
443
|
+
//! This function helps determine if ldmatrix requires a transpose.
|
444
|
+
bool isLdMatrixTranspose(const LoadStoreOp* ldst);
|
445
|
+
|
446
|
+
//! Get a total ordering of dimensions for known tensors. All dims of a
|
447
|
+
//! particular DimRole are adjacent in the output. We then set the order as
|
448
|
+
//! follows:
|
449
|
+
//! 1. Batch dimensions go first
|
450
|
+
//! 2. K dimensions are innermost
|
451
|
+
//! 3. M or N can be innermost, depending on the first output's allocation
|
452
|
+
//! domain's innermost non-batch dimension.
|
453
|
+
//! 4. Within each DimRole, dims are ordered as follows:
|
454
|
+
//! a. Batch, M, and N dimensions are ordered like the allocation domain of
|
455
|
+
//! the first output
|
456
|
+
//! b. K dimensions are ordered like the allocation domain of the first
|
457
|
+
//! A operand
|
458
|
+
//!
|
459
|
+
//! NOTE: The broadcast graph is used for this so that we map broadcast
|
460
|
+
//! dimensions to non-broadcast.
|
461
|
+
// TODO: we might want more sophisticated ordering analysis for multi-dim role
|
462
|
+
// ordering to maximize vectorization across multiple tensors (rule 4)
|
463
|
+
std::vector<ValGroup> canonicalDimOrdering(
|
464
|
+
const mma_utils::TensorRolesMap& tensor_roles,
|
465
|
+
const mma_utils::DimRolesMap& dim_roles,
|
466
|
+
const ValGraph& broadcast_graph);
|
467
|
+
|
468
|
+
//! Returns roles maps which have been merged across individual maps generated
|
469
|
+
//! by the provided matmul patterns.
|
470
|
+
//!
|
471
|
+
//! Returns std::nullopt if two patterns have incompatible roles
|
472
|
+
std::optional<std::pair<DimRolesMap, TensorRolesMap>> allPatternRoles(
|
473
|
+
IdModel& id_model,
|
474
|
+
const std::vector<MatmulPattern>& patterns);
|
475
|
+
|
476
|
+
// Utility to check concrete static size
|
477
|
+
inline void checkConcreteStaticDim(const AbstractId& abs_id) {
|
478
|
+
IterDomain* id = representativeId(abs_id);
|
479
|
+
NVF_ERROR(
|
480
|
+
!id->isBroadcast() && !id->isReduction(),
|
481
|
+
"no support for reduction or broadcast domains, but got ",
|
482
|
+
id->toString());
|
483
|
+
NVF_ERROR(
|
484
|
+
id->extent()->isConstInt(),
|
485
|
+
"swizzled dimension's extend must be known during scheduling, got ",
|
486
|
+
id->toString());
|
487
|
+
}
|
488
|
+
|
489
|
+
//! Automatically generates the shared memory swizzled data layout for tma loads
|
490
|
+
//! in matmul mainloop. The shared memory data layout is always 2D currently.
|
491
|
+
//! This utility function assumes that the shared_mem_tv has the following
|
492
|
+
//! structure: [tile_row, tile_col]
|
493
|
+
//! Returns which swizzle format to use for mma inputs with tma loads.
|
494
|
+
MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv);
|
495
|
+
|
496
|
+
} // namespace mma_utils
|
497
|
+
|
498
|
+
std::string toString(const mma_utils::AbstractMatmulTensor& abten);
|
499
|
+
|
500
|
+
} // namespace nvfuser
|