nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,446 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <disjoint_set.h>
|
11
|
+
#include <ir/all_nodes.h>
|
12
|
+
|
13
|
+
#include <iostream>
|
14
|
+
#include <string>
|
15
|
+
#include <type_traits>
|
16
|
+
#include <unordered_map>
|
17
|
+
#include <vector>
|
18
|
+
|
19
|
+
namespace nvfuser {
|
20
|
+
|
21
|
+
// ValGraph is a DAG of Vals and Exprs connected by their input and
|
22
|
+
// output dependencies. Each graph node is a collection of
|
23
|
+
// either Vals or Exprs that are grouped together through mapVals and
|
24
|
+
// mapExprs, respectively.
|
25
|
+
//
|
26
|
+
// The primary use case of ValGraph is for representing groupings and
|
27
|
+
// dependencies of iteration domains. For example, given a fusion as
|
28
|
+
// shown below:
|
29
|
+
//
|
30
|
+
// T1 = set(T0);
|
31
|
+
// T2 = set(T1);
|
32
|
+
//
|
33
|
+
// T0: root [I0, I1], loop [I0, I1]
|
34
|
+
// T1: root [I2, I3], loop [I2*I3/4, 4]
|
35
|
+
// T2: root [I4, I5], loop [I4*I5/4, 4]
|
36
|
+
//
|
37
|
+
// The Exact ValGraph consists of ValGroups of:
|
38
|
+
//
|
39
|
+
// - {I0, I2, I4}
|
40
|
+
// - {I1, I3, I5}
|
41
|
+
// - {I2*I3, I4*I5}
|
42
|
+
// - {I2*I3/4, I4*I5/4}
|
43
|
+
// - {4, 4}
|
44
|
+
//
|
45
|
+
// and ExprGroups of:
|
46
|
+
//
|
47
|
+
// - {merge of I2 and I3, merge of I4 and I5}
|
48
|
+
// - {split of I2*I3, split of I4*I5}
|
49
|
+
//
|
50
|
+
// ValGraph can be used with any Val types, however, it's currenty
|
51
|
+
// only tested with IterDomain. Some of the routines might need to be
|
52
|
+
// extended for other Val types.
|
53
|
+
|
54
|
+
using ValGroup = std::shared_ptr<VectorOfUniqueEntries<Val*>>;
|
55
|
+
using ValGroups = VectorOfUniqueEntries<ValGroup>;
|
56
|
+
using ExprGroup = std::shared_ptr<VectorOfUniqueEntries<Expr*>>;
|
57
|
+
using ExprGroups = VectorOfUniqueEntries<ExprGroup>;
|
58
|
+
|
59
|
+
class ValGraph {
|
60
|
+
public:
|
61
|
+
ValGraph() = default;
|
62
|
+
|
63
|
+
ValGraph(const ValGraph& other);
|
64
|
+
ValGraph(ValGraph&& other) = default;
|
65
|
+
|
66
|
+
ValGraph& operator=(const ValGraph& other);
|
67
|
+
ValGraph& operator=(ValGraph&& other) = default;
|
68
|
+
|
69
|
+
ValGraph(bool propagate_through_exprs)
|
70
|
+
: propagate_through_exprs_(propagate_through_exprs) {}
|
71
|
+
|
72
|
+
// Returns the disjoint val set.
|
73
|
+
const DisjointSets<Val*>& disjointValSets() const {
|
74
|
+
return disjoint_vals_;
|
75
|
+
}
|
76
|
+
|
77
|
+
// Returns the disjoint Expr set.
|
78
|
+
const DisjointSets<Expr*>& disjointExprSets() const {
|
79
|
+
return disjoint_exprs_;
|
80
|
+
}
|
81
|
+
|
82
|
+
// Return if there's a group entry in the graph for this expr
|
83
|
+
bool hasGroup(Expr* expr) const;
|
84
|
+
|
85
|
+
// Return if there's a group entry in the graph for this val
|
86
|
+
bool hasGroup(Val* val) const;
|
87
|
+
|
88
|
+
// Convert expr to its exprGroup, assert that it exists.
|
89
|
+
const ExprGroup& toGroup(Expr* expr) const;
|
90
|
+
|
91
|
+
// Convert Val to its ValGroup, assert that it exists.
|
92
|
+
const ValGroup& toGroup(Val* val) const;
|
93
|
+
|
94
|
+
// Convert a vector-like container of Val* or Expr* to their
|
95
|
+
// ValGroups or ExprGroups. The vector-like container type must
|
96
|
+
// define the element type as value_type
|
97
|
+
template <
|
98
|
+
typename ContainerType,
|
99
|
+
typename ElementType = typename std::remove_pointer<
|
100
|
+
typename ContainerType::value_type>::type,
|
101
|
+
typename RetType = typename std::conditional<
|
102
|
+
std::is_base_of<Val, ElementType>::value,
|
103
|
+
ValGroups,
|
104
|
+
ExprGroups>::type,
|
105
|
+
typename = std::enable_if_t<
|
106
|
+
std::is_base_of<Val, ElementType>::value ||
|
107
|
+
std::is_base_of<Expr, ElementType>::value>>
|
108
|
+
RetType toGroups(const ContainerType& entries) const {
|
109
|
+
RetType groups;
|
110
|
+
for (auto entry : entries) {
|
111
|
+
groups.pushBack(toGroup(entry));
|
112
|
+
}
|
113
|
+
return groups;
|
114
|
+
}
|
115
|
+
|
116
|
+
// Return output/input Val groups of provided expr
|
117
|
+
// Note that the same ValGroup can show up multiple times, so the
|
118
|
+
// output type cannot be VectorOfUniqueEntries
|
119
|
+
std::vector<ValGroup> outputGroups(const ExprGroup& expr) const;
|
120
|
+
std::vector<ValGroup> inputGroups(const ExprGroup& expr) const;
|
121
|
+
|
122
|
+
// Return Val groups that have no definition.
|
123
|
+
ValGroups getTerminatingInputs() const;
|
124
|
+
|
125
|
+
// Recursively traverses uses of the IdGroups in 'of' and returns all
|
126
|
+
// ExprGroups that have a use in their definition of provided of IdGroups.
|
127
|
+
ExprGroups allUsesOf(const ValGroups& of) const;
|
128
|
+
|
129
|
+
// Recursively traverses definitions of the IdGroups in 'of' and returns all
|
130
|
+
// ExprGroups used in this history of defining the 'of' IdGroups.
|
131
|
+
ExprGroups allDefinitionsOf(const ValGroups& of) const;
|
132
|
+
|
133
|
+
//! Returns the expressions associated with the
|
134
|
+
//! definitions of the provided ValGroup.
|
135
|
+
//!
|
136
|
+
//! Each ExprGroup of the returned ExprGroup vector is proven to be
|
137
|
+
//! equivalent. The ExprGroup vector holds expression groups that are not
|
138
|
+
//! equivalent, but produce one of the ValGroups within the same disjoint Val
|
139
|
+
//! set.
|
140
|
+
const ExprGroups& getDefinitions(const ValGroup& val_group) const;
|
141
|
+
|
142
|
+
//! Same as getDefinitions but for uses instead of
|
143
|
+
//! definitions
|
144
|
+
const ExprGroups& getUses(const ValGroup& val_group) const;
|
145
|
+
|
146
|
+
bool hasDefinitions(const ValGroup& val_group) const;
|
147
|
+
|
148
|
+
bool hasUses(const ValGroup& val_group) const;
|
149
|
+
|
150
|
+
// Uses the Valgraph to produce mappings between from and to.
|
151
|
+
// Supports one to many mappings. If a single Val in from maps to
|
152
|
+
// multiple Vals in to, the order of the Vals in value of
|
153
|
+
// the map is preserved to be the order provided in to.
|
154
|
+
//
|
155
|
+
// Example:
|
156
|
+
// tv0: [i0, b1]
|
157
|
+
// tv1: [i2, i3]
|
158
|
+
// tv2: [i4, i5]
|
159
|
+
// tv2 = tv0 + tv1
|
160
|
+
//
|
161
|
+
// tv0: [i0*b1] CA(1)
|
162
|
+
// tv1: [i2*i3] CA(1)
|
163
|
+
// tv2: [i4*i5] CA(1)
|
164
|
+
//
|
165
|
+
// Between tv0 and tv2, the Permissive graph would map:
|
166
|
+
// {i0, i4}
|
167
|
+
// {b1, i5}
|
168
|
+
// {i0*b1, i4*i5}
|
169
|
+
//
|
170
|
+
// Here, buildMapBetween with:
|
171
|
+
// from: {i0, b1, i0*b1}
|
172
|
+
// to: {i4, i5, i4*i5}
|
173
|
+
// will return a map of:
|
174
|
+
// i0: {i4}
|
175
|
+
// b1: {i5}
|
176
|
+
// i0*b1: {i4*i5}
|
177
|
+
std::unordered_map<Val*, VectorOfUniqueEntries<Val*>> buildMapBetween(
|
178
|
+
const std::vector<Val*>& from,
|
179
|
+
const std::vector<Val*>& to) const;
|
180
|
+
|
181
|
+
// Alias of the above on unique vector entries
|
182
|
+
std::unordered_map<Val*, VectorOfUniqueEntries<Val*>> buildMapBetween(
|
183
|
+
const VectorOfUniqueEntries<Val*>& from,
|
184
|
+
const VectorOfUniqueEntries<Val*>& to) const;
|
185
|
+
|
186
|
+
std::string toString() const;
|
187
|
+
|
188
|
+
std::string toGraphvizDotGraph() const;
|
189
|
+
|
190
|
+
// Initializes entries for the provided Val with its definitions and
|
191
|
+
// uses. The provided Val will have its own new ValGroup, each item in the
|
192
|
+
// definitions and uses will become a new ExprGroup, and these new ExprGroups
|
193
|
+
// will be the definitions and uses of the new ValGroup.
|
194
|
+
void initializeVal(
|
195
|
+
Val* val,
|
196
|
+
const VectorOfUniqueEntries<Expr*>& definitions,
|
197
|
+
const VectorOfUniqueEntries<Expr*>& uses);
|
198
|
+
|
199
|
+
// Same as the above exept val->definition() and val->uses() are
|
200
|
+
// used
|
201
|
+
void initializeVal(Val* val);
|
202
|
+
|
203
|
+
// Initializes entries for the provided Val. The provided Val will be added to
|
204
|
+
// the provided existing ValGroup. There will be no changes on the definitions
|
205
|
+
// and uses of the provided ValGroup.
|
206
|
+
void initializeVal(Val* v, ValGroup vg) {
|
207
|
+
disjoint_vals_.appendToSet(v, vg);
|
208
|
+
}
|
209
|
+
|
210
|
+
// Add expr to the disjoint sets as a sole group. Used for
|
211
|
+
// registering replayed domains and exprs. Error if the expr is
|
212
|
+
// already registered.
|
213
|
+
void registerExpr(Expr* expr);
|
214
|
+
|
215
|
+
// Returns true if first and second are expressions through which
|
216
|
+
// this ValGraph has matching inputs (if forward), or outputs (if not
|
217
|
+
// forward). Returning true means the expressions are "the same", in terms
|
218
|
+
// they modify matching original inputs by the same amount.
|
219
|
+
bool exprsMap(Expr* first, Expr* second, bool forward) const;
|
220
|
+
|
221
|
+
// Check basic consistencies of val and expr groups and their
|
222
|
+
// mappings.
|
223
|
+
void validateConsistency() const;
|
224
|
+
|
225
|
+
void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) {
|
226
|
+
unique_uses_.at(id_group).pushBack(uses);
|
227
|
+
}
|
228
|
+
|
229
|
+
void addUniqueDefinitions(const ValGroup& id_group, const ExprGroup& defs) {
|
230
|
+
unique_definitions_.at(id_group).pushBack(defs);
|
231
|
+
}
|
232
|
+
|
233
|
+
// Set val0 and val1 to mapped in this graph, attempt to propagate
|
234
|
+
// new mapping through val0/val1 definitions/uses.
|
235
|
+
void mapVals(Val* val0, Val* val1);
|
236
|
+
|
237
|
+
// Checks if expr0 and expr1 should map together, maps them together, and if
|
238
|
+
// expression propagation is on, propagates mapping through
|
239
|
+
// them. The forward parameter determines the direction of the
|
240
|
+
// propagation. The expressions are mapped if the inputs are mapped
|
241
|
+
// when the forward parameter is true. This should
|
242
|
+
// be the only call in ValGraph to mapThroughExpr.
|
243
|
+
void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward);
|
244
|
+
|
245
|
+
// Can't back prop through merge without making sure one input actually
|
246
|
+
// matches. This can be done on a map or extent basis.
|
247
|
+
// TODO: Move this to val_graph.cpp once validation_utils.cpp is
|
248
|
+
// retired.
|
249
|
+
template <typename T>
|
250
|
+
static bool shouldMapMergeBackward(
|
251
|
+
Merge* merge0,
|
252
|
+
Merge* merge1,
|
253
|
+
const DisjointSets<T*>& id_sets) {
|
254
|
+
auto extent_match = [](IterDomain* id0, IterDomain* id1) -> bool {
|
255
|
+
return id0->extent()->sameAs(id1->extent()) ||
|
256
|
+
(id0->extent()->isConstInt() && id1->extent()->isConstInt() &&
|
257
|
+
id0->extent()->evaluate().as<int64_t>() ==
|
258
|
+
id1->extent()->evaluate().as<int64_t>());
|
259
|
+
};
|
260
|
+
|
261
|
+
// If one pair of the domains are mapped in the given graph, the
|
262
|
+
// backward merge is considered mapped
|
263
|
+
if (id_sets.permissiveAreMapped(merge0->outer(), merge1->outer()) ||
|
264
|
+
id_sets.permissiveAreMapped(merge0->inner(), merge1->inner())) {
|
265
|
+
return true;
|
266
|
+
}
|
267
|
+
|
268
|
+
// Considered mapped if the extents are equal
|
269
|
+
if (extent_match(merge0->outer(), merge1->outer()) ||
|
270
|
+
extent_match(merge0->inner(), merge1->inner())) {
|
271
|
+
return true;
|
272
|
+
}
|
273
|
+
|
274
|
+
// The mapped ID group may have different extents depending on the
|
275
|
+
// mapping conditions. For example, the Permissive graph may have a
|
276
|
+
// symbolic extent as well as an extent of 1 for broadcast
|
277
|
+
// domains. Those other mapped domains need to be checked as well.
|
278
|
+
|
279
|
+
// First, the outer groups
|
280
|
+
auto outer0_group = id_sets.mappingExists(merge0->outer())
|
281
|
+
? id_sets.disjointSetMap().at(merge0->outer())
|
282
|
+
: std::make_shared<VectorOfUniqueEntries<T*>>(
|
283
|
+
VectorOfUniqueEntries<T*>{merge0->outer()});
|
284
|
+
auto outer1_group = id_sets.mappingExists(merge1->outer())
|
285
|
+
? id_sets.disjointSetMap().at(merge1->outer())
|
286
|
+
: std::make_shared<VectorOfUniqueEntries<T*>>(
|
287
|
+
VectorOfUniqueEntries<T*>{merge1->outer()});
|
288
|
+
|
289
|
+
for (T* outer0 : *outer0_group) {
|
290
|
+
for (T* outer1 : *outer1_group) {
|
291
|
+
if (extent_match(
|
292
|
+
outer0->template as<IterDomain>(),
|
293
|
+
outer1->template as<IterDomain>())) {
|
294
|
+
return true;
|
295
|
+
}
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
// Check the inner groups as well if not already matched
|
300
|
+
auto inner0_group = id_sets.mappingExists(merge0->inner())
|
301
|
+
? id_sets.disjointSetMap().at(merge0->inner())
|
302
|
+
: std::make_shared<VectorOfUniqueEntries<T*>>(
|
303
|
+
VectorOfUniqueEntries<T*>{merge0->inner()});
|
304
|
+
auto inner1_group = id_sets.mappingExists(merge1->inner())
|
305
|
+
? id_sets.disjointSetMap().at(merge1->inner())
|
306
|
+
: std::make_shared<VectorOfUniqueEntries<T*>>(
|
307
|
+
VectorOfUniqueEntries<T*>{merge1->inner()});
|
308
|
+
|
309
|
+
for (T* inner0 : *inner0_group) {
|
310
|
+
for (T* inner1 : *inner1_group) {
|
311
|
+
if (extent_match(
|
312
|
+
inner0->template as<IterDomain>(),
|
313
|
+
inner1->template as<IterDomain>())) {
|
314
|
+
return true;
|
315
|
+
}
|
316
|
+
}
|
317
|
+
}
|
318
|
+
|
319
|
+
return false;
|
320
|
+
}
|
321
|
+
|
322
|
+
private:
|
323
|
+
// Map expr0 and expr1 with each other, update unique_definitions_
|
324
|
+
// unique_uses_
|
325
|
+
// TODO: Make this variant hidden?
|
326
|
+
void mapExprs(Expr* expr0, Expr* expr1);
|
327
|
+
|
328
|
+
// Checks if expr's are considered "the same" where sameness is
|
329
|
+
// defined as inputs and outputs in the same position across
|
330
|
+
// expressions are mapped. If the expressions are determined the
|
331
|
+
// same then
|
332
|
+
//
|
333
|
+
// if forward
|
334
|
+
// will map outputs
|
335
|
+
// else
|
336
|
+
// will map inputs
|
337
|
+
//
|
338
|
+
// Returns true if expressions were mapped through.
|
339
|
+
bool mapThroughExpr(Expr* first, Expr* second, bool forward);
|
340
|
+
|
341
|
+
private:
|
342
|
+
// If propagate_through_exprs_ = false, then mapThroughExpr will not be called
|
343
|
+
// as a consequence of calling mapVals. As well as mapThroughExpr will not be
|
344
|
+
// called (again) as a result of calling mapThroughExpr.
|
345
|
+
//
|
346
|
+
// Note: For the second sentence of above... mapThroughExpr can call mapVals
|
347
|
+
// which could in return call mapThoughExpr again, but
|
348
|
+
// propagate_through_exprs_ as mentioned above prevents that from happening.
|
349
|
+
bool propagate_through_exprs_ = true;
|
350
|
+
|
351
|
+
// Keeps a disjoint set entry for all Vals.
|
352
|
+
DisjointSets<Val*> disjoint_vals_;
|
353
|
+
|
354
|
+
// Keeps a disjoint set entry for all Exprs.
|
355
|
+
DisjointSets<Expr*> disjoint_exprs_;
|
356
|
+
|
357
|
+
// Definitions of ValGroup. There can be multiple definitions due to
|
358
|
+
// replays.
|
359
|
+
std::unordered_map<ValGroup, ExprGroups> unique_definitions_;
|
360
|
+
|
361
|
+
std::unordered_map<ValGroup, ExprGroups> unique_uses_;
|
362
|
+
};
|
363
|
+
|
364
|
+
struct ValGroupAndItsGraph {
|
365
|
+
ValGroup group;
|
366
|
+
ValGraph* graph;
|
367
|
+
bool operator==(const ValGroupAndItsGraph& other) const {
|
368
|
+
return group == other.group && graph == other.graph;
|
369
|
+
}
|
370
|
+
bool operator!=(const ValGroupAndItsGraph& other) const {
|
371
|
+
return !operator==(other);
|
372
|
+
}
|
373
|
+
operator const ValGroup&() const {
|
374
|
+
return group;
|
375
|
+
}
|
376
|
+
};
|
377
|
+
|
378
|
+
inline std::ostream& operator<<(
|
379
|
+
std::ostream& os,
|
380
|
+
const ValGroupAndItsGraph& g) {
|
381
|
+
return os << g.group;
|
382
|
+
}
|
383
|
+
|
384
|
+
// Returns the first pair of id's in ids detected to match each other on the
|
385
|
+
// given ID graph. TODO: what this is really looking for is if
|
386
|
+
// there's any overlapping between the iter domains in the provided set.
|
387
|
+
//
|
388
|
+
// i.e. if we have:
|
389
|
+
// tv0 = arange(6).reshape({3, 2})
|
390
|
+
// tv1 = tv0[3, 2].t()
|
391
|
+
// tv2 = tv0[3, 2].reshape({2, 3})
|
392
|
+
// tv3 = tv1 + tv2
|
393
|
+
//
|
394
|
+
// Then we can see this overlap in the tv3 expression as:
|
395
|
+
//
|
396
|
+
// tv0 = { {0, 1, 2},
|
397
|
+
// {3, 4, 5} }
|
398
|
+
//
|
399
|
+
// tv1 = { {0, 3},
|
400
|
+
// {1, 4},
|
401
|
+
// {2, 5} }
|
402
|
+
//
|
403
|
+
// tv2 = { {0, 1},
|
404
|
+
// {2, 3},
|
405
|
+
// {4, 5} }
|
406
|
+
//
|
407
|
+
// The elements in tv1 {3, 1, 4, 2}, map respectively to the elements in tv2
|
408
|
+
// {1, 2, 3, 4}. The reason this is so important is it means that generating
|
409
|
+
// tv3 is no longer a trivially parallelizable problem (if we include the dag
|
410
|
+
// all the way to tv0). So tv0's axes cannot be inlined across both the tv0
|
411
|
+
// and tv1 path. This breaks some assumptions we have today in schedulers that
|
412
|
+
// will assume tv2 can be trivially inlined/parallelized. Instead we'd need to
|
413
|
+
// take into consideration the effective communication going on here, so that
|
414
|
+
// we pull multiple values of tv0 to compute tv3.
|
415
|
+
//
|
416
|
+
// Note, however, that the above example is not detectable at this
|
417
|
+
// moment as the self mapping is partial through reshape. The analysis
|
418
|
+
// below would need to be extended to consider producer and consumers
|
419
|
+
// of domains as well rather than just root, logical and loop domains.
|
420
|
+
std::optional<std::pair<IterDomain*, IterDomain*>> detectSelfMapping(
|
421
|
+
const std::vector<IterDomain*>& ids,
|
422
|
+
const ValGraph& id_graph);
|
423
|
+
|
424
|
+
struct SelfMapping {
|
425
|
+
IterDomain* id1;
|
426
|
+
IterDomain* id2;
|
427
|
+
// For debugging, records which domain `id1` and `id2` belong to. This value
|
428
|
+
// is either "Root", "Logical", or "Leaf". Consider making it an enum.
|
429
|
+
std::string where;
|
430
|
+
};
|
431
|
+
|
432
|
+
// Returns if a self mapping was detected that would invalidate assumptions of
|
433
|
+
// the overall lowering system.
|
434
|
+
//
|
435
|
+
// It is assumed that for any tensor represented by a list of domains,
|
436
|
+
// those domains should never be mapped with each other. It may be
|
437
|
+
// possible to lift this assumption, but it's unclear if it could
|
438
|
+
// matter in practice.
|
439
|
+
//
|
440
|
+
// TODO: Can we make this more of an alias analysis?
|
441
|
+
// Ref: https://github.com/csarofeen/pytorch/pull/1954#discussion_r961940498
|
442
|
+
std::optional<SelfMapping> hasSelfMapping(
|
443
|
+
const TensorView* tv,
|
444
|
+
const ValGraph& id_graph);
|
445
|
+
|
446
|
+
} // namespace nvfuser
|
@@ -0,0 +1,259 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <bfs.h>
|
11
|
+
#include <disjoint_set.h>
|
12
|
+
#include <id_model/to_string.h>
|
13
|
+
#include <ir/all_nodes.h>
|
14
|
+
#include <val_graph.h>
|
15
|
+
|
16
|
+
namespace nvfuser {
|
17
|
+
|
18
|
+
// Iterates through a Val Graph in topological order, calling handle on
|
19
|
+
// all Val and all Expr groups in a forward topological order.
|
20
|
+
//
|
21
|
+
// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the
|
22
|
+
// AlmostExact and Permissive graphs would have cycles with a ValGroup
|
23
|
+
// and an ExprGroup. For example:
|
24
|
+
//
|
25
|
+
// [i0, 1]
|
26
|
+
// merge
|
27
|
+
// [i0*1]
|
28
|
+
// Current ValGroups: {{i0}, {1}, {i0*1}}
|
29
|
+
// map i0 and i0*1 as they effectively have the same extent
|
30
|
+
// Final ValGroups: {{i0, i0*1}, {1}}
|
31
|
+
//
|
32
|
+
// Here, the merge expr is the user of i0 and the definition of
|
33
|
+
// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks
|
34
|
+
// like:
|
35
|
+
//
|
36
|
+
// {i0, i0*1} ----> {merge} ----> {i0, i0*1}
|
37
|
+
// use def
|
38
|
+
//
|
39
|
+
// These ExprGroups are called trivial ExprGroups (see also
|
40
|
+
// ValGraph::isTrivialExprGroup).
|
41
|
+
//
|
42
|
+
// Strictly speaking, these cycles mean there's no valid topological
|
43
|
+
// order anymore. In our use cases for IdModel, however, it's likely
|
44
|
+
// sufficient to return an ordering such as:
|
45
|
+
//
|
46
|
+
// {i0, i0*1} -> {merge}
|
47
|
+
//
|
48
|
+
// I.e., we visit {i0, i0*1} first even though {merge} is technically
|
49
|
+
// a definition.
|
50
|
+
//
|
51
|
+
// Another alternative may be simply giving up when such a cycle is
|
52
|
+
// detected, which may be more preferrable as it would be less
|
53
|
+
// confusing. At this moment, this visitor is only used with graphs
|
54
|
+
// with no such cycle. Should be revisited when necessary.
|
55
|
+
//
|
56
|
+
// Warning: This is not a great iterator if there's a desire to minimize paths
|
57
|
+
// traveled to simply visit all ValGroups in order. See ExprsBetween to see how
|
58
|
+
// we might minimize paths.
|
59
|
+
class ValGraphVisitor {
|
60
|
+
public:
|
61
|
+
ValGraphVisitor() = delete;
|
62
|
+
|
63
|
+
ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete;
|
64
|
+
|
65
|
+
ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete;
|
66
|
+
|
67
|
+
virtual ~ValGraphVisitor() = default;
|
68
|
+
|
69
|
+
protected:
|
70
|
+
ValGraphVisitor(const ValGraph& val_graph, bool allow_cycle = true)
|
71
|
+
: val_graph_(val_graph), allow_cycle_(allow_cycle) {}
|
72
|
+
|
73
|
+
ValGraphVisitor(const ValGraphVisitor& other) = default;
|
74
|
+
|
75
|
+
ValGraphVisitor(ValGraphVisitor&& other) = default;
|
76
|
+
|
77
|
+
virtual void handle(const ValGroup& val_group) = 0;
|
78
|
+
virtual void handle(const ExprGroup& expr_group) = 0;
|
79
|
+
|
80
|
+
// Returns if the traversal was successful. If false, error_message_
|
81
|
+
// should be populated.
|
82
|
+
bool traverse();
|
83
|
+
|
84
|
+
const ValGraph& graph() {
|
85
|
+
return val_graph_;
|
86
|
+
};
|
87
|
+
|
88
|
+
const std::string& errorMessage() const {
|
89
|
+
return error_message_;
|
90
|
+
}
|
91
|
+
|
92
|
+
private:
|
93
|
+
const ValGraph& val_graph_;
|
94
|
+
bool allow_cycle_ = true;
|
95
|
+
std::string error_message_;
|
96
|
+
};
|
97
|
+
|
98
|
+
// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor.
|
99
|
+
class ValGraphStmtSort : public ValGraphVisitor {
|
100
|
+
public:
|
101
|
+
ValGraphStmtSort(const ValGraph& val_graph, bool allow_cycle = true)
|
102
|
+
: ValGraphVisitor(val_graph, allow_cycle) {
|
103
|
+
NVF_ERROR(ValGraphVisitor::traverse(), errorMessage());
|
104
|
+
}
|
105
|
+
|
106
|
+
// Return non-reference so that code like below can work
|
107
|
+
// for (auto expr_group: ValGraphStmtSort(graph).exprs())
|
108
|
+
ExprGroups exprs() const {
|
109
|
+
return sorted_exprs_;
|
110
|
+
}
|
111
|
+
|
112
|
+
ValGroups vals() const {
|
113
|
+
return sorted_vals_;
|
114
|
+
}
|
115
|
+
|
116
|
+
~ValGraphStmtSort() override = default;
|
117
|
+
|
118
|
+
protected:
|
119
|
+
using ValGraphVisitor::handle;
|
120
|
+
|
121
|
+
void handle(const ValGroup& val_group) override {
|
122
|
+
sorted_vals_.pushBack(val_group);
|
123
|
+
}
|
124
|
+
|
125
|
+
void handle(const ExprGroup& expr_group) override {
|
126
|
+
sorted_exprs_.pushBack(expr_group);
|
127
|
+
}
|
128
|
+
|
129
|
+
ExprGroups sorted_exprs_;
|
130
|
+
ValGroups sorted_vals_;
|
131
|
+
};
|
132
|
+
|
133
|
+
bool isCyclic(const ValGraph& graph);
|
134
|
+
|
135
|
+
class ValGraphDefinitions {
|
136
|
+
const ValGraph& graph_;
|
137
|
+
|
138
|
+
public:
|
139
|
+
ValGraphDefinitions(const ValGraph& graph) : graph_(graph) {}
|
140
|
+
decltype(auto) operator()(const ValGroup& val_group) const {
|
141
|
+
return graph_.getDefinitions(val_group);
|
142
|
+
}
|
143
|
+
};
|
144
|
+
|
145
|
+
class ValGraphUses {
|
146
|
+
const ValGraph& graph_;
|
147
|
+
|
148
|
+
public:
|
149
|
+
ValGraphUses(const ValGraph& graph) : graph_(graph) {}
|
150
|
+
decltype(auto) operator()(const ValGroup& val_group) const {
|
151
|
+
return graph_.getUses(val_group);
|
152
|
+
}
|
153
|
+
};
|
154
|
+
|
155
|
+
class ValGraphInputs {
|
156
|
+
const ValGraph& graph_;
|
157
|
+
|
158
|
+
public:
|
159
|
+
ValGraphInputs(const ValGraph& graph) : graph_(graph) {}
|
160
|
+
decltype(auto) operator()(const ExprGroup& expr_group) const {
|
161
|
+
return graph_.inputGroups(expr_group);
|
162
|
+
}
|
163
|
+
};
|
164
|
+
|
165
|
+
class ValGraphOutputs {
|
166
|
+
const ValGraph& graph_;
|
167
|
+
|
168
|
+
public:
|
169
|
+
ValGraphOutputs(const ValGraph& graph) : graph_(graph) {}
|
170
|
+
decltype(auto) operator()(const ExprGroup& expr_group) const {
|
171
|
+
return graph_.outputGroups(expr_group);
|
172
|
+
}
|
173
|
+
};
|
174
|
+
|
175
|
+
template <>
|
176
|
+
struct GetValType<ExprGroup> {
|
177
|
+
using type = ValGroup;
|
178
|
+
};
|
179
|
+
|
180
|
+
class ValGraphBFS : public BFS<
|
181
|
+
ExprGroup,
|
182
|
+
ValGroup,
|
183
|
+
ValGraphDefinitions,
|
184
|
+
ValGraphUses,
|
185
|
+
ValGraphInputs,
|
186
|
+
ValGraphOutputs> {
|
187
|
+
public:
|
188
|
+
ValGraphBFS(
|
189
|
+
const ValGraph& graph,
|
190
|
+
std::vector<NodeType> from_groups,
|
191
|
+
std::vector<NodeType> to_groups,
|
192
|
+
bool require_all_to_visited = true,
|
193
|
+
Direction allowed_direction = Direction::Undefined)
|
194
|
+
: BFS(ValGraphDefinitions(graph),
|
195
|
+
ValGraphUses(graph),
|
196
|
+
ValGraphInputs(graph),
|
197
|
+
ValGraphOutputs(graph),
|
198
|
+
std::move(from_groups),
|
199
|
+
std::move(to_groups),
|
200
|
+
require_all_to_visited,
|
201
|
+
allowed_direction) {}
|
202
|
+
|
203
|
+
// Just a shortcut to the generic getExprsBetween
|
204
|
+
static std::pair<ValGraphBFS::ExprPath, bool> getExprGroupsBetween(
|
205
|
+
const ValGraph& graph,
|
206
|
+
const ValGroups& from,
|
207
|
+
const ValGroups& to,
|
208
|
+
bool require_all_to_visited = true,
|
209
|
+
Direction allowed_direction = Direction::Undefined) {
|
210
|
+
return getExprsBetween<ValGraphBFS>(
|
211
|
+
from.vector(),
|
212
|
+
to.vector(),
|
213
|
+
require_all_to_visited,
|
214
|
+
allowed_direction,
|
215
|
+
graph);
|
216
|
+
}
|
217
|
+
};
|
218
|
+
|
219
|
+
class ValGraphPermissiveBFS : public BFSWithPermissiveDependence<
|
220
|
+
ExprGroup,
|
221
|
+
ValGroup,
|
222
|
+
ValGraphDefinitions,
|
223
|
+
ValGraphUses,
|
224
|
+
ValGraphInputs,
|
225
|
+
ValGraphOutputs> {
|
226
|
+
public:
|
227
|
+
ValGraphPermissiveBFS(
|
228
|
+
const ValGraph& graph,
|
229
|
+
std::vector<NodeType> from_groups,
|
230
|
+
std::vector<NodeType> to_groups,
|
231
|
+
bool require_all_to_visited = true,
|
232
|
+
Direction allowed_direction = Direction::Undefined)
|
233
|
+
: BFSWithPermissiveDependence(
|
234
|
+
ValGraphDefinitions(graph),
|
235
|
+
ValGraphUses(graph),
|
236
|
+
ValGraphInputs(graph),
|
237
|
+
ValGraphOutputs(graph),
|
238
|
+
std::move(from_groups),
|
239
|
+
std::move(to_groups),
|
240
|
+
require_all_to_visited,
|
241
|
+
allowed_direction) {}
|
242
|
+
|
243
|
+
// Just a shortcut to the generic getExprsBetween
|
244
|
+
static std::pair<ValGraphPermissiveBFS::ExprPath, bool> getExprGroupsBetween(
|
245
|
+
const ValGraph& graph,
|
246
|
+
const ValGroups& from,
|
247
|
+
const ValGroups& to,
|
248
|
+
bool require_all_to_visited = true,
|
249
|
+
Direction allowed_direction = Direction::Undefined) {
|
250
|
+
return getExprsBetween<ValGraphPermissiveBFS>(
|
251
|
+
from.vector(),
|
252
|
+
to.vector(),
|
253
|
+
require_all_to_visited,
|
254
|
+
allowed_direction,
|
255
|
+
graph);
|
256
|
+
}
|
257
|
+
};
|
258
|
+
|
259
|
+
} // namespace nvfuser
|