nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,391 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <compute_at_map.h>
13
+ #include <device_lower/analysis/circular_buffer.h>
14
+ #include <device_lower/analysis/fused_reduction.h>
15
+ #include <device_lower/analysis/predicate_elimination.h>
16
+ #include <device_lower/analysis/sync_information.h>
17
+ #include <device_lower/analysis/tensor_memory.h>
18
+ #include <device_lower/analysis/thread_predicate.h>
19
+ #include <device_lower/analysis/tma.h>
20
+ #include <device_lower/analysis/trivial_broadcast.h>
21
+ #include <device_lower/id_model_options.h>
22
+ #include <device_lower/pass/allocation.h>
23
+ #include <device_lower/pass/circular_buffer.h>
24
+ #include <device_lower/pass/predicate.h>
25
+ #include <device_lower/pass/scalar_hoist.h>
26
+ #include <device_lower/pass/warp_reduce.h>
27
+ #include <exceptions.h>
28
+ #include <expr_simplifier.h>
29
+ #include <id_model/id_model.h>
30
+ #include <id_model/indexing.h>
31
+ #include <ir/all_nodes.h>
32
+ #include <kernel.h>
33
+ #include <kernel_ir.h>
34
+ #include <logical_domain_map.h>
35
+ #include <non_divisible_split.h>
36
+ #include <options.h>
37
+ #include <parallel_dimension_map.h>
38
+ #include <runtime/executor_params.h>
39
+ #include <vectorization_info.h>
40
+ #include <visibility.h>
41
+
42
+ #include <functional>
43
+ #include <memory>
44
+ #include <ostream>
45
+ #include <unordered_map>
46
+ #include <unordered_set>
47
+
48
+ namespace nvfuser {
49
+
50
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
51
+ class GpuLower : public NonCopyable {
52
+ class KernelIrMapper;
53
+
54
+ public:
55
+ GpuLower() = delete;
56
+
57
+ using Pass = std::pair<
58
+ std::string, // name of the pass
59
+ std::function<std::vector<Expr*>(const std::vector<Expr*>&)>>;
60
+
61
+ // GpuLower lowers the provided fusion into a kernel which can be translated
62
+ // into cuda code. index_type allows to compile the kernel based on int32
63
+ // indexing instead of int64 for additional performance.
64
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
65
+ NVF_API explicit GpuLower(
66
+ Fusion* fusion,
67
+ const CompileParams& cparams = CompileParams());
68
+
69
+ NVF_API kir::Kernel* kernel() const;
70
+
71
+ //! Returns the currently active lowering object.
72
+ //! It's an error if no lowering is in progress.
73
+ static GpuLower* current();
74
+
75
+ //! Query if lowering is in progress
76
+ static bool hasCurrent();
77
+
78
+ //! Actually run the lowering by executing the passes in the order given by
79
+ //! passes_
80
+ NVF_API kir::Kernel* run();
81
+
82
+ const PrimDataType& indexType() const {
83
+ return cparams_.index_type.value();
84
+ }
85
+
86
+ const auto& minDeviceVersion() const {
87
+ return min_device_version_;
88
+ }
89
+
90
+ const std::string& minDeviceVersionReason() const {
91
+ return min_device_version_reason_;
92
+ }
93
+
94
+ std::shared_ptr<const ConcretizedBroadcastDomains>
95
+ concretizedBroadcastDomains() {
96
+ return concretized_broadcast_domains_;
97
+ }
98
+
99
+ const ThreadPredicateMap& threadPredMap() const {
100
+ return thread_pred_map_;
101
+ }
102
+
103
+ // Returns non-const reference. Necessary to reset a predicate flag
104
+ // when a broadcast expression is fused into a reduction.
105
+ ThreadPredicateMap& threadPredMap() {
106
+ return thread_pred_map_;
107
+ }
108
+
109
+ std::shared_ptr<const ComputeAtMap> caMap() const {
110
+ return std::const_pointer_cast<const ComputeAtMap>(compute_at_map_);
111
+ }
112
+
113
+ bool hasIdModel() const {
114
+ return id_model_.get() != nullptr;
115
+ }
116
+
117
+ IdModel& idModel() {
118
+ NVF_ERROR(id_model_.get());
119
+ return *id_model_;
120
+ }
121
+
122
+ const IdModel& idModel() const {
123
+ NVF_ERROR(id_model_.get());
124
+ return *id_model_;
125
+ }
126
+
127
+ bool isTensorIndexerEnabled() const {
128
+ return tensor_indexer_.get() != nullptr;
129
+ }
130
+
131
+ TensorIndexer& tensorIndexer() {
132
+ NVF_ERROR(tensor_indexer_.get());
133
+ return *tensor_indexer_;
134
+ }
135
+
136
+ const TensorIndexer& tensorIndexer() const {
137
+ NVF_ERROR(tensor_indexer_.get());
138
+ return *tensor_indexer_;
139
+ }
140
+
141
+ const ParallelDimensionMap& parallelDimensionMap() const {
142
+ return parallel_dimension_map_;
143
+ }
144
+
145
+ ParallelDimensionMap& parallelDimensionMap() {
146
+ return parallel_dimension_map_;
147
+ }
148
+
149
+ PredicateElimination& predicateElimination() {
150
+ NVF_ERROR(pred_elimination_.get() != nullptr);
151
+ return *pred_elimination_;
152
+ }
153
+
154
+ const PredicateElimination& predicateElimination() const {
155
+ NVF_ERROR(pred_elimination_.get() != nullptr);
156
+ return *pred_elimination_;
157
+ }
158
+
159
+ LocalAllocationInfoMap& localAllocationInfoMap() {
160
+ return local_allocation_info_map_;
161
+ }
162
+
163
+ const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
164
+ return warp_pad_info_;
165
+ }
166
+
167
+ auto& nonDivisibleSplitInfo() {
168
+ return non_divisible_split_info_;
169
+ }
170
+
171
+ const auto& nonDivisibleSplitInfo() const {
172
+ return non_divisible_split_info_;
173
+ }
174
+
175
+ const auto& divisibleSplitSet() const {
176
+ return divisible_splits_;
177
+ }
178
+
179
+ CircularBufferInfo& circularBufferInfo() {
180
+ return circular_buffer_info_;
181
+ }
182
+
183
+ TmaCircularBufferInfo& tmaCircularBufferInfo() {
184
+ return tma_circular_buffer_info_;
185
+ }
186
+
187
+ CommonScalarMap& commonScalarMap() {
188
+ return common_scalar_map_;
189
+ }
190
+
191
+ const auto& vectorizedAccesses() const {
192
+ return vectorized_accesses_;
193
+ }
194
+
195
+ auto& vectorizedAccesses() {
196
+ return vectorized_accesses_;
197
+ }
198
+
199
+ const auto& vectorizedSetInfo() const {
200
+ return vectorized_set_info_;
201
+ }
202
+
203
+ auto& vectorizedSetInfo() {
204
+ return vectorized_set_info_;
205
+ }
206
+
207
+ FusedReductionInfo& fusedReductionInfo() {
208
+ return fused_reduction_info_;
209
+ }
210
+
211
+ std::shared_ptr<const SyncMap> syncMap() const {
212
+ return sync_map_;
213
+ }
214
+
215
+ kir::KernelPerformanceProfile& profile() {
216
+ return profile_;
217
+ }
218
+
219
+ std::unordered_map<const Expr*, TensorView*>& ldstMBarrierMap() {
220
+ return ldst_mbarrier_map_;
221
+ }
222
+
223
+ const std::unordered_map<const Expr*, TensorView*>& ldstMBarrierMap() const {
224
+ return ldst_mbarrier_map_;
225
+ }
226
+
227
+ bool isNvFuserZeroEnabled() {
228
+ if (isOptionDisabled(DisableOption::MagicZero)) {
229
+ return false;
230
+ }
231
+ return cparams_.enable_magic_zero;
232
+ }
233
+
234
+ // This is an interface to propagate information after expression
235
+ // replacement on the kernel IR. E.g.:
236
+ // for ...
237
+ // c = a + b (expr 0)
238
+ // after any pass that does replacement:
239
+ // for ...
240
+ // c1 = a1 + b1 (expr1)
241
+ // The previous analysis that was performed on expr0 might still
242
+ // be valid on expr1 but that info would be lost after replacement.
243
+ // This function provides an interface to manually update the info
244
+ // in any pass that performs replacement.
245
+ void propagateExprInfo(const Expr* old_expr, const Expr* new_expr);
246
+
247
+ std::vector<Val*>& allKnownVals() {
248
+ return all_known_vals_;
249
+ }
250
+
251
+ const std::vector<Val*>& allKnownVals() const {
252
+ return all_known_vals_;
253
+ }
254
+
255
+ const std::vector<Pass>& passes() const {
256
+ return passes_;
257
+ }
258
+
259
+ std::vector<Pass>& passes() {
260
+ return passes_;
261
+ }
262
+
263
+ std::unordered_map<TensorView*, const TMAInfo>& consumerToTMAInfo() {
264
+ return consumer_to_tma_info_;
265
+ }
266
+
267
+ const std::unordered_map<TensorView*, const TMAInfo>& consumerToTMAInfo()
268
+ const {
269
+ return consumer_to_tma_info_;
270
+ }
271
+
272
+ const TensorMemoryInfo& tmemInfo() const {
273
+ return tmem_info_;
274
+ }
275
+
276
+ TensorMemoryInfo& tmemInfo() {
277
+ return tmem_info_;
278
+ }
279
+
280
+ // Register a boolean Val as a predicate to validate at the run time. Optional
281
+ // validation error messages can be given as args.
282
+ template <typename... Args>
283
+ void validate(Val* validation_condition, Args... args) {
284
+ auto sv = simplifyExpr(validation_condition);
285
+ if (sv->isTrue()) {
286
+ // If validation_condition is simplified to true, we know that the
287
+ // condition is always true regardless of the runtime values of the
288
+ // inputs. We can skip the validation. For example, we are not interested
289
+ // in validating that 3 < 4 or i % 8 < 8 every time we run the kernel.
290
+ return;
291
+ }
292
+ std::string message = to_str(args...);
293
+ NVF_ERROR(!sv->isFalse(), message);
294
+ validations_.emplace_back(sv, message);
295
+ }
296
+
297
+ const std::vector<std::pair<const Val*, std::string>>& validations() const {
298
+ return validations_;
299
+ }
300
+
301
+ std::vector<std::pair<const Val*, std::string>>& validations() {
302
+ return validations_;
303
+ }
304
+
305
+ // Get the index variable assigned for a given loop ID. Currently
306
+ // it's a wrapper around ComputeAtMap::getIndexVariable or
307
+ // IdModel::getLoopIndexVariable if IdModelEnableOption::Loop is
308
+ // enabled.
309
+ Val* getLoopIndexVariable(
310
+ IterDomain* id,
311
+ CircularBufferLoopStage stage =
312
+ CircularBufferLoopStage::NotApplicable) const;
313
+
314
+ const IdModelOptions idModelOptions() const {
315
+ return id_model_options_;
316
+ }
317
+
318
+ private:
319
+ void analysis(Fusion* fusion);
320
+
321
+ // Goes through the parallelized iterdomains of the used TVs and find
322
+ // the parallel dimensions that need to be padded to a multiples of
323
+ // warp size.
324
+ void collectPaddedParallelDims();
325
+
326
+ bool resolveComputeWith(Fusion* fusion);
327
+
328
+ private:
329
+ // Lowered Kernel IR
330
+ std::unique_ptr<kir::Kernel> kernel_;
331
+
332
+ // Passes to lower kernel, in order
333
+ std::vector<Pass> passes_;
334
+
335
+ // Some stateful information during lowering
336
+ // TODO: A lot of this information uses a define class then call build. It
337
+ // would be safer to wrap all of these in unique pointers and remove the build
338
+ // interface and default constructor. That way they couldn't be accessed
339
+ // without being initialized.
340
+ std::pair<int64_t, int64_t> min_device_version_;
341
+ std::string min_device_version_reason_;
342
+ std::shared_ptr<const ConcretizedBroadcastDomains>
343
+ concretized_broadcast_domains_;
344
+ ThreadPredicateMap thread_pred_map_;
345
+ std::unique_ptr<PredicateElimination> pred_elimination_;
346
+ std::shared_ptr<ComputeAtMap> compute_at_map_;
347
+ LocalAllocationInfoMap local_allocation_info_map_;
348
+ WarpPaddedParallelInfo warp_pad_info_;
349
+ ParallelDimensionMap parallel_dimension_map_;
350
+ NonDivisibleSplitInfo non_divisible_split_info_;
351
+ CircularBufferInfo circular_buffer_info_;
352
+ TmaCircularBufferInfo tma_circular_buffer_info_;
353
+ CommonScalarMap common_scalar_map_;
354
+ FusedReductionInfo fused_reduction_info_;
355
+ std::shared_ptr<const SyncMap> sync_map_;
356
+ kir::KernelPerformanceProfile profile_;
357
+ std::unordered_set<Split*> divisible_splits_;
358
+ CompileParams cparams_;
359
+ std::unique_ptr<IdModel> id_model_;
360
+ std::unique_ptr<TensorIndexer> tensor_indexer_;
361
+ std::unordered_map<TensorView*, const TMAInfo> consumer_to_tma_info_;
362
+
363
+ // Track which tensor views are inputs or outputs of a vectorized operation
364
+ // and their maximum vectorized access size
365
+ // std::unordered_map<TensorView*, VectorizationInfo> vectorized_accesses_;
366
+ std::unordered_map<TensorView*, int64_t> vectorized_accesses_;
367
+ // Info on each vectorized set op
368
+ std::vector<VectorizedSetInfo> vectorized_set_info_;
369
+
370
+ // All vals that are known to the kernel, including fusion inputs and
371
+ // precomputed values
372
+ std::vector<Val*> all_known_vals_;
373
+
374
+ // Keep track of the mbarrier used for each load/store operation
375
+ std::unordered_map<const Expr*, TensorView*> ldst_mbarrier_map_;
376
+
377
+ // Information about tensor memory usage
378
+ TensorMemoryInfo tmem_info_;
379
+
380
+ // Keep track of validations needed at runtime. For example, a pair of
381
+ //! "extent mod split_factor == 0" and an error message for divisibility check
382
+ //! for vectorization.
383
+ std::vector<std::pair<const Val*, std::string>> validations_;
384
+
385
+ Fusion* fusion_ = nullptr;
386
+
387
+ // A temporary option set to selectively enable IdModel usage
388
+ IdModelOptions id_model_options_;
389
+ };
390
+
391
+ } // namespace nvfuser
@@ -0,0 +1,37 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <dispatch.h>
13
+ #include <ir/all_nodes.h>
14
+
15
+ #include <vector>
16
+
17
+ namespace nvfuser {
18
+
19
+ //! Reuse Allocation nodes via pointer aliasing
20
+ //!
21
+ //! First pass finds candidate TensorViews
22
+ //! A candidate TensorView is anything in shared memory OR
23
+ //! in local memory with a static size larger than register_size_threshold
24
+ //!
25
+ //! Second pass finds appropriate input Allocate Node
26
+ //! among candidate TensorViews
27
+ //!
28
+ //! Alias Criteria:
29
+ //! If input is a candidate TensorView,
30
+ //! input allocation has the same size as output allocation,
31
+ //! thread bindings match,
32
+ //! is not used after this op:
33
+ //! then alias output Allocate to input Allocate.
34
+ //!
35
+ std::vector<Expr*> reuseMemoryAllocations(const std::vector<Expr*>& exprs);
36
+
37
+ } // namespace nvfuser
@@ -0,0 +1,32 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <ir/all_nodes.h>
13
+ #include <kernel_ir.h>
14
+
15
+ #include <vector>
16
+
17
+ namespace nvfuser {
18
+
19
+ //! Buffer allocation information to store in GPU lower to avoid
20
+ //! logic duplication
21
+ struct LocalAllocationInfo {
22
+ kir::Allocate* alloc_expr = nullptr;
23
+ std::vector<IterDomain*> alloc_domains;
24
+ };
25
+
26
+ using LocalAllocationInfoMap = std::
27
+ unordered_map<const kir::Allocate*, std::unique_ptr<LocalAllocationInfo>>;
28
+
29
+ //! Insert buffer allocations
30
+ std::vector<Expr*> insertAllocations(const std::vector<Expr*>& exprs);
31
+
32
+ } // namespace nvfuser
@@ -0,0 +1,191 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <ir/all_nodes.h>
13
+ #include <kernel_ir.h>
14
+ #include <kernel_ir_dispatch.h>
15
+
16
+ // Double buffering a tensor doubles its allocation size and uses two
17
+ // buffers to facilitate computation and memory access
18
+ // overlapping. The basic form of code looks like as follows:
19
+ //
20
+ // Before:
21
+ // for i
22
+ // x[S]; // allocation
23
+ // for j:
24
+ // x[j] = y[i, j]
25
+ // for j:
26
+ // ... = x[j]
27
+ //
28
+ // After:
29
+ // X[S * 2]; // allocation
30
+ // for i in 0 to 1: // Prologue
31
+ // for j:
32
+ // x[j] = y[i, j]
33
+ //
34
+ // for i in 0 to N-1: // Main
35
+ // for j:
36
+ // x[j + (1 - i % 2) * S] = y[i + 1, j]
37
+ // for j:
38
+ // ... = x[j + (i % 2) * S]
39
+ //
40
+ // for i in N-1 to N: // Epilogue
41
+ // for j:
42
+ // ... = x[j + (i % 2) * S]
43
+ //
44
+ // Here, S is the original size of tensor x.
45
+ //
46
+ // The i loop is the double buffer loop of tensor x, where double
47
+ // buffering is applied to the tensor. The first step of lowering is
48
+ // to find the double buffering axis for each double buffered
49
+ // tensor. It must not be parallelized as it isn't possible to double
50
+ // buffer parallelized loops. Also, an unrolled axis expands the
51
+ // allocation and is intended to make the loop completely unrolled,
52
+ // which also conflicts with double buffering. So, basically, the double
53
+ // buffering axis is the inner-most axis within the axes left
54
+ // of the CA position. However, when it is parallelized or unrolled, a
55
+ // further left axis is picked.
56
+ //
57
+ // Once the double buffer axis is determined, the main task is to
58
+ // replicate the corresponding double buffer loop as illustrated
59
+ // above. The Prologue loop is to just fetch the first element to
60
+ // populate the buffer. The main loop is mostly the same as the
61
+ // original loop, except for the indexing change to switch the two
62
+ // buffers. When used as a consumer, an offset of (1 - i % 2) * S is
63
+ // added, whereas (i % 2) * S is added when used as a producer. Here,
64
+ // i is the index of the double buffer loop. The Epilogue loop is just
65
+ // for the last iteration of the loop. Since the main loop reads one
66
+ // element ahead of the producer of the double buffered tensor, it
67
+ // would require an additional guard to prevent buffer overruns with
68
+ // the producer if the main loop were also used for the last
69
+ // iteration. However, the value loaded by the invalid load would not
70
+ // be used, so instead of adding the additional predicate, the Epilogue
71
+ // loop is replicated from the original loop, except for the load
72
+ // expression since it's not used. Note that this overrun does not
73
+ // happen when the producer is on gmem, so in that case, this
74
+ // additional replication is not done.
75
+ //
76
+ // When creating those three types of loops, additional care must be
77
+ // taken when multiple tensors are double buffered. When multiple
78
+ // tensors use the same loop as their double buffer loop, one pass of
79
+ // replication takes care of them at once, meaning the same Prologue,
80
+ // Main, Epilogue loops are used for the multiple tensors.
81
+ //
82
+ // Other tasks to do for a double buffer tensor include:
83
+ // - Move allocation to outside of the double buffer loop
84
+ // - Double the allocation size
85
+ // - Omit the RAW sync in the Main and Epilogue loops
86
+
87
+ // [Cicular buffer] An generalization of double buffering.
88
+ // On sm80+ hardware there is asynchronous copy infrastructure that
89
+ // motivates a circular buffering generalization of double buffering.
90
+ // Almost all analyses previously done for double buffering are exactly
91
+ // the same with circular buffering, except for the introduction of
92
+ // new concept: `stage depth`.
93
+ //
94
+ // The `stage depth` is defined as the multiplier of extra buffering
95
+ // space used. In the case of double buffering, the stage depth would
96
+ // be 2.
97
+ //
98
+ // A circular buffered loop structure would look like follows, which
99
+ // exactly parallels the case of double buffered loop structure, since
100
+ // it is a exact generalization to the same purpose.
101
+ //
102
+ // Here S is the original allocation size as above,
103
+ // D is the stage depth. With D=2, the below loop structure becomes
104
+ // exactly the same as the case in double buffering.
105
+ //
106
+ // allocate X[S*D] // allocation
107
+ // for i in 0..D-1: // prolog
108
+ // for j in ...
109
+ // if pred:
110
+ // x[i*S+j] = y[i, j];
111
+ //
112
+ // for i in 0..N: // main loop
113
+ // for j in ...
114
+ // if pred:
115
+ // x[((i+D-1)%D)*S+j] = y[i+D-1, j];
116
+ // for j in ...
117
+ // .. = x[(i%D)*S+j]
118
+ //
119
+ // (Epilog omitted since this only makes sense in using
120
+ // cp.async, where producer will be in global mem and consumer will
121
+ // be in shared mem).
122
+ //
123
+ // The profitability of this optimization comes from extra tolerance
124
+ // of global memory pipeline latency, as on the expression `.. = x[(i%D)*S+j]`
125
+ // we only need to make sure the data for the current iteration is
126
+ // completed while the remaining D-2 load iterations could still be in progress
127
+ // and overlap with the computes of the current loop.
128
+ //
129
+ // To express this pattern on sm80+ hardware we can group the loads
130
+ // in each iteration of the circular buffered loop as one "transaction",
131
+ // and specify how many transactions we want to ensure completion when
132
+ // we insert the async barriers.
133
+ //
134
+ // allocate X[S*D] // allocation
135
+ // for i in 0..D-1: // prolog
136
+ // for j in ...
137
+ // if pred:
138
+ // x[i*S+j] = y[i, j];
139
+ // cp.async.commit; // mark the transaction boundary
140
+ //
141
+ // # At this point we have D-1 transactions on the fly.
142
+ // and for the first iteration of the main loop we need
143
+ // one transaction completed, so we leave D-2 transactions
144
+ // on the fly, which would be the input to the barrier instruction.
145
+ //
146
+ // cp.async.wait D-2 // ensure all but the last D-2 transactions complete.
147
+ //
148
+ // for i in 0..N: // main loop
149
+ // # At this point we always have D-2 transactions on the fly.
150
+ // and one completed.
151
+ // for j in ...
152
+ // if pred:
153
+ // x[((i+D-1)%D)*S+j] = y[i+D-1, j];
154
+ // for j in ...
155
+ // .. = x[(i%D)*S+j]
156
+ // cp.async.commit; // mark the transaction boundary for the
157
+ // load issued in this iteration.
158
+ // # At this point we have D-1 transactions on the fly,
159
+ // and none completed.
160
+ // cp.async.wait D-2; // Ensure all but the last D-2 transactions complete.
161
+ // __syncthreads(); // Need to syncthreads because each thread will only
162
+ // ensure completion of its own async copies so
163
+ // would need to sync to this point to ensure
164
+ // completion of the whole tile.
165
+
166
+ namespace nvfuser {
167
+
168
+ class TmaCircularBufferInfo {
169
+ public:
170
+ // Map cpAsyncBulk to its tensor index
171
+ void recordTensorIndex(const Expr* expr, kir::TensorIndex* index);
172
+
173
+ // Check if tensor index exists for expression
174
+ bool existsTensorIndex(const Expr* expr) const;
175
+
176
+ // Get tensor index for expression
177
+ kir::TensorIndex* getTensorIndex(const Expr* expr);
178
+
179
+ private:
180
+ // Track mbarrier used for cpAsyncBulk load operation. Required by indexing
181
+ // pass.
182
+ std::unordered_map<const Expr*, kir::TensorIndex*> ldst_mbarrier_index_map_;
183
+ };
184
+
185
+ class CircularBufferPass {
186
+ public:
187
+ //! Apply circular buffering transformations
188
+ static std::vector<Expr*> run(const std::vector<Expr*>& exprs);
189
+ };
190
+
191
+ } // namespace nvfuser
@@ -0,0 +1,17 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/base_nodes.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ std::vector<Expr*> reorderExprsForComputeAt();
16
+
17
+ } // namespace nvfuser
@@ -0,0 +1,21 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <dispatch.h>
11
+ #include <fusion.h>
12
+ #include <ir/all_nodes.h>
13
+
14
+ #include <vector>
15
+
16
+ namespace nvfuser {
17
+
18
+ // Transpose, Shift, Gather, and View Ops with LoadStoreOps
19
+ std::vector<Expr*> loadStoreOpInserter(const std::vector<Expr*>& exprs);
20
+
21
+ } // namespace nvfuser