nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,771 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <device_lower/pass/loop_rotation.h>
|
11
|
+
#include <disjoint_set.h>
|
12
|
+
#include <exceptions.h>
|
13
|
+
#include <fusion.h>
|
14
|
+
#include <ir/all_nodes.h>
|
15
|
+
#include <ir/cloner.h>
|
16
|
+
#include <scheduler/reduction_heuristic.h>
|
17
|
+
#include <scheduler/tools/maxinfo_propagator.h>
|
18
|
+
#include <visibility.h>
|
19
|
+
|
20
|
+
namespace nvfuser {
|
21
|
+
|
22
|
+
class ComputeAtMap;
|
23
|
+
class SchedulerRuntimeInfo;
|
24
|
+
class HeuristicDataCache;
|
25
|
+
|
26
|
+
namespace scheduler_utils {
|
27
|
+
|
28
|
+
// Assume any only half of the register file is available to spend on buffers,
|
29
|
+
// this is because when we allocate a buffer in register is has to be accesed
|
30
|
+
// with a compile time constant index. Unfortunately nvcc seems to be using
|
31
|
+
// many registers for indexing. This is a bad estimation of extra register use,
|
32
|
+
// but it's hard to get a better one.
|
33
|
+
constexpr int64_t register_file_size_full = (int64_t)256 * 1024;
|
34
|
+
constexpr int64_t register_file_size = register_file_size_full / 2;
|
35
|
+
constexpr int64_t register_file_size_56k = (int64_t)56 * 4 * 1024;
|
36
|
+
|
37
|
+
// Empirically observed number. Not guaranteed to be a good estimate
|
38
|
+
constexpr int64_t register_overhead = 40l;
|
39
|
+
constexpr int64_t max_registers_per_thread = 255l;
|
40
|
+
constexpr int64_t bytes_per_register = 4l;
|
41
|
+
|
42
|
+
constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
|
43
|
+
constexpr int64_t y_grid_limit = 65535;
|
44
|
+
constexpr int64_t z_grid_limit = 65535;
|
45
|
+
constexpr int64_t z_block_limit = 64;
|
46
|
+
|
47
|
+
// Find largest power of 2 that is a factor of n. If n==0, return largest power
|
48
|
+
// of 2 representable by int64_t
|
49
|
+
constexpr int64_t maxVectorizationWidth(int64_t n) {
|
50
|
+
if (n == 0) {
|
51
|
+
// Max representable int has null sign bit then all ones. Shift right then
|
52
|
+
// xor to preserve only the most significant bit.
|
53
|
+
int64_t m = std::numeric_limits<int64_t>::max();
|
54
|
+
return m ^ (m >> 1);
|
55
|
+
}
|
56
|
+
// For example
|
57
|
+
// n = b101101000
|
58
|
+
// n - 1 = b101100111
|
59
|
+
// ~ (n - 1) = b010011000
|
60
|
+
// n & (~ (n - 1)) = b000001000
|
61
|
+
// The key is that subtracting one flips all trailing 0s as well as the least
|
62
|
+
// significant 1, so all of the other bits will fail the &, leaving
|
63
|
+
// only that 1.
|
64
|
+
return n & (~(n - 1));
|
65
|
+
}
|
66
|
+
|
67
|
+
// Largest Power of 2 less-than n
|
68
|
+
constexpr int64_t lastPow2(int64_t n) {
|
69
|
+
NVF_ERROR(n >= 0);
|
70
|
+
n |= (n >> 1);
|
71
|
+
n |= (n >> 2);
|
72
|
+
n |= (n >> 4);
|
73
|
+
n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
74
|
+
n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
75
|
+
n |= (n >> 32); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
76
|
+
return std::max((int64_t)1, n - (n >> 1));
|
77
|
+
}
|
78
|
+
|
79
|
+
// round up to multiple of 8 or pow2 whichever smaller
|
80
|
+
constexpr int64_t roundUpPow2Or8(const int64_t x) {
|
81
|
+
auto round_up_pow2 = lastPow2(x);
|
82
|
+
if (round_up_pow2 < x) {
|
83
|
+
round_up_pow2 *= 2;
|
84
|
+
}
|
85
|
+
constexpr int64_t kEight = 8;
|
86
|
+
auto round_up_8 = x % kEight == 0 ? x : x + (kEight - x % kEight);
|
87
|
+
return std::min(round_up_8, round_up_pow2);
|
88
|
+
}
|
89
|
+
|
90
|
+
constexpr int64_t roundUpPow2(const int64_t x) {
|
91
|
+
auto round_up_pow2 = scheduler_utils::lastPow2(x);
|
92
|
+
if (round_up_pow2 < x) {
|
93
|
+
round_up_pow2 *= 2;
|
94
|
+
}
|
95
|
+
return round_up_pow2;
|
96
|
+
}
|
97
|
+
|
98
|
+
constexpr int64_t roundUpToN(const int64_t x, const int64_t n) {
|
99
|
+
return x % n == 0 ? x : x + (n - x % n);
|
100
|
+
}
|
101
|
+
|
102
|
+
// Div x by y, but min at 1
|
103
|
+
inline int64_t safeDiv(const int64_t x, const int64_t y) {
|
104
|
+
return std::max(x / y, (int64_t)1);
|
105
|
+
}
|
106
|
+
|
107
|
+
// Split the given dimensions in `to_split`. Also update the dimensions in
|
108
|
+
// `to_update` to the positions in the splitted tensor. Splitting one dimension
|
109
|
+
// multiple times is supported, and if this is the case, then the order of
|
110
|
+
// `to_split` matters. All given dimensions are numbers before any split.
|
111
|
+
void splitDims(
|
112
|
+
TensorView* tv,
|
113
|
+
std::vector<std::pair<int64_t, int64_t>> to_split, // (dim, size)
|
114
|
+
std::vector<int64_t>& to_update);
|
115
|
+
|
116
|
+
inline void splitDims(
|
117
|
+
TensorView* tv,
|
118
|
+
std::vector<std::pair<int64_t, int64_t>> to_split) { // (dim, size)
|
119
|
+
std::vector<int64_t> unused;
|
120
|
+
splitDims(tv, std::move(to_split), unused);
|
121
|
+
}
|
122
|
+
|
123
|
+
// Merge all the given dimensions in `to_merge` into a single dimension. Also
|
124
|
+
// update the dimensions in `to_update` to the positions in the merged tensor.
|
125
|
+
// Returns the merged dimension. All given dimensions are numbers before any
|
126
|
+
// merge.
|
127
|
+
// NOTE: merged is done as the entries in the order of `to_merge`, assuming an
|
128
|
+
// order from inner to outer
|
129
|
+
std::optional<int64_t> mergeDims(
|
130
|
+
TensorView* tv,
|
131
|
+
std::vector<int64_t> to_merge,
|
132
|
+
std::vector<int64_t>& to_update);
|
133
|
+
|
134
|
+
inline std::optional<int64_t> mergeDims(
|
135
|
+
TensorView* tv,
|
136
|
+
std::vector<int64_t> to_merge) {
|
137
|
+
std::vector<int64_t> unused;
|
138
|
+
return mergeDims(tv, std::move(to_merge), unused);
|
139
|
+
}
|
140
|
+
|
141
|
+
// Merge all reduction to the right side and returns total number of
|
142
|
+
// reduction axes.
|
143
|
+
int64_t mergeReduction(TensorView* tv);
|
144
|
+
|
145
|
+
// merge all non-reduction axes to the left side and returns total number of
|
146
|
+
// iteration axes.
|
147
|
+
int64_t mergeNonReduction(TensorView* tv);
|
148
|
+
|
149
|
+
// Propagate the parallelization from the selected dimensions of the reference
|
150
|
+
// tensor to their corresponding dimensions in all selected tensors in the DAG.
|
151
|
+
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
|
152
|
+
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
|
153
|
+
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
|
154
|
+
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
|
155
|
+
// Empty `selected_parallel_types` means selecting all parallel types.
|
156
|
+
void parallelizeAllLike(
|
157
|
+
TensorView* reference_tv,
|
158
|
+
int64_t pos = -1,
|
159
|
+
std::vector<TensorView*> selected_tvs = {},
|
160
|
+
const std::unordered_set<ParallelType>& selected_parallel_types = {},
|
161
|
+
bool propagate_padding = true);
|
162
|
+
|
163
|
+
inline void parallelizeAllLike(
|
164
|
+
TensorView* reference_tv,
|
165
|
+
std::vector<TensorView*> selected_tvs,
|
166
|
+
const std::unordered_set<ParallelType>& selected_parallel_types = {},
|
167
|
+
bool propagate_padding = true) {
|
168
|
+
parallelizeAllLike(
|
169
|
+
reference_tv,
|
170
|
+
-1,
|
171
|
+
std::move(selected_tvs),
|
172
|
+
selected_parallel_types,
|
173
|
+
propagate_padding);
|
174
|
+
}
|
175
|
+
|
176
|
+
// Common hyperparameters used in heuristic scheduler. These hyperparameters
|
177
|
+
// are passed to SchedulerEntry::computeHeuristics through the
|
178
|
+
// HeuristicDataCache. These hyperparameters alter the generation of the
|
179
|
+
// HeuristicParams for the scheduler.
|
180
|
+
struct SchedulerHyperParameters {
|
181
|
+
SchedulerHyperParameters(
|
182
|
+
int64_t vectorize_factor_,
|
183
|
+
int64_t unroll_factor_,
|
184
|
+
int64_t threads_per_block_min_,
|
185
|
+
int64_t threads_per_block_max_)
|
186
|
+
: vectorize_factor(vectorize_factor_),
|
187
|
+
unroll_factor(unroll_factor_),
|
188
|
+
threads_per_block_min(threads_per_block_min_),
|
189
|
+
threads_per_block_max(threads_per_block_max_) {}
|
190
|
+
|
191
|
+
//! Number of elements to load per vectorize load.
|
192
|
+
int64_t vectorize_factor = 1;
|
193
|
+
|
194
|
+
//! Number of iterations to unroll for-loop.
|
195
|
+
int64_t unroll_factor = 1;
|
196
|
+
|
197
|
+
//! Minimum number of threads per block.
|
198
|
+
int64_t threads_per_block_min = 1;
|
199
|
+
|
200
|
+
//! Maximum number of threads per block.
|
201
|
+
int64_t threads_per_block_max = 1;
|
202
|
+
};
|
203
|
+
|
204
|
+
struct PersistentBufferInfo {
|
205
|
+
std::vector<TensorView*> persistent_buffers;
|
206
|
+
std::unordered_set<IterDomain*> unmappable_dims;
|
207
|
+
|
208
|
+
// Persistent buffers are needed until the path through the reduction -
|
209
|
+
// broadcast chain is resolved by any other chain using the persistent buffer
|
210
|
+
// that is not going through a reduction. This assumes all reduction paths
|
211
|
+
// have the same reduction pattern. Order is the same as persistent_buffers
|
212
|
+
std::vector<std::vector<TensorView*>> persistent_buffer_resolution_points;
|
213
|
+
|
214
|
+
// Not all persistent buffers can be projected to inputs, if a buffer can be
|
215
|
+
// projected to the inputs which may reduce the persistent buffer size (BN
|
216
|
+
// Backwards specifically) then keep track of it here. Persistent buffers that
|
217
|
+
// have a persistent buffer/reduction before them should not be projected
|
218
|
+
// through that.
|
219
|
+
std::vector<TensorView*> projectable_persistent_buffers;
|
220
|
+
|
221
|
+
// Track inputs of input projectable buffers
|
222
|
+
std::vector<TensorView*> projectable_buffer_inputs;
|
223
|
+
|
224
|
+
// Map unmappable dims to projectable_buffer_inputs
|
225
|
+
std::unordered_set<IterDomain*> unamppable_dims_projected_to_inputs;
|
226
|
+
|
227
|
+
// Some parameters used in
|
228
|
+
// normalization_scheduler_utils::isProjectBufferToInput
|
229
|
+
bool has_view_ops = false;
|
230
|
+
bool projection_with_exp_op = false;
|
231
|
+
bool projection_with_rng_op = false;
|
232
|
+
};
|
233
|
+
|
234
|
+
// Buffers whos roots can't map to all producer roots based on compute at. These
|
235
|
+
// are the buffers we would make persistent in a persistent kerenl or would have
|
236
|
+
// to recompute if we can't make a persistent kernel. This function will also
|
237
|
+
// return inputs as being marked persistent if they follow this pattern. It is
|
238
|
+
// important to note however inputs don't strictly have to be persistent as they
|
239
|
+
// can simply be read multiple times from GMEM in the same kernel.
|
240
|
+
PersistentBufferInfo persistentBuffers(Fusion* fusion);
|
241
|
+
|
242
|
+
// A persistent tv can be projected to its producers when all the producers are
|
243
|
+
// persistent tvs and there is no reduction op.
|
244
|
+
bool canProjectToPersistentProducer(
|
245
|
+
TensorView* buffer,
|
246
|
+
const std::vector<TensorView*>& producers,
|
247
|
+
const std::unordered_set<TensorView*>& persistent_buffer_set);
|
248
|
+
|
249
|
+
//! Evaluates if a persistent buffer can be projected to input tvs without
|
250
|
+
//! dependency on reduction tvs. Returns a std::pair with a boolean indicating
|
251
|
+
//! whether projection is feasible and a vector of projectable tvs.
|
252
|
+
//!
|
253
|
+
//! The function operates in two main steps:
|
254
|
+
//! (1) Checks if the persistent buffer has dependencies on any of the given
|
255
|
+
//! reduction tvs. If no dependencies are found, it returns true with an
|
256
|
+
//! empty vector of target broadcast tvs.
|
257
|
+
//! (2) If there are dependencies, it examines each reduction tv for an
|
258
|
+
//! associated broadcast tv that can be projected to. If all reduction tvs
|
259
|
+
//! have corresponding broadcast tvs, true is returned along with these tvs.
|
260
|
+
//! If any reduction tv lacks a corresponding broadcast tv, false is
|
261
|
+
//! returned with the current list of identified broadcast tvs.
|
262
|
+
std::pair<bool, std::vector<TensorView*>> canProjectToInputsWithoutReduction(
|
263
|
+
const std::vector<TensorView*> reduction_tvs,
|
264
|
+
TensorView* persistent_buffer);
|
265
|
+
|
266
|
+
struct ReductionTvProperties {
|
267
|
+
// How many elements in tensor view are there to reduce.
|
268
|
+
int64_t total_reduction_numel = 1;
|
269
|
+
|
270
|
+
// How many reductions do we need to perform, i.e. how many iter dimension.
|
271
|
+
// elements are there
|
272
|
+
int64_t total_iteration_numel = 1;
|
273
|
+
|
274
|
+
// Is the inner most dimension a reduction, if no reductions mark true.
|
275
|
+
bool fastest_dim_reduction = true;
|
276
|
+
|
277
|
+
// How many elements in the inner most dimension merging surrounding domains
|
278
|
+
// that match in type. This is used for 3D schedulers in
|
279
|
+
// reduction/normalization.
|
280
|
+
int64_t inner_most_dimension_numel = 1;
|
281
|
+
|
282
|
+
// Same thing as above, but the number of dimensions instead of the numel.
|
283
|
+
int64_t inner_most_dimension_ndims = 1;
|
284
|
+
|
285
|
+
// Merging neighboring iteration domains, and reduction domains, what's the
|
286
|
+
// resulting dimensionality of the problem.
|
287
|
+
int64_t dimensionality = 1;
|
288
|
+
};
|
289
|
+
|
290
|
+
// Fill ReductionTvProperties structure about tv
|
291
|
+
ReductionTvProperties getReductionProperties(
|
292
|
+
Fusion* fusion,
|
293
|
+
SchedulerRuntimeInfo& runtime_info,
|
294
|
+
TensorView* tv);
|
295
|
+
|
296
|
+
// Struct to store persistent buffer sizes. also holds the persistent buffer
|
297
|
+
// size of the buffers are projected to the inputs.
|
298
|
+
struct PersistentBufferSizeReturn {
|
299
|
+
int64_t persistent_buffer_size = 0;
|
300
|
+
int64_t projected_persistent_buffer_size = 0;
|
301
|
+
};
|
302
|
+
|
303
|
+
// Compute the amount of register space would be needed to perform this kernel
|
304
|
+
// persistently, only based on buffers that must be persistent, and based on the
|
305
|
+
// maximum of all minimum size requirement. i.e. if must be persistent, only
|
306
|
+
// hold persistent dimension.
|
307
|
+
PersistentBufferSizeReturn persistentBufferSize(
|
308
|
+
Fusion* fusion,
|
309
|
+
SchedulerRuntimeInfo& runtime_info,
|
310
|
+
const PersistentBufferInfo& persistent_buffers,
|
311
|
+
HeuristicDataCache* data_cache = nullptr);
|
312
|
+
|
313
|
+
// Merges tensor view to the form:
|
314
|
+
// [IterationDomain, ReductionDomain] Returns if <iteration dimensions,
|
315
|
+
// reduction dimensions>
|
316
|
+
std::pair<bool, bool> canonicalDimReduction(
|
317
|
+
Fusion* fusion,
|
318
|
+
TensorView* tv,
|
319
|
+
bool schedule_3D = false);
|
320
|
+
|
321
|
+
// Return a list of tensor views that are outputs of reduction operations,
|
322
|
+
// excluding resharding reduce expressions. If multiple outputs of an expression
|
323
|
+
// are found, only include one in the list
|
324
|
+
std::vector<TensorView*> getReductionTvs(Fusion* fusion);
|
325
|
+
|
326
|
+
// Returns a list of TensorViews that are the consumer tv for a view operation.
|
327
|
+
std::vector<TensorView*> getViewTVs(Fusion* fusion);
|
328
|
+
|
329
|
+
// Returns a list of non-reduction TensorViews that have a root domain
|
330
|
+
std::vector<TensorView*> getTVsWithNonReductionRFactor(Fusion* fusion);
|
331
|
+
|
332
|
+
// Reset inputs and outputs to global memory, everything else to local.
|
333
|
+
void clearMemorySpace(Fusion* fusion);
|
334
|
+
|
335
|
+
// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
|
336
|
+
// return empty vector.
|
337
|
+
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
|
338
|
+
|
339
|
+
// Returns the pairs of <cache of each fusion output, corresponding output> for
|
340
|
+
// all outputs.
|
341
|
+
std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
|
342
|
+
Fusion* fusion,
|
343
|
+
bool unroll);
|
344
|
+
|
345
|
+
// Ignores broadcast and reduction, returns iter domain in allocation domain
|
346
|
+
// that's "inner most".
|
347
|
+
IterDomain* innerMostAllocDim(TensorView* tv);
|
348
|
+
|
349
|
+
// Looks through fusion and finds all dims that match to the one provided in
|
350
|
+
// the tensorview provided. Iter domain must be a root domain. If inner_only,
|
351
|
+
// will only map dimensions if they're the inner most position. This is
|
352
|
+
// important when projecting a dimension between an rfactor position and its
|
353
|
+
// root position when mapping from consumer to producer. If inner_only=true,
|
354
|
+
// takes the rfactor/root dimensions that maps, projects it to the root/rfactor
|
355
|
+
// domain, but only following the inner most pass when encounting split/merge.
|
356
|
+
// When propagating backward, for split it will only propagate backwards if the
|
357
|
+
// mapped dimension is the inner portion of the split. For merge, inner_only
|
358
|
+
// doesn't make a dimension and will propagate through the inner portion of the
|
359
|
+
// merge. When propagating forward, the logic is symmetric with the backward
|
360
|
+
// case.
|
361
|
+
class FindAllMappedDims : public MaxInfoSpanningTree::Propagator {
|
362
|
+
std::unordered_map<TensorView*, IterDomain*> mapped_root_ids_;
|
363
|
+
std::unordered_map<TensorView*, IterDomain*> mapped_logical_ids_;
|
364
|
+
TensorView* starting_tv_ = nullptr;
|
365
|
+
IterDomain* starting_id_ = nullptr;
|
366
|
+
bool inner_only_;
|
367
|
+
bool vectorize_pass_;
|
368
|
+
|
369
|
+
public:
|
370
|
+
FindAllMappedDims(
|
371
|
+
TensorView* from,
|
372
|
+
IterDomain* starting_id,
|
373
|
+
bool inner_only,
|
374
|
+
bool vectorize_pass);
|
375
|
+
void setUp() override;
|
376
|
+
void propagateC2P(TensorView* from, TensorView* to) override;
|
377
|
+
void propagateP2C(TensorView* from, TensorView* to) override;
|
378
|
+
void propagateSibling(TensorView* from, TensorView* to) override;
|
379
|
+
std::unordered_set<IterDomain*> get() const;
|
380
|
+
};
|
381
|
+
|
382
|
+
// Checks if tensor view has an iteration domain in vector dims in its inner
|
383
|
+
// most root position (excluding broadcast and reduction), and checks if it is a
|
384
|
+
// contiguous dimension
|
385
|
+
bool hasInnerDim(
|
386
|
+
TensorView* tv,
|
387
|
+
std::unordered_set<IterDomain*> vector_dims,
|
388
|
+
bool should_vectorize);
|
389
|
+
|
390
|
+
// Returns all inputs and outputs that share the inner most dimension of the
|
391
|
+
// provided reference. If reference is an input it ignores reduction axes, will
|
392
|
+
// ignore all broadcast axes. If inner_only, will require inner->inner mapping
|
393
|
+
// in view, otherwise, it allows all inner->any mapping. If vectorize_pass, will
|
394
|
+
// check contiguity for vectorization, otherwise it just checks it has that
|
395
|
+
// inner dim.
|
396
|
+
std::vector<TensorView*> getInputsOutputsWithInnerDim(
|
397
|
+
TensorView* reference_tv,
|
398
|
+
bool inner_only,
|
399
|
+
bool vectorize_pass);
|
400
|
+
|
401
|
+
// Holder return struct for the below function.
|
402
|
+
struct DisjointLogicalSetInfo {
|
403
|
+
// const* to the disjoint set in disjoint_rfactor_set passed in to
|
404
|
+
// getDisjointLogicalSetsOf each iterdomain in the rfactor of ref is mapped
|
405
|
+
// to.
|
406
|
+
//
|
407
|
+
// WARNING: these pointers are relative to the disjoint_rfactor_set reference
|
408
|
+
// passed into getDisjointLogicalSetsOf it's the user's responsibility to
|
409
|
+
// maintain the lifetime of that reference to match this vector.
|
410
|
+
std::vector<const VectorOfUniqueEntries<IterDomain*>*> disjoint_sets_of_ref;
|
411
|
+
|
412
|
+
// Unique ID associated to the disjoint view group the logical id belongs to
|
413
|
+
// in disjoint_sets_of_ref. It's straight forward to map from
|
414
|
+
// disjoint_sets_of_ref to the vector, but not the other way around.
|
415
|
+
std::vector<int64_t> disjoint_set_ids;
|
416
|
+
|
417
|
+
// TensorView reference the above vectors are relative to.
|
418
|
+
TensorView* ref;
|
419
|
+
};
|
420
|
+
|
421
|
+
// Returns disjoint rfactor sets mapped onto the given reference. Returns a pair
|
422
|
+
// of vectors of size rfactorDomain of reference. Vector of
|
423
|
+
// VectorOfUniqueEntries returns a const* to the disjoint set in
|
424
|
+
// disjoint_rfactor_set the iterdomain is mapped to. Integer vector represents
|
425
|
+
// which disjoint rfactor group the logical id belongs to. It's straightforward
|
426
|
+
// to map from the former to the latter, but not the latter to former.
|
427
|
+
//
|
428
|
+
// Since we return a const* to entries in disjoint_rfactor_set, it must be
|
429
|
+
// passed in as a reference. Algorithm is N^2 based on number of dims in
|
430
|
+
// reference, but generating the disjoint rfactor set is likely the limiter on
|
431
|
+
// perf of this function.
|
432
|
+
//
|
433
|
+
// logical_reorder_map is provided to assume TensorView `of` will be reordered
|
434
|
+
// per the map
|
435
|
+
DisjointLogicalSetInfo getDisjointLogicalSetsOf(
|
436
|
+
Fusion* fusion,
|
437
|
+
TensorView* of,
|
438
|
+
DisjointSets<IterDomain*>& disjoint_rfactor_set,
|
439
|
+
const std::unordered_map<int64_t, int64_t>& logical_reorder_map = {});
|
440
|
+
|
441
|
+
// Structure to hold byte multiples for break points. I.e. if we have the
|
442
|
+
// tensors:
|
443
|
+
// T0[I0, I1] float
|
444
|
+
// T1[I0, I1] bool
|
445
|
+
// T2[I0] half
|
446
|
+
// T3 [I1] double
|
447
|
+
// and a break point of 1 the multiples would be:
|
448
|
+
// lhs_multiple = 4 + 1 + 2 = 7
|
449
|
+
// rhs_multiple = 4 + 1 + 8 = 13
|
450
|
+
struct BroadcastMultiple {
|
451
|
+
int64_t rhs_multiple = 0;
|
452
|
+
int64_t lhs_multiple = 0;
|
453
|
+
};
|
454
|
+
|
455
|
+
struct BroadcastMultipleInformation {
|
456
|
+
std::vector<int64_t> view_disjoint_set_ids;
|
457
|
+
std::vector<BroadcastMultiple> broadcast_multiples;
|
458
|
+
};
|
459
|
+
|
460
|
+
// Returns a vector of size reference_tv->getLogicalDomain().size() which
|
461
|
+
// is a view disjoint set id of each of those iter domains. If entries share the
|
462
|
+
// same value, they undergo view transformations in the fusion together.
|
463
|
+
// Broadcast multiples are also of size
|
464
|
+
// reference_tv->getLogicalDomain().size(), each entry [i] is the number of
|
465
|
+
// inputs/outputs that have a non-broadcast dimension mapped to the
|
466
|
+
// corresponding dimension in reference_tv. Broadcast multiples includes
|
467
|
+
// reference_tv if reference_tv is an input or output. Broadcast multiples is
|
468
|
+
// multiplied by data type size. In the case of view operations the broadcast
|
469
|
+
// multiple is the full multiple size if any domain in the group maps to a
|
470
|
+
// non-broadcast dimension in the given input/output. Otherwise if all
|
471
|
+
// dimensions are broadcast that input/output will not contribute to the
|
472
|
+
// multiple.
|
473
|
+
//
|
474
|
+
// logical_reorder_map is provided to assume reference_tv will be reordered per
|
475
|
+
// the map
|
476
|
+
BroadcastMultipleInformation getBroadcastMultiples(
|
477
|
+
TensorView* reference_tv,
|
478
|
+
DataType index_type,
|
479
|
+
const std::unordered_map<int64_t, int64_t>& logical_reorder_map = {});
|
480
|
+
|
481
|
+
//! Propagate current transformations on from_tv up to the given
|
482
|
+
//! position, to all tensorviews on the owning fusion that has
|
483
|
+
//! a connection with `from_tv` on the fusion graph.
|
484
|
+
void transformPropagateToAllFrom(TensorView* from_tv, int64_t pos);
|
485
|
+
|
486
|
+
//! A type of custom transform propagator that propagates iterdomain
|
487
|
+
//! transforms from a source tv to all tvs that are selected
|
488
|
+
//! using a "direction" and a "boundary".
|
489
|
+
//!
|
490
|
+
//! The propagation model always assumes a `from_tv`, a `direction` and a
|
491
|
+
//! `boundary`.
|
492
|
+
//!
|
493
|
+
//! This propagator will only transform producers and consumers
|
494
|
+
//! of `from_tv`, and all propagation modes **require** a boundary to be
|
495
|
+
//! specified to signify where the propagation should stop.
|
496
|
+
//!
|
497
|
+
//! There are currently three modes of propagation: forward, backward and
|
498
|
+
//! both-way, see comment on the interface functions for details.
|
499
|
+
struct BoundedDirectionalTransformPropagator {
|
500
|
+
//! Custom option container for configuring
|
501
|
+
//! the transform propagation actions.
|
502
|
+
//! All option values default to false unless
|
503
|
+
//! the corresponding setter is called.
|
504
|
+
struct Options {
|
505
|
+
//! If true, the transform propagator will
|
506
|
+
//! also propagate parallel types from
|
507
|
+
//! `from_tv` to all selected tvs.
|
508
|
+
bool propagate_parallel_type = false;
|
509
|
+
|
510
|
+
//! If true, the specified boundary tvs
|
511
|
+
//! will also be replayed as `from_tv`.
|
512
|
+
//! If false, they will not be affected
|
513
|
+
//! by the propagation pass.
|
514
|
+
bool transform_boundary = false;
|
515
|
+
|
516
|
+
//! Sets the position boundary in parallel
|
517
|
+
//! type propagation, see comment on
|
518
|
+
//! scheduler_utils::parallelizeAllLike.
|
519
|
+
//! Only used if propagate_parallel_type==true.
|
520
|
+
int64_t parallel_propagation_pos = -1;
|
521
|
+
|
522
|
+
//! Setter for enabling parallel type
|
523
|
+
//! propagation. see comment on the variable.
|
524
|
+
//!
|
525
|
+
//! \param up_to_pos, sets the parallel type
|
526
|
+
//! propagation boundary. see comment on
|
527
|
+
//! scheduler_utils::parallelizeAllLike.
|
528
|
+
Options propagateParallelType(int64_t up_to_pos = -1) {
|
529
|
+
propagate_parallel_type = true;
|
530
|
+
parallel_propagation_pos = up_to_pos;
|
531
|
+
return *this;
|
532
|
+
}
|
533
|
+
|
534
|
+
//! Setter for enabling propagation to
|
535
|
+
//! boundary tvs. see comment on the variable
|
536
|
+
Options propagateToBoundary() {
|
537
|
+
transform_boundary = true;
|
538
|
+
return *this;
|
539
|
+
}
|
540
|
+
};
|
541
|
+
|
542
|
+
//! Replay transforms from tensorview `from`
|
543
|
+
//! to the tensorviews that are consumers
|
544
|
+
//! of boundary tensorviews in `to` and producers of `from`.
|
545
|
+
static void backward(
|
546
|
+
TensorView* from,
|
547
|
+
int64_t pos,
|
548
|
+
std::vector<TensorView*> to,
|
549
|
+
std::optional<Options> options = std::nullopt);
|
550
|
+
|
551
|
+
//! Replay transforms from tensorview `from`
|
552
|
+
//! to the tensorviews that are producers
|
553
|
+
//! of boundary tensorviews in `to` and consumers of `from`.
|
554
|
+
static void forward(
|
555
|
+
TensorView* from,
|
556
|
+
int64_t pos,
|
557
|
+
std::vector<TensorView*> to,
|
558
|
+
std::optional<Options> options = std::nullopt);
|
559
|
+
|
560
|
+
//! Replay transforms from tensorview `from`
|
561
|
+
//! to all the tensorviews that are consumers
|
562
|
+
//! of tensorviews in `backward_to` and producers
|
563
|
+
//! of tensorviews in `forward_to` while being
|
564
|
+
//! either a producer or a consumer of tensorview `from`.
|
565
|
+
static void bothWays(
|
566
|
+
TensorView* from,
|
567
|
+
int64_t pos,
|
568
|
+
std::vector<TensorView*> backward_to,
|
569
|
+
std::vector<TensorView*> forward_to,
|
570
|
+
std::optional<Options> options = std::nullopt);
|
571
|
+
|
572
|
+
private:
|
573
|
+
//! Utility function:
|
574
|
+
//! Will realize the transform propagation to the
|
575
|
+
//! tensorview's in `included_tvs`.
|
576
|
+
//! Assumes that all tvs in included_tvs are either
|
577
|
+
//! a producer or a consumer of from_tv.
|
578
|
+
static void propagate(
|
579
|
+
TensorView* from_tv,
|
580
|
+
int64_t pos,
|
581
|
+
std::unordered_set<TensorView*> included_tvs,
|
582
|
+
Options options);
|
583
|
+
};
|
584
|
+
|
585
|
+
// Schedulers typically start by merging some axes together then splitting,
|
586
|
+
// and propagating those transformations through the dag. What we want to
|
587
|
+
// understand is if these merges can be supported through view operations.
|
588
|
+
// For example it could be problematic to support a reduction fusion:
|
589
|
+
//
|
590
|
+
// tv0[2, 3, 4]
|
591
|
+
// tv1 = sum(tv0, {1, 2})
|
592
|
+
// tv2 = view(tv0, {6, 4})
|
593
|
+
//
|
594
|
+
// Since the first step of the reduction scheduler would be tv1->merge(1, 2).
|
595
|
+
// If we tried to propagate this transformation through the view it would make
|
596
|
+
// the view invalid. If we tried to propagate the view through the reduction,
|
597
|
+
// it would attempt to merge a reduction and non-reduction dimension. So for
|
598
|
+
// these types of fusions we would like to understand that the view considers
|
599
|
+
// axis 1 and 2 of tv1 as "non-separable" axes.
|
600
|
+
//
|
601
|
+
// If IterDomains are disjoint in the returned set, then they are considered
|
602
|
+
// "separable".
|
603
|
+
// Warning: This pass generates the IdGraphs, not intended for use at runtime.
|
604
|
+
DisjointSets<IterDomain*> disjointLogicalSets(Fusion* fusion);
|
605
|
+
|
606
|
+
// Makes sure that there are no group id's left of pos that match right of pos.
|
607
|
+
// e.g.
|
608
|
+
// [1, 0, 0] pos 2 would return false
|
609
|
+
// [1, 0, 0] pos 1 would return true
|
610
|
+
bool breakIsDisjoint(std::vector<int64_t> group_ids, int64_t pos);
|
611
|
+
|
612
|
+
// Generates an old to new map to reorder tv's domain as the logical order.
|
613
|
+
// Priority is given to inner most dimensions for example:
|
614
|
+
// logical [i0, i1, i2]
|
615
|
+
// domain [i0*i2, i1]
|
616
|
+
// will produce the map {{0, 1}, {1, 0}}
|
617
|
+
// This is somewhat similar to orderTiledConcreteIdAsRoot
|
618
|
+
std::unordered_map<int64_t, int64_t> domainReorderAsLogicalMap(TensorView* tv);
|
619
|
+
|
620
|
+
// Generates an old to new map to reorder tv's domain as the logical order.
|
621
|
+
// This only handles the simple case where allocation is a permutation of
|
622
|
+
// logical domain, otherwise, the function returns an empty container.
|
623
|
+
std::unordered_map<int64_t, int64_t> maybeLogicalReorderAsAllocationMap(
|
624
|
+
TensorView* tv);
|
625
|
+
|
626
|
+
// Assumes view's are consistent as detected by
|
627
|
+
// registery.cpp::requiresForwardViewReplay returning false
|
628
|
+
void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map);
|
629
|
+
|
630
|
+
//! Check if tv is an output of a fastest-dim reduction
|
631
|
+
bool isFastestDimReduction(TensorView* tv);
|
632
|
+
|
633
|
+
// A wrapper for Fusion::rotateLoop that provide more consistent interace
|
634
|
+
inline void rotateLoop(
|
635
|
+
TensorView* loop_tv,
|
636
|
+
int64_t axis,
|
637
|
+
std::unordered_set<Statement*> selection) {
|
638
|
+
auto fusion = loop_tv->fusion();
|
639
|
+
if (!fusion->hasManaged("loop_rotation")) {
|
640
|
+
fusion->manage("loop_rotation", LoopRotationParam{});
|
641
|
+
}
|
642
|
+
fusion->getManaged<LoopRotationParam>("loop_rotation")
|
643
|
+
.emplace_back(loop_tv, axis, std::move(selection));
|
644
|
+
}
|
645
|
+
|
646
|
+
//! Certain tensors may need to be placed on shared or global memory
|
647
|
+
//! due to data dependencies caused by resize operations. Create
|
648
|
+
//! caches of those tensors so that original operations producing
|
649
|
+
//! them should keep using the same memory. This avoids, for example,
|
650
|
+
//! reductions to global memory.
|
651
|
+
//!
|
652
|
+
//! Example:
|
653
|
+
//!
|
654
|
+
//! tv1 = sum(tv0)
|
655
|
+
//! tv2 = some_resize_op(tv1);
|
656
|
+
//! tv3 = some_other_op(tv1);
|
657
|
+
//!
|
658
|
+
//! When tv1 is promoted to Global, we want to avoid reducing to a
|
659
|
+
//! global memory tensor. After the transformation by this function,
|
660
|
+
//! the fusion should look like:
|
661
|
+
//!
|
662
|
+
//! tv1 = sum(tv0);
|
663
|
+
//! tv4 = tv1
|
664
|
+
//! tv4->setMemoryType(Global)
|
665
|
+
//! tv2 = some_resize_op(tv4)
|
666
|
+
//! tv3 = some_other_op(tv1);
|
667
|
+
//!
|
668
|
+
//! Note that the sum reduction is done using a Local buffer, i.e.,
|
669
|
+
//! tv1, but the data dependency for the resize op is still satisfied
|
670
|
+
//! by having a copy of tv1, i.e., tv4. Note that the other op using
|
671
|
+
//! tv1 still uses tv1.
|
672
|
+
void prepareForMemoryTypePromotion(Fusion* fusion);
|
673
|
+
|
674
|
+
//! If a consumer tensor induces a data dependency between threads,
|
675
|
+
//! move its producer to a shared memory that is sufficient to satisfy
|
676
|
+
//! the dependency. For example, if the domain is parallelized
|
677
|
+
//! with blockIdx, the producer memory type will be changed to
|
678
|
+
//! Global. A proper RAW sync will be automatically inserted when the
|
679
|
+
//! fusion is lowered.
|
680
|
+
void promoteProducerMemoryTypes(
|
681
|
+
Fusion* fusion,
|
682
|
+
const std::vector<TensorView*>& input_caches);
|
683
|
+
|
684
|
+
//! Get all tensors that are connected to from_tvs without going through
|
685
|
+
//! any tvs in the cutoff_tv_set.
|
686
|
+
std::unordered_set<TensorView*> getAllTvsFrom(
|
687
|
+
const std::vector<TensorView*>& from_tvs,
|
688
|
+
const std::unordered_set<TensorView*>& cutoff_tv_set);
|
689
|
+
|
690
|
+
//! Get the persistent buffer size of a tensor
|
691
|
+
int64_t getPersistentBufferSizeOfTensor(
|
692
|
+
const TensorView* buffer,
|
693
|
+
SchedulerRuntimeInfo& runtime_info,
|
694
|
+
const PersistentBufferInfo& persistent_buffer_info);
|
695
|
+
|
696
|
+
//! The required shared memory size for a block inclues two parts: (1) smem
|
697
|
+
//! for persistent buffers and (2) overhead. The overhead includes space
|
698
|
+
//! reserved by the CUDA driver and reduction workspace which depends on the
|
699
|
+
//! number of threads per block specified by the parameter threads_per_block.
|
700
|
+
//! By default, the function uses the maximum allowed number of threads per
|
701
|
+
//! block (threads_per_block = -1) to calculate the overhead. The caller can
|
702
|
+
//! specify a different value if they are sure about the max value used at
|
703
|
+
//! runtime.
|
704
|
+
int64_t getSharedMemoryOverheadPerBlock(
|
705
|
+
Fusion* fusion,
|
706
|
+
const std::vector<TensorView*>& reduction_tvs,
|
707
|
+
int64_t threads_per_block = -1);
|
708
|
+
|
709
|
+
// Returns true if any Expr in `fusion` is resharding.
|
710
|
+
bool isResharding(Fusion* fusion);
|
711
|
+
|
712
|
+
// Move non-concretized broadcast domains to innermost
|
713
|
+
// positions. Broadcast domains mapped with any domains of given tvs
|
714
|
+
// are ignored.
|
715
|
+
//
|
716
|
+
// The goal here is to find domains that are not scheduled by
|
717
|
+
// propagation from reference tensors (i.e., ignored_tvs). All
|
718
|
+
// schedulers make sure to include only schedulable domains but they
|
719
|
+
// may also allow to have non-concretized broadcast domains that have
|
720
|
+
// no mapping with any of reference tensors. Since they are
|
721
|
+
// non-concretized, they should be safe to ignore. Ideally, they
|
722
|
+
// should just be removed from the fusion. For now, they are moved to
|
723
|
+
// innermost positions to prevent them from interfering
|
724
|
+
// inlining. If they happened to be at the
|
725
|
+
// outermost position, the tensor wouldn't be inlined at all. See
|
726
|
+
// issue #2686 and PR #2799.
|
727
|
+
void moveNonConcretizedBroadcastInnermost(
|
728
|
+
Fusion* fusion,
|
729
|
+
const std::unordered_set<TensorView*>& ignored_tvs = {});
|
730
|
+
|
731
|
+
// Returns a factor represents the computation cost of the given fusion.
|
732
|
+
// Estimated using the number of MUFU operations, each weighted with a
|
733
|
+
// predefined factor.
|
734
|
+
int64_t getComputationCostFactor(Fusion* fusion);
|
735
|
+
|
736
|
+
// Returns the required bytes in flight to saturate the memory bandwidth.
|
737
|
+
int64_t getRequiredBytesInFlight();
|
738
|
+
|
739
|
+
// Returns true if the device has a high bandwidth to compute raito.
|
740
|
+
bool isHighBandwidthFlopsRatio();
|
741
|
+
|
742
|
+
// Return true if the fusion has computation requires Floating-Point
|
743
|
+
// Multi-Function (MUFU) units, e.g. cos, sin, exponent, logarithm, sine,
|
744
|
+
// cosine, square root, hyperbolic tangent. Currently, we only tested tanh, exp,
|
745
|
+
// and Reciprocal. Note that, if compiled with fast math (not supported yet) or
|
746
|
+
// directly lowered with inlined ptx, needs to revise the inner reduction
|
747
|
+
// heuristics which uses this function to set the optimal unroll factor.
|
748
|
+
bool hasExpensiveMUFUops(Fusion* fusion);
|
749
|
+
// Reorder DID parallelized axes to outermost positions. Returns
|
750
|
+
// the position of the outermost non-DID axis.
|
751
|
+
int64_t reorderDevicesToOuter(TensorView* tv);
|
752
|
+
|
753
|
+
// Returns number of non-reduction/non-broadcas/non-device dims in logical
|
754
|
+
// domain
|
755
|
+
inline int64_t nLogicalDims(const TensorView* tv) {
|
756
|
+
auto logical_dom = tv->getLogicalDomain();
|
757
|
+
int64_t tv_n_dims = 0;
|
758
|
+
for (auto dim : logical_dom) {
|
759
|
+
if (!dim->isReduction() && !dim->isBroadcast() && !dim->isDeviceDim()) {
|
760
|
+
tv_n_dims++;
|
761
|
+
}
|
762
|
+
}
|
763
|
+
return tv_n_dims;
|
764
|
+
}
|
765
|
+
|
766
|
+
// Reorer the loop domain of a given tensor to align with a given list of
|
767
|
+
// reference IDs. Non-matching loop IDs are placed outermost positions.
|
768
|
+
void reorderTensorLike(TensorView* tv, const std::vector<IterDomain*>& ref);
|
769
|
+
|
770
|
+
} // namespace scheduler_utils
|
771
|
+
} // namespace nvfuser
|