nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,1125 @@
|
|
1
|
+
// clang-format off
|
2
|
+
/*
|
3
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
|
4
|
+
* All rights reserved.
|
5
|
+
* SPDX-License-Identifier: BSD-3-Clause
|
6
|
+
*/
|
7
|
+
// clang-format on
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <exceptions.h>
|
11
|
+
#include <macros.h>
|
12
|
+
#include <visibility.h>
|
13
|
+
|
14
|
+
#include <c10/core/ScalarType.h>
|
15
|
+
|
16
|
+
#include <polymorphic_value.h>
|
17
|
+
|
18
|
+
#include <array>
|
19
|
+
#include <complex>
|
20
|
+
#include <cstdint>
|
21
|
+
#include <iostream>
|
22
|
+
#include <optional>
|
23
|
+
#include <string>
|
24
|
+
#include <type_traits>
|
25
|
+
#include <typeinfo>
|
26
|
+
#include <unordered_set>
|
27
|
+
#include <variant>
|
28
|
+
|
29
|
+
namespace nvfuser {
|
30
|
+
|
31
|
+
// Order of strength
|
32
|
+
enum class ValType {
|
33
|
+
TensorDomain,
|
34
|
+
IterDomain,
|
35
|
+
TensorView,
|
36
|
+
NamedScalar,
|
37
|
+
Predicate,
|
38
|
+
TensorIndex,
|
39
|
+
Stream,
|
40
|
+
Others
|
41
|
+
};
|
42
|
+
|
43
|
+
// Manual - The user provides the Bool value. Predicate generation is bypassed.
|
44
|
+
// Inline corresponds with PredicateCompute::getInlinePredicate
|
45
|
+
// Unswitch corresponds with UnswitchPredicate::get
|
46
|
+
// Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag
|
47
|
+
// ReductionWrite - Same as Inline but without reduction axes
|
48
|
+
// LoopRotation - Predicate added by loop rotation, currently always true.
|
49
|
+
// ElectSync - Select a single thread to launch asynchronous operations.
|
50
|
+
enum class PredicateType {
|
51
|
+
Manual,
|
52
|
+
Inline,
|
53
|
+
Unswitch,
|
54
|
+
Vectorize,
|
55
|
+
Misaligned,
|
56
|
+
ReductionWrite,
|
57
|
+
LoopRotation,
|
58
|
+
ElectSync
|
59
|
+
};
|
60
|
+
|
61
|
+
// Index type is a convenience type that may be a 64 or 32 signed integer.
|
62
|
+
// This is helpful for math on indexing/size when we don't know what the index
|
63
|
+
// type might be. This allows us to prevent assuming the welford count must be
|
64
|
+
// int64_t which is relatively heavy to carry around. Index will be resolved
|
65
|
+
// at compile time with KernelIndexMode.
|
66
|
+
enum class PrimDataType {
|
67
|
+
// Floating point types
|
68
|
+
Double,
|
69
|
+
Float,
|
70
|
+
Half,
|
71
|
+
BFloat16,
|
72
|
+
Float8_e4m3fn,
|
73
|
+
Float8_e5m2,
|
74
|
+
// Integral types
|
75
|
+
Char,
|
76
|
+
Short,
|
77
|
+
Int32,
|
78
|
+
Int,
|
79
|
+
Byte, // Following ATen convention
|
80
|
+
UInt16, // Following ATen convention
|
81
|
+
UInt32,
|
82
|
+
UInt64,
|
83
|
+
Index,
|
84
|
+
// Boolean types
|
85
|
+
Bool,
|
86
|
+
// Complex types
|
87
|
+
ComplexDouble,
|
88
|
+
ComplexFloat,
|
89
|
+
// Pointers
|
90
|
+
SMemAddress,
|
91
|
+
TMemAddress,
|
92
|
+
// Null
|
93
|
+
Null
|
94
|
+
};
|
95
|
+
|
96
|
+
#if defined(__GNUC__) && !defined(__clang__)
|
97
|
+
#pragma GCC diagnostic push
|
98
|
+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
99
|
+
#endif
|
100
|
+
|
101
|
+
struct DataType;
|
102
|
+
|
103
|
+
struct ArrayType {
|
104
|
+
std::shared_ptr<DataType> type;
|
105
|
+
size_t size;
|
106
|
+
inline bool operator==(const ArrayType& other) const;
|
107
|
+
};
|
108
|
+
|
109
|
+
struct PointerType {
|
110
|
+
std::shared_ptr<DataType> type;
|
111
|
+
inline bool operator==(const PointerType& other) const;
|
112
|
+
};
|
113
|
+
|
114
|
+
struct StructType {
|
115
|
+
std::string name;
|
116
|
+
std::function<std::shared_ptr<Struct>()> create;
|
117
|
+
|
118
|
+
struct FieldInfo {
|
119
|
+
std::string name;
|
120
|
+
std::shared_ptr<DataType> type;
|
121
|
+
bool used_in_kernel = true;
|
122
|
+
};
|
123
|
+
|
124
|
+
std::vector<FieldInfo> fields;
|
125
|
+
|
126
|
+
template <typename T>
|
127
|
+
static StructType make(std::vector<FieldInfo> fields, std::string name = "") {
|
128
|
+
static_assert(
|
129
|
+
std::is_base_of<Struct, T>::value,
|
130
|
+
"StructType::make only accepts Struct types");
|
131
|
+
return StructType{
|
132
|
+
.name = std::move(name),
|
133
|
+
.create =
|
134
|
+
[]() {
|
135
|
+
return std::static_pointer_cast<Struct>(std::make_shared<T>());
|
136
|
+
},
|
137
|
+
.fields = std::move(fields)};
|
138
|
+
}
|
139
|
+
|
140
|
+
inline const DataType& fieldDataType(const std::string& name) const {
|
141
|
+
for (const auto& field : fields) {
|
142
|
+
if (field.name == name) {
|
143
|
+
return *field.type;
|
144
|
+
}
|
145
|
+
}
|
146
|
+
NVF_THROW("Field ", name, " not found in struct ", this->name);
|
147
|
+
}
|
148
|
+
|
149
|
+
inline bool operator==(const StructType& other) const;
|
150
|
+
};
|
151
|
+
|
152
|
+
struct OpaqueType {
|
153
|
+
std::string name;
|
154
|
+
std::reference_wrapper<const std::type_info> type_info;
|
155
|
+
size_t size;
|
156
|
+
|
157
|
+
template <typename T>
|
158
|
+
static OpaqueType make(std::string name = "") {
|
159
|
+
return OpaqueType{
|
160
|
+
.name = std::move(name), .type_info = typeid(T), .size = sizeof(T)};
|
161
|
+
}
|
162
|
+
|
163
|
+
inline bool operator==(const OpaqueType& other) const {
|
164
|
+
return type_info.get() == other.type_info.get();
|
165
|
+
}
|
166
|
+
};
|
167
|
+
|
168
|
+
struct DataType {
|
169
|
+
using VariantOfSupportedTypes = std::
|
170
|
+
variant<PrimDataType, ArrayType, PointerType, StructType, OpaqueType>;
|
171
|
+
VariantOfSupportedTypes type = PrimDataType::Null;
|
172
|
+
|
173
|
+
DataType() = default;
|
174
|
+
DataType(VariantOfSupportedTypes type) : type(std::move(type)) {}
|
175
|
+
DataType(PrimDataType type) : type(type) {}
|
176
|
+
DataType(ArrayType type) : type(std::move(type)) {}
|
177
|
+
DataType(PointerType type) : type(std::move(type)) {}
|
178
|
+
DataType(StructType type) : type(std::move(type)) {}
|
179
|
+
DataType(OpaqueType type) : type(std::move(type)) {}
|
180
|
+
|
181
|
+
static constexpr PrimDataType Double = PrimDataType::Double;
|
182
|
+
static constexpr PrimDataType Float = PrimDataType::Float;
|
183
|
+
static constexpr PrimDataType Half = PrimDataType::Half;
|
184
|
+
static constexpr PrimDataType Float8_e4m3fn = PrimDataType::Float8_e4m3fn;
|
185
|
+
static constexpr PrimDataType Float8_e5m2 = PrimDataType::Float8_e5m2;
|
186
|
+
static constexpr PrimDataType Index = PrimDataType::Index;
|
187
|
+
static constexpr PrimDataType Char = PrimDataType::Char;
|
188
|
+
static constexpr PrimDataType Short = PrimDataType::Short;
|
189
|
+
static constexpr PrimDataType Int32 = PrimDataType::Int32;
|
190
|
+
static constexpr PrimDataType Int = PrimDataType::Int;
|
191
|
+
static constexpr PrimDataType Byte = PrimDataType::Byte;
|
192
|
+
static constexpr PrimDataType UInt16 = PrimDataType::UInt16;
|
193
|
+
static constexpr PrimDataType UInt32 = PrimDataType::UInt32;
|
194
|
+
static constexpr PrimDataType UInt64 = PrimDataType::UInt64;
|
195
|
+
static constexpr PrimDataType Bool = PrimDataType::Bool;
|
196
|
+
static constexpr PrimDataType BFloat16 = PrimDataType::BFloat16;
|
197
|
+
static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat;
|
198
|
+
static constexpr PrimDataType ComplexDouble = PrimDataType::ComplexDouble;
|
199
|
+
static constexpr PrimDataType SMemAddress = PrimDataType::SMemAddress;
|
200
|
+
static constexpr PrimDataType TMemAddress = PrimDataType::TMemAddress;
|
201
|
+
static constexpr PrimDataType Null = PrimDataType::Null;
|
202
|
+
};
|
203
|
+
|
204
|
+
inline bool operator==(const DataType& lhs, const DataType& rhs) {
|
205
|
+
return lhs.type == rhs.type;
|
206
|
+
}
|
207
|
+
|
208
|
+
inline bool operator!=(const DataType& lhs, const DataType& rhs) {
|
209
|
+
return !operator==(lhs, rhs);
|
210
|
+
}
|
211
|
+
|
212
|
+
bool ArrayType::operator==(const ArrayType& other) const {
|
213
|
+
return *type == *other.type && size == other.size;
|
214
|
+
}
|
215
|
+
|
216
|
+
bool PointerType::operator==(const PointerType& other) const {
|
217
|
+
return *type == *other.type;
|
218
|
+
}
|
219
|
+
|
220
|
+
bool StructType::operator==(const StructType& other) const {
|
221
|
+
if (fields.size() != other.fields.size()) {
|
222
|
+
return false;
|
223
|
+
}
|
224
|
+
for (auto i : c10::irange(fields.size())) {
|
225
|
+
if (fields[i].name != other.fields[i].name ||
|
226
|
+
*fields[i].type != *other.fields[i].type ||
|
227
|
+
fields[i].used_in_kernel != other.fields[i].used_in_kernel) {
|
228
|
+
return false;
|
229
|
+
}
|
230
|
+
}
|
231
|
+
return true;
|
232
|
+
}
|
233
|
+
|
234
|
+
inline StructType StructHandle::type() const {
|
235
|
+
return struct_ptr_->type();
|
236
|
+
}
|
237
|
+
|
238
|
+
StructType globalTensorMetaData(
|
239
|
+
const PrimDataType& dtype,
|
240
|
+
size_t dim,
|
241
|
+
size_t alloc_dim);
|
242
|
+
|
243
|
+
inline StructType globalTensorMetaData(const PrimDataType& dtype, size_t dim) {
|
244
|
+
return globalTensorMetaData(dtype, dim, dim);
|
245
|
+
}
|
246
|
+
|
247
|
+
class Val;
|
248
|
+
//! Get the type of a Val's metadata, currently only supporting tensors
|
249
|
+
NVF_API DataType metaDataTypeOf(const Val* tv);
|
250
|
+
|
251
|
+
enum class KernelIndexMode { INT32, INT64 };
|
252
|
+
|
253
|
+
PrimDataType indexModeToDtype(KernelIndexMode index_mode);
|
254
|
+
KernelIndexMode indexTypeToMode(DataType index_type);
|
255
|
+
|
256
|
+
// check if type preserves all information from base_type. Which indicates a
|
257
|
+
// cast from base_type -> type -> base_type should be bit-wise identical
|
258
|
+
bool isInclusiveType(const DataType& base_type, const DataType& type);
|
259
|
+
|
260
|
+
// Returns if the datatype is a floating point type
|
261
|
+
inline bool isFloatingPointType(DataType dtype) {
|
262
|
+
return dtype == DataType::Double || dtype == DataType::Float ||
|
263
|
+
dtype == DataType::Half || dtype == DataType::BFloat16 ||
|
264
|
+
dtype == DataType::Float8_e4m3fn || dtype == DataType::Float8_e5m2;
|
265
|
+
}
|
266
|
+
|
267
|
+
// Returns if the datatype is an integer type
|
268
|
+
inline bool isIntegralType(DataType dtype) {
|
269
|
+
return std::visit(
|
270
|
+
[](auto&& dtype) {
|
271
|
+
using T = std::decay_t<decltype(dtype)>;
|
272
|
+
if constexpr (std::is_same_v<T, PrimDataType>) {
|
273
|
+
switch (dtype) {
|
274
|
+
case DataType::Index:
|
275
|
+
case DataType::Char:
|
276
|
+
case DataType::Short:
|
277
|
+
case DataType::Int:
|
278
|
+
case DataType::Int32:
|
279
|
+
case DataType::Byte:
|
280
|
+
case DataType::UInt16:
|
281
|
+
case DataType::UInt32:
|
282
|
+
case DataType::UInt64:
|
283
|
+
return true;
|
284
|
+
default:
|
285
|
+
return false;
|
286
|
+
}
|
287
|
+
}
|
288
|
+
return false;
|
289
|
+
},
|
290
|
+
dtype.type);
|
291
|
+
}
|
292
|
+
|
293
|
+
// Returns if the datatype is an unsigned integer type
|
294
|
+
inline bool isUnsignedIntegralType(DataType dtype) {
|
295
|
+
return dtype == DataType::Byte || dtype == DataType::UInt16 ||
|
296
|
+
dtype == DataType::UInt32 || dtype == DataType::UInt64;
|
297
|
+
}
|
298
|
+
|
299
|
+
// Returns if the datatype is a pointer type
|
300
|
+
inline bool isPointerType(DataType dtype) {
|
301
|
+
return std::holds_alternative<PointerType>(dtype.type) ||
|
302
|
+
dtype == DataType::SMemAddress || dtype == DataType::TMemAddress;
|
303
|
+
}
|
304
|
+
|
305
|
+
// Returns if the datatype is an integer or pointer type
|
306
|
+
inline bool isIntegralOrPointerType(DataType dtype) {
|
307
|
+
return isIntegralType(dtype) || isPointerType(dtype);
|
308
|
+
}
|
309
|
+
|
310
|
+
// Returns if the datatype is a boolean type
|
311
|
+
inline bool isBooleanType(DataType dtype) {
|
312
|
+
return dtype == DataType::Bool;
|
313
|
+
}
|
314
|
+
|
315
|
+
// Returns if the datatype is a complex type
|
316
|
+
inline bool isComplexType(DataType dtype) {
|
317
|
+
return dtype == DataType::ComplexFloat || dtype == DataType::ComplexDouble;
|
318
|
+
}
|
319
|
+
|
320
|
+
// Returns if the datatype is a complex type
|
321
|
+
inline bool isStructType(DataType dtype) {
|
322
|
+
return std::holds_alternative<StructType>(dtype.type);
|
323
|
+
}
|
324
|
+
|
325
|
+
// Return the corresponding scalar of a complex type
|
326
|
+
DataType getTypeFromComplexType(DataType dtype);
|
327
|
+
// Return the corresponding complex type of a scalar
|
328
|
+
DataType getComplexTypeFromType(DataType dtype);
|
329
|
+
// Return if the datatype is supported on the current device
|
330
|
+
NVF_API bool isSupportedTypeByDevice(DataType dtype);
|
331
|
+
|
332
|
+
NVF_API int64_t dataTypeSize(DataType type);
|
333
|
+
|
334
|
+
// If the index type is known it will be automatically used here
|
335
|
+
int64_t dataTypeSize(DataType type, DataType index_type);
|
336
|
+
|
337
|
+
template <PrimDataType DT>
|
338
|
+
struct DataTypeToNativeType;
|
339
|
+
|
340
|
+
template <PrimDataType DT>
|
341
|
+
struct DataTypeToAtenType;
|
342
|
+
|
343
|
+
template <typename NativeType>
|
344
|
+
struct NativeTypeToDataType;
|
345
|
+
|
346
|
+
template <at::ScalarType aten_type>
|
347
|
+
struct AtenTypeToDataType;
|
348
|
+
|
349
|
+
template <at::ScalarType aten_type>
|
350
|
+
struct AtenTypeToNativeType;
|
351
|
+
|
352
|
+
template <typename NativeType>
|
353
|
+
struct IsPrimitiveNativeType : std::false_type {};
|
354
|
+
|
355
|
+
#define DEFINE_DATATYPE_TO_NATIVE_TYPE(data_type, native_type) \
|
356
|
+
template <> \
|
357
|
+
struct DataTypeToNativeType<data_type> { \
|
358
|
+
using type = native_type; \
|
359
|
+
}; \
|
360
|
+
template <> \
|
361
|
+
struct NativeTypeToDataType<native_type> { \
|
362
|
+
static constexpr PrimDataType type = data_type; \
|
363
|
+
}; \
|
364
|
+
template <> \
|
365
|
+
struct IsPrimitiveNativeType<native_type> : std::true_type {}
|
366
|
+
|
367
|
+
#define DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( \
|
368
|
+
data_type, at_type, native_type) \
|
369
|
+
DEFINE_DATATYPE_TO_NATIVE_TYPE(data_type, native_type); \
|
370
|
+
template <> \
|
371
|
+
struct AtenTypeToDataType<at_type> { \
|
372
|
+
static constexpr PrimDataType type = data_type; \
|
373
|
+
}; \
|
374
|
+
template <> \
|
375
|
+
struct AtenTypeToNativeType<at_type> { \
|
376
|
+
using type = native_type; \
|
377
|
+
}
|
378
|
+
|
379
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
380
|
+
DataType::Float,
|
381
|
+
at::ScalarType::Float,
|
382
|
+
float);
|
383
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
384
|
+
DataType::Double,
|
385
|
+
at::ScalarType::Double,
|
386
|
+
double);
|
387
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
388
|
+
DataType::Half,
|
389
|
+
at::ScalarType::Half,
|
390
|
+
at::Half);
|
391
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
392
|
+
DataType::BFloat16,
|
393
|
+
at::ScalarType::BFloat16,
|
394
|
+
at::BFloat16);
|
395
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
396
|
+
DataType::Float8_e4m3fn,
|
397
|
+
at::ScalarType::Float8_e4m3fn,
|
398
|
+
at::Float8_e4m3fn);
|
399
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
400
|
+
DataType::Float8_e5m2,
|
401
|
+
at::ScalarType::Float8_e5m2,
|
402
|
+
at::Float8_e5m2);
|
403
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
404
|
+
DataType::Char,
|
405
|
+
at::ScalarType::Char,
|
406
|
+
int8_t);
|
407
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
408
|
+
DataType::Short,
|
409
|
+
at::ScalarType::Short,
|
410
|
+
int16_t);
|
411
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
412
|
+
DataType::Int32,
|
413
|
+
at::ScalarType::Int,
|
414
|
+
int);
|
415
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
416
|
+
DataType::Int,
|
417
|
+
at::ScalarType::Long,
|
418
|
+
int64_t);
|
419
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
420
|
+
DataType::Byte,
|
421
|
+
at::ScalarType::Byte,
|
422
|
+
uint8_t);
|
423
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
424
|
+
DataType::UInt16,
|
425
|
+
at::ScalarType::UInt16,
|
426
|
+
uint16_t);
|
427
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
428
|
+
DataType::UInt32,
|
429
|
+
at::ScalarType::UInt32,
|
430
|
+
uint32_t);
|
431
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
432
|
+
DataType::UInt64,
|
433
|
+
at::ScalarType::UInt64,
|
434
|
+
uint64_t);
|
435
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
436
|
+
DataType::Bool,
|
437
|
+
at::ScalarType::Bool,
|
438
|
+
bool);
|
439
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
440
|
+
DataType::ComplexFloat,
|
441
|
+
at::ScalarType::ComplexFloat,
|
442
|
+
std::complex<float>);
|
443
|
+
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
|
444
|
+
DataType::ComplexDouble,
|
445
|
+
at::ScalarType::ComplexDouble,
|
446
|
+
std::complex<double>);
|
447
|
+
|
448
|
+
#undef DEFINE_DATATYPE_TO_NATIVE_TYPE
|
449
|
+
#undef DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE
|
450
|
+
|
451
|
+
inline DataType getDataType(const PolymorphicValue& value) {
|
452
|
+
std::optional<DataType> dtype = std::nullopt;
|
453
|
+
PolymorphicValue::for_all_types([&value, &dtype](auto _) {
|
454
|
+
using T = typename decltype(_)::type;
|
455
|
+
if constexpr (IsPrimitiveNativeType<T>::value) {
|
456
|
+
if (value.is<T>()) {
|
457
|
+
dtype = NativeTypeToDataType<T>::type;
|
458
|
+
}
|
459
|
+
} else if constexpr (std::is_same_v<T, std::vector<PolymorphicValue>>) {
|
460
|
+
if (value.is<T>()) {
|
461
|
+
const auto& vec = value.as<T>();
|
462
|
+
size_t size = vec.size();
|
463
|
+
NVF_CHECK(size > 0, "Empty array is not supported");
|
464
|
+
dtype =
|
465
|
+
ArrayType{std::make_shared<DataType>(getDataType(vec[0])), size};
|
466
|
+
}
|
467
|
+
} else if constexpr (std::is_same_v<T, Pointer>) {
|
468
|
+
// For pointers in polymorphic value, we only store the data size of the
|
469
|
+
// pointee, so it is impossible to infer the pointer type.
|
470
|
+
NVF_CHECK(!value.is<T>(), "Can not infer pointer type.");
|
471
|
+
} else if constexpr (std::is_same_v<T, StructHandle>) {
|
472
|
+
if (value.is<T>()) {
|
473
|
+
dtype = value.as<T>().type();
|
474
|
+
}
|
475
|
+
} else if constexpr (std::is_same_v<T, Opaque>) {
|
476
|
+
if (value.is<T>()) {
|
477
|
+
const auto& opaque = value.as<T>();
|
478
|
+
dtype = DataType(OpaqueType{
|
479
|
+
.type_info = opaque.any().type(), .size = opaque.size()});
|
480
|
+
}
|
481
|
+
}
|
482
|
+
});
|
483
|
+
NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name());
|
484
|
+
return dtype.value();
|
485
|
+
}
|
486
|
+
|
487
|
+
inline bool isCompatibleDataType(DataType dtype, DataType dtype2) {
|
488
|
+
if (dtype == dtype2) {
|
489
|
+
return true;
|
490
|
+
}
|
491
|
+
if (isIntegralType(dtype) && isIntegralType(dtype2)) {
|
492
|
+
return true;
|
493
|
+
}
|
494
|
+
if (isFloatingPointType(dtype) && isFloatingPointType(dtype2)) {
|
495
|
+
return true;
|
496
|
+
}
|
497
|
+
if (isComplexType(dtype) && isComplexType(dtype2)) {
|
498
|
+
return true;
|
499
|
+
}
|
500
|
+
if (std::holds_alternative<ArrayType>(dtype.type) &&
|
501
|
+
std::holds_alternative<ArrayType>(dtype2.type)) {
|
502
|
+
const auto& array_type = std::get<ArrayType>(dtype.type);
|
503
|
+
const auto& array_type2 = std::get<ArrayType>(dtype2.type);
|
504
|
+
return array_type.size == array_type2.size &&
|
505
|
+
isCompatibleDataType(*array_type.type, *array_type2.type);
|
506
|
+
}
|
507
|
+
if (std::holds_alternative<StructType>(dtype.type) &&
|
508
|
+
std::holds_alternative<StructType>(dtype2.type)) {
|
509
|
+
const auto& struct_type = std::get<StructType>(dtype.type);
|
510
|
+
const auto& struct_type2 = std::get<StructType>(dtype2.type);
|
511
|
+
if (struct_type.fields.size() != struct_type2.fields.size()) {
|
512
|
+
return false;
|
513
|
+
}
|
514
|
+
for (auto i : c10::irange(struct_type.fields.size())) {
|
515
|
+
if (struct_type.fields[i].name != struct_type2.fields[i].name ||
|
516
|
+
!isCompatibleDataType(
|
517
|
+
*struct_type.fields[i].type, *struct_type2.fields[i].type)) {
|
518
|
+
return false;
|
519
|
+
}
|
520
|
+
}
|
521
|
+
return true;
|
522
|
+
}
|
523
|
+
if (std::holds_alternative<OpaqueType>(dtype.type) &&
|
524
|
+
std::holds_alternative<OpaqueType>(dtype2.type)) {
|
525
|
+
const auto& opaque_type = std::get<OpaqueType>(dtype.type);
|
526
|
+
const auto& opaque_type2 = std::get<OpaqueType>(dtype2.type);
|
527
|
+
return opaque_type.type_info.get() == opaque_type2.type_info.get();
|
528
|
+
}
|
529
|
+
return false;
|
530
|
+
}
|
531
|
+
|
532
|
+
inline bool hasCompatibleDataType(
|
533
|
+
const PolymorphicValue& value,
|
534
|
+
DataType dtype) {
|
535
|
+
// We can not always completely infer data type from value, so we need some
|
536
|
+
// special handling here.
|
537
|
+
if (std::holds_alternative<PointerType>(dtype.type)) {
|
538
|
+
if (!value.is<Pointer>()) {
|
539
|
+
return false;
|
540
|
+
}
|
541
|
+
auto ptr = std::get<PointerType>(dtype.type);
|
542
|
+
return dataTypeSize(*ptr.type) == value.as<Pointer>().size();
|
543
|
+
} else if (std::holds_alternative<ArrayType>(dtype.type)) {
|
544
|
+
if (!value.is<std::vector>()) {
|
545
|
+
return false;
|
546
|
+
}
|
547
|
+
const auto& array_type = std::get<ArrayType>(dtype.type);
|
548
|
+
if (array_type.size != value.as<std::vector>().size()) {
|
549
|
+
return false;
|
550
|
+
}
|
551
|
+
if (array_type.size == 0) {
|
552
|
+
return true;
|
553
|
+
}
|
554
|
+
}
|
555
|
+
return isCompatibleDataType(getDataType(value), dtype);
|
556
|
+
}
|
557
|
+
|
558
|
+
#if defined(__GNUC__) && !defined(__clang__)
|
559
|
+
#pragma GCC diagnostic pop
|
560
|
+
#endif
|
561
|
+
|
562
|
+
//! Returns the number of base-10 digits required to guarantee a lossless
|
563
|
+
//! binary->text->binary round-trip. For exact types, this function returns 0.
|
564
|
+
int max_digits10(DataType dtype);
|
565
|
+
|
566
|
+
enum class UnaryOpType {
|
567
|
+
Cast,
|
568
|
+
BitCast,
|
569
|
+
RefCast,
|
570
|
+
|
571
|
+
Abs,
|
572
|
+
Acos,
|
573
|
+
Acosh,
|
574
|
+
Address,
|
575
|
+
Asin,
|
576
|
+
Asinh,
|
577
|
+
Atan,
|
578
|
+
Atanh,
|
579
|
+
Ceil,
|
580
|
+
Cos,
|
581
|
+
Cosh,
|
582
|
+
Dereference,
|
583
|
+
Exp,
|
584
|
+
Exp2,
|
585
|
+
Expm1,
|
586
|
+
Erf,
|
587
|
+
Erfc,
|
588
|
+
Erfinv,
|
589
|
+
Erfcinv,
|
590
|
+
Floor,
|
591
|
+
Frac,
|
592
|
+
Gelu,
|
593
|
+
Imag,
|
594
|
+
Silu,
|
595
|
+
Lgamma,
|
596
|
+
Log,
|
597
|
+
Log10,
|
598
|
+
Log1p,
|
599
|
+
Log2,
|
600
|
+
Neg,
|
601
|
+
Real,
|
602
|
+
Reciprocal,
|
603
|
+
Relu,
|
604
|
+
Rsqrt,
|
605
|
+
Round,
|
606
|
+
Sigmoid,
|
607
|
+
Signbit,
|
608
|
+
Sin,
|
609
|
+
Sinh,
|
610
|
+
Sqrt,
|
611
|
+
Tan,
|
612
|
+
Tanh,
|
613
|
+
Trunc,
|
614
|
+
|
615
|
+
// Tools to help debugging
|
616
|
+
Print,
|
617
|
+
|
618
|
+
// Logical and bitwise negation
|
619
|
+
LogicalNot,
|
620
|
+
BitwiseNot,
|
621
|
+
|
622
|
+
// Operators returning boolean values
|
623
|
+
IsFinite,
|
624
|
+
IsInf,
|
625
|
+
IsNan,
|
626
|
+
IsNegInf,
|
627
|
+
IsPosInf,
|
628
|
+
IsReal,
|
629
|
+
|
630
|
+
// Special unary ops
|
631
|
+
ElectSync,
|
632
|
+
ToUnsignedSmemAddr,
|
633
|
+
AdjustPartialLdMatrixAddrInTuring8,
|
634
|
+
AdjustPartialLdMatrixAddrInTuring16
|
635
|
+
};
|
636
|
+
|
637
|
+
// TODO: Order of this list is important as it affects type promotion. it's not
|
638
|
+
// in the right order now.
|
639
|
+
enum class BinaryOpType {
|
640
|
+
// Math Ops
|
641
|
+
Add,
|
642
|
+
Atan2,
|
643
|
+
Div,
|
644
|
+
Fmod,
|
645
|
+
Max,
|
646
|
+
Min,
|
647
|
+
Mul,
|
648
|
+
Nextafter,
|
649
|
+
Pow,
|
650
|
+
Remainder,
|
651
|
+
Sub,
|
652
|
+
// TypeAs,
|
653
|
+
|
654
|
+
// Integer output ops.
|
655
|
+
Mod,
|
656
|
+
CeilDiv,
|
657
|
+
Lshift,
|
658
|
+
Rshift,
|
659
|
+
Gcd,
|
660
|
+
|
661
|
+
// Bitwise Ops
|
662
|
+
// These always return integers, as if each arg is first cast to int
|
663
|
+
// If changing modify isIntegerOp.
|
664
|
+
BitwiseAnd,
|
665
|
+
BitwiseOr,
|
666
|
+
BitwiseXor,
|
667
|
+
|
668
|
+
// Logical Ops
|
669
|
+
// Int operations, leave position of Mod as first logical op see
|
670
|
+
// isLogicalOp(BinaryOpType bopt)
|
671
|
+
Eq,
|
672
|
+
GE,
|
673
|
+
GT,
|
674
|
+
LE,
|
675
|
+
LT,
|
676
|
+
NE,
|
677
|
+
|
678
|
+
// These ops compare as if each arg is first cast to bool
|
679
|
+
LogicalAnd,
|
680
|
+
LogicalOr,
|
681
|
+
|
682
|
+
// generate complex from real and imaginary parts
|
683
|
+
Complex
|
684
|
+
};
|
685
|
+
|
686
|
+
enum class ScatterOpType { Set };
|
687
|
+
|
688
|
+
enum class RNGOpType {
|
689
|
+
Uniform, // Uniform in [0, 1)
|
690
|
+
UniformRange, // Uniform in [low, high]
|
691
|
+
NormalStandard, // Normal with mean 0, std 1
|
692
|
+
NormalGeneral, // Normal with given mean and std
|
693
|
+
Undefined,
|
694
|
+
};
|
695
|
+
|
696
|
+
// Return if output of operator should be a boolean
|
697
|
+
bool isIntegerOp(const BinaryOpType bopt);
|
698
|
+
|
699
|
+
// Return if output of operator should be a boolean
|
700
|
+
bool isLogicalOp(const BinaryOpType bopt);
|
701
|
+
|
702
|
+
enum class TernaryOpType { Clamp, Lerp, Threshold, Where };
|
703
|
+
|
704
|
+
enum class ParallelType {
|
705
|
+
DIDx,
|
706
|
+
BIDz,
|
707
|
+
BIDy,
|
708
|
+
BIDx,
|
709
|
+
TIDz,
|
710
|
+
TIDy,
|
711
|
+
TIDx,
|
712
|
+
Stream,
|
713
|
+
Vectorize,
|
714
|
+
MisalignedVectorize,
|
715
|
+
Unroll,
|
716
|
+
Unswitch,
|
717
|
+
Mma,
|
718
|
+
Group,
|
719
|
+
Bulk,
|
720
|
+
Serial
|
721
|
+
};
|
722
|
+
|
723
|
+
std::unordered_set<ParallelType> allParallelTypesExcept(
|
724
|
+
const std::unordered_set<ParallelType>& except);
|
725
|
+
|
726
|
+
static constexpr std::array<ParallelType, 6> kParallelTypeThreads = {
|
727
|
+
ParallelType::BIDx,
|
728
|
+
ParallelType::BIDy,
|
729
|
+
ParallelType::BIDz,
|
730
|
+
ParallelType::TIDx,
|
731
|
+
ParallelType::TIDy,
|
732
|
+
ParallelType::TIDz};
|
733
|
+
|
734
|
+
static constexpr std::array<ParallelType, 3> kParallelTypeBIDs = {
|
735
|
+
ParallelType::BIDx,
|
736
|
+
ParallelType::BIDy,
|
737
|
+
ParallelType::BIDz};
|
738
|
+
|
739
|
+
static constexpr std::array<ParallelType, 3> kParallelTypeTIDs = {
|
740
|
+
ParallelType::TIDx,
|
741
|
+
ParallelType::TIDy,
|
742
|
+
ParallelType::TIDz};
|
743
|
+
|
744
|
+
static constexpr std::array<ParallelType, 1> kParallelTypeDIDs = {
|
745
|
+
ParallelType::DIDx};
|
746
|
+
|
747
|
+
enum class MemoryType { Local, Shared, Global, Tensor };
|
748
|
+
|
749
|
+
// Symbolic: Undetermined between Iteration or Broadcast
|
750
|
+
enum class IterType {
|
751
|
+
Iteration,
|
752
|
+
Reduction,
|
753
|
+
Broadcast,
|
754
|
+
Stride,
|
755
|
+
GatherScatter,
|
756
|
+
VectorComponent,
|
757
|
+
Symbolic
|
758
|
+
};
|
759
|
+
|
760
|
+
// Used for Iteration Domain mapping modes in ComputeAtMap
|
761
|
+
enum class IdMappingMode {
|
762
|
+
EXACT,
|
763
|
+
ALMOSTEXACT,
|
764
|
+
BROADCAST,
|
765
|
+
PERMISSIVE,
|
766
|
+
LOOP,
|
767
|
+
// TODO: Reconsider if this graph is really necessary
|
768
|
+
PERMISSIVE_RESIZE,
|
769
|
+
// TODO: Reconsider if this graph is really necessary
|
770
|
+
INNERMOST
|
771
|
+
};
|
772
|
+
|
773
|
+
static constexpr std::array<IdMappingMode, 7> kIdMappingModes = {
|
774
|
+
IdMappingMode::EXACT,
|
775
|
+
IdMappingMode::ALMOSTEXACT,
|
776
|
+
IdMappingMode::BROADCAST,
|
777
|
+
IdMappingMode::PERMISSIVE,
|
778
|
+
IdMappingMode::LOOP,
|
779
|
+
IdMappingMode::PERMISSIVE_RESIZE,
|
780
|
+
IdMappingMode::INNERMOST};
|
781
|
+
|
782
|
+
// See
|
783
|
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
|
784
|
+
// for what each option means. Will also consider .L1::no_allocate because .cs
|
785
|
+
// still pollutes cache to some extent.
|
786
|
+
enum class CacheOp {
|
787
|
+
Unspecified, // Opt in for the default cache operator or when the LoadStoreOp
|
788
|
+
// doesn't take a cache operator.
|
789
|
+
AllLevels,
|
790
|
+
Streaming,
|
791
|
+
Global,
|
792
|
+
};
|
793
|
+
|
794
|
+
//! Used to annotate the special memory intrinsics that a loadstore op will be
|
795
|
+
//! lowered to.
|
796
|
+
//!
|
797
|
+
//! SegmenterSet here is used to hint segmenter to break kernel on the output
|
798
|
+
//! of the node
|
799
|
+
enum class LoadStoreOpType {
|
800
|
+
Set,
|
801
|
+
SegmenterSet,
|
802
|
+
LdMatrix,
|
803
|
+
CpAsync,
|
804
|
+
CpAsyncBulk,
|
805
|
+
CpAsyncBulkTensorTile,
|
806
|
+
StMatrix,
|
807
|
+
LdTMem,
|
808
|
+
StTMem
|
809
|
+
};
|
810
|
+
|
811
|
+
// Used to label what part of the circular buffered iterdomain
|
812
|
+
// a for loop is materializing.
|
813
|
+
enum class CircularBufferLoopStage {
|
814
|
+
Prolog = 0,
|
815
|
+
Main,
|
816
|
+
Epilog,
|
817
|
+
LoadWarp,
|
818
|
+
ComputeWarp,
|
819
|
+
EndOfStages, // A special placeholder used to iterate over all stages
|
820
|
+
NotApplicable
|
821
|
+
};
|
822
|
+
|
823
|
+
// The circular buffer load expressions are cloned for these circular buffer
|
824
|
+
// loop types.
|
825
|
+
// e.g., No additional loads are required for the Epilogue stage.
|
826
|
+
inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) {
|
827
|
+
return stage == CircularBufferLoopStage::Prolog ||
|
828
|
+
stage == CircularBufferLoopStage::Main ||
|
829
|
+
stage == CircularBufferLoopStage::LoadWarp;
|
830
|
+
}
|
831
|
+
|
832
|
+
// The consuming expressions of circular buffer are cloned for these circular
|
833
|
+
// buffer loop types.
|
834
|
+
// e.g., No actual computation occurs in the Prologue stage.
|
835
|
+
inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) {
|
836
|
+
return stage == CircularBufferLoopStage::Main ||
|
837
|
+
stage == CircularBufferLoopStage::Epilog ||
|
838
|
+
stage == CircularBufferLoopStage::ComputeWarp;
|
839
|
+
}
|
840
|
+
|
841
|
+
// A loop type may have WAR hazard if any of the following is true:
|
842
|
+
// - The load *in this loop type* may overwrite a buffer being read by a
|
843
|
+
// compute somewhere (*may or may not be in this loop*)
|
844
|
+
// - The compute *in this loop type* reads circular buffer TVs that, if not
|
845
|
+
// properly handled, could be overwriten by a circular buffer loading
|
846
|
+
// somewhere (*may or may not be in this loop*)
|
847
|
+
inline bool mayHaveWarHazard(CircularBufferLoopStage stage) {
|
848
|
+
return stage == CircularBufferLoopStage::Main ||
|
849
|
+
stage == CircularBufferLoopStage::LoadWarp ||
|
850
|
+
stage == CircularBufferLoopStage::ComputeWarp;
|
851
|
+
}
|
852
|
+
|
853
|
+
//! Supported swizzle types,
|
854
|
+
//! corresponds to swizzles functions on the runtime cuda
|
855
|
+
//! naming it swizzle_2d to reserve the options to have a swizzle_1d.
|
856
|
+
//!
|
857
|
+
//! TODO: unify with existing swizzle logic, currently
|
858
|
+
//! doesn't have the same type.
|
859
|
+
enum class SwizzleType { NoSwizzle = 0, XOR, CyclicShift };
|
860
|
+
enum class Swizzle2DType { NoSwizzle = 0, ZShape, XOR, CyclicShift };
|
861
|
+
|
862
|
+
//! Modes of swizzle, see [Note on swizzle mode].
|
863
|
+
enum class SwizzleMode { NoSwizzle = 0, Data, Loop };
|
864
|
+
|
865
|
+
// Returns if function needs an f suffix on the operator when operating on a
|
866
|
+
// float value i.e. sin->sinf
|
867
|
+
bool needFloatSuffix(UnaryOpType t);
|
868
|
+
bool needFloatSuffix(BinaryOpType t);
|
869
|
+
bool needFloatSuffix(RNGOpType t);
|
870
|
+
|
871
|
+
ValType promoteType(const ValType& t1, const ValType& t2);
|
872
|
+
|
873
|
+
#define HANDLE_TYPE_PROMOTION(Type1, Type2) \
|
874
|
+
if (t1 == NativeTypeToDataType<Type1>::type && \
|
875
|
+
t2 == NativeTypeToDataType<Type2>::type) { \
|
876
|
+
return NativeTypeToDataType<std::common_type_t<Type1, Type2>>::type; \
|
877
|
+
}
|
878
|
+
|
879
|
+
#define HANDLE_TYPE_PROMOTION1(Type1) \
|
880
|
+
HANDLE_TYPE_PROMOTION(Type1, float); \
|
881
|
+
HANDLE_TYPE_PROMOTION(Type1, double); \
|
882
|
+
HANDLE_TYPE_PROMOTION(Type1, int64_t); \
|
883
|
+
HANDLE_TYPE_PROMOTION(Type1, int); \
|
884
|
+
HANDLE_TYPE_PROMOTION(Type1, bool); \
|
885
|
+
HANDLE_TYPE_PROMOTION(Type1, std::complex<float>); \
|
886
|
+
HANDLE_TYPE_PROMOTION(Type1, std::complex<double>)
|
887
|
+
|
888
|
+
inline DataType promoteType(const DataType& t1, const DataType& t2) {
|
889
|
+
if (t1 == t2) {
|
890
|
+
return t1;
|
891
|
+
}
|
892
|
+
// pointer +- integer = pointer
|
893
|
+
if (isPointerType(t1) && isIntegralType(t2)) {
|
894
|
+
return t1;
|
895
|
+
}
|
896
|
+
if (isPointerType(t2) && isIntegralType(t1)) {
|
897
|
+
return t2;
|
898
|
+
}
|
899
|
+
// When seeing DataType::Index, assuming we are computing index, so propagate
|
900
|
+
// DataType::Index
|
901
|
+
if ((t1 == DataType::Index && isIntegralType(t2)) ||
|
902
|
+
(t2 == DataType::Index && isIntegralType(t1))) {
|
903
|
+
return DataType::Index;
|
904
|
+
}
|
905
|
+
// Workaround a case where C++ and ATen have different type promotion rules
|
906
|
+
if ((t1 == DataType::Double && t2 == DataType::ComplexFloat) ||
|
907
|
+
(t2 == DataType::Double && t1 == DataType::ComplexFloat)) {
|
908
|
+
// WARNING: ATen and C++ behave differently for this case. ATen returns
|
909
|
+
// DataType::ComplexDouble but C++ returns DataType::ComplexFloat. Right now
|
910
|
+
// we choose to be consistent with ATen.
|
911
|
+
// TODO: I am pretty sure that for some cases we would need C++'s promotion
|
912
|
+
// rule, for example, when we are simplifying scalar expressions, and for
|
913
|
+
// other cases, we need ATen's promotion rule, for example, when we define
|
914
|
+
// fusion from ATen graph. Fortunately, right now this is the only case to
|
915
|
+
// worry about, and I don't think in practice, using ATen's rule would cause
|
916
|
+
// any trouble.
|
917
|
+
return DataType::ComplexDouble;
|
918
|
+
}
|
919
|
+
// Use C++ promotion rule when dtype has a native C++ type
|
920
|
+
HANDLE_TYPE_PROMOTION1(float);
|
921
|
+
HANDLE_TYPE_PROMOTION1(double);
|
922
|
+
HANDLE_TYPE_PROMOTION1(int64_t);
|
923
|
+
HANDLE_TYPE_PROMOTION1(int);
|
924
|
+
HANDLE_TYPE_PROMOTION1(bool);
|
925
|
+
HANDLE_TYPE_PROMOTION1(std::complex<float>);
|
926
|
+
HANDLE_TYPE_PROMOTION1(std::complex<double>);
|
927
|
+
// double + half/bfloat16 = double
|
928
|
+
if ((t1 == DataType::Double && isFloatingPointType(t2)) ||
|
929
|
+
(t2 == DataType::Double && isFloatingPointType(t1))) {
|
930
|
+
return DataType::Double;
|
931
|
+
}
|
932
|
+
// float + half/bfloat16 = float
|
933
|
+
// half + bfloat16 = float
|
934
|
+
if (isFloatingPointType(t1) && isFloatingPointType(t2)) {
|
935
|
+
return DataType::Float;
|
936
|
+
}
|
937
|
+
// complex + half/bfloat16 = complex
|
938
|
+
if (isComplexType(t1)) {
|
939
|
+
return t1;
|
940
|
+
}
|
941
|
+
if (isComplexType(t2)) {
|
942
|
+
return t2;
|
943
|
+
}
|
944
|
+
// half + integers/bool = half
|
945
|
+
// bfloat16 + integers/bool = bfloat16
|
946
|
+
if (isFloatingPointType(t1)) {
|
947
|
+
return t1;
|
948
|
+
}
|
949
|
+
if (isFloatingPointType(t2)) {
|
950
|
+
return t2;
|
951
|
+
}
|
952
|
+
NVF_CHECK(false, "Expected promotable DataTypes but got: ", t1, " and ", t2);
|
953
|
+
}
|
954
|
+
|
955
|
+
#undef HANDLE_TYPE_PROMOTION
|
956
|
+
#undef HANDLE_TYPE_PROMOTION1
|
957
|
+
|
958
|
+
template <typename... Args>
|
959
|
+
inline DataType promoteType(
|
960
|
+
const DataType& t1,
|
961
|
+
const DataType& t2,
|
962
|
+
const Args&... args) {
|
963
|
+
return promoteType(t1, promoteType(t2, promoteType(args...)));
|
964
|
+
}
|
965
|
+
|
966
|
+
inline DataType promoteType(const std::vector<DataType>& types) {
|
967
|
+
NVF_CHECK(!types.empty(), "Can not promote empty type vector")
|
968
|
+
DataType result = types.at(0);
|
969
|
+
for (const auto& t : types) {
|
970
|
+
result = promoteType(result, t);
|
971
|
+
}
|
972
|
+
return result;
|
973
|
+
}
|
974
|
+
|
975
|
+
// If type cannot be found (i.e. codegen does not support provided type) returns
|
976
|
+
// DataType::Null
|
977
|
+
NVF_API DataType aten_to_data_type(const at::ScalarType& scalar_type);
|
978
|
+
NVF_API at::ScalarType data_type_to_aten(const DataType& data_type);
|
979
|
+
|
980
|
+
NVF_API std::ostream& operator<<(std::ostream&, const ValType);
|
981
|
+
std::ostream& operator<<(std::ostream&, const PredicateType);
|
982
|
+
NVF_API std::ostream& operator<<(std::ostream&, const DataType);
|
983
|
+
std::ostream& operator<<(std::ostream&, const UnaryOpType);
|
984
|
+
NVF_API std::ostream& operator<<(std::ostream&, const BinaryOpType);
|
985
|
+
std::ostream& operator<<(std::ostream&, const TernaryOpType);
|
986
|
+
std::ostream& operator<<(std::ostream&, const ScatterOpType);
|
987
|
+
std::ostream& operator<<(std::ostream&, const RNGOpType);
|
988
|
+
NVF_API std::ostream& operator<<(std::ostream&, const ParallelType);
|
989
|
+
NVF_API std::ostream& operator<<(std::ostream&, const MemoryType);
|
990
|
+
NVF_API std::ostream& operator<<(std::ostream&, const IterType);
|
991
|
+
std::ostream& operator<<(std::ostream&, const IdMappingMode);
|
992
|
+
NVF_API std::ostream& operator<<(std::ostream&, const LoadStoreOpType);
|
993
|
+
std::ostream& operator<<(std::ostream&, const CircularBufferLoopStage);
|
994
|
+
std::ostream& operator<<(std::ostream&, const SwizzleType&);
|
995
|
+
std::ostream& operator<<(std::ostream&, const Swizzle2DType&);
|
996
|
+
std::ostream& operator<<(std::ostream&, const SwizzleMode&);
|
997
|
+
std::ostream& operator<<(std::ostream&, const KernelIndexMode&);
|
998
|
+
NVF_API std::ostream& operator<<(std::ostream&, const CacheOp&);
|
999
|
+
std::ostream& operator<<(std::ostream& os, const std::optional<bool>&);
|
1000
|
+
|
1001
|
+
std::string stringifyThreadSize(const ParallelType);
|
1002
|
+
std::string stringifyThread(const ParallelType);
|
1003
|
+
std::string typePrefix(const DataType);
|
1004
|
+
|
1005
|
+
// TODO: ThreadDim should be BlockDim and BlockDim should be GridDim
|
1006
|
+
// Returns if parallel type is TID[x, y, z]
|
1007
|
+
NVF_API bool isParallelTypeThreadDim(ParallelType);
|
1008
|
+
// Returns if parallel type is BID[x, y, z]
|
1009
|
+
NVF_API bool isParallelTypeBlockDim(ParallelType);
|
1010
|
+
// Returns if parallel type is a grid or block parallelization dimension
|
1011
|
+
NVF_API bool isParallelTypeThread(ParallelType);
|
1012
|
+
// Returns if parallel type is DIDx
|
1013
|
+
NVF_API bool isParallelTypeDeviceDim(ParallelType);
|
1014
|
+
|
1015
|
+
NVF_API bool isParallelTypeVectorize(ParallelType);
|
1016
|
+
|
1017
|
+
std::optional<std::string> inline_op_str(const UnaryOpType);
|
1018
|
+
std::optional<std::string> inline_op_str(const BinaryOpType);
|
1019
|
+
std::optional<std::string> inline_op_str(const RNGOpType);
|
1020
|
+
std::optional<std::string> integer_op_str(const BinaryOpType);
|
1021
|
+
std::optional<std::string> bool_op_str(const BinaryOpType);
|
1022
|
+
const char* predicate_type2string(PredicateType t);
|
1023
|
+
const char* load_store_type2string(LoadStoreOpType t);
|
1024
|
+
|
1025
|
+
std::optional<std::string> cast_func_str(const std::pair<DataType, DataType>&);
|
1026
|
+
|
1027
|
+
constexpr inline size_t primDataTypeSize(PrimDataType type) {
|
1028
|
+
switch (type) {
|
1029
|
+
case DataType::Bool:
|
1030
|
+
return sizeof(bool);
|
1031
|
+
case DataType::ComplexDouble:
|
1032
|
+
return sizeof(std::complex<double>);
|
1033
|
+
case DataType::ComplexFloat:
|
1034
|
+
return sizeof(std::complex<float>);
|
1035
|
+
case DataType::Double:
|
1036
|
+
return sizeof(double);
|
1037
|
+
case DataType::Float:
|
1038
|
+
return sizeof(float);
|
1039
|
+
case DataType::Half:
|
1040
|
+
return sizeof(at::Half);
|
1041
|
+
case DataType::BFloat16:
|
1042
|
+
return sizeof(at::BFloat16);
|
1043
|
+
case DataType::Float8_e4m3fn:
|
1044
|
+
return sizeof(at::Float8_e4m3fn);
|
1045
|
+
case DataType::Float8_e5m2:
|
1046
|
+
return sizeof(at::Float8_e5m2);
|
1047
|
+
case DataType::Index:
|
1048
|
+
NVF_THROW("The actual type of Index is only known at compile time.");
|
1049
|
+
case DataType::Char:
|
1050
|
+
return sizeof(int8_t);
|
1051
|
+
case DataType::Short:
|
1052
|
+
return sizeof(int16_t);
|
1053
|
+
case DataType::Int32:
|
1054
|
+
return sizeof(int32_t);
|
1055
|
+
case DataType::Int:
|
1056
|
+
return sizeof(int64_t);
|
1057
|
+
case DataType::Byte:
|
1058
|
+
return sizeof(uint8_t);
|
1059
|
+
case DataType::UInt16:
|
1060
|
+
return sizeof(uint16_t);
|
1061
|
+
case DataType::UInt32:
|
1062
|
+
case DataType::SMemAddress:
|
1063
|
+
case DataType::TMemAddress:
|
1064
|
+
return sizeof(uint32_t);
|
1065
|
+
case DataType::UInt64:
|
1066
|
+
return sizeof(uint64_t);
|
1067
|
+
default:
|
1068
|
+
NVF_THROW("Size undefined for data type.");
|
1069
|
+
}
|
1070
|
+
}
|
1071
|
+
|
1072
|
+
enum class LaunchConfigType {
|
1073
|
+
Compatible,
|
1074
|
+
SharedMemory,
|
1075
|
+
BIDz,
|
1076
|
+
BIDy,
|
1077
|
+
BIDx,
|
1078
|
+
TIDz,
|
1079
|
+
TIDy,
|
1080
|
+
TIDx
|
1081
|
+
};
|
1082
|
+
|
1083
|
+
const char* const kMagicZeroName = "nvfuser_zero";
|
1084
|
+
|
1085
|
+
//! Maximum number of reductions that can be grouped together. The
|
1086
|
+
//! limit can be increased by extending struct Tuple define in tuple.cu.
|
1087
|
+
static constexpr int kMaxNumGroupedReductions = 16;
|
1088
|
+
|
1089
|
+
Pointer::Pointer(void* ptr, DataType dtype)
|
1090
|
+
: ptr_(reinterpret_cast<std::byte*>(ptr)), size_(dataTypeSize(dtype)) {}
|
1091
|
+
|
1092
|
+
inline PolymorphicValue castToDtype(
|
1093
|
+
PolymorphicValue value,
|
1094
|
+
const DataType& dtype) {
|
1095
|
+
if (!value.hasValue()) {
|
1096
|
+
return value;
|
1097
|
+
}
|
1098
|
+
// Cast the given value to the given data type. This enables interface
|
1099
|
+
// like: IrBuilder::create<Val>(0, DataType::Double) where value is
|
1100
|
+
// an integer but the desired data type is double.
|
1101
|
+
if (!hasCompatibleDataType(value, dtype)) {
|
1102
|
+
PolymorphicValue::for_all_types([&](auto _) {
|
1103
|
+
using T = typename decltype(_)::type;
|
1104
|
+
if constexpr (IsPrimitiveNativeType<T>::value) {
|
1105
|
+
if (isCompatibleDataType(NativeTypeToDataType<T>::type, dtype)) {
|
1106
|
+
value = PolymorphicValue(static_cast<T>(value));
|
1107
|
+
}
|
1108
|
+
}
|
1109
|
+
// TODO: support arrays and pointers
|
1110
|
+
});
|
1111
|
+
}
|
1112
|
+
return value;
|
1113
|
+
}
|
1114
|
+
|
1115
|
+
// Converts an enum to its underlying type.
|
1116
|
+
// It corresponds with std::to_underlying introduced in c++23
|
1117
|
+
// https://en.cppreference.com/w/cpp/utility/to_underlying
|
1118
|
+
template <typename E>
|
1119
|
+
constexpr auto toUnderlying(E e) noexcept {
|
1120
|
+
return static_cast<std::underlying_type_t<E>>(e);
|
1121
|
+
}
|
1122
|
+
|
1123
|
+
enum class AsyncOpType { NotAsync, CpAsync, CpAsyncBulk, WgMma };
|
1124
|
+
|
1125
|
+
} // namespace nvfuser
|