nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,74 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <fusion.h>
|
11
|
+
#include <scheduler/matmul_heuristic.h>
|
12
|
+
#include <scheduler/mma_utils.h>
|
13
|
+
#include <val_graph.h>
|
14
|
+
#include <val_graph_visitor.h>
|
15
|
+
#include <visibility.h>
|
16
|
+
|
17
|
+
namespace nvfuser {
|
18
|
+
|
19
|
+
// Base class for AmpereMultipleMatmulScheduler and
|
20
|
+
// HopperMultipleMatmulScheduler
|
21
|
+
class MultipleMatmulScheduler {
|
22
|
+
public:
|
23
|
+
MultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params)
|
24
|
+
: fusion_(fusion),
|
25
|
+
params_(params),
|
26
|
+
id_model_(fusion, /*build_graphs=*/false) {}
|
27
|
+
virtual ~MultipleMatmulScheduler() = default;
|
28
|
+
|
29
|
+
virtual void run() = 0;
|
30
|
+
|
31
|
+
protected:
|
32
|
+
void findPatterns();
|
33
|
+
|
34
|
+
void translatePatterns();
|
35
|
+
|
36
|
+
// Get tensor roles and id roles
|
37
|
+
// When there are multiple matmul patterns, we can have conflicting roles.
|
38
|
+
// For now we throw an error if this is the case.
|
39
|
+
// TODO: This should be checked in canScheduleCompileTime
|
40
|
+
void findRoles();
|
41
|
+
|
42
|
+
void countDims();
|
43
|
+
|
44
|
+
//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer
|
45
|
+
//! to the new IdModel. This is necessary whenever we perform an operation
|
46
|
+
//! that creates a new TensorView, such as caching or rFactor
|
47
|
+
void updateIdModel();
|
48
|
+
|
49
|
+
protected:
|
50
|
+
Fusion* fusion_;
|
51
|
+
const MatmulParams* params_;
|
52
|
+
IdModel id_model_;
|
53
|
+
|
54
|
+
// Broadcast graph of id_model_, which we modify at times using e.g.
|
55
|
+
// AbstractTensor.split or by mapping vals in cacheAfter and rFactor
|
56
|
+
ValGraph* graph_ = nullptr;
|
57
|
+
std::vector<mma_utils::MatmulPattern> patterns_;
|
58
|
+
mma_utils::DimRolesMap id_roles_;
|
59
|
+
mma_utils::TensorRolesMap tensor_roles_;
|
60
|
+
mma_utils::MatmulOperandInnerDims inner_dims_;
|
61
|
+
|
62
|
+
int64_t num_splitk_dims_ = 0;
|
63
|
+
int64_t num_device_dims_ = 0;
|
64
|
+
int64_t num_local_batch_dims_ = 0;
|
65
|
+
int64_t num_device_and_batch_dims_ = 0;
|
66
|
+
|
67
|
+
std::vector<TensorView*> as_, bs_, mma_results_;
|
68
|
+
};
|
69
|
+
|
70
|
+
NVF_API void scheduleMultipleMatmuls(
|
71
|
+
Fusion* fusion,
|
72
|
+
const MatmulParams* mparams);
|
73
|
+
|
74
|
+
} // namespace nvfuser
|
@@ -0,0 +1,48 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <scheduler/heuristic.h>
|
11
|
+
#include <scheduler/registry.h>
|
12
|
+
|
13
|
+
namespace nvfuser {
|
14
|
+
|
15
|
+
class Fusion;
|
16
|
+
class SchedulerRuntimeInfo;
|
17
|
+
class HeuristicDataCache;
|
18
|
+
|
19
|
+
//! NoOp scheduler represents the case where scheduler will
|
20
|
+
//! not do any scheduling operations and forward the un-scheduled
|
21
|
+
//! fusion directly to code generation and kernel compilation.
|
22
|
+
//!
|
23
|
+
//! Typical use case of this scheduler is to handle edge cases
|
24
|
+
//! such as where all tensors are size-1 or size-0.
|
25
|
+
|
26
|
+
class NoOpScheduler : public SchedulerEntry {
|
27
|
+
public:
|
28
|
+
//! Check if the no-op heuristics apply in given fusion
|
29
|
+
bool canScheduleCompileTime(Fusion* fusion) override;
|
30
|
+
|
31
|
+
bool canScheduleRunTime(
|
32
|
+
Fusion* fusion,
|
33
|
+
SchedulerRuntimeInfo& runtime_info,
|
34
|
+
HeuristicDataCache* data_cache = nullptr) override;
|
35
|
+
|
36
|
+
std::unique_ptr<HeuristicParams> computeHeuristics(
|
37
|
+
Fusion* fusion,
|
38
|
+
SchedulerRuntimeInfo& runtime_info,
|
39
|
+
HeuristicDataCache* data_cache) override;
|
40
|
+
|
41
|
+
void schedule(Fusion* fusion, const HeuristicParams* params) override;
|
42
|
+
|
43
|
+
constexpr static SchedulerType schedulerType() {
|
44
|
+
return SchedulerType::NoOp;
|
45
|
+
}
|
46
|
+
};
|
47
|
+
|
48
|
+
} // namespace nvfuser
|
@@ -0,0 +1,49 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/core/ivalue.h>
|
11
|
+
#include <exceptions.h>
|
12
|
+
#include <fusion.h>
|
13
|
+
#include <scheduler/reduction_heuristic.h>
|
14
|
+
#include <scheduler/registry.h>
|
15
|
+
#include <scheduler/utils.h>
|
16
|
+
#include <visibility.h>
|
17
|
+
|
18
|
+
// TODO: If caching inputs would require persistence we are sending it to the
|
19
|
+
// persistent kerenl scheduler. This isn't necessary if the only persistent
|
20
|
+
// buffers are inputs as we could re-read them from global memory. Need to
|
21
|
+
// consider if this is worth implementing.
|
22
|
+
|
23
|
+
namespace nvfuser {
|
24
|
+
|
25
|
+
class SchedulerRuntimeInfo;
|
26
|
+
class HeuristicDataCache;
|
27
|
+
|
28
|
+
class InnerPersistentKernelScheduler : public SchedulerEntry {
|
29
|
+
public:
|
30
|
+
bool canScheduleCompileTime(Fusion* fusion) override;
|
31
|
+
|
32
|
+
bool canScheduleRunTime(
|
33
|
+
Fusion* fusion,
|
34
|
+
SchedulerRuntimeInfo& runtime_info,
|
35
|
+
HeuristicDataCache* data_cache = nullptr) override;
|
36
|
+
|
37
|
+
std::unique_ptr<HeuristicParams> computeHeuristics(
|
38
|
+
Fusion* fusion,
|
39
|
+
SchedulerRuntimeInfo& runtime_info,
|
40
|
+
HeuristicDataCache* data_cache) override;
|
41
|
+
|
42
|
+
void schedule(Fusion* fusion, const HeuristicParams* params) override;
|
43
|
+
|
44
|
+
constexpr static SchedulerType schedulerType() {
|
45
|
+
return SchedulerType::InnerPersistent;
|
46
|
+
}
|
47
|
+
};
|
48
|
+
|
49
|
+
} // namespace nvfuser
|
@@ -0,0 +1,51 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/core/ivalue.h>
|
11
|
+
#include <exceptions.h>
|
12
|
+
#include <fusion.h>
|
13
|
+
#include <scheduler/reduction_heuristic.h>
|
14
|
+
#include <scheduler/registry.h>
|
15
|
+
#include <scheduler/utils.h>
|
16
|
+
|
17
|
+
// TODO: If caching inputs would require persistence we are sending it to the
|
18
|
+
// persistent kerenl scheduler. This isn't necessary if the only persistent
|
19
|
+
// buffers are inputs as we could re-read them from global memory. Need to
|
20
|
+
// consider if this is worth implementing.
|
21
|
+
|
22
|
+
namespace nvfuser {
|
23
|
+
|
24
|
+
class SchedulerRuntimeInfo;
|
25
|
+
class HeuristicDataCache;
|
26
|
+
|
27
|
+
class InnerOuterPersistentKernelScheduler : public SchedulerEntry {
|
28
|
+
public:
|
29
|
+
constexpr static int64_t threads_per_block_min = 128l;
|
30
|
+
constexpr static int64_t threads_per_block_max = 512l;
|
31
|
+
|
32
|
+
void schedule(Fusion* fusion, const HeuristicParams* params) override;
|
33
|
+
|
34
|
+
bool canScheduleCompileTime(Fusion* fusion) override;
|
35
|
+
|
36
|
+
bool canScheduleRunTime(
|
37
|
+
Fusion* fusion,
|
38
|
+
SchedulerRuntimeInfo& runtime_info,
|
39
|
+
HeuristicDataCache* data_cache = nullptr) override;
|
40
|
+
|
41
|
+
constexpr static SchedulerType schedulerType() {
|
42
|
+
return SchedulerType::InnerOuterPersistent;
|
43
|
+
}
|
44
|
+
|
45
|
+
std::unique_ptr<HeuristicParams> computeHeuristics(
|
46
|
+
Fusion* fusion,
|
47
|
+
SchedulerRuntimeInfo& runtime_info,
|
48
|
+
HeuristicDataCache* data_cache) override;
|
49
|
+
};
|
50
|
+
|
51
|
+
} // namespace nvfuser
|
@@ -0,0 +1,48 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/core/ivalue.h>
|
11
|
+
#include <exceptions.h>
|
12
|
+
#include <fusion.h>
|
13
|
+
#include <scheduler/reduction_heuristic.h>
|
14
|
+
#include <scheduler/registry.h>
|
15
|
+
#include <scheduler/utils.h>
|
16
|
+
#include <visibility.h>
|
17
|
+
|
18
|
+
// TODO: If caching inputs would require persistence we are sending it to the
|
19
|
+
// persistent kerenl scheduler. This isn't necessary if the only persistent
|
20
|
+
// buffers are inputs as we could re-read them from global memory. Need to
|
21
|
+
// consider if this is worth implementing.
|
22
|
+
|
23
|
+
namespace nvfuser {
|
24
|
+
|
25
|
+
class SchedulerRuntimeInfo;
|
26
|
+
class HeuristicDataCache;
|
27
|
+
|
28
|
+
class OuterPersistentKernelScheduler : public SchedulerEntry {
|
29
|
+
public:
|
30
|
+
bool canScheduleCompileTime(Fusion* fusion) override;
|
31
|
+
|
32
|
+
bool canScheduleRunTime(
|
33
|
+
Fusion* fusion,
|
34
|
+
SchedulerRuntimeInfo& runtime_info,
|
35
|
+
HeuristicDataCache* data_cache = nullptr) override;
|
36
|
+
|
37
|
+
std::unique_ptr<HeuristicParams> computeHeuristics(
|
38
|
+
Fusion* fusion,
|
39
|
+
SchedulerRuntimeInfo& runtime_info,
|
40
|
+
HeuristicDataCache* data_cache) override;
|
41
|
+
|
42
|
+
void schedule(Fusion* fusion, const HeuristicParams* params) override;
|
43
|
+
|
44
|
+
constexpr static SchedulerType schedulerType() {
|
45
|
+
return SchedulerType::OuterPersistent;
|
46
|
+
}
|
47
|
+
};
|
48
|
+
} // namespace nvfuser
|
@@ -0,0 +1,379 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <ir/all_nodes.h>
|
12
|
+
#include <runtime/executor_params.h>
|
13
|
+
#include <scheduler/reduction_utils.h>
|
14
|
+
#include <scheduler/scheduler_types.h>
|
15
|
+
#include <scheduler/utils.h>
|
16
|
+
#include <cmath>
|
17
|
+
#include <optional>
|
18
|
+
#include <ostream>
|
19
|
+
#include <vector>
|
20
|
+
|
21
|
+
namespace nvfuser {
|
22
|
+
class SchedulerRuntimeInfo;
|
23
|
+
class HeuristicDataCache;
|
24
|
+
|
25
|
+
namespace normalization_scheduler_utils {
|
26
|
+
|
27
|
+
//! Utility class to iterate candidates of launch configurations in a
|
28
|
+
//! preferred order. The iteration order is defined as:
|
29
|
+
//!
|
30
|
+
//! for bdimx in all valid bdimx in an decreasing order
|
31
|
+
//! for gdimy in valid gdimy values in an increasing order
|
32
|
+
//!
|
33
|
+
//! Each of bdimx and gdimy determines bdimy and gdimx, respecitively,
|
34
|
+
//! such that the number of threads per block is always 256 and the
|
35
|
+
//! number of blocks is always equal to the number of SMs.
|
36
|
+
class PreferredLaunchConfig {
|
37
|
+
public:
|
38
|
+
//! Minimum blockDim.x.
|
39
|
+
static constexpr int kMinBdimx = 8;
|
40
|
+
//! Maximum blockDim.x.
|
41
|
+
static constexpr int kMaxBdimx = 16;
|
42
|
+
|
43
|
+
PreferredLaunchConfig();
|
44
|
+
|
45
|
+
int bdimx() const {
|
46
|
+
return bdimx_;
|
47
|
+
}
|
48
|
+
|
49
|
+
int bdimy() const {
|
50
|
+
return bdimy_;
|
51
|
+
}
|
52
|
+
|
53
|
+
int gdimx() const {
|
54
|
+
return gdimxAt(grid_dims_pos_);
|
55
|
+
}
|
56
|
+
|
57
|
+
int gdimy() const {
|
58
|
+
return gdimyAt(grid_dims_pos_);
|
59
|
+
}
|
60
|
+
|
61
|
+
//! Peek the next gdimx. -1 is returned if no further gdimx is available.
|
62
|
+
int peekNextGdimx() const;
|
63
|
+
|
64
|
+
//! Peek the next gdimy. -1 is returned if no further gdimy is available.
|
65
|
+
int peekNextGdimy() const;
|
66
|
+
|
67
|
+
//! Move to the next launch configuration. Will be marked as invalid
|
68
|
+
//! if no valid configuration exists. Return true if successfully moved.
|
69
|
+
bool moveToNextConfig();
|
70
|
+
|
71
|
+
//! Try setting blockDim to the next valid config if
|
72
|
+
//! available. Return false if no valid config exists. gridDim is
|
73
|
+
//! reset.
|
74
|
+
bool moveToNextBdim();
|
75
|
+
|
76
|
+
//! Query if the next configuration will cause blockDim.x to become
|
77
|
+
//! smaller.
|
78
|
+
bool isNextSmallerBdimx() const;
|
79
|
+
|
80
|
+
//! Query if blockDim.x can be further lowered
|
81
|
+
bool canLowerBdimx() const;
|
82
|
+
|
83
|
+
//! Query if no valid configuration is found
|
84
|
+
bool isInvalid() const {
|
85
|
+
return !valid_;
|
86
|
+
}
|
87
|
+
|
88
|
+
private:
|
89
|
+
//! Populate the list of valid gridDim configurations
|
90
|
+
void initValidGdims();
|
91
|
+
|
92
|
+
int gdimxAt(int pos) const {
|
93
|
+
return valid_grid_dims_.at(pos).first;
|
94
|
+
}
|
95
|
+
|
96
|
+
int gdimyAt(int pos) const {
|
97
|
+
return valid_grid_dims_.at(pos).second;
|
98
|
+
}
|
99
|
+
|
100
|
+
//! Set blockDim.x and in turn blockDim.y. Return true if the
|
101
|
+
//! specified blockDim.x is successfully set. If dry_run is true,
|
102
|
+
//! just check if the given config is valid but do not modify the
|
103
|
+
//! current config.
|
104
|
+
bool setBdimx(int bdimx, bool dry_run = false);
|
105
|
+
|
106
|
+
void resetGdim() {
|
107
|
+
grid_dims_pos_ = 0;
|
108
|
+
}
|
109
|
+
|
110
|
+
void resetBdim() {
|
111
|
+
// Start with the maximum bdimx and lower it until satisfactory
|
112
|
+
// config is found
|
113
|
+
setBdimx(kMaxBdimx);
|
114
|
+
}
|
115
|
+
|
116
|
+
//! Try setting gridDim to the next valid config if
|
117
|
+
//! available. Return false if no valid config exists
|
118
|
+
bool moveToNextGdim();
|
119
|
+
|
120
|
+
int getNextGdimsPos() const;
|
121
|
+
|
122
|
+
void invalidate() {
|
123
|
+
valid_ = false;
|
124
|
+
}
|
125
|
+
|
126
|
+
friend std::ostream& operator<<(std::ostream& os, PreferredLaunchConfig cfg) {
|
127
|
+
os << "{gdimx: " << cfg.gdimx() << ", gdimy: " << cfg.gdimy()
|
128
|
+
<< ", bdimx: " << cfg.bdimx() << ", bdimy: " << cfg.bdimy() << "}";
|
129
|
+
return os;
|
130
|
+
}
|
131
|
+
|
132
|
+
private:
|
133
|
+
//! Remember if it is still a valid configuration
|
134
|
+
bool valid_ = false;
|
135
|
+
|
136
|
+
//! List of valid gridDims ordered by the dimension of
|
137
|
+
//! gridDim.x. Larger gridDim.x is preferred as it would promote
|
138
|
+
//! larger independent parallelism
|
139
|
+
std::vector<std::pair<int, int>> valid_grid_dims_;
|
140
|
+
//! The offset of the Current gridDim in valid_grid_dims_
|
141
|
+
int grid_dims_pos_ = 0;
|
142
|
+
|
143
|
+
//! Current blockDim.x
|
144
|
+
int bdimx_ = 0;
|
145
|
+
//! Current blockDim.y
|
146
|
+
int bdimy_ = 0;
|
147
|
+
};
|
148
|
+
|
149
|
+
//! Scheduling parameters for grid outer normalization
|
150
|
+
struct GridOuterNormalizationParams {
|
151
|
+
LaunchParams launch_params;
|
152
|
+
int64_t persistent_buffer_factor = -1;
|
153
|
+
int64_t unswitch_factor = -1;
|
154
|
+
};
|
155
|
+
|
156
|
+
std::optional<GridOuterNormalizationParams> getGridOuterNormalizationParams(
|
157
|
+
int64_t total_reduction_numel,
|
158
|
+
int64_t total_iteration_numel,
|
159
|
+
int64_t vectorize_factor,
|
160
|
+
int64_t persistent_buffer_size);
|
161
|
+
|
162
|
+
//! check iter type of each domain in inner and outer reduction tvs
|
163
|
+
//! inner reduction must be [I,I,...R,R]
|
164
|
+
//! outer reduction must be [R,R,...I,I]
|
165
|
+
bool checkIfReductionsAreInnerOuter(
|
166
|
+
const std::vector<TensorView*>& inner_reduction_tvs,
|
167
|
+
const std::vector<TensorView*>& outer_reduction_tvs);
|
168
|
+
|
169
|
+
//! check if the inner reduction has shared input with outer reduction
|
170
|
+
bool hasSharedInput(
|
171
|
+
const std::vector<TensorView*>& inner_reduction_tvs,
|
172
|
+
const std::vector<TensorView*>& outer_reduction_tvs);
|
173
|
+
|
174
|
+
//! The first part of outer reduction is computed with inner reduction and the
|
175
|
+
//! second part is scheduled separately. So, (1) the outer reduction tvs can
|
176
|
+
//! only be connected with inner reduction tvs through their producers. (2)
|
177
|
+
//! Outer reduction tvs are also scheduled separately and they can only be
|
178
|
+
//! connected through their producers.
|
179
|
+
bool isConnectedOnlyThroughReductionProducer(
|
180
|
+
const std::vector<TensorView*>& inner_reduction_tvs,
|
181
|
+
const std::vector<TensorView*>& outer_reduction_tvs);
|
182
|
+
|
183
|
+
// Returns true if every iteration domain in inner reduction tv is a reduction
|
184
|
+
// domain in outer reduction tv.
|
185
|
+
bool isReductionIterationAxisMatched(
|
186
|
+
const std::vector<TensorView*>& inner_reduction_tvs,
|
187
|
+
const std::vector<TensorView*>& outer_reduction_tvs);
|
188
|
+
|
189
|
+
//! in combined_inner_outer_reduction, the partial results of outer reductions
|
190
|
+
//! must be persistent, calculate the size of these buffers when estimate
|
191
|
+
//! register usage
|
192
|
+
int64_t partialReductionBufferSize(
|
193
|
+
const std::vector<TensorView*>& outer_reduction_tvs,
|
194
|
+
SchedulerRuntimeInfo& runtime_info);
|
195
|
+
|
196
|
+
// Return a scheduleHeuristic based on reduction types.
|
197
|
+
using ReductionType = reduction_scheduler_utils::ReductionType;
|
198
|
+
SchedulerType getPersistentHeuristicFor(ReductionType reduction_type);
|
199
|
+
|
200
|
+
struct PersistentKernelProperties {
|
201
|
+
int64_t inner_most_dimension_numel;
|
202
|
+
int64_t total_reduction_numel;
|
203
|
+
int64_t total_iteration_numel;
|
204
|
+
int64_t max_persistent_buffer_size;
|
205
|
+
int64_t n_tensor_inputs;
|
206
|
+
int64_t max_dtype_size;
|
207
|
+
int64_t vectorize_factor;
|
208
|
+
bool project_persistent_buffers;
|
209
|
+
PrimDataType index_type;
|
210
|
+
bool has_exp_op;
|
211
|
+
bool has_rng_op;
|
212
|
+
bool disable_project_to_avoid_recompute;
|
213
|
+
std::vector<TensorView*> persistent_buffers;
|
214
|
+
std::string toString() const {
|
215
|
+
std::stringstream ss;
|
216
|
+
ss << "===== Persistent Kernel Properties ========\n"
|
217
|
+
<< "inner_most_dimension_numel: " << inner_most_dimension_numel << "\n"
|
218
|
+
<< "total_reduction_numel: " << total_reduction_numel << "\n"
|
219
|
+
<< "total_iteration_numel: " << total_iteration_numel << "\n"
|
220
|
+
<< "max_persistent_buffer_size: " << max_persistent_buffer_size << "\n"
|
221
|
+
<< "n_tensor_inputs: " << n_tensor_inputs << "\n"
|
222
|
+
<< "max_input_dtype_size: " << max_dtype_size << "\n"
|
223
|
+
<< "max allowed vectorize_factor: " << vectorize_factor << "\n"
|
224
|
+
<< "disable_project_to_avoid_recompute: "
|
225
|
+
<< disable_project_to_avoid_recompute << "\n"
|
226
|
+
<< "project_persistent_buffers: " << project_persistent_buffers << "\n";
|
227
|
+
return ss.str();
|
228
|
+
}
|
229
|
+
};
|
230
|
+
PersistentKernelProperties getPersistentKernelProperties(
|
231
|
+
Fusion* fusion,
|
232
|
+
SchedulerRuntimeInfo& runtime_info,
|
233
|
+
HeuristicDataCache* data_cache,
|
234
|
+
SchedulerType heuristic);
|
235
|
+
|
236
|
+
// Verify the presence of a reduction TensorView connected to a Fusion input
|
237
|
+
void checkReductionTvForScheduling(Fusion* fusion, TensorView* ref_red_tv);
|
238
|
+
|
239
|
+
// Check the operations and input tensors of the fusion. This
|
240
|
+
// verification is a common step shared by all persistent kernel implementations
|
241
|
+
// during compile-time checks.
|
242
|
+
bool checkOpsAndInputs(Fusion* fusion, SchedulerType scheduler_type);
|
243
|
+
|
244
|
+
// Returns true if the reduction pattern is consistent. For the
|
245
|
+
// InnerPersistentKernelScheduler and OuterPersistentKernelScheduler, a single
|
246
|
+
// vector of TensorViews is provided, while for the
|
247
|
+
// InnerOuterPersistentKernelScheduler, two vectors of TensorViews are provided.
|
248
|
+
bool checkReductionPattern(
|
249
|
+
Fusion* fusion,
|
250
|
+
SchedulerType scheduler_type,
|
251
|
+
const std::vector<TensorView*>& reduction_tvs1,
|
252
|
+
const std::vector<TensorView*>& reduction_tvs2 = {});
|
253
|
+
|
254
|
+
// The compile-time checks for both the InnerPersistentKernelScheduler and
|
255
|
+
// OuterPersistentKernelScheduler are identical. These checks are constructed
|
256
|
+
// using checkOpsAndInputs, checkReductionPattern, and checkViewBufferTopology.
|
257
|
+
bool compileTimeCheck(Fusion* fusion, SchedulerType scheduler_type);
|
258
|
+
|
259
|
+
// Common preparations before the actual schedule, used by all persistent
|
260
|
+
// schedulers. Write to dummy_outputs, cached_inputs, reduction_tvs, and
|
261
|
+
// cached_outputs.
|
262
|
+
void beforeSchedule(
|
263
|
+
Fusion* fusion,
|
264
|
+
const ReductionParams* rparams,
|
265
|
+
std::vector<TensorView*>& dummy_outputs,
|
266
|
+
std::vector<TensorView*>& cached_inputs,
|
267
|
+
std::vector<TensorView*>& reduction_tvs,
|
268
|
+
std::vector<TensorView*>& smem_consumers,
|
269
|
+
std::vector<std::pair<TensorView*, TensorView*>>& cached_outputs);
|
270
|
+
|
271
|
+
// schedule a reduction tv, used by all persistent schedulers.
|
272
|
+
// will group reduction ops for OuterPersistentKernelScheduler with multiple
|
273
|
+
// reduction tvs.
|
274
|
+
TensorView* scheduleReductionGeneral(
|
275
|
+
Fusion* fusion,
|
276
|
+
const ReductionParams* rparams,
|
277
|
+
std::vector<TensorView*>& reduction_tvs,
|
278
|
+
SchedulerType scheduler_type);
|
279
|
+
|
280
|
+
// Used by InnerPersistentKernelScheduler and OuterPersistentKernelScheduler
|
281
|
+
void schedulePersistentKernel(
|
282
|
+
Fusion* fusion,
|
283
|
+
const ReductionParams* rparams,
|
284
|
+
SchedulerType scheduler_type);
|
285
|
+
|
286
|
+
// Get max register or shared memory size for persistent buffer
|
287
|
+
int64_t getMaxRegOrSharedMemorySizeForPersistentBuffer(
|
288
|
+
SchedulerRuntimeInfo& runtime_info,
|
289
|
+
const std::vector<TensorView*>& persistent_buffers,
|
290
|
+
const bool can_use_smem_persistent);
|
291
|
+
|
292
|
+
enum class BufferProjectionStrategy {
|
293
|
+
// Recompute persistent buffers from inputs, only need to cache inputs in
|
294
|
+
// registers or shared memories, usually used when size of required cached
|
295
|
+
// inputs is smaller than the size of persistent buffers.
|
296
|
+
ProjectToInputs,
|
297
|
+
// Don't project to inputs, to avoid recompute from inputs. This saves
|
298
|
+
// computation cost but uses more registers or shared memories. Usually used
|
299
|
+
// when the required buffer size is small and hardware has high bandwidth to
|
300
|
+
// flops ratio.
|
301
|
+
NoProjectToAvoidRecompute,
|
302
|
+
// Project to inputs is disabled due to other reasons, e.g. can't reduce
|
303
|
+
// buffer size, recompute requires very expensive rng ops, not supported due
|
304
|
+
// to view ops.
|
305
|
+
NoProjectOtherReasons
|
306
|
+
};
|
307
|
+
|
308
|
+
// Returns BufferProjectionStrategy based on buffer size, hardware, and fusion
|
309
|
+
// ops.
|
310
|
+
|
311
|
+
// This function is used by inner persistent and InnerOuter persistent
|
312
|
+
// schedulers.
|
313
|
+
// Using shared memory to store persistent buffers is not supported yet for
|
314
|
+
// inner persistent scheduler with 3D reduction type.
|
315
|
+
// TODO: Outer persistent scheduler should also use this function.
|
316
|
+
// If the scheduler is innerOuter with outer broadcast, projection is allowed
|
317
|
+
// even it leads to a larger buffer size becuase the scheduled kernel allows the
|
318
|
+
// reuse of the outer broadcast Tv when iterating over the outer reduction
|
319
|
+
// dimension and leads to higher performance ( TODO: needs re-evaluate, may not
|
320
|
+
// true if the buffer size is increased a lot when projecting to inputs). See
|
321
|
+
// https://github.com/NVIDIA/Fuser/issues/402
|
322
|
+
|
323
|
+
// However, we experimentally found that certain relatively expensive operations
|
324
|
+
// should not be projected even when that would require a larger buffer size.
|
325
|
+
// Specifically,
|
326
|
+
// - rng: should never be projected no matter how much larger the buffer would
|
327
|
+
// consume
|
328
|
+
// - exp in inner normalization: only allowed to get projected if the buffer is
|
329
|
+
// smaller than a certain size Otherwise, as long as the projected inputs are
|
330
|
+
// smaller than the original persistent buffers, this function returns true.
|
331
|
+
BufferProjectionStrategy isProjectBufferToInputs(
|
332
|
+
Fusion* fusion,
|
333
|
+
SchedulerRuntimeInfo& runtime_info,
|
334
|
+
const scheduler_utils::PersistentBufferInfo& persistent_buffer_info,
|
335
|
+
const scheduler_utils::PersistentBufferSizeReturn&
|
336
|
+
persistent_buffer_size_info,
|
337
|
+
const SchedulerType sh,
|
338
|
+
const bool can_use_smem_persistent,
|
339
|
+
const bool check_projected_buffer_size = true);
|
340
|
+
|
341
|
+
// Set memory type of persistent buffer marked in
|
342
|
+
// rparams->smem_persistent_buffers as shared memory. Return a vector of the
|
343
|
+
// consumers of the shared memory tensors, they are cached after the smem
|
344
|
+
// tensors and will be vectorized by the scheduler if possible to avoid shared
|
345
|
+
// memory bank conflicts.
|
346
|
+
std::vector<TensorView*> movePersistentBufferToSmem(
|
347
|
+
Fusion* fusion,
|
348
|
+
const ReductionParams* rparams,
|
349
|
+
const std::vector<TensorView*>& cached_inputs);
|
350
|
+
|
351
|
+
// Find the resolution points of a persistent buffer. See also
|
352
|
+
// the comments of PersistentBufferResolution in utils.cpp. Unlike
|
353
|
+
// PersistentBufferResolution, this analysis traverses a given fusion
|
354
|
+
// both forward and backward, which is necessary in some cases. For
|
355
|
+
// example:
|
356
|
+
//
|
357
|
+
// t0 = makeSymbolicTensor(2)
|
358
|
+
// t1 = makeSymbolicTensor(2)
|
359
|
+
// t2 = set(t0)
|
360
|
+
// t3 = sum(t2, 1)
|
361
|
+
// t4 = broadcast(t3, {false, true})
|
362
|
+
// t5 = add(t1, t2)
|
363
|
+
// t6 = add(t4, t1)
|
364
|
+
// fusion.addOutput(t5)
|
365
|
+
// fusion.addOutput(t6)
|
366
|
+
//
|
367
|
+
// The path from t2 to t3, t4 and t6 is a normalization path. While t1 itself
|
368
|
+
// does not depend on t2, since it is used with t2, inlining of t2
|
369
|
+
// also means t1 must be inlined, which in turn means t6 must be
|
370
|
+
// inlined. However, t6 depends on the reduction, inlining of t2 is
|
371
|
+
// not possible. For normalization fusions like this pattern,
|
372
|
+
// PersistentBufferResolution is not able to detect the resolution
|
373
|
+
// point. getResolutionPointsOf addresses the problem by traversing
|
374
|
+
// both forward and backward directions. See
|
375
|
+
// PersistentBufferTest.GetResolutionIssue1123 for a concrete example
|
376
|
+
std::vector<TensorView*> getResolutionPointsOf(TensorView* persistent_buffer);
|
377
|
+
|
378
|
+
} // namespace normalization_scheduler_utils
|
379
|
+
} // namespace nvfuser
|