nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,577 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <disjoint_set.h>
11
+ #include <exceptions.h>
12
+ #include <ir/all_nodes.h>
13
+ #include <iter_visitor.h>
14
+ #include <utils.h>
15
+ #include <visibility.h>
16
+
17
+ namespace nvfuser {
18
+
19
+ //! Generic interface for mapping logical domains of a producer-consumer pair.
20
+ class LogicalDomainMap : public PolymorphicBase {
21
+ public:
22
+ //! Return a map from a producer TensorDomain to a consumer
23
+ //! TensorDomain
24
+ //!
25
+ //! \param producer A producer TensorDomain
26
+ //! \param consumer A consumer TensorDomain
27
+ //! \param dims_to_map Maps only producer logical domains in this set
28
+ std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
29
+ const TensorDomain* producer,
30
+ const TensorDomain* consumer,
31
+ const std::unordered_set<IterDomain*>& dims_to_map) const;
32
+
33
+ //! Return a map from a producer TensorDomain to a consumer
34
+ //! TensorDomain
35
+ //!
36
+ //! \param producer A producer TensorDomain
37
+ //! \param consumer A consumer TensorDomain
38
+ std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
39
+ const TensorDomain* producer,
40
+ const TensorDomain* consumer) const;
41
+
42
+ //! Return a map from a consumer TensorDomain to a producer
43
+ //! TensorDomain
44
+ //!
45
+ //! \param consumer A consumer TensorDomain
46
+ //! \param producer A producer TensorDomain
47
+ //! \param dims_to_map Maps only consumer root domains in this set
48
+ std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
49
+ const TensorDomain* consumer,
50
+ const TensorDomain* producer,
51
+ const std::unordered_set<IterDomain*>& dims_to_map) const;
52
+
53
+ //! Return a map from a consumer TensorDomain to a producer
54
+ //! TensorDomain
55
+ //!
56
+ //! \param consumer A consumer TensorDomain
57
+ //! \param producer A producer TensorDomain
58
+ std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
59
+ const TensorDomain* consumer,
60
+ const TensorDomain* producer) const;
61
+
62
+ protected:
63
+ //! Return a map between logical IterDomains of a producer-consumer
64
+ //! pair.
65
+ //!
66
+ //! \param producer A producer TensorDomain
67
+ //! \param consumer A consumer TensorDomain
68
+ //! \param dims_to_map Maps only from IterDomains in this set
69
+ //! \param producer_to_consumer Maps from producer to consumer if true
70
+ virtual std::unordered_map<IterDomain*, IterDomain*> map(
71
+ const TensorDomain* producer,
72
+ const TensorDomain* consumer,
73
+ const std::unordered_set<IterDomain*>& dims_to_map,
74
+ bool producer_to_consumer) const = 0;
75
+ };
76
+
77
+ //! Maps logical domains of a producer-consumer pair. This class only
78
+ //! looks at the given pair of TensorViews and does not take into
79
+ //! consideration the constraints of the computeAt transformation,
80
+ //! i.e., unable to compute the same tensors multiple times. This
81
+ //! should not be used for transformations implementing computeAt, but
82
+ //! should be valid otherwise.
83
+ class PairwiseLogicalDomainMap : public LogicalDomainMap {
84
+ public:
85
+ //! When require_same_extent is false, domains that may have
86
+ //! different extents are also mapped. For example, IDs of lookup
87
+ //! tensors in gather may have larger extents than the corresponding
88
+ //! IDs of the output and index tensors. This relaxation is
89
+ //! necessary when indexing into lookup tensors as producers.
90
+ //!
91
+ //! \param producer The producer tensor of a producer-consumer pair.
92
+ //! \param consumer The consumer tensor of a producer-consumer pair.
93
+ PairwiseLogicalDomainMap(
94
+ const TensorView* producer,
95
+ const TensorView* consumer);
96
+
97
+ PairwiseLogicalDomainMap& mapBroadcast(bool b) {
98
+ map_broadcast_ = b;
99
+ return *this;
100
+ }
101
+
102
+ //! If b is true: map symbolic domains with other IterDomains even if their
103
+ //! extents don't match. If b is false (default): map symbolic domains with
104
+ //! other IterDomains only if their extents match.
105
+ PairwiseLogicalDomainMap& mapSymbolic(bool b) {
106
+ map_symbolic_ = b;
107
+ return *this;
108
+ }
109
+
110
+ PairwiseLogicalDomainMap& mapDifferentExtents(bool b) {
111
+ map_different_extents_ = b;
112
+ return *this;
113
+ }
114
+
115
+ PairwiseLogicalDomainMap& mapIndexedDomains(bool b) {
116
+ map_indexed_domains_ = b;
117
+ return *this;
118
+ }
119
+
120
+ const TensorView* producerTv() const {
121
+ return producer_tv_;
122
+ }
123
+
124
+ const TensorView* consumerTv() const {
125
+ return consumer_tv_;
126
+ }
127
+
128
+ std::string toString() const;
129
+
130
+ // Helper methods on top of LogicalDomainMap::mapProducerToConsumer and
131
+ // LogicalDomainMap::mapConsumerToProducer. This way, the caller doesn't have
132
+ // to specify the producer domain and the consumer domain, which is redundant
133
+ // and error-prone.
134
+ std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
135
+ const std::unordered_set<IterDomain*>* dims_to_map = nullptr) const;
136
+ std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
137
+ const std::unordered_set<IterDomain*>* dims_to_map = nullptr) const;
138
+
139
+ protected:
140
+ std::unordered_map<IterDomain*, IterDomain*> map(
141
+ const TensorDomain* producer,
142
+ const TensorDomain* consumer,
143
+ const std::unordered_set<IterDomain*>& dims_to_map,
144
+ bool producer_to_consumer) const override;
145
+
146
+ private:
147
+ const TensorView* producer_tv_ = nullptr;
148
+ const TensorView* consumer_tv_ = nullptr;
149
+
150
+ //! Options to allow more permissive mappings
151
+
152
+ //! Map broadcast and non-broadcast domains. Note that this is on by
153
+ //! default
154
+ bool map_broadcast_ = true;
155
+ //! Map symbolic domains with other IterDomains, even if their extents don't
156
+ //! match. Note that this is off by default, in which case they are mapped
157
+ //! only if their extents match.
158
+ bool map_symbolic_ = false;
159
+ //! Map domains that may have different extents, e.g., torchGather
160
+ bool map_different_extents_ = false;
161
+ //! Map domains that are indirectly accessed, e.g., indexSelect
162
+ bool map_indexed_domains_ = false;
163
+ };
164
+
165
+ //! Represents an iteration domain of a TensorDomain. Only used for
166
+ //! logical domain mapping.
167
+ //!
168
+ //! Note that an IterDomain object may be reused
169
+ //! across multiple TensorDomains, but an IterDomain in a
170
+ //! TensorDomain may not be necessarily mappable to the same
171
+ //! IterDomain used in a different TensorDomain. Thus, for the purpose
172
+ //! of logical domain mapping, an iteration domain needs to be identified
173
+ //! with an IterDomain and its TensorDomain.
174
+ class DomainKey {
175
+ public:
176
+ DomainKey() = default;
177
+ DomainKey(
178
+ const TensorDomain* td,
179
+ const IterDomain* id,
180
+ const IterDomain* concrete_id = nullptr)
181
+ : td_(td), id_(id), concrete_id_(concrete_id) {}
182
+ const TensorDomain* td() const {
183
+ return td_;
184
+ }
185
+ const IterDomain* id() const {
186
+ return id_;
187
+ }
188
+ const IterDomain* concreteId() const {
189
+ return concrete_id_;
190
+ }
191
+ bool operator==(const DomainKey& other) const {
192
+ return td() == other.td() && id() == other.id() &&
193
+ concreteId() == other.concreteId();
194
+ }
195
+ bool operator!=(const DomainKey& other) const {
196
+ return !(*this == other);
197
+ }
198
+
199
+ std::string toString() const;
200
+
201
+ private:
202
+ const TensorDomain* td_ = nullptr;
203
+ const IterDomain* id_ = nullptr;
204
+ const IterDomain* concrete_id_ = nullptr;
205
+ };
206
+
207
+ struct DomainKeyHash {
208
+ std::size_t operator()(const DomainKey& key) const {
209
+ return std::hash<const TensorDomain*>{}(key.td()) ^
210
+ std::hash<const IterDomain*>{}(key.id());
211
+ }
212
+ };
213
+
214
+ using DomainKeySet = std::unordered_set<DomainKey, DomainKeyHash>;
215
+
216
+ template <typename Mapped>
217
+ using DomainKeyMap = std::unordered_map<DomainKey, Mapped, DomainKeyHash>;
218
+
219
+ class ComputeAtLogicalDomainMap;
220
+
221
+ //! A helper class to find all DomainKeys that are consumers of
222
+ //! reduction outputs. Such consumer IterDomains may not be mapped to
223
+ //! the producer reduction domain since the corresponding reduction
224
+ //! loop must be closed before any of the consumers can appear.
225
+ class UnmappableReductionDomains : private IterVisitor {
226
+ public:
227
+ UnmappableReductionDomains();
228
+ ~UnmappableReductionDomains() override = default;
229
+
230
+ //! Returns true when mapping consumer domains would cause a
231
+ //! reduction output domain to be mapped with a consumer domain of
232
+ //! the redution. It needs to be avoided as computing consumers of
233
+ //! reduction outputs within the corresponding reduction loop is not
234
+ //! possible. This routine is used to build logical domain mappings.
235
+ bool isReductionOutputMapped(
236
+ const DomainKeySet& consumer_domains,
237
+ const ComputeAtLogicalDomainMap& logical_map) const;
238
+
239
+ std::string toString() const;
240
+
241
+ private:
242
+ using IterVisitor::handle;
243
+ void handle(ReductionOp* op) override;
244
+ void handle(GroupedReductionOp* op) override;
245
+ void handle(WelfordOp* op) override;
246
+ void handle(MmaOp* op) override;
247
+
248
+ void handleReductionOutput(TensorView* out_tv);
249
+
250
+ private:
251
+ //! Map from Reduction output DomainKeys to consumer DomainKeys
252
+ DomainKeyMap<DomainKeySet> reduction_domains_;
253
+ //! Map from Reduction output DomainKeys to producer DomainKeys
254
+ DomainKeyMap<DomainKeySet> reduction_domain_inputs_;
255
+ };
256
+
257
+ //! Models logical-domain mappings for computeAt
258
+ //!
259
+ //! Two iteration domains are mapped when computeAt of one iteration
260
+ //! domain is possible at another iteration domain. Consider a simple
261
+ //! example:
262
+ //! T2 [i0,i1] = T1[i2,i3] + T0[i4,i5]
263
+ //! This will create mappings between i0, i2 and i4.
264
+ //!
265
+ //! Note that with views, there can be multiple domains mapped with
266
+ //! the same domain. Thus, obtaining one-to-one maps can
267
+ //! fail. Currently, the only use of this class is getMappableDims,
268
+ //! which just grabs any domain that is mappable, which works no
269
+ //! matter view is used or not.
270
+ class NVF_API ComputeAtLogicalDomainMap : public LogicalDomainMap {
271
+ friend class ComputeAtLogicalDomainMapBuilder;
272
+
273
+ public:
274
+ //! Builds a mapping table by analyzing the current
275
+ //! fusion. Overwrite a previous table if any.
276
+ //!
277
+ //! \param map_through_reduction If set
278
+ //! true, will disable UnmappableReductionDomains check.
279
+ //! This is only for re-using logic in detecting
280
+ //! normalization fusions, which deviates slightly from
281
+ //! intended use of this class. Should always be true
282
+ //! in compute_at use cases.
283
+ void build(bool map_through_reduction = false);
284
+
285
+ //! Returns if key(td_a, id_a) and key(td_b, id_b) are mapped to eachother
286
+ //! (equivalent), or are the same key.
287
+ //!
288
+ //! \param td_a A TensorDomain
289
+ //! \param id_a An IterDomain in td_a
290
+ //! \param td_b Another TensorDomain
291
+ //! \param id_b An IterDomain in td_b
292
+ //! \returns Boolean representing if they are mapped
293
+ bool canMap(
294
+ const TensorDomain* td_a,
295
+ const IterDomain* id_a,
296
+ const TensorDomain* td_b,
297
+ const IterDomain* id_b) const;
298
+
299
+ //! Make a TensorDomain an alias of another TensorDomain
300
+ //!
301
+ //! This is for the computeAt transformation, where TensorViews are
302
+ //! updated with new TensorDomains. Since they keep using the same
303
+ //! logical doamins, the logical mapping remains valid but needs to
304
+ //! reflect the use of new TensorDomains as aliases of the existing
305
+ //! ones.
306
+ //!
307
+ //! \param td An existing TensorDomain
308
+ //! \param td_alias An alias of td
309
+ void setAlias(const TensorDomain* td, const TensorDomain* td_alias);
310
+
311
+ //! Return a map between TensorDomains
312
+ //!
313
+ //! Unlike the other map functions, two TensorDomains do not need to
314
+ //! be a producer-consumer pair. Since they may not be a
315
+ //! producer-consumer pair, this function requires proper domains, which may
316
+ //! be root or logical domains. Also, no error check is done as we do not
317
+ //! assume producer-consumer relationship.
318
+ //!
319
+ //! Note that an exception is thrown when a domain is found to be
320
+ //! mapped to multiple domains, which can happen with views.
321
+ //!
322
+ //! \param from_td A TensorDomain from which a map is created
323
+ //! \param from_dom A root/logical domain of from_td
324
+ //! \param to_td A TensorDomain to which a map is created
325
+ //! \param to_dom A root/logical domain of to_td
326
+ std::unordered_map<IterDomain*, IterDomain*> mapBestEffort(
327
+ const TensorDomain* from_td,
328
+ const std::vector<IterDomain*>& from_dom,
329
+ const TensorDomain* to_td,
330
+ const std::vector<IterDomain*>& to_dom) const;
331
+
332
+ // Returns an unordered set of all iter domains in producer and consumer that
333
+ // can map to eachother
334
+ std::unordered_set<IterDomain*> getMappableDims(
335
+ const TensorDomain* producer,
336
+ const TensorDomain* consumer) const;
337
+
338
+ std::string toString() const;
339
+
340
+ //! Returns true if id in td is concretized
341
+ bool isConcretized(const TensorDomain* td, const IterDomain* id) const;
342
+
343
+ private:
344
+ //! Returns if key_a and key(td_b, id_b) are mapped to eachother (equivalent),
345
+ //! or are the same key.
346
+ //!
347
+ //! \param key_a A DomainKey
348
+ //! \param td_b Another TensorDomain
349
+ //! \param id_b An IterDomain in td_b
350
+ //! \returns Boolean representing if they are mapped
351
+ bool canMap(
352
+ const DomainKey& key_a,
353
+ const TensorDomain* td_b,
354
+ const IterDomain* id_b) const;
355
+
356
+ //! Returns if key_a and key_b are mapped to each other (equivalent), or are
357
+ //! the same key. Returns false if two keys are not known to be mapped.
358
+ bool canMap(const DomainKey& key_a, const DomainKey& key_b) const;
359
+
360
+ //! Returns the set of (non-broadcast) DomainKeys that id in td is
361
+ //! broadcasted to. Can result in more than one "concrete" DomainKey.
362
+ std::vector<DomainKey> getConcretizedKeys(
363
+ const TensorDomain* td,
364
+ const IterDomain* id) const;
365
+
366
+ //! Returns the set of (non-broadcast) iter domains that id in td is
367
+ //! broadcasted to. Can result in more than one "concrete" iter domain.
368
+ std::unordered_set<const IterDomain*>& getConcretizedDomains(
369
+ const TensorDomain* td,
370
+ const IterDomain* id);
371
+
372
+ //! Return a map between logical IterDomains of a producer-consumer
373
+ //! pair.
374
+ //!
375
+ //! \param producer A producer TensorDomain
376
+ //! \param consumer A consumer TensorDomain
377
+ //! \param dims_to_map Maps only from IterDomains in this set
378
+ //! \param producer_to_consumer Maps from producer to consumer if true
379
+ std::unordered_map<IterDomain*, IterDomain*> map(
380
+ const TensorDomain* producer,
381
+ const TensorDomain* consumer,
382
+ const std::unordered_set<IterDomain*>& dims_to_map,
383
+ bool producer_to_consumer) const override;
384
+
385
+ private:
386
+ //! Disjoint set of all mapped <TD, ID> keys to determine axes equivalency
387
+ DisjointSets<DomainKey, DomainKeyHash> eq_set_;
388
+
389
+ //! All IterDomains in the mapping that are a broadcast ID
390
+ DomainKeyMap<std::unordered_set<const IterDomain*>> bcast_map_;
391
+
392
+ //! Broadcast iter domain that does not match dimensions in its produer,
393
+ //! meaning it is a brand new domain in its TensorDomain.
394
+ DomainKeySet new_broadcast_domains_;
395
+
396
+ //! Broadcast iter domain that does not match dimensions in its consumer,
397
+ //! meaning it is a removed domain in its TensorDomain.
398
+ DomainKeySet removed_broadcast_domains_;
399
+
400
+ //! Keep track of window axes so that the map function can ignore them.
401
+ std::unordered_set<IterDomain*> window_axes_;
402
+ };
403
+
404
+ //! Create a DisjointSets of logical IterDomains by traversing the
405
+ //! current fusion entirely. IterDomains that can be mapped each
406
+ //! other with computeAt are grouped into the same subset in the
407
+ //! DisjointSets.
408
+ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor {
409
+ public:
410
+ explicit ComputeAtLogicalDomainMapBuilder(
411
+ ComputeAtLogicalDomainMap& logical_map,
412
+ bool map_through_reduction = false);
413
+
414
+ private:
415
+ //! Initialize the bcast map for fusion outputs
416
+ void initializeBcastMap(const TensorView* tv, const IterDomain* id);
417
+
418
+ //! Set a pair of producer-consumer domain keys as mappable
419
+ void setMapped(const DomainKey& producer, const DomainKey& consumer);
420
+
421
+ //! Records two domains are invalid to map
422
+ void setInvalid(const DomainKey& key1, const DomainKey& key2);
423
+
424
+ //! Check if no pair of domains is invalid to map
425
+ bool isInvalid(const DomainKeySet& domains) const;
426
+
427
+ //! Track a pair of producer-consumer domains as potentially mappable. Inserts
428
+ //! entries into pending_map_, but does not add anything into the logical_map_
429
+ //! (added when handle is called on a TensorView). Maybe mapped will, however,
430
+ //! immediately propagate broadcast iter domains.
431
+ void setMaybeMapped(
432
+ const TensorDomain* producer_td,
433
+ const IterDomain* producer_id,
434
+ const TensorDomain* consumer_td,
435
+ const IterDomain* consumer_id);
436
+
437
+ void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
438
+
439
+ //! Map pointwise IterDomains from inputs of expressions to outputs.
440
+ //! Do not map reduction IterDomains in inputs.
441
+ void mapPointwiseLikeOp(Expr* e);
442
+
443
+ using BackwardVisitor::handle;
444
+
445
+ void dispatch(Expr* e) override;
446
+
447
+ void handle(UnaryOp* uop) override {
448
+ mapPointwiseLikeOp(uop);
449
+ }
450
+
451
+ void handle(BinaryOp* bop) override {
452
+ mapPointwiseLikeOp(bop);
453
+ }
454
+
455
+ void handle(TernaryOp* top) override {
456
+ mapPointwiseLikeOp(top);
457
+ }
458
+
459
+ void handle(RNGOp* top) override;
460
+
461
+ void handle(SelectOp* op) override {
462
+ mapPointwiseLikeOp(op);
463
+ }
464
+
465
+ void handle(IndexSelectOp* op) override {
466
+ mapPointwiseLikeOp(op);
467
+ }
468
+
469
+ void handle(TorchGatherOp* op) override {
470
+ mapPointwiseLikeOp(op);
471
+ }
472
+
473
+ void handle(ReductionOp* op) override {
474
+ mapPointwiseLikeOp(op);
475
+ }
476
+
477
+ void handle(GroupedReductionOp* op) override {
478
+ mapPointwiseLikeOp(op);
479
+ }
480
+
481
+ void handle(WelfordOp* wop) override {
482
+ mapPointwiseLikeOp(wop);
483
+ }
484
+
485
+ void handle(LoadStoreOp* ldst) override {
486
+ mapPointwiseLikeOp(ldst);
487
+ }
488
+
489
+ void handle(MmaOp* wop) override {
490
+ mapPointwiseLikeOp(wop);
491
+ }
492
+
493
+ void handle(ViewOp* op) override {
494
+ mapPointwiseLikeOp(op);
495
+ }
496
+
497
+ void handle(ViewAsScalar* op) override;
498
+
499
+ void handle(BroadcastOp* op) override;
500
+
501
+ void handle(SqueezeOp* op) override;
502
+
503
+ void handle(ExpandOp* op) override {
504
+ mapPointwiseLikeOp(op);
505
+ }
506
+
507
+ void handle(RepeatOp* op) override {
508
+ mapPointwiseLikeOp(op);
509
+ }
510
+
511
+ void handle(PadOp* op) override {
512
+ // For compute-at, padded id should be mapped
513
+ mapPointwiseLikeOp(op);
514
+ }
515
+
516
+ void handle(SliceOp* op) override {
517
+ mapPointwiseLikeOp(op);
518
+ }
519
+
520
+ void handle(CatOp* op) override {
521
+ // For compute-at, concat id should be mapped
522
+ mapPointwiseLikeOp(op);
523
+ }
524
+
525
+ void handle(TensorView* tv) override;
526
+
527
+ //! Maps all pending mappings.
528
+ //! This is called for each of TensorViews in a backward traversal,
529
+ //! recursively building mappings from the output tensors to the
530
+ //! input tensors.
531
+ void mapAllPendingMappings(const DomainKey& key);
532
+
533
+ //! Maps all pending mappings for id of td. When id is a broadcast,
534
+ //! mapping is done separately for each concrete domain.
535
+ void mapAllPendingMappings(const TensorDomain* td, IterDomain* id);
536
+
537
+ bool safeToMap(const DomainKeySet& domains);
538
+
539
+ private:
540
+ ComputeAtLogicalDomainMap& logical_map_;
541
+ //! Keep track of what we want to try and map
542
+ DomainKeyMap<DomainKeySet> pending_map_;
543
+ std::unordered_set<Expr*> visited_;
544
+ //! Helper class to find invalid mappings due to reductions
545
+ UnmappableReductionDomains incompatible_domains_;
546
+ //! Running vector of domain pairs that are invalid to map
547
+ std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_;
548
+
549
+ //! Disable UnmappableReductions check, should
550
+ //! always be false for compute_at use cases
551
+ bool map_through_reduction_ = false;
552
+ };
553
+
554
+ //! Maps logical domains of an entire fusion. Does not map broadcast
555
+ //! domains with non-broadcast domains.
556
+ class NVF_API ExactLogicalDomainMap : public LogicalDomainMap {
557
+ public:
558
+ ExactLogicalDomainMap(Fusion* fusion);
559
+
560
+ bool areMapped(const IterDomain* id_a, const IterDomain* id_b) const;
561
+
562
+ std::string toString() const;
563
+
564
+ const DisjointSets<const IterDomain*>& getMappedSets() const;
565
+
566
+ protected:
567
+ std::unordered_map<IterDomain*, IterDomain*> map(
568
+ const TensorDomain* producer,
569
+ const TensorDomain* consumer,
570
+ const std::unordered_set<IterDomain*>& dims_to_map,
571
+ bool producer_to_consumer) const override;
572
+
573
+ private:
574
+ DisjointSets<const IterDomain*> eq_sets_;
575
+ };
576
+
577
+ } // namespace nvfuser
@@ -0,0 +1,23 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+
9
+ #pragma once
10
+
11
+ #if __has_include(<bits/c++config.h>)
12
+ #include <bits/c++config.h>
13
+ #endif
14
+
15
+ #if defined(__GLIBCXX__) && __GLIBCXX__ >= 20230714
16
+ #define STD_UNORDERED_SET_SUPPORTS_INCOMPLETE_TYPE 1
17
+ #endif
18
+
19
+ #if __cplusplus < 202002L
20
+ #define IS_CPP20 0
21
+ #else
22
+ #define IS_CPP20 1
23
+ #endif