nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,308 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+ #include <ir/all_nodes.h>
10
+ #include <ops/all_ops.h>
11
+
12
+ namespace nvfuser::python_frontend {
13
+
14
+ // Get std::function for UnaryOp
15
+ template <typename ResultType, typename... ArgTypes>
16
+ std::function<ResultType(ArgTypes...)> getFunction(const UnaryOp* uop) {
17
+ auto wrap_function = [](ResultType (*fn)(ArgTypes...)) { return fn; };
18
+
19
+ switch (uop->getUnaryOpType()) {
20
+ case UnaryOpType::Abs:
21
+ return wrap_function(abs);
22
+ case UnaryOpType::Acos:
23
+ return wrap_function(acos);
24
+ case UnaryOpType::Acosh:
25
+ return wrap_function(acosh);
26
+ case UnaryOpType::Asin:
27
+ return wrap_function(asin);
28
+ case UnaryOpType::Asinh:
29
+ return wrap_function(asinh);
30
+ case UnaryOpType::Atan:
31
+ return wrap_function(atan);
32
+ case UnaryOpType::Atanh:
33
+ return wrap_function(atanh);
34
+ case UnaryOpType::Ceil:
35
+ return wrap_function(ceil);
36
+ case UnaryOpType::Cos:
37
+ return wrap_function(cos);
38
+ case UnaryOpType::Cosh:
39
+ return wrap_function(cosh);
40
+ case UnaryOpType::Exp:
41
+ return wrap_function(exp);
42
+ case UnaryOpType::Exp2:
43
+ return wrap_function(exp2);
44
+ case UnaryOpType::Expm1:
45
+ return wrap_function(expm1);
46
+ case UnaryOpType::Erf:
47
+ return wrap_function(erf);
48
+ case UnaryOpType::Erfc:
49
+ return wrap_function(erfc);
50
+ case UnaryOpType::Erfinv:
51
+ return wrap_function(erfinv);
52
+ case UnaryOpType::Erfcinv:
53
+ return wrap_function(erfcinv);
54
+ case UnaryOpType::Floor:
55
+ return wrap_function(floor);
56
+ case UnaryOpType::Frac:
57
+ return wrap_function(frac);
58
+ case UnaryOpType::Lgamma:
59
+ return wrap_function(lgamma);
60
+ case UnaryOpType::Log:
61
+ return wrap_function(log);
62
+ case UnaryOpType::Log10:
63
+ return wrap_function(log10);
64
+ case UnaryOpType::Log1p:
65
+ return wrap_function(log1p);
66
+ case UnaryOpType::Log2:
67
+ return wrap_function(log2);
68
+ case UnaryOpType::Neg:
69
+ return wrap_function(neg);
70
+ case UnaryOpType::LogicalNot:
71
+ return wrap_function(logical_not);
72
+ case UnaryOpType::BitwiseNot:
73
+ return wrap_function(bitwise_not);
74
+ case UnaryOpType::Reciprocal:
75
+ return wrap_function(reciprocal);
76
+ case UnaryOpType::Relu:
77
+ return wrap_function(relu);
78
+ case UnaryOpType::Rsqrt:
79
+ return wrap_function(rsqrt);
80
+ case UnaryOpType::Round:
81
+ return wrap_function(round);
82
+ case UnaryOpType::Sigmoid:
83
+ return wrap_function(sigmoid);
84
+ case UnaryOpType::Signbit:
85
+ return wrap_function(signbit);
86
+ case UnaryOpType::Silu:
87
+ return wrap_function(silu);
88
+ case UnaryOpType::Sin:
89
+ return wrap_function(sin);
90
+ case UnaryOpType::Sinh:
91
+ return wrap_function(sinh);
92
+ case UnaryOpType::Sqrt:
93
+ return wrap_function(sqrt);
94
+ case UnaryOpType::Tan:
95
+ return wrap_function(tan);
96
+ case UnaryOpType::Tanh:
97
+ return wrap_function(tanh);
98
+ case UnaryOpType::Trunc:
99
+ return wrap_function(trunc);
100
+ case UnaryOpType::IsFinite:
101
+ return wrap_function(isfinite);
102
+ case UnaryOpType::IsInf:
103
+ return wrap_function(isinf);
104
+ case UnaryOpType::IsNan:
105
+ return wrap_function(isnan);
106
+ case UnaryOpType::IsNegInf:
107
+ return wrap_function(isneginf);
108
+ case UnaryOpType::IsPosInf:
109
+ return wrap_function(isposinf);
110
+ case UnaryOpType::IsReal:
111
+ return wrap_function(isreal);
112
+ case UnaryOpType::Real:
113
+ return wrap_function(real);
114
+ case UnaryOpType::Imag:
115
+ return wrap_function(imag);
116
+ default:
117
+ NVF_CHECK(
118
+ false,
119
+ "Unexpected operator type: ",
120
+ uop->getUnaryOpType(),
121
+ " in ",
122
+ uop->toString());
123
+ }
124
+ }
125
+
126
+ // Get std::function for BinaryOp
127
+ template <typename ResultType, typename... ArgTypes>
128
+ std::function<ResultType(ArgTypes...)> getFunction(const BinaryOp* bop) {
129
+ auto wrap_function = [](ResultType (*fn)(ArgTypes...)) { return fn; };
130
+
131
+ switch (bop->getBinaryOpType()) {
132
+ case BinaryOpType::Add:
133
+ return wrap_function(add);
134
+ break;
135
+ case BinaryOpType::Atan2:
136
+ return wrap_function(atan2);
137
+ break;
138
+ case BinaryOpType::Div:
139
+ return wrap_function(div);
140
+ break;
141
+ case BinaryOpType::Fmod:
142
+ return wrap_function(fmod);
143
+ break;
144
+ case BinaryOpType::Mul:
145
+ return wrap_function(mul);
146
+ break;
147
+ case BinaryOpType::Nextafter:
148
+ return wrap_function(nextafter);
149
+ break;
150
+ case BinaryOpType::Pow:
151
+ return wrap_function(pow);
152
+ break;
153
+ case BinaryOpType::Remainder:
154
+ return wrap_function(remainder);
155
+ break;
156
+ case BinaryOpType::Sub:
157
+ return wrap_function(sub);
158
+ break;
159
+ case BinaryOpType::Mod:
160
+ return wrap_function(mod);
161
+ break;
162
+ case BinaryOpType::Eq:
163
+ return wrap_function(eq);
164
+ break;
165
+ case BinaryOpType::NE:
166
+ return wrap_function(ne);
167
+ break;
168
+ case BinaryOpType::GT:
169
+ return wrap_function(gt);
170
+ break;
171
+ case BinaryOpType::GE:
172
+ return wrap_function(ge);
173
+ break;
174
+ case BinaryOpType::LT:
175
+ return wrap_function(lt);
176
+ break;
177
+ case BinaryOpType::LE:
178
+ return wrap_function(le);
179
+ break;
180
+ case BinaryOpType::BitwiseAnd:
181
+ return wrap_function(bitwise_and);
182
+ break;
183
+ case BinaryOpType::BitwiseOr:
184
+ return wrap_function(bitwise_or);
185
+ break;
186
+ case BinaryOpType::BitwiseXor:
187
+ return wrap_function(bitwise_xor);
188
+ break;
189
+ case BinaryOpType::LogicalAnd:
190
+ return wrap_function(logical_and);
191
+ break;
192
+ case BinaryOpType::LogicalOr:
193
+ return wrap_function(logical_or);
194
+ break;
195
+ case BinaryOpType::Lshift:
196
+ return wrap_function(bitwise_left_shift);
197
+ break;
198
+ case BinaryOpType::Rshift:
199
+ return wrap_function(bitwise_right_shift);
200
+ break;
201
+ case BinaryOpType::Gcd:
202
+ return wrap_function(gcd);
203
+ break;
204
+ case BinaryOpType::Min:
205
+ return wrap_function(minimum);
206
+ break;
207
+ case BinaryOpType::Max:
208
+ return wrap_function(maximum);
209
+ break;
210
+ case BinaryOpType::CeilDiv:
211
+ return wrap_function(ceilDiv);
212
+ break;
213
+ default:
214
+ NVF_CHECK(
215
+ false,
216
+ "Unexpected operator type: ",
217
+ bop->getBinaryOpType(),
218
+ " in ",
219
+ bop->toString());
220
+ }
221
+ }
222
+
223
+ // Get std::function for TernaryOp
224
+ template <typename ResultType, typename... ArgTypes>
225
+ std::function<ResultType(ArgTypes...)> getFunction(const TernaryOp* top) {
226
+ auto wrap_function = [](ResultType (*fn)(ArgTypes...)) { return fn; };
227
+
228
+ // clamp and threshold define a subset of TernaryOp configurations, so they
229
+ // are handled in a separate template specialization.
230
+ switch (top->getTernaryOpType()) {
231
+ case TernaryOpType::Lerp:
232
+ return wrap_function(lerp);
233
+ break;
234
+ case TernaryOpType::Where:
235
+ return wrap_function(where);
236
+ break;
237
+ case TernaryOpType::Threshold:
238
+ case TernaryOpType::Clamp:
239
+ NVF_CHECK(
240
+ false,
241
+ "Invalid function arguments for operator type",
242
+ top->getTernaryOpType(),
243
+ " in ",
244
+ top->toString());
245
+ default:
246
+ NVF_CHECK(
247
+ false,
248
+ "Unexpected operator type: ",
249
+ top->getTernaryOpType(),
250
+ " in ",
251
+ top->toString());
252
+ }
253
+ }
254
+
255
+ // Fully specialized template functions to create std::function for TernaryOp.
256
+ template <>
257
+ std::function<TensorView*(TensorView*, Val*, Val*)> getFunction<
258
+ TensorView*,
259
+ TensorView*,
260
+ Val*,
261
+ Val*>(const TernaryOp* top);
262
+
263
+ template <>
264
+ std::function<Val*(Val*, Val*, Val*)> getFunction<Val*, Val*, Val*, Val*>(
265
+ const TernaryOp* top);
266
+
267
+ // Get std::function for ReductionOp
268
+ template <typename ResultType, typename... ArgTypes>
269
+ std::function<ResultType(ArgTypes...)> getFunction(const ReductionOp* rop) {
270
+ switch (rop->getReductionOpType()) {
271
+ case BinaryOpType::Add:
272
+ return sum;
273
+ break;
274
+ case BinaryOpType::Mul:
275
+ return prod;
276
+ break;
277
+ case BinaryOpType::Max:
278
+ return max;
279
+ break;
280
+ case BinaryOpType::Min:
281
+ return min;
282
+ break;
283
+ default:
284
+ NVF_CHECK(
285
+ false,
286
+ "Unexpected reduction operator type: ",
287
+ rop->getReductionOpType(),
288
+ " in ",
289
+ rop->toString());
290
+ }
291
+ }
292
+
293
+ // Get string name for UnaryOp
294
+ std::string getString(const UnaryOp* uop);
295
+
296
+ // Get string name for BinaryOp
297
+ std::string getString(const BinaryOp* bop);
298
+
299
+ // Get string name for TernaryOp
300
+ std::string getString(const TernaryOp* bop);
301
+
302
+ // Get string name for ReductionOp
303
+ std::string getString(const ReductionOp* rop);
304
+
305
+ // Get serde record type for ReductionOp
306
+ serde::RecordType getSerdeType(const ReductionOp* rop);
307
+
308
+ } // namespace nvfuser::python_frontend
@@ -0,0 +1,17 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+ #include <scheduler/expr_eval_sched.h>
10
+ #include <scheduler/matmul.h>
11
+ #include <scheduler/no_op.h>
12
+ #include <scheduler/normalization_inner.h>
13
+ #include <scheduler/normalization_inner_outer.h>
14
+ #include <scheduler/normalization_outer.h>
15
+ #include <scheduler/pointwise.h>
16
+ #include <scheduler/reduction.h>
17
+ #include <scheduler/transpose.h>
@@ -0,0 +1,206 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <scheduler/multi_matmul.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ // MmaOps in the scheduled tensor. Each one outputs a TensorView* which we call
16
+ // an mma_result. Each MmaOp will also have two input TensorViews which we call
17
+ // "ab" and "bb" since they are the immediate A and B operands and they contain
18
+ // broadcast dimensions. Again there can be multiple abs and multiple bbs in
19
+ // one fusion. These TensorViews are loaded from global memory tensors that we
20
+ // call "a" and "b" into shared memory tensors called acw_smem and bcw_smem.
21
+ // They are loaded from shared memory to register buffers we call "acr" and
22
+ // "bcr" ("cr" meaning "cache read" in this context).
23
+ //
24
+ // Putting this all together we have the following order for a simple matmul
25
+ //
26
+ // a -> acw_smem -> acr -> ... -> ab
27
+ // \ .
28
+ // mma_result -> ... -> dc -> d
29
+ // /
30
+ // b -> bcw_smem -> bcr -> ... -> bb
31
+ //
32
+ // The ... indicate that there might be other tensors involved in a prologue or
33
+ // epilogue section at that location.
34
+ //
35
+ // In this example there are two matmuls both using the same "a" operand:
36
+ //
37
+ // b1 -> bcw_smem1 -> bcr1 -> ... -> bb1
38
+ // \ .
39
+ // mma_result1
40
+ // / \ .
41
+ // a -> acw_smem -> acr -> ... -> ab ... -> dc -> d
42
+ // \ /
43
+ // mma_result2
44
+ // /
45
+ // b2 -> bcw_smem2 -> bcr2 -> ... -> bb2
46
+ //
47
+ // Note that there can be more than one output d and each one will have its own
48
+ // register cache dc.
49
+ //
50
+ // Split-K and smem epilogue unswizzling add two additional tensors for each
51
+ // mma in the fusion: splitk_sum and smem_epilogue.
52
+ //
53
+ // // No split-K, no smem epilogue unswizzling:
54
+ // mma_result -> ... -> dc -> d
55
+ // // split-K, no smem epilogue unswizzling:
56
+ // mma_result -> splitk_sum -> ... -> dc -> d
57
+ // // smem epilogue unswizzling, no split-K:
58
+ // mma_result -> smem_epilogue -> ... -> dc -> d
59
+ // // split-K and smem epilogue unswizzling:
60
+ // mma_result -> smem_epilogue -> splitk_sum -> ... -> dc -> d
61
+ //
62
+ // These additional tensors are added to each mma_result in the fusion.
63
+ //
64
+ // Each of the named tensors above is scheduled differently. We schedule them
65
+ // by building AbstractTensors for each tensor category; these are held in
66
+ // AmpereMultipleMatmulScheduler::schedules_.
67
+ // TODO: Inherit from SchedulerEntry
68
+ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler {
69
+ public:
70
+ AmpereMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params)
71
+ : MultipleMatmulScheduler(fusion, params) {
72
+ const auto device_prop = at::cuda::getCurrentDeviceProperties();
73
+ const int cc = device_prop->major * 10 + device_prop->minor;
74
+ NVF_ERROR(
75
+ cc >= 75 && cc < 90,
76
+ "This matmul scheduler is restricted to Ampere and Turing.");
77
+ }
78
+
79
+ void run() final;
80
+
81
+ private:
82
+ void cacheInputsAndOutputs();
83
+
84
+ // Including current tensor naming convention for reference,
85
+ // this is very temporary and will change over time and
86
+ // in fact the whole body of this function will
87
+ // eventually be a set of utility functions for different
88
+ // sections of matmul(fusion) kernels, with
89
+ // each having its own build out to do.
90
+ //
91
+ // Current naming convention is based on the following formula:
92
+ //
93
+ // d = alpha * (a x b) + beta * c
94
+ //
95
+ // and is defined in the following way:
96
+ //
97
+ // operands assumed in global memory : a, b, c
98
+ //
99
+ // registers staging global load : ar, br (short for a/b read)
100
+ //
101
+ // shared mem cache of operands : acw_smem, bcw_smem (short for a/b
102
+ // cache_write smem)
103
+ //
104
+ // registers at shared memory load output : acr, bcr (short for a/b cache
105
+ // read)
106
+ //
107
+ // register tensor input to the actual mma op: ab, bb (short for a/b
108
+ // broadcasted)
109
+ //
110
+ // accumulator register: mma_result
111
+ // - mma_result is MmaOp output if there is epilogue
112
+ // - mma_result is dc (short for d cache) if there is no epilogue
113
+ //
114
+ // result in global memory: d
115
+
116
+ // Currently the support is for a, b, c and d as fusion inputs/outputs
117
+ // aka. no prolog fusion yet.
118
+ void defineOperandCaches();
119
+
120
+ void cacheOperandsToSmem(
121
+ const std::vector<TensorView*>& operands,
122
+ std::vector<TensorView*>& smem_operands,
123
+ int64_t vec_size);
124
+
125
+ // We add two LoadStore operators to the inputs of our fusions. The first
126
+ // one is for a read from global memory and the second one (below) is for a
127
+ // cache read. As an optimizaton, we avoid adding an operator if there's an
128
+ // existing LoadStoreOp present. Please note that for the second LoadStore
129
+ // we don't propagate the allocation domain, since the scheduler sets the
130
+ // allocation domain in the registers.
131
+ void cacheOperandsToRegisters(
132
+ const std::vector<TensorView*>& tv_smems,
133
+ std::vector<TensorView*>& tv_rs);
134
+
135
+ //! Swizzle the M and N outer dimensions after makeTile has been called.
136
+ //! This updates outer_dim_roles if we introduce a new dimension, which can
137
+ //! happen if tv is missing a merged axis, in which case we skip merging after
138
+ //! the split. This is analogous to forwarding during transform propagation.
139
+ void swizzleBlockTiles(
140
+ TensorView* tv,
141
+ std::vector<MatmulDimRole>& outer_dim_roles);
142
+
143
+ //! This calls orig->cacheAfter() and also updates the broadcast graph to
144
+ //! reflect the new IterDomain mappings
145
+ TensorView* cacheAfter(
146
+ TensorView* orig,
147
+ LoadStoreOpType op_type = LoadStoreOpType::Set,
148
+ CacheOp cache_op = CacheOp::AllLevels,
149
+ bool propagate_allocation_domain = false);
150
+
151
+ //! Do block tiling for a collection of TensorViews. The tensors should be
152
+ //! unscheduled before this method is called.
153
+ //! 1) Axes will be ordered according to canonicalDimOrdering, and then axes
154
+ //! with the same role will be merged.
155
+ //! 2) After that, we perform splits according to
156
+ //! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
157
+ //! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
158
+ //! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
159
+ //! Mo and No.
160
+ //! 4) Finally, we do a split-K split if the splitk_factor is not 1
161
+ std::vector<std::vector<MatmulDimRole>> blockTileTensors(
162
+ const std::vector<TensorView*>& tvs);
163
+
164
+ //! Schedule the loads of all operands from global memory to shared memory.
165
+ //! Starting from the basic tiled schedule, we swizzle the operand memory.
166
+ //! Note that the cache op and LoadStoreOpType are already set during
167
+ //! defineOperandCaches().
168
+ void scheduleOperandSmemStores();
169
+
170
+ void scheduleMmaOperands(
171
+ std::vector<TensorView*>& tvs,
172
+ const std::optional<MmaOperand> operand_type);
173
+
174
+ // MmaOperand contains only A and B. If tvs are outputs (i.e. not operands),
175
+ // then operand_type should be std::nullopt.
176
+ void scheduleMmaResults();
177
+
178
+ void schedulePrologues();
179
+
180
+ void scheduleOutputTensor(TensorView* c);
181
+
182
+ void scheduleEpilogue();
183
+
184
+ //! Propagates transformations from fusion output to fusion tv inputs that are
185
+ //! producers in the epilogue. Transformations' propagation aims at input tvs
186
+ //! which are not assigned to core roles, that is, are not MMA inputs.
187
+ void scheduleFusionInputsForEpilogue();
188
+
189
+ void scheduleSplitKSum();
190
+
191
+ void setUpInlining();
192
+
193
+ // NOTE: this should be called after acw_smem, acr, ..., ab, and mma_result
194
+ // transforms have been applied and inlining
195
+ void setUpCircularBuffering();
196
+
197
+ private:
198
+ std::vector<std::pair<TensorView*, TensorView*>> cached_outputs_;
199
+
200
+ std::vector<ValGroup> canonical_dim_ordering_;
201
+
202
+ std::vector<TensorView*> acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, bbs_,
203
+ splitk_sums_, smem_epilogues_;
204
+ };
205
+
206
+ } // namespace nvfuser
@@ -0,0 +1,19 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <fusion.h>
11
+ #include <visibility.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ // Visits all global-to-local vector loads in `fusion` and refines their cache
16
+ // policies.
17
+ NVF_API void refineCachePolicy(Fusion* fusion);
18
+
19
+ } // namespace nvfuser