nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,130 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <visibility.h>
12
+
13
+ #include <ir/interface_nodes.h>
14
+ #include <type.h>
15
+
16
+ //
17
+ // The operations defined in this header is intended as user facing functions.
18
+ // The user will provide the necessary input TensorViews and the function will
19
+ // create the correct intermediate nodes and return the output TensorViews.
20
+ //
21
+
22
+ namespace nvfuser {
23
+
24
+ struct ForwardDropoutResult {
25
+ TensorView* output = nullptr;
26
+ TensorView* mask = nullptr;
27
+ };
28
+
29
+ NVF_API ForwardDropoutResult dropout(TensorView* x, Val* prob);
30
+
31
+ NVF_API ForwardDropoutResult dropout(TensorView* x, Val* prob, Val* scale);
32
+
33
+ NVF_API TensorView* dropout_backward(
34
+ TensorView* dy,
35
+ TensorView* mask,
36
+ Val* scale);
37
+
38
+ NVF_API TensorView* triu(TensorView* tv, Val* offset);
39
+
40
+ struct LstmResult {
41
+ TensorView* cell = nullptr;
42
+ TensorView* hidden = nullptr;
43
+ };
44
+
45
+ NVF_API LstmResult lstm(
46
+ TensorView* prev_cell,
47
+ TensorView* in_x,
48
+ TensorView* forget_x,
49
+ TensorView* cell_x,
50
+ TensorView* out_x);
51
+
52
+ // Linear functions which takes in two tensors of shapes input[* , in_features],
53
+ // weight[out_features, in_features] / [in_features] and an optional bias of
54
+ // shape [out_features] or 0D scalar. Bias can only be given if weight is a 2-D
55
+ // tensor.
56
+ TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias);
57
+ // This is an implementation detail to reflect when linear is called
58
+ // without a bias. This calls the above function. We use this function
59
+ // since it simplifies creating a Python API which takes optional arguments.
60
+ // Other options include using lambdas or creating a new RecordFunctor for
61
+ // Linear.
62
+ TensorView* linear(TensorView* input, TensorView* weight);
63
+
64
+ NVF_API TensorView* sign(TensorView* x);
65
+ NVF_API Val* sign(Val* x);
66
+ TensorView* softplus(TensorView* x, Val* beta, Val* threshold);
67
+ NVF_API TensorView* gelu(TensorView* x);
68
+ NVF_API TensorView* gelu_backward(TensorView* dy, TensorView* x);
69
+ TensorView* tanh_gelu(TensorView* x);
70
+ TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x);
71
+ TensorView* tanh_backward(TensorView* dy, TensorView* tanh_x);
72
+ TensorView* leaky_relu(TensorView* x, Val* negative_slope);
73
+
74
+ NVF_API TensorView* view_as_real(TensorView* x);
75
+
76
+ // Matmul function which takes in tensors with the shapes
77
+ // A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different
78
+ // layouts via strides. This has the same functionality as torch.matmul
79
+ TensorView* matmul(TensorView* tv_a, TensorView* tv_b);
80
+
81
+ // Scaled Dot Product Flash Attention Forward Result
82
+ struct SdpfaFwdResult {
83
+ TensorView* output = nullptr;
84
+ TensorView* log_sumexp = nullptr;
85
+ TensorView* philox_seed = nullptr;
86
+ TensorView* philox_offset = nullptr;
87
+ };
88
+
89
+ // Scaled Dot Product Flash Attention Forward API.
90
+ // Returns the same output as at::_scaled_dot_product_flash_attention
91
+ SdpfaFwdResult sdpfa_fwd(
92
+ TensorView* query,
93
+ TensorView* key,
94
+ TensorView* value,
95
+ Val* dropout_p,
96
+ Val* is_causal,
97
+ Val* scale);
98
+
99
+ // Scaled Dot Product Flash Attention Backward Result
100
+ struct SdpfaBwdResult {
101
+ TensorView* grad_query = nullptr;
102
+ TensorView* grad_key = nullptr;
103
+ TensorView* grad_value = nullptr;
104
+ };
105
+
106
+ // Scaled Dot Product Flash Attention Backward API.
107
+ // Returns the same output as at::_scaled_dot_product_flash_attention_backward
108
+ SdpfaBwdResult sdpfa_bwd(
109
+ TensorView* grad_output,
110
+ TensorView* query,
111
+ TensorView* key,
112
+ TensorView* value,
113
+ TensorView* output,
114
+ TensorView* log_sumexp,
115
+ Val* dropout_p,
116
+ Val* is_causal,
117
+ TensorView* philox_seed,
118
+ TensorView* philox_offset,
119
+ Val* scale);
120
+
121
+ TensorView* embedding_fwd(
122
+ TensorView* input,
123
+ TensorView* weight,
124
+ Val* padding_idx,
125
+ Val* max_norm,
126
+ Val* norm_type,
127
+ Val* scale_grad_by_freq,
128
+ Val* sparse);
129
+
130
+ } // namespace nvfuser
@@ -0,0 +1,55 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <visibility.h>
12
+
13
+ #include <ir/interface_nodes.h>
14
+ #include <type.h>
15
+
16
+ namespace nvfuser {
17
+
18
+ NVF_API TensorView* select(TensorView* tv, int64_t dim, Val* index);
19
+
20
+ // torch.index_select
21
+ NVF_API TensorView* indexSelect(
22
+ TensorView* input,
23
+ int64_t dim,
24
+ TensorView* index);
25
+
26
+ // torch.gather
27
+ NVF_API TensorView* torchGather(
28
+ TensorView* input,
29
+ int64_t dim,
30
+ TensorView* index);
31
+
32
+ // torch.scatter
33
+ TensorView* scatterOp(
34
+ ScatterOpType type,
35
+ TensorView* self,
36
+ int64_t dim,
37
+ TensorView* index,
38
+ TensorView* src);
39
+
40
+ NVF_API TensorView* scatter(
41
+ TensorView* self,
42
+ int64_t dim,
43
+ TensorView* index,
44
+ TensorView* src);
45
+
46
+ //! numpy.take_along_axis
47
+ //! (https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html)
48
+ //! Note the order of the parameters follows the numpy order, which is
49
+ //! different from torchGather.
50
+ NVF_API TensorView* takeAlongAxis(
51
+ TensorView* input,
52
+ TensorView* index,
53
+ int64_t dim);
54
+
55
+ } // namespace nvfuser
@@ -0,0 +1,263 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <visibility.h>
12
+
13
+ #include <ir/interface_nodes.h>
14
+ #include <type.h>
15
+
16
+ #include <tuple>
17
+ #include <vector>
18
+
19
+ //
20
+ // The operations defined in this header is intended as user facing functions.
21
+ // The user will provide the necessary input TensorViews and the function will
22
+ // create the correct intermediate nodes and return the output TensorViews.
23
+ //
24
+
25
+ namespace nvfuser {
26
+
27
+ struct ForwardNormResult {
28
+ TensorView* output = nullptr;
29
+ TensorView* mean = nullptr;
30
+ TensorView* invstd = nullptr;
31
+ };
32
+
33
+ struct BackwardNormResult {
34
+ TensorView* grad_input = nullptr;
35
+ TensorView* grad_weight = nullptr;
36
+ TensorView* grad_bias = nullptr;
37
+ };
38
+
39
+ struct ForwardRMSNormResult {
40
+ TensorView* output = nullptr;
41
+ TensorView* invstd = nullptr;
42
+ };
43
+
44
+ struct BackwardRMSNormResult {
45
+ TensorView* grad_input = nullptr;
46
+ TensorView* grad_weight = nullptr;
47
+ };
48
+
49
+ struct VarMeanResult {
50
+ TensorView* var = nullptr;
51
+ TensorView* mean = nullptr;
52
+ };
53
+
54
+ } // namespace nvfuser
55
+
56
+ namespace std {
57
+
58
+ // Make these results behave like a std::tuple
59
+ using nvfuser::BackwardNormResult;
60
+ using nvfuser::BackwardRMSNormResult;
61
+ using nvfuser::ForwardNormResult;
62
+ using nvfuser::ForwardRMSNormResult;
63
+ using nvfuser::TensorView;
64
+ using nvfuser::VarMeanResult;
65
+
66
+ template <int i>
67
+ constexpr TensorView* get(const ForwardNormResult& results) {
68
+ if (i == 0) {
69
+ return results.output;
70
+ }
71
+ if (i == 1) {
72
+ return results.mean;
73
+ }
74
+ if (i == 2) {
75
+ return results.invstd;
76
+ }
77
+ return nullptr;
78
+ }
79
+
80
+ template <int i>
81
+ constexpr TensorView* get(const BackwardNormResult& results) {
82
+ if (i == 0) {
83
+ return results.grad_input;
84
+ }
85
+ if (i == 1) {
86
+ return results.grad_weight;
87
+ }
88
+ if (i == 2) {
89
+ return results.grad_bias;
90
+ }
91
+ return nullptr;
92
+ }
93
+
94
+ template <int i>
95
+ constexpr TensorView* get(const ForwardRMSNormResult& results) {
96
+ if (i == 0) {
97
+ return results.output;
98
+ }
99
+ if (i == 1) {
100
+ return results.invstd;
101
+ }
102
+ return nullptr;
103
+ }
104
+
105
+ template <int i>
106
+ constexpr TensorView* get(const BackwardRMSNormResult& results) {
107
+ if (i == 0) {
108
+ return results.grad_input;
109
+ }
110
+ if (i == 1) {
111
+ return results.grad_weight;
112
+ }
113
+ return nullptr;
114
+ }
115
+
116
+ template <int i>
117
+ constexpr TensorView* get(const VarMeanResult& results) {
118
+ if (i == 0) {
119
+ return results.var;
120
+ }
121
+ if (i == 1) {
122
+ return results.mean;
123
+ }
124
+ return nullptr;
125
+ }
126
+
127
+ } // namespace std
128
+
129
+ namespace nvfuser {
130
+
131
+ TensorView* mean(TensorView* x, const std::vector<int64_t>& dims, bool keepdim);
132
+
133
+ NVF_API TensorView* variance(
134
+ TensorView* x,
135
+ const std::vector<int64_t>& dims,
136
+ bool unbiased,
137
+ bool keepdim);
138
+
139
+ NVF_API TensorView* variance(
140
+ TensorView* x,
141
+ const std::vector<int64_t>& dims,
142
+ int64_t correction,
143
+ bool keepdim);
144
+
145
+ NVF_API VarMeanResult variance_mean(
146
+ TensorView* x,
147
+ const std::vector<int64_t>& dims,
148
+ int64_t correction,
149
+ bool keepdim);
150
+
151
+ NVF_API TensorView* standard_deviation(
152
+ TensorView* x,
153
+ const std::vector<int64_t>& dims,
154
+ bool unbiased,
155
+ bool keepdim);
156
+
157
+ NVF_API TensorView* softmax(TensorView* x, int64_t dim);
158
+
159
+ NVF_API TensorView* softmax_backward(
160
+ TensorView* dy,
161
+ TensorView* y,
162
+ const int64_t dim);
163
+
164
+ NVF_API TensorView* log_softmax(TensorView* x, int64_t dim);
165
+
166
+ NVF_API TensorView* log_softmax_backward(
167
+ TensorView* dy,
168
+ TensorView* y,
169
+ const int64_t dim);
170
+
171
+ NVF_API ForwardNormResult layer_norm(
172
+ TensorView* x,
173
+ const std::vector<int64_t>& norm_shape,
174
+ TensorView* weight,
175
+ TensorView* bias,
176
+ Val* eps);
177
+
178
+ NVF_API ForwardNormResult layer_norm(
179
+ TensorView* x,
180
+ const int64_t kNormShapeNumDims,
181
+ TensorView* weight,
182
+ TensorView* bias,
183
+ Val* eps);
184
+
185
+ NVF_API ForwardRMSNormResult rms_norm(
186
+ TensorView* x,
187
+ const std::vector<int64_t>& norm_shape,
188
+ TensorView* weight,
189
+ Val* eps);
190
+
191
+ NVF_API ForwardRMSNormResult rms_norm(
192
+ TensorView* x,
193
+ const int64_t kNormShapeNumDims,
194
+ TensorView* weight,
195
+ Val* eps);
196
+
197
+ NVF_API BackwardNormResult layer_norm_backward(
198
+ TensorView* dy,
199
+ TensorView* x,
200
+ const std::vector<int64_t>& norm_shape,
201
+ TensorView* mean,
202
+ TensorView* rstd,
203
+ TensorView* weight,
204
+ TensorView* bias,
205
+ const std::vector<bool>& output_mask);
206
+
207
+ NVF_API BackwardRMSNormResult rms_norm_backward(
208
+ TensorView* dy,
209
+ TensorView* x,
210
+ const std::vector<int64_t>& norm_shape,
211
+ TensorView* rstd,
212
+ TensorView* weight,
213
+ const std::vector<bool>& output_mask);
214
+
215
+ NVF_API ForwardNormResult batch_norm(
216
+ TensorView* x,
217
+ TensorView* weight,
218
+ TensorView* bias,
219
+ TensorView* running_mean,
220
+ TensorView* running_var,
221
+ const bool kTraining,
222
+ Val* momentum,
223
+ Val* eps,
224
+ bool channels_last = false);
225
+
226
+ NVF_API BackwardNormResult batch_norm_backward(
227
+ TensorView* x,
228
+ TensorView* dy,
229
+ TensorView* weight,
230
+ TensorView* running_mean,
231
+ TensorView* running_var,
232
+ TensorView* save_mean,
233
+ TensorView* save_invstd,
234
+ const bool kTraining,
235
+ Val* eps,
236
+ const std::vector<bool>& output_mask,
237
+ bool channels_last = false);
238
+
239
+ NVF_API ForwardNormResult instance_norm(
240
+ TensorView* x,
241
+ TensorView* weight,
242
+ TensorView* bias,
243
+ TensorView* running_mean,
244
+ TensorView* running_var,
245
+ const bool kUseInputStats, // kTraining?
246
+ Val* momentum,
247
+ Val* eps,
248
+ bool channels_last = false);
249
+
250
+ NVF_API BackwardNormResult instance_norm_backward(
251
+ TensorView* x,
252
+ TensorView* dy,
253
+ TensorView* weight,
254
+ TensorView* running_mean,
255
+ TensorView* running_var,
256
+ TensorView* save_mean,
257
+ TensorView* save_invstd,
258
+ const bool kTraining,
259
+ Val* eps,
260
+ const std::vector<bool>& output_mask,
261
+ bool channels_last = false);
262
+
263
+ } // namespace nvfuser
@@ -0,0 +1,127 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/base_nodes.h>
12
+ #include <ir/interface_nodes.h>
13
+ #include <scheduler/matmul_utils.h>
14
+ #include <type.h>
15
+ #include <visibility.h>
16
+
17
+ #include <vector>
18
+
19
+ namespace nvfuser {
20
+
21
+ enum class AttnRole { Q = 0, K, V, Mask };
22
+
23
+ namespace ops {
24
+
25
+ TensorView* maybe_broadcast_inner_to_rank(TensorView* t, size_t rank);
26
+
27
+ // A utility function that broadcasts index TensorView to the rank of the other
28
+ // TensorView.
29
+ TensorView* maybeBroadcastIndexTv(TensorView* t, size_t dim, size_t rank);
30
+
31
+ // A utility function that checks if index tv is already broadcasted to correct
32
+ // shape for index_select
33
+ bool isIndexAlreadyBroadcast(
34
+ const std::vector<IterDomain*>& index_domain,
35
+ size_t dim,
36
+ size_t rank);
37
+
38
+ Val* simplifiedInt(Val* val);
39
+
40
+ // If one size is nullptr, return the other. If both symbolic just return v1. If
41
+ // one's concrete, prefer that one (simplified). If both concrete make sure
42
+ // they're the same size.
43
+ Val* promoteSize(Val* v1, Val* v2);
44
+
45
+ // Will return a new value of type val with the DataType dtype.
46
+ Val* newScalar(ValType vtype, DataType dtype);
47
+
48
+ IterType promoteIterType(IterType type1, IterType type2);
49
+
50
+ // For MatmulOp, the input iterdomains at a given index do not necessarily map
51
+ // to the output iterdomain at that index This function aligns the input
52
+ // iterdomain to the output and returns a vector where each element is the input
53
+ // iterdomain corresponding to the output iterdomain at that index. If the
54
+ // element is nullptr, there is no mapping between input-output at that index.
55
+ // Based on the input dimensions following cases are possible:
56
+ // 1. A/B is 1D: [M, K] x [K] -> [M] (Mapping A: {id_M}, Mapping B: {nullptr})
57
+ // or [K] x [N, K] -> [N] (Mapping A: {nullptr}, Mapping B: {id_N})
58
+ // 2. A and B are 2D: [M, K] x [K, N] -> [M, N] (Mapping A: {id_M, nullptr},
59
+ // Mapping B: {nullptr, id_N})
60
+ // 3. A/B are atleast 1D and one of them is > 2D: [B, M, K] x [K, N] -> [B, M,
61
+ // N] (Mapping A: {id_B, id_M, nullptr}, Mapping B: {nullptr, nullptr, id_N})
62
+ // Args:
63
+ // 1. input_domain: root/logical domain without reductions for any input to
64
+ // MatmulOp
65
+ // 2. input_position: Specifies if the input is A / B (0 or 1)
66
+ // 3: out_size: MatmulOp output dimension (input and output may not be the same
67
+ // size).
68
+ std::vector<IterDomain*> mapMatmulOpIterDomains(
69
+ const std::vector<IterDomain*>& input_domain,
70
+ int64_t input_position,
71
+ size_t out_size);
72
+
73
+ // For LinearOp, the output is the same as the first input (A[*,
74
+ // in_features])for all but the last dimension. If the second input is 2D
75
+ // (B[out_features, in_features]), the last dimension of output is out_features.
76
+ // If bias is 1D (bias[out_features]) it maps to the last dimension of the
77
+ // output. Args:
78
+ // 1. input_domain: root/logical domain without reductions for any input to
79
+ // LinearOp
80
+ // 2. input_position: Specifies if the input is A / B / Bias (0, 1, or 2)
81
+ // (MatmulTensorRole::Input_A/Input_B/Input_C) 3: out_size: LinearOp output
82
+ // dimension (input and output may not be the same size).
83
+ std::vector<IterDomain*> mapLinearOpIterDomains(
84
+ const std::vector<IterDomain*>& input_domain,
85
+ int64_t input_position,
86
+ size_t out_size,
87
+ bool k_bcast);
88
+
89
+ // Takes a vector of aligned input iterdomains to create the output iterdomain.
90
+ // This is used if the input iterdomains are not trivially mapped to the output
91
+ // iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will
92
+ // be the output IterType regardless of the inputs; otherwise the output
93
+ // IterType is inferred from ids.
94
+ IterDomain* newOutputIterDomain(
95
+ const std::vector<IterDomain*>& ids,
96
+ const std::optional<IterType> force_iter_type = std::nullopt);
97
+
98
+ // Takes a vector of `Val*`s and assumes they are all aligned to create the
99
+ // output tensorview, e.g., for BinaryOp. `vals` can contain scalars, e.g, when
100
+ // creating the output TensorView for `tv0+scalar`. This is for convenience and
101
+ // scalars will be ignored.
102
+ std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals);
103
+
104
+ TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype);
105
+
106
+ std::vector<Val*> maybeBroadcast(const std::vector<Val*>& vals);
107
+
108
+ NVF_API Val* newValLike(Val* val, DataType dtype);
109
+
110
+ // returns the minimum init value for reduction:
111
+ // -inf for floating type;
112
+ // lowest value for integer type;
113
+ // false for bool.
114
+ Val* getMinimumValue(DataType v);
115
+
116
+ // returns the maximum init value for reduction:
117
+ // inf for floating type;
118
+ // highest value for integer type;
119
+ // true for bool.
120
+ Val* getMaximumValue(DataType v);
121
+
122
+ std::vector<unsigned int> canonicalizeAxes(
123
+ const std::vector<int64_t>& axes,
124
+ int64_t ndims);
125
+
126
+ } // namespace ops
127
+ } // namespace nvfuser