nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,651 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <iter_visitor.h>
12
+ #include <logical_domain_map.h>
13
+
14
+ #include <unordered_map>
15
+ #include <unordered_set>
16
+ #include <vector>
17
+
18
+ /*
19
+ * Index compute takes in a list of indices typically generated from the
20
+ * surrounding for loop nest. The number of indicies are intended to match the
21
+ * number of dimensions of the incomming TensorView which may have less or more
22
+ * dimensions than its allocation domain due to split/merge operations.
23
+ * Split/merge operations are then replayed backwards produce resulting
24
+ * indices (based on input indices) that match the allocation dimension.
25
+ *
26
+ * For example with GLOBAL tensor:
27
+ * TV[I, K]
28
+ * TV[Io, Ii{4}, K] = TV.split(I, factor=4)
29
+ * ALLOC: NONE
30
+ * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
31
+ * FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k}
32
+ * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
33
+ *
34
+ *
35
+ * For example with SHARED tensor:
36
+ *
37
+ * global_TV[I, K]
38
+ * global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4)
39
+ * smem_TV.compute_at(global_TV, 1)
40
+ * global_TV.parallelize(1, threadIDx.x)
41
+ *
42
+ * ALLOC: alloc(smem_TV, 4 x K)
43
+ * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
44
+ * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k}
45
+ * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
46
+ * global
47
+ *
48
+ *
49
+ * For example with LOCAL tensor:
50
+ * global_TV[I, K, L]
51
+ * global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4)
52
+ * reg_TV.compute_at(global_TV, 2)
53
+ * global_TV.parallelize(1, threadIDx.x)
54
+ * global_TV{i, j, k, l} -> { i * 4 + j, k, l }
55
+ * global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l}
56
+ *
57
+ * ALLOC: alloc(reg_TV, K x L)
58
+ * INDEX: {k, l} -> {k, l}
59
+ * FLATTENED_INDEX: {k, l} -> {k * L + l}
60
+ * PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global
61
+ *
62
+ * These indices can then be flattened later based on strides.
63
+ */
64
+
65
+ namespace nvfuser {
66
+
67
+ class ContigIDs;
68
+ class LoopIndexing;
69
+ struct IndexFromIdGraph;
70
+ class TensorIndexer;
71
+
72
+ class IndexCompute : public BackwardVisitor {
73
+ protected:
74
+ using BackwardVisitor::handle;
75
+
76
+ void dispatch(Expr*) override;
77
+
78
+ void handle(Split*) override;
79
+ void handle(Merge*) override;
80
+ void handle(Swizzle*) override;
81
+ void handle(Swizzle2D*) override;
82
+ void handle(Resize*) override;
83
+
84
+ // return extent_map_[id] if exists, else return id->extent()
85
+ Val* getExtent(IterDomain* id) const;
86
+
87
+ //! True if a domain is not used to index
88
+ bool isZero(IterDomain* id) const;
89
+ //! True if any dependent of a domain is not used to index
90
+ bool hasZeroMerged(IterDomain* id) const;
91
+
92
+ //! Returns the concrete ID from the compute at EXACT mode map if
93
+ //! concrete_id_pass == true, otherwise returns id passed in.
94
+ //! Helps unify the expr handling logic in reference domain and concrete id
95
+ //! based traversal.
96
+ IterDomain* maybeGetExactMapConcreteID(IterDomain* id) const;
97
+
98
+ //! (Concrete indexing pass only)
99
+ //! Collect permissive index binding from the given expression.
100
+ //! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList.
101
+ void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing);
102
+
103
+ //! (Concrete indexing pass only)
104
+ //! Iterate through id_expr's input and pull index vals from permissive
105
+ //! map, when both of the following are true:
106
+ //! 1. the output id is missing in index_map_.
107
+ //! 2. the output id is found in permissive map.
108
+ void updateIndexMapFromPermissiveMap(const Expr* id_expr);
109
+
110
+ //! Initialize unswitched_domain_map_ from the loop unswitched
111
+ //! domains
112
+ void initializeUnswitchDomainMap();
113
+
114
+ //! Propagate unswitched map info from expr outputs to inputs
115
+ void updateUnswitchedDomains(Expr* expr);
116
+
117
+ //! Query if an IterDomain has a dependent unswitched domain
118
+ bool hasUnswitchedDependentDomains(IterDomain* id) const;
119
+
120
+ //! Query if the usual modulo propagation may be invalid for a merge
121
+ //! inner path
122
+ bool isModuloInvalidUnswitchedIndex(
123
+ IterDomain* out_concrete_id,
124
+ Val* out_ind,
125
+ Val* inner_extent) const;
126
+
127
+ // Tensor domain we're mapping back to allocation
128
+ const TensorDomain* td_; // NOLINT
129
+
130
+ // Map we update as we propagate backward, containing all IDs in the
131
+ // propagation. Initial indices are mapped with this map at tv->domain()
132
+ // and are back propagated to tv->getMaybeAllocationDomain(). This index_map_
133
+ // keeps the indices at intermediate IterDomain's in that back propagation.
134
+ std::unordered_map<IterDomain*, Val*> index_map_; // NOLINT
135
+
136
+ // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
137
+ // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
138
+ // the extent I0*I1. Also contains updated extents if we merge in a 0 index.
139
+ // See zero_merged_in_.
140
+ std::unordered_map<IterDomain*, Val*> extent_map_; // NOLINT
141
+
142
+ // Keeps track of domains that do not contribute to indexing
143
+ std::unordered_set<IterDomain*> zero_domains_; // NOLINT
144
+
145
+ // This set keeps track of IterDomain's that have had a zero index merged into
146
+ // them. This happens if we do something like tv->axis(0)->split(4) then
147
+ // tv->computeAt(1, ...) if this tensor is in smem or lmem the backward
148
+ // indexing would be (0, i) then when we do the backward computation that zero
149
+ // and i would attempt to be merged together. We handle indices like these
150
+ // specially.
151
+ std::unordered_set<IterDomain*> zero_merged_in_;
152
+
153
+ // IDs that are a result of contiguous merges
154
+ std::unordered_set<IterDomain*> contig_ids_;
155
+
156
+ // Mentions if we should propagate an index down a particular IterDomain path
157
+ // if there's an option
158
+ std::unordered_set<IterDomain*> preferred_paths_;
159
+
160
+ // Temporary flag which tells IndexCompute to use concrete id's from the exact
161
+ // map rather than the actual IDs used in the ID expressions.
162
+ bool concrete_id_pass_ = false;
163
+
164
+ // Mode of swizzle that are activated in this index compute
165
+ // instance. Will treat swizzles of different mode as no-op.
166
+ // Currently data mode swizzles are handled same as before in IndexSwizzle
167
+ // pass, while loop mode swizzles are handled early on in concrete indexing
168
+ // pass. See also [Note on swizzle mode]
169
+ SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;
170
+
171
+ // (Concrete id pass only)
172
+ // Contains the indexing math that could be resolved with only the
173
+ // iterdomains on the right of the consumer_tv's ca axis, i.e. the
174
+ // ones that corresponding to the loops that consumer_tv would not
175
+ // share with any of its consumers.
176
+ // These indexing vals should be kept separate from index_map_ and
177
+ // should only be used when the indexing traversal follows the
178
+ // order defined in LoopIndexingAnalysis::traverseFromDomainVals.
179
+ std::unordered_map<IterDomain*, Val*> permissive_index_map_;
180
+
181
+ //! Leaf domains that have maximum index values for unswitch
182
+ //! predicates. These domains need extra adjustments when going
183
+ //! through module operations for merge inner domains as module does
184
+ //! not always guarantee to preserve the maximum-ness property
185
+ std::unordered_set<IterDomain*> unswitched_loop_domains_;
186
+
187
+ //! Mapppings from unswitched IterDomains to their unswitched
188
+ //! domains and their inner domains. Used to figure out if a module
189
+ //! could invalidate the maximum-ness property of an unswitched index.
190
+ //!
191
+ //! Mappings are created in a bottom-up fashion from loop to root
192
+ //! such that fine-grained domain mappings are kept as much as
193
+ //! possible for making the modulo analysis most precise.
194
+ //!
195
+ //! Specifically, for the loop domains, this just maps unswitched
196
+ //! domains, i.e., those included in unswitched_loop_domains_, to
197
+ //! themselves. There'll be no mapping for those loop domains that
198
+ //! are not included in unswitched_loop_domains_. The mappings of
199
+ //! all other domains are defined based on their consumer
200
+ //! domains. By default, they are also just mapped
201
+ //! to themselves if any of the consumers are also mapped. However,
202
+ //! when a domain is the input to a split, the mappings of the split output
203
+ //! domains are tracked separately and the split input will be
204
+ //! mapped to two sets of unswitched domains, one from the inner
205
+ //! output and another from the outer output. The mapping info from
206
+ //! the inner output is propagated as is, whereas the mapping info
207
+ //! from the outer output is prepended with the inner output
208
+ //! domain so that the unswitched domain list includes its inner
209
+ //! domain. Note that the semantic of inner domains is defined based
210
+ //! on split operations since they define propagated index math.
211
+ //!
212
+ //! The reason of tracking the information from split outer domains
213
+ //! separately is to avoid adjusting the unswitched predicate index
214
+ //! as much as possible. For example, here's a common transpose
215
+ //! scheduling pattern:
216
+ //!
217
+ //! // Initial 2D tensor
218
+ //! [i0, i1]
219
+ //! // Create a square tile of 32x32
220
+ //! -> [i0 / 32, 32, i1 / 32, 32]
221
+ //! -> [i0 / 32 * i1 / 32, 32 * 32]
222
+ //! // Factor out a small domain (commonly vectorized)
223
+ //! -> [i0 / 32 * i1 / 32, 32 * 32 / 4, 4]
224
+ //! // Factor out another domain (commonly parallelized by TIDx)
225
+ //! -> [i0 / 32 * i1 / 32, 32 * 32 / 4 / 128, 128, 4]
226
+ //!
227
+ //! Notice that the merge of "32 * 32" is not contiguous, so we need
228
+ //! to predicate its input domains by propagating index exprs
229
+ //! through the merge inner path with "% 32". If any of the final
230
+ //! loop domains are unswitched, we need to make sure the index expr
231
+ //! sent through "% 32" is the maximum for the domain of extent
232
+ //! "32". Conservatively, this can just be 31, however, that isn't
233
+ //! always strictly required. For example, suppose the innermost
234
+ //! domain of extent 4 is unswitched. Its initial index is
235
+ //! 3. Propagating it through the merge inner path as usual is
236
+ //! guaranteed to be correct. More generally, it's always the case
237
+ //! when the inner extent of a merge is divisible by the extent of
238
+ //! an unswitched output and its domains. Suppose also the third
239
+ //! innermost domain is also unswitched, its initial index is 1. Its
240
+ //! contribution through the merge inner path is zero as the initial
241
+ //! index is multiplied by the extents of its inner domains, i.e.,
242
+ //! 128 and 4, and they are divisible by the extent of the merge
243
+ //! inner domain. Again, more generally, if the stride of an
244
+ //! unswitched domain is a multiple of the inner extent of the merge
245
+ //! operation producing the unswitched domain, there's no
246
+ //! contribution from the unswitched domain, so it doesn't matter if
247
+ //! it's maximum or not.
248
+ //!
249
+ //! In the above pattern, the second innermost domain is commonly
250
+ //! parallelized with TIDx. Suppose it's also unswitched. Notice
251
+ //! that there's no concern for that domain of invalding the
252
+ //! maximum-ness property as threadIdx.x is the only valid initial
253
+ //! index value for each thread. However, this is the reason we keep track
254
+ //! of the split output contributions separately. More specifically,
255
+ //! the intermediate domain of (32 * 32 / 4) will have an index of
256
+ //! (1 * 128 + threadIdx.x), and the domain of (32 * 32) will have
257
+ //! (1 * 128 * 4 + threadIdx.x * 4 + 3). As discussed above, we can
258
+ //! reason about that the first and third components of this
259
+ //! unswitched expression is safe with respect to the propagation
260
+ //! with modulo by 32. The second component is also safe as that's
261
+ //! the only valid index for the domain. If not separately tracked,
262
+ //! all we could know would be that the extent of (32 * 32) is
263
+ //! 1024. Since part of the dependent domains are parallelized the
264
+ //! propagated index is not guaranteed to be 1023, so we would need
265
+ //! to make a conservative decision to send 1023 to the merge inner
266
+ //! path.
267
+ std::unordered_map<IterDomain*, std::vector<std::deque<IterDomain*>>>
268
+ unswitched_domain_map_;
269
+
270
+ public:
271
+ const std::unordered_map<IterDomain*, Val*>& indexMap() const {
272
+ return index_map_;
273
+ }
274
+
275
+ const std::unordered_map<IterDomain*, Val*>& extentMap() const {
276
+ return extent_map_;
277
+ }
278
+
279
+ const std::unordered_set<IterDomain*>& zeroDomains() const {
280
+ return zero_domains_;
281
+ }
282
+
283
+ const std::unordered_set<IterDomain*>& zeroMergedIn() const {
284
+ return zero_merged_in_;
285
+ }
286
+
287
+ // Propagate back from _td using initial_index_map
288
+ IndexCompute(
289
+ const TensorDomain* _td,
290
+ std::unordered_map<IterDomain*, Val*> initial_index_map,
291
+ std::unordered_map<IterDomain*, Val*> _extent_map,
292
+ std::unordered_set<IterDomain*> zero_domains,
293
+ std::unordered_set<IterDomain*> _zero_merged_in,
294
+ std::unordered_set<IterDomain*> preferred_paths = {});
295
+
296
+ IndexCompute(
297
+ const TensorDomain* _td,
298
+ std::unordered_map<IterDomain*, Val*> initial_index_map,
299
+ std::unordered_map<IterDomain*, Val*> _extent_map,
300
+ std::unordered_set<IterDomain*> zero_domains,
301
+ std::unordered_set<IterDomain*> _zero_merged_in,
302
+ const ContigIDs& contig_finder,
303
+ std::unordered_set<IterDomain*> preferred_paths = {},
304
+ std::unordered_set<IterDomain*> unswitched_domains = {});
305
+
306
+ // Entry point used for using concrete id based traversal. This traversal is
307
+ // assumed to start at loop IDs provided by initial_index_map.
308
+ IndexCompute(
309
+ std::unordered_map<IterDomain*, Val*> initial_index_map,
310
+ std::unordered_set<IterDomain*> zero_domains,
311
+ std::unordered_set<IterDomain*> preferred_paths,
312
+ std::unordered_set<IterDomain*> unswitched_domains = {});
313
+
314
+ // Updates index_map, extent_map, and zero_merged_in based on id_map and
315
+ // returns a new IndexCompute ready to be used.
316
+ IndexCompute updateIndexCompute(
317
+ const TensorDomain* new_td,
318
+ const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
319
+ id_map,
320
+ const ContigIDs& contig_finder) const;
321
+
322
+ // Interface to run index traversal through loop indexing analysis result to
323
+ // be used with the entry point for concrete id based traversal.
324
+ void run(const LoopIndexing& loop_indexing);
325
+
326
+ virtual void run();
327
+ };
328
+
329
+ //! Apply swizzle and update allocation indices accordingly
330
+ class IndexSwizzle : public IndexCompute {
331
+ public:
332
+ IndexSwizzle(
333
+ const TensorView* tv,
334
+ std::unordered_map<IterDomain*, Val*> initial_index_map,
335
+ std::unordered_map<IterDomain*, Val*> extent_map,
336
+ std::unordered_set<IterDomain*> zero_domains,
337
+ std::unordered_set<IterDomain*> zero_merged_in);
338
+
339
+ IndexSwizzle(
340
+ const TensorView* tv,
341
+ const TensorDomain* domain,
342
+ std::unordered_map<IterDomain*, Val*> initial_index_map,
343
+ std::unordered_map<IterDomain*, Val*> extent_map,
344
+ std::unordered_set<IterDomain*> zero_domains,
345
+ std::unordered_set<IterDomain*> zero_merged_in);
346
+
347
+ void run() override;
348
+
349
+ protected:
350
+ using IndexCompute::handle;
351
+
352
+ void dispatch(Expr* e) override;
353
+
354
+ void handle(Swizzle2D* swizzle_2d) override;
355
+
356
+ private:
357
+ const TensorView* tv_ = nullptr;
358
+ std::unordered_set<IterDomain*> swizzled_ids_;
359
+ };
360
+
361
+ //! Information about a predicate. By default, it corresponds to a
362
+ //! single logical domain but may cover multiple logial domains due to
363
+ //! contigous indexing.
364
+ class PredicateInfo {
365
+ friend class Index;
366
+ friend class TensorIndexer;
367
+
368
+ public:
369
+ const auto& startPredicate() const {
370
+ return start_predicate_;
371
+ }
372
+
373
+ auto& startPredicate() {
374
+ return start_predicate_;
375
+ }
376
+
377
+ const auto& startOffset() const {
378
+ return start_offset_;
379
+ }
380
+
381
+ const auto& stopPredicate() const {
382
+ return stop_predicate_;
383
+ }
384
+
385
+ const auto& stopOffset() const {
386
+ return stop_offset_;
387
+ }
388
+
389
+ const auto& predicatedDomains() const {
390
+ return predicated_domains_;
391
+ }
392
+
393
+ const auto& loopDomains() const {
394
+ return loop_domains_;
395
+ }
396
+
397
+ CircularBufferLoopStage loopStage() const {
398
+ return loop_stage_;
399
+ }
400
+
401
+ //! Return a false RootPredicateInfo, i.e., both start and stop
402
+ //! predicates are false.
403
+ static PredicateInfo getFalseInfo();
404
+
405
+ private:
406
+ // prdicate for lower end
407
+ Val* start_predicate_ = nullptr;
408
+ // prdicate for upper end
409
+ Val* stop_predicate_ = nullptr;
410
+ // Offset of the start predicate
411
+ Val* start_offset_ = nullptr;
412
+ // Offset of the stop predicate
413
+ Val* stop_offset_ = nullptr;
414
+ // Track which domains are covered by the generated predicates
415
+ std::unordered_set<IterDomain*> predicated_domains_;
416
+ // Loops domains used for the predicate domains
417
+ std::unordered_set<IterDomain*> loop_domains_;
418
+ // Circular buffer loop stage if applicable
419
+ CircularBufferLoopStage loop_stage_ = CircularBufferLoopStage::NotApplicable;
420
+ };
421
+
422
+ // Simple interface for IndexCompute
423
+ // If getComputeAtAxis and more generally TensorView const model is fixed, we
424
+ // can make the below tensorviews const.
425
+ class Index {
426
+ private:
427
+ // Producer indexing if it's in shared or local memory
428
+ static std::vector<Val*> getNonGlobalProducerStridedIndices(
429
+ TensorView* producer,
430
+ const TensorView* consumer,
431
+ const std::vector<ForLoop*>& loops,
432
+ const std::unordered_set<ForLoop*>& rotated_loops,
433
+ const std::unordered_map<IterDomain*, Val*>& override_index = {});
434
+
435
+ // Consumer indexing if it's in shared or local memory
436
+ static std::vector<Val*> getNonGlobalConsumerStridedIndices(
437
+ const TensorView* consumer,
438
+ const std::vector<ForLoop*>& loops,
439
+ const std::unordered_set<ForLoop*>& rotated_loops,
440
+ const std::unordered_map<IterDomain*, Val*>& override_index = {});
441
+
442
+ // get the strides of a tensor used for the index lowering
443
+ static std::vector<Val*> getStrides(TensorView* tv);
444
+
445
+ // get the allocation indices of a consumer tensor
446
+ static std::vector<Val*> getConsumerAllocationIndices(
447
+ const TensorView* tv,
448
+ const std::vector<ForLoop*>& loops,
449
+ const IndexFromIdGraph& index_from_id_graph);
450
+
451
+ // get the allocation indices of a producer tensor
452
+ static std::vector<Val*> getProducerAllocationIndices(
453
+ TensorView* producer,
454
+ const TensorView* consumer,
455
+ const std::vector<ForLoop*>& loops,
456
+ const std::unordered_set<ForLoop*>& rotated_loops,
457
+ const std::unordered_map<IterDomain*, Val*>& override_index = {});
458
+
459
+ public:
460
+ // Producer if it's in global memory
461
+ static std::vector<Val*> getGlobalProducerStridedIndices(
462
+ TensorView* producer,
463
+ const TensorView* consumer,
464
+ const std::vector<ForLoop*>& loops,
465
+ const std::unordered_set<ForLoop*>& rotated_loops,
466
+ const std::unordered_map<IterDomain*, Val*>& override_index = {});
467
+
468
+ // Consumer indexing if it's in global memory
469
+ static std::vector<Val*> getGlobalConsumerStridedIndices(
470
+ TensorView* consumer,
471
+ const std::vector<ForLoop*>& loops,
472
+ const std::unordered_set<ForLoop*>& rotated_loops,
473
+ const std::unordered_map<int, Val*>& override_index = {});
474
+
475
+ // Indexing functions
476
+ // Consumer = Producer
477
+ // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
478
+ // Producer indexing dispatch
479
+ // The argument `generate_pointer` specifies whether to generate pointer for
480
+ // the tensor. If global tensor, then generate T1.data. If shared memory
481
+ // tensor, then use `cvta` ptx to convert shared memory address to unsigned
482
+ // int for indexing. Search `toSmem` in the codebase for additional
483
+ // information. This argument is effective only if the indexed tensor is a
484
+ // shared memory or global tensor. On other memory type, this argument will
485
+ // cause an error.
486
+ static kir::TensorIndex* getProducerIndex(
487
+ TensorView* producer,
488
+ const TensorView* consumer,
489
+ const std::vector<ForLoop*>& loops,
490
+ const std::unordered_set<ForLoop*>& rotated_loops,
491
+ const std::unordered_map<IterDomain*, Val*>& override_index = {},
492
+ bool generate_pointer = false,
493
+ DataType as_type = DataType::Null);
494
+
495
+ // Consumer index dispatch
496
+ static kir::TensorIndex* getConsumerIndex(
497
+ TensorView* consumer,
498
+ const std::vector<ForLoop*>& loops,
499
+ const std::unordered_set<ForLoop*>& rotated_loops,
500
+ const std::unordered_map<int, Val*>& override_index = {},
501
+ bool generate_pointer = false,
502
+ DataType as_type = DataType::Null);
503
+
504
+ //! Returns a vector of strided indices mapped onto the
505
+ //! allocation domain of a producer tensor. The size of the returned
506
+ //! vector is guaranteed to be equal to the number of axes of the
507
+ //! indexing allocation domain.
508
+ static Val* getProducerStridedIndices(
509
+ TensorView* producer,
510
+ const TensorView* consumer,
511
+ const std::vector<ForLoop*>& loops,
512
+ const std::unordered_set<ForLoop*>& rotated_loops,
513
+ const std::unordered_map<IterDomain*, Val*>& override_index = {},
514
+ bool generate_pointer = false);
515
+
516
+ //! Returns a vector of strided indices mapped onto the
517
+ //! allocation domain of a consumer tensor. The size of the returned
518
+ //! vector is guaranteed to be equal to the number of axes of the
519
+ //! indexing allocation domain.
520
+ static Val* getConsumerStridedIndices(
521
+ TensorView* consumer,
522
+ const std::vector<ForLoop*>& loops,
523
+ const std::unordered_set<ForLoop*>& rotated_loops,
524
+ const std::unordered_map<int, Val*>& override_index = {},
525
+ bool generate_pointer = false);
526
+
527
+ //! Returns the logical index linearized from a multi-dimension address into a
528
+ //! linear memory address a consumer tensor. The returned index is intended to
529
+ //! be used for the computation of some tensor factories, such as: iota and
530
+ //! rand (for Philox pseudo random sequences)
531
+ static Val* getLinearLogicalIndex(
532
+ TensorView* consumer_tv,
533
+ const std::vector<ForLoop*>& loops,
534
+ const std::unordered_set<ForLoop*>& rotated_loops);
535
+
536
+ //! Returns a vector of logical indices mapped onto the logical
537
+ //! domain of a consumer tensor. The returned index is intended
538
+ //! to be used for the computation of some tensor factories, such as:
539
+ //! eye
540
+ static std::vector<Val*> getConsumerPerDimLogicalIndex(
541
+ TensorView* consumer_tv,
542
+ const std::vector<ForLoop*>& loops,
543
+ const std::unordered_set<ForLoop*>& rotated_loops);
544
+
545
+ //! Returns a vector of logical indices mapped onto the logical
546
+ //! domain of a producer tensor.
547
+ static std::vector<Val*> getProducerPerDimLogicalIndex(
548
+ TensorView* producer_tv,
549
+ const TensorView* consumer_tv,
550
+ const std::vector<ForLoop*>& loops,
551
+ const std::unordered_set<ForLoop*>& rotated_loops,
552
+ const std::unordered_map<IterDomain*, Val*>& override_index = {});
553
+
554
+ //! Take a consumer tensorview and loop nest and generates predicates
555
+ //! associated with the concrete roots of the loop nest. Returns a list of
556
+ //! predicates, and a list of concrete roots they're associated with. It
557
+ //! is assumed that no predicate is required if index[i] is an index
558
+ //! directly from a for loop. This will not catch all cases if we actually
559
+ //! have static size information for example:
560
+ //!
561
+ //! TV[I].split(4)
562
+ //! would produce the code:
563
+ //! for(i : I/4)
564
+ //! for(j : 4)
565
+ //! if( i * 4 + j < TV.size(0))
566
+ //! TV[i * 4 + j]...
567
+ //!
568
+ //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't
569
+ //! need the predicate. This will be caught by canOmitPredicate in the
570
+ //! predicate lowering
571
+ //!
572
+ //! unswitch_or_vec_loop is the for loop to start the unswitch like
573
+ //! predicate, this is not a bool value as if we have an unswitch loop
574
+ //! with a vectorized loop inside, we only want to base the "unswitch"
575
+ //! like predicate on the vectorized loop.
576
+ static std::vector<PredicateInfo> getReferenceRootPredicates(
577
+ TensorView* consumer_tv,
578
+ const std::vector<ForLoop*>& loops,
579
+ const std::unordered_set<ForLoop*>& rotated_loops,
580
+ ForLoop* unswitch_or_vec_loop);
581
+
582
+ //! Compute the result for iota
583
+ static Val* iota(
584
+ TensorView* consumer_tv,
585
+ const std::vector<ForLoop*>& loops,
586
+ const std::unordered_set<ForLoop*>& rotated_loops,
587
+ Val* start,
588
+ Val* step,
589
+ DataType dtype);
590
+
591
+ //! Compute the result for eye
592
+ static Val* eye(
593
+ TensorView* consumer_tv,
594
+ const std::vector<ForLoop*>& loops,
595
+ const std::unordered_set<ForLoop*>& rotated_loops,
596
+ DataType dtype);
597
+
598
+ //! Compute the global index and the expected bytes for complete_tx mechanism
599
+ //! for CpAsyncBulk.
600
+ static std::pair<Val*, Val*> getCpAsyncBulkGmemIndex(
601
+ const LoadStoreOp* ldst,
602
+ Val* mbarrier,
603
+ const std::vector<ForLoop*>& loops,
604
+ const std::unordered_set<ForLoop*>& rotated_loops);
605
+ };
606
+
607
+ // Used for local and shared index mapping. Returns a map from loops
608
+ // to loop indices as well as a set of loops that do not contribute to
609
+ // indexing.
610
+ // TODO: could be cleaned up further.
611
+ std::pair<std::unordered_map<ForLoop*, Val*>, std::unordered_set<ForLoop*>>
612
+ indexMapFromTV(
613
+ const TensorView* tv,
614
+ const std::vector<ForLoop*>& loops,
615
+ const std::unordered_set<ForLoop*>& rotated_loops,
616
+ ForLoop* alloc_loop,
617
+ bool as_consumer,
618
+ ForLoop* circular_buffer_loop = nullptr);
619
+
620
+ //! Set "pragma unroll" required for loops that indexing of Local
621
+ //! tensors depends on.
622
+ //!
623
+ //! \param tv Indexed tensor
624
+ //! \param alloc_loop Allocation loop of tv
625
+ //! \param loops The current loop structure
626
+ //! \param id_map Producer-to-consumer map in case of indexing as producer
627
+ void ensureStaticIndexing(
628
+ const TensorView* tv,
629
+ ForLoop* alloc_loop,
630
+ const std::vector<ForLoop*>& loops,
631
+ const std::unordered_map<IterDomain*, IterDomain*>& id_map = {});
632
+
633
+ struct PredicateDomainInfo {
634
+ public:
635
+ // Iteration domain to predicate
636
+ IterDomain* id = nullptr;
637
+ // The set of iteration domains that make up the id. If this is for
638
+ // a non-divisible split, the set only contains the id itself. This
639
+ // set is used to remove redundant predicates when gathering
640
+ // unswitch predicates.
641
+ std::unordered_set<IterDomain*> covered_ids;
642
+ // True if this predicate is for an intermediate domain. Examples
643
+ // include domains with non-divisible split and resized domains.
644
+ bool is_intermediate_domain = false;
645
+ };
646
+
647
+ // Get all domains that need to be predicated due to non-divisible splits
648
+ std::vector<PredicateDomainInfo> getNonDivisibleConsumerDomainsToPredicate(
649
+ TensorView* consumer_tv);
650
+
651
+ } // namespace nvfuser