nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,27 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ATen/ATen.h>
|
11
|
+
|
12
|
+
namespace nvfuser {
|
13
|
+
|
14
|
+
//! This returns a slice of a thread local at::Tensor that contains all zeroes.
|
15
|
+
//! Uses of this memory should always "clean up" by resetting the memory to zero
|
16
|
+
//! at the end of the kernel.
|
17
|
+
at::Tensor contigZeroedTensor(
|
18
|
+
const std::vector<int64_t>& sizes,
|
19
|
+
const c10::ScalarType& aten_dtype,
|
20
|
+
const c10::Device& device);
|
21
|
+
|
22
|
+
//! This should be called after each kernel launch to allow subsequent launches
|
23
|
+
//! to re-use allocated memory. Note that it does not free allocated zeroed
|
24
|
+
//! memory, but rather it marks all zeroed memory as available for re-use.
|
25
|
+
void releaseZeroedMemory();
|
26
|
+
|
27
|
+
} // namespace nvfuser
|
@@ -0,0 +1,47 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <ir/all_nodes.h>
|
12
|
+
#include <visibility.h>
|
13
|
+
|
14
|
+
namespace nvfuser {
|
15
|
+
|
16
|
+
//! Horizontally fuse multiple reductions.
|
17
|
+
//!
|
18
|
+
//! Given a list of tensors produced by ReductionOp, create a new
|
19
|
+
//! GroupedReductionOp expression that takes the input tensors of the
|
20
|
+
//! original reductions and produces the given tensors, replacing
|
21
|
+
//! their defining expressions.
|
22
|
+
//!
|
23
|
+
//! GroupedReductionOp works just like ReductionOp with a potential
|
24
|
+
//! benefit of aggregating synchronizations across individual
|
25
|
+
//! reductions. See the reduction::gridReduce2 runtime function for a
|
26
|
+
//! two-input version of grid reduction.
|
27
|
+
//!
|
28
|
+
//! The grouped reductions must follow several constraints, which
|
29
|
+
//! include:
|
30
|
+
//! - There must not exist any data dependency between individual
|
31
|
+
//! reductions.
|
32
|
+
//! - All reduction output tensors must have the same number of
|
33
|
+
//! dimensions, the same transformations and the same axes to
|
34
|
+
//! reduce.
|
35
|
+
//!
|
36
|
+
//! Note that Welford is not allowed yet, though it should be
|
37
|
+
//! technically straightforward to support horizontal fusions of
|
38
|
+
//! welford ops. Unclear how common it would be in practice, though.
|
39
|
+
//!
|
40
|
+
//! \param reduction_outputs Tensors produced by ReductionOp
|
41
|
+
//! \param error_on_failure Throw an exception if an error is detected
|
42
|
+
//! \return True if successfully grouped
|
43
|
+
NVF_API bool groupReductions(
|
44
|
+
const std::vector<TensorView*>& reduction_outputs,
|
45
|
+
bool error_on_failure = true);
|
46
|
+
|
47
|
+
} // namespace nvfuser
|
@@ -0,0 +1,60 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <fusion.h>
|
11
|
+
#include <host_ir/host_ir.h>
|
12
|
+
|
13
|
+
namespace nvfuser {
|
14
|
+
|
15
|
+
class KernelExecutor;
|
16
|
+
|
17
|
+
namespace hir {
|
18
|
+
|
19
|
+
/*
|
20
|
+
HostIrContainer is used to represent a host program.
|
21
|
+
1) It inherits from Fusion, so that (Host) IRs can be resgistered to it.
|
22
|
+
2) It holds a vector of Host Expressions `top_level_exprs_` that represent the
|
23
|
+
host program. For now, this vector is manually managed. Moreover, because we use
|
24
|
+
a vector as data structure, top_level_exprs_ can only represent linear Host
|
25
|
+
programs. Later, we it should support non-linear program having a DAG structure.
|
26
|
+
*/
|
27
|
+
|
28
|
+
class HostIrContainer final : public Fusion {
|
29
|
+
public:
|
30
|
+
HostIrContainer() = default;
|
31
|
+
HostIrContainer(const HostIrContainer&) = delete;
|
32
|
+
HostIrContainer& operator=(const HostIrContainer&) = delete;
|
33
|
+
|
34
|
+
// Do not have a definition here as it requires the definition of
|
35
|
+
// KernelExecutor due to kernel_executors_.
|
36
|
+
// NOLINTNEXTLINE (modernize-use-equals-default)
|
37
|
+
~HostIrContainer() override;
|
38
|
+
|
39
|
+
//! Print to an output stream
|
40
|
+
std::ostream& print(std::ostream& os) const;
|
41
|
+
|
42
|
+
const std::vector<Expr*>& topLevelExprs() const;
|
43
|
+
|
44
|
+
void pushBackTopLevelExprs(Expr* expr);
|
45
|
+
|
46
|
+
void pushBackKernelExecutor(std::unique_ptr<KernelExecutor> ke);
|
47
|
+
|
48
|
+
KernelExecutor* getKernelExecutor(int64_t index) const;
|
49
|
+
|
50
|
+
Stream* getDefaultStream();
|
51
|
+
|
52
|
+
private:
|
53
|
+
std::vector<Expr*> top_level_exprs_;
|
54
|
+
std::vector<std::unique_ptr<KernelExecutor>> kernel_executors_;
|
55
|
+
Stream* default_stream_ = nullptr;
|
56
|
+
};
|
57
|
+
|
58
|
+
} // namespace hir
|
59
|
+
|
60
|
+
} // namespace nvfuser
|
@@ -0,0 +1,152 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <dispatch.h>
|
11
|
+
#include <expr_evaluator.h>
|
12
|
+
#include <host_ir/container.h>
|
13
|
+
#include <host_ir/host_ir.h>
|
14
|
+
#include <multidevice/communicator.h>
|
15
|
+
#include <runtime/executor.h>
|
16
|
+
#include <runtime/executor_abstract.h>
|
17
|
+
#include <runtime/executor_params.h>
|
18
|
+
#include <runtime/fusion_executor_cache.h>
|
19
|
+
|
20
|
+
#include <c10/cuda/CUDAStream.h>
|
21
|
+
|
22
|
+
namespace nvfuser {
|
23
|
+
|
24
|
+
class HostIrExecutor : public ExecutorAbstract {
|
25
|
+
public:
|
26
|
+
HostIrExecutor(
|
27
|
+
int64_t fusion_id = 0,
|
28
|
+
int64_t concrete_id = 0,
|
29
|
+
int64_t runtime_id = 0,
|
30
|
+
int64_t group_id = 0);
|
31
|
+
|
32
|
+
static bool supported(Fusion* fusion);
|
33
|
+
|
34
|
+
void compile(Fusion* fusion);
|
35
|
+
|
36
|
+
bool isCompiled() const override;
|
37
|
+
|
38
|
+
NVF_API std::vector<at::Tensor> run(
|
39
|
+
KernelArgumentHolder& args,
|
40
|
+
std::vector<at::Tensor> outputs = {});
|
41
|
+
|
42
|
+
const std::unique_ptr<hir::HostIrContainer>& hostContainer() const {
|
43
|
+
return host_ir_container_;
|
44
|
+
}
|
45
|
+
|
46
|
+
private:
|
47
|
+
std::unique_ptr<hir::HostIrContainer> host_ir_container_;
|
48
|
+
Communicator* communicator_;
|
49
|
+
};
|
50
|
+
|
51
|
+
namespace hir {
|
52
|
+
|
53
|
+
/*
|
54
|
+
a HostIrEvaluator evaluates a host programs represented through a
|
55
|
+
HostIrContainer It is instantiated with the desired HostIrContainer, and runs
|
56
|
+
the Host program with concrete inputs by calling the method runWithInput.
|
57
|
+
|
58
|
+
For now HostIrEvaluator is an interpreter; later we could rather compile host
|
59
|
+
code.
|
60
|
+
|
61
|
+
Note: most of the implementation is copy pasted for MultiDeviceExecutor. This
|
62
|
+
duplication will be resolved in the future.
|
63
|
+
*/
|
64
|
+
|
65
|
+
// Set of parameters that control the behavior of HostIrEvaluator
|
66
|
+
struct HostIrEvaluatorParams {
|
67
|
+
// Experimental: whether to use FusionExecutorCache rather than
|
68
|
+
// KernelExecutor.
|
69
|
+
bool use_fusion_executor_cache = false;
|
70
|
+
// Experimental: whether to apply auto-scheduling in FusionExecutorCache if
|
71
|
+
// use_fusion_executor_cache=true. WAR: temporary hack mainly use for
|
72
|
+
// development
|
73
|
+
bool skip_auto_scheduling = false;
|
74
|
+
// Experimental: whether to cache fusion executor. WAR: avoid recompilation
|
75
|
+
// but implicitely assumes that the input shape don't change over iterations
|
76
|
+
bool cache_fusion_executor = false;
|
77
|
+
// number of additional cuda streams to use at runtime for comm+compute
|
78
|
+
// pipelining
|
79
|
+
int64_t number_of_streams = 4;
|
80
|
+
};
|
81
|
+
|
82
|
+
class HostIrEvaluator final : public OptOutDispatch {
|
83
|
+
public:
|
84
|
+
HostIrEvaluator(
|
85
|
+
std::unique_ptr<HostIrContainer> container,
|
86
|
+
Communicator* communicator = nullptr,
|
87
|
+
HostIrEvaluatorParams = HostIrEvaluatorParams());
|
88
|
+
std::vector<at::Tensor> runWithInput(
|
89
|
+
std::unordered_map<Val*, c10::IValue> val_to_IValue);
|
90
|
+
|
91
|
+
const std::vector<Val*>& inputs() {
|
92
|
+
return container_->inputs();
|
93
|
+
}
|
94
|
+
|
95
|
+
const std::vector<Val*>& outputs() {
|
96
|
+
return container_->outputs();
|
97
|
+
}
|
98
|
+
|
99
|
+
std::ostream& print(std::ostream& os) const {
|
100
|
+
return container_->print(os);
|
101
|
+
};
|
102
|
+
|
103
|
+
const auto& getFusionExecutorCaches() {
|
104
|
+
return fec_;
|
105
|
+
};
|
106
|
+
|
107
|
+
const auto& getCudaStreams() {
|
108
|
+
return streams_;
|
109
|
+
}
|
110
|
+
|
111
|
+
// check if the runtime is valid returns an error msg.
|
112
|
+
// An empty message means that the runtime is valid
|
113
|
+
std::string canRun() const;
|
114
|
+
|
115
|
+
private:
|
116
|
+
using OptOutDispatch::handle;
|
117
|
+
void handle(SetCurrentStream* set_current_stream) override;
|
118
|
+
void handle(GetCurrentStream* get_current_stream) override;
|
119
|
+
void handle(Synchronize* synchronize) override;
|
120
|
+
void handle(PostOnStream* post_ir) override;
|
121
|
+
void handle(LaunchKernel* post_ir) override;
|
122
|
+
void handle(Communication* communication) override;
|
123
|
+
void handle(P2PCommunication* communication) override;
|
124
|
+
void handle(Wait* wait) override;
|
125
|
+
void handle(ForLoop* for_loop) override;
|
126
|
+
void handle(StartCoalescing* start_coalescing) override;
|
127
|
+
void handle(EndCoalescing* end_coalescing) override;
|
128
|
+
void handle(kir::IfThenElse* if_then_else) override;
|
129
|
+
void handle(MatmulOp* matmul) override;
|
130
|
+
void handle(LinearOp* linear) override;
|
131
|
+
void handle(kir::Allocate* allocate) override;
|
132
|
+
void unhandled(Statement* stmt) override;
|
133
|
+
|
134
|
+
c10::cuda::CUDAStream getCUDAStream(Stream* stream);
|
135
|
+
|
136
|
+
std::unique_ptr<HostIrContainer> container_;
|
137
|
+
Communicator* communicator_;
|
138
|
+
HostIrEvaluatorParams params_;
|
139
|
+
// Stores concrete computed values
|
140
|
+
ExpressionEvaluator expr_evaluator_;
|
141
|
+
// Cache Fusions, KernelExecutors
|
142
|
+
std::unordered_map<HostUnit*, std::unique_ptr<ExecutorAbstract>> executors_;
|
143
|
+
std::unordered_map<HostUnit*, FusionExecutorCache> fec_;
|
144
|
+
using StreamKey = std::variant<int64_t, Stream*>;
|
145
|
+
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
|
146
|
+
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
|
147
|
+
const int64_t my_device_index_;
|
148
|
+
};
|
149
|
+
|
150
|
+
} // namespace hir
|
151
|
+
|
152
|
+
} // namespace nvfuser
|
@@ -0,0 +1,320 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <fusion.h>
|
11
|
+
#include <ir/base_nodes.h>
|
12
|
+
#include <ir/builder.h>
|
13
|
+
#include <multidevice/communication.h>
|
14
|
+
#include <atomic>
|
15
|
+
|
16
|
+
namespace nvfuser {
|
17
|
+
|
18
|
+
namespace hir {
|
19
|
+
|
20
|
+
/*
|
21
|
+
Host Irs are used to represent a host program. They need to be registered in a
|
22
|
+
HostIrContainer. Each Ir represents a Host data or instruction.
|
23
|
+
*/
|
24
|
+
|
25
|
+
/*
|
26
|
+
HostUnit represents a Fusion in the Host Program. In other words, it
|
27
|
+
represents a compute graph (or a segment of a larger compute graph)
|
28
|
+
represented by a Fusion that should be compiled and executed as a bulked item
|
29
|
+
from the host perspective.
|
30
|
+
|
31
|
+
This IR can be thought as a thin layer around the class `Fusion`, which
|
32
|
+
furthermore inherits from `Expr` so that it is an "IR" in nvFuser IR
|
33
|
+
semantics.
|
34
|
+
|
35
|
+
This IRs fundamentally allows nested IR structures. It could potentially be
|
36
|
+
useful in other instances than HostIrs.
|
37
|
+
|
38
|
+
Its implementation is minimal, the only specifity being the moethod
|
39
|
+
`fusion_to_execute()` that returns the fusion that the IR represents.
|
40
|
+
|
41
|
+
Note: HostUnit has no I/O itself -- however the Fusion it embbeds has I/O of
|
42
|
+
course, which are not registered in the surrounding HostIrContainer.
|
43
|
+
|
44
|
+
Note: Whether HostUnit should inherit from Expr or Val is debatable. Both are
|
45
|
+
possible, I define it as an Expr for now here but am open to change it.
|
46
|
+
*/
|
47
|
+
class HostUnit : public Expr {
|
48
|
+
public:
|
49
|
+
using Expr::Expr;
|
50
|
+
HostUnit(IrBuilderPasskey passkey, std::unique_ptr<Fusion> fusion);
|
51
|
+
HostUnit(const HostUnit* src, IrCloner* ir_cloner);
|
52
|
+
|
53
|
+
HostUnit(const HostUnit& other) = delete;
|
54
|
+
HostUnit& operator=(const HostUnit& other) = delete;
|
55
|
+
HostUnit(HostUnit&& other) = delete;
|
56
|
+
HostUnit& operator=(HostUnit&& other) = delete;
|
57
|
+
|
58
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
59
|
+
std::string toString(int indent_size = 0) const override;
|
60
|
+
std::string toInlineString(int indent_size = 0) const override;
|
61
|
+
const char* getOpString() const override {
|
62
|
+
return "hir::HostUnit";
|
63
|
+
}
|
64
|
+
|
65
|
+
bool sameAs(const Statement* other) const override;
|
66
|
+
|
67
|
+
Fusion* fusion_to_execute() const {
|
68
|
+
return fusion_.get();
|
69
|
+
}
|
70
|
+
|
71
|
+
private:
|
72
|
+
std::unique_ptr<Fusion> fusion_;
|
73
|
+
};
|
74
|
+
|
75
|
+
/*
|
76
|
+
PostOnStream represents the host instruction of executing a HostUnit. Its I/O
|
77
|
+
represents in the host program the concrete I/O that will be bound at runtime
|
78
|
+
to the Fusion's I/O for compilation and execution. At runtime, PostOnStream
|
79
|
+
will compile and launch the kernel lowered from the HostUnit's embedded
|
80
|
+
Fusion.
|
81
|
+
|
82
|
+
Note: later PostOnStream will take a "Stream" argument
|
83
|
+
|
84
|
+
Note: later PostOnStream will also be able to launch network Communications
|
85
|
+
|
86
|
+
Note: later compilation and kernel launch will be separated and represented by
|
87
|
+
distinct Host IRs
|
88
|
+
*/
|
89
|
+
class PostOnStream : public Expr {
|
90
|
+
public:
|
91
|
+
using Expr::Expr;
|
92
|
+
PostOnStream(
|
93
|
+
IrBuilderPasskey passkey,
|
94
|
+
Expr* host_op,
|
95
|
+
std::vector<Val*> inputs,
|
96
|
+
std::vector<Val*> outputs);
|
97
|
+
|
98
|
+
PostOnStream(const PostOnStream& other) = delete;
|
99
|
+
PostOnStream& operator=(const PostOnStream& other) = delete;
|
100
|
+
PostOnStream(PostOnStream&& other) = delete;
|
101
|
+
PostOnStream& operator=(PostOnStream&& other) = delete;
|
102
|
+
|
103
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
104
|
+
|
105
|
+
std::string toString(int indent_size = 0) const override;
|
106
|
+
std::string toInlineString(int indent_size = 0) const override;
|
107
|
+
const char* getOpString() const override {
|
108
|
+
return "hir::PostOnStream";
|
109
|
+
}
|
110
|
+
|
111
|
+
bool sameAs(const Statement* other) const override;
|
112
|
+
|
113
|
+
Expr* hostOpToPost() const {
|
114
|
+
return attributes_.at(0)->as<Expr>();
|
115
|
+
}
|
116
|
+
};
|
117
|
+
|
118
|
+
class LaunchKernel : public Expr {
|
119
|
+
public:
|
120
|
+
using Expr::Expr;
|
121
|
+
LaunchKernel(
|
122
|
+
IrBuilderPasskey passkey,
|
123
|
+
int64_t hic_executor_index, // Index into the HostIrContainer's vector of
|
124
|
+
// KernelExecutors--i.e., the kernel this IR
|
125
|
+
// should launch
|
126
|
+
const std::vector<Val*>& inputs,
|
127
|
+
const std::vector<Val*>& outputs);
|
128
|
+
|
129
|
+
LaunchKernel(const LaunchKernel& other) = delete;
|
130
|
+
LaunchKernel& operator=(const LaunchKernel& other) = delete;
|
131
|
+
LaunchKernel(LaunchKernel&& other) = delete;
|
132
|
+
LaunchKernel& operator=(LaunchKernel&& other) = delete;
|
133
|
+
|
134
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
135
|
+
|
136
|
+
std::string toString(int indent_size = 0) const override;
|
137
|
+
std::string toInlineString(int indent_size = 0) const override;
|
138
|
+
const char* getOpString() const override {
|
139
|
+
return "hir::LaunchKernel";
|
140
|
+
}
|
141
|
+
|
142
|
+
int64_t getIndex() const {
|
143
|
+
return attribute<int64_t>(0);
|
144
|
+
}
|
145
|
+
};
|
146
|
+
|
147
|
+
class Stream : public Val {
|
148
|
+
public:
|
149
|
+
// if index is provided, the IR represents the streams whose index is the
|
150
|
+
// dynamic value of that index. Otherwise, it statically represents a new
|
151
|
+
// Stream.
|
152
|
+
Stream(IrBuilderPasskey passkey, Val* index = nullptr);
|
153
|
+
Stream(const Stream* src, IrCloner* ir_cloner);
|
154
|
+
bool sameAs(const Statement* other) const override;
|
155
|
+
|
156
|
+
NVFUSER_DECLARE_CLONE
|
157
|
+
std::string toString(int indent_size = 0) const override;
|
158
|
+
std::string toInlineString(int indent_size = 0) const override;
|
159
|
+
|
160
|
+
Val* index() const {
|
161
|
+
return index_;
|
162
|
+
}
|
163
|
+
|
164
|
+
private:
|
165
|
+
Val* index_ = nullptr;
|
166
|
+
};
|
167
|
+
|
168
|
+
class SetCurrentStream : public Expr {
|
169
|
+
public:
|
170
|
+
using Expr::Expr;
|
171
|
+
SetCurrentStream(IrBuilderPasskey passkey, Stream* stream);
|
172
|
+
|
173
|
+
SetCurrentStream(const SetCurrentStream& other) = delete;
|
174
|
+
SetCurrentStream& operator=(const SetCurrentStream& other) = delete;
|
175
|
+
SetCurrentStream(SetCurrentStream&& other) = delete;
|
176
|
+
SetCurrentStream& operator=(SetCurrentStream&& other) = delete;
|
177
|
+
|
178
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
179
|
+
|
180
|
+
std::string toString(int indent_size = 0) const override;
|
181
|
+
std::string toInlineString(int indent_size = 0) const override;
|
182
|
+
const char* getOpString() const override {
|
183
|
+
return "hir::SetCurrentStream";
|
184
|
+
}
|
185
|
+
|
186
|
+
bool sameAs(const Statement* other) const override;
|
187
|
+
|
188
|
+
Stream* stream() const {
|
189
|
+
return attributes_.at(0)->as<Stream>();
|
190
|
+
}
|
191
|
+
};
|
192
|
+
|
193
|
+
class GetCurrentStream : public Expr {
|
194
|
+
public:
|
195
|
+
using Expr::Expr;
|
196
|
+
GetCurrentStream(IrBuilderPasskey passkey);
|
197
|
+
|
198
|
+
GetCurrentStream(const GetCurrentStream& other) = delete;
|
199
|
+
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
|
200
|
+
GetCurrentStream(GetCurrentStream&& other) = delete;
|
201
|
+
GetCurrentStream& operator=(GetCurrentStream&& other) = delete;
|
202
|
+
|
203
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
204
|
+
|
205
|
+
std::string toString(int indent_size = 0) const override;
|
206
|
+
const char* getOpString() const override {
|
207
|
+
return "hir::GetCurrentStream";
|
208
|
+
}
|
209
|
+
|
210
|
+
Stream* stream() const {
|
211
|
+
return attributes_.at(0)->as<Stream>();
|
212
|
+
}
|
213
|
+
};
|
214
|
+
|
215
|
+
class Wait : public Expr {
|
216
|
+
public:
|
217
|
+
using Expr::Expr;
|
218
|
+
Wait(IrBuilderPasskey passkey, Expr* expr);
|
219
|
+
|
220
|
+
Wait(const Wait& other) = delete;
|
221
|
+
Wait& operator=(const Wait& other) = delete;
|
222
|
+
Wait(Wait&& other) = delete;
|
223
|
+
Wait& operator=(Wait&& other) = delete;
|
224
|
+
|
225
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
226
|
+
|
227
|
+
std::string toString(int indent_size = 0) const override;
|
228
|
+
std::string toInlineString(int indent_size = 0) const override;
|
229
|
+
const char* getOpString() const override {
|
230
|
+
return "hir::Wait";
|
231
|
+
}
|
232
|
+
|
233
|
+
bool sameAs(const Statement* other) const override;
|
234
|
+
|
235
|
+
Expr* communication() const {
|
236
|
+
return attributes_.at(0)->as<Expr>();
|
237
|
+
}
|
238
|
+
};
|
239
|
+
|
240
|
+
// Makes the current stream wait on the given stream. Non-blocking from the host
|
241
|
+
// point of view.
|
242
|
+
class Synchronize : public Expr {
|
243
|
+
public:
|
244
|
+
using Expr::Expr;
|
245
|
+
Synchronize(IrBuilderPasskey passkey, Stream* stream);
|
246
|
+
|
247
|
+
Synchronize(const Synchronize& other) = delete;
|
248
|
+
Synchronize& operator=(const Synchronize& other) = delete;
|
249
|
+
Synchronize(Synchronize&& other) = delete;
|
250
|
+
Synchronize& operator=(Synchronize&& other) = delete;
|
251
|
+
|
252
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
253
|
+
|
254
|
+
std::string toString(int indent_size = 0) const override;
|
255
|
+
std::string toInlineString(int indent_size = 0) const override;
|
256
|
+
const char* getOpString() const override {
|
257
|
+
return "hir::Synchronize";
|
258
|
+
}
|
259
|
+
|
260
|
+
bool sameAs(const Statement* other) const override;
|
261
|
+
|
262
|
+
Stream* stream() const {
|
263
|
+
return attributes_.at(0)->as<Stream>();
|
264
|
+
}
|
265
|
+
};
|
266
|
+
|
267
|
+
// For ProcessGroupNCCL, startCoalescing and endCoalescing correspond to
|
268
|
+
// ncclGroupStart and ncclGroupEnd respectively. Those calls group p2p calls
|
269
|
+
// that need to be progressed together -- one global work handle returned by
|
270
|
+
// endCoalescing needs to be progressed. This has the following main advantages:
|
271
|
+
// 1) calls are progressed concurrently
|
272
|
+
// 2) since NICs are two-sided, a send and a recv calls need to be coalesced to
|
273
|
+
// achieve full BW.
|
274
|
+
// 3) If not coalesced, we can easily reach a deadlock if the
|
275
|
+
// send/recv pairs are not ordered correctly.
|
276
|
+
// It is in general preferable to coalesce send/recv calls. The only drawback is
|
277
|
+
// that we don't have a fine-grain control on synchronicity, in other words, we
|
278
|
+
// can only synchronize with the grouped communication at once.
|
279
|
+
// Remark: ProcessGroupUCC does not implement coalesced groups for now
|
280
|
+
class StartCoalescing : public Expr {
|
281
|
+
public:
|
282
|
+
using Expr::Expr;
|
283
|
+
StartCoalescing(IrBuilderPasskey passkey);
|
284
|
+
|
285
|
+
StartCoalescing(const StartCoalescing& other) = delete;
|
286
|
+
StartCoalescing& operator=(const StartCoalescing& other) = delete;
|
287
|
+
StartCoalescing(StartCoalescing&& other) = delete;
|
288
|
+
StartCoalescing& operator=(StartCoalescing&& other) = delete;
|
289
|
+
|
290
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
291
|
+
|
292
|
+
std::string toString(int indent_size = 0) const override;
|
293
|
+
std::string toInlineString(int indent_size = 0) const override;
|
294
|
+
const char* getOpString() const override {
|
295
|
+
return "hir::StartCoalescing";
|
296
|
+
}
|
297
|
+
};
|
298
|
+
|
299
|
+
class EndCoalescing : public Expr {
|
300
|
+
public:
|
301
|
+
using Expr::Expr;
|
302
|
+
EndCoalescing(IrBuilderPasskey passkey);
|
303
|
+
|
304
|
+
EndCoalescing(const EndCoalescing& other) = delete;
|
305
|
+
EndCoalescing& operator=(const EndCoalescing& other) = delete;
|
306
|
+
EndCoalescing(EndCoalescing&& other) = delete;
|
307
|
+
EndCoalescing& operator=(EndCoalescing&& other) = delete;
|
308
|
+
|
309
|
+
NVFUSER_DECLARE_CLONE_AND_CREATE
|
310
|
+
|
311
|
+
std::string toString(int indent_size = 0) const override;
|
312
|
+
std::string toInlineString(int indent_size = 0) const override;
|
313
|
+
const char* getOpString() const override {
|
314
|
+
return "hir::EndCoalescing";
|
315
|
+
}
|
316
|
+
};
|
317
|
+
|
318
|
+
} // namespace hir
|
319
|
+
|
320
|
+
} // namespace nvfuser
|
@@ -0,0 +1,35 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <host_ir/container.h>
|
11
|
+
#include <ir/base_nodes.h>
|
12
|
+
#include <multidevice/communication.h>
|
13
|
+
#include <multidevice/multidevice.h>
|
14
|
+
|
15
|
+
namespace nvfuser {
|
16
|
+
|
17
|
+
class HostIrLower {
|
18
|
+
public:
|
19
|
+
// The flag `ignore_inner_resharding` is useful because the preseg passes
|
20
|
+
// `InsertReshardingsPass` and `ReorderShardedAxisPass` want different
|
21
|
+
// behaviors
|
22
|
+
static bool canLower(Expr* expr, bool ignore_inner_resharding = false);
|
23
|
+
|
24
|
+
// Lower a sharded Expr into a series of Communication.
|
25
|
+
static std::vector<Expr*> lower(Expr* c);
|
26
|
+
|
27
|
+
static std::unique_ptr<hir::HostIrContainer> lower(
|
28
|
+
std::unique_ptr<Fusion> fusion,
|
29
|
+
int64_t my_device_index);
|
30
|
+
|
31
|
+
private:
|
32
|
+
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
|
33
|
+
};
|
34
|
+
|
35
|
+
} // namespace nvfuser
|
@@ -0,0 +1,56 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <device_lower/lower2device.h>
|
11
|
+
#include <device_lower/utils.h>
|
12
|
+
#include <id_model/id_model.h>
|
13
|
+
|
14
|
+
namespace nvfuser {
|
15
|
+
|
16
|
+
// Get the loop index of a given loop domain for circular buffer
|
17
|
+
// loops. nullptr is returned if not relevant.
|
18
|
+
//
|
19
|
+
// This is a WAR for circular buffering. TensorIndexer has a map of
|
20
|
+
// loop indices for all loop groups, however, it does not work with
|
21
|
+
// circular buffering. The loop graph is
|
22
|
+
// designed to represent each loop and each loop group is supposed
|
23
|
+
// to have a one-to-one relationship with each loop. However, for
|
24
|
+
// circular buffering, this assumption is broken as we are using
|
25
|
+
// the same iter domain for the prologue, main and epilogue
|
26
|
+
// loops. Ideally, those loops should have distinctive loop groups,
|
27
|
+
// but for now, here's a workaround to get a correct loop index
|
28
|
+
Val* getLoopIndexOfCircularBufferLoop(
|
29
|
+
IterDomain* loop_id,
|
30
|
+
const std::vector<ForLoop*>& for_loops,
|
31
|
+
const IdModel& id_model);
|
32
|
+
|
33
|
+
// For a circular-buffering expr, the producer loop index needs to be
|
34
|
+
// advanced by (#stages - 1) if it's the main loop. Return the offset
|
35
|
+
// if it's applicable. Otherwise, nullptr is returned.
|
36
|
+
Val* getLoopIndexOffsetForProducerOfCircularBuffer(
|
37
|
+
const Expr* expr,
|
38
|
+
const ForLoop* for_loop,
|
39
|
+
const IdModel& id_model);
|
40
|
+
|
41
|
+
// Get the additional offset for a circular buffer. This offset will
|
42
|
+
// be added to the normal linear index. For example, if this is a
|
43
|
+
// double buffered tensor, the offset would look like "i % 2", where i
|
44
|
+
// is the loop index of the double-buffer loop.
|
45
|
+
Val* getOffsetForCircularBufferTensor(
|
46
|
+
TensorView* circular_buffer_tv,
|
47
|
+
bool as_consumer,
|
48
|
+
const std::vector<ForLoop*>& for_loops);
|
49
|
+
|
50
|
+
// Find the circular buffering stage of a given circular buffered tensor
|
51
|
+
CircularBufferLoopStage getCircularBufferLoopStage(
|
52
|
+
const TensorView* circular_buffer_tv,
|
53
|
+
const std::vector<ForLoop*>& for_loops,
|
54
|
+
const ValGraph& loop_graph);
|
55
|
+
|
56
|
+
} // namespace nvfuser
|