nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,111 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
#include <scheduler/all_schedulers.h>
|
10
|
+
|
11
|
+
namespace nvfuser {
|
12
|
+
|
13
|
+
class TensorView;
|
14
|
+
class ComputeAtLogicalDomainMap;
|
15
|
+
class ComputeAtMap;
|
16
|
+
class ExpressionEvaluator;
|
17
|
+
class KernelArgumentHolder;
|
18
|
+
|
19
|
+
namespace registry_utils {
|
20
|
+
|
21
|
+
bool checkPatternEquivalence(
|
22
|
+
TensorView* out_tv0,
|
23
|
+
TensorView* out_tv1,
|
24
|
+
const ComputeAtLogicalDomainMap& logical_map);
|
25
|
+
|
26
|
+
// Reusing some code from lowering specifically in lower_trivial_broadcast.cpp
|
27
|
+
// ConcretizedBroadcastDomains::maybeNonUniquelyConcretized this checks if
|
28
|
+
// there's a broadcast iteration domain that's being broadcasted to seemingly
|
29
|
+
// different extents, meaning we don't know in the kernel if the dimension is
|
30
|
+
// being broadcasted to one size multiple times or different sizes. This is a
|
31
|
+
// hard to optimize problem and likely indicates we shouldn't be fusing.
|
32
|
+
bool hasNonUniqueBcast(Fusion* fusion);
|
33
|
+
|
34
|
+
// TODO: remove this requirement entirely
|
35
|
+
bool rejectScheduleForMemoryPromotion(
|
36
|
+
Fusion* fusion,
|
37
|
+
SchedulerType scheduler_type);
|
38
|
+
|
39
|
+
bool isConnectedFusionGraph(Fusion* fusion);
|
40
|
+
|
41
|
+
// Returns if a fusion cannot transformed into a consistent format since we
|
42
|
+
// can't transform forward through view operations, for exmaple:
|
43
|
+
//
|
44
|
+
// tv0[I0, I1, I2]
|
45
|
+
// tv1[I0*I1, I2] = view(tv0)
|
46
|
+
// tv2[I0, I1*I2] = view(tv0)
|
47
|
+
//
|
48
|
+
// If we start transform propagation at either tv1 or tv2, it would require
|
49
|
+
// "replaying forward" through the other. If we started at tv1 we'd have to be
|
50
|
+
// able to take tv2[I0, I1*I2] and transform it to [I0*I1, I2], however this
|
51
|
+
// would "undo" the view transformation which we do not support today.
|
52
|
+
//
|
53
|
+
// Returns true if a scenario like above is found in the fusion.
|
54
|
+
bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map);
|
55
|
+
|
56
|
+
// Returns if view interferes with how we want to treat the reference, being
|
57
|
+
// at least a 2D reduction schedule but maybe a 3D reduction schedule.
|
58
|
+
bool reductionInterferingView(
|
59
|
+
Fusion* fusion,
|
60
|
+
const ComputeAtMap& ca_map,
|
61
|
+
TensorView* reduction_reference);
|
62
|
+
|
63
|
+
// Check inputs, outputs and intermediates
|
64
|
+
// Intermediates are contiguous, so strides are not necessary
|
65
|
+
// Strides are required for inputs and also maybe for outputs as
|
66
|
+
// they may be non-contiguous. However, in our current interface,
|
67
|
+
// output strides are not available, so if there's any outputs that
|
68
|
+
// are non contiguous, need to fall back to 64-bit indexing
|
69
|
+
PrimDataType getIndexTypeOfKernel(
|
70
|
+
Fusion* fusion,
|
71
|
+
const std::vector<TensorView*>& all_tvs,
|
72
|
+
const KernelArgumentHolder& inputs,
|
73
|
+
ExpressionEvaluator& ee);
|
74
|
+
|
75
|
+
class SchedulerTopologyChecker {
|
76
|
+
public:
|
77
|
+
// Checks if any broadcasts are resolved after a reduction that don't follow
|
78
|
+
// the normalization pattern
|
79
|
+
static bool hasNonNormalizePostReductionBCast(Fusion* fusion);
|
80
|
+
|
81
|
+
// Checks if any broadcasts are resolved after a reduction, this shouldn't
|
82
|
+
// be accepted in the single reduction or multi-reduction scheduler
|
83
|
+
static bool hasPostReductionBCast(Fusion* fusion);
|
84
|
+
|
85
|
+
// Checks if there's any unsupported operations post reduction. If outer
|
86
|
+
// reduction we can fuse some pointwise ops if they don't require
|
87
|
+
// broadcasting (checked in hasPostReductionBCast). For inner reductions we
|
88
|
+
// cannot fuse any binary like operation (includes operations like shift
|
89
|
+
// that we're not fusing right now) involving "new" inputs (not going
|
90
|
+
// through a reduction).
|
91
|
+
static bool supportedPostReductionFusion(
|
92
|
+
Fusion* fusion,
|
93
|
+
std::vector<TensorView*> reduction_tvs);
|
94
|
+
|
95
|
+
// Checks if there's any gather-like ops that result in non-resolved
|
96
|
+
// broadcast domains and then get squeezed before reaching reduction
|
97
|
+
// TVs. The reduction scheduler uses reduction TVs as a scheduling
|
98
|
+
// reference, so that won't be able to schedule the broadcast ID if
|
99
|
+
// squeezed and its corresponding index-accessed producer ID, and
|
100
|
+
// any IDs that the producer ID depends on.
|
101
|
+
//
|
102
|
+
// This analysis has some similarity as DomainMap. Can be
|
103
|
+
// consolidated?
|
104
|
+
static bool hasGatherToBroadcastBeforeReduction(
|
105
|
+
Fusion* fusion,
|
106
|
+
const std::vector<TensorView*>& reduction_tvs);
|
107
|
+
};
|
108
|
+
|
109
|
+
} // namespace registry_utils
|
110
|
+
|
111
|
+
} // namespace nvfuser
|
@@ -0,0 +1,41 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <scheduler/heuristic.h>
|
11
|
+
#include <scheduler/registry.h>
|
12
|
+
|
13
|
+
namespace nvfuser {
|
14
|
+
|
15
|
+
class Fusion;
|
16
|
+
class SchedulerRuntimeInfo;
|
17
|
+
class HeuristicDataCache;
|
18
|
+
|
19
|
+
class ResizeScheduler : public SchedulerEntry {
|
20
|
+
public:
|
21
|
+
bool canScheduleCompileTime(Fusion* fusion) override;
|
22
|
+
bool canScheduleRunTime(
|
23
|
+
Fusion* fusion,
|
24
|
+
SchedulerRuntimeInfo& runtime_info,
|
25
|
+
HeuristicDataCache* data_cache = nullptr) override {
|
26
|
+
return true;
|
27
|
+
}
|
28
|
+
|
29
|
+
std::unique_ptr<HeuristicParams> computeHeuristics(
|
30
|
+
Fusion* fusion,
|
31
|
+
SchedulerRuntimeInfo& runtime_info,
|
32
|
+
HeuristicDataCache* data_cache) override;
|
33
|
+
|
34
|
+
void schedule(Fusion* fusion, const HeuristicParams* params) override;
|
35
|
+
|
36
|
+
constexpr static SchedulerType schedulerType() {
|
37
|
+
return SchedulerType::Resize;
|
38
|
+
}
|
39
|
+
};
|
40
|
+
|
41
|
+
} // namespace nvfuser
|
@@ -0,0 +1,67 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <c10/util/hash.h>
|
11
|
+
#include <ir/interface_nodes.h>
|
12
|
+
#include <scheduler/heuristic.h>
|
13
|
+
#include <utils.h>
|
14
|
+
|
15
|
+
#include <sstream>
|
16
|
+
|
17
|
+
namespace nvfuser {
|
18
|
+
|
19
|
+
class ResizeParams : public HeuristicParams {
|
20
|
+
public:
|
21
|
+
ResizeParams() : HeuristicParams(SchedulerType::Resize) {};
|
22
|
+
|
23
|
+
// Split grid x dimension
|
24
|
+
bool split_grid_x_dim = false;
|
25
|
+
|
26
|
+
int64_t largest_input = -1;
|
27
|
+
|
28
|
+
int64_t vectorization_factor = 1;
|
29
|
+
|
30
|
+
static constexpr int64_t max_gdimx = (1L << 31) - 1L;
|
31
|
+
|
32
|
+
using HeuristicParams::HeuristicParams;
|
33
|
+
|
34
|
+
// Warning: Does not check launch parameters!
|
35
|
+
bool sameAs(const HeuristicParams* other_base) const override {
|
36
|
+
auto other = dynamic_cast<const ResizeParams*>(other_base);
|
37
|
+
if (other == nullptr) {
|
38
|
+
return false;
|
39
|
+
}
|
40
|
+
bool attr_equal = other->cparams == cparams &&
|
41
|
+
other->split_grid_x_dim == split_grid_x_dim &&
|
42
|
+
other->largest_input == largest_input &&
|
43
|
+
other->vectorization_factor == vectorization_factor;
|
44
|
+
return attr_equal;
|
45
|
+
}
|
46
|
+
|
47
|
+
std::string toString() const override {
|
48
|
+
std::stringstream ss;
|
49
|
+
ss << "\n===== Resize Parameters ========\n"
|
50
|
+
<< (tag.empty() ? "" : "Tag: ") << tag << " Resize Characteristics:\n"
|
51
|
+
<< " split grid x dim: " << split_grid_x_dim << "\n"
|
52
|
+
<< " index of largest input: " << largest_input << "\n"
|
53
|
+
<< " vectorization factor: " << vectorization_factor << "\n";
|
54
|
+
ss << "====================================\n";
|
55
|
+
return ss.str();
|
56
|
+
}
|
57
|
+
|
58
|
+
size_t hash() const override {
|
59
|
+
return c10::get_hash(split_grid_x_dim);
|
60
|
+
}
|
61
|
+
|
62
|
+
std::unique_ptr<HeuristicParams> clone() const override {
|
63
|
+
return std::make_unique<ResizeParams>(*this);
|
64
|
+
}
|
65
|
+
};
|
66
|
+
|
67
|
+
} // namespace nvfuser
|
@@ -0,0 +1,166 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
#include <cstddef>
|
10
|
+
#include <cstdint>
|
11
|
+
|
12
|
+
#include <expr_evaluator.h>
|
13
|
+
#include <fusion.h>
|
14
|
+
#include <runtime/executor_kernel_arg.h>
|
15
|
+
#include <utils.h>
|
16
|
+
#include <visibility.h>
|
17
|
+
|
18
|
+
namespace nvfuser {
|
19
|
+
|
20
|
+
class ExpressionEvaluator;
|
21
|
+
|
22
|
+
//! SchedulerRuntimeInfo is the abstraction introduced in
|
23
|
+
//! this PR for passing runtime input dependent information
|
24
|
+
//! to the schedulers and kernel caches.
|
25
|
+
//!
|
26
|
+
//! Note:
|
27
|
+
//! if any additional info needed, or maybe just the inputs themselves it
|
28
|
+
//! could just be added to this class, and they will be distributed to the
|
29
|
+
//! segmenter and schedulers.
|
30
|
+
//! It is important that input id encoding should be up to date with any change
|
31
|
+
//! of this class to avoid launching compiled kernels with illegal inputs.
|
32
|
+
|
33
|
+
class SchedulerRuntimeInfo : public NonCopyable {
|
34
|
+
public:
|
35
|
+
// Max vector size we will consider, in bytes,
|
36
|
+
// currently set to 16B = 128b
|
37
|
+
static constexpr int64_t max_alignment_size_in_byte = 16;
|
38
|
+
|
39
|
+
//! Create runtime info for given fusion and input. Creating and binding
|
40
|
+
//! evaluator is optional. The evaluator is used to manage intermediate
|
41
|
+
//! integers in the fusion. We need them for segmenter and schedulers,
|
42
|
+
//! but we don't need them when we are just using this class to provide
|
43
|
+
//! additional encoding for kernel cache lookup.
|
44
|
+
//!
|
45
|
+
//! The index type of forced_index_type is used if given, no matter
|
46
|
+
//! how large the actual arguments and fusion tensors
|
47
|
+
//! are. CORRECTNESS IS NOT GUARANTEED.
|
48
|
+
SchedulerRuntimeInfo(
|
49
|
+
Fusion* complete_fusion,
|
50
|
+
KernelArgumentHolder args,
|
51
|
+
PrecomputedValues* precomputed_values = nullptr,
|
52
|
+
const std::vector<TensorView*>& all_tvs = {},
|
53
|
+
std::optional<PrimDataType> forced_index_type = std::nullopt);
|
54
|
+
|
55
|
+
NVF_API SchedulerRuntimeInfo(
|
56
|
+
Fusion* complete_fusion,
|
57
|
+
const at::ArrayRef<c10::IValue>& aten_inputs);
|
58
|
+
|
59
|
+
//! Lookup for the alignment sizes of the given tv. Currently only returns
|
60
|
+
//! actual alignment info for input tensors to the complete fusion,
|
61
|
+
//! and for other intermediate/fuser-allocated tensors will
|
62
|
+
//! return max_alignment_size_in_byte.
|
63
|
+
size_t getAlignmentSize(TensorView* tv);
|
64
|
+
|
65
|
+
//! Returns sizes of tensor dimensions in same order as allocation domain,
|
66
|
+
//! ignoring any IterType::Reduction domains in the allocation domain. This
|
67
|
+
//! only works for complete Fusion inputs whose allocation domain is a
|
68
|
+
//! permutation of their root domain and will raise an exception otherwise.
|
69
|
+
const std::vector<int64_t>& getInputAllocationSizes(TensorView* tv) const {
|
70
|
+
NVF_ERROR(
|
71
|
+
isInputTv(tv),
|
72
|
+
"TensorView ",
|
73
|
+
tv->toString(),
|
74
|
+
" is not an input or its logical domain is not a permutation of its ",
|
75
|
+
"allocation domain");
|
76
|
+
auto sizes_it = input_sizes_.find(tv);
|
77
|
+
NVF_ERROR(sizes_it != input_sizes_.end());
|
78
|
+
return sizes_it->second;
|
79
|
+
}
|
80
|
+
|
81
|
+
//! Returns strides of tensor in same order as allocation domain, in elements
|
82
|
+
//! instead of bytes. Only works for complete Fusion inputs whose allocation
|
83
|
+
//! domain is a permutation of their root domain and will raise an exception
|
84
|
+
//! otherwise.
|
85
|
+
const std::vector<int64_t>& getInputAllocationStrides(TensorView* tv) const {
|
86
|
+
NVF_ERROR(
|
87
|
+
isInputTv(tv),
|
88
|
+
"TensorView ",
|
89
|
+
tv->toString(),
|
90
|
+
" is not an input or its logical domain is not a permutation of its ",
|
91
|
+
"allocation domain");
|
92
|
+
auto strides_it = input_strides_elements_.find(tv);
|
93
|
+
NVF_ERROR(strides_it != input_strides_elements_.end());
|
94
|
+
return strides_it->second;
|
95
|
+
}
|
96
|
+
|
97
|
+
// Computes alignment size in bytes for provided ptr address
|
98
|
+
static size_t computeAlignmentSize(size_t ptr_address);
|
99
|
+
|
100
|
+
// Return the runtime pointer value for provided tensor view
|
101
|
+
size_t ptrOf(TensorView* tv) const;
|
102
|
+
|
103
|
+
PrimDataType getIndexType() const {
|
104
|
+
return index_type_;
|
105
|
+
}
|
106
|
+
|
107
|
+
Fusion* fusion() {
|
108
|
+
return complete_fusion_;
|
109
|
+
}
|
110
|
+
|
111
|
+
ExpressionEvaluator& expressionEvaluator() {
|
112
|
+
NVF_ERROR(expression_evaluator_ != nullptr);
|
113
|
+
return *expression_evaluator_;
|
114
|
+
}
|
115
|
+
|
116
|
+
private:
|
117
|
+
// Build and bind full fusion inputs to an expression evaluator
|
118
|
+
std::unique_ptr<ExpressionEvaluator> getExpressionEvaluator(
|
119
|
+
const KernelArgumentHolder& inputs,
|
120
|
+
PrecomputedValues* precomputed_values);
|
121
|
+
|
122
|
+
bool isInputTv(TensorView* tv) const {
|
123
|
+
return std::find(
|
124
|
+
complete_fusion_->inputs().begin(),
|
125
|
+
complete_fusion_->inputs().end(),
|
126
|
+
tv) != complete_fusion_->inputs().end();
|
127
|
+
}
|
128
|
+
|
129
|
+
private:
|
130
|
+
// Returns the offset of tv in the inputs ignoring non tensor views. Used to
|
131
|
+
// access input_sizes, input_strides, input_ptr
|
132
|
+
int offsetTensorPos(TensorView* tv);
|
133
|
+
|
134
|
+
// Expression evaluator used to probe sizes in the fusion IR
|
135
|
+
std::unique_ptr<ExpressionEvaluator> expression_evaluator_ = nullptr;
|
136
|
+
|
137
|
+
// Fusion reference that this runtime info is associated with
|
138
|
+
Fusion* complete_fusion_ = nullptr;
|
139
|
+
|
140
|
+
// Copy of aten input pointer addresses
|
141
|
+
// TODO: Support output tensor pointers
|
142
|
+
std::unordered_map<Val*, size_t> input_ptrs_;
|
143
|
+
|
144
|
+
// Copy of aten input tensor sizes ordered like the TensorView's allocation
|
145
|
+
// domain
|
146
|
+
std::unordered_map<Val*, std::vector<int64_t>> input_sizes_;
|
147
|
+
|
148
|
+
// Copy of aten input tensor strides (in elements) ordered like the
|
149
|
+
// TensorView's allocation domain
|
150
|
+
std::unordered_map<Val*, std::vector<int64_t>> input_strides_elements_;
|
151
|
+
|
152
|
+
// Copy of aten input tensor strides (in bytes) for only discontiguous
|
153
|
+
// dimensions
|
154
|
+
std::unordered_map<Val*, std::vector<size_t>> input_discontig_strides_;
|
155
|
+
|
156
|
+
// Cache for getAlignmentSize
|
157
|
+
std::unordered_map<TensorView*, size_t> alignment_map_;
|
158
|
+
|
159
|
+
// Found index mode kernel needs to be run in
|
160
|
+
PrimDataType index_type_ = PrimDataType::Int;
|
161
|
+
|
162
|
+
// TODO: Remove
|
163
|
+
std::unordered_map<TensorView*, size_t> vectorword_map_;
|
164
|
+
};
|
165
|
+
|
166
|
+
} // namespace nvfuser
|
@@ -0,0 +1,80 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <visibility.h>
|
11
|
+
#include <array>
|
12
|
+
#include <ostream>
|
13
|
+
#include <string>
|
14
|
+
|
15
|
+
namespace nvfuser {
|
16
|
+
|
17
|
+
//! Each SchedulerType maps to a scheduler in distinct CPP files.
|
18
|
+
//! For instance, SchedulerType::PointWise maps to PointWiseScheduler in
|
19
|
+
//! pointwise.cpp.
|
20
|
+
//!
|
21
|
+
//! Each of the scheduler needs to provide 3 interface functions:
|
22
|
+
//!
|
23
|
+
//! 1. canScheduleCompileTime(Fusion* fusion) :
|
24
|
+
//!
|
25
|
+
//! This function contains compiled-time checks on the graph itself
|
26
|
+
//! without runtime input information. Only `fusion` is given in the
|
27
|
+
//! argument to make sure only compile-time available info is needed in
|
28
|
+
//! the check.
|
29
|
+
//!
|
30
|
+
//! This function is to be called exactly once on each segmented group
|
31
|
+
//! created in a segmented fusion so this part will not contribute to
|
32
|
+
//! dynamic shape latency.
|
33
|
+
//!
|
34
|
+
//! 2. canScheduleRunTime(
|
35
|
+
//! Fusion* fusion,
|
36
|
+
//! SchedulerRuntimeInfo& runtime_info,
|
37
|
+
//! HeuristicDataCache* data_cache = nullptr):
|
38
|
+
//! This function contains all canSchedule checks that will have to
|
39
|
+
//! involve runtime input information, and will be run both by the
|
40
|
+
//! segmenter and the kernel cache. The latency of this function will
|
41
|
+
//! contribute to dynamic shape latency so `data_cache` should be used as
|
42
|
+
//! much as possible to save re-computation.
|
43
|
+
//!
|
44
|
+
//! 3. schedule(fusion):
|
45
|
+
//!
|
46
|
+
//! This function will be called when compiling a kernel. It should apply
|
47
|
+
//! scheduling to the given fusion
|
48
|
+
|
49
|
+
enum class SchedulerType {
|
50
|
+
None,
|
51
|
+
NoOp,
|
52
|
+
PointWise,
|
53
|
+
Matmul,
|
54
|
+
Reduction,
|
55
|
+
InnerPersistent,
|
56
|
+
InnerOuterPersistent,
|
57
|
+
OuterPersistent,
|
58
|
+
Transpose,
|
59
|
+
ExprEval,
|
60
|
+
Resize
|
61
|
+
};
|
62
|
+
|
63
|
+
//! Define a schedule table to loop over all the heuristics in priority order.
|
64
|
+
constexpr std::array<SchedulerType, 10> all_heuristics_in_priority_order = {
|
65
|
+
SchedulerType::ExprEval,
|
66
|
+
SchedulerType::NoOp,
|
67
|
+
SchedulerType::Matmul,
|
68
|
+
SchedulerType::Reduction,
|
69
|
+
SchedulerType::Resize,
|
70
|
+
SchedulerType::Transpose,
|
71
|
+
SchedulerType::PointWise,
|
72
|
+
SchedulerType::InnerPersistent,
|
73
|
+
SchedulerType::OuterPersistent,
|
74
|
+
SchedulerType::InnerOuterPersistent};
|
75
|
+
|
76
|
+
std::string toString(SchedulerType sh);
|
77
|
+
|
78
|
+
NVF_API std::ostream& operator<<(std::ostream& os, SchedulerType sh);
|
79
|
+
|
80
|
+
} // namespace nvfuser
|
@@ -0,0 +1,114 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/core/ivalue.h>
|
11
|
+
#include <exceptions.h>
|
12
|
+
#include <fusion.h>
|
13
|
+
#include <scheduler/registry.h>
|
14
|
+
#include <scheduler/transpose_heuristic.h>
|
15
|
+
#include <visibility.h>
|
16
|
+
|
17
|
+
#define SUPPORT_SPLITTING_INNERMOST_DIM 0
|
18
|
+
|
19
|
+
namespace nvfuser {
|
20
|
+
|
21
|
+
// Note [Transpose scheduling]
|
22
|
+
//
|
23
|
+
// The target of transpose scheduling is to get coalesced global memory access
|
24
|
+
// to as much input and output tensors as possible. For a DAG with only pure
|
25
|
+
// pointwise operators, the scheduling is very simple because the inner most
|
26
|
+
// dimension of all input and output tensors are all mapped together in the
|
27
|
+
// ComputeAtMap, i.e., there is essentially only one inner most dimension. In
|
28
|
+
// such case, we just vectorize that inner most dimension and bind it to
|
29
|
+
// threadIdx.x identically for all input and output tensors. In the case where
|
30
|
+
// transposes are present in the DAG, the inner most dimensions of different
|
31
|
+
// inputs and outputs might not match. And there is no fixed pattern on which
|
32
|
+
// input/output tensors should share the same inner most dimension with which.
|
33
|
+
// Consider the following example DAGs ([T] represents transpose, all tensors
|
34
|
+
// are 2D):
|
35
|
+
//
|
36
|
+
// t0 t1 t0 t1 t0 t1 t0 t1 t0
|
37
|
+
// \ | \ / \ | \ | |
|
38
|
+
// \ [T] [T] [T] \ [T] t2 [T] [T]
|
39
|
+
// \ / \ / \ / \ / \ / \ |
|
40
|
+
// t2 t2 t2 t3 t3 t4 t5 [T]
|
41
|
+
// |
|
42
|
+
// t1
|
43
|
+
//
|
44
|
+
// In order to support all these cases in a general way, the following
|
45
|
+
// perspective is very important: What we are looking for is to bind threadIdx.x
|
46
|
+
// differently for different inputs and outputs, so there has to be some tensor
|
47
|
+
// somewhere in the DAG that we write and read with different threadIdx.x
|
48
|
+
// bindings. The tensor of binding swap can be any tensor on the path that
|
49
|
+
// connects inputs/outputs with different inner most dimension, especially, it
|
50
|
+
// does not necessarily have to be the tensor of the transpose operator. In
|
51
|
+
// other words, thanks to our indexing system who is already taking care of the
|
52
|
+
// correctness of transpose, the scheduler can freely choose where to realize
|
53
|
+
// these transposes as different threadIdx.x bindings. This observation greatly
|
54
|
+
// simplifies our scheduling.
|
55
|
+
//
|
56
|
+
// Our scheduling strategy is as follows: We first split the inputs and outputs
|
57
|
+
// of the fusion into two groups according to their inner most dimension. The
|
58
|
+
// inner most dimensions of tensors in the same group are mapped to each other,
|
59
|
+
// and they are not mapped to the inner most dimesion of tensors in a different
|
60
|
+
// group. Depending on the transpose pattern, there can be more than two groups,
|
61
|
+
// if this is the case, we only consider the two largest groups, and the tensors
|
62
|
+
// in the remaining groups will just be accessed unvectorized and uncoalesced.
|
63
|
+
// We call the largest group as `group1` and the second largest group as
|
64
|
+
// `group2`. When we have the groups, we will make a 2D tiling [I1, I2] ->
|
65
|
+
// [I1/tile1, tile1, I2/tile2, tile2] on the inner most dimensions of group1 and
|
66
|
+
// group2. If I1 and I2 are too small to make a 32x32 tile, such as in the
|
67
|
+
// fusion of tanspose(T1[1024, 2, 1024, 2], {1, 3}), we merge in other
|
68
|
+
// dimensions to make a virtual I1 and I2. The details of how we create virtual
|
69
|
+
// I1 and I2 are described in note [Supporting small transpose dimensions].
|
70
|
+
//
|
71
|
+
// Each tile [tile1, tile2] will be handled by a block, and the tensors that
|
72
|
+
// have mismatched threadIdx.x bindings will use shared memory. The outer IDs of
|
73
|
+
// the tiling split will be merged with non-tiled IDs and then binded to
|
74
|
+
// blockIdx.x for the entire DAG, regardless of which group a tensor belongs to.
|
75
|
+
// For the inner tile IDs [tile1, tile2], we need to transform and parallelize
|
76
|
+
// group 1 and group 2 differently. The intermediate tensors can be transformed
|
77
|
+
// and parallelized consistently either with group 1 or group 2. Here, since
|
78
|
+
// group 1 is larger than group 2, we decide to only transform and parallelize
|
79
|
+
// the cached inputs of group 2 together with group 2, and keep the rest of the
|
80
|
+
// DAG consistent with group 1.
|
81
|
+
//
|
82
|
+
// If you would like to see an example of how to manually schedule a complicated
|
83
|
+
// DAG using this idea, refer to:
|
84
|
+
// FusionManualScheduleTransposeComplexDAG1_CUDA
|
85
|
+
|
86
|
+
class SchedulerRuntimeInfo;
|
87
|
+
class HeuristicDataCache;
|
88
|
+
|
89
|
+
//! Utility for canSchedule interface to check if this fusion has at least two
|
90
|
+
//! groups, each with a fully broadcasted reference tensor.
|
91
|
+
NVF_API bool hasAtLeastTwoValidGroups(Fusion* fusion);
|
92
|
+
|
93
|
+
class TransposeScheduler : public SchedulerEntry {
|
94
|
+
public:
|
95
|
+
bool canScheduleCompileTime(Fusion* fusion) override;
|
96
|
+
|
97
|
+
bool canScheduleRunTime(
|
98
|
+
Fusion* fusion,
|
99
|
+
SchedulerRuntimeInfo& runtime_info,
|
100
|
+
HeuristicDataCache* data_cache = nullptr) override;
|
101
|
+
|
102
|
+
std::unique_ptr<HeuristicParams> computeHeuristics(
|
103
|
+
Fusion* fusion,
|
104
|
+
SchedulerRuntimeInfo& runtime_info,
|
105
|
+
HeuristicDataCache* data_cache) override;
|
106
|
+
|
107
|
+
void schedule(Fusion* fusion, const HeuristicParams* params) override;
|
108
|
+
|
109
|
+
constexpr static SchedulerType schedulerType() {
|
110
|
+
return SchedulerType::Transpose;
|
111
|
+
}
|
112
|
+
};
|
113
|
+
|
114
|
+
} // namespace nvfuser
|