nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,28 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <compute_at_map.h>
11
+ #include <fusion.h>
12
+ #include <ir/all_nodes.h>
13
+ #include <visibility.h>
14
+
15
+ namespace nvfuser {
16
+
17
+ // Looks through all transformations assocaited with view, or enforced divisible
18
+ // vectorization splits and gathers all splits that provably don't have a
19
+ // remainder, therefore the extents of the associated IterDomains do not require
20
+ // a ceilDiv expressions.
21
+ NVF_API std::unordered_set<Split*> getAllDivisibleSplits(Fusion* fusion);
22
+
23
+ // Same as above but will use provided ComputeAtMap instead of building its own.
24
+ NVF_API std::unordered_set<Split*> getAllDivisibleSplits(
25
+ Fusion* fusion,
26
+ const ComputeAtMap* ca_map);
27
+
28
+ } // namespace nvfuser
@@ -0,0 +1,36 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/all_nodes.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ //! Keep track of certain patterns of reductions.
16
+ //!
17
+ //! - Allreduce IterDomain: reduced and broadcast domain.
18
+ class FusedReductionInfo {
19
+ public:
20
+ void markAsAllreduce(IterDomain* id);
21
+
22
+ bool isAllreduce(IterDomain* id) const;
23
+
24
+ private:
25
+ // Reduction IterDomains that are also broadcast
26
+ std::unordered_set<IterDomain*> allreduce_ids_;
27
+ };
28
+
29
+ //! Detect reductions and broadcasts that are eligible for the fused
30
+ //! reduction kernel. When found, the predicate flags of the broadcast
31
+ //! is unset, which effectively makes the broadcast just a unary set
32
+ //! op.
33
+ //! TODO: Consider moving the warp-based fused reduction here.
34
+ void fuseReductionsAndBroadcasts(Fusion*);
35
+
36
+ } // namespace nvfuser
@@ -0,0 +1,322 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <fusion.h>
12
+ #include <index_compute.h>
13
+
14
+ namespace nvfuser {
15
+
16
+ // Struct to hold useful information from an index pass on iterdomain graph.
17
+ // Used to return the IndexCompute structure back to the indexing calls in
18
+ // index_compute.cpp. Other structurs are required to resolve the actual
19
+ // indexing math there.
20
+ struct IndexFromIdGraph {
21
+ IndexCompute index;
22
+ IndexCompute concrete_index;
23
+ std::unordered_map<IterDomain*, Val*> initial_concrete_index_map;
24
+ std::vector<IterDomain*> resolved_loop_domains;
25
+
26
+ explicit IndexFromIdGraph(
27
+ IndexCompute index,
28
+ IndexCompute concrete_index,
29
+ std::unordered_map<IterDomain*, Val*> initial_concrete_index_map,
30
+ std::vector<IterDomain*> loop_domains);
31
+ };
32
+
33
+ //! Indexing interface, returns IndexFromIdGraph which the IndexCompute object
34
+ //! can be queried from directly for the produced indexing. If producer_tv !=
35
+ //! nullptr producer will be indexed, if producer_tv == nullptr consumer will be
36
+ //! indexed. If is_global global indexing will be done, else shared memory or
37
+ //! local indexing will be performed.
38
+ IndexFromIdGraph getTensorIndexFromIdGraph(
39
+ const std::vector<ForLoop*>& loops,
40
+ const std::unordered_set<ForLoop*>& rotated_loops,
41
+ const TensorView* consumer_tv,
42
+ const TensorView* producer_tv = nullptr,
43
+ bool is_global = true,
44
+ const std::unordered_map<IterDomain*, IterDomain*>& c2p_map = {});
45
+
46
+ //! Indexing interface for calculating predicate index returns IndexFromIdGraph
47
+ //! which the IndexCompute object can be queried from directly for the produced
48
+ //! indexing If is_start_predicate, will produce indexing math for the start
49
+ //! predicates.
50
+ IndexFromIdGraph getPredicateIndexingFromIdGraph(
51
+ const std::vector<ForLoop*>& loops,
52
+ const std::unordered_set<ForLoop*>& rotated_loops,
53
+ TensorView* consumer_tv,
54
+ ForLoop* unswitch_or_vec_loop,
55
+ IterDomain* circular_buffer_axis,
56
+ bool is_start_predicate);
57
+
58
+ //! getTensorIndexFromIdGraph is the function that index_compute will call very
59
+ //! straightforwardly. However, for implementing the new indexing logic that
60
+ //! starts to abstract some of the indexing away from index_compute we need to
61
+ //! move quite a bit of the intertwined indexing logic away from the
62
+ //! index_compute file and the index_reference_replay file. This is because we
63
+ //! want to separate out what has to be done on the fly, from what analysis we
64
+ //! can do early on with the iter domain graph and associated properties.
65
+ //!
66
+ //! getTensorIndexFromIdGraph places this analysis internally in
67
+ //! LoopIndexingAnalysis. LoopIndexingAnalysis though has to communicate to:
68
+ //! 1) index_compute.cpp::IndexCompute to tell IndexCompute which expressions
69
+ //! it needs to traverse to compute the indexing math.
70
+ //!
71
+ //! LoopIndexing is nothing but a mechanism for this communication.
72
+ //!
73
+ //! Holds information needed to produce indexing math. In the current version of
74
+ //! indexing pass, the iter domains combined with the loop nests are the source
75
+ //! of truth in terms of resolving the actual integer indexing math from the
76
+ //! sequence of iterdomain transforms.
77
+ //!
78
+ //! This information is crtiical in resolving indexing associated with complex
79
+ //! broadcast patterns. Check FusionComplexBCast* test cases as well as
80
+ //! Indexing* tests for examples where resolving indices from IterDomain
81
+ //! transformations can be challenging.
82
+ //!
83
+ //! The source of this challenge is due to inling patterns where the IterDomains
84
+ //! responsible for control flow are not local to a particular TensorView.
85
+ //! Broadcast, operations like view/reshape, and gather/shift can make indexing
86
+ //! local buffers complex because of the complex effects inlining into other
87
+ //! TensorViews produce.
88
+ //!
89
+ //! TODO:
90
+ //! The first iteration tries to match the semantics of reference
91
+ //! replay without any new logic. In a follow up iteration will
92
+ //! need to revisit a few further pathological patterns.
93
+ //!
94
+ //! Note:
95
+ //! The current implementation of loop indexing pass works on
96
+ //! equivalent classes defined by ComputeAt exact map. The
97
+ //! list of expressions stored in this class form a "reference", graph of
98
+ //! iterdomain expressions when all of their inputs and outputs are replaced
99
+ //! with their exact concrete mapped id's.
100
+ //!
101
+ //! Here an invariant in a graph of iterdomain expressions is that
102
+ //! each iterdomain is produced exactly once and is either a loop domain
103
+ //! or has been consumed exactly once by another expression. This makes sure
104
+ //! that a well defined indexing can be generated for each of the concrete ids
105
+ //! whenever we either forward or backward traverse the graph.
106
+ class LoopIndexing {
107
+ public:
108
+ //! Returns the original loop nest.
109
+ const auto& loops() const {
110
+ return loops_;
111
+ }
112
+
113
+ //! Returns the vector of Iterdomains
114
+ //! that match the original loop pattern.
115
+ const auto& loopDomains() const {
116
+ return loop_domains_;
117
+ }
118
+
119
+ const auto& loopRootDomains() const {
120
+ return loop_root_;
121
+ }
122
+
123
+ //! Returns the consumer tv that the view info
124
+ //! was derived from.
125
+ auto consumerTv() const {
126
+ return consumer_tv_;
127
+ }
128
+
129
+ //! Returns the set of Iterdomain transforms that
130
+ //! define the correct indexing path, in forward
131
+ //! topological order.
132
+ std::vector<Expr*> getForwardExprList() const;
133
+
134
+ //! Returns the set of Iterdomain transforms that
135
+ //! define the correct indexing path, in backward
136
+ //! topological order.
137
+ std::vector<Expr*> getBackwardExprList() const;
138
+
139
+ //! Returns the set of out of line expressions in
140
+ //! reverse topological order.
141
+ const std::vector<Expr*>& getBackwardOutOfLineExprList() const {
142
+ return out_of_line_exprs_;
143
+ }
144
+
145
+ //! Returns all exact concrete id's that were produced
146
+ //! or consumed in the selected indexing expressions
147
+ std::unordered_set<IterDomain*> getAllExactConcreteIdSet() const;
148
+
149
+ private:
150
+ friend class LoopIndexingAnalysis;
151
+
152
+ //! The loop nest that this loop indexing is derived from.
153
+ std::vector<ForLoop*> loops_;
154
+
155
+ //! Consumer tv, where the view related info was derived from.
156
+ const TensorView* consumer_tv_ = nullptr;
157
+
158
+ //! The source iterdomains that all the Iterdomain transforms
159
+ //! in this loop nest originated from.
160
+ std::vector<IterDomain*> loop_root_;
161
+
162
+ //! The loop iterdomains that the original loop nests correspond
163
+ //! to. May be longer than loops_ with the dangling iterdomains
164
+ //! appended towards the end.
165
+ std::vector<IterDomain*> loop_domains_;
166
+
167
+ //! The selected sequence of expressions that should represent
168
+ //! the correct indexing math from the given loop nest.
169
+ std::vector<Expr*> index_exprs_;
170
+
171
+ //! The subset of sequence of expressions that can be resolved
172
+ //! with only the iterdomains on the right of consumer tv's ca
173
+ //! axis.
174
+ //! Expressions are ordered in reverse topological order.
175
+ std::vector<Expr*> out_of_line_exprs_;
176
+ };
177
+
178
+ class LoopIndexingAnalysis {
179
+ public:
180
+ static LoopIndexing fromLoopAndConsumer(
181
+ const std::vector<ForLoop*>& loops,
182
+ const TensorView* consumer_tv);
183
+
184
+ //! Return all concrete IDs that can be reachable from a given list
185
+ //! of consumer loop IDs. Reachability is defined as the existence
186
+ //! an indexing path from the the loop IDs
187
+ static VectorOfUniqueEntries<IterDomain*> getReplayableConcreteIDs(
188
+ const std::vector<IterDomain*>& consumer_loop_ids,
189
+ const TensorView* consumer_tv);
190
+
191
+ private:
192
+ explicit LoopIndexingAnalysis(
193
+ const std::vector<ForLoop*>& loops,
194
+ const TensorView* consumer_tv);
195
+
196
+ explicit LoopIndexingAnalysis(
197
+ const std::vector<IterDomain*>& consumer_loop_ids,
198
+ const TensorView* consumer_tv);
199
+
200
+ void run();
201
+
202
+ //! Populate derived information into a LoopIndexing
203
+ //! data structure.
204
+ LoopIndexing getLoopIndexing(const std::vector<ForLoop*>& loops) {
205
+ LoopIndexing indexing;
206
+ indexing.loops_ = loops;
207
+ indexing.consumer_tv_ = consumer_tv_;
208
+ indexing.loop_root_ = loop_root_domains_;
209
+ indexing.loop_domains_ = loop_domains_.vector();
210
+ indexing.index_exprs_ = replayed_exprs_;
211
+ indexing.out_of_line_exprs_ = out_of_line_exprs_;
212
+ return indexing;
213
+ }
214
+
215
+ //! Validates that the current loop structure is well formed, in the sense
216
+ //! that ca_map would not map any two loops in the loop nest together.
217
+ void validateLoopStructure(const std::vector<ForLoop*>& loops);
218
+
219
+ //! Start at the loop iter domains, and traverse back into history on the
220
+ //! concrete IDs in the exact map calling "visitExpr" expressions through the
221
+ //! history.
222
+ void traverseFromDomainVals();
223
+
224
+ //! Concretize the given iterdomain and record the visit (in deterministic
225
+ //! order) in terms of the exact mapped concrete id. Marks the mapping of the
226
+ //! id to the concrete id in "concrete_to_original_id_" and returns the
227
+ //! concrete id.
228
+ IterDomain* concretizeAndVisitId(IterDomain* id);
229
+
230
+ //! If an equivalent expression has already been processed this function
231
+ //! simply returns. Otherwise puts the exact concrete IDs of inputs in
232
+ //! consumed_concrete_, and concrete IDs of outputs in produced_concrete_.
233
+ //! Then adds the expression to replayed_exprs_.
234
+ void visitExpr(Expr* expr);
235
+
236
+ //! Iterates through provided vals, calls concretizeAndVisitId on them, and
237
+ //! returns if any of the returned vals are in existing_ids. This is used to
238
+ //! check if inputs or outputs of ID expressions have already been
239
+ //! produced/consumed in the traversal. Indexing only needs to consume/produce
240
+ //! one IterDomain per exact disjoint set.
241
+ bool visitIdsAndCheckDuplication(
242
+ const std::vector<Val*>& vals,
243
+ const std::unordered_set<IterDomain*>& existing_ids);
244
+
245
+ //! Fills loop_domains_ with the corresponding replayed_concrete_id mapping to
246
+ //! the provided loops. Must be done after the exact iterdomain "replay"
247
+ //! (traverseFromDomainVals). loop_domains_ are the original_id not the
248
+ //! concrete_id (translated with concrete_to_original_id). These iter domains
249
+ //! are used to grab the history that will be replayed in IndexCompute. We're
250
+ //! looking for "new" root domains and subsequent transformations, filling in
251
+ //! any missing "outputs" (or inputs for backward traversal). Then fills
252
+ //! loop_domains_ with all of these iter domains.
253
+ void constructLoopDomains();
254
+
255
+ //! Fills out_of_line_exprs_ by traversing the selected list of
256
+ //! expressions in reverse topological order and collect iterdomains
257
+ //! on the indexing paths that only involves loop id's on the right
258
+ //! of consumer's ca axis.
259
+ void collectOutOfLineExprs();
260
+
261
+ private:
262
+ //! Original consumer tv to derive view info from.
263
+ const TensorView* consumer_tv_ = nullptr;
264
+
265
+ // Exact concrete domains that has been used
266
+ // in the traversal connection.
267
+ std::unordered_set<IterDomain*> produced_concrete_;
268
+ std::unordered_set<IterDomain*> consumed_concrete_;
269
+
270
+ //! Iterdomains that the corresponding loops are generated from.
271
+ std::vector<IterDomain*> initial_loop_domain_ids_;
272
+
273
+ //! All Id's in consumer's transform history
274
+ std::vector<Val*> all_consumer_id_vals_;
275
+
276
+ //! Concrete iterdomains visited in the domain traversal,
277
+ //! in the order they are visited in traverseFromDomainVals.
278
+ VectorOfUniqueEntries<IterDomain*> replayed_concrete_ids_;
279
+
280
+ //! Keeping track of the original visited id's before they
281
+ //! were concretized.
282
+ std::unordered_map<IterDomain*, IterDomain*> concrete_to_original_id_;
283
+
284
+ //! Map from concrete id to its single consumer on the selected
285
+ //! iterdomain expression list.
286
+ std::unordered_map<IterDomain*, Expr*> concrete_id_to_consumer_;
287
+
288
+ //! Source domains that all the Iterdomain transforms
289
+ //! in the loop nest originated from.
290
+ std::vector<IterDomain*> loop_root_domains_;
291
+
292
+ //! Leaf domains representing the original loop structure
293
+ VectorOfUniqueEntries<IterDomain*> loop_domains_;
294
+
295
+ //! Selected list of exprs that will produce and consume each
296
+ //! of the exact concrete ids from the loop nest exactly once.
297
+ std::vector<Expr*> replayed_exprs_;
298
+
299
+ //! Set of expressions from the selected list that can be
300
+ //! resolved from axes on the right of ca axes.
301
+ std::vector<Expr*> out_of_line_exprs_;
302
+ };
303
+
304
+ // When indexing there are sometimes an option to propagate an index down
305
+ // multiple paths. This will return the IterDomains in the history of the
306
+ // reference domain and mark which paths should be taken (if there's a
307
+ // preference) to reach the roots provided in preferred_roots.
308
+ std::unordered_set<IterDomain*> buildLoopIndexingPreferredPath(
309
+ const TensorView* original_tv,
310
+ const LoopIndexing& loop_indexing,
311
+ bool use_replay_map = false,
312
+ std::unordered_map<IterDomain*, IterDomain*> p2c_map = {});
313
+
314
+ // Get an logical IterDomain that is mapped with an IterDomain. If
315
+ // multiple such IDs exist, select one whose input IDs are mapped with
316
+ // the consumer IDs. This is to ensure the path from the loop
317
+ // IterDomains to the root matches with the consumer tensor.
318
+ IterDomain* getLogicalIDToTraverse(
319
+ IterDomain* id,
320
+ const std::vector<Val*>& consumer_all_ids);
321
+
322
+ } // namespace nvfuser
@@ -0,0 +1,71 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+ #include <exceptions.h>
10
+
11
+ #include <ir/all_nodes.h>
12
+ #include <kernel_ir.h>
13
+
14
+ #include <vector>
15
+
16
+ namespace nvfuser {
17
+
18
+ class PredicateElimination : public IterVisitor {
19
+ public:
20
+ PredicateElimination(Fusion* fusion);
21
+
22
+ //! True if expr does not need a predicate
23
+ //!
24
+ //! \param expr Tensor expression
25
+ bool canOmitPredicate(const Expr* expr) const;
26
+
27
+ bool needsSharedMemoryPredicate(const Expr* expr) const;
28
+
29
+ //! Value to initialize out-of-bound regions
30
+ Val* getInitValue(TensorView* tv) const;
31
+
32
+ //! Dump to string for debugging
33
+ std::string toString() const;
34
+
35
+ // A utility to set removal info of `to` the same as `from`.
36
+ // See issue #1641
37
+ // We build predicate info before lowering but more expressions
38
+ // are created during lowering that this class also need to
39
+ // keep track of to make sure correct predicate removal is
40
+ // applied.
41
+ // This utility is a quick patch for the missing information
42
+ // since it might be better just to recompute predicate info
43
+ // if all expressions were mutated, but that'd take much more
44
+ // global info to reliably track.
45
+ void propagateRemovalInfo(const Expr* from, const Expr* to);
46
+
47
+ const std::unordered_set<const Expr*>& getNonPredicatedExprs() const {
48
+ return non_predicated_exprs_;
49
+ }
50
+
51
+ private:
52
+ using IterVisitor::handle;
53
+
54
+ void dispatch(Expr* expr) final;
55
+
56
+ //! Set a value to initialize out-of-bound regions
57
+ bool setDefaultInitValue(TensorView* tv);
58
+ //! Set a value to initialize out-of-bound regions of reduction tensors
59
+ bool setReductionInitValue(TensorView* tv, Val* reduction_init);
60
+
61
+ //! Check if expr needs to be predicated
62
+ bool needsPredicate(Expr* expr) const;
63
+
64
+ private:
65
+ //! Expressions that are found to be safe without predicates
66
+ std::unordered_set<const Expr*> non_predicated_exprs_;
67
+ //! Tensors and their initialization values
68
+ std::unordered_map<TensorView*, Val*> init_value_map_;
69
+ };
70
+
71
+ } // namespace nvfuser
@@ -0,0 +1,47 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/all_nodes.h>
12
+ #include <parallel_type_bitmap.h>
13
+ #include <visibility.h>
14
+
15
+ #include <unordered_map>
16
+
17
+ namespace nvfuser {
18
+
19
+ class SyncMap {
20
+ public:
21
+ //! Validates all tensors are consistently parallelized. Basically,
22
+ //! when a producer axis is threaded, either with threadIdx or
23
+ //! blockIdx, there must be a mapped consumer axis with the
24
+ //! same ParallelType with some exceptions.
25
+ //!
26
+ //! ComputeAtMap is already built as they are used to validate consistency.
27
+ //!
28
+ //! Fills needs_raw_sync with output TVs if they need a raw sync if on smem or
29
+ //! gmem. The second entry in this map is the parallel dimensions being
30
+ //! communicated across.
31
+ NVF_API SyncMap(Fusion* fusion);
32
+
33
+ std::string toString() const;
34
+
35
+ ParallelTypeBitmap needsRawSync(TensorView* tv) const {
36
+ auto it = needs_raw_sync_.find(tv);
37
+ if (it != needs_raw_sync_.end()) {
38
+ return it->second;
39
+ }
40
+ return ParallelTypeBitmap();
41
+ }
42
+
43
+ private:
44
+ std::unordered_map<TensorView*, ParallelTypeBitmap> needs_raw_sync_;
45
+ };
46
+
47
+ } // namespace nvfuser
@@ -0,0 +1,65 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ namespace nvfuser {
11
+
12
+ class Fusion;
13
+
14
+ // Information used to lower tensor memory. So far, there is no information
15
+ // needed, the computeTMemInfo just check that there is only one tensor on TMem
16
+ // in the fusion. This limitation is described in the note below, and it is only
17
+ // for incremental development. This limitation will be removed soon in the
18
+ // future.
19
+ struct TensorMemoryInfo;
20
+ TensorMemoryInfo computeTMemInfo(Fusion* fusion);
21
+
22
+ // Note: [Tensor Memory Allocation]
23
+ //
24
+ // Tensor memory is a very special memory, so its allocation is also very
25
+ // different from other memory types.
26
+ //
27
+ // It is highly recommended to read the PTX documentation for tensor memory
28
+ // if you are not alreay familiar with it:
29
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory
30
+ //
31
+ // The first thing to note is, TMem does not have virtualization. This means:
32
+ // We can not just allocate starting from address 0 like how we allocate shared
33
+ // memory, and rely on page table to translate the same virtual address of
34
+ // different CTA to different physical address. There is no virtual TMem
35
+ // address. All addresses are physical addresses.
36
+ //
37
+ // Because multiple CTAs can execute on the same SM simultaneously, there must
38
+ // be some handshaking mechanism for each CTA to know the region of TMem that it
39
+ // can use. This is done by using the PTX instruction tcgen05.alloc. To ensure
40
+ // safety, there is a mutex "I have the right to allocate TMem" in the
41
+ // hardware. At the beginning of each CTA, the CTA will try to acquire the mutex
42
+ // automatically. If it fails, the CTA will be blocked until the mutex is free.
43
+ // This means, only one CTA can allocate TMem at a time. Once the CTA has
44
+ // finished allocating TMem, it should release the mutex to relinquish the right
45
+ // to allocate. After the right to allocate is relinquished, this CTA can not
46
+ // allocate new TMem any more, but it can still access the TMem that it has
47
+ // allocated, and it can also free the TMem that it has allocated. Once one CTA
48
+ // relinquishes the right to allocate, the next CTA that is blocked will be
49
+ // unblocked and can acquire the mutex to allocate TMem.
50
+ //
51
+ // Currently, the TMem allocation is not supported in nvFuser. We currently only
52
+ // allow one TensorView to be on TMem, and because we never relinquish the right
53
+ // to allocate TMem, CTA will be serialized on SM. A new CTA can be scheduled on
54
+ // an SM only after the previous CTA on that SM has completely finished
55
+ // executing. Thanks to this serialization, we can just skip allocating and
56
+ // think that our only TMem TensorView own the entire TMem, because we are sure
57
+ // that there will not be another CTA using that address. As a result, we could
58
+ // just provide address 0 to our instructions that access TMem. In principle, it
59
+ // is clearly wrong to write to an address that is not allocated, but because we
60
+ // are sure that it will in practice work for the specific unit test that we are
61
+ // targeting, we just do it so we have incremental development.
62
+
63
+ struct TensorMemoryInfo {};
64
+
65
+ } // namespace nvfuser