nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,111 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+ #include <scheduler/all_schedulers.h>
10
+
11
+ namespace nvfuser {
12
+
13
+ class TensorView;
14
+ class ComputeAtLogicalDomainMap;
15
+ class ComputeAtMap;
16
+ class ExpressionEvaluator;
17
+ class KernelArgumentHolder;
18
+
19
+ namespace registry_utils {
20
+
21
+ bool checkPatternEquivalence(
22
+ TensorView* out_tv0,
23
+ TensorView* out_tv1,
24
+ const ComputeAtLogicalDomainMap& logical_map);
25
+
26
+ // Reusing some code from lowering specifically in lower_trivial_broadcast.cpp
27
+ // ConcretizedBroadcastDomains::maybeNonUniquelyConcretized this checks if
28
+ // there's a broadcast iteration domain that's being broadcasted to seemingly
29
+ // different extents, meaning we don't know in the kernel if the dimension is
30
+ // being broadcasted to one size multiple times or different sizes. This is a
31
+ // hard to optimize problem and likely indicates we shouldn't be fusing.
32
+ bool hasNonUniqueBcast(Fusion* fusion);
33
+
34
+ // TODO: remove this requirement entirely
35
+ bool rejectScheduleForMemoryPromotion(
36
+ Fusion* fusion,
37
+ SchedulerType scheduler_type);
38
+
39
+ bool isConnectedFusionGraph(Fusion* fusion);
40
+
41
+ // Returns if a fusion cannot transformed into a consistent format since we
42
+ // can't transform forward through view operations, for exmaple:
43
+ //
44
+ // tv0[I0, I1, I2]
45
+ // tv1[I0*I1, I2] = view(tv0)
46
+ // tv2[I0, I1*I2] = view(tv0)
47
+ //
48
+ // If we start transform propagation at either tv1 or tv2, it would require
49
+ // "replaying forward" through the other. If we started at tv1 we'd have to be
50
+ // able to take tv2[I0, I1*I2] and transform it to [I0*I1, I2], however this
51
+ // would "undo" the view transformation which we do not support today.
52
+ //
53
+ // Returns true if a scenario like above is found in the fusion.
54
+ bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map);
55
+
56
+ // Returns if view interferes with how we want to treat the reference, being
57
+ // at least a 2D reduction schedule but maybe a 3D reduction schedule.
58
+ bool reductionInterferingView(
59
+ Fusion* fusion,
60
+ const ComputeAtMap& ca_map,
61
+ TensorView* reduction_reference);
62
+
63
+ // Check inputs, outputs and intermediates
64
+ // Intermediates are contiguous, so strides are not necessary
65
+ // Strides are required for inputs and also maybe for outputs as
66
+ // they may be non-contiguous. However, in our current interface,
67
+ // output strides are not available, so if there's any outputs that
68
+ // are non contiguous, need to fall back to 64-bit indexing
69
+ PrimDataType getIndexTypeOfKernel(
70
+ Fusion* fusion,
71
+ const std::vector<TensorView*>& all_tvs,
72
+ const KernelArgumentHolder& inputs,
73
+ ExpressionEvaluator& ee);
74
+
75
+ class SchedulerTopologyChecker {
76
+ public:
77
+ // Checks if any broadcasts are resolved after a reduction that don't follow
78
+ // the normalization pattern
79
+ static bool hasNonNormalizePostReductionBCast(Fusion* fusion);
80
+
81
+ // Checks if any broadcasts are resolved after a reduction, this shouldn't
82
+ // be accepted in the single reduction or multi-reduction scheduler
83
+ static bool hasPostReductionBCast(Fusion* fusion);
84
+
85
+ // Checks if there's any unsupported operations post reduction. If outer
86
+ // reduction we can fuse some pointwise ops if they don't require
87
+ // broadcasting (checked in hasPostReductionBCast). For inner reductions we
88
+ // cannot fuse any binary like operation (includes operations like shift
89
+ // that we're not fusing right now) involving "new" inputs (not going
90
+ // through a reduction).
91
+ static bool supportedPostReductionFusion(
92
+ Fusion* fusion,
93
+ std::vector<TensorView*> reduction_tvs);
94
+
95
+ // Checks if there's any gather-like ops that result in non-resolved
96
+ // broadcast domains and then get squeezed before reaching reduction
97
+ // TVs. The reduction scheduler uses reduction TVs as a scheduling
98
+ // reference, so that won't be able to schedule the broadcast ID if
99
+ // squeezed and its corresponding index-accessed producer ID, and
100
+ // any IDs that the producer ID depends on.
101
+ //
102
+ // This analysis has some similarity as DomainMap. Can be
103
+ // consolidated?
104
+ static bool hasGatherToBroadcastBeforeReduction(
105
+ Fusion* fusion,
106
+ const std::vector<TensorView*>& reduction_tvs);
107
+ };
108
+
109
+ } // namespace registry_utils
110
+
111
+ } // namespace nvfuser
@@ -0,0 +1,41 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <scheduler/heuristic.h>
11
+ #include <scheduler/registry.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ class Fusion;
16
+ class SchedulerRuntimeInfo;
17
+ class HeuristicDataCache;
18
+
19
+ class ResizeScheduler : public SchedulerEntry {
20
+ public:
21
+ bool canScheduleCompileTime(Fusion* fusion) override;
22
+ bool canScheduleRunTime(
23
+ Fusion* fusion,
24
+ SchedulerRuntimeInfo& runtime_info,
25
+ HeuristicDataCache* data_cache = nullptr) override {
26
+ return true;
27
+ }
28
+
29
+ std::unique_ptr<HeuristicParams> computeHeuristics(
30
+ Fusion* fusion,
31
+ SchedulerRuntimeInfo& runtime_info,
32
+ HeuristicDataCache* data_cache) override;
33
+
34
+ void schedule(Fusion* fusion, const HeuristicParams* params) override;
35
+
36
+ constexpr static SchedulerType schedulerType() {
37
+ return SchedulerType::Resize;
38
+ }
39
+ };
40
+
41
+ } // namespace nvfuser
@@ -0,0 +1,67 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/util/hash.h>
11
+ #include <ir/interface_nodes.h>
12
+ #include <scheduler/heuristic.h>
13
+ #include <utils.h>
14
+
15
+ #include <sstream>
16
+
17
+ namespace nvfuser {
18
+
19
+ class ResizeParams : public HeuristicParams {
20
+ public:
21
+ ResizeParams() : HeuristicParams(SchedulerType::Resize) {};
22
+
23
+ // Split grid x dimension
24
+ bool split_grid_x_dim = false;
25
+
26
+ int64_t largest_input = -1;
27
+
28
+ int64_t vectorization_factor = 1;
29
+
30
+ static constexpr int64_t max_gdimx = (1L << 31) - 1L;
31
+
32
+ using HeuristicParams::HeuristicParams;
33
+
34
+ // Warning: Does not check launch parameters!
35
+ bool sameAs(const HeuristicParams* other_base) const override {
36
+ auto other = dynamic_cast<const ResizeParams*>(other_base);
37
+ if (other == nullptr) {
38
+ return false;
39
+ }
40
+ bool attr_equal = other->cparams == cparams &&
41
+ other->split_grid_x_dim == split_grid_x_dim &&
42
+ other->largest_input == largest_input &&
43
+ other->vectorization_factor == vectorization_factor;
44
+ return attr_equal;
45
+ }
46
+
47
+ std::string toString() const override {
48
+ std::stringstream ss;
49
+ ss << "\n===== Resize Parameters ========\n"
50
+ << (tag.empty() ? "" : "Tag: ") << tag << " Resize Characteristics:\n"
51
+ << " split grid x dim: " << split_grid_x_dim << "\n"
52
+ << " index of largest input: " << largest_input << "\n"
53
+ << " vectorization factor: " << vectorization_factor << "\n";
54
+ ss << "====================================\n";
55
+ return ss.str();
56
+ }
57
+
58
+ size_t hash() const override {
59
+ return c10::get_hash(split_grid_x_dim);
60
+ }
61
+
62
+ std::unique_ptr<HeuristicParams> clone() const override {
63
+ return std::make_unique<ResizeParams>(*this);
64
+ }
65
+ };
66
+
67
+ } // namespace nvfuser
@@ -0,0 +1,166 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+ #include <cstddef>
10
+ #include <cstdint>
11
+
12
+ #include <expr_evaluator.h>
13
+ #include <fusion.h>
14
+ #include <runtime/executor_kernel_arg.h>
15
+ #include <utils.h>
16
+ #include <visibility.h>
17
+
18
+ namespace nvfuser {
19
+
20
+ class ExpressionEvaluator;
21
+
22
+ //! SchedulerRuntimeInfo is the abstraction introduced in
23
+ //! this PR for passing runtime input dependent information
24
+ //! to the schedulers and kernel caches.
25
+ //!
26
+ //! Note:
27
+ //! if any additional info needed, or maybe just the inputs themselves it
28
+ //! could just be added to this class, and they will be distributed to the
29
+ //! segmenter and schedulers.
30
+ //! It is important that input id encoding should be up to date with any change
31
+ //! of this class to avoid launching compiled kernels with illegal inputs.
32
+
33
+ class SchedulerRuntimeInfo : public NonCopyable {
34
+ public:
35
+ // Max vector size we will consider, in bytes,
36
+ // currently set to 16B = 128b
37
+ static constexpr int64_t max_alignment_size_in_byte = 16;
38
+
39
+ //! Create runtime info for given fusion and input. Creating and binding
40
+ //! evaluator is optional. The evaluator is used to manage intermediate
41
+ //! integers in the fusion. We need them for segmenter and schedulers,
42
+ //! but we don't need them when we are just using this class to provide
43
+ //! additional encoding for kernel cache lookup.
44
+ //!
45
+ //! The index type of forced_index_type is used if given, no matter
46
+ //! how large the actual arguments and fusion tensors
47
+ //! are. CORRECTNESS IS NOT GUARANTEED.
48
+ SchedulerRuntimeInfo(
49
+ Fusion* complete_fusion,
50
+ KernelArgumentHolder args,
51
+ PrecomputedValues* precomputed_values = nullptr,
52
+ const std::vector<TensorView*>& all_tvs = {},
53
+ std::optional<PrimDataType> forced_index_type = std::nullopt);
54
+
55
+ NVF_API SchedulerRuntimeInfo(
56
+ Fusion* complete_fusion,
57
+ const at::ArrayRef<c10::IValue>& aten_inputs);
58
+
59
+ //! Lookup for the alignment sizes of the given tv. Currently only returns
60
+ //! actual alignment info for input tensors to the complete fusion,
61
+ //! and for other intermediate/fuser-allocated tensors will
62
+ //! return max_alignment_size_in_byte.
63
+ size_t getAlignmentSize(TensorView* tv);
64
+
65
+ //! Returns sizes of tensor dimensions in same order as allocation domain,
66
+ //! ignoring any IterType::Reduction domains in the allocation domain. This
67
+ //! only works for complete Fusion inputs whose allocation domain is a
68
+ //! permutation of their root domain and will raise an exception otherwise.
69
+ const std::vector<int64_t>& getInputAllocationSizes(TensorView* tv) const {
70
+ NVF_ERROR(
71
+ isInputTv(tv),
72
+ "TensorView ",
73
+ tv->toString(),
74
+ " is not an input or its logical domain is not a permutation of its ",
75
+ "allocation domain");
76
+ auto sizes_it = input_sizes_.find(tv);
77
+ NVF_ERROR(sizes_it != input_sizes_.end());
78
+ return sizes_it->second;
79
+ }
80
+
81
+ //! Returns strides of tensor in same order as allocation domain, in elements
82
+ //! instead of bytes. Only works for complete Fusion inputs whose allocation
83
+ //! domain is a permutation of their root domain and will raise an exception
84
+ //! otherwise.
85
+ const std::vector<int64_t>& getInputAllocationStrides(TensorView* tv) const {
86
+ NVF_ERROR(
87
+ isInputTv(tv),
88
+ "TensorView ",
89
+ tv->toString(),
90
+ " is not an input or its logical domain is not a permutation of its ",
91
+ "allocation domain");
92
+ auto strides_it = input_strides_elements_.find(tv);
93
+ NVF_ERROR(strides_it != input_strides_elements_.end());
94
+ return strides_it->second;
95
+ }
96
+
97
+ // Computes alignment size in bytes for provided ptr address
98
+ static size_t computeAlignmentSize(size_t ptr_address);
99
+
100
+ // Return the runtime pointer value for provided tensor view
101
+ size_t ptrOf(TensorView* tv) const;
102
+
103
+ PrimDataType getIndexType() const {
104
+ return index_type_;
105
+ }
106
+
107
+ Fusion* fusion() {
108
+ return complete_fusion_;
109
+ }
110
+
111
+ ExpressionEvaluator& expressionEvaluator() {
112
+ NVF_ERROR(expression_evaluator_ != nullptr);
113
+ return *expression_evaluator_;
114
+ }
115
+
116
+ private:
117
+ // Build and bind full fusion inputs to an expression evaluator
118
+ std::unique_ptr<ExpressionEvaluator> getExpressionEvaluator(
119
+ const KernelArgumentHolder& inputs,
120
+ PrecomputedValues* precomputed_values);
121
+
122
+ bool isInputTv(TensorView* tv) const {
123
+ return std::find(
124
+ complete_fusion_->inputs().begin(),
125
+ complete_fusion_->inputs().end(),
126
+ tv) != complete_fusion_->inputs().end();
127
+ }
128
+
129
+ private:
130
+ // Returns the offset of tv in the inputs ignoring non tensor views. Used to
131
+ // access input_sizes, input_strides, input_ptr
132
+ int offsetTensorPos(TensorView* tv);
133
+
134
+ // Expression evaluator used to probe sizes in the fusion IR
135
+ std::unique_ptr<ExpressionEvaluator> expression_evaluator_ = nullptr;
136
+
137
+ // Fusion reference that this runtime info is associated with
138
+ Fusion* complete_fusion_ = nullptr;
139
+
140
+ // Copy of aten input pointer addresses
141
+ // TODO: Support output tensor pointers
142
+ std::unordered_map<Val*, size_t> input_ptrs_;
143
+
144
+ // Copy of aten input tensor sizes ordered like the TensorView's allocation
145
+ // domain
146
+ std::unordered_map<Val*, std::vector<int64_t>> input_sizes_;
147
+
148
+ // Copy of aten input tensor strides (in elements) ordered like the
149
+ // TensorView's allocation domain
150
+ std::unordered_map<Val*, std::vector<int64_t>> input_strides_elements_;
151
+
152
+ // Copy of aten input tensor strides (in bytes) for only discontiguous
153
+ // dimensions
154
+ std::unordered_map<Val*, std::vector<size_t>> input_discontig_strides_;
155
+
156
+ // Cache for getAlignmentSize
157
+ std::unordered_map<TensorView*, size_t> alignment_map_;
158
+
159
+ // Found index mode kernel needs to be run in
160
+ PrimDataType index_type_ = PrimDataType::Int;
161
+
162
+ // TODO: Remove
163
+ std::unordered_map<TensorView*, size_t> vectorword_map_;
164
+ };
165
+
166
+ } // namespace nvfuser
@@ -0,0 +1,80 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <visibility.h>
11
+ #include <array>
12
+ #include <ostream>
13
+ #include <string>
14
+
15
+ namespace nvfuser {
16
+
17
+ //! Each SchedulerType maps to a scheduler in distinct CPP files.
18
+ //! For instance, SchedulerType::PointWise maps to PointWiseScheduler in
19
+ //! pointwise.cpp.
20
+ //!
21
+ //! Each of the scheduler needs to provide 3 interface functions:
22
+ //!
23
+ //! 1. canScheduleCompileTime(Fusion* fusion) :
24
+ //!
25
+ //! This function contains compiled-time checks on the graph itself
26
+ //! without runtime input information. Only `fusion` is given in the
27
+ //! argument to make sure only compile-time available info is needed in
28
+ //! the check.
29
+ //!
30
+ //! This function is to be called exactly once on each segmented group
31
+ //! created in a segmented fusion so this part will not contribute to
32
+ //! dynamic shape latency.
33
+ //!
34
+ //! 2. canScheduleRunTime(
35
+ //! Fusion* fusion,
36
+ //! SchedulerRuntimeInfo& runtime_info,
37
+ //! HeuristicDataCache* data_cache = nullptr):
38
+ //! This function contains all canSchedule checks that will have to
39
+ //! involve runtime input information, and will be run both by the
40
+ //! segmenter and the kernel cache. The latency of this function will
41
+ //! contribute to dynamic shape latency so `data_cache` should be used as
42
+ //! much as possible to save re-computation.
43
+ //!
44
+ //! 3. schedule(fusion):
45
+ //!
46
+ //! This function will be called when compiling a kernel. It should apply
47
+ //! scheduling to the given fusion
48
+
49
+ enum class SchedulerType {
50
+ None,
51
+ NoOp,
52
+ PointWise,
53
+ Matmul,
54
+ Reduction,
55
+ InnerPersistent,
56
+ InnerOuterPersistent,
57
+ OuterPersistent,
58
+ Transpose,
59
+ ExprEval,
60
+ Resize
61
+ };
62
+
63
+ //! Define a schedule table to loop over all the heuristics in priority order.
64
+ constexpr std::array<SchedulerType, 10> all_heuristics_in_priority_order = {
65
+ SchedulerType::ExprEval,
66
+ SchedulerType::NoOp,
67
+ SchedulerType::Matmul,
68
+ SchedulerType::Reduction,
69
+ SchedulerType::Resize,
70
+ SchedulerType::Transpose,
71
+ SchedulerType::PointWise,
72
+ SchedulerType::InnerPersistent,
73
+ SchedulerType::OuterPersistent,
74
+ SchedulerType::InnerOuterPersistent};
75
+
76
+ std::string toString(SchedulerType sh);
77
+
78
+ NVF_API std::ostream& operator<<(std::ostream& os, SchedulerType sh);
79
+
80
+ } // namespace nvfuser
@@ -0,0 +1,114 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ATen/core/ivalue.h>
11
+ #include <exceptions.h>
12
+ #include <fusion.h>
13
+ #include <scheduler/registry.h>
14
+ #include <scheduler/transpose_heuristic.h>
15
+ #include <visibility.h>
16
+
17
+ #define SUPPORT_SPLITTING_INNERMOST_DIM 0
18
+
19
+ namespace nvfuser {
20
+
21
+ // Note [Transpose scheduling]
22
+ //
23
+ // The target of transpose scheduling is to get coalesced global memory access
24
+ // to as much input and output tensors as possible. For a DAG with only pure
25
+ // pointwise operators, the scheduling is very simple because the inner most
26
+ // dimension of all input and output tensors are all mapped together in the
27
+ // ComputeAtMap, i.e., there is essentially only one inner most dimension. In
28
+ // such case, we just vectorize that inner most dimension and bind it to
29
+ // threadIdx.x identically for all input and output tensors. In the case where
30
+ // transposes are present in the DAG, the inner most dimensions of different
31
+ // inputs and outputs might not match. And there is no fixed pattern on which
32
+ // input/output tensors should share the same inner most dimension with which.
33
+ // Consider the following example DAGs ([T] represents transpose, all tensors
34
+ // are 2D):
35
+ //
36
+ // t0 t1 t0 t1 t0 t1 t0 t1 t0
37
+ // \ | \ / \ | \ | |
38
+ // \ [T] [T] [T] \ [T] t2 [T] [T]
39
+ // \ / \ / \ / \ / \ / \ |
40
+ // t2 t2 t2 t3 t3 t4 t5 [T]
41
+ // |
42
+ // t1
43
+ //
44
+ // In order to support all these cases in a general way, the following
45
+ // perspective is very important: What we are looking for is to bind threadIdx.x
46
+ // differently for different inputs and outputs, so there has to be some tensor
47
+ // somewhere in the DAG that we write and read with different threadIdx.x
48
+ // bindings. The tensor of binding swap can be any tensor on the path that
49
+ // connects inputs/outputs with different inner most dimension, especially, it
50
+ // does not necessarily have to be the tensor of the transpose operator. In
51
+ // other words, thanks to our indexing system who is already taking care of the
52
+ // correctness of transpose, the scheduler can freely choose where to realize
53
+ // these transposes as different threadIdx.x bindings. This observation greatly
54
+ // simplifies our scheduling.
55
+ //
56
+ // Our scheduling strategy is as follows: We first split the inputs and outputs
57
+ // of the fusion into two groups according to their inner most dimension. The
58
+ // inner most dimensions of tensors in the same group are mapped to each other,
59
+ // and they are not mapped to the inner most dimesion of tensors in a different
60
+ // group. Depending on the transpose pattern, there can be more than two groups,
61
+ // if this is the case, we only consider the two largest groups, and the tensors
62
+ // in the remaining groups will just be accessed unvectorized and uncoalesced.
63
+ // We call the largest group as `group1` and the second largest group as
64
+ // `group2`. When we have the groups, we will make a 2D tiling [I1, I2] ->
65
+ // [I1/tile1, tile1, I2/tile2, tile2] on the inner most dimensions of group1 and
66
+ // group2. If I1 and I2 are too small to make a 32x32 tile, such as in the
67
+ // fusion of tanspose(T1[1024, 2, 1024, 2], {1, 3}), we merge in other
68
+ // dimensions to make a virtual I1 and I2. The details of how we create virtual
69
+ // I1 and I2 are described in note [Supporting small transpose dimensions].
70
+ //
71
+ // Each tile [tile1, tile2] will be handled by a block, and the tensors that
72
+ // have mismatched threadIdx.x bindings will use shared memory. The outer IDs of
73
+ // the tiling split will be merged with non-tiled IDs and then binded to
74
+ // blockIdx.x for the entire DAG, regardless of which group a tensor belongs to.
75
+ // For the inner tile IDs [tile1, tile2], we need to transform and parallelize
76
+ // group 1 and group 2 differently. The intermediate tensors can be transformed
77
+ // and parallelized consistently either with group 1 or group 2. Here, since
78
+ // group 1 is larger than group 2, we decide to only transform and parallelize
79
+ // the cached inputs of group 2 together with group 2, and keep the rest of the
80
+ // DAG consistent with group 1.
81
+ //
82
+ // If you would like to see an example of how to manually schedule a complicated
83
+ // DAG using this idea, refer to:
84
+ // FusionManualScheduleTransposeComplexDAG1_CUDA
85
+
86
+ class SchedulerRuntimeInfo;
87
+ class HeuristicDataCache;
88
+
89
+ //! Utility for canSchedule interface to check if this fusion has at least two
90
+ //! groups, each with a fully broadcasted reference tensor.
91
+ NVF_API bool hasAtLeastTwoValidGroups(Fusion* fusion);
92
+
93
+ class TransposeScheduler : public SchedulerEntry {
94
+ public:
95
+ bool canScheduleCompileTime(Fusion* fusion) override;
96
+
97
+ bool canScheduleRunTime(
98
+ Fusion* fusion,
99
+ SchedulerRuntimeInfo& runtime_info,
100
+ HeuristicDataCache* data_cache = nullptr) override;
101
+
102
+ std::unique_ptr<HeuristicParams> computeHeuristics(
103
+ Fusion* fusion,
104
+ SchedulerRuntimeInfo& runtime_info,
105
+ HeuristicDataCache* data_cache) override;
106
+
107
+ void schedule(Fusion* fusion, const HeuristicParams* params) override;
108
+
109
+ constexpr static SchedulerType schedulerType() {
110
+ return SchedulerType::Transpose;
111
+ }
112
+ };
113
+
114
+ } // namespace nvfuser