nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,339 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <scheduler/heuristic.h>
|
11
|
+
|
12
|
+
#include <sstream>
|
13
|
+
|
14
|
+
namespace nvfuser {
|
15
|
+
|
16
|
+
class TensorView;
|
17
|
+
|
18
|
+
// Parameters of the reduction heuristic to describe the optimal schedule.
|
19
|
+
// Warning: equal operator is intended for use in caching the kernel associated
|
20
|
+
// with these reduction parameters. It does not check if the launch parameters
|
21
|
+
// are equivelent!
|
22
|
+
class ReductionParams : public HeuristicParams {
|
23
|
+
public:
|
24
|
+
// Note that heuristictype can be different from SchedulerType::Reduction
|
25
|
+
// since ReductionParams is also used by, e.g., normalization schedulers.
|
26
|
+
ReductionParams(SchedulerType scheduler_type = SchedulerType::Reduction)
|
27
|
+
: HeuristicParams(scheduler_type) {};
|
28
|
+
// Reducing inner most dimension?
|
29
|
+
bool fastest_dim = false;
|
30
|
+
|
31
|
+
// Store input in shared memory or registers to reduce global memory reads
|
32
|
+
bool persistent_kernel = false;
|
33
|
+
|
34
|
+
// Project persistent buffers back to inputs to reduce persistent buffer size
|
35
|
+
bool project_persistent_buffers = false;
|
36
|
+
|
37
|
+
// Are we treating the scheduling as 3 dimensional, can be useful for patterns
|
38
|
+
// like [reduction, iteration, reduction].
|
39
|
+
bool schedule_3D = false;
|
40
|
+
|
41
|
+
// For outer reductions we may want to swap the gdimx and gdimy bindings to
|
42
|
+
// amortize the cost of the final cleanup in grid reductions.
|
43
|
+
bool flip_grid = false;
|
44
|
+
|
45
|
+
// Inner Reduction Domain:
|
46
|
+
|
47
|
+
// Reduce across the block?
|
48
|
+
bool cross_block_inner_reduction = false;
|
49
|
+
// Reduce across the grid?
|
50
|
+
bool cross_grid_inner_reduction = false;
|
51
|
+
// Unrolling/Vectorization factor for inner reduction dimension
|
52
|
+
int64_t unroll_factor_inner_reduction = 1;
|
53
|
+
|
54
|
+
// Extra unroll on top of vectorization
|
55
|
+
int64_t unroll_factor_top_of_vectorization = 1;
|
56
|
+
|
57
|
+
// vectorize instead of unroll
|
58
|
+
bool vectorize_inner_reduction = false;
|
59
|
+
// Split grid dim for iteration axis in case it's too large for cuda
|
60
|
+
bool split_grid_dim_inner_reduction = false;
|
61
|
+
// Pad inner dimension to nearest warp
|
62
|
+
bool pad_inner_reduction_to_warp = false;
|
63
|
+
// Register persistent buffer size in inner dimension
|
64
|
+
int64_t batches_per_block_inner_reduction = 1;
|
65
|
+
|
66
|
+
// Which block parallel dimension should be used for the inner reduction.
|
67
|
+
// !!WARNING!! Convenience method, this be unique based on non-parallel type
|
68
|
+
// parameters, not used for equivalence/hashing.
|
69
|
+
ParallelType block_dim_inner_reduction = ParallelType::Serial;
|
70
|
+
// Which grid parallel dimension should be used for the inner reduction.
|
71
|
+
// !!WARNING!! Convenience method, this be unique based on non-parallel type
|
72
|
+
// parameters, not used for equivalence/hashing.
|
73
|
+
ParallelType grid_dim_inner_reduction = ParallelType::Serial;
|
74
|
+
|
75
|
+
// Iteration Domain:
|
76
|
+
|
77
|
+
// Perform multiple reductions per block?
|
78
|
+
bool multiple_reds_per_blk = false;
|
79
|
+
// Unrolling/Vectorization factor for iteration dimension
|
80
|
+
int64_t unroll_factor_iter_dom = 1;
|
81
|
+
// vectorize instead of unroll
|
82
|
+
bool vectorize_iter_dom = false;
|
83
|
+
// Inner split grid dim for iteration axis in case it's too large for cuda
|
84
|
+
bool split_grid_dim_iter_dom_inner = false;
|
85
|
+
// Outer split grid dim for iteration axis in case it's too large for cuda
|
86
|
+
bool split_grid_dim_iter_dom_outer = false;
|
87
|
+
|
88
|
+
// Which block parallel dimension should be used for the iter domain.
|
89
|
+
// !!WARNING!! Convenience method, this be unique based on non-parallel type
|
90
|
+
// parameters, not used for equivalence/hashing.
|
91
|
+
ParallelType block_dim_iter_dom = ParallelType::Serial;
|
92
|
+
// Which grid parallel dimension should be used for the iter domain.
|
93
|
+
// !!WARNING!! Convenience method, this be unique based on non-parallel type
|
94
|
+
// parameters, not used for equivalence/hashing.
|
95
|
+
ParallelType grid_dim_iter_dom = ParallelType::Serial;
|
96
|
+
|
97
|
+
// Outer Reduction Domain if 3D Scheduled:
|
98
|
+
|
99
|
+
// Reduce across the block?
|
100
|
+
bool cross_block_outer_reduction = false;
|
101
|
+
// Reduce across the grid?
|
102
|
+
bool cross_grid_outer_reduction = false;
|
103
|
+
// Split grid dim for iteration axis in case it's too large for cuda
|
104
|
+
bool split_grid_dim_outer_reduction = false;
|
105
|
+
// Register persistent buffer size in outer dimension
|
106
|
+
int64_t batches_per_block_outer_reduction = 1;
|
107
|
+
// Unrolling/Vectorization factor for outer reduction factor
|
108
|
+
int64_t unroll_factor_outer_reduction = 1;
|
109
|
+
|
110
|
+
// Which block parallel dimension should be used for the outer reduction.
|
111
|
+
// !!WARNING!! Convenience method, this be unique based on non-parallel type
|
112
|
+
// parameters, not used for equivalence/hashing.
|
113
|
+
ParallelType block_dim_outer_reduction = ParallelType::Serial;
|
114
|
+
// Which grid parallel dimension should be used for the outer reduction.
|
115
|
+
// !!WARNING!! Convenience method, this be unique based on non-parallel type
|
116
|
+
// parameters, not used for equivalence/hashing.
|
117
|
+
ParallelType grid_dim_outer_reduction = ParallelType::Serial;
|
118
|
+
|
119
|
+
// Use computeWith to persistent buffers
|
120
|
+
bool compute_persistent_buffer_with_first_consumer = false;
|
121
|
+
|
122
|
+
bool static_bdimx = false;
|
123
|
+
bool static_bdimy = false;
|
124
|
+
|
125
|
+
bool isUnrolled() const {
|
126
|
+
return unroll_factor_inner_reduction > 1 || unroll_factor_iter_dom > 1 ||
|
127
|
+
unroll_factor_outer_reduction > 1;
|
128
|
+
}
|
129
|
+
|
130
|
+
// specific to combined inner and outer reduction
|
131
|
+
bool combined_inner_outer = false;
|
132
|
+
// use TIDx for out reduction axis
|
133
|
+
bool tidx_for_outer_reduction = false;
|
134
|
+
// pad outer reduction to warp
|
135
|
+
bool pad_outer_reduction_to_warp = false;
|
136
|
+
// in outer reduction part of inner-outer persistent scheduler, may further
|
137
|
+
// split inner dim by grid
|
138
|
+
bool combined_split_grid_inner_dim = false;
|
139
|
+
// partial result of outer reduction is written to gmem then read back in a
|
140
|
+
// different parallel pattern set the vectorization factor of its read and
|
141
|
+
// write
|
142
|
+
int64_t vectorization_factor_outer = 1;
|
143
|
+
int64_t vectorization_factor_tmp_gmem_write = 1;
|
144
|
+
// inner reduction axis is parallelized by block_dim_inner_reduction (usually
|
145
|
+
// TIDx) the remaining part is further parallelized by
|
146
|
+
// block_dim_inner_reduction_extra (usually TIDy)
|
147
|
+
ParallelType block_dim_inner_reduction_extra = ParallelType::Serial;
|
148
|
+
|
149
|
+
// vector stores buffer should be moved to shared memory.
|
150
|
+
// TODO: For innerOuterPersistentHeuristic, only the persistent tensors in the
|
151
|
+
// original fusion definition may be moved to shared memory, the intermediate
|
152
|
+
// persistent tensors which are creased by the scheduler to store the partial
|
153
|
+
// outer reduction results are always stored in registers because they are
|
154
|
+
// frequently accessed by both read and write. The code can be extended to
|
155
|
+
// allow the move of these intermediate persistent tensors to shared memory
|
156
|
+
// when the shared memory is much larger than the register file.
|
157
|
+
std::vector<TensorView*> smem_persistent_buffers;
|
158
|
+
|
159
|
+
public:
|
160
|
+
using HeuristicParams::HeuristicParams;
|
161
|
+
|
162
|
+
// Warning: Does not check launch parameters!
|
163
|
+
bool sameAs(const HeuristicParams* other_base) const override {
|
164
|
+
auto other = dynamic_cast<const ReductionParams*>(other_base);
|
165
|
+
if (other == nullptr) {
|
166
|
+
return false;
|
167
|
+
}
|
168
|
+
|
169
|
+
bool attr_equal = other->cparams == cparams &&
|
170
|
+
other->fastest_dim == fastest_dim &&
|
171
|
+
other->persistent_kernel == persistent_kernel &&
|
172
|
+
other->project_persistent_buffers == project_persistent_buffers &&
|
173
|
+
other->schedule_3D == schedule_3D && other->flip_grid == flip_grid &&
|
174
|
+
other->cross_block_inner_reduction == cross_block_inner_reduction &&
|
175
|
+
other->cross_grid_inner_reduction == cross_grid_inner_reduction &&
|
176
|
+
other->unroll_factor_inner_reduction == unroll_factor_inner_reduction &&
|
177
|
+
other->vectorize_inner_reduction == vectorize_inner_reduction &&
|
178
|
+
other->split_grid_dim_inner_reduction ==
|
179
|
+
split_grid_dim_inner_reduction &&
|
180
|
+
other->pad_inner_reduction_to_warp == pad_inner_reduction_to_warp &&
|
181
|
+
other->batches_per_block_inner_reduction ==
|
182
|
+
batches_per_block_inner_reduction &&
|
183
|
+
other->multiple_reds_per_blk == multiple_reds_per_blk &&
|
184
|
+
other->unroll_factor_iter_dom == unroll_factor_iter_dom &&
|
185
|
+
other->vectorize_iter_dom == vectorize_iter_dom &&
|
186
|
+
other->split_grid_dim_iter_dom_inner == split_grid_dim_iter_dom_inner &&
|
187
|
+
other->split_grid_dim_iter_dom_outer == split_grid_dim_iter_dom_outer &&
|
188
|
+
other->cross_block_outer_reduction == cross_block_outer_reduction &&
|
189
|
+
other->cross_grid_outer_reduction == cross_grid_outer_reduction &&
|
190
|
+
other->unroll_factor_outer_reduction == unroll_factor_outer_reduction &&
|
191
|
+
other->split_grid_dim_outer_reduction ==
|
192
|
+
split_grid_dim_outer_reduction &&
|
193
|
+
other->batches_per_block_outer_reduction ==
|
194
|
+
batches_per_block_outer_reduction &&
|
195
|
+
other->compute_persistent_buffer_with_first_consumer ==
|
196
|
+
compute_persistent_buffer_with_first_consumer &&
|
197
|
+
other->combined_inner_outer == combined_inner_outer &&
|
198
|
+
other->tidx_for_outer_reduction == tidx_for_outer_reduction &&
|
199
|
+
other->pad_outer_reduction_to_warp == pad_outer_reduction_to_warp &&
|
200
|
+
other->vectorization_factor_outer == vectorization_factor_outer &&
|
201
|
+
other->combined_split_grid_inner_dim == combined_split_grid_inner_dim &&
|
202
|
+
other->unroll_factor_top_of_vectorization ==
|
203
|
+
unroll_factor_top_of_vectorization &&
|
204
|
+
other->vectorization_factor_tmp_gmem_write ==
|
205
|
+
vectorization_factor_tmp_gmem_write;
|
206
|
+
|
207
|
+
if (other->static_bdimy || static_bdimy) {
|
208
|
+
attr_equal = attr_equal && other->lparams.bdimy() == lparams.bdimy();
|
209
|
+
}
|
210
|
+
if (other->static_bdimx || static_bdimx) {
|
211
|
+
attr_equal = attr_equal && other->lparams.bdimx() == lparams.bdimx();
|
212
|
+
}
|
213
|
+
return attr_equal;
|
214
|
+
}
|
215
|
+
|
216
|
+
std::string toString() const override {
|
217
|
+
std::stringstream ss;
|
218
|
+
ss << "\n===== Reduction Parameters ========\n"
|
219
|
+
<< (tag.empty() ? "" : "Tag: ") << tag << "\n"
|
220
|
+
<< (fastest_dim ? "Red On Fastest Dim\n" : "Red On Slow Dim\n")
|
221
|
+
<< (persistent_kernel ? "Persistent Kernel\n" : "")
|
222
|
+
<< (project_persistent_buffers ? "Project Persistent Buffers\n" : "");
|
223
|
+
if (batches_per_block_inner_reduction > 1 || persistent_kernel) {
|
224
|
+
ss << "Batches per block: " << batches_per_block_inner_reduction << "\n";
|
225
|
+
}
|
226
|
+
|
227
|
+
if (schedule_3D) {
|
228
|
+
ss << "3D Schedule\n"
|
229
|
+
<< "Outer Reduction: ";
|
230
|
+
if (cross_block_outer_reduction) {
|
231
|
+
ss << "cross block - " << block_dim_outer_reduction << " / ";
|
232
|
+
}
|
233
|
+
if (cross_grid_outer_reduction) {
|
234
|
+
ss << "cross grid - " << grid_dim_outer_reduction << " / ";
|
235
|
+
ss << (split_grid_dim_outer_reduction ? "split grid dim / " : "");
|
236
|
+
}
|
237
|
+
|
238
|
+
ss << (unroll_factor_outer_reduction > 1 ? "unroll / " : "");
|
239
|
+
if (unroll_factor_outer_reduction > 1) {
|
240
|
+
ss << "factor " << unroll_factor_outer_reduction << " ";
|
241
|
+
}
|
242
|
+
|
243
|
+
if (batches_per_block_outer_reduction > 1 || persistent_kernel) {
|
244
|
+
ss << "persistent batch - " << batches_per_block_outer_reduction;
|
245
|
+
}
|
246
|
+
}
|
247
|
+
|
248
|
+
ss << "\nIteration Domain: ";
|
249
|
+
|
250
|
+
if (grid_dim_iter_dom != ParallelType::Serial) {
|
251
|
+
ss << grid_dim_iter_dom << " / ";
|
252
|
+
if (split_grid_dim_iter_dom_outer) {
|
253
|
+
ss << "split grid dimension outer / ";
|
254
|
+
} else if (split_grid_dim_iter_dom_inner) {
|
255
|
+
ss << "split grid dimension inner / ";
|
256
|
+
}
|
257
|
+
}
|
258
|
+
if (block_dim_iter_dom != ParallelType::Serial) {
|
259
|
+
ss << block_dim_iter_dom << " / ";
|
260
|
+
}
|
261
|
+
ss << (multiple_reds_per_blk ? "multiple reductions per block / " : "")
|
262
|
+
<< (vectorize_iter_dom ? "vectorize / " : "")
|
263
|
+
<< (unroll_factor_iter_dom > 1 && !vectorize_iter_dom ? "unroll / "
|
264
|
+
: "");
|
265
|
+
if (unroll_factor_iter_dom > 1) {
|
266
|
+
ss << "factor " << unroll_factor_iter_dom;
|
267
|
+
}
|
268
|
+
|
269
|
+
ss << "\nInner Reduction Domain: ";
|
270
|
+
|
271
|
+
if (cross_block_inner_reduction) {
|
272
|
+
ss << "cross block - " << block_dim_inner_reduction << " / ";
|
273
|
+
ss << (pad_inner_reduction_to_warp ? " pad to warp / " : "");
|
274
|
+
}
|
275
|
+
if (cross_grid_inner_reduction) {
|
276
|
+
ss << "cross grid - " << grid_dim_inner_reduction << " / ";
|
277
|
+
ss << (split_grid_dim_inner_reduction ? "split grid dim / " : "");
|
278
|
+
}
|
279
|
+
if (batches_per_block_inner_reduction > 1 || persistent_kernel) {
|
280
|
+
ss << "persistent batch - " << batches_per_block_inner_reduction << " / ";
|
281
|
+
}
|
282
|
+
ss << (cross_grid_inner_reduction && split_grid_dim_inner_reduction
|
283
|
+
? "split grid dimension / "
|
284
|
+
: "")
|
285
|
+
<< (vectorize_inner_reduction ? "vectorize / " : "")
|
286
|
+
<< (unroll_factor_inner_reduction > 1 && !vectorize_inner_reduction
|
287
|
+
? "unroll / "
|
288
|
+
: "");
|
289
|
+
if (unroll_factor_inner_reduction > 1) {
|
290
|
+
ss << "factor " << unroll_factor_inner_reduction;
|
291
|
+
}
|
292
|
+
|
293
|
+
if (compute_persistent_buffer_with_first_consumer) {
|
294
|
+
ss << "\ncomputeWith persistent buffers";
|
295
|
+
}
|
296
|
+
|
297
|
+
ss << "\n" << lparams.toString();
|
298
|
+
ss << cparams.toString() << "\n";
|
299
|
+
ss << "====================================\n";
|
300
|
+
return ss.str();
|
301
|
+
}
|
302
|
+
|
303
|
+
// Warning: Hash is not based on launch parameters!
|
304
|
+
size_t hash() const override {
|
305
|
+
constexpr size_t bits = sizeof(std::size_t) * 8;
|
306
|
+
size_t attr_hash = static_cast<size_t>(fastest_dim) << (bits - 1) ^
|
307
|
+
static_cast<size_t>(persistent_kernel) << (bits - 2) ^
|
308
|
+
static_cast<size_t>(project_persistent_buffers) << (bits - 3) ^
|
309
|
+
static_cast<size_t>(schedule_3D) << (bits - 4) ^
|
310
|
+
static_cast<size_t>(flip_grid) << (bits - 5) ^
|
311
|
+
static_cast<size_t>(cross_block_inner_reduction) << (bits - 6) ^
|
312
|
+
static_cast<size_t>(cross_grid_inner_reduction) << (bits - 7) ^
|
313
|
+
static_cast<size_t>(unroll_factor_inner_reduction) << (bits - 8) ^
|
314
|
+
static_cast<size_t>(vectorize_inner_reduction) << (bits - 9) ^
|
315
|
+
static_cast<size_t>(split_grid_dim_inner_reduction) << (bits - 10) ^
|
316
|
+
static_cast<size_t>(pad_inner_reduction_to_warp) << (bits - 11) ^
|
317
|
+
static_cast<size_t>(batches_per_block_inner_reduction) << (bits - 12) ^
|
318
|
+
static_cast<size_t>(multiple_reds_per_blk) << (bits - 13) ^
|
319
|
+
static_cast<size_t>(unroll_factor_iter_dom) << (bits - 14) ^
|
320
|
+
static_cast<size_t>(vectorize_iter_dom) << (bits - 15) ^
|
321
|
+
static_cast<size_t>(split_grid_dim_iter_dom_outer) << (bits - 16) ^
|
322
|
+
static_cast<size_t>(split_grid_dim_iter_dom_inner) << (bits - 17) ^
|
323
|
+
static_cast<size_t>(cross_block_outer_reduction) << (bits - 18) ^
|
324
|
+
static_cast<size_t>(cross_grid_outer_reduction) << (bits - 19) ^
|
325
|
+
static_cast<size_t>(split_grid_dim_outer_reduction) << (bits - 20) ^
|
326
|
+
static_cast<size_t>(batches_per_block_outer_reduction) << (bits - 21) ^
|
327
|
+
static_cast<size_t>(unroll_factor_outer_reduction) << (bits - 22) ^
|
328
|
+
static_cast<size_t>(compute_persistent_buffer_with_first_consumer)
|
329
|
+
<< (bits - 23) ^
|
330
|
+
static_cast<size_t>(unroll_factor_top_of_vectorization) << (bits - 24);
|
331
|
+
return attr_hash;
|
332
|
+
}
|
333
|
+
|
334
|
+
std::unique_ptr<HeuristicParams> clone() const override {
|
335
|
+
return std::make_unique<ReductionParams>(*this);
|
336
|
+
}
|
337
|
+
};
|
338
|
+
|
339
|
+
} // namespace nvfuser
|
@@ -0,0 +1,159 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <fusion.h>
|
12
|
+
#include <ir/all_nodes.h>
|
13
|
+
#include <scheduler/reduction_heuristic.h>
|
14
|
+
#include <visibility.h>
|
15
|
+
|
16
|
+
namespace nvfuser {
|
17
|
+
|
18
|
+
namespace reduction_scheduler_utils {
|
19
|
+
|
20
|
+
// Consistent parallelization based on provided reduction parameters. Provided
|
21
|
+
// tensor is expected to be reduced by canonicalDimReduction before sending
|
22
|
+
// here. reduction_tv should be provided as the tensorview to reduce.
|
23
|
+
// RFactor of reduction_tv will be returned if applicable otherwise reduction_tv
|
24
|
+
// is returned
|
25
|
+
TensorView* scheduleReductionTV(
|
26
|
+
const ReductionParams* rparams,
|
27
|
+
TensorView* reduction_tv,
|
28
|
+
bool has_iter_axis);
|
29
|
+
|
30
|
+
// Inlining function intended for single or multi reduction fusions.
|
31
|
+
void multiReductionInliner(
|
32
|
+
Fusion* fusion,
|
33
|
+
TensorView* reduction_tv,
|
34
|
+
TensorView* reference_tv,
|
35
|
+
const bool unroll,
|
36
|
+
const bool vectorize,
|
37
|
+
const bool use_grouped_reduction,
|
38
|
+
const int64_t vectorizatoin_factor,
|
39
|
+
std::vector<TensorView*> reduction_tvs,
|
40
|
+
std::vector<TensorView*> cached_inputs,
|
41
|
+
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs,
|
42
|
+
std::vector<TensorView*> smem_persistent_buffer_consumers = {},
|
43
|
+
std::vector<TensorView*> dummy_outputs = {});
|
44
|
+
|
45
|
+
// Propagate transformations with internal cutoff boundary at boundaryNodesSet
|
46
|
+
// in P2C forward propagate, disable propagation to TensorView in
|
47
|
+
// boundaryNodesSet in C2P backward propagate, disable propagation from
|
48
|
+
// TensorView in boundaryNodesSet
|
49
|
+
NVF_API void propagateTransformation(
|
50
|
+
TensorView* reference_tv,
|
51
|
+
const std::unordered_set<TensorView*>& boundaryNodesSet =
|
52
|
+
std::unordered_set<TensorView*>());
|
53
|
+
|
54
|
+
// Propagate RFactor from first reduction TensorView to others
|
55
|
+
void propagateRFactor(
|
56
|
+
TensorView* reference_tv,
|
57
|
+
TensorView* reduction_tv,
|
58
|
+
const std::vector<TensorView*>& reduction_tvs);
|
59
|
+
|
60
|
+
// Get all cached input/output and shared memory TensorViews that are
|
61
|
+
// vectorizable and unrollable.
|
62
|
+
//
|
63
|
+
// Parameters:
|
64
|
+
// reference_tv: TensorView created during RFactor, used to find vectorizable
|
65
|
+
// TensorViews.
|
66
|
+
// is_vectorize: Indicates if vectorization is applied in the scheduler.
|
67
|
+
// cached_inputs: Inputs cached in registers or shared memory.
|
68
|
+
// cached_outputs: Outputs cached in registers.
|
69
|
+
NVF_API std::unordered_set<TensorView*> getCachedTvsToUnrollOrVectorize(
|
70
|
+
TensorView* reference_tv,
|
71
|
+
bool is_vectorize,
|
72
|
+
const std::vector<TensorView*>& cached_inputs,
|
73
|
+
const std::vector<std::pair<TensorView*, TensorView*>>& cached_outputs);
|
74
|
+
|
75
|
+
// Propagate parallelization from the reference TensorView to other TensorViews.
|
76
|
+
// Unroll, Vectorize, and MisalignedVectorize types are explicitly handled for
|
77
|
+
// TensorViews in unroll_vectorizable_cached_tvs. Clears unroll parallelization
|
78
|
+
// for reduction_tv and reference_tv if they shouldn't be unrolled.
|
79
|
+
//
|
80
|
+
// Parameters:
|
81
|
+
// reduction_tv: The reduction TensorView being scheduled and parallelized.
|
82
|
+
// Needs to clear its vectorization or convert to grouped
|
83
|
+
// reduction.
|
84
|
+
//
|
85
|
+
// reference_tv: The reference TensorView being scheduled and parallelized,
|
86
|
+
// propagates parallelization to other selected TensorViews.
|
87
|
+
//
|
88
|
+
// is_unroll_or_vectorization: Indicates if unroll or vectorization is used in
|
89
|
+
// the scheduler.
|
90
|
+
//
|
91
|
+
// reduction_tvs: All reduction TensorViews in the fusion. May add grouped
|
92
|
+
// parallelization.
|
93
|
+
//
|
94
|
+
// unroll_vectorizable_cached_tvs: Cached TensorViews that are unrollable
|
95
|
+
// or vectorizable.
|
96
|
+
//
|
97
|
+
// selected_tvs: TensorViews selected for parallelization, default is all Tvs.
|
98
|
+
NVF_API void propagateParallelization(
|
99
|
+
TensorView* reduction_tv,
|
100
|
+
TensorView* reference_tv,
|
101
|
+
const bool is_unroll_or_vectorization,
|
102
|
+
const bool use_grouped_reduction,
|
103
|
+
const std::vector<TensorView*>& reduction_tvs,
|
104
|
+
const std::unordered_set<TensorView*>& unroll_vectorizable_cached_tvs,
|
105
|
+
const std::vector<TensorView*>& selected_tvs = {});
|
106
|
+
|
107
|
+
// Sort and rfactor the reference tv in a consistent way for reduction inliner.
|
108
|
+
// Order of the sort is:
|
109
|
+
//
|
110
|
+
// [i-device dims, i-block dims, i-thread dims, i-constant sized, i-non-constant
|
111
|
+
// sized
|
112
|
+
// r-block dims, r-thread dims, r-non-constant sized, r-constant sized,
|
113
|
+
// i/r-unswitched, i/r-unroll/vectorized, broadcasted dims]
|
114
|
+
//
|
115
|
+
// Rfactored axes are reductions bound to grid or blocks. If no axes are bound
|
116
|
+
// to a grid or block dimension it will rfactor the r-unswitch dimension.
|
117
|
+
// Reduction inliner expects an rfactored domain.
|
118
|
+
NVF_API TensorView* sortAndRFactor(TensorView* reference_tv);
|
119
|
+
|
120
|
+
// If project_to_inputs is true, take all projectable persistent buffers,
|
121
|
+
// and move them to the inputs. Otherwise, try to project to their immediate
|
122
|
+
// producers if these producers are persistent buffers.
|
123
|
+
// This function create dummy outputs which should be used in later stages of
|
124
|
+
// the scheduling.
|
125
|
+
NVF_API std::vector<TensorView*> projectPersistentBuffers(
|
126
|
+
Fusion* fusion,
|
127
|
+
const bool project_to_inputs);
|
128
|
+
|
129
|
+
//! Get reduction types based on the given fusion or reduction tvs.
|
130
|
+
//! If there are no reduction tvs, return None.
|
131
|
+
//! If there are only inner reduction tvs, return Inner.
|
132
|
+
//! If there are only outer reduction tvs, return Outer.
|
133
|
+
//! If there are both inner and outer reduction tvs, return InnerOuter.
|
134
|
+
enum class ReductionType { Inner, Outer, InnerOuter, None };
|
135
|
+
std::ostream& operator<<(std::ostream& os, ReductionType reduction_type);
|
136
|
+
std::string toString(ReductionType reduction_type);
|
137
|
+
ReductionType getReductionType(Fusion* fusion);
|
138
|
+
ReductionType getReductionType(const std::vector<TensorView*>& reduction_tvs);
|
139
|
+
|
140
|
+
/**
|
141
|
+
* @brief Vectorize shared memory consumers
|
142
|
+
*
|
143
|
+
* Applies vectorization to shared memory consumers.
|
144
|
+
* If extent of the last dim multiples vectorization factor exceeds hardware
|
145
|
+
* limitations, additional split is added.
|
146
|
+
*
|
147
|
+
* @param smem_consumers Vector of TensorView pointers representing shared
|
148
|
+
* memory consumers
|
149
|
+
* @param io_vectorization_factor Vectorization factor set for fusion inputs and
|
150
|
+
* outputs
|
151
|
+
* @note TODO: Optimize writing to shared memory and address bank conflicts for
|
152
|
+
* float32 with innermost extent of 8
|
153
|
+
*/
|
154
|
+
void sharedMemoryConsumerVectorization(
|
155
|
+
std::vector<TensorView*>& smem_consumers,
|
156
|
+
const int64_t io_vectorization_factor);
|
157
|
+
|
158
|
+
} // namespace reduction_scheduler_utils
|
159
|
+
} // namespace nvfuser
|
@@ -0,0 +1,97 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
#include <exceptions.h>
|
10
|
+
#include <expr_evaluator.h>
|
11
|
+
#include <fusion.h>
|
12
|
+
#include <scheduler/compile_time_info.h>
|
13
|
+
#include <scheduler/utils.h>
|
14
|
+
|
15
|
+
namespace nvfuser {
|
16
|
+
|
17
|
+
class HeuristicDataCache;
|
18
|
+
class HeuristicParams;
|
19
|
+
class SchedulerRuntimeInfo;
|
20
|
+
|
21
|
+
//! Virtual base class for schedule heuristics
|
22
|
+
//! heuristic implementations derive from this
|
23
|
+
//! class and implement a schedule(Fusion*)
|
24
|
+
//! and a bool canSchedule(Fusion*) interface
|
25
|
+
class SchedulerEntry {
|
26
|
+
public:
|
27
|
+
NVF_API virtual ~SchedulerEntry() = default;
|
28
|
+
|
29
|
+
//! Fusion runtime facing API,
|
30
|
+
//! schedule the given fusion with heuristics owned
|
31
|
+
//! by this entry, for actual heuristics to override
|
32
|
+
NVF_API virtual void schedule(
|
33
|
+
Fusion* fusion,
|
34
|
+
const HeuristicParams* params) = 0;
|
35
|
+
|
36
|
+
virtual std::unique_ptr<HeuristicParams> computeHeuristics(
|
37
|
+
Fusion* fusion,
|
38
|
+
SchedulerRuntimeInfo& runtime_info,
|
39
|
+
HeuristicDataCache* data_cache = nullptr) = 0;
|
40
|
+
|
41
|
+
// Compile check that the scheduler maybe able to schedule the fusion
|
42
|
+
virtual bool canScheduleCompileTime(Fusion* fusion) = 0;
|
43
|
+
|
44
|
+
// Runtime check that the scheduler can take the fusion. Scheduler must be
|
45
|
+
// able to schedule the fusion if canScheduleCompileTime && this returns True.
|
46
|
+
virtual bool canScheduleRunTime(
|
47
|
+
Fusion* fusion,
|
48
|
+
SchedulerRuntimeInfo& runtime_info,
|
49
|
+
HeuristicDataCache* data_cache = nullptr) = 0;
|
50
|
+
|
51
|
+
// Dispatch heuristic type to the right derived class of scheduler entry.
|
52
|
+
// Scheduler entries are stateless so it's a lightweight class to dispatch to
|
53
|
+
// the virtual functions in this abstract class.
|
54
|
+
NVF_API static std::unique_ptr<SchedulerEntry> makeSchedulerInstance(
|
55
|
+
SchedulerType scheduler_type);
|
56
|
+
|
57
|
+
// Checks the provided scheduler type can schedule the fusion with the
|
58
|
+
// provided inputs. Schedules the fusion according to the heuristics provided
|
59
|
+
// by the scheduler. Returns the heuristics. This is simply a convenience
|
60
|
+
// function for a common testing pattern. If validate_scheduler is set to
|
61
|
+
// false canSchedule will not be checked.
|
62
|
+
NVF_API static std::unique_ptr<HeuristicParams> scheduleWith(
|
63
|
+
Fusion* fusion,
|
64
|
+
SchedulerType scheduler_type,
|
65
|
+
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
66
|
+
bool validate_scheduler = true);
|
67
|
+
|
68
|
+
//! Heuristic comparison
|
69
|
+
NVF_API bool sameAs(const SchedulerEntry* other);
|
70
|
+
|
71
|
+
NVF_API const HeuristicParams* params() const {
|
72
|
+
return params_.get();
|
73
|
+
}
|
74
|
+
|
75
|
+
std::unique_ptr<HeuristicParams> params_ = nullptr;
|
76
|
+
};
|
77
|
+
|
78
|
+
namespace Schedule {
|
79
|
+
|
80
|
+
//! External access for canSchedule utilities through SchedulerEntry
|
81
|
+
//! to avoid exposing a single function to the namespace
|
82
|
+
bool canSchedule(
|
83
|
+
SchedulerType sh,
|
84
|
+
Fusion* fusion,
|
85
|
+
SchedulerRuntimeInfo& runtime_info,
|
86
|
+
HeuristicDataCache* data_cache = nullptr,
|
87
|
+
bool skip_compile_time_checks = false);
|
88
|
+
|
89
|
+
//! Fusion segmenter facing API,
|
90
|
+
//! returns a schedule that applies in the given fusion, returns
|
91
|
+
//! SchedulerType::None if no schedule in the registry can handle.
|
92
|
+
SchedulerType proposeHeuristics(
|
93
|
+
Fusion* fusion,
|
94
|
+
SchedulerRuntimeInfo& runtime_info);
|
95
|
+
} // namespace Schedule
|
96
|
+
|
97
|
+
} // namespace nvfuser
|