nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,744 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <ir/base_nodes.h>
|
12
|
+
#include <optional>
|
13
|
+
|
14
|
+
//! IR header hierarchy
|
15
|
+
//! 1. utils.h - PolymorphicBase and NonCopyable
|
16
|
+
//! 2. ir/base_nodes.h - Statement, Expr, and Val
|
17
|
+
//! 3. ** ir/internal_base_nodes.h ** - IterDomain and TensorDomain
|
18
|
+
//! 4. ir/interface_nodes.h - TensorView and Scalar
|
19
|
+
//! 5. ir/internal_nodes.h - Any internal-only IR nodes
|
20
|
+
|
21
|
+
namespace nvfuser {
|
22
|
+
|
23
|
+
// Friends for direct access to split
|
24
|
+
class TensorDomain;
|
25
|
+
class IterDomain;
|
26
|
+
class ReplayTransformations;
|
27
|
+
class IndexReferenceReplay;
|
28
|
+
class ViewTransform;
|
29
|
+
class Scope;
|
30
|
+
class IrCloner;
|
31
|
+
struct AnalyzeViewResult;
|
32
|
+
|
33
|
+
// Convenience utility to initialize IterDomain's without having to sort through
|
34
|
+
// all the default values. Intended to be used with
|
35
|
+
// IterDomain::IterDomain(IrBuilderPasskey, IterDomainBuilder).
|
36
|
+
class IterDomainBuilder {
|
37
|
+
public:
|
38
|
+
// Match legacy constructor
|
39
|
+
NVF_API IterDomainBuilder(Val* _start, Val* _extent);
|
40
|
+
|
41
|
+
// Grab all the parameters from id to set the IterDomainBuilder
|
42
|
+
NVF_API IterDomainBuilder(const IterDomain* id);
|
43
|
+
|
44
|
+
// Resets defaults for rfactor, is padded dim, padded to size, and is mma
|
45
|
+
// swizzle which should only be set during scheduling.
|
46
|
+
IterDomainBuilder& resetSchedulingParams();
|
47
|
+
|
48
|
+
// Resets is_rfactor_domain
|
49
|
+
IterDomainBuilder& resetRfactor();
|
50
|
+
|
51
|
+
IterDomainBuilder& start(Val* _start);
|
52
|
+
IterDomainBuilder& extent(Val* _extent);
|
53
|
+
NVF_API IterDomainBuilder& expanded_extent(Val* _expanded_extent);
|
54
|
+
IterDomainBuilder& stop_offset(Val* _stop_offset);
|
55
|
+
IterDomainBuilder& parallel_type(ParallelType _parallel_type);
|
56
|
+
NVF_API IterDomainBuilder& iter_type(IterType _iter_type);
|
57
|
+
IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
|
58
|
+
IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
|
59
|
+
IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size);
|
60
|
+
|
61
|
+
NVF_API IterDomain* build() const;
|
62
|
+
|
63
|
+
// Must have start and extent at least
|
64
|
+
IterDomainBuilder() = delete;
|
65
|
+
|
66
|
+
Val* start_ = nullptr;
|
67
|
+
Val* extent_ = nullptr;
|
68
|
+
Val* expanded_extent_ = nullptr;
|
69
|
+
Val* stop_offset_ = nullptr;
|
70
|
+
ParallelType parallel_type_ = ParallelType::Serial;
|
71
|
+
IterType iter_type_ = IterType::Iteration;
|
72
|
+
|
73
|
+
// Only relevant at scheduling time or compile time.
|
74
|
+
bool is_rfactor_domain_ = false;
|
75
|
+
bool is_padded_dimension_ = false;
|
76
|
+
std::optional<int64_t> padded_to_size_ = std::nullopt;
|
77
|
+
};
|
78
|
+
|
79
|
+
//! Simply a representation of an annotated 1D iterable from start to extent.
|
80
|
+
//! TensorDomains which represent how to iterate over a tensor is made up of
|
81
|
+
//! IterDomains to form an ND iterable. We directly set parallization strategies
|
82
|
+
//! on IterDomains.
|
83
|
+
class NVF_API IterDomain : public Val {
|
84
|
+
public:
|
85
|
+
IterDomain(IrBuilderPasskey, const IterDomainBuilder& args);
|
86
|
+
|
87
|
+
// Legacy constructor, TODO: should start moving to use the IterDomainBuilder
|
88
|
+
// constructor. Same as the above but can set the offset of the stop point.
|
89
|
+
IterDomain(
|
90
|
+
IrBuilderPasskey,
|
91
|
+
Val* start,
|
92
|
+
Val* extent,
|
93
|
+
Val* expanded_extent,
|
94
|
+
Val* stop_offset,
|
95
|
+
ParallelType parallel_type,
|
96
|
+
IterType iter_type,
|
97
|
+
bool is_rfactor_domain,
|
98
|
+
bool is_padded_dimension,
|
99
|
+
std::optional<int64_t> padded_to_size);
|
100
|
+
|
101
|
+
IterDomain(const IterDomain* src, IrCloner* ir_cloner);
|
102
|
+
|
103
|
+
NVFUSER_DECLARE_CLONE
|
104
|
+
|
105
|
+
bool sameAs(const Statement* other) const override;
|
106
|
+
|
107
|
+
std::string toString(int indent_size = 0) const override;
|
108
|
+
|
109
|
+
std::string toInlineString(int indent_size = 0) const override;
|
110
|
+
|
111
|
+
//! Returns a new IterDomain matching properties of this
|
112
|
+
//!
|
113
|
+
//! This does NOT copy the is_rfactor_domain flag.
|
114
|
+
//!
|
115
|
+
//! When map_with_original is true, the clone of the original is
|
116
|
+
//! mapped in the Exact graph.
|
117
|
+
IterDomain* cloneWithoutRFactor(bool map_with_original = false);
|
118
|
+
|
119
|
+
//! Clone a vector domains
|
120
|
+
static std::vector<IterDomain*> clone(
|
121
|
+
const std::vector<IterDomain*>& domains);
|
122
|
+
|
123
|
+
//! The optional parameters of rfactor_domain and iter_type can be
|
124
|
+
//! used to override the default behavior.
|
125
|
+
static IterDomain* merge(
|
126
|
+
IterDomain* outer,
|
127
|
+
IterDomain* inner,
|
128
|
+
std::optional<bool> rfactor_domain = std::nullopt,
|
129
|
+
std::optional<IterType> iter_type = std::nullopt);
|
130
|
+
|
131
|
+
//! The optional parameters of rfactor_domain, outer_iter_type and
|
132
|
+
//! inner_iter_type can be used to override the default behavior.
|
133
|
+
static std::pair<IterDomain*, IterDomain*> split(
|
134
|
+
IterDomain* in,
|
135
|
+
Val* factor,
|
136
|
+
bool inner_split,
|
137
|
+
std::optional<bool> rfactor_domain = std::nullopt,
|
138
|
+
std::optional<IterType> outer_iter_type = std::nullopt,
|
139
|
+
std::optional<IterType> inner_iter_type = std::nullopt);
|
140
|
+
|
141
|
+
//! Resize an IterDomain by expanding both the left and right sides
|
142
|
+
//! by given widths. The resulting IterDomain has an extent of
|
143
|
+
//! (left_expansion + in->extent() + right_expansion). Note that the
|
144
|
+
//! expansion factors can be negative, meaning the input IterDomain
|
145
|
+
//! is shrunk. This is the case when resize is used to represent
|
146
|
+
//! slice.
|
147
|
+
//!
|
148
|
+
//! When mark_as_rfactor is true, the output IterDomain
|
149
|
+
//! is marked as an rfactor domain. For example, expressions such as
|
150
|
+
//! PadOp and SliceOp resize IterDomains and generate rfactor
|
151
|
+
//! resized domains.
|
152
|
+
//!
|
153
|
+
//! Usually, the IterType of the output IterDomain will be Symbolic. This is
|
154
|
+
//! because unless the left and right expansions are known at Fusion
|
155
|
+
//! definition we cannot be sure that the output will have an extent != 1. In
|
156
|
+
//! case the output extent is in fact 1, we will set the IterType to
|
157
|
+
//! Broadcast. If the left and right expansions are constant, and sum to at
|
158
|
+
//! least two, then even an empty input will result in an Iteration IterType.
|
159
|
+
//! In these cases, we will set the output IterType to Iteration at
|
160
|
+
//! definition. Otherwise, it will be set to Symbolic and will be resolved
|
161
|
+
//! when concretization is performed by FusionExecutorCache.
|
162
|
+
//!
|
163
|
+
//! The optional iter_type argument can be used to force the output IterType,
|
164
|
+
//! but for safety its use should typically be confined to concretization.
|
165
|
+
static IterDomain* resize(
|
166
|
+
IterDomain* in,
|
167
|
+
Val* left_expansion,
|
168
|
+
Val* right_expansion,
|
169
|
+
bool mark_as_rfactor = false,
|
170
|
+
std::optional<IterType> iter_type = std::nullopt);
|
171
|
+
|
172
|
+
bool isReduction() const {
|
173
|
+
return getIterType() == IterType::Reduction;
|
174
|
+
}
|
175
|
+
|
176
|
+
bool isIteration() const {
|
177
|
+
return getIterType() == IterType::Iteration;
|
178
|
+
}
|
179
|
+
|
180
|
+
bool isRFactorProduct() const {
|
181
|
+
return is_rfactor_domain_;
|
182
|
+
}
|
183
|
+
|
184
|
+
bool isBroadcast() const {
|
185
|
+
return getIterType() == IterType::Broadcast;
|
186
|
+
}
|
187
|
+
|
188
|
+
bool isSymbolic() const {
|
189
|
+
return getIterType() == IterType::Symbolic;
|
190
|
+
}
|
191
|
+
|
192
|
+
bool isGatherScatter() const {
|
193
|
+
return getIterType() == IterType::GatherScatter;
|
194
|
+
}
|
195
|
+
|
196
|
+
bool isStride() const {
|
197
|
+
return getIterType() == IterType::Stride;
|
198
|
+
}
|
199
|
+
|
200
|
+
bool isVectorComponent() const {
|
201
|
+
return getIterType() == IterType::VectorComponent;
|
202
|
+
}
|
203
|
+
|
204
|
+
bool isParallelized() const {
|
205
|
+
return getParallelType() != ParallelType::Serial;
|
206
|
+
}
|
207
|
+
|
208
|
+
//! Return if this iter domain is mapped to a grid dimension
|
209
|
+
bool isBlockDim() const {
|
210
|
+
return isParallelTypeBlockDim(getParallelType());
|
211
|
+
}
|
212
|
+
|
213
|
+
//! Return if this iter domain is mapped to a block dimension
|
214
|
+
bool isThreadDim() const {
|
215
|
+
return isParallelTypeThreadDim(getParallelType());
|
216
|
+
}
|
217
|
+
|
218
|
+
//! Return if this iter domain is either mapped to a block or grid dimension
|
219
|
+
bool isThread() const {
|
220
|
+
return (isBlockDim() || isThreadDim());
|
221
|
+
}
|
222
|
+
|
223
|
+
bool isDeviceDim() const {
|
224
|
+
return isParallelTypeDeviceDim(getParallelType());
|
225
|
+
}
|
226
|
+
|
227
|
+
void parallelize(ParallelType t);
|
228
|
+
|
229
|
+
ParallelType getParallelType() const {
|
230
|
+
return parallel_type_;
|
231
|
+
}
|
232
|
+
|
233
|
+
IterType getIterType() const {
|
234
|
+
return iter_type_;
|
235
|
+
}
|
236
|
+
|
237
|
+
Val* start() const {
|
238
|
+
return start_;
|
239
|
+
}
|
240
|
+
|
241
|
+
Val* stop() const;
|
242
|
+
|
243
|
+
Val* stopOffset() const;
|
244
|
+
|
245
|
+
Val* extent() const {
|
246
|
+
NVF_ERROR(extent_ != nullptr);
|
247
|
+
return extent_;
|
248
|
+
}
|
249
|
+
|
250
|
+
bool hasExpandedExtent() const {
|
251
|
+
return expanded_extent_ != nullptr;
|
252
|
+
}
|
253
|
+
|
254
|
+
// Returns the expanded extent of a strided broadcast entry.
|
255
|
+
Val* expandedExtent() const {
|
256
|
+
NVF_ERROR(
|
257
|
+
hasExpandedExtent(),
|
258
|
+
"Requested expanded extent, but none found on this dimension.");
|
259
|
+
return expanded_extent_;
|
260
|
+
}
|
261
|
+
|
262
|
+
Val* getMaybeExpandedExtent() const {
|
263
|
+
if (hasExpandedExtent()) {
|
264
|
+
return expandedExtent();
|
265
|
+
}
|
266
|
+
return extent();
|
267
|
+
}
|
268
|
+
|
269
|
+
//! Dimension padding interface:
|
270
|
+
//! 2 modes are currently supported:
|
271
|
+
//!
|
272
|
+
//! - mode 1: if to_size is given as a positive number,
|
273
|
+
//! the dimension will be padded to the size so that
|
274
|
+
//! this iterdomain will be compile-time constant
|
275
|
+
//! size and it is the scheduler's responsibility
|
276
|
+
//! to ensure no input larger than the padded size
|
277
|
+
//! will be observed
|
278
|
+
//!
|
279
|
+
//! - mode 2: if no to_size is given, this dimension
|
280
|
+
//! is "dynamically" padded to next smallest multiple
|
281
|
+
//! of a warp size, i.e. 17 padded to 32, 33 padded to 64
|
282
|
+
//! based on the given input.
|
283
|
+
void padToMultipleOfWarp(std::optional<int64_t> maybe_to_size = {}) {
|
284
|
+
// Currently only restricted to TIDx to generate warp reduce
|
285
|
+
NVF_CHECK(
|
286
|
+
parallel_type_ == ParallelType::TIDx,
|
287
|
+
"padToMultipleOfWarp : warp padding only supported on TIDx parallel dimension");
|
288
|
+
is_padded_dimension_ = true;
|
289
|
+
if (maybe_to_size.has_value()) {
|
290
|
+
if (maybe_to_size.value() > 0) {
|
291
|
+
padded_to_size_ = maybe_to_size.value();
|
292
|
+
}
|
293
|
+
}
|
294
|
+
}
|
295
|
+
|
296
|
+
//! Indicates if this iterdomain had padding
|
297
|
+
//! dynamical or statical
|
298
|
+
bool hasPaddingToMultipleOfWarp() const {
|
299
|
+
return is_padded_dimension_;
|
300
|
+
}
|
301
|
+
|
302
|
+
//! Returns a concrete value if this iterdomain
|
303
|
+
//! has been padded to a statical size.
|
304
|
+
std::optional<int64_t> getMaybeSizeAfterPadding() const {
|
305
|
+
return padded_to_size_;
|
306
|
+
}
|
307
|
+
|
308
|
+
//! True if range of iteration domain isn't across the full extent
|
309
|
+
bool maybePartial() const;
|
310
|
+
|
311
|
+
//! Check if IterDomain is a broadcast axis with compile-time
|
312
|
+
//! known extent. This is the case with all size-1 IterDomains on
|
313
|
+
//! a TensorView's root domain when the TensorView is created.
|
314
|
+
bool isImplicitBroadcast() const {
|
315
|
+
return isBroadcast() && extent()->isOneInt();
|
316
|
+
}
|
317
|
+
|
318
|
+
//! Split for stride by a given factor. It effectively does an inner
|
319
|
+
//! split by the factor and sets the inner domain as a Stride
|
320
|
+
//! domain.
|
321
|
+
std::pair<IterDomain*, IterDomain*> stridedSplit(int64_t factor);
|
322
|
+
|
323
|
+
//! Marks that this id represents a
|
324
|
+
//! instruction loop, mma use only.
|
325
|
+
//!
|
326
|
+
//! An instruction loop can be considered a generalization of
|
327
|
+
//! vectorization. It also represents a loop that's implemented
|
328
|
+
//! by an instruction and should not be realized by codegen and
|
329
|
+
//! cannot be inlined with.
|
330
|
+
//! As an example, if a mma macro, call it mma_eg implements:
|
331
|
+
//! for m in M
|
332
|
+
//! for n in N
|
333
|
+
//! for k in K
|
334
|
+
//! C[m,n] += A[m,k]*B[k,n],
|
335
|
+
//! But the generated code should simply be:
|
336
|
+
//! mma_eg(C,A,B)
|
337
|
+
//! without the 3 level loopnest, i.e. they're instruction loops.
|
338
|
+
//!
|
339
|
+
//! In the actual mma macros, the loopnests it implements is a
|
340
|
+
//! transformed version of above to match the mma swizzle.
|
341
|
+
//! So it's different implicit loopnest for different macros.
|
342
|
+
//! MmaSwizzler will label the instruction loops case-by-case.
|
343
|
+
bool isMma() const {
|
344
|
+
return parallel_type_ == ParallelType::Mma;
|
345
|
+
}
|
346
|
+
|
347
|
+
//! Marks that this id represents an instruction loop, cp.async.bulk use only.
|
348
|
+
bool isBulk() const {
|
349
|
+
return parallel_type_ == ParallelType::Bulk;
|
350
|
+
}
|
351
|
+
|
352
|
+
//! Applies 2D swizzle on a rectangular tile defined by
|
353
|
+
//! a pair of iterdomains.
|
354
|
+
static std::pair<IterDomain*, IterDomain*> swizzle(
|
355
|
+
SwizzleType swizzle_type,
|
356
|
+
IterDomain* in_x,
|
357
|
+
IterDomain* in_y);
|
358
|
+
static std::pair<IterDomain*, IterDomain*> swizzle(
|
359
|
+
Swizzle2DType swizzle_type,
|
360
|
+
IterDomain* in_x,
|
361
|
+
IterDomain* in_y,
|
362
|
+
SwizzleMode swizzle_mode = SwizzleMode::Data);
|
363
|
+
|
364
|
+
protected:
|
365
|
+
friend TensorDomain;
|
366
|
+
friend ReplayTransformations;
|
367
|
+
friend IndexReferenceReplay;
|
368
|
+
|
369
|
+
private:
|
370
|
+
//! Valid range is defined as [start:-stop_offset]
|
371
|
+
Val* const start_ = nullptr;
|
372
|
+
Val* const extent_ = nullptr;
|
373
|
+
|
374
|
+
// Broadcast dimensions are assumed to be size 1 for the sake of code
|
375
|
+
// generation. If a user though calls `expand` on a tensor that dimension is
|
376
|
+
// still considered a broadcast dimension. However if we ever output that
|
377
|
+
// dimension it should be a size dictated by the `expand` operation, and have
|
378
|
+
// a stride of zero. Since this extent is important to track, but not
|
379
|
+
// necessarily generate code for (still want loops on broadcast to be of size
|
380
|
+
// 0), we simply store it separately from extent_. Having an expanded_extent_
|
381
|
+
// is only allowed with broadcasted dimsneions. Only in this instance does it
|
382
|
+
// make sense to have an expanded_extent_, because it's used when users are
|
383
|
+
// expecting return tensors to have a physical domain. If a user simply
|
384
|
+
// "broadcasts" an operation
|
385
|
+
Val* const expanded_extent_ = nullptr;
|
386
|
+
|
387
|
+
//! Distance of stop from the end
|
388
|
+
Val* const stop_offset_ = nullptr;
|
389
|
+
ParallelType parallel_type_ = ParallelType::Serial;
|
390
|
+
IterType iter_type_ = IterType::Iteration;
|
391
|
+
bool is_rfactor_domain_ = false;
|
392
|
+
bool is_padded_dimension_ = false;
|
393
|
+
std::optional<int64_t> padded_to_size_ = std::nullopt;
|
394
|
+
};
|
395
|
+
|
396
|
+
//! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
|
397
|
+
//! logical axis in its associated tensor. TensorDomain does not directly hold
|
398
|
+
//! the Tensor it is associated with, and in theory could be associated with
|
399
|
+
//! multiple tensors. TensorDomain's primary responsibility is to provide a
|
400
|
+
//! mechanism to access history of transformations that were used to generate
|
401
|
+
//! it. This is done through the normal interaction of Expr/Val in Fusion. i.e.
|
402
|
+
//! if we want to know the previous operation generating a particular
|
403
|
+
//! TensorDomain we can simply call:
|
404
|
+
//!
|
405
|
+
//! FusionGuard::getCurFusion()->definition(a_tensor_domain)
|
406
|
+
//!
|
407
|
+
//! which should give us an operation in the list [split, merge] or similar
|
408
|
+
//! operations that take in a TensorDomain, applies a transformation and outputs
|
409
|
+
//! a tensor domain.
|
410
|
+
class TensorDomain : public Val {
|
411
|
+
public:
|
412
|
+
NVF_API explicit TensorDomain(
|
413
|
+
IrBuilderPasskey,
|
414
|
+
std::vector<IterDomain*> logical_domain,
|
415
|
+
std::vector<std::optional<bool>> contiguity = {});
|
416
|
+
|
417
|
+
// See notes [ Note stride order and contiguity vector ] in
|
418
|
+
// python_bindings.cpp
|
419
|
+
TensorDomain(
|
420
|
+
IrBuilderPasskey,
|
421
|
+
std::vector<IterDomain*> logical_domain,
|
422
|
+
std::vector<int64_t> stride_order,
|
423
|
+
std::vector<std::optional<bool>> contiguity = {});
|
424
|
+
|
425
|
+
TensorDomain(
|
426
|
+
IrBuilderPasskey,
|
427
|
+
std::vector<IterDomain*> logical_domain,
|
428
|
+
std::vector<IterDomain*> loop_domain,
|
429
|
+
std::vector<std::optional<bool>> contiguity = {});
|
430
|
+
|
431
|
+
TensorDomain(
|
432
|
+
IrBuilderPasskey,
|
433
|
+
std::vector<IterDomain*> root_domain,
|
434
|
+
std::vector<IterDomain*> logical_domain,
|
435
|
+
std::vector<IterDomain*> loop_domain,
|
436
|
+
std::vector<std::optional<bool>> contiguity = {});
|
437
|
+
|
438
|
+
TensorDomain(
|
439
|
+
IrBuilderPasskey,
|
440
|
+
std::vector<IterDomain*> root_domain,
|
441
|
+
std::vector<IterDomain*> logical_domain,
|
442
|
+
std::vector<IterDomain*> allocation,
|
443
|
+
std::vector<IterDomain*> loop_domain,
|
444
|
+
std::vector<std::optional<bool>> contiguity = {},
|
445
|
+
std::vector<IterDomain*> additional_ids = {});
|
446
|
+
|
447
|
+
TensorDomain(IrBuilderPasskey, const TensorDomain* src);
|
448
|
+
|
449
|
+
TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
|
450
|
+
|
451
|
+
NVFUSER_DECLARE_CLONE
|
452
|
+
|
453
|
+
bool operator==(const TensorDomain& other) const;
|
454
|
+
bool operator!=(const TensorDomain& other) const {
|
455
|
+
return !(*this == other);
|
456
|
+
}
|
457
|
+
|
458
|
+
int64_t nDims() const {
|
459
|
+
return (int64_t)loop_domain_.size();
|
460
|
+
}
|
461
|
+
|
462
|
+
bool sameAs(const Statement* other) const override;
|
463
|
+
|
464
|
+
static bool sameAs(
|
465
|
+
const std::vector<IterDomain*>& lhs,
|
466
|
+
const std::vector<IterDomain*>& rhs);
|
467
|
+
|
468
|
+
// When `loop_only` is false, prints also the root, logical and allocation
|
469
|
+
// domain if not empty.
|
470
|
+
std::string toString(int indent_size, bool loop_only) const;
|
471
|
+
std::string toString(int indent_size = 0) const override;
|
472
|
+
std::string toInlineString(int indent_size = 0) const override;
|
473
|
+
|
474
|
+
// Note: [Contiguity]
|
475
|
+
// Contiguity is a vector of optional<bool> which has the same number of
|
476
|
+
// elements as logical_domain_. The contiguity of a broadcast dimension is
|
477
|
+
// meaningless, so it has to be nullopt. The contiguity of a non-broadcasting
|
478
|
+
// dimension is true if and only if it is memory dense with the next
|
479
|
+
// non-broadcasting dimension.
|
480
|
+
// For example, if I have a tensor torch.zeros(4, 1, 3).expand(-1, 10, -1),
|
481
|
+
// the contiguity will be (true, nullopt, true), which means 4 is memory dense
|
482
|
+
// with 3.
|
483
|
+
const std::vector<std::optional<bool>>& contiguity() const {
|
484
|
+
return contiguity_;
|
485
|
+
}
|
486
|
+
|
487
|
+
// The python frontend has a stride_order argument in the define_tensor
|
488
|
+
// function. This argument allows the user to specify the allocation domain
|
489
|
+
// for the TensorView. When translating the CPP Fusion into a Python
|
490
|
+
// FusionDefinition, the stride_order argument is required if this
|
491
|
+
// TensorDomain's allocation domain is a permutation of the logical domain.
|
492
|
+
// This function generates the stride_order argument for this TensorDomain.
|
493
|
+
std::vector<int64_t> strideOrder() const;
|
494
|
+
|
495
|
+
NVF_API void setContiguity(const std::vector<std::optional<bool>>& contig);
|
496
|
+
|
497
|
+
std::string getContiguityString() const {
|
498
|
+
return toDelimitedString(contiguity(), /*delim=*/" ");
|
499
|
+
}
|
500
|
+
|
501
|
+
bool hasReduction() const {
|
502
|
+
return has_reduction_;
|
503
|
+
}
|
504
|
+
|
505
|
+
bool hasBlockReduction() const;
|
506
|
+
bool hasGridReduction() const;
|
507
|
+
bool hasBlockBroadcast() const;
|
508
|
+
bool hasGridBroadcast() const;
|
509
|
+
|
510
|
+
bool hasBroadcast() const {
|
511
|
+
return no_bcast_domain_.size() != loop_domain_.size();
|
512
|
+
}
|
513
|
+
|
514
|
+
bool hasRoot() const {
|
515
|
+
return !root_domain_.empty();
|
516
|
+
}
|
517
|
+
|
518
|
+
bool hasAllocation() const {
|
519
|
+
return !allocation_domain_.empty();
|
520
|
+
}
|
521
|
+
|
522
|
+
// Returns if rfactor domain only consists of id's of iter type.
|
523
|
+
bool hasViewLikeRFactor() const;
|
524
|
+
|
525
|
+
bool hasVectorize() const;
|
526
|
+
|
527
|
+
NVF_API bool hasSymbolicAxis() const;
|
528
|
+
|
529
|
+
std::optional<int64_t> getReductionAxis() const;
|
530
|
+
|
531
|
+
const std::vector<IterDomain*>& noReductions() const {
|
532
|
+
return no_reduction_domain_;
|
533
|
+
}
|
534
|
+
|
535
|
+
const std::vector<IterDomain*>& noBroadcasts() const {
|
536
|
+
return no_bcast_domain_;
|
537
|
+
}
|
538
|
+
|
539
|
+
// The input logical domain. The root domain of a consumer should equal the
|
540
|
+
// logical domain of its producer ignoring reduction dimensions.
|
541
|
+
const std::vector<IterDomain*>& root() const {
|
542
|
+
return root_domain_;
|
543
|
+
};
|
544
|
+
|
545
|
+
const std::vector<IterDomain*>& maybeRoot() const {
|
546
|
+
return root_domain_.empty() ? logical_domain_ : root_domain_;
|
547
|
+
};
|
548
|
+
|
549
|
+
// Check if id is a root ID. Always return false if there's no root
|
550
|
+
// domain.
|
551
|
+
bool isRoot(const IterDomain* id) const {
|
552
|
+
return hasRoot() &&
|
553
|
+
std::find(root().begin(), root().end(), id) != root().end();
|
554
|
+
}
|
555
|
+
|
556
|
+
bool isMaybeRoot(const IterDomain* id) const {
|
557
|
+
return (hasRoot() && isRoot(id)) || (!hasRoot() && isLogical(id));
|
558
|
+
}
|
559
|
+
|
560
|
+
// The output logical domain.
|
561
|
+
const std::vector<IterDomain*>& logical() const {
|
562
|
+
return logical_domain_;
|
563
|
+
};
|
564
|
+
|
565
|
+
// Check if id is a logical ID.
|
566
|
+
bool isLogical(const IterDomain* id) const {
|
567
|
+
return std::find(logical().begin(), logical().end(), id) != logical().end();
|
568
|
+
}
|
569
|
+
|
570
|
+
// The allocation domain. This describes how data is stored in memory in
|
571
|
+
// outer-to-inner order.
|
572
|
+
const std::vector<IterDomain*>& allocation() const {
|
573
|
+
return allocation_domain_;
|
574
|
+
}
|
575
|
+
|
576
|
+
// Check if id is an allocation ID. Always return false if there's
|
577
|
+
// no allocation domain.
|
578
|
+
bool isAllocation(const IterDomain* id) const {
|
579
|
+
return hasAllocation() &&
|
580
|
+
std::find(allocation().begin(), allocation().end(), id) !=
|
581
|
+
allocation().end();
|
582
|
+
}
|
583
|
+
|
584
|
+
// The loop domain after scheduling. This defines loop nests and loop indices.
|
585
|
+
const std::vector<IterDomain*>& loop() const {
|
586
|
+
return loop_domain_;
|
587
|
+
}
|
588
|
+
|
589
|
+
const std::vector<IterDomain*>& initialLoop() const {
|
590
|
+
return initial_loop_domain_;
|
591
|
+
}
|
592
|
+
|
593
|
+
// Check if id is a loop ID.
|
594
|
+
bool isLoop(const IterDomain* id) const {
|
595
|
+
return std::find(loop().begin(), loop().end(), id) != loop().end();
|
596
|
+
}
|
597
|
+
|
598
|
+
// Check if id is an intial loop ID.
|
599
|
+
bool isInitialLoop(const IterDomain* id) const {
|
600
|
+
return std::find(initialLoop().begin(), initialLoop().end(), id) !=
|
601
|
+
loop().end();
|
602
|
+
}
|
603
|
+
|
604
|
+
// Get all IDs that is on the shortest path between any of the domains
|
605
|
+
// (logical domain, root domain, loop domain, allocation domain) following
|
606
|
+
// definition and uses path. Return values are topologically ordered and
|
607
|
+
// unique.
|
608
|
+
std::vector<IterDomain*> allIDs() const;
|
609
|
+
|
610
|
+
// Similar to allIDs but returns all ID expressions.
|
611
|
+
std::vector<Expr*> allExprs() const;
|
612
|
+
|
613
|
+
// Combine allIDs and allExprs
|
614
|
+
std::vector<Statement*> allStatements() const;
|
615
|
+
|
616
|
+
const std::vector<IterDomain*>& maybeAllocation() const {
|
617
|
+
return hasAllocation() ? allocation_domain_ : logical();
|
618
|
+
};
|
619
|
+
|
620
|
+
// Additional IDs that are not on the path from one of
|
621
|
+
// root/logical/allocation/loop domain to another. We need to keep track of
|
622
|
+
// these IDs to ensure that we can find all paths/IDs of interest.
|
623
|
+
const std::vector<IterDomain*>& additionalIDs() const {
|
624
|
+
return additional_ids_;
|
625
|
+
}
|
626
|
+
|
627
|
+
// Set the loop domain of this TensorDomain.
|
628
|
+
NVF_API void setLoopDomain(std::vector<IterDomain*> new_loop_domain);
|
629
|
+
|
630
|
+
// Set the allocation domain of this TensorDomain. Because contiguity is
|
631
|
+
// always defined w.r.t. the allocation domain, the contiguity must be updated
|
632
|
+
// accordingly.
|
633
|
+
NVF_API void setAllocationDomain(
|
634
|
+
std::vector<IterDomain*> new_allocation_domain,
|
635
|
+
std::vector<std::optional<bool>> new_contiguity);
|
636
|
+
|
637
|
+
// Similar to the previous one, but with new contiguity filled with all true
|
638
|
+
// or all false.
|
639
|
+
void setAllocationDomain(
|
640
|
+
std::vector<IterDomain*> new_allocation_domain,
|
641
|
+
bool new_contiguity) {
|
642
|
+
auto contiguity_flags =
|
643
|
+
getContiguityFilledWith(new_allocation_domain, new_contiguity);
|
644
|
+
setAllocationDomain(
|
645
|
+
std::move(new_allocation_domain), std::move(contiguity_flags));
|
646
|
+
}
|
647
|
+
|
648
|
+
void resetDomains() {
|
649
|
+
no_reduction_domain_ = noReductions(loop_domain_);
|
650
|
+
no_bcast_domain_ = noBroadcasts(loop_domain_);
|
651
|
+
has_reduction_ = hasReduction(loop_domain_);
|
652
|
+
}
|
653
|
+
|
654
|
+
// i here is int, as we want to accept negative value and ::size_type can be a
|
655
|
+
// uint.
|
656
|
+
IterDomain* axis(int64_t i) const;
|
657
|
+
|
658
|
+
int64_t posOf(IterDomain* id) const;
|
659
|
+
|
660
|
+
//! Returns a position of a root domain
|
661
|
+
int64_t rootPosOf(IterDomain* id) const;
|
662
|
+
|
663
|
+
//! Create a new broadcast IterDomain with the given extent in the loop domain
|
664
|
+
void broadcast(int64_t axis, Val* extent);
|
665
|
+
|
666
|
+
// Split "axis" into 2 axes
|
667
|
+
//! inner_split dictates if the factor section of the split should be inside
|
668
|
+
//! the
|
669
|
+
//! remainer or outside.
|
670
|
+
//! e.g. split(0, 4, inner_split = true) will result in:
|
671
|
+
//! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
|
672
|
+
//! e.g. split(0, 4, inner_split = false) will result in:
|
673
|
+
//! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
|
674
|
+
void split(int64_t axis_, Val* factor, bool inner_split);
|
675
|
+
|
676
|
+
// Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting
|
677
|
+
// axis is by default placed at original position axis_o
|
678
|
+
void merge(int64_t axis_o, int64_t axis_i);
|
679
|
+
|
680
|
+
// Reorder axes according to map[old_pos] = new_pos
|
681
|
+
void reorder(const std::unordered_map<int64_t, int64_t>& old2new);
|
682
|
+
|
683
|
+
//! Applies 2D swizzle on a rectangular tile defined by
|
684
|
+
//! a pair of iterdomains contained in this domain.
|
685
|
+
void swizzle(SwizzleType swizzle_type, int64_t x, int64_t y);
|
686
|
+
void swizzle(
|
687
|
+
Swizzle2DType swizzle_type,
|
688
|
+
int64_t x,
|
689
|
+
int64_t y,
|
690
|
+
SwizzleMode swizzle_mode = SwizzleMode::Data);
|
691
|
+
|
692
|
+
// Transform TensorView according to merge and split transformations
|
693
|
+
TensorDomain* view(const AnalyzeViewResult& view_analysis);
|
694
|
+
|
695
|
+
TensorDomain* flatten(int64_t start_dim, int64_t end_dim);
|
696
|
+
|
697
|
+
static std::vector<IterDomain*> orderedAs(
|
698
|
+
const std::vector<IterDomain*>& td,
|
699
|
+
const std::unordered_map<int64_t, int64_t>& old2new);
|
700
|
+
|
701
|
+
NVF_API static std::vector<IterDomain*> noReductions(
|
702
|
+
const std::vector<IterDomain*>&);
|
703
|
+
NVF_API static std::vector<IterDomain*> noBroadcasts(
|
704
|
+
const std::vector<IterDomain*>&);
|
705
|
+
NVF_API static std::vector<IterDomain*> noDevices(
|
706
|
+
const std::vector<IterDomain*>&);
|
707
|
+
|
708
|
+
static bool hasBroadcast(const std::vector<IterDomain*>&);
|
709
|
+
static bool hasReduction(const std::vector<IterDomain*>&);
|
710
|
+
|
711
|
+
// Get a vector whose size is the number of IDs in the given logical_domain
|
712
|
+
// filled with fill_value or nullopt depending on whether its corresponding ID
|
713
|
+
// is broadcast.
|
714
|
+
NVF_API static std::vector<std::optional<bool>> getContiguityFilledWith(
|
715
|
+
const std::vector<IterDomain*>& logical_domain,
|
716
|
+
bool fill_value);
|
717
|
+
|
718
|
+
// pair is in order where second is the consumer of first
|
719
|
+
std::pair<TensorDomain*, TensorDomain*> rFactor(
|
720
|
+
const std::vector<int64_t>& axes);
|
721
|
+
|
722
|
+
private:
|
723
|
+
int64_t wrapDim(int64_t dim) const {
|
724
|
+
return nvfuser::wrapDim(dim, nDims());
|
725
|
+
}
|
726
|
+
|
727
|
+
private:
|
728
|
+
const std::vector<IterDomain*> root_domain_;
|
729
|
+
const std::vector<IterDomain*> logical_domain_;
|
730
|
+
std::vector<IterDomain*> allocation_domain_;
|
731
|
+
std::vector<IterDomain*> loop_domain_;
|
732
|
+
// Initial loop domain. Loop domain is updated with transformations
|
733
|
+
// such as split, but the initial loop domain can only change with
|
734
|
+
// setLoopDomain
|
735
|
+
std::vector<IterDomain*> initial_loop_domain_;
|
736
|
+
std::vector<IterDomain*> additional_ids_;
|
737
|
+
|
738
|
+
std::vector<IterDomain*> no_bcast_domain_;
|
739
|
+
std::vector<IterDomain*> no_reduction_domain_;
|
740
|
+
std::vector<std::optional<bool>> contiguity_;
|
741
|
+
bool has_reduction_ = false;
|
742
|
+
};
|
743
|
+
|
744
|
+
} // namespace nvfuser
|