nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,26 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <kernel.h>
|
12
|
+
#include <visibility.h>
|
13
|
+
|
14
|
+
#include <string>
|
15
|
+
|
16
|
+
namespace nvfuser {
|
17
|
+
namespace codegen {
|
18
|
+
|
19
|
+
//! Generates a CUDA kernel definition for the given kernel
|
20
|
+
NVF_API std::string generateCudaKernel(
|
21
|
+
const kir::Kernel* kernel,
|
22
|
+
const std::string& kernel_name = "CUDAGeneratedKernel",
|
23
|
+
std::optional<int64_t> num_threads_per_cta = std::nullopt);
|
24
|
+
|
25
|
+
} // namespace codegen
|
26
|
+
} // namespace nvfuser
|
@@ -0,0 +1,28 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <ir/interface_nodes.h>
|
11
|
+
|
12
|
+
namespace nvfuser {
|
13
|
+
|
14
|
+
class TensorDomain;
|
15
|
+
class TensorView;
|
16
|
+
|
17
|
+
struct ComputeAt {
|
18
|
+
public:
|
19
|
+
// Runs the compute at pass making producer look like consumer, computing
|
20
|
+
// producer relative to consumer
|
21
|
+
static void runAt(
|
22
|
+
TensorView* producer,
|
23
|
+
TensorView* consumer,
|
24
|
+
int64_t consumer_position,
|
25
|
+
ComputeAtMode mode = ComputeAtMode::Standard);
|
26
|
+
};
|
27
|
+
|
28
|
+
} // namespace nvfuser
|
@@ -0,0 +1,394 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <device_lower/analysis/trivial_broadcast.h>
|
11
|
+
#include <disjoint_set.h>
|
12
|
+
#include <exceptions.h>
|
13
|
+
#include <ir/all_nodes.h>
|
14
|
+
#include <kernel_ir.h>
|
15
|
+
#include <visibility.h>
|
16
|
+
|
17
|
+
#include <deque>
|
18
|
+
#include <unordered_map>
|
19
|
+
|
20
|
+
namespace nvfuser {
|
21
|
+
|
22
|
+
class IdModelValidator;
|
23
|
+
|
24
|
+
// There's four modes of these iter domain mappings all uniquely important in
|
25
|
+
// the lowering process.
|
26
|
+
//
|
27
|
+
// For EXACT/PERMISSIVE mode consider:
|
28
|
+
//
|
29
|
+
// consumer[i0, b1] = producer[i0]
|
30
|
+
// consumer->merge(0) (consumer will now be [i0 * b1])
|
31
|
+
// When producer is replayed as consumer (the direction we use for mapping)
|
32
|
+
// with BestEffortReplay forward_bcast_mismatch = True the producer to
|
33
|
+
// consumer map will have both a mapping of consumer(i0) to producer(i0) as
|
34
|
+
// well as consumer(i0*b1) to producer(i0). This latter mapping is important
|
35
|
+
// for loop nest mappings as the consumer will generate a loop based on i0*b1
|
36
|
+
// and the producer may be computeAt inside this loop nest. However, for
|
37
|
+
// indexing we do not want these two maps as producer may be indexed as i0*i1
|
38
|
+
// depending on the loop nest structure and how it was built. Therefore we
|
39
|
+
// really need to carry (at least) two sets of maps around for lowering.
|
40
|
+
//
|
41
|
+
// LOOP mode is important if we have something like:
|
42
|
+
// consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt
|
43
|
+
// = 1) which can easily happen when using shared memory. We want to make sure
|
44
|
+
// that the iteration domain used for loop construction (concreteId) has the
|
45
|
+
// proper parallelization strategy. In parallel mode we do typical iteration
|
46
|
+
// domain mapping, however we remove from it any iteration domains outside the
|
47
|
+
// computeAt of producer when mapping. This guarentees we won't map
|
48
|
+
// IterDomains that could have different parallelization strategies. We also
|
49
|
+
// propagate the parallel strategy in parallel mode so all mapped IDs that
|
50
|
+
// must have the same parallel type, do.
|
51
|
+
//
|
52
|
+
// IdMappingMode::LOOP
|
53
|
+
// Only maps leaf axes to left of compute at
|
54
|
+
// Forward broadcast axes in replay
|
55
|
+
// IdMappingMode::PERMISSIVE
|
56
|
+
// Forward broadcast axes in replay
|
57
|
+
// Map all iteration domains
|
58
|
+
// Always contain root mappings (otherwise they could have been forwarded in
|
59
|
+
// broadcast)
|
60
|
+
// IdMappingMode::PERMISSIVE_RESIZE
|
61
|
+
// Include everything in PERMISSIVE. Map also domains that are
|
62
|
+
// inputs and outputs of resize ops. Used for, e.g., propagating
|
63
|
+
// parallel types across those domains. It also maps producers and
|
64
|
+
// consumers of gathered and scattered domains
|
65
|
+
// IdMappingMode::INNERMOST
|
66
|
+
// Include everything in PERMISSIVE_RESIZE. Maps also iter domain across
|
67
|
+
// split/merge to the inner domain, it is used to map inner most iter domain.
|
68
|
+
// i.e. transpose scheduler use this to map inner most domain.
|
69
|
+
// IdMappingMode::EXACT
|
70
|
+
// Don't map any broadcast axes to non-broadcast axes
|
71
|
+
// Do not forward through any broadcast IDs
|
72
|
+
// IdMappingMode::AlmostExact
|
73
|
+
// Forward through broadcast axes, but not through to a non-broadcast axis
|
74
|
+
// i.e. id{b1*i0}, id{i0} are mapped
|
75
|
+
// id{i1*i0}, id{i0} are not mapped (this part is the difference from
|
76
|
+
// PERMISSIVE)
|
77
|
+
// Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped
|
78
|
+
//
|
79
|
+
class IterDomainGraph {
|
80
|
+
public:
|
81
|
+
NVF_API IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false);
|
82
|
+
|
83
|
+
const DisjointSets<IterDomain*>& permissiveNodes() const {
|
84
|
+
return permissive_nodes_;
|
85
|
+
}
|
86
|
+
const DisjointSets<IterDomain*>& exactNodes() const {
|
87
|
+
return exact_nodes_;
|
88
|
+
}
|
89
|
+
const DisjointSets<IterDomain*>& almostExactNodes() const {
|
90
|
+
return almost_exact_nodes_;
|
91
|
+
}
|
92
|
+
const DisjointSets<IterDomain*>& loopNodes() const {
|
93
|
+
return loop_nodes_;
|
94
|
+
}
|
95
|
+
const DisjointSets<IterDomain*>& permissiveResizeNodes() const {
|
96
|
+
return permissive_resize_nodes_;
|
97
|
+
}
|
98
|
+
const DisjointSets<IterDomain*>& innermostNodes() const {
|
99
|
+
return innermost_nodes_;
|
100
|
+
}
|
101
|
+
|
102
|
+
// Consumers and producers is not symmetric like the other sets
|
103
|
+
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
|
104
|
+
consumers() const {
|
105
|
+
return consumers_;
|
106
|
+
}
|
107
|
+
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
|
108
|
+
producers() const {
|
109
|
+
return producers_;
|
110
|
+
}
|
111
|
+
|
112
|
+
const DisjointSets<IterDomain*>& siblings() const {
|
113
|
+
return sibling_sets_;
|
114
|
+
}
|
115
|
+
|
116
|
+
const VectorOfUniqueEntries<IterDomain*>& allIds() const {
|
117
|
+
return all_ids_;
|
118
|
+
}
|
119
|
+
|
120
|
+
const std::unordered_set<IterDomain*>& rfactorIds() const {
|
121
|
+
return rfactor_ids_;
|
122
|
+
}
|
123
|
+
|
124
|
+
// Returns if first and second are expressions through which the provided
|
125
|
+
// id_map have matching inputs (if forward), or outputs (if not forward).
|
126
|
+
// Returning true means the expressions are "the same", in terms they modify
|
127
|
+
// matching original extents, by the same amount.
|
128
|
+
static bool exprsMap(
|
129
|
+
Expr* first,
|
130
|
+
Expr* second,
|
131
|
+
bool forward,
|
132
|
+
const DisjointSets<IterDomain*>& id_map);
|
133
|
+
|
134
|
+
bool hasSelfMapping() const {
|
135
|
+
return self_mapping_info_.has_value();
|
136
|
+
}
|
137
|
+
|
138
|
+
// Update the LOOP nodes with resolved computeWith
|
139
|
+
void updateComputeWith(TensorView* compute_with_tv);
|
140
|
+
|
141
|
+
private:
|
142
|
+
void build(Fusion* fusion);
|
143
|
+
|
144
|
+
void initializeId(IterDomain* id, bool is_rfactor_id, bool is_loop_id);
|
145
|
+
|
146
|
+
// Checks if exprsMap then if forward will map outputs else inputs in exact
|
147
|
+
// and permissive map.
|
148
|
+
void mapThroughExpr(Expr* first, Expr* second, bool forward);
|
149
|
+
|
150
|
+
DisjointSets<IterDomain*> permissive_nodes_;
|
151
|
+
DisjointSets<IterDomain*> exact_nodes_;
|
152
|
+
DisjointSets<IterDomain*> almost_exact_nodes_;
|
153
|
+
DisjointSets<IterDomain*> loop_nodes_;
|
154
|
+
DisjointSets<IterDomain*> permissive_resize_nodes_;
|
155
|
+
DisjointSets<IterDomain*> innermost_nodes_;
|
156
|
+
|
157
|
+
// Consumers and producers is not symmetric like the other sets.
|
158
|
+
// Mapping is based on the most permissive map, i.e., the
|
159
|
+
// permissive-resize map.
|
160
|
+
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
|
161
|
+
consumers_;
|
162
|
+
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
|
163
|
+
producers_;
|
164
|
+
|
165
|
+
DisjointSets<IterDomain*> sibling_sets_;
|
166
|
+
|
167
|
+
VectorOfUniqueEntries<IterDomain*> all_ids_;
|
168
|
+
|
169
|
+
// This used to only have non-reduction rfactor IDs. Changed to
|
170
|
+
// include reduction rfactor IDs as well at PR #2562
|
171
|
+
std::unordered_set<IterDomain*> rfactor_ids_;
|
172
|
+
|
173
|
+
std::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
|
174
|
+
self_mapping_info_ = std::nullopt;
|
175
|
+
|
176
|
+
// Temporary interface exposure for validating IdModel
|
177
|
+
friend class IdModelValidator;
|
178
|
+
};
|
179
|
+
|
180
|
+
using CircularBufferIndices = std::unordered_map<CircularBufferLoopStage, Val*>;
|
181
|
+
|
182
|
+
class ComputeAtMap {
|
183
|
+
public:
|
184
|
+
ComputeAtMap() = delete;
|
185
|
+
ComputeAtMap(const ComputeAtMap&) = delete;
|
186
|
+
ComputeAtMap& operator=(const ComputeAtMap&) = delete;
|
187
|
+
ComputeAtMap(ComputeAtMap&&) = default;
|
188
|
+
ComputeAtMap& operator=(ComputeAtMap&&) = default;
|
189
|
+
NVF_API ComputeAtMap(Fusion* fusion, bool allow_self_mapping = false);
|
190
|
+
|
191
|
+
//! Run through disjoint sets in the LOOP map, make sure there's only one
|
192
|
+
//! non-serial parallel type in each disjoint set, set the parallel type of
|
193
|
+
//! all IterDomains in the disjoint set to that PType.
|
194
|
+
void validateAndPropagatePType();
|
195
|
+
|
196
|
+
//! Run through disjoint sets in the LOOP map and allocate the index
|
197
|
+
//! variable for the associated for loop that will be generated
|
198
|
+
//! for each disjoint sets in the loop map. This pre-allocation makes
|
199
|
+
//! 2 key assumptions about computeAt map that would very likely be
|
200
|
+
//! long term invariant:
|
201
|
+
//! 1. All kir::forloop created in the lowering pass should belong
|
202
|
+
//! to one of the disjoint sets in loop map.
|
203
|
+
//! 2. The lowering pass will *never* create a loop nest with 2
|
204
|
+
//! different nesting levels mapped together, i.e. the case below
|
205
|
+
//! never occurs:
|
206
|
+
//! for i in IterDomain1
|
207
|
+
//! for j in IterDomain2
|
208
|
+
//! ...
|
209
|
+
//! With loop_map.areMapped(IterDomain1, IterDomain2) == true.
|
210
|
+
//! Under this condition, we can pre-allocate all required index
|
211
|
+
//! variable integers before creating any kir::forloop, and this
|
212
|
+
//! would help optimizing the generated integer math for indexing.
|
213
|
+
void allocateIndexVariables();
|
214
|
+
|
215
|
+
//! Returns if id0 and id1 are mapped to each other with provided
|
216
|
+
//! IdMappingMode
|
217
|
+
NVF_API bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode)
|
218
|
+
const;
|
219
|
+
|
220
|
+
//! Returns an iter domain that is the maximum expanded size of all iter
|
221
|
+
//! domains the one provided maps to. Useful for opening loops to the correct
|
222
|
+
//! iteration size. Not guarenteed to return the same ID every call, but is
|
223
|
+
//! guarenteed to return iter domains in the same disjoint set.
|
224
|
+
NVF_API IterDomain* getConcreteMappedID(IterDomain* id, IdMappingMode mode)
|
225
|
+
const;
|
226
|
+
|
227
|
+
//! Returns a list of expressions that produce the iter domains of all exact
|
228
|
+
//! mapped id's to 'id'. Expressions that are the same exact transformations
|
229
|
+
//! are deduplicated in the returned expressions.
|
230
|
+
std::vector<Expr*> uniqueExactDefinitions(IterDomain* id) const {
|
231
|
+
auto disjoint_set = disjointSetOf(id, IdMappingMode::EXACT);
|
232
|
+
auto unique_exact_definition_it =
|
233
|
+
unique_exact_definitions_.find(disjoint_set);
|
234
|
+
if (unique_exact_definition_it == unique_exact_definitions_.end()) {
|
235
|
+
return {};
|
236
|
+
}
|
237
|
+
return unique_exact_definition_it->second;
|
238
|
+
}
|
239
|
+
|
240
|
+
//! Returns a list of expressions that *use* the iter domains of all exact
|
241
|
+
//! mapped id's to 'id'. Expressions that are the same exact transformations
|
242
|
+
//! are deduplicated in the returned expressions.
|
243
|
+
std::vector<Expr*> uniqueExactUses(IterDomain* id) const {
|
244
|
+
auto disjoint_set = disjointSetOf(id, IdMappingMode::EXACT);
|
245
|
+
auto unique_exact_use_it = unique_exact_uses_.find(disjoint_set);
|
246
|
+
if (unique_exact_use_it == unique_exact_uses_.end()) {
|
247
|
+
return {};
|
248
|
+
}
|
249
|
+
return unique_exact_use_it->second;
|
250
|
+
}
|
251
|
+
|
252
|
+
// Prints mapping information, forwards to an internal IterDomainGraph
|
253
|
+
std::string toString() const;
|
254
|
+
|
255
|
+
// Returns if the provided ID is an rfactor id
|
256
|
+
bool isRfactor(IterDomain* ref_id) const;
|
257
|
+
|
258
|
+
// Returns all logical domains in rfactor_concrete_count_reset_domains_ that
|
259
|
+
// are in the disjoint set of the provided IterDomain. This will be every
|
260
|
+
// rfactor ID the provided ID "depends" on in the map.
|
261
|
+
std::vector<IterDomain*> getLogicalDomainsOfIdGroup(
|
262
|
+
IterDomain* ref_id,
|
263
|
+
IdMappingMode mode) const;
|
264
|
+
|
265
|
+
const IterDomainGraph& idGraph() const {
|
266
|
+
return id_graph_;
|
267
|
+
}
|
268
|
+
|
269
|
+
//! Get the ID sets for a provided IdMappingMode
|
270
|
+
const DisjointSets<IterDomain*>& getIdSets(IdMappingMode mode) const;
|
271
|
+
|
272
|
+
// Returns if the ID actually has a disjoint set meaning it has been processed
|
273
|
+
// in the creation of the compute at map.
|
274
|
+
bool idExistsInMap(IterDomain* id, IdMappingMode mode = IdMappingMode::EXACT)
|
275
|
+
const;
|
276
|
+
|
277
|
+
//! Returns the pre-allocated index variable integer used in
|
278
|
+
//! the ForLoop corresponding to the given IterDomain.
|
279
|
+
//! this interface is only valid if the ID has a loop mapping,
|
280
|
+
//! ca_map will throw exceptions if given iterdomain doesn't
|
281
|
+
//! have a loop map entry.
|
282
|
+
Val* getIndexVariable(
|
283
|
+
IterDomain* id,
|
284
|
+
CircularBufferLoopStage circular_buffer_loop_stage =
|
285
|
+
CircularBufferLoopStage::NotApplicable) const;
|
286
|
+
|
287
|
+
// Returns if expr_1 and expr_2 have exact mapped IterDomains in
|
288
|
+
// inputs/outputs (order matters) and if the expressions have matching
|
289
|
+
// parameters.
|
290
|
+
bool areExactExprs(Expr* expr_1, Expr* expr_2);
|
291
|
+
|
292
|
+
// Produce the disjoint set containing provided id with mapping mode.
|
293
|
+
const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& disjointSetOf(
|
294
|
+
IterDomain* id,
|
295
|
+
IdMappingMode mode) const;
|
296
|
+
|
297
|
+
// Update the LOOP map with resolved computeWith
|
298
|
+
void updateComputeWith(TensorView* compute_with_tv);
|
299
|
+
|
300
|
+
// Traverses through definitions of exact maps (unique_exact_definitions_) to
|
301
|
+
// all input ID's from provided exact_sets. Returns all the exact map concrete
|
302
|
+
// IDs of all the exact sets that on the path to and including the inputs
|
303
|
+
// required to construct the exact concrete id of of_id.
|
304
|
+
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
|
305
|
+
getAllDisjointSetProducers(
|
306
|
+
const VectorOfUniqueEntries<
|
307
|
+
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>& exact_sets)
|
308
|
+
const;
|
309
|
+
|
310
|
+
// Traverses through uses of exact maps (unique_exact_uses_) to
|
311
|
+
// all input ID's from provided exact_sets. Returns all the exact map concrete
|
312
|
+
// IDs of all the exact sets that on the path to and including the inputs
|
313
|
+
// required to construct the exact concrete id of of_id.
|
314
|
+
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
|
315
|
+
getAllDisjointSetConsumers(
|
316
|
+
const VectorOfUniqueEntries<
|
317
|
+
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>& exact_sets)
|
318
|
+
const;
|
319
|
+
|
320
|
+
private:
|
321
|
+
// Traverses through definitions of exact maps (unique_exact_definitions_) to
|
322
|
+
// input ID's from provided ID. Returns all the exact map concrete IDs of the
|
323
|
+
// exact sets that are inputs required to construct the exact concrete id of
|
324
|
+
// of_id.
|
325
|
+
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
|
326
|
+
getInputDisjointSetsOf(IterDomain* of_id, bool stop_at_rfactor = true);
|
327
|
+
|
328
|
+
// Build id_graph_
|
329
|
+
void build(Fusion* fusion);
|
330
|
+
|
331
|
+
// Build concrete_id_cache_
|
332
|
+
// Build a single entry in concrete_cache_id_
|
333
|
+
IterDomain* computeConcreteId(IterDomain* id, IdMappingMode mode);
|
334
|
+
void buildConcreteIds();
|
335
|
+
|
336
|
+
// Relies on concrete_id_cache_, buildConcreteIds() must be run before this.
|
337
|
+
void buildUniqueExactExprMaps();
|
338
|
+
|
339
|
+
// Should be built once and never modified again.
|
340
|
+
IterDomainGraph id_graph_;
|
341
|
+
|
342
|
+
// Used specifically for concrete ID computation
|
343
|
+
ConcretizedBroadcastDomains concretized_bcasts_;
|
344
|
+
|
345
|
+
// Prevent needing to recompute concrete_id's in compute at map.
|
346
|
+
// VectorOfUniqueEntries is unique across mapping modes, so don't need to use
|
347
|
+
// mapping mode directly in this cache. const
|
348
|
+
// VectorOfUniqueEntries<IterDomain*>& is what's returned by
|
349
|
+
// ComputeAtMap::disjointSetOf which can be used directly.
|
350
|
+
std::unordered_map<
|
351
|
+
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
|
352
|
+
IterDomain*>
|
353
|
+
concrete_id_cache_;
|
354
|
+
|
355
|
+
// Unique expressions operating on exact disjoint set. For each IterDomain in
|
356
|
+
// each exact disjoint set will log its definition in the std::vector<Expr*>.
|
357
|
+
// If another expression is already in the set where inputs and outputs
|
358
|
+
// exactly match with the expression to add along with the other parameters of
|
359
|
+
// the transformation (like split's factor, or swizzles types) then the
|
360
|
+
// expression will not be added as it would be a "duplicate" transformation.
|
361
|
+
std::unordered_map<
|
362
|
+
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
|
363
|
+
std::vector<Expr*>>
|
364
|
+
unique_exact_definitions_;
|
365
|
+
|
366
|
+
// Same as unique_exact_definitions_ but for uses instead of definitions
|
367
|
+
std::unordered_map<
|
368
|
+
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
|
369
|
+
std::vector<Expr*>>
|
370
|
+
unique_exact_uses_;
|
371
|
+
|
372
|
+
//! Allocated Loop index variable through the CA map.
|
373
|
+
//! only valid for disjoint sets on the loop ca map.
|
374
|
+
std::unordered_map<const VectorOfUniqueEntries<IterDomain*>*, Val*>
|
375
|
+
loop_index_variable_map_;
|
376
|
+
|
377
|
+
//! Allocated loop indices for circular buffer loop.
|
378
|
+
//! only valid for disjoint sets on the loop ca map
|
379
|
+
//! that have circular buffer-ed iterdomains.
|
380
|
+
using CircularBufferIndicesPtr = std::unique_ptr<CircularBufferIndices>;
|
381
|
+
std::unordered_map<
|
382
|
+
const VectorOfUniqueEntries<IterDomain*>*,
|
383
|
+
CircularBufferIndicesPtr>
|
384
|
+
circular_buffered_loop_index_variable_map_;
|
385
|
+
|
386
|
+
// Shortcut to access the fusion this computeAt map was
|
387
|
+
// built from.
|
388
|
+
Fusion* fusion_;
|
389
|
+
|
390
|
+
// Temporary interface exposure for validating IdModel
|
391
|
+
friend class IdModelValidator;
|
392
|
+
};
|
393
|
+
|
394
|
+
} // namespace nvfuser
|