nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,661 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <visibility.h>
|
12
|
+
|
13
|
+
#include <bfs.h>
|
14
|
+
#include <dispatch.h>
|
15
|
+
#include <ir/base_nodes.h>
|
16
|
+
#include <type.h>
|
17
|
+
|
18
|
+
#include <deque>
|
19
|
+
#include <unordered_set>
|
20
|
+
#include <vector>
|
21
|
+
|
22
|
+
namespace nvfuser {
|
23
|
+
|
24
|
+
class Fusion;
|
25
|
+
|
26
|
+
/*
|
27
|
+
* IterVisitor starts from leaf nodes, fusion outputs, or the provided values.
|
28
|
+
* It walks the DAG bacwkards from the starting nodes, to roots. Each node in
|
29
|
+
* the dag will be called with handle(Statement*) in topolgical order inputs of
|
30
|
+
* the fusion to outputs of the fusion.
|
31
|
+
*
|
32
|
+
* TODO: We may want a BFS version of this code to extract ILP, not implemented
|
33
|
+
* yet.
|
34
|
+
*
|
35
|
+
* TODO: We may want to have ordering of outputs to inputs. I'm not sure why we
|
36
|
+
* would want this, but seems like it would be a reasonable request.
|
37
|
+
*/
|
38
|
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
39
|
+
class NVF_API IterVisitor : public OptOutDispatch {
|
40
|
+
public:
|
41
|
+
~IterVisitor() override = default;
|
42
|
+
|
43
|
+
IterVisitor() = default;
|
44
|
+
|
45
|
+
IterVisitor(const IterVisitor& other) = default;
|
46
|
+
IterVisitor& operator=(const IterVisitor& other) = default;
|
47
|
+
|
48
|
+
IterVisitor(IterVisitor&& other) = default;
|
49
|
+
IterVisitor& operator=(IterVisitor&& other) = default;
|
50
|
+
|
51
|
+
protected:
|
52
|
+
// Functions return nodes in reverse order to be added to the to_visit queue
|
53
|
+
// These functions will start at outputs and propagate up through the DAG
|
54
|
+
// to inputs based on depth first traversal. Next could be called on a node
|
55
|
+
// multiple times.
|
56
|
+
virtual std::vector<Statement*> next(Statement* stmt);
|
57
|
+
|
58
|
+
virtual std::vector<Statement*> next(Val* v);
|
59
|
+
|
60
|
+
virtual std::vector<Statement*> next(Expr* expr);
|
61
|
+
|
62
|
+
using OptOutDispatch::handle;
|
63
|
+
|
64
|
+
// This dispatch functions is called on every Statement* in topological order,
|
65
|
+
// starting from outputs to inputs.
|
66
|
+
void dispatch(Statement* s) override;
|
67
|
+
|
68
|
+
// This dispatch functions is called on every Expr* in topological order,
|
69
|
+
// starting from outputs to inputs.
|
70
|
+
void dispatch(Expr* e) override;
|
71
|
+
|
72
|
+
// This dispatch functions is called on every Val* in topological order,
|
73
|
+
// starting from outputs to inputs.
|
74
|
+
void dispatch(Val* v) override;
|
75
|
+
|
76
|
+
// The entire stack during traversal. stmt_stack.back().back() is the node
|
77
|
+
// that is being called in handle(). stmt_stack.back() contains siblings (not
|
78
|
+
// guarenteed to be all siblings throughout traversal). stmt_stack.front()
|
79
|
+
// contains the outputs we started with (not guarenteed to be all outputs
|
80
|
+
// throughout traversal).
|
81
|
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
82
|
+
std::vector<std::vector<Statement*>> stmt_stack;
|
83
|
+
|
84
|
+
void traverseHelper(Fusion* fusion, bool traverse_all_paths = false);
|
85
|
+
|
86
|
+
public:
|
87
|
+
//! Traverses nodes in Fusion from inputs in topological order to "to". i.e.
|
88
|
+
//! from inputs towards outputs.
|
89
|
+
//! \param traverseAllPaths = false only call handle on each Statement* once
|
90
|
+
//! traverseAllPaths = true traverses all paths between expressions/values.
|
91
|
+
//! Calls handle on a Statement* for every path from inputs to "to".
|
92
|
+
//! \param traverseIntoMembers = When hitting nodes like TensorView,
|
93
|
+
//! TensorDomain, or IterDomain where there are members of the nodes that are
|
94
|
+
//! Val's a value of "true" will also traverse into those member Val's, a
|
95
|
+
//! value of "false" will not traverse into the members.
|
96
|
+
//! \param traverse_attributes When true, traverse into expr
|
97
|
+
//! attributes. Note that attributes of template type Attribute are
|
98
|
+
//! not traversed as there's no dispatch support.
|
99
|
+
//! \param traverse_siblings When true, traverse all outputs of
|
100
|
+
//! active multi-output expressions, even if those Expr outputs are not used
|
101
|
+
//! in paths to Fusion outputs.
|
102
|
+
void traverseTo(
|
103
|
+
const std::vector<Val*>& to,
|
104
|
+
bool traverse_all_paths = false,
|
105
|
+
bool traverse_into_members = false,
|
106
|
+
bool traverse_attributes = false,
|
107
|
+
bool traverse_siblings = false);
|
108
|
+
|
109
|
+
//! Traverses nodes in Fusion from inputs in topological order to "to". i.e.
|
110
|
+
//! from inputs towards outputs.
|
111
|
+
//! \param traverseAllPaths = false only call handle on each Statement* once
|
112
|
+
//! traverseAllPaths = true traverses all paths between expressions/values.
|
113
|
+
//! Calls handle on a Statement* for every path from inputs to "to".
|
114
|
+
//! \param traverseIntoMembers = When hitting nodes like TensorView,
|
115
|
+
//! TensorDomain, or IterDomain where there are members of the nodes that are
|
116
|
+
//! Val's a value of "true" will also traverse into those member Val's, a
|
117
|
+
//! value of "false" will not traverse into the members.
|
118
|
+
//! \param from: Specified values to start traversing. If a "from" Val is not
|
119
|
+
//! on path from inputs to "to" node it will not be visited. If there's a path
|
120
|
+
//! from inputs to "to" that doesn't go through "from" that input and the path
|
121
|
+
//! from it will also be traversed.
|
122
|
+
//! \param traverse_attributes When true, traverse into expr
|
123
|
+
//! attributes. Note that attributes of template type Attribute are
|
124
|
+
//! not traversed as there's no dispatch support.
|
125
|
+
//! \param traverse_siblings When true, traverse all outputs of
|
126
|
+
//! active multi-output expressions, even if those Expr outputs are not used
|
127
|
+
//! in paths to Fusion outputs.
|
128
|
+
void traverseBetween(
|
129
|
+
const std::unordered_set<Val*>& from,
|
130
|
+
const std::vector<Val*>& to,
|
131
|
+
bool traverse_all_paths = false,
|
132
|
+
bool traverse_into_members = false,
|
133
|
+
bool traverse_attributes = false,
|
134
|
+
bool traverse_siblings = false);
|
135
|
+
|
136
|
+
// Iterates from terminating outputs registered with the fusion. Terminating
|
137
|
+
// means value is not used to generate any other value used in producing
|
138
|
+
// registered outputs.
|
139
|
+
void traverse(Fusion* fusion);
|
140
|
+
|
141
|
+
// Same as traverse but it traverses every edge, meaning it will traverse
|
142
|
+
// values more than once.
|
143
|
+
void traverseAllPaths(Fusion* fusion);
|
144
|
+
|
145
|
+
//! Get inputs to vals. Possible input vals can be optionally
|
146
|
+
//! given. If not, vals with no producers are returned.
|
147
|
+
//
|
148
|
+
// TODO: This doesn't seem to fit with IterVisitor. Should probably be moved
|
149
|
+
// out of the class.
|
150
|
+
static std::vector<Val*> getInputsTo(
|
151
|
+
const std::vector<Val*>& vals,
|
152
|
+
const std::vector<Val*>& inputs = {});
|
153
|
+
};
|
154
|
+
|
155
|
+
/*
|
156
|
+
* Backward visitor calls handle in reverse order from outputs to inputs.
|
157
|
+
* It would be really nice to unify this with IterVisitor, however,
|
158
|
+
* the challenge there is that we specify traversal from outputs towards inputs
|
159
|
+
* because it implicitly provides DCE. However, if users are not careful, they
|
160
|
+
* could miss necessary outputs to do a backward traversal.
|
161
|
+
*
|
162
|
+
* BackwardVisitor checks that all outputs of an Expr is visited before visiting
|
163
|
+
* the Expr. If we don't provide nodes to start from on all backward paths of
|
164
|
+
* those outputs we will never visit the Expr.
|
165
|
+
*
|
166
|
+
* The first step of BackwardVisitor is to make sure we've specified enough
|
167
|
+
* outputs to guarentee that we will traverse all outputs of all exprs during
|
168
|
+
* the backward traversal. In the case where we don't require visiting all
|
169
|
+
* outputs of some exprs, example being the `N` output of welford ops.
|
170
|
+
* `must_cover_all_expr_outputs` is added to disable the check, and in
|
171
|
+
* this case the visitor pass need be aware
|
172
|
+
* 1. Exprs in the `from` list with any output that has a use chain that
|
173
|
+
* ends with a final consumer `will be` visited.
|
174
|
+
* 2. Vals in the `from` list that doesn't have a use chain that ends with
|
175
|
+
* a final consumer `will not be` visited, even though its
|
176
|
+
* definition expr might be visited. An example is if the `N` output
|
177
|
+
* of an welford op is unused, but other outputs are, the welford op
|
178
|
+
* will be visited but the `N` output will not.
|
179
|
+
*
|
180
|
+
*/
|
181
|
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
182
|
+
class BackwardVisitor : public OptOutDispatch {
|
183
|
+
public:
|
184
|
+
// clang-tidy: cppcoreguidelines-virtual-class-destructor
|
185
|
+
~BackwardVisitor() override = default;
|
186
|
+
|
187
|
+
protected:
|
188
|
+
BackwardVisitor(bool must_cover_all_expr_outputs = true)
|
189
|
+
: must_cover_all_expr_outputs_(must_cover_all_expr_outputs) {}
|
190
|
+
|
191
|
+
BackwardVisitor(const BackwardVisitor& other) = default;
|
192
|
+
BackwardVisitor& operator=(const BackwardVisitor& other) = default;
|
193
|
+
|
194
|
+
BackwardVisitor(BackwardVisitor&& other) = default;
|
195
|
+
BackwardVisitor& operator=(BackwardVisitor&& other) = default;
|
196
|
+
|
197
|
+
// Functions return nodes in reverse order to be added to the to_visit queue
|
198
|
+
// These functions will start at outputs and propagate up through the DAG
|
199
|
+
// to inputs based on depth first traversal. Next could be called on a node
|
200
|
+
// multiple times.
|
201
|
+
virtual std::vector<Statement*> next(Statement* stmt);
|
202
|
+
|
203
|
+
virtual std::vector<Statement*> next(Expr* expr);
|
204
|
+
|
205
|
+
virtual std::vector<Statement*> next(Val* val);
|
206
|
+
|
207
|
+
using OptOutDispatch::handle;
|
208
|
+
|
209
|
+
// This handle functions is called on every Statement* in topological order,
|
210
|
+
// starting from outputs to inputs.
|
211
|
+
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
|
212
|
+
virtual void dispatch(Statement* stmt) override;
|
213
|
+
|
214
|
+
// This handle functions is called on every Expr* in topological order,
|
215
|
+
// starting from outputs to inputs.
|
216
|
+
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
|
217
|
+
virtual void dispatch(Expr* expr) override;
|
218
|
+
|
219
|
+
// This handle functions is called on every Val* in topological order,
|
220
|
+
// starting from outputs to inputs.
|
221
|
+
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
|
222
|
+
virtual void dispatch(Val* val) override;
|
223
|
+
|
224
|
+
// All exprs that need to be visited in this traversal. Labeled in topological
|
225
|
+
// order (size_t).
|
226
|
+
std::unordered_map<Expr*, size_t> traversal_exprs_;
|
227
|
+
|
228
|
+
// The entire stack during traversal. stmt_stack.back().back() is the node
|
229
|
+
// that is being called in handle(). stmt_stack.back() contains siblings (not
|
230
|
+
// guarenteed to be all siblings throughout traversal). stmt_stack.front()
|
231
|
+
// contains the inputs we started with (not guarenteed to be all outputs
|
232
|
+
// throughout traversal).
|
233
|
+
std::deque<std::deque<Statement*>> stmt_stack_;
|
234
|
+
|
235
|
+
// Starts at nodes provided in from, traverses from these nodes to inputs.
|
236
|
+
// Calls handle on all Statement*s in topological sorted order.
|
237
|
+
// traverseAllPaths = false only call handle on each Statement* once
|
238
|
+
// traverseAllPaths = true traverses all paths from nodes in from to inputs.
|
239
|
+
// Handle on a Statement* for every path from "from" nodes, to inputs.
|
240
|
+
void traverseTo(const std::vector<Val*>& from, bool traverseAllPaths = false);
|
241
|
+
|
242
|
+
bool must_cover_all_expr_outputs_ = true;
|
243
|
+
};
|
244
|
+
|
245
|
+
class DependencyCheck {
|
246
|
+
public:
|
247
|
+
// Returns if "dependency" is a dependency of "of".
|
248
|
+
NVF_API static bool isDependencyOf(Val* dependency, Val* of);
|
249
|
+
|
250
|
+
// Finds a Val* path from "of" to "dependency". Returns that path.
|
251
|
+
// deque.back() is "of", deque[0] is dependency if a chain exists.
|
252
|
+
NVF_API static std::deque<Val*> getSingleDependencyChain(
|
253
|
+
Val* dependency,
|
254
|
+
Val* of);
|
255
|
+
|
256
|
+
// Finds all Val* paths from "of" to "dependency". Returns those paths.
|
257
|
+
// deque[i].back() is "of", and deque[i][0] is "dependency". Returns an
|
258
|
+
// empty deque if no dependency found.
|
259
|
+
static std::deque<std::deque<Val*>> getAllDependencyChains(
|
260
|
+
Val* dependency,
|
261
|
+
Val* of);
|
262
|
+
|
263
|
+
// Finds all Val* paths from all leaf nodes to "dependency". Returns those
|
264
|
+
// paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency".
|
265
|
+
// Returns an empty deque if there are no uses of dependency found.
|
266
|
+
static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
|
267
|
+
|
268
|
+
// Grab all values that exist between and including provided
|
269
|
+
// vals. Returned values are topologicaly ordered, and unique.
|
270
|
+
NVF_API static std::vector<Val*> getAllValsBetween(
|
271
|
+
const std::unordered_set<Val*>& dependencies,
|
272
|
+
const std::vector<Val*>& of);
|
273
|
+
|
274
|
+
// Returns all dependent exprs that exist between
|
275
|
+
// the provided vals
|
276
|
+
static std::vector<Expr*> getAllExprsBetween(
|
277
|
+
const std::unordered_set<Val*>& dependencies,
|
278
|
+
const std::vector<Val*>& of);
|
279
|
+
|
280
|
+
// Return registered outputs of the fusion that are a dependency of any val of
|
281
|
+
static std::unordered_set<Val*> getAllOutputsOf(
|
282
|
+
const std::unordered_set<Val*>& of);
|
283
|
+
|
284
|
+
// Return all Vals that depend on the given Vals
|
285
|
+
static std::unordered_set<Val*> getAllDependentVals(
|
286
|
+
const std::unordered_set<Val*>& of);
|
287
|
+
};
|
288
|
+
|
289
|
+
// Expr sort will take a fusion and return a topologically sorted list of
|
290
|
+
// expressions.
|
291
|
+
class StmtSort : public IterVisitor {
|
292
|
+
protected:
|
293
|
+
StmtSort() = default;
|
294
|
+
|
295
|
+
std::vector<Statement*> stmts;
|
296
|
+
|
297
|
+
using IterVisitor::handle;
|
298
|
+
|
299
|
+
void dispatch(Statement* stmt) override;
|
300
|
+
|
301
|
+
public:
|
302
|
+
// If traverse_members it will also extract all member nodes in the sorted
|
303
|
+
// statement list in the fusion. i.e. all IterDomains, extents, and associated
|
304
|
+
// expressions of them. Similarly, if traverse_attributes it will
|
305
|
+
// grab all nodes associated as Expr attributes.
|
306
|
+
NVF_API static std::vector<Statement*> getStmts(
|
307
|
+
Fusion* fusion,
|
308
|
+
bool traverse_members = false,
|
309
|
+
bool traverse_attributes = false,
|
310
|
+
bool traverse_siblings = false);
|
311
|
+
|
312
|
+
// Returns ordered Statements required to produce 'to', including 'to'.
|
313
|
+
NVF_API static std::vector<Statement*> getStmtsTo(
|
314
|
+
const std::vector<Val*>& to,
|
315
|
+
bool traverse_members = false,
|
316
|
+
bool traverse_attributes = false,
|
317
|
+
bool traverse_siblings = false);
|
318
|
+
|
319
|
+
// Returns all ordered Statements of a given fusion. Unlike
|
320
|
+
// getStmts, for TensorDomain, all of its iter domains and exprs are
|
321
|
+
// grabbed and returned in a topological order.
|
322
|
+
NVF_API static std::vector<Statement*> getAllStmts(
|
323
|
+
Fusion* fusion,
|
324
|
+
bool traverse_members = false,
|
325
|
+
bool traverse_attributes = false,
|
326
|
+
bool traverse_siblings = false);
|
327
|
+
|
328
|
+
// Returns ordered Statements required to produce 'to', including
|
329
|
+
// 'to'. Unlike getStmtsTo, for TensorDomain, all of its iter domains and
|
330
|
+
// exprs are grabbed and returned in a topological order, regardless of
|
331
|
+
// `traverse_members`.
|
332
|
+
//
|
333
|
+
// The to vals are assumed to be either TensorView or scalar
|
334
|
+
// Val. This assumption could be removed if desired.
|
335
|
+
NVF_API static std::vector<Statement*> getAllStmtsTo(
|
336
|
+
const std::vector<Val*>& to,
|
337
|
+
bool traverse_members = false,
|
338
|
+
bool traverse_attributes = false,
|
339
|
+
bool traverse_siblings = false);
|
340
|
+
|
341
|
+
// Returns ordered Statements required to produce from, including from.
|
342
|
+
// Stops traversal once hiting any Statements in to. Includes Statements in
|
343
|
+
// to.
|
344
|
+
//
|
345
|
+
// Warning: this doesn't necessarily prevent statements before `to` from being
|
346
|
+
// returned. e.g.
|
347
|
+
// i1 = i0
|
348
|
+
// i2 = i1
|
349
|
+
// i3 = i2
|
350
|
+
// i4 = i3 + i1
|
351
|
+
// getExprs(fusion, {i4}, {i3})
|
352
|
+
// will return the definition and values {i0, i1, i4}
|
353
|
+
// i3 is dependent on i1, but since i4 also is then the traversal will go down
|
354
|
+
// the i4->i1->i0 path, even though the i4->i3-//>i2->i1 path is blocked.
|
355
|
+
//
|
356
|
+
// If traverse_members it will also extract all member nodes in the sorted
|
357
|
+
// expr list in the fusion. i.e. all expressions on IterDomains, extents, etc
|
358
|
+
NVF_API static std::vector<Statement*> getStmtsBetween(
|
359
|
+
const std::vector<Val*>& from,
|
360
|
+
const std::vector<Val*>& to,
|
361
|
+
bool traverse_members = false,
|
362
|
+
bool traverse_attributes = false,
|
363
|
+
bool traverse_siblings = false);
|
364
|
+
|
365
|
+
// Same as getStmts version but filters to only return the Expr*s
|
366
|
+
static std::vector<Expr*> getExprs(
|
367
|
+
const Fusion* fusion,
|
368
|
+
bool traverse_members = false,
|
369
|
+
bool traverse_attributes = false,
|
370
|
+
bool traverse_siblings = false);
|
371
|
+
|
372
|
+
// Same as getStmts version but filters to only return the Expr*s
|
373
|
+
NVF_API static std::vector<Expr*> getExprsTo(
|
374
|
+
const std::vector<Val*>& to,
|
375
|
+
bool traverse_members = false,
|
376
|
+
bool traverse_attributes = false,
|
377
|
+
bool traverse_siblings = false);
|
378
|
+
|
379
|
+
// Same as getStmts version but filters to only return the Expr*s
|
380
|
+
NVF_API static std::vector<Expr*> getExprsBetween(
|
381
|
+
const std::vector<Val*>& from,
|
382
|
+
const std::vector<Val*>& to,
|
383
|
+
bool traverse_members = false,
|
384
|
+
bool traverse_attributes = false,
|
385
|
+
bool traverse_siblings = false);
|
386
|
+
};
|
387
|
+
|
388
|
+
class InputsOf : public IterVisitor {
|
389
|
+
private:
|
390
|
+
std::unordered_set<Val*> grabbed_inputs;
|
391
|
+
std::vector<Val*> ordered_inputs;
|
392
|
+
|
393
|
+
using IterVisitor::handle;
|
394
|
+
|
395
|
+
void dispatch(Val* v) final;
|
396
|
+
|
397
|
+
public:
|
398
|
+
NVF_API static std::vector<Val*> output(Val* output_);
|
399
|
+
static std::vector<Val*> outputs(const std::vector<Val*>& outputs_);
|
400
|
+
};
|
401
|
+
|
402
|
+
//! This is a generic traversal class that is used to modify a Fusion graph by
|
403
|
+
//! replacing Vals to simplify computation or remove dead code. This differs
|
404
|
+
//! from OptOutMutator, which is built for mutating TensorViews in-place in a
|
405
|
+
//! graph by altering the associated IterDomains, and which does not easily
|
406
|
+
//! handle modifying TensorView definitions and Fusion outputs during traversal.
|
407
|
+
//!
|
408
|
+
//! Derived classes should override handle() for relevant Exprs and they should
|
409
|
+
//! make use of registerReplacement() to change the definitions of Vals in the
|
410
|
+
//! graph. Note that if replacements are made using registerReplacement(old_val,
|
411
|
+
//! new_val), then neither new_val nor any new Statements produced in creating
|
412
|
+
//! it will be traversed by this class. Also note that any Vals or Exprs that
|
413
|
+
//! are previously marked dead will not be processed by handle().
|
414
|
+
class DeadCodeRemover : BackwardVisitor {
|
415
|
+
public:
|
416
|
+
DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {}
|
417
|
+
|
418
|
+
DeadCodeRemover(const DeadCodeRemover& other) = default;
|
419
|
+
DeadCodeRemover& operator=(const DeadCodeRemover& other) = default;
|
420
|
+
|
421
|
+
DeadCodeRemover(DeadCodeRemover&& other) = default;
|
422
|
+
DeadCodeRemover& operator=(DeadCodeRemover&& other) = default;
|
423
|
+
|
424
|
+
//! Instead of traverseTo, run() is the entry point for this class, and we
|
425
|
+
//! always traverse from outputs backward to their inputs.
|
426
|
+
//!
|
427
|
+
//! Returns a bool indicating whether the Fusion was modified or not.
|
428
|
+
bool run();
|
429
|
+
|
430
|
+
inline Fusion* fusion() const {
|
431
|
+
return fusion_;
|
432
|
+
}
|
433
|
+
|
434
|
+
protected:
|
435
|
+
using BackwardVisitor::handle;
|
436
|
+
|
437
|
+
void dispatch(Statement* stmt) override;
|
438
|
+
void dispatch(Expr* expr) override;
|
439
|
+
|
440
|
+
//! We implement this in order to remove dangling TensorViews whose uses are
|
441
|
+
//! all dead. Note that we do not remove other ValTypes like Scalars since
|
442
|
+
//! they might still be used as attributes or members of other objects, which
|
443
|
+
//! is not reflected by Val::uses().
|
444
|
+
void handle(TensorView* tv) override;
|
445
|
+
|
446
|
+
//! Registers a Val for replacement in outputs and in all its uses.
|
447
|
+
//!
|
448
|
+
//! Note that replacement does not occur immediately, but will be done after
|
449
|
+
//! the traversal is completed. This is so that any Val* and Expr* pointers
|
450
|
+
//! may be safely dereferenced during traversal.
|
451
|
+
//!
|
452
|
+
//! The argument old_val is always marked Dead by this method. If old_val is a
|
453
|
+
//! Fusion input, we do not replace it. If old_val's definition is non-null
|
454
|
+
//! and has other outputs which are not dead, we do not remove old_val.
|
455
|
+
//!
|
456
|
+
//! Returns whether old_val was registered for removal from the Fusion.
|
457
|
+
bool registerReplacement(Val* old_val, Val* new_val);
|
458
|
+
|
459
|
+
//! Find whether a statement is not marked as live code.
|
460
|
+
inline bool isDead(Statement* stmt) const {
|
461
|
+
return live_statements_.find(stmt) == live_statements_.end();
|
462
|
+
}
|
463
|
+
|
464
|
+
//! Find whether a statement is marked as live code.
|
465
|
+
inline bool isLive(Statement* stmt) const {
|
466
|
+
return !isDead(stmt);
|
467
|
+
}
|
468
|
+
|
469
|
+
//! Check whether all outputs of an expression have been marked dead
|
470
|
+
inline bool allOutputsDead(Expr* expr) const {
|
471
|
+
return std::all_of(
|
472
|
+
expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) {
|
473
|
+
return isDead(outp);
|
474
|
+
});
|
475
|
+
}
|
476
|
+
|
477
|
+
//! Check whether all uses have been marked dead
|
478
|
+
inline bool allUsesDead(Val* val) const {
|
479
|
+
auto fu_it = future_uses_.find(val);
|
480
|
+
if (fu_it != future_uses_.end() && !fu_it->second.empty()) {
|
481
|
+
// Regardless of whether current uses are marked dead, this appears in a
|
482
|
+
// replacement expression, so it has a future live use and we should keep
|
483
|
+
// it.
|
484
|
+
return false;
|
485
|
+
}
|
486
|
+
|
487
|
+
return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) {
|
488
|
+
return isDead(use);
|
489
|
+
});
|
490
|
+
}
|
491
|
+
|
492
|
+
private:
|
493
|
+
//! Removes an Expr* from the Fusion, if possible.
|
494
|
+
//!
|
495
|
+
//! The Expr will _only_ be marked dead and removed if all of its outputs are
|
496
|
+
//! already marked dead. In this case all the outputs will also be removed
|
497
|
+
//! from the Fusion.
|
498
|
+
//!
|
499
|
+
//! Returns whether the Expr was marked dead and removed from the Fusion.
|
500
|
+
bool maybeRemoveExpr(Expr* expr);
|
501
|
+
|
502
|
+
//! Mark a single Statement as being alive.
|
503
|
+
inline void markLive(Statement* stmt) {
|
504
|
+
live_statements_.insert(stmt);
|
505
|
+
if (auto e = dynamic_cast<Expr*>(stmt)) {
|
506
|
+
// Check if this expression is already in uses() for each of its inputs
|
507
|
+
// and if not, record it in future_uses_
|
508
|
+
for (Val* inp : e->inputs()) {
|
509
|
+
if (std::find(inp->uses().begin(), inp->uses().end(), e) ==
|
510
|
+
inp->uses().end()) {
|
511
|
+
auto fu_it = future_uses_.find(inp);
|
512
|
+
if (fu_it == future_uses_.end()) {
|
513
|
+
future_uses_.emplace(inp, std::unordered_set<Expr*>({e}));
|
514
|
+
} else {
|
515
|
+
fu_it->second.insert(e);
|
516
|
+
}
|
517
|
+
}
|
518
|
+
}
|
519
|
+
}
|
520
|
+
}
|
521
|
+
|
522
|
+
//! Ensure that a Statement and its upstream Statements are alive. If it is an
|
523
|
+
//! Expr, ensure all its inputs are alive. If it's a Val with a definition,
|
524
|
+
//! recursive to the definition. Newly-created Statements default to being
|
525
|
+
//! dead, so this method is called when adding a Statement to the active path
|
526
|
+
//! of the Fusion inside registerReplacement.
|
527
|
+
void markLiveRecursive(Statement* stmt);
|
528
|
+
|
529
|
+
//! Mark a single Statement as being dead. This does not remove stmt from the
|
530
|
+
//! Fusion. It is an error to call this on a Fusion output.
|
531
|
+
//!
|
532
|
+
//! Returns true if the statement was previously live, and false otherwise.
|
533
|
+
bool markDead(Statement* stmt);
|
534
|
+
|
535
|
+
//! Register a Val for later removal.
|
536
|
+
void registerRemoval(Val* val);
|
537
|
+
|
538
|
+
//! Register an Expr for later removal.
|
539
|
+
//!
|
540
|
+
//! Note that if any of its outputs are removed, expr will be removed even if
|
541
|
+
//! it is not marked for removal, and all its outputs will have their
|
542
|
+
//! definitions set to nullptr.
|
543
|
+
inline void registerRemoval(Expr* expr) {
|
544
|
+
exprs_to_remove_.push_back(expr);
|
545
|
+
}
|
546
|
+
|
547
|
+
//! All modifications to the Fusion are registered during traversal then
|
548
|
+
//! later they are committed by this method. For safety, this should only be
|
549
|
+
//! run after traversing the graph.
|
550
|
+
//!
|
551
|
+
//! Returns a bool indicating whether any modifications were performed.
|
552
|
+
bool modifyFusion() const;
|
553
|
+
|
554
|
+
private:
|
555
|
+
//! The Fusion associated with live_statements_
|
556
|
+
Fusion* fusion_;
|
557
|
+
|
558
|
+
//! Statements are marked dead by removing them from this set
|
559
|
+
std::unordered_set<Statement*> live_statements_;
|
560
|
+
|
561
|
+
//! Vals to be replaced in outputs and with replaceValInExprInputs in all
|
562
|
+
//! uses.
|
563
|
+
std::vector<std::pair<Val*, Val*>> vals_to_replace_;
|
564
|
+
|
565
|
+
//! Statements that will be removed. We remove Vals before Exprs, so we track
|
566
|
+
//! them separately here.
|
567
|
+
std::vector<Val*> vals_to_remove_;
|
568
|
+
std::vector<Expr*> exprs_to_remove_;
|
569
|
+
|
570
|
+
//! This holds additional _future_ uses of each val. val->uses() only returns
|
571
|
+
//! currently live uses, so until we have finalized all replacements, new uses
|
572
|
+
//! will not appear there. The mapping below gets populated whenever we mark
|
573
|
+
//! an expression as live, if that expression is not already in inp->uses()
|
574
|
+
//! for any of its inputs.
|
575
|
+
std::unordered_map<Val*, std::unordered_set<Expr*>> future_uses_;
|
576
|
+
};
|
577
|
+
|
578
|
+
struct IRDefinitions {
|
579
|
+
decltype(auto) operator()(Val* val) const {
|
580
|
+
auto def = val->definition();
|
581
|
+
if (def == nullptr) {
|
582
|
+
return std::vector<Expr*>{};
|
583
|
+
}
|
584
|
+
return std::vector<Expr*>{val->definition()};
|
585
|
+
}
|
586
|
+
};
|
587
|
+
|
588
|
+
struct IRUses {
|
589
|
+
decltype(auto) operator()(Val* val) const {
|
590
|
+
return val->uses();
|
591
|
+
}
|
592
|
+
};
|
593
|
+
|
594
|
+
struct IRInputs {
|
595
|
+
decltype(auto) operator()(Expr* expr) const {
|
596
|
+
return expr->inputs();
|
597
|
+
}
|
598
|
+
};
|
599
|
+
|
600
|
+
struct IROutputs {
|
601
|
+
decltype(auto) operator()(Expr* expr) const {
|
602
|
+
return expr->outputs();
|
603
|
+
}
|
604
|
+
};
|
605
|
+
|
606
|
+
template <>
|
607
|
+
struct GetValType<Expr*> {
|
608
|
+
using type = Val*;
|
609
|
+
};
|
610
|
+
|
611
|
+
class IRBFS
|
612
|
+
: public BFS<Expr*, Val*, IRDefinitions, IRUses, IRInputs, IROutputs> {
|
613
|
+
public:
|
614
|
+
IRBFS(
|
615
|
+
std::vector<NodeType> from_groups,
|
616
|
+
std::vector<NodeType> to_groups,
|
617
|
+
bool require_all_to_visited,
|
618
|
+
Direction allowed_direction = Direction::Undefined)
|
619
|
+
: BFS(IRDefinitions{},
|
620
|
+
IRUses{},
|
621
|
+
IRInputs{},
|
622
|
+
IROutputs{},
|
623
|
+
std::move(from_groups),
|
624
|
+
std::move(to_groups),
|
625
|
+
require_all_to_visited,
|
626
|
+
allowed_direction) {}
|
627
|
+
};
|
628
|
+
|
629
|
+
inline std::vector<Val*> getInputsOfExpr(Expr* expr, Direction dir) {
|
630
|
+
return getInputsOfExpr<Expr*>(expr, dir, IRInputs(), IROutputs());
|
631
|
+
}
|
632
|
+
|
633
|
+
inline std::vector<Val*> getOutputsOfExpr(Expr* expr, Direction dir) {
|
634
|
+
return getOutputsOfExpr<Expr*>(expr, dir, IRInputs(), IROutputs());
|
635
|
+
}
|
636
|
+
|
637
|
+
class IRPermissiveBFS : public BFSWithPermissiveDependence<
|
638
|
+
Expr*,
|
639
|
+
Val*,
|
640
|
+
IRDefinitions,
|
641
|
+
IRUses,
|
642
|
+
IRInputs,
|
643
|
+
IROutputs> {
|
644
|
+
public:
|
645
|
+
IRPermissiveBFS(
|
646
|
+
std::vector<NodeType> from_groups,
|
647
|
+
std::vector<NodeType> to_groups,
|
648
|
+
bool require_all_to_visited,
|
649
|
+
Direction allowed_direction = Direction::Undefined)
|
650
|
+
: BFSWithPermissiveDependence(
|
651
|
+
IRDefinitions{},
|
652
|
+
IRUses{},
|
653
|
+
IRInputs{},
|
654
|
+
IROutputs{},
|
655
|
+
std::move(from_groups),
|
656
|
+
std::move(to_groups),
|
657
|
+
require_all_to_visited,
|
658
|
+
allowed_direction) {}
|
659
|
+
};
|
660
|
+
|
661
|
+
} // namespace nvfuser
|