nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,179 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/core/TensorBody.h>
|
11
|
+
#include <ATen/core/ivalue.h>
|
12
|
+
#include <c10/util/intrusive_ptr.h>
|
13
|
+
|
14
|
+
#include <exceptions.h>
|
15
|
+
#include <multidevice/multidevice.h>
|
16
|
+
#ifdef NVFUSER_DISTRIBUTED
|
17
|
+
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
18
|
+
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
19
|
+
#include <torch/csrc/distributed/c10d/Work.hpp>
|
20
|
+
#else
|
21
|
+
#include <multidevice/c10d_mock.h>
|
22
|
+
#endif
|
23
|
+
#include <visibility.h>
|
24
|
+
|
25
|
+
namespace nvfuser {
|
26
|
+
|
27
|
+
// This file implements the class Communicator which sets up the inter-process
|
28
|
+
// Backend. This class contains inter-process information, such as the rank, the
|
29
|
+
// world size, as well as the Process Group that can be called to perform
|
30
|
+
// inter-process communications.
|
31
|
+
//
|
32
|
+
// Each process is associated with a unique deviceId and device. The actual MPI
|
33
|
+
// rank remains private to the class and should not be used by the user. The
|
34
|
+
// communicator class holds privately the mappings ranks <-> device IDs <->
|
35
|
+
// device.
|
36
|
+
|
37
|
+
using RankType = DeviceIdxType;
|
38
|
+
|
39
|
+
// Supported backends. TODO: gloo untested
|
40
|
+
enum class CommunicatorBackend { kNccl, kUcc, kGloo };
|
41
|
+
|
42
|
+
std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb);
|
43
|
+
|
44
|
+
#ifdef USE_C10D_NCCL
|
45
|
+
constexpr CommunicatorBackend comm_backend_default = CommunicatorBackend::kNccl;
|
46
|
+
#else
|
47
|
+
constexpr CommunicatorBackend comm_backend_default = CommunicatorBackend::kUcc;
|
48
|
+
#endif
|
49
|
+
constexpr int comm_server_local_rank_default = 0;
|
50
|
+
|
51
|
+
class Communicator {
|
52
|
+
public:
|
53
|
+
static Communicator& getInstance() {
|
54
|
+
// This isn't the best practice to use singleton. Ideally, we'd like to
|
55
|
+
// ```
|
56
|
+
// static Communicator communicator;
|
57
|
+
// ```
|
58
|
+
// and let the destructor clean it up at program exit after `main` returns.
|
59
|
+
// This however would cause a "driver shutting down" error, likely because
|
60
|
+
// another static variable destructor shuts down the CUDA driver before
|
61
|
+
// ~Communicator. Note that the order of static variable destruction
|
62
|
+
// across translation units is undefined.
|
63
|
+
//
|
64
|
+
// Therefore, we `new Communicator()` as a raw pointer and let the user
|
65
|
+
// call Communicator::getInstance().cleanup() to clean up the Communicator
|
66
|
+
// explicitly before the end of `main`. For example, the cleanup method is
|
67
|
+
// called via MultiDeviceTestEnvironment::TearDown in C++ unit tests and
|
68
|
+
// nvfuser._cleanup() in Python.
|
69
|
+
static auto* communicator = new Communicator();
|
70
|
+
return *communicator;
|
71
|
+
}
|
72
|
+
|
73
|
+
Communicator(const Communicator&) = delete;
|
74
|
+
Communicator& operator=(const Communicator&) = delete;
|
75
|
+
~Communicator() = delete;
|
76
|
+
// As said in `getInstance`, the user of this class is supposed to call this
|
77
|
+
// method to clean up the singleton. This obviously can only be called once.
|
78
|
+
void cleanup();
|
79
|
+
|
80
|
+
// returns if distributed config is available
|
81
|
+
auto is_available() const {
|
82
|
+
return is_available_;
|
83
|
+
}
|
84
|
+
|
85
|
+
// returns the number of processes in the communicator
|
86
|
+
auto size() const {
|
87
|
+
return size_;
|
88
|
+
}
|
89
|
+
|
90
|
+
// returns the local number of processes in the communicator (within the node)
|
91
|
+
auto local_size() const {
|
92
|
+
return local_size_;
|
93
|
+
}
|
94
|
+
|
95
|
+
// sets the communicator's default backend
|
96
|
+
void setDefaultBackend(CommunicatorBackend backend) {
|
97
|
+
default_backend_ = backend;
|
98
|
+
}
|
99
|
+
|
100
|
+
// performs a blocking barrier in the communicator
|
101
|
+
void barrier(std::optional<CommunicatorBackend> backend = std::nullopt);
|
102
|
+
|
103
|
+
// returns the backend associated with a team
|
104
|
+
// the argument "prefix" is prepended to the key used to retrieve preexisting
|
105
|
+
// backends. Prefix is used to distinguish between different backends with the
|
106
|
+
// same team
|
107
|
+
c10d::Backend* getBackendForTeam(
|
108
|
+
const Team& team,
|
109
|
+
std::optional<CommunicatorBackend> backend,
|
110
|
+
const std::string& prefix = "");
|
111
|
+
|
112
|
+
// returns the device associated with the current process
|
113
|
+
auto device() const {
|
114
|
+
return at::Device("cuda:" + std::to_string(local_rank_));
|
115
|
+
}
|
116
|
+
|
117
|
+
// returns the device Id associated with the current process
|
118
|
+
DeviceIdxType deviceId() const {
|
119
|
+
return rankToDiD(rank_);
|
120
|
+
}
|
121
|
+
|
122
|
+
// returns local rank associted with the current process,
|
123
|
+
// i.e. the rank within a machine/node as opposed to the rank within the
|
124
|
+
// world.
|
125
|
+
RankType local_rank() const {
|
126
|
+
return local_rank_;
|
127
|
+
}
|
128
|
+
|
129
|
+
// returns world backend for communicator backend or default backend if not
|
130
|
+
// specified.
|
131
|
+
c10d::Backend* getWorld(
|
132
|
+
std::optional<CommunicatorBackend> backend = std::nullopt);
|
133
|
+
|
134
|
+
// returns if a backend is available for creation
|
135
|
+
bool isBackendAvailable(CommunicatorBackend backend) const {
|
136
|
+
if (backend == CommunicatorBackend::kUcc) {
|
137
|
+
return ucc_available_;
|
138
|
+
} else if (backend == CommunicatorBackend::kNccl) {
|
139
|
+
return nccl_available_;
|
140
|
+
}
|
141
|
+
return false;
|
142
|
+
}
|
143
|
+
|
144
|
+
private:
|
145
|
+
Communicator(
|
146
|
+
CommunicatorBackend backend = comm_backend_default,
|
147
|
+
RankType server_local_rank = comm_server_local_rank_default);
|
148
|
+
|
149
|
+
// returns the rank corresponding to a device index
|
150
|
+
RankType dIdToRank(DeviceIdxType d_id) const {
|
151
|
+
return static_cast<RankType>(d_id);
|
152
|
+
}
|
153
|
+
|
154
|
+
// returns the device index corresponding to a rank
|
155
|
+
DeviceIdxType rankToDiD(RankType rank) const {
|
156
|
+
return static_cast<DeviceIdxType>(rank);
|
157
|
+
}
|
158
|
+
|
159
|
+
CommunicatorBackend getBackend(std::optional<CommunicatorBackend> backend) {
|
160
|
+
return backend.value_or(default_backend_);
|
161
|
+
}
|
162
|
+
|
163
|
+
bool is_available_;
|
164
|
+
CommunicatorBackend default_backend_;
|
165
|
+
RankType rank_;
|
166
|
+
int64_t size_;
|
167
|
+
RankType local_rank_;
|
168
|
+
int64_t local_size_;
|
169
|
+
std::string master_addr_;
|
170
|
+
int master_port_;
|
171
|
+
bool ucc_available_;
|
172
|
+
bool nccl_available_;
|
173
|
+
// stores the world's store used for the backend init
|
174
|
+
c10::intrusive_ptr<c10d::TCPStore> store_;
|
175
|
+
// cache for the created backends. The keys are strings generated from Teams
|
176
|
+
std::unordered_map<std::string, c10::intrusive_ptr<c10d::Backend>> backends_;
|
177
|
+
};
|
178
|
+
|
179
|
+
} // namespace nvfuser
|
@@ -0,0 +1,95 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <vector>
|
12
|
+
|
13
|
+
#include <exceptions.h>
|
14
|
+
#include <multidevice/multidevice.h>
|
15
|
+
#include <type.h>
|
16
|
+
#include <visibility.h>
|
17
|
+
|
18
|
+
namespace nvfuser {
|
19
|
+
|
20
|
+
// The class DeviceMesh represents a set of (unique) devices on which a Pipeline
|
21
|
+
// Stage will be executed. For now, we only support flat meshes, but later we
|
22
|
+
// will add support for n-dimensional meshes.
|
23
|
+
class DeviceMesh final {
|
24
|
+
public:
|
25
|
+
// https://google.github.io/styleguide/cppguide.html#Implicit_Conversions
|
26
|
+
//
|
27
|
+
// Not using `explicit` for the constructor that takes a vector would lead
|
28
|
+
// to contention between operator<<(std::vector) defined in c10/util/Logging.h
|
29
|
+
// and operator<<(DeviceMesh) defined later in this file, which would be
|
30
|
+
// resolved arbitrarily by the compiler.
|
31
|
+
//
|
32
|
+
// There are no such contention for std::initializer_list so I chose to
|
33
|
+
// allow implicit conversion for that. This allows users to write `DeviceMesh
|
34
|
+
// mesh = {1, 2};`, which is more concise.
|
35
|
+
explicit DeviceMesh(std::vector<DeviceIdxType> devices = {});
|
36
|
+
DeviceMesh(std::initializer_list<DeviceIdxType> devices);
|
37
|
+
DeviceMesh(const DeviceMesh&) = default;
|
38
|
+
DeviceMesh(DeviceMesh&&) = default;
|
39
|
+
DeviceMesh& operator=(const DeviceMesh&) = default;
|
40
|
+
DeviceMesh& operator=(DeviceMesh&&) = default;
|
41
|
+
|
42
|
+
// Creates a device mesh of [0 .. num_devices-1]. I didn't make it a
|
43
|
+
// constructor because single-element initializer lists would be directed to
|
44
|
+
// use that instead of the constructor for vectors.
|
45
|
+
static DeviceMesh createForNumDevices(int64_t num_devices);
|
46
|
+
|
47
|
+
// Returns the number of devices in the mesh
|
48
|
+
int64_t size() const {
|
49
|
+
return static_cast<int64_t>(vector_.size());
|
50
|
+
}
|
51
|
+
|
52
|
+
int64_t size(ParallelType parallel_type) const;
|
53
|
+
|
54
|
+
// Returns a vector containing the device indices of the mesh
|
55
|
+
const std::vector<DeviceIdxType>& vector() const {
|
56
|
+
return vector_;
|
57
|
+
}
|
58
|
+
|
59
|
+
// Returns whether a device is present in the mesh
|
60
|
+
bool has(const DeviceIdxType device) const {
|
61
|
+
return std::find(vector_.begin(), vector_.end(), device) != vector_.end();
|
62
|
+
}
|
63
|
+
|
64
|
+
// Returns the index of device in the mesh, or -1 if device is not present.
|
65
|
+
int64_t idxOf(const DeviceIdxType device) const {
|
66
|
+
auto it = std::find(vector_.begin(), vector_.end(), device);
|
67
|
+
if (it != vector_.end()) {
|
68
|
+
return std::distance(vector_.begin(), it);
|
69
|
+
}
|
70
|
+
return -1;
|
71
|
+
}
|
72
|
+
|
73
|
+
// Returns the device at a particular index in the mesh
|
74
|
+
DeviceIdxType at(int64_t index) const {
|
75
|
+
return vector_.at(index);
|
76
|
+
}
|
77
|
+
|
78
|
+
bool operator==(const DeviceMesh& other) const {
|
79
|
+
return vector_ == other.vector();
|
80
|
+
}
|
81
|
+
|
82
|
+
bool operator!=(const DeviceMesh& other) const {
|
83
|
+
return vector_ != other.vector();
|
84
|
+
}
|
85
|
+
|
86
|
+
private:
|
87
|
+
void setDevices(std::vector<DeviceIdxType> devices);
|
88
|
+
|
89
|
+
// stores the list of device indices
|
90
|
+
std::vector<DeviceIdxType> vector_;
|
91
|
+
};
|
92
|
+
|
93
|
+
std::ostream& operator<<(std::ostream& out, const DeviceMesh& mesh);
|
94
|
+
|
95
|
+
} // namespace nvfuser
|
@@ -0,0 +1,107 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <c10/core/DeviceType.h>
|
11
|
+
#include <exceptions.h>
|
12
|
+
#include <fusion.h>
|
13
|
+
#include <fusion_segmenter.h>
|
14
|
+
#include <host_ir/executor.h>
|
15
|
+
#include <ir/cloner.h>
|
16
|
+
#include <multidevice/communication.h>
|
17
|
+
#include <multidevice/communicator.h>
|
18
|
+
#include <multidevice/multidevice.h>
|
19
|
+
|
20
|
+
namespace nvfuser {
|
21
|
+
|
22
|
+
/*
|
23
|
+
The MultiDeviceExecutor executes a Fusion on a multi-device setting.
|
24
|
+
It is instantiated from a Fusion and a Communicator.
|
25
|
+
|
26
|
+
The Fusion must be scheduled prior to the instantiation of the
|
27
|
+
MultiDeviceExecutor. One can use the multidevice scheduling API to specify
|
28
|
+
the desired tensor sharding. It is composed of two aspects:
|
29
|
+
*) Set each tensor's DeviceMesh, through TensorView::setDeviceMesh
|
30
|
+
*) parallelize each tensor axis, possibly with the multidevice sharding
|
31
|
+
parallel type ParallelType::DIDx
|
32
|
+
|
33
|
+
We make the following assumptions on the Fusion:
|
34
|
+
- Only one (non-reduction) axis is allowed to be parallelized
|
35
|
+
with ParallelType::DIDx. Moreover, this axis cannot be split/merged.
|
36
|
+
- We only support 1D device meshes for now
|
37
|
+
- We only support TensorViews in communication segments.
|
38
|
+
|
39
|
+
Summary of the different steps performed by the MultiDeviceExecutor:
|
40
|
+
I. At instantiation:
|
41
|
+
- resharding "Set" exprs are automatically inserted in the fusion where a
|
42
|
+
network communication is needed. See the function insertReshardings.
|
43
|
+
- the Fusion is segmented into segments which can be of two types:
|
44
|
+
1) compute segments, composed of non-Resharding expressions only,
|
45
|
+
that can be purely execute on a single device
|
46
|
+
or
|
47
|
+
2) communication, composed of exactly one resharding expression, which
|
48
|
+
can be either a "Set" or "Reduce" Exprs.
|
49
|
+
- the runtime order of execution of the different segments is computed in
|
50
|
+
prepareRuntimeOrder
|
51
|
+
|
52
|
+
II. At runtime, through the method runWithInput:
|
53
|
+
- allocateRecvBuffers allocates on each device the necessary buffers to
|
54
|
+
store the data received from network communications
|
55
|
+
- Each (compute or comm) segment is executed separately, in order:
|
56
|
+
1) each compute segment is transformed into a fusion, compiled and executed
|
57
|
+
on a single device, see postKernel
|
58
|
+
2) each comm segment is lowered into a series of communications (defined in
|
59
|
+
multidevice/communications.h) and are posted on the stream.
|
60
|
+
"Wait" primitives are also posted on the stream.
|
61
|
+
|
62
|
+
TODOS:
|
63
|
+
*) the MultiDeviceExecutor should be integrated into FusionExecutorCache.
|
64
|
+
*) The different steps should be divided into compilation, allocation,
|
65
|
+
runtime etc. This will be done along the way when we will have better
|
66
|
+
symbolic representation of the multidevice modules
|
67
|
+
*) Allocation of buffers needs to be reimplemented
|
68
|
+
*) Need to work on auto-scheduling, in particular, to combine inter-/intra-
|
69
|
+
device scheduling.
|
70
|
+
*/
|
71
|
+
|
72
|
+
class MultiDeviceExecutor {
|
73
|
+
public:
|
74
|
+
MultiDeviceExecutor(
|
75
|
+
std::unique_ptr<Fusion> fusion,
|
76
|
+
Communicator& comm,
|
77
|
+
hir::HostIrEvaluatorParams params = hir::HostIrEvaluatorParams());
|
78
|
+
|
79
|
+
// Run the fusion on several devices with the given global inputs
|
80
|
+
std::vector<at::Tensor> runWithInput(const std::vector<c10::IValue>& inputs);
|
81
|
+
|
82
|
+
// Returns the Communicator
|
83
|
+
Communicator* comm() const {
|
84
|
+
return &comm_;
|
85
|
+
}
|
86
|
+
|
87
|
+
// check if the runtime is valid returns an error msg.
|
88
|
+
// An empty message means that the runtime is valid
|
89
|
+
std::string validate() const {
|
90
|
+
return host_ir_executor_->canRun();
|
91
|
+
}
|
92
|
+
|
93
|
+
//! Print to default debugging output stream
|
94
|
+
std::ostream& print(std::ostream& os = debug());
|
95
|
+
|
96
|
+
const auto& getFusionExecutorCaches() {
|
97
|
+
return host_ir_executor_->getFusionExecutorCaches();
|
98
|
+
};
|
99
|
+
|
100
|
+
private:
|
101
|
+
// holds the Communicator to be used for execution
|
102
|
+
Communicator& comm_;
|
103
|
+
// holds the HostIrEvaluator used for execution
|
104
|
+
std::unique_ptr<hir::HostIrEvaluator> host_ir_executor_;
|
105
|
+
};
|
106
|
+
|
107
|
+
} // namespace nvfuser
|
@@ -0,0 +1,18 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <c10/core/Device.h>
|
12
|
+
|
13
|
+
namespace nvfuser {
|
14
|
+
using DeviceIdxType = int64_t;
|
15
|
+
using DimensionType = int;
|
16
|
+
using DeviceType = c10::Device;
|
17
|
+
using Team = std::vector<DeviceIdxType>;
|
18
|
+
} // namespace nvfuser
|
@@ -0,0 +1,187 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <c10/util/ArrayRef.h>
|
11
|
+
|
12
|
+
#include <compute_at_map.h>
|
13
|
+
#include <fusion.h>
|
14
|
+
#include <id_model/id_model.h>
|
15
|
+
#include <ir/interface_nodes.h>
|
16
|
+
#include <multidevice/multidevice.h>
|
17
|
+
#include <visibility.h>
|
18
|
+
|
19
|
+
namespace nvfuser {
|
20
|
+
|
21
|
+
// Returns true iff nvFuser was compiled with distributed APIs enabled.
|
22
|
+
NVF_API bool distributedEnabled();
|
23
|
+
|
24
|
+
// For a resharding expression, either a set or reduce, returns root IDs
|
25
|
+
// that change sharding.
|
26
|
+
// (1) sharded root IterDomains that are added by the expression
|
27
|
+
// i.e. sharded IterDomains that are present in the output, but not the input.
|
28
|
+
// (2) sharded root IterDomains that are removed by the expression
|
29
|
+
// i.e. sharded IterDomains that are present in the input, but not the output.
|
30
|
+
// TODO: Analyze loop domain for unsharded/sharded IDs and return their
|
31
|
+
// parent root IDs.
|
32
|
+
std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
|
33
|
+
TensorView* producer,
|
34
|
+
TensorView* consumer);
|
35
|
+
|
36
|
+
// Returns whether a TensorView has a non-reduction axis parallelized Didx
|
37
|
+
// Checks that the other non-reduction axis are not parallelized on Didx
|
38
|
+
bool isSharded(const TensorView*);
|
39
|
+
|
40
|
+
// Returns number of device dimensions in a TensorView's loop domain.
|
41
|
+
int64_t numDeviceDims(const TensorView*);
|
42
|
+
|
43
|
+
// Returns the subset of tvs which elements have the different multi-device
|
44
|
+
// sharding as ref
|
45
|
+
template <typename TvIterator>
|
46
|
+
std::unordered_set<TensorView*> getTvsWithDifferentSharding(
|
47
|
+
TensorView* ref,
|
48
|
+
TvIterator tvs) {
|
49
|
+
std::unordered_set<TensorView*> ret;
|
50
|
+
const auto& reference_dom = ref->getLoopDomain();
|
51
|
+
FusionGuard fg(ref->fusion());
|
52
|
+
auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());
|
53
|
+
std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;
|
54
|
+
for (auto id : reference_dom) {
|
55
|
+
auto ca_id =
|
56
|
+
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE);
|
57
|
+
concrete_to_reference_map[ca_id] = id;
|
58
|
+
}
|
59
|
+
|
60
|
+
for (TensorView* tv : tvs) {
|
61
|
+
if (ref->getDeviceMesh().vector() != tv->getDeviceMesh().vector()) {
|
62
|
+
ret.insert(tv);
|
63
|
+
continue;
|
64
|
+
}
|
65
|
+
for (auto id : tv->getLoopDomain()) {
|
66
|
+
auto ca_id =
|
67
|
+
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE);
|
68
|
+
if (concrete_to_reference_map.count(ca_id) > 0) {
|
69
|
+
auto ref_id = concrete_to_reference_map.at(ca_id);
|
70
|
+
if ((ref_id->isDeviceDim() || id->isDeviceDim()) &&
|
71
|
+
ref_id->getParallelType() != id->getParallelType()) {
|
72
|
+
ret.insert(tv);
|
73
|
+
break;
|
74
|
+
}
|
75
|
+
}
|
76
|
+
}
|
77
|
+
}
|
78
|
+
return ret;
|
79
|
+
}
|
80
|
+
|
81
|
+
// Returns whether an Expr embeds multi-device resharding
|
82
|
+
bool isResharding(const Expr* expr);
|
83
|
+
|
84
|
+
// Returns whether two tensors have different shardings. Expect a
|
85
|
+
// producer/consumer relationship between the arguments.
|
86
|
+
bool haveDifferentShardings(
|
87
|
+
const TensorView* producer,
|
88
|
+
const TensorView* consumer,
|
89
|
+
const IdModel& id_model);
|
90
|
+
|
91
|
+
// Returns whether a resharding expr reshards an inner axis
|
92
|
+
bool isInnerResharding(Expr* expr);
|
93
|
+
|
94
|
+
// Shards all tensors in tvs like reference
|
95
|
+
void shardAllLike(TensorView* ref, std::vector<TensorView*> tvs);
|
96
|
+
|
97
|
+
// Shards all TVs between from and to AND between TVs created inside a fusion
|
98
|
+
// and to. This is required for (1) expressions like rng_uniform that create a
|
99
|
+
// TV inside a fusion that is not between a path from user visible TVs. (2)
|
100
|
+
// multi-output expressions may have output tensors that are not along a path to
|
101
|
+
// the fusion output which would not be reachable otherwise. (2) sharding
|
102
|
+
// propagation checks all TVs in the fusion are assigned a device mesh
|
103
|
+
// regardless if they are reachable. To keep the checks simple, we require all
|
104
|
+
// TVs are assigned a mesh if they exist in the fusion.
|
105
|
+
void shardBetween(
|
106
|
+
const std::vector<TensorView*>& from,
|
107
|
+
const std::vector<TensorView*>& to,
|
108
|
+
TensorView* ref);
|
109
|
+
// Same as above but using the outputs of the from and to expressions
|
110
|
+
// to form the from and to TVs.
|
111
|
+
void shardBetween(
|
112
|
+
const std::vector<Expr*>& from,
|
113
|
+
const std::vector<Expr*>& to,
|
114
|
+
TensorView* ref);
|
115
|
+
|
116
|
+
// Returns the devices involved in an expr
|
117
|
+
std::set<DeviceIdxType> involvedDevices(Expr* expr);
|
118
|
+
|
119
|
+
// Returns the number of device indices present accross all
|
120
|
+
// device meshes in the Fusion
|
121
|
+
int64_t requestedNumberOfDevices(Fusion*);
|
122
|
+
|
123
|
+
// remove the multi-device scheduling annotations
|
124
|
+
void unshard(Fusion*);
|
125
|
+
void unshard(TensorView*);
|
126
|
+
|
127
|
+
// Returns the index of the sharded logical axis that produces the allocation
|
128
|
+
// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel
|
129
|
+
// type, returns -1.
|
130
|
+
//
|
131
|
+
// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by
|
132
|
+
// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and
|
133
|
+
// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in
|
134
|
+
// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's
|
135
|
+
// extent if that IterDomain is sharded.
|
136
|
+
int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);
|
137
|
+
|
138
|
+
// Shards the input tensor along `axis`. How the tensor gets sliced along `axis`
|
139
|
+
// is determined by `mesh` and `device_id`. Returns the sharded tensor.
|
140
|
+
at::Tensor shardTensor(
|
141
|
+
at::Tensor tensor,
|
142
|
+
int64_t axis,
|
143
|
+
const DeviceMesh& mesh,
|
144
|
+
DeviceIdxType device_id);
|
145
|
+
|
146
|
+
// Reorders a TensorView so that the DID parallelized axis are in front.
|
147
|
+
void reorderDIDToFront(TensorView*);
|
148
|
+
|
149
|
+
// Given a TensorView and the shape of a sharded tensor of which certain
|
150
|
+
// dimensions are partially allocated, returns the global shape that'll be used
|
151
|
+
// to bind to the TensorView's logical domain. This is to solve #3282 so we can
|
152
|
+
// bind a sharded tensor to a TensorView that has a DID-parallel loop domain.
|
153
|
+
//
|
154
|
+
// For example, when `tv` is
|
155
|
+
// logical: iM, iN
|
156
|
+
// allocation: iDIDx{D}, iN/D, iM
|
157
|
+
// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because,
|
158
|
+
// according to the allocation domain, iM is fully allocated and iN is sharded
|
159
|
+
// and thus partially allocated.
|
160
|
+
//
|
161
|
+
// If the TensorView is not sharded, this function returns `sizes`.
|
162
|
+
//
|
163
|
+
// Limitations:
|
164
|
+
// - The function assumes that there are no Merges from logical to the
|
165
|
+
// DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical
|
166
|
+
// dimension this DID-parallelization should be attributed to.
|
167
|
+
// - The function assumes that all Splits from logical to the DID-parallel
|
168
|
+
// IterDomains in allocation are even. This is because there are currently no
|
169
|
+
// ways to pass in the global shape.
|
170
|
+
//
|
171
|
+
// Despite these limitations, I took this approach as a shortcut to fix #3282,
|
172
|
+
// which blocked many other tasks. I'm however open to other better, long-term
|
173
|
+
// solutions. Some alternatives considered in #3282 are:
|
174
|
+
// - Try to bind `at::Tensor`s to allocation domains instead of logical. Many
|
175
|
+
// `*Op::evaluate` methods (e.g.
|
176
|
+
// https://github.com/NVIDIA/Fuser/blob/2415d904d1e9a5da7ca6fb1a55d3045bbd510341/csrc/ir/nodes.cpp#L4321-L4329)
|
177
|
+
// assume the input/output `at::Tensor`s have the same dimension order as the
|
178
|
+
// logical domain. Doing so would have to change them all.
|
179
|
+
// - Try to pass into FusionExecutorCache both logical (global) shapes and
|
180
|
+
// allocated (local) tensors for sharded TensorViews. The logical shapes would
|
181
|
+
// have to be passed through FusionKernelRuntime, FusionExecutor,
|
182
|
+
// ExpressionEvaluator, and so on, which is an API overhaul.
|
183
|
+
std::vector<int64_t> unshardedSizes(
|
184
|
+
const TensorView* tv,
|
185
|
+
c10::IntArrayRef sizes);
|
186
|
+
|
187
|
+
} // namespace nvfuser
|
@@ -0,0 +1,86 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <visibility.h>
|
12
|
+
|
13
|
+
#include <ir/all_nodes.h>
|
14
|
+
#include <iter_visitor.h>
|
15
|
+
|
16
|
+
namespace nvfuser {
|
17
|
+
|
18
|
+
//! See doc/reading/divisibility-of-split.md#predication
|
19
|
+
//! If an IterDomain is split and its inner output domain is
|
20
|
+
//! eventually split too, the second split must be divisible or the
|
21
|
+
//! inner domain must be predicated. This class finds Split
|
22
|
+
//! expressions that need to be divisible or predicated.
|
23
|
+
//!
|
24
|
+
//! Second splits are not limited to just direct output domains of
|
25
|
+
//! first splits but also indirect descendent domains as well.
|
26
|
+
//!
|
27
|
+
//! Predicating non-divisible split domains does not work if split
|
28
|
+
//! output domains are vectorized where ParallelType::Vectorize is
|
29
|
+
//! applied to an inner domain of splits. If it's non-divisible,
|
30
|
+
//! predicating the input domain of the non-divisible split results in
|
31
|
+
//! a vectoried operation is predicated out entirely since we do not
|
32
|
+
//! generate a fall-back non-vectorized else path. Runtime check is
|
33
|
+
//! done for those domains.
|
34
|
+
class NVF_API NonDivisibleSplitInfo : public IterVisitor {
|
35
|
+
public:
|
36
|
+
void build(Fusion* fusion);
|
37
|
+
|
38
|
+
const auto& splitsToPredicate() const {
|
39
|
+
return splits_to_predicate_;
|
40
|
+
}
|
41
|
+
|
42
|
+
const auto& splitsToValidate() const {
|
43
|
+
return splits_to_validate_;
|
44
|
+
}
|
45
|
+
|
46
|
+
private:
|
47
|
+
using IterVisitor::handle;
|
48
|
+
|
49
|
+
void handle(Split* split) override;
|
50
|
+
|
51
|
+
void handle(Merge* merge) override;
|
52
|
+
|
53
|
+
//! True if reachable from inner domains of splits
|
54
|
+
bool isReachableFromInnerDomains(IterDomain* id) const;
|
55
|
+
|
56
|
+
//! Forward propagate the reachability information
|
57
|
+
void propagateReachability(Split* split, bool is_protected);
|
58
|
+
|
59
|
+
//! Forward propagate the reachability information
|
60
|
+
void propagateReachability(Merge* merge);
|
61
|
+
|
62
|
+
void clearReachability();
|
63
|
+
|
64
|
+
//! Returns the extent of a split output domain if it's not proven to
|
65
|
+
//! be divisible.
|
66
|
+
Val* getMaybeNonDivisibleExtent(Split* split) const;
|
67
|
+
|
68
|
+
//! Remove redundant predicates as divisibility may be validated at
|
69
|
+
//! run time
|
70
|
+
void removeRedundancy();
|
71
|
+
|
72
|
+
//! Add validations to GpuLower::current()->validations()
|
73
|
+
void addValidations();
|
74
|
+
|
75
|
+
private:
|
76
|
+
//! Split expressions whose input domain must be predicated
|
77
|
+
std::unordered_map<TensorView*, std::vector<Split*>> splits_to_predicate_;
|
78
|
+
//! Split expressions whose divisibility must be validated at run time
|
79
|
+
std::unordered_set<Split*> splits_to_validate_;
|
80
|
+
|
81
|
+
//! Temporarily used for analyzing each tensor
|
82
|
+
TensorView* current_tv_ = nullptr;
|
83
|
+
std::unordered_set<IterDomain*> inner_domains_;
|
84
|
+
};
|
85
|
+
|
86
|
+
} // namespace nvfuser
|