nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
#include <exceptions.h>
|
10
|
+
#include <visibility.h>
|
11
|
+
|
12
|
+
#include <python_frontend/fusion_record.h>
|
13
|
+
#include <runtime/fusion_executor_cache.h>
|
14
|
+
#include <scheduler/compile_time_info.h>
|
15
|
+
#include <scheduler/registry.h>
|
16
|
+
|
17
|
+
#include <memory>
|
18
|
+
#include <mutex>
|
19
|
+
|
20
|
+
namespace nvfuser::python_frontend {
|
21
|
+
|
22
|
+
//! \struct UserSchedule
|
23
|
+
//! \brief A container to hold a scheduled Fusion IR as well as an executor
|
24
|
+
//! to contain the corresponding generated kernel.
|
25
|
+
struct UserSchedule {
|
26
|
+
UserSchedule(int64_t fusion_id, int64_t device_id);
|
27
|
+
|
28
|
+
//! Runtime information for schedulers
|
29
|
+
std::unique_ptr<SchedulerRuntimeInfo> runtime_info;
|
30
|
+
|
31
|
+
//! The scheduler heuristic for this UserSchedule
|
32
|
+
std::unique_ptr<SchedulerEntry> scheduler;
|
33
|
+
|
34
|
+
//! The parameters for scheduler heuristic.
|
35
|
+
std::unique_ptr<HeuristicParams> heuristic_params;
|
36
|
+
|
37
|
+
//! The compile-time data cache.
|
38
|
+
std::unique_ptr<HeuristicDataCache> data_cache;
|
39
|
+
|
40
|
+
//! Concretized, Scheduled Fusion IR
|
41
|
+
std::unique_ptr<Fusion> scheduled_fusion;
|
42
|
+
|
43
|
+
//! Generated kernel container
|
44
|
+
std::unique_ptr<KernelExecutor> executor;
|
45
|
+
|
46
|
+
//! ID of fusion in python frontend fusion cache
|
47
|
+
int64_t fusion_id_ = -1;
|
48
|
+
|
49
|
+
//! device ID for this user schedule
|
50
|
+
int64_t device_id_ = -1;
|
51
|
+
|
52
|
+
//! Get scheduler runtime info for UserSchedule
|
53
|
+
SchedulerRuntimeInfo* runtimeInfo() {
|
54
|
+
NVF_ERROR(
|
55
|
+
runtime_info != nullptr,
|
56
|
+
"Requires SchedulerRuntimeInfo to use heuristic schedulers");
|
57
|
+
return runtime_info.get();
|
58
|
+
}
|
59
|
+
|
60
|
+
//! Get Fusion for UserSchedule
|
61
|
+
Fusion* fusion() {
|
62
|
+
NVF_ERROR(
|
63
|
+
scheduled_fusion != nullptr,
|
64
|
+
"Requires Fusion to use heuristic schedulers");
|
65
|
+
return scheduled_fusion.get();
|
66
|
+
}
|
67
|
+
|
68
|
+
//! Return if we can schedule FusionDefinition with heuristic.
|
69
|
+
bool canSchedule(const SchedulerType& heuristic);
|
70
|
+
|
71
|
+
//! Return if we can schedule FusionDefinition with heuristic along with any
|
72
|
+
//! debug messages from canScheduleRejectReason.
|
73
|
+
std::tuple<bool, std::string> canScheduleDebug(
|
74
|
+
const SchedulerType& scheduler_type);
|
75
|
+
|
76
|
+
//! Create scheduler and get heuristic parameters for fusion.
|
77
|
+
HeuristicParams* computeHeuristics(SchedulerType scheduler_type);
|
78
|
+
|
79
|
+
//! Schedule fusion with selected heuristics and scheduler.
|
80
|
+
void schedule();
|
81
|
+
|
82
|
+
//! Schedule fusion with heuristic.
|
83
|
+
void scheduleWithType(SchedulerType scheduler_type);
|
84
|
+
};
|
85
|
+
|
86
|
+
//! \struct FusionSchedules
|
87
|
+
//! \brief A container for auto generated and user defined schedules
|
88
|
+
//! that correspond to compiled kernels for each complete Fusion Definition.
|
89
|
+
struct FusionSchedules {
|
90
|
+
FusionSchedules(int64_t fusion_id = 0);
|
91
|
+
Fusion* preschedFusion();
|
92
|
+
|
93
|
+
//! Schedules Automatically generated by nvFuser for dynamic inputs. (default)
|
94
|
+
//! NOTE: The FusionExecutorCache also holds the Unscheduled Fusion IR
|
95
|
+
std::unique_ptr<FusionExecutorCache> auto_gen_schedules;
|
96
|
+
//! Schedules defined by the user for specific input sizes.
|
97
|
+
//! They are also generated per device as all devices may not be the same.
|
98
|
+
//! Key: Input Encoding hash of Fusion inputs as is created by the
|
99
|
+
//! InputsIdLookup struct found inside of the FusionCache.
|
100
|
+
//! Value: A vector based on device_id of User Defined Fusion Schedules.
|
101
|
+
std::unordered_map<size_t, std::unordered_map<int, UserSchedule>>
|
102
|
+
user_def_schedules;
|
103
|
+
//! Keeps a pointer to the last scheduled Fusion IR for printing
|
104
|
+
Fusion* last_user_def_scheduled_ir;
|
105
|
+
//! Keeps a pointer to the last executed executor for printing its cuda kernel
|
106
|
+
KernelExecutor* last_user_def_executor;
|
107
|
+
//! For thread-Safe locking of Fusion Schedules
|
108
|
+
std::mutex scheds_lock;
|
109
|
+
//! ID of fusion in python frontend fusion cache
|
110
|
+
int64_t fusion_id_ = -1;
|
111
|
+
//! Fusion IDs of input arguments for FusionState
|
112
|
+
std::vector<int64_t> inputs_fid_;
|
113
|
+
//! IDs for Extents for TensorView input arguments for FusionState
|
114
|
+
std::vector<int64_t> extents_fid_;
|
115
|
+
//! Fusion IDs of output arguments for FusionState
|
116
|
+
std::vector<int64_t> outputs_fid_;
|
117
|
+
//! Map Fusion Val to its corresponding FusionDefinition index
|
118
|
+
std::unordered_map<const Val*, int64_t> map_value_to_fid_;
|
119
|
+
};
|
120
|
+
|
121
|
+
//! \struct TrieNode
|
122
|
+
//! \brief Is the container for a Node in a prefix tree or trie
|
123
|
+
//! where each node represents a statement in a fusion definition and
|
124
|
+
//! the leaf Nodes represent a complete Fusion that is cached.
|
125
|
+
|
126
|
+
struct TrieNode {
|
127
|
+
TrieNode(
|
128
|
+
RecordFunctor* rec,
|
129
|
+
TrieNode* _parent = nullptr,
|
130
|
+
size_t _fusion_id = 0);
|
131
|
+
|
132
|
+
// Queries whether the entry denotes a leaf node which also represents
|
133
|
+
// a the end of Fusion entry in the cache.
|
134
|
+
bool isTerminal() const;
|
135
|
+
//! getException returns the cached Exception raise during construction of
|
136
|
+
//! Fusion. It returns std::nullopt if the no error thrown. This function is
|
137
|
+
//! called at the end of FusionDefinition::finalizeDefinition to avoid
|
138
|
+
//! silently using a bad FusionDefinition cached in FusionCache.
|
139
|
+
std::optional<std::string> getException();
|
140
|
+
//! setException is called to record exception message thrown during
|
141
|
+
//! construction of Fusion.
|
142
|
+
void setException(const char* e);
|
143
|
+
//! Serialize TrieNode using flatbuffers
|
144
|
+
NVF_API flatbuffers::Offset<serde::TrieNode> serialize(
|
145
|
+
flatbuffers::FlatBufferBuilder& builder,
|
146
|
+
const std::map<RecordFunctor*, size_t>&
|
147
|
+
map_record_functor_to_trie_node_id);
|
148
|
+
|
149
|
+
//! An entry's primary data is the record it holds
|
150
|
+
std::unique_ptr<RecordFunctor> record;
|
151
|
+
//! A hash map of the children for the current node.
|
152
|
+
//! The hash map hashes a pointer to a RecordFunctor because
|
153
|
+
//! the hash function is virtual.
|
154
|
+
std::unordered_map<RecordFunctor*, std::unique_ptr<TrieNode>> children;
|
155
|
+
//! An index into FusionCache's vector of nvFuser object that holds an
|
156
|
+
//! unscheduled Fusion. The id is only valid if the entry is terminal.
|
157
|
+
size_t fusion_id;
|
158
|
+
//! Count of times the Entry is traversed
|
159
|
+
size_t visits;
|
160
|
+
//! Parent node for printing
|
161
|
+
TrieNode* parent;
|
162
|
+
//! For thread-Safe locking of a node
|
163
|
+
std::mutex trie_node_lock;
|
164
|
+
//! exception is used to track if we failed to create a valid fusion for
|
165
|
+
//! FusionDefinition at this given TrieNode
|
166
|
+
std::optional<std::string> exception = std::nullopt;
|
167
|
+
};
|
168
|
+
|
169
|
+
//! \class FusionCache
|
170
|
+
//! \brief A singleton class used in the nvFuser python interface
|
171
|
+
//! to manage the caching of fusions.
|
172
|
+
//!
|
173
|
+
//! The fusion cache implements a prefix tree (trie) of records in order to
|
174
|
+
//! cache fusions. A leaf of the tree with a terminal node contains a
|
175
|
+
//! container for caching the kernels generated for specific fusions.
|
176
|
+
//!
|
177
|
+
//! \todo
|
178
|
+
//! Add the ability to evict a fusion. There is currently a max number
|
179
|
+
//! of fusions that is checked to prevent a runaway case.
|
180
|
+
//!
|
181
|
+
//! \note
|
182
|
+
//! Thread-Safety is assured by the Python GIL. If a no-GIL python is used
|
183
|
+
//! then further scrutiny needs to be applied to the mutexes used to limit
|
184
|
+
//! acccess to the singleton pointer, node creation, and user schedule
|
185
|
+
//! creation. Otherwise, the Python GIL provides a natural thread based mutex
|
186
|
+
//! that does not allow for multiple threads to interact.
|
187
|
+
|
188
|
+
class FusionCache {
|
189
|
+
//! The constructor is private given the FusionCache is only constructed
|
190
|
+
//! as a singleton.
|
191
|
+
FusionCache(size_t max_fusions, std::optional<int64_t> selected_device);
|
192
|
+
|
193
|
+
public:
|
194
|
+
//! Copy and Assignment of the FusionCache is not supported
|
195
|
+
//! clang-tidy: deleted member function should be public
|
196
|
+
FusionCache(const FusionCache&) = delete;
|
197
|
+
FusionCache& operator=(const FusionCache&) = delete;
|
198
|
+
|
199
|
+
//! The next 4 public methods are the python interface methods
|
200
|
+
|
201
|
+
//! Gets a pointer to the singleton and creates a new one if necessary
|
202
|
+
NVF_API static FusionCache* get(
|
203
|
+
size_t max_fusions = 16384,
|
204
|
+
std::optional<int64_t> selected_device = std::nullopt,
|
205
|
+
bool load_from_default_workspace = true);
|
206
|
+
//! Number of fusions cached
|
207
|
+
NVF_API size_t numFusions() const;
|
208
|
+
//! Get device associated with this FusionCache
|
209
|
+
NVF_API std::optional<int64_t> deviceId() const;
|
210
|
+
//! print cache contents
|
211
|
+
NVF_API void print(std::ostream& os) const;
|
212
|
+
//! print cache stats
|
213
|
+
NVF_API void stats(std::ostream& os) const;
|
214
|
+
//! Reset Cache to an empty state
|
215
|
+
NVF_API static void reset();
|
216
|
+
|
217
|
+
//! Serialize Fusion Cache using flatbuffers
|
218
|
+
NVF_API void serialize(std::string filename) const;
|
219
|
+
//! Deserialize Fusion Cache using flatbuffers
|
220
|
+
NVF_API void deserialize(std::string filename);
|
221
|
+
|
222
|
+
//! The rest of the public methods are only used in C++
|
223
|
+
|
224
|
+
//! Thread-Unsafe: Queries the current trie node to see if a record matches
|
225
|
+
//! one of its children
|
226
|
+
NVF_API std::optional<TrieNode*> queryChildren(
|
227
|
+
TrieNode* node,
|
228
|
+
RecordFunctor* rec) const;
|
229
|
+
//! Query a Fusion's Schedules based on fusion id or cache id
|
230
|
+
FusionSchedules* queryFusionSchedules(size_t fusion_id) const;
|
231
|
+
//! Determine if a user schedule exists for given inputs.
|
232
|
+
bool existUserSchedule(
|
233
|
+
const FusionSchedules* scheds,
|
234
|
+
const at::ArrayRef<c10::IValue>& inputs,
|
235
|
+
int device);
|
236
|
+
//! Lookup the User Schedule Id and return null if one does not exist.
|
237
|
+
//! NOTE: this method cannot be const because the InputsIdLookup can
|
238
|
+
//! cause a modification to that data member for cache eviction.
|
239
|
+
std::optional<size_t> queryUserScheduleId(
|
240
|
+
const FusionSchedules* scheds,
|
241
|
+
const at::ArrayRef<c10::IValue>& inputs);
|
242
|
+
//! Lookup the User Schedule based on Id
|
243
|
+
const UserSchedule& queryUserSchedule(
|
244
|
+
const FusionSchedules* scheds,
|
245
|
+
size_t id,
|
246
|
+
int device) const;
|
247
|
+
//! Thread-Safe: Creates a child node for the current cache entry and an
|
248
|
+
//! optional fusion_id is returned if the new entry is terminal
|
249
|
+
NVF_API TrieNode* createChild(TrieNode* node, RecordFunctor* rec);
|
250
|
+
//! Lookup the User Schedule based on Id
|
251
|
+
UserSchedule* createUserSchedule(
|
252
|
+
FusionSchedules* scheds,
|
253
|
+
const at::ArrayRef<c10::IValue>& inputs,
|
254
|
+
int device,
|
255
|
+
bool overwrite_existing_schedule = false);
|
256
|
+
//! Get the root Trie ptr
|
257
|
+
NVF_API TrieNode* rootTriePtr();
|
258
|
+
|
259
|
+
private:
|
260
|
+
//! The static pointer to the FusionCache
|
261
|
+
static FusionCache* singleton_;
|
262
|
+
//! Lock for accessing the singleton by multiple threads
|
263
|
+
static std::mutex singleton_lock_;
|
264
|
+
|
265
|
+
//! The max allowed number of fusions in the cache
|
266
|
+
size_t max_fusions_;
|
267
|
+
//! A separate process is created for each device in a distributed setting.
|
268
|
+
//! Each FusionCache becomes associated with a device.
|
269
|
+
std::optional<int64_t> device_id_;
|
270
|
+
//! The root (start) of the prefix tree to start a cache look up of a given
|
271
|
+
//! fusion definition.
|
272
|
+
std::unique_ptr<TrieNode> root_;
|
273
|
+
//! A vector of nvFuser Fusion IR fusions.
|
274
|
+
std::vector<std::unique_ptr<FusionSchedules>> fusions_;
|
275
|
+
//! A vector of Terminal trie nodes for Stats collection
|
276
|
+
std::vector<TrieNode*> terminal_nodes_;
|
277
|
+
|
278
|
+
//! Items specifically to aid user defined schedules these data members
|
279
|
+
//! are for the mechanics of user schedule usage and don't make sense as
|
280
|
+
//! part of an abstraction
|
281
|
+
|
282
|
+
// Inputs for user defined schedules are encoded into an integer Id
|
283
|
+
// NOTE: I would prefer this be per FusionSchedules object but the container
|
284
|
+
// is not allowed to be copied or moved.
|
285
|
+
InputsIdLookup user_def_input_encodings_;
|
286
|
+
};
|
287
|
+
|
288
|
+
//! Serialize Fusion Cache to common workspace
|
289
|
+
//! /tmp/nvfuser_kernel_db/nvf_serde_[cuda_major]_[cuda_minor]_[nvrtc_major]_[nvrtc_minor]
|
290
|
+
//!
|
291
|
+
//! '''python
|
292
|
+
//! # Use atexit to automatically call serialize on program exit
|
293
|
+
//! import atexit
|
294
|
+
//! atexit.register(nvfuser.serialize)
|
295
|
+
//! '''
|
296
|
+
NVF_API void serialize();
|
297
|
+
|
298
|
+
} // namespace nvfuser::python_frontend
|
@@ -0,0 +1,372 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <functional>
|
11
|
+
#include <iostream>
|
12
|
+
#include <unordered_map>
|
13
|
+
|
14
|
+
#include <exceptions.h>
|
15
|
+
#include <python_frontend/distributed_tensor.h>
|
16
|
+
#include <python_frontend/fusion_state.h>
|
17
|
+
#include <python_frontend/segmentation.h>
|
18
|
+
#include <visibility.h>
|
19
|
+
|
20
|
+
namespace nvfuser::python_frontend {
|
21
|
+
|
22
|
+
class FusionCache;
|
23
|
+
class FusionDefinition;
|
24
|
+
class FusionInterface;
|
25
|
+
class FusionState;
|
26
|
+
struct RecordFunctor;
|
27
|
+
class SegmentationState;
|
28
|
+
struct TrieNode;
|
29
|
+
struct UserSchedule;
|
30
|
+
|
31
|
+
//! This is helper function used to print a python formated
|
32
|
+
//! Fusion IR DataType when printing a fusion definition.
|
33
|
+
|
34
|
+
NVF_API const char* dtypeToPyString(PrimDataType t);
|
35
|
+
|
36
|
+
//! The Tensor and Scalar classes are used to define separate function
|
37
|
+
//! signatures in the FusionDefinition to identify the appropriate Operator
|
38
|
+
//! function.
|
39
|
+
//!
|
40
|
+
//! Example:
|
41
|
+
//!
|
42
|
+
//! add(Tensor* arg1, Tensor* arg2) -> Tensor*
|
43
|
+
//! add(Tensor* arg1, Val* arg2) -> Tensor*
|
44
|
+
//! add(Val* arg1, Val* arg2) -> Val*
|
45
|
+
struct Tensor {
|
46
|
+
Tensor(size_t _index, size_t _dims, FusionDefinition* _fd)
|
47
|
+
: index(_index), dims(_dims), fusion_definition(_fd) {}
|
48
|
+
|
49
|
+
size_t operator()() const {
|
50
|
+
return index;
|
51
|
+
}
|
52
|
+
|
53
|
+
bool operator==(const Tensor& other) const {
|
54
|
+
if (index != other.index) {
|
55
|
+
return false;
|
56
|
+
}
|
57
|
+
|
58
|
+
if (dims != other.dims) {
|
59
|
+
return false;
|
60
|
+
}
|
61
|
+
|
62
|
+
if (fusion_definition != other.fusion_definition) {
|
63
|
+
return false;
|
64
|
+
}
|
65
|
+
return true;
|
66
|
+
}
|
67
|
+
|
68
|
+
bool operator!=(const Tensor& other) const {
|
69
|
+
return !(*this == other);
|
70
|
+
}
|
71
|
+
|
72
|
+
//! A unique index to identifiy each recorded state item.
|
73
|
+
size_t index;
|
74
|
+
size_t dims;
|
75
|
+
|
76
|
+
//! Pointer to the FusionDefinition used to create this tensor
|
77
|
+
//! The FusionDefinition pointer is necessary to enable special
|
78
|
+
//! dunder operations (ie __add__()) from the python API.
|
79
|
+
FusionDefinition* fusion_definition;
|
80
|
+
};
|
81
|
+
|
82
|
+
struct Scalar {
|
83
|
+
Scalar(size_t _index, FusionDefinition* _fd)
|
84
|
+
: index(_index), fusion_definition(_fd) {}
|
85
|
+
|
86
|
+
size_t operator()() const {
|
87
|
+
return index;
|
88
|
+
}
|
89
|
+
|
90
|
+
bool operator==(const Scalar& other) const {
|
91
|
+
if (index != other.index) {
|
92
|
+
return false;
|
93
|
+
}
|
94
|
+
|
95
|
+
if (fusion_definition != other.fusion_definition) {
|
96
|
+
return false;
|
97
|
+
}
|
98
|
+
return true;
|
99
|
+
}
|
100
|
+
|
101
|
+
bool operator!=(const Scalar& other) const {
|
102
|
+
return !(*this == other);
|
103
|
+
}
|
104
|
+
|
105
|
+
//! A unique index to identifiy each recorded state item.
|
106
|
+
size_t index;
|
107
|
+
|
108
|
+
//! Pointer to the FusionDefinition used to create this scalar
|
109
|
+
//! The FusionDefinition pointer is necessary to enable special
|
110
|
+
//! dunder operations (ie __add__()) from the python API.
|
111
|
+
FusionDefinition* fusion_definition;
|
112
|
+
};
|
113
|
+
|
114
|
+
struct Vector {
|
115
|
+
Vector(size_t _index, size_t _size, FusionDefinition* _fd)
|
116
|
+
: index(_index), size(_size), fusion_definition(_fd) {}
|
117
|
+
|
118
|
+
size_t operator()() const {
|
119
|
+
return index;
|
120
|
+
}
|
121
|
+
|
122
|
+
bool operator==(const Vector& other) const {
|
123
|
+
if (index != other.index) {
|
124
|
+
return false;
|
125
|
+
}
|
126
|
+
|
127
|
+
if (size != other.size) {
|
128
|
+
return false;
|
129
|
+
}
|
130
|
+
|
131
|
+
if (fusion_definition != other.fusion_definition) {
|
132
|
+
return false;
|
133
|
+
}
|
134
|
+
return true;
|
135
|
+
}
|
136
|
+
|
137
|
+
bool operator!=(const Vector& other) const {
|
138
|
+
return !(*this == other);
|
139
|
+
}
|
140
|
+
|
141
|
+
//! A unique index to identifiy each recorded state item.
|
142
|
+
size_t index;
|
143
|
+
//! Elements in the vector
|
144
|
+
size_t size;
|
145
|
+
|
146
|
+
//! Pointer to the FusionDefinition used to create this scalar
|
147
|
+
FusionDefinition* fusion_definition;
|
148
|
+
};
|
149
|
+
|
150
|
+
//! FusionDefinition defines the C++ side of a Python Context manager to
|
151
|
+
//! encapsulate the definition of fusion operations.
|
152
|
+
//!
|
153
|
+
//! The FusionDefinition records the state definitions and operations prior
|
154
|
+
//! to exiting the context manager. Upon exit, the operations are queried
|
155
|
+
//! in a cache and the recorded records are used to build an nvFuser Fusion
|
156
|
+
//! object if the definition missed in the cache.
|
157
|
+
//!
|
158
|
+
//! The nested Operators class was designed to allow the user to query all the
|
159
|
+
//! available Operators in the FusionDefinition via python help.
|
160
|
+
//!
|
161
|
+
//! Example:
|
162
|
+
//! help(FusionDefinition.Operators)
|
163
|
+
class NVF_API FusionDefinition : public FusionState {
|
164
|
+
public:
|
165
|
+
FusionDefinition(std::optional<size_t> id, size_t max_length = 256);
|
166
|
+
|
167
|
+
// The copy/move/assign constructors/operators are removed
|
168
|
+
FusionDefinition(const FusionDefinition& fd) = delete;
|
169
|
+
FusionDefinition(FusionDefinition&& fd) = delete;
|
170
|
+
FusionDefinition& operator=(const FusionDefinition& fd) = delete;
|
171
|
+
FusionDefinition& operator=(FusionDefinition&& fd) = delete;
|
172
|
+
|
173
|
+
//! Enter Python Context Manager -- Reset trie for new cache lookup
|
174
|
+
NVF_API FusionDefinition* setupDefinition();
|
175
|
+
//! Exit Python Context Manager -- Triggers Fusion IR build if it is not
|
176
|
+
//! cached
|
177
|
+
NVF_API void finalizeDefinition();
|
178
|
+
//! Check that a user schedule exists for FusionDefinition and input
|
179
|
+
//! arguments on device.
|
180
|
+
NVF_API bool existSchedule(const at::ArrayRef<c10::IValue>& inputs);
|
181
|
+
//! Setup user scheduling of a fusion
|
182
|
+
//! Copies fusion object and sets up FusionGuard
|
183
|
+
NVF_API void setupSchedule(
|
184
|
+
const at::ArrayRef<c10::IValue>& inputs,
|
185
|
+
bool overwrite_existing_schedule = false);
|
186
|
+
//! Finalized use scheduling of a fusion
|
187
|
+
//! resets FusionGuard, lowers IR to a kernel, compiles kernel
|
188
|
+
NVF_API void finalizeSchedule(const at::ArrayRef<c10::IValue>& inputs);
|
189
|
+
//! A hook that gets called right before
|
190
|
+
//! FusionDefinition.multidevice_schedule.
|
191
|
+
NVF_API void setupMultideviceSchedule();
|
192
|
+
//! A hook that gets called right after FusionDefinition.multidevice_schedule.
|
193
|
+
NVF_API void finalizeMultideviceSchedule();
|
194
|
+
//! Prints a python function representing the definition
|
195
|
+
NVF_API void print(std::ostream& os) const;
|
196
|
+
//! Executes a fusion if a valid definition or cache lookup occurred prior.
|
197
|
+
//!
|
198
|
+
//! This method returns a list of `DistributedTensor`s. Each
|
199
|
+
//! `DistributedTensor` is either the local view of a distributed tensor
|
200
|
+
//! (when the mesh is non-empty) or a non-distributed tensor
|
201
|
+
//! (when the mesh is empty).
|
202
|
+
//!
|
203
|
+
//! Alternatives considered:
|
204
|
+
//! 1. Return std::vector<std::variant<at::Tensor, DistributedTensor>>.
|
205
|
+
//! Because DistributedTensor can also represent a non-distributed tensor, I
|
206
|
+
//! chose the current API for simplicity -- C++ is more verbose than Python
|
207
|
+
//! when dealing with dynamic types.
|
208
|
+
//! 2. Return std::variant<std::vector<at::Tensor>,
|
209
|
+
//! std::vector<DistributedTensor>>. Same reason.
|
210
|
+
//! 3. Store output shardings (i.e. the mesh and the mesh-to-tenseor-axis
|
211
|
+
//! mapping) to a field of FusionDefinition and retrieve it using another
|
212
|
+
//! method. This would be similar to getDebugOutput. I didn't choose that
|
213
|
+
//! because it introduced a new state in the class that could get out of sync.
|
214
|
+
NVF_API std::vector<DistributedTensor> execute(
|
215
|
+
const at::ArrayRef<c10::IValue>& inputs,
|
216
|
+
std::optional<int8_t> device,
|
217
|
+
bool override_user_schedule,
|
218
|
+
bool capture_debug_output,
|
219
|
+
bool profile,
|
220
|
+
std::vector<std::string> _enable_options,
|
221
|
+
std::vector<std::string> _disable_options) const;
|
222
|
+
//! Return debugging output captured through exeuction with
|
223
|
+
//! capture_debug_output=true
|
224
|
+
std::optional<std::string> getDebugOutput() const {
|
225
|
+
return debug_output_;
|
226
|
+
}
|
227
|
+
// Returns the tolerances values based on reduction sizes.
|
228
|
+
NVF_API std::vector<std::pair<double, double>> getValTolerances(
|
229
|
+
const at::ArrayRef<c10::IValue>& inputs);
|
230
|
+
|
231
|
+
//! Return the unscheduled Fusion IR
|
232
|
+
NVF_API std::string fusionIr();
|
233
|
+
//! Return the user scheduled FusionIR;
|
234
|
+
NVF_API std::string userScheduleIr();
|
235
|
+
//! Return the Cuda code for the last executed set of inputs
|
236
|
+
NVF_API std::string lastCudaCode(
|
237
|
+
bool intrinsic_code,
|
238
|
+
bool override_user_schedule) const;
|
239
|
+
//! Return the Cuda code for the given inputs
|
240
|
+
NVF_API std::string cudaCodeFor(
|
241
|
+
const at::ArrayRef<c10::IValue>& inputs,
|
242
|
+
bool intrinsic_code,
|
243
|
+
bool override_user_schedule) const;
|
244
|
+
//! Return the Cuda code for the last executed set of inputs
|
245
|
+
NVF_API std::string lastScheduledFusionIr(
|
246
|
+
bool tensor_transforms,
|
247
|
+
bool override_user_schedule) const;
|
248
|
+
//! Return the Cuda code for the given inputs
|
249
|
+
NVF_API std::string scheduledFusionIrFor(
|
250
|
+
const at::ArrayRef<c10::IValue>& inputs,
|
251
|
+
bool tensor_transforms,
|
252
|
+
bool override_user_schedule) const;
|
253
|
+
//! Return fusion id of defined FusionDefinition
|
254
|
+
NVF_API std::optional<size_t> id() const;
|
255
|
+
//! Prints the Prescheduled Fusion IR representation
|
256
|
+
void printMathIr();
|
257
|
+
|
258
|
+
bool completed() {
|
259
|
+
return id().has_value();
|
260
|
+
}
|
261
|
+
|
262
|
+
//! Return a prescheduled Fusion object
|
263
|
+
Fusion* preschedFusion();
|
264
|
+
|
265
|
+
//! Return UserSchedule struct if it exists
|
266
|
+
UserSchedule* userSchedule();
|
267
|
+
|
268
|
+
//! These methods are used to record the FusionDefinition for cache lookup
|
269
|
+
|
270
|
+
//! Defines a Tensor State Record
|
271
|
+
NVF_API Tensor addTensor(TensorView* tv);
|
272
|
+
//! Defines a Scalar State Record
|
273
|
+
NVF_API Scalar defineScalar();
|
274
|
+
//! Defines a Tensor State Record
|
275
|
+
NVF_API Tensor defineTensor(size_t dims);
|
276
|
+
//! Defines a Vector State Record
|
277
|
+
NVF_API Vector defineVector(size_t size);
|
278
|
+
//! Defines a Record that records the operation required to
|
279
|
+
//! build the corresponding Fusion IR operation on cache miss.
|
280
|
+
NVF_API void defineRecord(RecordFunctor* record);
|
281
|
+
//! Gets a Record State object
|
282
|
+
NVF_API State recordingState(size_t index) const;
|
283
|
+
//! Get all Tensors in FusionState.
|
284
|
+
NVF_API std::vector<Tensor> tensors();
|
285
|
+
|
286
|
+
//! Run segmentation algorithm on FusionDefinition. Returns the number of
|
287
|
+
//! segments.
|
288
|
+
NVF_API int64_t setupSegmentation(const at::ArrayRef<c10::IValue>& inputs);
|
289
|
+
//! Given an empty FusionDefinition and a segment id, buildSegment creates the
|
290
|
+
//! CPP Fusion, translates it to the python FusionDefinition, then return a
|
291
|
+
//! mapping from segment fusion state indices to the original fusion state
|
292
|
+
//! indices.
|
293
|
+
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
|
294
|
+
FusionDefinition& segment_fd,
|
295
|
+
int64_t segment_id);
|
296
|
+
//! After creating segments, destroy SegmentationState.
|
297
|
+
NVF_API void finalizeSegmentation();
|
298
|
+
|
299
|
+
private:
|
300
|
+
//! Returns the FusionCache Ptr that holds the cache of Fusions
|
301
|
+
FusionCache* fusionCache() const;
|
302
|
+
//! Composite operations can create hidden TensorViews in the CPP fusion
|
303
|
+
//! These TensorViews are not visible from python definition. This function
|
304
|
+
//! finds and adds them to FusionDefinition
|
305
|
+
void findHiddenTensorViews(Fusion* fusion);
|
306
|
+
//! Update Symbolic FusionStates after DynamicTransform pass
|
307
|
+
void updateSymbolicStates(
|
308
|
+
const std::unordered_map<Val*, Val*>& symbolic_to_concretized_map);
|
309
|
+
// Check that the NvFuser TensorView and the Python Tensor dimensions match.
|
310
|
+
// Apply after buildFusionIr
|
311
|
+
void verifyTensorDimensions();
|
312
|
+
|
313
|
+
//! Holds the defined maximum length of a FusionDefinition in order to
|
314
|
+
//! prevent a run away error. The user should feel free to increase this
|
315
|
+
//! number as appropriate.
|
316
|
+
size_t max_length_;
|
317
|
+
//! Fusion Cache Id for Scheduled Fusion.
|
318
|
+
std::optional<size_t> fusion_id_;
|
319
|
+
//! A pointer to the FusionCache.
|
320
|
+
FusionCache* fusion_cache_;
|
321
|
+
//! Current pointer to node in FusionCache.
|
322
|
+
TrieNode* trie_node_;
|
323
|
+
|
324
|
+
// Book keeping data members for user created schedules
|
325
|
+
|
326
|
+
//! Data member for holding previous fusion container when manually setting
|
327
|
+
//! the fusion guard.
|
328
|
+
Fusion* prev_fusion_;
|
329
|
+
//! Data member for holding the current user schedule object
|
330
|
+
UserSchedule* user_sched_;
|
331
|
+
//! Number of recording_states_ before applying user schedule
|
332
|
+
int64_t num_recording_states_presched_ = 0;
|
333
|
+
//! Data member that creates SegmentedFusion from cloned, prescheduled Fusion
|
334
|
+
//! then translates the segments to python FusionDefinitions.
|
335
|
+
std::unique_ptr<SegmentationState> segmentation_state_;
|
336
|
+
|
337
|
+
public:
|
338
|
+
//! The Operators are not directly defined in this header. They are defined
|
339
|
+
//! in the python bindings through lambda functions so the user only needs to
|
340
|
+
//! define new operators in one place.
|
341
|
+
//! Operators define what operations are fused.
|
342
|
+
struct Operators {
|
343
|
+
Operators(FusionDefinition* fd) : fusion_definition(fd) {}
|
344
|
+
bool validUse() const {
|
345
|
+
return !fusion_definition->completed();
|
346
|
+
}
|
347
|
+
|
348
|
+
FusionDefinition* fusion_definition;
|
349
|
+
};
|
350
|
+
|
351
|
+
//! The SchedOperators are not directly defined in this header. They are
|
352
|
+
//! defined in the python bindings through lambda functions so the user only
|
353
|
+
//! needs to define new operators in one place.
|
354
|
+
//! SchedOperators allow the user to define how a fusion should be blocked
|
355
|
+
//! for execution.
|
356
|
+
struct SchedOperators {
|
357
|
+
SchedOperators(FusionDefinition* fd) : fusion_definition(fd) {}
|
358
|
+
bool validUse() const {
|
359
|
+
return fusion_definition->completed();
|
360
|
+
}
|
361
|
+
|
362
|
+
FusionDefinition* fusion_definition;
|
363
|
+
};
|
364
|
+
|
365
|
+
Operators ops;
|
366
|
+
SchedOperators sched;
|
367
|
+
|
368
|
+
private:
|
369
|
+
mutable std::optional<std::string> debug_output_ = std::nullopt;
|
370
|
+
};
|
371
|
+
|
372
|
+
} // namespace nvfuser::python_frontend
|