nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,87 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ir/all_nodes.h>
11
+ #include <val_graph.h>
12
+
13
+ #include <string>
14
+ #include <vector>
15
+
16
+ namespace nvfuser {
17
+
18
+ std::string toString(const std::vector<Val*>& val_group, int indent_size = 0);
19
+
20
+ std::string toString(
21
+ const std::vector<IterDomain*>& id_group,
22
+ int indent_size = 0);
23
+
24
+ std::string toString(
25
+ const ValGroup& id_group,
26
+ int indent_size = 0,
27
+ bool with_ptr = false);
28
+
29
+ std::string toString(
30
+ const std::vector<ValGroup>& id_groups,
31
+ int indent_size = 0,
32
+ bool with_ptr = false);
33
+
34
+ std::string toString(
35
+ const ValGroups& id_groups,
36
+ int indent_size = 0,
37
+ bool with_ptr = false);
38
+
39
+ std::string toInlineString(const std::vector<ValGroup>& id_groups);
40
+ std::string toInlineString(const ValGroups& id_groups);
41
+
42
+ std::string toString(const std::vector<Expr*>& expr_group, int indent_size = 0);
43
+ std::string toString(
44
+ const ExprGroup& expr_group,
45
+ int indent_size = 0,
46
+ bool with_ptr = false);
47
+
48
+ std::string toString(
49
+ const ValGraph& id_graph,
50
+ const std::vector<Expr*>& expr_group,
51
+ int indent_size = 0,
52
+ bool with_ptr = false);
53
+ std::string toString(
54
+ const ValGraph& id_graph,
55
+ const ExprGroup& expr_groups,
56
+ int indent_size = 0,
57
+ bool with_ptr = false);
58
+
59
+ std::string toString(
60
+ const ValGraph& id_graph,
61
+ const std::vector<ExprGroup>& expr_groups,
62
+ int indent_size = 0,
63
+ bool with_ptr = false);
64
+ std::string toString(
65
+ const ValGraph& id_graph,
66
+ const ExprGroups& expr_groups,
67
+ int indent_size = 0,
68
+ bool with_ptr = false);
69
+
70
+ std::string idGroupsString(
71
+ const ValGraph& id_graph,
72
+ int indent_size = 0,
73
+ bool with_ptr = false);
74
+ std::string exprGroupsString(
75
+ const ValGraph& id_graph,
76
+ int indent_size = 0,
77
+ bool with_ptr = false);
78
+ std::string definitionsString(
79
+ const ValGraph& id_graph,
80
+ int indent_size = 0,
81
+ bool with_ptr = false);
82
+ std::string usesString(
83
+ const ValGraph& id_graph,
84
+ int indent_size = 0,
85
+ bool with_ptr = false);
86
+
87
+ } // namespace nvfuser
@@ -0,0 +1,58 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/macros/Export.h>
11
+
12
+ #include <ir/all_nodes.h>
13
+
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ namespace nvfuser {
18
+
19
+ // TODO: Consider merging this class with the existing replay
20
+ // classes. The use cases are not exactly the same, so it isn't
21
+ // immediately clear if they could be trivially merge.
22
+ class ReplayTransform : OptInConstDispatch {
23
+ public:
24
+ // Replays expression_to_match with the provided ordered_inputs. Inputs should
25
+ // be ordered as they would be used in provided expression. Returns new
26
+ // replayed expression.
27
+ static Expr* replayAs(
28
+ const std::vector<IterDomain*>& ordered_inputs,
29
+ const Expr* expression_to_match);
30
+
31
+ private:
32
+ ReplayTransform(
33
+ const std::vector<IterDomain*>& ordered_inputs,
34
+ const Expr* expression_to_match);
35
+
36
+ using OptInConstDispatch::handle;
37
+
38
+ // We're going to replay this split operation on the corresponding ID
39
+ void handle(const Split* split) final;
40
+
41
+ // We're going to replay this merge operation on the corresponding IDs
42
+ void handle(const Merge* merge) final;
43
+
44
+ // We're going to replay this swizzle operation on the corresponding IDs
45
+ // if replaying swizzle is enabled.
46
+ void handle(const Swizzle2D* swizzle_2d) final;
47
+
48
+ void handle(const Swizzle* swizzle) final;
49
+
50
+ // We're going to replay this resize operation on the corresponding IDs
51
+ // if replaying resize is enabled.
52
+ void handle(const Resize* resize) final;
53
+
54
+ Expr* replayed_expr_ = nullptr;
55
+ const std::vector<IterDomain*>& input_ids_;
56
+ };
57
+
58
+ } // namespace nvfuser
@@ -0,0 +1,176 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <expr_simplifier.h>
11
+ #include <id_model/id_model.h>
12
+ #include <id_model/to_string.h>
13
+ #include <ir/utils.h>
14
+ #include <options.h>
15
+ #include <utils.h>
16
+
17
+ #include <functional>
18
+ #include <iostream>
19
+ #include <sstream>
20
+
21
+ namespace nvfuser {
22
+
23
+ // Options to enable the IdModel-based tensor indexer selectively
24
+ enum class IdModelEnableOption {
25
+ ConsumerIndex,
26
+ ProducerIndex,
27
+ InlinePredicate,
28
+ UnswitchPredicate,
29
+ // Uses the loop promotion to generate loops. Indexing and
30
+ // predication need to be enabled as well.
31
+ Loop,
32
+ };
33
+
34
+ inline std::unordered_set<IdModelEnableOption> getIdModelEnabledOptions() {
35
+ std::unordered_set<IdModelEnableOption> opts;
36
+
37
+ if (hasEnableOptionArgument(EnableOption::IdModel, "consumer_index") ||
38
+ hasEnableOptionArgument(EnableOption::IdModel, "index") ||
39
+ hasEnableOptionArgument(EnableOption::IdModel, "all")) {
40
+ opts.insert(IdModelEnableOption::ConsumerIndex);
41
+ }
42
+
43
+ if (hasEnableOptionArgument(EnableOption::IdModel, "producer_index") ||
44
+ hasEnableOptionArgument(EnableOption::IdModel, "index") ||
45
+ hasEnableOptionArgument(EnableOption::IdModel, "all")) {
46
+ opts.insert(IdModelEnableOption::ProducerIndex);
47
+ }
48
+
49
+ if (hasEnableOptionArgument(EnableOption::IdModel, "inline_predicate") ||
50
+ hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
51
+ hasEnableOptionArgument(EnableOption::IdModel, "all")) {
52
+ opts.insert(IdModelEnableOption::InlinePredicate);
53
+ }
54
+
55
+ if (hasEnableOptionArgument(EnableOption::IdModel, "unswitch_predicate") ||
56
+ hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
57
+ hasEnableOptionArgument(EnableOption::IdModel, "all")) {
58
+ opts.insert(IdModelEnableOption::UnswitchPredicate);
59
+ }
60
+
61
+ if (hasEnableOptionArgument(EnableOption::IdModel, "loop") ||
62
+ hasEnableOptionArgument(EnableOption::IdModel, "all")) {
63
+ opts.insert(IdModelEnableOption::Loop);
64
+ }
65
+
66
+ // Loop requires ConsumerIndex, ProducerIndex, InlinePredicate and
67
+ // UnswitchPredicate
68
+ if (opts.find(IdModelEnableOption::Loop) != opts.end()) {
69
+ NVF_ERROR(
70
+ opts.find(IdModelEnableOption::ConsumerIndex) != opts.end(),
71
+ "ConsumerIndex required for Loop");
72
+ NVF_ERROR(
73
+ opts.find(IdModelEnableOption::ProducerIndex) != opts.end(),
74
+ "ProducerIndex required for Loop");
75
+ NVF_ERROR(
76
+ opts.find(IdModelEnableOption::InlinePredicate) != opts.end(),
77
+ "InlinePredicate required for Loop");
78
+ NVF_ERROR(
79
+ opts.find(IdModelEnableOption::UnswitchPredicate) != opts.end(),
80
+ "UnswitchPredicate required for Loop");
81
+ }
82
+
83
+ return opts;
84
+ }
85
+
86
+ inline bool isIdModelOptionEnabled(IdModelEnableOption option) {
87
+ const auto opts = getIdModelEnabledOptions();
88
+ return opts.find(option) != opts.end();
89
+ }
90
+
91
+ // Get the promotion domain of a given loop domain.
92
+ inline IterDomain* getLoopPromotion(
93
+ IterDomain* loop_id,
94
+ const IdModel& id_model) {
95
+ const auto& loop_graph = id_model.idGraph(IdMappingMode::LOOP);
96
+ const auto& loop_promotion_map = id_model.loopPromotionMap();
97
+ const auto& loop_group = loop_graph.toGroup(loop_id);
98
+
99
+ auto loop_promotion_map_it = loop_promotion_map.find(loop_group);
100
+ NVF_ERROR(
101
+ loop_promotion_map_it != loop_promotion_map.end(),
102
+ "No loop promotion found: ",
103
+ loop_id->toString(),
104
+ ". Loop group: ",
105
+ nvfuser::toString(loop_group));
106
+
107
+ return loop_promotion_map_it->second;
108
+ }
109
+
110
+ // Get the loop domains of a given expr. Currently, they're always
111
+ // the loop domains of a consumer tensor, but in the future this
112
+ // function may return the loop domains of a producer for
113
+ // producer-based indexing.
114
+ inline std::vector<IterDomain*> getLoopIds(
115
+ const Expr* expr,
116
+ const IdModel& id_model) {
117
+ // Assume consumer-based indexing. Needs to revisit for ops like
118
+ // scatter
119
+ NVF_ERROR(!expr->outputs().empty());
120
+ auto output_tv = ir_utils::getTvOutput(expr);
121
+ NVF_ERROR(output_tv != nullptr);
122
+ auto loop_ids = output_tv->getLoopDomain();
123
+
124
+ for (auto& loop_id : loop_ids) {
125
+ loop_id = getLoopPromotion(loop_id, id_model);
126
+ }
127
+
128
+ return loop_ids;
129
+ }
130
+
131
+ inline ParallelType getParallelType(const ValGroup& loop_group) {
132
+ ParallelType common_pt = ParallelType::Serial;
133
+ for (const auto val : *loop_group) {
134
+ auto pt = val->as<IterDomain>()->getParallelType();
135
+ if (common_pt == pt || pt == ParallelType::Serial) {
136
+ continue;
137
+ } else if (common_pt == ParallelType::Serial) {
138
+ common_pt = pt;
139
+ } else {
140
+ // Inconsistent parallelization
141
+ NVF_THROW(
142
+ "Inconsistent parallelization detected. ",
143
+ "Known type: ",
144
+ common_pt,
145
+ "New type: ",
146
+ pt);
147
+ }
148
+ }
149
+
150
+ return common_pt;
151
+ }
152
+
153
+ // Check if the loop index of a loop group should be always
154
+ // just zero. For example, a loop group with an extent of one, i.e.,
155
+ // a broadcast-only loop group, should just use zero.
156
+ inline bool shouldUseZeroIndex(
157
+ const ValGroup& loop_group,
158
+ const IdModel& id_model) {
159
+ // Trivial loop
160
+ auto promotion_id =
161
+ getLoopPromotion(loop_group->front()->as<IterDomain>(), id_model);
162
+
163
+ // ExprSimplify should be disabled here as it would fail to
164
+ // recognize size-one IterDomain.
165
+ DisableOptionsGuard options_guard;
166
+ DisableOptionsGuard::getCurOptions().unset(DisableOption::ExprSimplify);
167
+
168
+ if (promotion_id->isBroadcast() ||
169
+ simplifyExpr(promotion_id->extent())->isOneInt()) {
170
+ return true;
171
+ }
172
+
173
+ return false;
174
+ }
175
+
176
+ } // namespace nvfuser
@@ -0,0 +1,55 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <compute_at_map.h>
11
+ #include <id_model/id_model.h>
12
+ #include <val_graph.h>
13
+
14
+ namespace nvfuser {
15
+
16
+ // Note that this class is a friend of ComputeAtMap as it needs to
17
+ // have private access
18
+ class IdModelValidator {
19
+ public:
20
+ IdModelValidator(Fusion* fusion, bool allow_self_mapping = false);
21
+
22
+ // Validate a given exact graph of IdModel by comparing it with
23
+ // ComputeAtMap. Their maps should
24
+ // be almost the same but there are some differences.
25
+ // - In ComputeAtMap, swizzles are just skipped no matter what swizzle
26
+ // type is used, so only swizzle outputs are mapped. In IdModel,
27
+ // only swizzle inputs are mapped, except for Loop swizzles where
28
+ // their inputs and outputs are mapped.
29
+ // - In ComputeAtMap, mappings are local. For example, if domain x0 is
30
+ // split to x1 and x2, and also domain y0 is split to y1 and
31
+ // y2. Suppose x0 and y0 are exactly mapped and the two splits are
32
+ // also considered exactly the same, IdModel maps x1 and y1, and x2
33
+ // and y2, respectively, whereas that doesn't happen with ComputeAtMap
34
+ //
35
+ // Accounting for the first difference doesn't seem trivial, so when
36
+ // swizzle is used we give up validating the exact graph. The second
37
+ // difference is whether mappings are propagated, which can be
38
+ // accounted for by updating the ComputeAtMap as is done in IdModel.
39
+ void checkExactGraphEquivalence(const ValGraph& exact_graph);
40
+
41
+ void checkAlmostExactGraphEquivalence(const ValGraph& almost_exact_graph);
42
+
43
+ void checkPermissiveGraphEquivalence(const ValGraph& permissive_graph);
44
+
45
+ private:
46
+ // Propagate mappings in a ComputeAtMap as is done in IdModel
47
+ static void fullyPropagateMappings(DisjointSets<IterDomain*>& id_sets);
48
+
49
+ private:
50
+ ComputeAtMap ca_map_;
51
+ // Validation is not enabled if swizzle is found. See the comment above
52
+ bool has_swizzle_ = false;
53
+ };
54
+
55
+ } // namespace nvfuser