nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,651 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <iter_visitor.h>
|
12
|
+
#include <logical_domain_map.h>
|
13
|
+
|
14
|
+
#include <unordered_map>
|
15
|
+
#include <unordered_set>
|
16
|
+
#include <vector>
|
17
|
+
|
18
|
+
/*
|
19
|
+
* Index compute takes in a list of indices typically generated from the
|
20
|
+
* surrounding for loop nest. The number of indicies are intended to match the
|
21
|
+
* number of dimensions of the incomming TensorView which may have less or more
|
22
|
+
* dimensions than its allocation domain due to split/merge operations.
|
23
|
+
* Split/merge operations are then replayed backwards produce resulting
|
24
|
+
* indices (based on input indices) that match the allocation dimension.
|
25
|
+
*
|
26
|
+
* For example with GLOBAL tensor:
|
27
|
+
* TV[I, K]
|
28
|
+
* TV[Io, Ii{4}, K] = TV.split(I, factor=4)
|
29
|
+
* ALLOC: NONE
|
30
|
+
* INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
|
31
|
+
* FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k}
|
32
|
+
* PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
|
33
|
+
*
|
34
|
+
*
|
35
|
+
* For example with SHARED tensor:
|
36
|
+
*
|
37
|
+
* global_TV[I, K]
|
38
|
+
* global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4)
|
39
|
+
* smem_TV.compute_at(global_TV, 1)
|
40
|
+
* global_TV.parallelize(1, threadIDx.x)
|
41
|
+
*
|
42
|
+
* ALLOC: alloc(smem_TV, 4 x K)
|
43
|
+
* INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
|
44
|
+
* FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k}
|
45
|
+
* PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
|
46
|
+
* global
|
47
|
+
*
|
48
|
+
*
|
49
|
+
* For example with LOCAL tensor:
|
50
|
+
* global_TV[I, K, L]
|
51
|
+
* global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4)
|
52
|
+
* reg_TV.compute_at(global_TV, 2)
|
53
|
+
* global_TV.parallelize(1, threadIDx.x)
|
54
|
+
* global_TV{i, j, k, l} -> { i * 4 + j, k, l }
|
55
|
+
* global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l}
|
56
|
+
*
|
57
|
+
* ALLOC: alloc(reg_TV, K x L)
|
58
|
+
* INDEX: {k, l} -> {k, l}
|
59
|
+
* FLATTENED_INDEX: {k, l} -> {k * L + l}
|
60
|
+
* PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global
|
61
|
+
*
|
62
|
+
* These indices can then be flattened later based on strides.
|
63
|
+
*/
|
64
|
+
|
65
|
+
namespace nvfuser {
|
66
|
+
|
67
|
+
class ContigIDs;
|
68
|
+
class LoopIndexing;
|
69
|
+
struct IndexFromIdGraph;
|
70
|
+
class TensorIndexer;
|
71
|
+
|
72
|
+
class IndexCompute : public BackwardVisitor {
|
73
|
+
protected:
|
74
|
+
using BackwardVisitor::handle;
|
75
|
+
|
76
|
+
void dispatch(Expr*) override;
|
77
|
+
|
78
|
+
void handle(Split*) override;
|
79
|
+
void handle(Merge*) override;
|
80
|
+
void handle(Swizzle*) override;
|
81
|
+
void handle(Swizzle2D*) override;
|
82
|
+
void handle(Resize*) override;
|
83
|
+
|
84
|
+
// return extent_map_[id] if exists, else return id->extent()
|
85
|
+
Val* getExtent(IterDomain* id) const;
|
86
|
+
|
87
|
+
//! True if a domain is not used to index
|
88
|
+
bool isZero(IterDomain* id) const;
|
89
|
+
//! True if any dependent of a domain is not used to index
|
90
|
+
bool hasZeroMerged(IterDomain* id) const;
|
91
|
+
|
92
|
+
//! Returns the concrete ID from the compute at EXACT mode map if
|
93
|
+
//! concrete_id_pass == true, otherwise returns id passed in.
|
94
|
+
//! Helps unify the expr handling logic in reference domain and concrete id
|
95
|
+
//! based traversal.
|
96
|
+
IterDomain* maybeGetExactMapConcreteID(IterDomain* id) const;
|
97
|
+
|
98
|
+
//! (Concrete indexing pass only)
|
99
|
+
//! Collect permissive index binding from the given expression.
|
100
|
+
//! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList.
|
101
|
+
void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing);
|
102
|
+
|
103
|
+
//! (Concrete indexing pass only)
|
104
|
+
//! Iterate through id_expr's input and pull index vals from permissive
|
105
|
+
//! map, when both of the following are true:
|
106
|
+
//! 1. the output id is missing in index_map_.
|
107
|
+
//! 2. the output id is found in permissive map.
|
108
|
+
void updateIndexMapFromPermissiveMap(const Expr* id_expr);
|
109
|
+
|
110
|
+
//! Initialize unswitched_domain_map_ from the loop unswitched
|
111
|
+
//! domains
|
112
|
+
void initializeUnswitchDomainMap();
|
113
|
+
|
114
|
+
//! Propagate unswitched map info from expr outputs to inputs
|
115
|
+
void updateUnswitchedDomains(Expr* expr);
|
116
|
+
|
117
|
+
//! Query if an IterDomain has a dependent unswitched domain
|
118
|
+
bool hasUnswitchedDependentDomains(IterDomain* id) const;
|
119
|
+
|
120
|
+
//! Query if the usual modulo propagation may be invalid for a merge
|
121
|
+
//! inner path
|
122
|
+
bool isModuloInvalidUnswitchedIndex(
|
123
|
+
IterDomain* out_concrete_id,
|
124
|
+
Val* out_ind,
|
125
|
+
Val* inner_extent) const;
|
126
|
+
|
127
|
+
// Tensor domain we're mapping back to allocation
|
128
|
+
const TensorDomain* td_; // NOLINT
|
129
|
+
|
130
|
+
// Map we update as we propagate backward, containing all IDs in the
|
131
|
+
// propagation. Initial indices are mapped with this map at tv->domain()
|
132
|
+
// and are back propagated to tv->getMaybeAllocationDomain(). This index_map_
|
133
|
+
// keeps the indices at intermediate IterDomain's in that back propagation.
|
134
|
+
std::unordered_map<IterDomain*, Val*> index_map_; // NOLINT
|
135
|
+
|
136
|
+
// Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
|
137
|
+
// producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
|
138
|
+
// the extent I0*I1. Also contains updated extents if we merge in a 0 index.
|
139
|
+
// See zero_merged_in_.
|
140
|
+
std::unordered_map<IterDomain*, Val*> extent_map_; // NOLINT
|
141
|
+
|
142
|
+
// Keeps track of domains that do not contribute to indexing
|
143
|
+
std::unordered_set<IterDomain*> zero_domains_; // NOLINT
|
144
|
+
|
145
|
+
// This set keeps track of IterDomain's that have had a zero index merged into
|
146
|
+
// them. This happens if we do something like tv->axis(0)->split(4) then
|
147
|
+
// tv->computeAt(1, ...) if this tensor is in smem or lmem the backward
|
148
|
+
// indexing would be (0, i) then when we do the backward computation that zero
|
149
|
+
// and i would attempt to be merged together. We handle indices like these
|
150
|
+
// specially.
|
151
|
+
std::unordered_set<IterDomain*> zero_merged_in_;
|
152
|
+
|
153
|
+
// IDs that are a result of contiguous merges
|
154
|
+
std::unordered_set<IterDomain*> contig_ids_;
|
155
|
+
|
156
|
+
// Mentions if we should propagate an index down a particular IterDomain path
|
157
|
+
// if there's an option
|
158
|
+
std::unordered_set<IterDomain*> preferred_paths_;
|
159
|
+
|
160
|
+
// Temporary flag which tells IndexCompute to use concrete id's from the exact
|
161
|
+
// map rather than the actual IDs used in the ID expressions.
|
162
|
+
bool concrete_id_pass_ = false;
|
163
|
+
|
164
|
+
// Mode of swizzle that are activated in this index compute
|
165
|
+
// instance. Will treat swizzles of different mode as no-op.
|
166
|
+
// Currently data mode swizzles are handled same as before in IndexSwizzle
|
167
|
+
// pass, while loop mode swizzles are handled early on in concrete indexing
|
168
|
+
// pass. See also [Note on swizzle mode]
|
169
|
+
SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;
|
170
|
+
|
171
|
+
// (Concrete id pass only)
|
172
|
+
// Contains the indexing math that could be resolved with only the
|
173
|
+
// iterdomains on the right of the consumer_tv's ca axis, i.e. the
|
174
|
+
// ones that corresponding to the loops that consumer_tv would not
|
175
|
+
// share with any of its consumers.
|
176
|
+
// These indexing vals should be kept separate from index_map_ and
|
177
|
+
// should only be used when the indexing traversal follows the
|
178
|
+
// order defined in LoopIndexingAnalysis::traverseFromDomainVals.
|
179
|
+
std::unordered_map<IterDomain*, Val*> permissive_index_map_;
|
180
|
+
|
181
|
+
//! Leaf domains that have maximum index values for unswitch
|
182
|
+
//! predicates. These domains need extra adjustments when going
|
183
|
+
//! through module operations for merge inner domains as module does
|
184
|
+
//! not always guarantee to preserve the maximum-ness property
|
185
|
+
std::unordered_set<IterDomain*> unswitched_loop_domains_;
|
186
|
+
|
187
|
+
//! Mapppings from unswitched IterDomains to their unswitched
|
188
|
+
//! domains and their inner domains. Used to figure out if a module
|
189
|
+
//! could invalidate the maximum-ness property of an unswitched index.
|
190
|
+
//!
|
191
|
+
//! Mappings are created in a bottom-up fashion from loop to root
|
192
|
+
//! such that fine-grained domain mappings are kept as much as
|
193
|
+
//! possible for making the modulo analysis most precise.
|
194
|
+
//!
|
195
|
+
//! Specifically, for the loop domains, this just maps unswitched
|
196
|
+
//! domains, i.e., those included in unswitched_loop_domains_, to
|
197
|
+
//! themselves. There'll be no mapping for those loop domains that
|
198
|
+
//! are not included in unswitched_loop_domains_. The mappings of
|
199
|
+
//! all other domains are defined based on their consumer
|
200
|
+
//! domains. By default, they are also just mapped
|
201
|
+
//! to themselves if any of the consumers are also mapped. However,
|
202
|
+
//! when a domain is the input to a split, the mappings of the split output
|
203
|
+
//! domains are tracked separately and the split input will be
|
204
|
+
//! mapped to two sets of unswitched domains, one from the inner
|
205
|
+
//! output and another from the outer output. The mapping info from
|
206
|
+
//! the inner output is propagated as is, whereas the mapping info
|
207
|
+
//! from the outer output is prepended with the inner output
|
208
|
+
//! domain so that the unswitched domain list includes its inner
|
209
|
+
//! domain. Note that the semantic of inner domains is defined based
|
210
|
+
//! on split operations since they define propagated index math.
|
211
|
+
//!
|
212
|
+
//! The reason of tracking the information from split outer domains
|
213
|
+
//! separately is to avoid adjusting the unswitched predicate index
|
214
|
+
//! as much as possible. For example, here's a common transpose
|
215
|
+
//! scheduling pattern:
|
216
|
+
//!
|
217
|
+
//! // Initial 2D tensor
|
218
|
+
//! [i0, i1]
|
219
|
+
//! // Create a square tile of 32x32
|
220
|
+
//! -> [i0 / 32, 32, i1 / 32, 32]
|
221
|
+
//! -> [i0 / 32 * i1 / 32, 32 * 32]
|
222
|
+
//! // Factor out a small domain (commonly vectorized)
|
223
|
+
//! -> [i0 / 32 * i1 / 32, 32 * 32 / 4, 4]
|
224
|
+
//! // Factor out another domain (commonly parallelized by TIDx)
|
225
|
+
//! -> [i0 / 32 * i1 / 32, 32 * 32 / 4 / 128, 128, 4]
|
226
|
+
//!
|
227
|
+
//! Notice that the merge of "32 * 32" is not contiguous, so we need
|
228
|
+
//! to predicate its input domains by propagating index exprs
|
229
|
+
//! through the merge inner path with "% 32". If any of the final
|
230
|
+
//! loop domains are unswitched, we need to make sure the index expr
|
231
|
+
//! sent through "% 32" is the maximum for the domain of extent
|
232
|
+
//! "32". Conservatively, this can just be 31, however, that isn't
|
233
|
+
//! always strictly required. For example, suppose the innermost
|
234
|
+
//! domain of extent 4 is unswitched. Its initial index is
|
235
|
+
//! 3. Propagating it through the merge inner path as usual is
|
236
|
+
//! guaranteed to be correct. More generally, it's always the case
|
237
|
+
//! when the inner extent of a merge is divisible by the extent of
|
238
|
+
//! an unswitched output and its domains. Suppose also the third
|
239
|
+
//! innermost domain is also unswitched, its initial index is 1. Its
|
240
|
+
//! contribution through the merge inner path is zero as the initial
|
241
|
+
//! index is multiplied by the extents of its inner domains, i.e.,
|
242
|
+
//! 128 and 4, and they are divisible by the extent of the merge
|
243
|
+
//! inner domain. Again, more generally, if the stride of an
|
244
|
+
//! unswitched domain is a multiple of the inner extent of the merge
|
245
|
+
//! operation producing the unswitched domain, there's no
|
246
|
+
//! contribution from the unswitched domain, so it doesn't matter if
|
247
|
+
//! it's maximum or not.
|
248
|
+
//!
|
249
|
+
//! In the above pattern, the second innermost domain is commonly
|
250
|
+
//! parallelized with TIDx. Suppose it's also unswitched. Notice
|
251
|
+
//! that there's no concern for that domain of invalding the
|
252
|
+
//! maximum-ness property as threadIdx.x is the only valid initial
|
253
|
+
//! index value for each thread. However, this is the reason we keep track
|
254
|
+
//! of the split output contributions separately. More specifically,
|
255
|
+
//! the intermediate domain of (32 * 32 / 4) will have an index of
|
256
|
+
//! (1 * 128 + threadIdx.x), and the domain of (32 * 32) will have
|
257
|
+
//! (1 * 128 * 4 + threadIdx.x * 4 + 3). As discussed above, we can
|
258
|
+
//! reason about that the first and third components of this
|
259
|
+
//! unswitched expression is safe with respect to the propagation
|
260
|
+
//! with modulo by 32. The second component is also safe as that's
|
261
|
+
//! the only valid index for the domain. If not separately tracked,
|
262
|
+
//! all we could know would be that the extent of (32 * 32) is
|
263
|
+
//! 1024. Since part of the dependent domains are parallelized the
|
264
|
+
//! propagated index is not guaranteed to be 1023, so we would need
|
265
|
+
//! to make a conservative decision to send 1023 to the merge inner
|
266
|
+
//! path.
|
267
|
+
std::unordered_map<IterDomain*, std::vector<std::deque<IterDomain*>>>
|
268
|
+
unswitched_domain_map_;
|
269
|
+
|
270
|
+
public:
|
271
|
+
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
|
272
|
+
return index_map_;
|
273
|
+
}
|
274
|
+
|
275
|
+
const std::unordered_map<IterDomain*, Val*>& extentMap() const {
|
276
|
+
return extent_map_;
|
277
|
+
}
|
278
|
+
|
279
|
+
const std::unordered_set<IterDomain*>& zeroDomains() const {
|
280
|
+
return zero_domains_;
|
281
|
+
}
|
282
|
+
|
283
|
+
const std::unordered_set<IterDomain*>& zeroMergedIn() const {
|
284
|
+
return zero_merged_in_;
|
285
|
+
}
|
286
|
+
|
287
|
+
// Propagate back from _td using initial_index_map
|
288
|
+
IndexCompute(
|
289
|
+
const TensorDomain* _td,
|
290
|
+
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
291
|
+
std::unordered_map<IterDomain*, Val*> _extent_map,
|
292
|
+
std::unordered_set<IterDomain*> zero_domains,
|
293
|
+
std::unordered_set<IterDomain*> _zero_merged_in,
|
294
|
+
std::unordered_set<IterDomain*> preferred_paths = {});
|
295
|
+
|
296
|
+
IndexCompute(
|
297
|
+
const TensorDomain* _td,
|
298
|
+
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
299
|
+
std::unordered_map<IterDomain*, Val*> _extent_map,
|
300
|
+
std::unordered_set<IterDomain*> zero_domains,
|
301
|
+
std::unordered_set<IterDomain*> _zero_merged_in,
|
302
|
+
const ContigIDs& contig_finder,
|
303
|
+
std::unordered_set<IterDomain*> preferred_paths = {},
|
304
|
+
std::unordered_set<IterDomain*> unswitched_domains = {});
|
305
|
+
|
306
|
+
// Entry point used for using concrete id based traversal. This traversal is
|
307
|
+
// assumed to start at loop IDs provided by initial_index_map.
|
308
|
+
IndexCompute(
|
309
|
+
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
310
|
+
std::unordered_set<IterDomain*> zero_domains,
|
311
|
+
std::unordered_set<IterDomain*> preferred_paths,
|
312
|
+
std::unordered_set<IterDomain*> unswitched_domains = {});
|
313
|
+
|
314
|
+
// Updates index_map, extent_map, and zero_merged_in based on id_map and
|
315
|
+
// returns a new IndexCompute ready to be used.
|
316
|
+
IndexCompute updateIndexCompute(
|
317
|
+
const TensorDomain* new_td,
|
318
|
+
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
|
319
|
+
id_map,
|
320
|
+
const ContigIDs& contig_finder) const;
|
321
|
+
|
322
|
+
// Interface to run index traversal through loop indexing analysis result to
|
323
|
+
// be used with the entry point for concrete id based traversal.
|
324
|
+
void run(const LoopIndexing& loop_indexing);
|
325
|
+
|
326
|
+
virtual void run();
|
327
|
+
};
|
328
|
+
|
329
|
+
//! Apply swizzle and update allocation indices accordingly
|
330
|
+
class IndexSwizzle : public IndexCompute {
|
331
|
+
public:
|
332
|
+
IndexSwizzle(
|
333
|
+
const TensorView* tv,
|
334
|
+
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
335
|
+
std::unordered_map<IterDomain*, Val*> extent_map,
|
336
|
+
std::unordered_set<IterDomain*> zero_domains,
|
337
|
+
std::unordered_set<IterDomain*> zero_merged_in);
|
338
|
+
|
339
|
+
IndexSwizzle(
|
340
|
+
const TensorView* tv,
|
341
|
+
const TensorDomain* domain,
|
342
|
+
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
343
|
+
std::unordered_map<IterDomain*, Val*> extent_map,
|
344
|
+
std::unordered_set<IterDomain*> zero_domains,
|
345
|
+
std::unordered_set<IterDomain*> zero_merged_in);
|
346
|
+
|
347
|
+
void run() override;
|
348
|
+
|
349
|
+
protected:
|
350
|
+
using IndexCompute::handle;
|
351
|
+
|
352
|
+
void dispatch(Expr* e) override;
|
353
|
+
|
354
|
+
void handle(Swizzle2D* swizzle_2d) override;
|
355
|
+
|
356
|
+
private:
|
357
|
+
const TensorView* tv_ = nullptr;
|
358
|
+
std::unordered_set<IterDomain*> swizzled_ids_;
|
359
|
+
};
|
360
|
+
|
361
|
+
//! Information about a predicate. By default, it corresponds to a
|
362
|
+
//! single logical domain but may cover multiple logial domains due to
|
363
|
+
//! contigous indexing.
|
364
|
+
class PredicateInfo {
|
365
|
+
friend class Index;
|
366
|
+
friend class TensorIndexer;
|
367
|
+
|
368
|
+
public:
|
369
|
+
const auto& startPredicate() const {
|
370
|
+
return start_predicate_;
|
371
|
+
}
|
372
|
+
|
373
|
+
auto& startPredicate() {
|
374
|
+
return start_predicate_;
|
375
|
+
}
|
376
|
+
|
377
|
+
const auto& startOffset() const {
|
378
|
+
return start_offset_;
|
379
|
+
}
|
380
|
+
|
381
|
+
const auto& stopPredicate() const {
|
382
|
+
return stop_predicate_;
|
383
|
+
}
|
384
|
+
|
385
|
+
const auto& stopOffset() const {
|
386
|
+
return stop_offset_;
|
387
|
+
}
|
388
|
+
|
389
|
+
const auto& predicatedDomains() const {
|
390
|
+
return predicated_domains_;
|
391
|
+
}
|
392
|
+
|
393
|
+
const auto& loopDomains() const {
|
394
|
+
return loop_domains_;
|
395
|
+
}
|
396
|
+
|
397
|
+
CircularBufferLoopStage loopStage() const {
|
398
|
+
return loop_stage_;
|
399
|
+
}
|
400
|
+
|
401
|
+
//! Return a false RootPredicateInfo, i.e., both start and stop
|
402
|
+
//! predicates are false.
|
403
|
+
static PredicateInfo getFalseInfo();
|
404
|
+
|
405
|
+
private:
|
406
|
+
// prdicate for lower end
|
407
|
+
Val* start_predicate_ = nullptr;
|
408
|
+
// prdicate for upper end
|
409
|
+
Val* stop_predicate_ = nullptr;
|
410
|
+
// Offset of the start predicate
|
411
|
+
Val* start_offset_ = nullptr;
|
412
|
+
// Offset of the stop predicate
|
413
|
+
Val* stop_offset_ = nullptr;
|
414
|
+
// Track which domains are covered by the generated predicates
|
415
|
+
std::unordered_set<IterDomain*> predicated_domains_;
|
416
|
+
// Loops domains used for the predicate domains
|
417
|
+
std::unordered_set<IterDomain*> loop_domains_;
|
418
|
+
// Circular buffer loop stage if applicable
|
419
|
+
CircularBufferLoopStage loop_stage_ = CircularBufferLoopStage::NotApplicable;
|
420
|
+
};
|
421
|
+
|
422
|
+
// Simple interface for IndexCompute
|
423
|
+
// If getComputeAtAxis and more generally TensorView const model is fixed, we
|
424
|
+
// can make the below tensorviews const.
|
425
|
+
class Index {
|
426
|
+
private:
|
427
|
+
// Producer indexing if it's in shared or local memory
|
428
|
+
static std::vector<Val*> getNonGlobalProducerStridedIndices(
|
429
|
+
TensorView* producer,
|
430
|
+
const TensorView* consumer,
|
431
|
+
const std::vector<ForLoop*>& loops,
|
432
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
433
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
|
434
|
+
|
435
|
+
// Consumer indexing if it's in shared or local memory
|
436
|
+
static std::vector<Val*> getNonGlobalConsumerStridedIndices(
|
437
|
+
const TensorView* consumer,
|
438
|
+
const std::vector<ForLoop*>& loops,
|
439
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
440
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
|
441
|
+
|
442
|
+
// get the strides of a tensor used for the index lowering
|
443
|
+
static std::vector<Val*> getStrides(TensorView* tv);
|
444
|
+
|
445
|
+
// get the allocation indices of a consumer tensor
|
446
|
+
static std::vector<Val*> getConsumerAllocationIndices(
|
447
|
+
const TensorView* tv,
|
448
|
+
const std::vector<ForLoop*>& loops,
|
449
|
+
const IndexFromIdGraph& index_from_id_graph);
|
450
|
+
|
451
|
+
// get the allocation indices of a producer tensor
|
452
|
+
static std::vector<Val*> getProducerAllocationIndices(
|
453
|
+
TensorView* producer,
|
454
|
+
const TensorView* consumer,
|
455
|
+
const std::vector<ForLoop*>& loops,
|
456
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
457
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
|
458
|
+
|
459
|
+
public:
|
460
|
+
// Producer if it's in global memory
|
461
|
+
static std::vector<Val*> getGlobalProducerStridedIndices(
|
462
|
+
TensorView* producer,
|
463
|
+
const TensorView* consumer,
|
464
|
+
const std::vector<ForLoop*>& loops,
|
465
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
466
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
|
467
|
+
|
468
|
+
// Consumer indexing if it's in global memory
|
469
|
+
static std::vector<Val*> getGlobalConsumerStridedIndices(
|
470
|
+
TensorView* consumer,
|
471
|
+
const std::vector<ForLoop*>& loops,
|
472
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
473
|
+
const std::unordered_map<int, Val*>& override_index = {});
|
474
|
+
|
475
|
+
// Indexing functions
|
476
|
+
// Consumer = Producer
|
477
|
+
// i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
|
478
|
+
// Producer indexing dispatch
|
479
|
+
// The argument `generate_pointer` specifies whether to generate pointer for
|
480
|
+
// the tensor. If global tensor, then generate T1.data. If shared memory
|
481
|
+
// tensor, then use `cvta` ptx to convert shared memory address to unsigned
|
482
|
+
// int for indexing. Search `toSmem` in the codebase for additional
|
483
|
+
// information. This argument is effective only if the indexed tensor is a
|
484
|
+
// shared memory or global tensor. On other memory type, this argument will
|
485
|
+
// cause an error.
|
486
|
+
static kir::TensorIndex* getProducerIndex(
|
487
|
+
TensorView* producer,
|
488
|
+
const TensorView* consumer,
|
489
|
+
const std::vector<ForLoop*>& loops,
|
490
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
491
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {},
|
492
|
+
bool generate_pointer = false,
|
493
|
+
DataType as_type = DataType::Null);
|
494
|
+
|
495
|
+
// Consumer index dispatch
|
496
|
+
static kir::TensorIndex* getConsumerIndex(
|
497
|
+
TensorView* consumer,
|
498
|
+
const std::vector<ForLoop*>& loops,
|
499
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
500
|
+
const std::unordered_map<int, Val*>& override_index = {},
|
501
|
+
bool generate_pointer = false,
|
502
|
+
DataType as_type = DataType::Null);
|
503
|
+
|
504
|
+
//! Returns a vector of strided indices mapped onto the
|
505
|
+
//! allocation domain of a producer tensor. The size of the returned
|
506
|
+
//! vector is guaranteed to be equal to the number of axes of the
|
507
|
+
//! indexing allocation domain.
|
508
|
+
static Val* getProducerStridedIndices(
|
509
|
+
TensorView* producer,
|
510
|
+
const TensorView* consumer,
|
511
|
+
const std::vector<ForLoop*>& loops,
|
512
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
513
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {},
|
514
|
+
bool generate_pointer = false);
|
515
|
+
|
516
|
+
//! Returns a vector of strided indices mapped onto the
|
517
|
+
//! allocation domain of a consumer tensor. The size of the returned
|
518
|
+
//! vector is guaranteed to be equal to the number of axes of the
|
519
|
+
//! indexing allocation domain.
|
520
|
+
static Val* getConsumerStridedIndices(
|
521
|
+
TensorView* consumer,
|
522
|
+
const std::vector<ForLoop*>& loops,
|
523
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
524
|
+
const std::unordered_map<int, Val*>& override_index = {},
|
525
|
+
bool generate_pointer = false);
|
526
|
+
|
527
|
+
//! Returns the logical index linearized from a multi-dimension address into a
|
528
|
+
//! linear memory address a consumer tensor. The returned index is intended to
|
529
|
+
//! be used for the computation of some tensor factories, such as: iota and
|
530
|
+
//! rand (for Philox pseudo random sequences)
|
531
|
+
static Val* getLinearLogicalIndex(
|
532
|
+
TensorView* consumer_tv,
|
533
|
+
const std::vector<ForLoop*>& loops,
|
534
|
+
const std::unordered_set<ForLoop*>& rotated_loops);
|
535
|
+
|
536
|
+
//! Returns a vector of logical indices mapped onto the logical
|
537
|
+
//! domain of a consumer tensor. The returned index is intended
|
538
|
+
//! to be used for the computation of some tensor factories, such as:
|
539
|
+
//! eye
|
540
|
+
static std::vector<Val*> getConsumerPerDimLogicalIndex(
|
541
|
+
TensorView* consumer_tv,
|
542
|
+
const std::vector<ForLoop*>& loops,
|
543
|
+
const std::unordered_set<ForLoop*>& rotated_loops);
|
544
|
+
|
545
|
+
//! Returns a vector of logical indices mapped onto the logical
|
546
|
+
//! domain of a producer tensor.
|
547
|
+
static std::vector<Val*> getProducerPerDimLogicalIndex(
|
548
|
+
TensorView* producer_tv,
|
549
|
+
const TensorView* consumer_tv,
|
550
|
+
const std::vector<ForLoop*>& loops,
|
551
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
552
|
+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
|
553
|
+
|
554
|
+
//! Take a consumer tensorview and loop nest and generates predicates
|
555
|
+
//! associated with the concrete roots of the loop nest. Returns a list of
|
556
|
+
//! predicates, and a list of concrete roots they're associated with. It
|
557
|
+
//! is assumed that no predicate is required if index[i] is an index
|
558
|
+
//! directly from a for loop. This will not catch all cases if we actually
|
559
|
+
//! have static size information for example:
|
560
|
+
//!
|
561
|
+
//! TV[I].split(4)
|
562
|
+
//! would produce the code:
|
563
|
+
//! for(i : I/4)
|
564
|
+
//! for(j : 4)
|
565
|
+
//! if( i * 4 + j < TV.size(0))
|
566
|
+
//! TV[i * 4 + j]...
|
567
|
+
//!
|
568
|
+
//! However if we had TV.size[0] = 16 at "compile time" then we wouldn't
|
569
|
+
//! need the predicate. This will be caught by canOmitPredicate in the
|
570
|
+
//! predicate lowering
|
571
|
+
//!
|
572
|
+
//! unswitch_or_vec_loop is the for loop to start the unswitch like
|
573
|
+
//! predicate, this is not a bool value as if we have an unswitch loop
|
574
|
+
//! with a vectorized loop inside, we only want to base the "unswitch"
|
575
|
+
//! like predicate on the vectorized loop.
|
576
|
+
static std::vector<PredicateInfo> getReferenceRootPredicates(
|
577
|
+
TensorView* consumer_tv,
|
578
|
+
const std::vector<ForLoop*>& loops,
|
579
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
580
|
+
ForLoop* unswitch_or_vec_loop);
|
581
|
+
|
582
|
+
//! Compute the result for iota
|
583
|
+
static Val* iota(
|
584
|
+
TensorView* consumer_tv,
|
585
|
+
const std::vector<ForLoop*>& loops,
|
586
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
587
|
+
Val* start,
|
588
|
+
Val* step,
|
589
|
+
DataType dtype);
|
590
|
+
|
591
|
+
//! Compute the result for eye
|
592
|
+
static Val* eye(
|
593
|
+
TensorView* consumer_tv,
|
594
|
+
const std::vector<ForLoop*>& loops,
|
595
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
596
|
+
DataType dtype);
|
597
|
+
|
598
|
+
//! Compute the global index and the expected bytes for complete_tx mechanism
|
599
|
+
//! for CpAsyncBulk.
|
600
|
+
static std::pair<Val*, Val*> getCpAsyncBulkGmemIndex(
|
601
|
+
const LoadStoreOp* ldst,
|
602
|
+
Val* mbarrier,
|
603
|
+
const std::vector<ForLoop*>& loops,
|
604
|
+
const std::unordered_set<ForLoop*>& rotated_loops);
|
605
|
+
};
|
606
|
+
|
607
|
+
// Used for local and shared index mapping. Returns a map from loops
|
608
|
+
// to loop indices as well as a set of loops that do not contribute to
|
609
|
+
// indexing.
|
610
|
+
// TODO: could be cleaned up further.
|
611
|
+
std::pair<std::unordered_map<ForLoop*, Val*>, std::unordered_set<ForLoop*>>
|
612
|
+
indexMapFromTV(
|
613
|
+
const TensorView* tv,
|
614
|
+
const std::vector<ForLoop*>& loops,
|
615
|
+
const std::unordered_set<ForLoop*>& rotated_loops,
|
616
|
+
ForLoop* alloc_loop,
|
617
|
+
bool as_consumer,
|
618
|
+
ForLoop* circular_buffer_loop = nullptr);
|
619
|
+
|
620
|
+
//! Set "pragma unroll" required for loops that indexing of Local
|
621
|
+
//! tensors depends on.
|
622
|
+
//!
|
623
|
+
//! \param tv Indexed tensor
|
624
|
+
//! \param alloc_loop Allocation loop of tv
|
625
|
+
//! \param loops The current loop structure
|
626
|
+
//! \param id_map Producer-to-consumer map in case of indexing as producer
|
627
|
+
void ensureStaticIndexing(
|
628
|
+
const TensorView* tv,
|
629
|
+
ForLoop* alloc_loop,
|
630
|
+
const std::vector<ForLoop*>& loops,
|
631
|
+
const std::unordered_map<IterDomain*, IterDomain*>& id_map = {});
|
632
|
+
|
633
|
+
struct PredicateDomainInfo {
|
634
|
+
public:
|
635
|
+
// Iteration domain to predicate
|
636
|
+
IterDomain* id = nullptr;
|
637
|
+
// The set of iteration domains that make up the id. If this is for
|
638
|
+
// a non-divisible split, the set only contains the id itself. This
|
639
|
+
// set is used to remove redundant predicates when gathering
|
640
|
+
// unswitch predicates.
|
641
|
+
std::unordered_set<IterDomain*> covered_ids;
|
642
|
+
// True if this predicate is for an intermediate domain. Examples
|
643
|
+
// include domains with non-divisible split and resized domains.
|
644
|
+
bool is_intermediate_domain = false;
|
645
|
+
};
|
646
|
+
|
647
|
+
// Get all domains that need to be predicated due to non-divisible splits
|
648
|
+
std::vector<PredicateDomainInfo> getNonDivisibleConsumerDomainsToPredicate(
|
649
|
+
TensorView* consumer_tv);
|
650
|
+
|
651
|
+
} // namespace nvfuser
|