nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <visibility.h>
|
12
|
+
|
13
|
+
#include <ir/interface_nodes.h>
|
14
|
+
#include <type.h>
|
15
|
+
|
16
|
+
//
|
17
|
+
// The operations defined in this header is intended as user facing functions.
|
18
|
+
// The user will provide the necessary input TensorViews and the function will
|
19
|
+
// create the correct intermediate nodes and return the output TensorViews.
|
20
|
+
//
|
21
|
+
|
22
|
+
namespace nvfuser {
|
23
|
+
|
24
|
+
struct ForwardDropoutResult {
|
25
|
+
TensorView* output = nullptr;
|
26
|
+
TensorView* mask = nullptr;
|
27
|
+
};
|
28
|
+
|
29
|
+
NVF_API ForwardDropoutResult dropout(TensorView* x, Val* prob);
|
30
|
+
|
31
|
+
NVF_API ForwardDropoutResult dropout(TensorView* x, Val* prob, Val* scale);
|
32
|
+
|
33
|
+
NVF_API TensorView* dropout_backward(
|
34
|
+
TensorView* dy,
|
35
|
+
TensorView* mask,
|
36
|
+
Val* scale);
|
37
|
+
|
38
|
+
NVF_API TensorView* triu(TensorView* tv, Val* offset);
|
39
|
+
|
40
|
+
struct LstmResult {
|
41
|
+
TensorView* cell = nullptr;
|
42
|
+
TensorView* hidden = nullptr;
|
43
|
+
};
|
44
|
+
|
45
|
+
NVF_API LstmResult lstm(
|
46
|
+
TensorView* prev_cell,
|
47
|
+
TensorView* in_x,
|
48
|
+
TensorView* forget_x,
|
49
|
+
TensorView* cell_x,
|
50
|
+
TensorView* out_x);
|
51
|
+
|
52
|
+
// Linear functions which takes in two tensors of shapes input[* , in_features],
|
53
|
+
// weight[out_features, in_features] / [in_features] and an optional bias of
|
54
|
+
// shape [out_features] or 0D scalar. Bias can only be given if weight is a 2-D
|
55
|
+
// tensor.
|
56
|
+
TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias);
|
57
|
+
// This is an implementation detail to reflect when linear is called
|
58
|
+
// without a bias. This calls the above function. We use this function
|
59
|
+
// since it simplifies creating a Python API which takes optional arguments.
|
60
|
+
// Other options include using lambdas or creating a new RecordFunctor for
|
61
|
+
// Linear.
|
62
|
+
TensorView* linear(TensorView* input, TensorView* weight);
|
63
|
+
|
64
|
+
NVF_API TensorView* sign(TensorView* x);
|
65
|
+
NVF_API Val* sign(Val* x);
|
66
|
+
TensorView* softplus(TensorView* x, Val* beta, Val* threshold);
|
67
|
+
NVF_API TensorView* gelu(TensorView* x);
|
68
|
+
NVF_API TensorView* gelu_backward(TensorView* dy, TensorView* x);
|
69
|
+
TensorView* tanh_gelu(TensorView* x);
|
70
|
+
TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x);
|
71
|
+
TensorView* tanh_backward(TensorView* dy, TensorView* tanh_x);
|
72
|
+
TensorView* leaky_relu(TensorView* x, Val* negative_slope);
|
73
|
+
|
74
|
+
NVF_API TensorView* view_as_real(TensorView* x);
|
75
|
+
|
76
|
+
// Matmul function which takes in tensors with the shapes
|
77
|
+
// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different
|
78
|
+
// layouts via strides. This has the same functionality as torch.matmul
|
79
|
+
TensorView* matmul(TensorView* tv_a, TensorView* tv_b);
|
80
|
+
|
81
|
+
// Scaled Dot Product Flash Attention Forward Result
|
82
|
+
struct SdpfaFwdResult {
|
83
|
+
TensorView* output = nullptr;
|
84
|
+
TensorView* log_sumexp = nullptr;
|
85
|
+
TensorView* philox_seed = nullptr;
|
86
|
+
TensorView* philox_offset = nullptr;
|
87
|
+
};
|
88
|
+
|
89
|
+
// Scaled Dot Product Flash Attention Forward API.
|
90
|
+
// Returns the same output as at::_scaled_dot_product_flash_attention
|
91
|
+
SdpfaFwdResult sdpfa_fwd(
|
92
|
+
TensorView* query,
|
93
|
+
TensorView* key,
|
94
|
+
TensorView* value,
|
95
|
+
Val* dropout_p,
|
96
|
+
Val* is_causal,
|
97
|
+
Val* scale);
|
98
|
+
|
99
|
+
// Scaled Dot Product Flash Attention Backward Result
|
100
|
+
struct SdpfaBwdResult {
|
101
|
+
TensorView* grad_query = nullptr;
|
102
|
+
TensorView* grad_key = nullptr;
|
103
|
+
TensorView* grad_value = nullptr;
|
104
|
+
};
|
105
|
+
|
106
|
+
// Scaled Dot Product Flash Attention Backward API.
|
107
|
+
// Returns the same output as at::_scaled_dot_product_flash_attention_backward
|
108
|
+
SdpfaBwdResult sdpfa_bwd(
|
109
|
+
TensorView* grad_output,
|
110
|
+
TensorView* query,
|
111
|
+
TensorView* key,
|
112
|
+
TensorView* value,
|
113
|
+
TensorView* output,
|
114
|
+
TensorView* log_sumexp,
|
115
|
+
Val* dropout_p,
|
116
|
+
Val* is_causal,
|
117
|
+
TensorView* philox_seed,
|
118
|
+
TensorView* philox_offset,
|
119
|
+
Val* scale);
|
120
|
+
|
121
|
+
TensorView* embedding_fwd(
|
122
|
+
TensorView* input,
|
123
|
+
TensorView* weight,
|
124
|
+
Val* padding_idx,
|
125
|
+
Val* max_norm,
|
126
|
+
Val* norm_type,
|
127
|
+
Val* scale_grad_by_freq,
|
128
|
+
Val* sparse);
|
129
|
+
|
130
|
+
} // namespace nvfuser
|
@@ -0,0 +1,55 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <visibility.h>
|
12
|
+
|
13
|
+
#include <ir/interface_nodes.h>
|
14
|
+
#include <type.h>
|
15
|
+
|
16
|
+
namespace nvfuser {
|
17
|
+
|
18
|
+
NVF_API TensorView* select(TensorView* tv, int64_t dim, Val* index);
|
19
|
+
|
20
|
+
// torch.index_select
|
21
|
+
NVF_API TensorView* indexSelect(
|
22
|
+
TensorView* input,
|
23
|
+
int64_t dim,
|
24
|
+
TensorView* index);
|
25
|
+
|
26
|
+
// torch.gather
|
27
|
+
NVF_API TensorView* torchGather(
|
28
|
+
TensorView* input,
|
29
|
+
int64_t dim,
|
30
|
+
TensorView* index);
|
31
|
+
|
32
|
+
// torch.scatter
|
33
|
+
TensorView* scatterOp(
|
34
|
+
ScatterOpType type,
|
35
|
+
TensorView* self,
|
36
|
+
int64_t dim,
|
37
|
+
TensorView* index,
|
38
|
+
TensorView* src);
|
39
|
+
|
40
|
+
NVF_API TensorView* scatter(
|
41
|
+
TensorView* self,
|
42
|
+
int64_t dim,
|
43
|
+
TensorView* index,
|
44
|
+
TensorView* src);
|
45
|
+
|
46
|
+
//! numpy.take_along_axis
|
47
|
+
//! (https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html)
|
48
|
+
//! Note the order of the parameters follows the numpy order, which is
|
49
|
+
//! different from torchGather.
|
50
|
+
NVF_API TensorView* takeAlongAxis(
|
51
|
+
TensorView* input,
|
52
|
+
TensorView* index,
|
53
|
+
int64_t dim);
|
54
|
+
|
55
|
+
} // namespace nvfuser
|
@@ -0,0 +1,263 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <visibility.h>
|
12
|
+
|
13
|
+
#include <ir/interface_nodes.h>
|
14
|
+
#include <type.h>
|
15
|
+
|
16
|
+
#include <tuple>
|
17
|
+
#include <vector>
|
18
|
+
|
19
|
+
//
|
20
|
+
// The operations defined in this header is intended as user facing functions.
|
21
|
+
// The user will provide the necessary input TensorViews and the function will
|
22
|
+
// create the correct intermediate nodes and return the output TensorViews.
|
23
|
+
//
|
24
|
+
|
25
|
+
namespace nvfuser {
|
26
|
+
|
27
|
+
struct ForwardNormResult {
|
28
|
+
TensorView* output = nullptr;
|
29
|
+
TensorView* mean = nullptr;
|
30
|
+
TensorView* invstd = nullptr;
|
31
|
+
};
|
32
|
+
|
33
|
+
struct BackwardNormResult {
|
34
|
+
TensorView* grad_input = nullptr;
|
35
|
+
TensorView* grad_weight = nullptr;
|
36
|
+
TensorView* grad_bias = nullptr;
|
37
|
+
};
|
38
|
+
|
39
|
+
struct ForwardRMSNormResult {
|
40
|
+
TensorView* output = nullptr;
|
41
|
+
TensorView* invstd = nullptr;
|
42
|
+
};
|
43
|
+
|
44
|
+
struct BackwardRMSNormResult {
|
45
|
+
TensorView* grad_input = nullptr;
|
46
|
+
TensorView* grad_weight = nullptr;
|
47
|
+
};
|
48
|
+
|
49
|
+
struct VarMeanResult {
|
50
|
+
TensorView* var = nullptr;
|
51
|
+
TensorView* mean = nullptr;
|
52
|
+
};
|
53
|
+
|
54
|
+
} // namespace nvfuser
|
55
|
+
|
56
|
+
namespace std {
|
57
|
+
|
58
|
+
// Make these results behave like a std::tuple
|
59
|
+
using nvfuser::BackwardNormResult;
|
60
|
+
using nvfuser::BackwardRMSNormResult;
|
61
|
+
using nvfuser::ForwardNormResult;
|
62
|
+
using nvfuser::ForwardRMSNormResult;
|
63
|
+
using nvfuser::TensorView;
|
64
|
+
using nvfuser::VarMeanResult;
|
65
|
+
|
66
|
+
template <int i>
|
67
|
+
constexpr TensorView* get(const ForwardNormResult& results) {
|
68
|
+
if (i == 0) {
|
69
|
+
return results.output;
|
70
|
+
}
|
71
|
+
if (i == 1) {
|
72
|
+
return results.mean;
|
73
|
+
}
|
74
|
+
if (i == 2) {
|
75
|
+
return results.invstd;
|
76
|
+
}
|
77
|
+
return nullptr;
|
78
|
+
}
|
79
|
+
|
80
|
+
template <int i>
|
81
|
+
constexpr TensorView* get(const BackwardNormResult& results) {
|
82
|
+
if (i == 0) {
|
83
|
+
return results.grad_input;
|
84
|
+
}
|
85
|
+
if (i == 1) {
|
86
|
+
return results.grad_weight;
|
87
|
+
}
|
88
|
+
if (i == 2) {
|
89
|
+
return results.grad_bias;
|
90
|
+
}
|
91
|
+
return nullptr;
|
92
|
+
}
|
93
|
+
|
94
|
+
template <int i>
|
95
|
+
constexpr TensorView* get(const ForwardRMSNormResult& results) {
|
96
|
+
if (i == 0) {
|
97
|
+
return results.output;
|
98
|
+
}
|
99
|
+
if (i == 1) {
|
100
|
+
return results.invstd;
|
101
|
+
}
|
102
|
+
return nullptr;
|
103
|
+
}
|
104
|
+
|
105
|
+
template <int i>
|
106
|
+
constexpr TensorView* get(const BackwardRMSNormResult& results) {
|
107
|
+
if (i == 0) {
|
108
|
+
return results.grad_input;
|
109
|
+
}
|
110
|
+
if (i == 1) {
|
111
|
+
return results.grad_weight;
|
112
|
+
}
|
113
|
+
return nullptr;
|
114
|
+
}
|
115
|
+
|
116
|
+
template <int i>
|
117
|
+
constexpr TensorView* get(const VarMeanResult& results) {
|
118
|
+
if (i == 0) {
|
119
|
+
return results.var;
|
120
|
+
}
|
121
|
+
if (i == 1) {
|
122
|
+
return results.mean;
|
123
|
+
}
|
124
|
+
return nullptr;
|
125
|
+
}
|
126
|
+
|
127
|
+
} // namespace std
|
128
|
+
|
129
|
+
namespace nvfuser {
|
130
|
+
|
131
|
+
TensorView* mean(TensorView* x, const std::vector<int64_t>& dims, bool keepdim);
|
132
|
+
|
133
|
+
NVF_API TensorView* variance(
|
134
|
+
TensorView* x,
|
135
|
+
const std::vector<int64_t>& dims,
|
136
|
+
bool unbiased,
|
137
|
+
bool keepdim);
|
138
|
+
|
139
|
+
NVF_API TensorView* variance(
|
140
|
+
TensorView* x,
|
141
|
+
const std::vector<int64_t>& dims,
|
142
|
+
int64_t correction,
|
143
|
+
bool keepdim);
|
144
|
+
|
145
|
+
NVF_API VarMeanResult variance_mean(
|
146
|
+
TensorView* x,
|
147
|
+
const std::vector<int64_t>& dims,
|
148
|
+
int64_t correction,
|
149
|
+
bool keepdim);
|
150
|
+
|
151
|
+
NVF_API TensorView* standard_deviation(
|
152
|
+
TensorView* x,
|
153
|
+
const std::vector<int64_t>& dims,
|
154
|
+
bool unbiased,
|
155
|
+
bool keepdim);
|
156
|
+
|
157
|
+
NVF_API TensorView* softmax(TensorView* x, int64_t dim);
|
158
|
+
|
159
|
+
NVF_API TensorView* softmax_backward(
|
160
|
+
TensorView* dy,
|
161
|
+
TensorView* y,
|
162
|
+
const int64_t dim);
|
163
|
+
|
164
|
+
NVF_API TensorView* log_softmax(TensorView* x, int64_t dim);
|
165
|
+
|
166
|
+
NVF_API TensorView* log_softmax_backward(
|
167
|
+
TensorView* dy,
|
168
|
+
TensorView* y,
|
169
|
+
const int64_t dim);
|
170
|
+
|
171
|
+
NVF_API ForwardNormResult layer_norm(
|
172
|
+
TensorView* x,
|
173
|
+
const std::vector<int64_t>& norm_shape,
|
174
|
+
TensorView* weight,
|
175
|
+
TensorView* bias,
|
176
|
+
Val* eps);
|
177
|
+
|
178
|
+
NVF_API ForwardNormResult layer_norm(
|
179
|
+
TensorView* x,
|
180
|
+
const int64_t kNormShapeNumDims,
|
181
|
+
TensorView* weight,
|
182
|
+
TensorView* bias,
|
183
|
+
Val* eps);
|
184
|
+
|
185
|
+
NVF_API ForwardRMSNormResult rms_norm(
|
186
|
+
TensorView* x,
|
187
|
+
const std::vector<int64_t>& norm_shape,
|
188
|
+
TensorView* weight,
|
189
|
+
Val* eps);
|
190
|
+
|
191
|
+
NVF_API ForwardRMSNormResult rms_norm(
|
192
|
+
TensorView* x,
|
193
|
+
const int64_t kNormShapeNumDims,
|
194
|
+
TensorView* weight,
|
195
|
+
Val* eps);
|
196
|
+
|
197
|
+
NVF_API BackwardNormResult layer_norm_backward(
|
198
|
+
TensorView* dy,
|
199
|
+
TensorView* x,
|
200
|
+
const std::vector<int64_t>& norm_shape,
|
201
|
+
TensorView* mean,
|
202
|
+
TensorView* rstd,
|
203
|
+
TensorView* weight,
|
204
|
+
TensorView* bias,
|
205
|
+
const std::vector<bool>& output_mask);
|
206
|
+
|
207
|
+
NVF_API BackwardRMSNormResult rms_norm_backward(
|
208
|
+
TensorView* dy,
|
209
|
+
TensorView* x,
|
210
|
+
const std::vector<int64_t>& norm_shape,
|
211
|
+
TensorView* rstd,
|
212
|
+
TensorView* weight,
|
213
|
+
const std::vector<bool>& output_mask);
|
214
|
+
|
215
|
+
NVF_API ForwardNormResult batch_norm(
|
216
|
+
TensorView* x,
|
217
|
+
TensorView* weight,
|
218
|
+
TensorView* bias,
|
219
|
+
TensorView* running_mean,
|
220
|
+
TensorView* running_var,
|
221
|
+
const bool kTraining,
|
222
|
+
Val* momentum,
|
223
|
+
Val* eps,
|
224
|
+
bool channels_last = false);
|
225
|
+
|
226
|
+
NVF_API BackwardNormResult batch_norm_backward(
|
227
|
+
TensorView* x,
|
228
|
+
TensorView* dy,
|
229
|
+
TensorView* weight,
|
230
|
+
TensorView* running_mean,
|
231
|
+
TensorView* running_var,
|
232
|
+
TensorView* save_mean,
|
233
|
+
TensorView* save_invstd,
|
234
|
+
const bool kTraining,
|
235
|
+
Val* eps,
|
236
|
+
const std::vector<bool>& output_mask,
|
237
|
+
bool channels_last = false);
|
238
|
+
|
239
|
+
NVF_API ForwardNormResult instance_norm(
|
240
|
+
TensorView* x,
|
241
|
+
TensorView* weight,
|
242
|
+
TensorView* bias,
|
243
|
+
TensorView* running_mean,
|
244
|
+
TensorView* running_var,
|
245
|
+
const bool kUseInputStats, // kTraining?
|
246
|
+
Val* momentum,
|
247
|
+
Val* eps,
|
248
|
+
bool channels_last = false);
|
249
|
+
|
250
|
+
NVF_API BackwardNormResult instance_norm_backward(
|
251
|
+
TensorView* x,
|
252
|
+
TensorView* dy,
|
253
|
+
TensorView* weight,
|
254
|
+
TensorView* running_mean,
|
255
|
+
TensorView* running_var,
|
256
|
+
TensorView* save_mean,
|
257
|
+
TensorView* save_invstd,
|
258
|
+
const bool kTraining,
|
259
|
+
Val* eps,
|
260
|
+
const std::vector<bool>& output_mask,
|
261
|
+
bool channels_last = false);
|
262
|
+
|
263
|
+
} // namespace nvfuser
|
@@ -0,0 +1,127 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <ir/base_nodes.h>
|
12
|
+
#include <ir/interface_nodes.h>
|
13
|
+
#include <scheduler/matmul_utils.h>
|
14
|
+
#include <type.h>
|
15
|
+
#include <visibility.h>
|
16
|
+
|
17
|
+
#include <vector>
|
18
|
+
|
19
|
+
namespace nvfuser {
|
20
|
+
|
21
|
+
enum class AttnRole { Q = 0, K, V, Mask };
|
22
|
+
|
23
|
+
namespace ops {
|
24
|
+
|
25
|
+
TensorView* maybe_broadcast_inner_to_rank(TensorView* t, size_t rank);
|
26
|
+
|
27
|
+
// A utility function that broadcasts index TensorView to the rank of the other
|
28
|
+
// TensorView.
|
29
|
+
TensorView* maybeBroadcastIndexTv(TensorView* t, size_t dim, size_t rank);
|
30
|
+
|
31
|
+
// A utility function that checks if index tv is already broadcasted to correct
|
32
|
+
// shape for index_select
|
33
|
+
bool isIndexAlreadyBroadcast(
|
34
|
+
const std::vector<IterDomain*>& index_domain,
|
35
|
+
size_t dim,
|
36
|
+
size_t rank);
|
37
|
+
|
38
|
+
Val* simplifiedInt(Val* val);
|
39
|
+
|
40
|
+
// If one size is nullptr, return the other. If both symbolic just return v1. If
|
41
|
+
// one's concrete, prefer that one (simplified). If both concrete make sure
|
42
|
+
// they're the same size.
|
43
|
+
Val* promoteSize(Val* v1, Val* v2);
|
44
|
+
|
45
|
+
// Will return a new value of type val with the DataType dtype.
|
46
|
+
Val* newScalar(ValType vtype, DataType dtype);
|
47
|
+
|
48
|
+
IterType promoteIterType(IterType type1, IterType type2);
|
49
|
+
|
50
|
+
// For MatmulOp, the input iterdomains at a given index do not necessarily map
|
51
|
+
// to the output iterdomain at that index This function aligns the input
|
52
|
+
// iterdomain to the output and returns a vector where each element is the input
|
53
|
+
// iterdomain corresponding to the output iterdomain at that index. If the
|
54
|
+
// element is nullptr, there is no mapping between input-output at that index.
|
55
|
+
// Based on the input dimensions following cases are possible:
|
56
|
+
// 1. A/B is 1D: [M, K] x [K] -> [M] (Mapping A: {id_M}, Mapping B: {nullptr})
|
57
|
+
// or [K] x [N, K] -> [N] (Mapping A: {nullptr}, Mapping B: {id_N})
|
58
|
+
// 2. A and B are 2D: [M, K] x [K, N] -> [M, N] (Mapping A: {id_M, nullptr},
|
59
|
+
// Mapping B: {nullptr, id_N})
|
60
|
+
// 3. A/B are atleast 1D and one of them is > 2D: [B, M, K] x [K, N] -> [B, M,
|
61
|
+
// N] (Mapping A: {id_B, id_M, nullptr}, Mapping B: {nullptr, nullptr, id_N})
|
62
|
+
// Args:
|
63
|
+
// 1. input_domain: root/logical domain without reductions for any input to
|
64
|
+
// MatmulOp
|
65
|
+
// 2. input_position: Specifies if the input is A / B (0 or 1)
|
66
|
+
// 3: out_size: MatmulOp output dimension (input and output may not be the same
|
67
|
+
// size).
|
68
|
+
std::vector<IterDomain*> mapMatmulOpIterDomains(
|
69
|
+
const std::vector<IterDomain*>& input_domain,
|
70
|
+
int64_t input_position,
|
71
|
+
size_t out_size);
|
72
|
+
|
73
|
+
// For LinearOp, the output is the same as the first input (A[*,
|
74
|
+
// in_features])for all but the last dimension. If the second input is 2D
|
75
|
+
// (B[out_features, in_features]), the last dimension of output is out_features.
|
76
|
+
// If bias is 1D (bias[out_features]) it maps to the last dimension of the
|
77
|
+
// output. Args:
|
78
|
+
// 1. input_domain: root/logical domain without reductions for any input to
|
79
|
+
// LinearOp
|
80
|
+
// 2. input_position: Specifies if the input is A / B / Bias (0, 1, or 2)
|
81
|
+
// (MatmulTensorRole::Input_A/Input_B/Input_C) 3: out_size: LinearOp output
|
82
|
+
// dimension (input and output may not be the same size).
|
83
|
+
std::vector<IterDomain*> mapLinearOpIterDomains(
|
84
|
+
const std::vector<IterDomain*>& input_domain,
|
85
|
+
int64_t input_position,
|
86
|
+
size_t out_size,
|
87
|
+
bool k_bcast);
|
88
|
+
|
89
|
+
// Takes a vector of aligned input iterdomains to create the output iterdomain.
|
90
|
+
// This is used if the input iterdomains are not trivially mapped to the output
|
91
|
+
// iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will
|
92
|
+
// be the output IterType regardless of the inputs; otherwise the output
|
93
|
+
// IterType is inferred from ids.
|
94
|
+
IterDomain* newOutputIterDomain(
|
95
|
+
const std::vector<IterDomain*>& ids,
|
96
|
+
const std::optional<IterType> force_iter_type = std::nullopt);
|
97
|
+
|
98
|
+
// Takes a vector of `Val*`s and assumes they are all aligned to create the
|
99
|
+
// output tensorview, e.g., for BinaryOp. `vals` can contain scalars, e.g, when
|
100
|
+
// creating the output TensorView for `tv0+scalar`. This is for convenience and
|
101
|
+
// scalars will be ignored.
|
102
|
+
std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals);
|
103
|
+
|
104
|
+
TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype);
|
105
|
+
|
106
|
+
std::vector<Val*> maybeBroadcast(const std::vector<Val*>& vals);
|
107
|
+
|
108
|
+
NVF_API Val* newValLike(Val* val, DataType dtype);
|
109
|
+
|
110
|
+
// returns the minimum init value for reduction:
|
111
|
+
// -inf for floating type;
|
112
|
+
// lowest value for integer type;
|
113
|
+
// false for bool.
|
114
|
+
Val* getMinimumValue(DataType v);
|
115
|
+
|
116
|
+
// returns the maximum init value for reduction:
|
117
|
+
// inf for floating type;
|
118
|
+
// highest value for integer type;
|
119
|
+
// true for bool.
|
120
|
+
Val* getMaximumValue(DataType v);
|
121
|
+
|
122
|
+
std::vector<unsigned int> canonicalizeAxes(
|
123
|
+
const std::vector<int64_t>& axes,
|
124
|
+
int64_t ndims);
|
125
|
+
|
126
|
+
} // namespace ops
|
127
|
+
} // namespace nvfuser
|