nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,322 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <fusion.h>
11
+ #include <scheduler/pointwise_utils.h>
12
+ #include <scheduler/scheduler_types.h>
13
+ #include <scheduler/tools/domain_map.h>
14
+ #include <scheduler/utils.h>
15
+ #include <scheduler/vectorize_helper.h>
16
+
17
+ namespace nvfuser {
18
+
19
+ //! namespace for hosting catalog of possible compile time
20
+ //! info that can be cached. Each possible entry type has
21
+ //! a value in `CompileTimeEntryType` and an entry type class
22
+ //! definition like `VectorizableInputsAndOutputs`. The corresponnding
23
+ //! classes contain their entry type, data type and maybe more
24
+ //! later depending on use cases.
25
+ namespace HeuristicCompileTime {
26
+
27
+ //! Each entry type under this category represent some information
28
+ //! that can be inferred compile-time, i.e. without any runtime input
29
+ //! meta data. They will be stored in `HeuristicDataCache` and will
30
+ //! be re-used each time the same fusion is visited.
31
+
32
+ //! Enum for all possible types of cached entries of compile-time info.
33
+ enum class CompileTimeEntryType {
34
+ DOMAIN_MAP,
35
+ TRANSPOSE_DOMAIN_MAP,
36
+ REFERENCE_TENSORS,
37
+ REFERENCE_TENSORS_FOR_GROUPS,
38
+ VECTORIZABLE_INPUTS_AND_OUTPUTS,
39
+ INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS,
40
+ TV_TO_CONTIG_INNER_SIZE_MAPS,
41
+ UNROLLABLE_INPUTS_AND_OUTPUTS,
42
+ REDUCTION_TVS,
43
+ PERSISTENT_BUFFER_INFO,
44
+ SCOPE_PERSISTENT_FACTOR_INFO,
45
+ BROADCAST_BYTE_MULTIPLES,
46
+ INNER_MOST_DIMS_INFO,
47
+ CAN_SCHEDULE_TRANSPOSE,
48
+ CAN_SCHEDULE_MUL_SUM_AS_MMA,
49
+ LOGICAL_REORDER_MAP,
50
+ VECTORIZATION_BREAK_POINT_OF_RED_PROD,
51
+ SCHEDULE_HYPERPARAMETERS
52
+ };
53
+
54
+ //! Entry type definition class for `DOMAIN_MAP`,
55
+ //! stores the domain map of a fusion.
56
+ class DomainMap {
57
+ public:
58
+ using DataType = scheduler_tools::DomainMap;
59
+ static const CompileTimeEntryType EntryType =
60
+ CompileTimeEntryType::DOMAIN_MAP;
61
+ };
62
+
63
+ //! Entry type definition class for `DOMAIN_MAP`,
64
+ //! stores the domain map of a fusion, used by transpose scheduler.
65
+ class TransposeDomainMap {
66
+ public:
67
+ using DataType = scheduler_tools::DomainMap;
68
+ static const CompileTimeEntryType EntryType =
69
+ CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP;
70
+ };
71
+
72
+ //! Entry type definition class for `REFERENCE_TENSORS`,
73
+ //! stores the the reference TensorViews used to schedule a fusion.
74
+ class ReferenceTensors {
75
+ public:
76
+ using DataType = std::vector<TensorView*>;
77
+ static const CompileTimeEntryType EntryType =
78
+ CompileTimeEntryType::REFERENCE_TENSORS;
79
+ };
80
+
81
+ //! Entry type definition class for `REFERENCE_TENSORS`,
82
+ //! stores the the reference TensorViews used to schedule a fusion, used by
83
+ //! transpose scheduler.
84
+ class ReferenceTensorsForGroups {
85
+ public:
86
+ using DataType = std::vector<TensorView*>;
87
+ static const CompileTimeEntryType EntryType =
88
+ CompileTimeEntryType::REFERENCE_TENSORS_FOR_GROUPS;
89
+ };
90
+
91
+ //! Entry type definition class for `VECTORIZABLE_INPUTS_AND_OUTPUTS`,
92
+ //! stores the vectorizable TensorViews on a fusion's inputs and outputs.
93
+ class VectorizableInputsAndOutputs {
94
+ public:
95
+ using DataType = std::vector<TensorView*>;
96
+ static const CompileTimeEntryType EntryType =
97
+ CompileTimeEntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS;
98
+ };
99
+
100
+ //! Entry type definition class for `TV_TO_CONTIG_INNER_SIZE_MAPS`,
101
+ //! stores the vectorizable TensorViews on a fusion's inputs and outputs.
102
+ class TvToContigInnerSizeMaps {
103
+ public:
104
+ using DataType = std::vector<std::unordered_map<TensorView*, Val*>>;
105
+ static const CompileTimeEntryType EntryType =
106
+ CompileTimeEntryType::TV_TO_CONTIG_INNER_SIZE_MAPS;
107
+ };
108
+
109
+ //! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`,
110
+ //! stores the fusion's inputs and outputs grouped by inner most dimension.
111
+ class InputsOutputsInnerDimGroups {
112
+ public:
113
+ using DataType = std::vector<std::vector<TensorView*>>;
114
+ static const CompileTimeEntryType EntryType =
115
+ CompileTimeEntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS;
116
+ };
117
+
118
+ //! Entry type definition class for `UNROLLABLE_INPUTS_AND_OUTPUTS`,
119
+ //! stores the unrollable TensorViews on a fusion's inputs and outputs.
120
+ class UnrollableInputsAndOutputs {
121
+ public:
122
+ using DataType = std::vector<TensorView*>;
123
+ static const CompileTimeEntryType EntryType =
124
+ CompileTimeEntryType::UNROLLABLE_INPUTS_AND_OUTPUTS;
125
+ };
126
+
127
+ //! Entry type definition class for `REDUCTION_TVS`,
128
+ //! stores the all tvs with reduction axes in a fusion.
129
+ class ReductionTVs {
130
+ public:
131
+ using DataType = std::vector<TensorView*>;
132
+ static const CompileTimeEntryType EntryType =
133
+ CompileTimeEntryType::REDUCTION_TVS;
134
+ };
135
+
136
+ //! Entry type definition class for `PERSISTENT_BUFFER_INFO`,
137
+ //! stores persistent buffers inferred from topology and scheduling of fusion.
138
+ class PersistentBufferInfo {
139
+ public:
140
+ using DataType = scheduler_utils::PersistentBufferInfo;
141
+ static const CompileTimeEntryType EntryType =
142
+ CompileTimeEntryType::PERSISTENT_BUFFER_INFO;
143
+ };
144
+
145
+ //! Entry type definition class for `INNER_MOST_DIMS_INFO`,
146
+ //! Used in the transpose scheduler to store inner most IterDomains and their
147
+ //! position in reference1 of group 1 and group 2
148
+ //! Note, negative value indicates mapping failure
149
+ class InnerMostDimInfo {
150
+ public:
151
+ using DataType = std::vector<int64_t>;
152
+ static const CompileTimeEntryType EntryType =
153
+ CompileTimeEntryType::INNER_MOST_DIMS_INFO;
154
+ };
155
+
156
+ //! Auxiliary data types for `SCOPE_PERSISTENT_FACTOR_INFO` entry type.
157
+ using ScopedPersistenceBufferMap = std::unordered_map<Val*, std::vector<bool>>;
158
+
159
+ //! Entry type definition class for `SCOPE_PERSISTENT_FACTOR_INFO`,
160
+ // Tracks which buffers are active at a given Val*, order of bool vector is
161
+ // based on persistence buffer order from persistence buffer info, this is then
162
+ // appended by the projectable persistent buffers' inputs. True in the bool
163
+ // vector means the persistent buffer is active at the generation of the key.
164
+ class ScopePersistentFactorInfo {
165
+ public:
166
+ using DataType = ScopedPersistenceBufferMap;
167
+ static const CompileTimeEntryType EntryType =
168
+ CompileTimeEntryType::SCOPE_PERSISTENT_FACTOR_INFO;
169
+ };
170
+
171
+ //! Entry type definition class for `BROADCAST_BYTE_MULTIPLES`,
172
+ //! stores "byte multiples" information. This information can be used to figure
173
+ //! out if using a 2D scheduler how many bytes have to be transferred with
174
+ //! varying split locations. See BroadcastMultiple definition for more
175
+ //! information.
176
+ class BroadcastMultiples {
177
+ public:
178
+ using DataType = scheduler_utils::BroadcastMultipleInformation;
179
+ static const CompileTimeEntryType EntryType =
180
+ CompileTimeEntryType::BROADCAST_BYTE_MULTIPLES;
181
+ };
182
+
183
+ //! Entry type definition class for `CAN_SCHEDULE_TRANSPOSE`,
184
+ //! stores if the transpose scheduler can scheduler this fusion
185
+ class CanScheduleTranspose {
186
+ public:
187
+ using DataType = bool;
188
+ static const CompileTimeEntryType EntryType =
189
+ CompileTimeEntryType::CAN_SCHEDULE_TRANSPOSE;
190
+ };
191
+
192
+ //! Entry type definition class for `LOGICAL_REORDER_MAP`,
193
+ //! stores the domain map of a fusion.
194
+ class LogicalReorderMap {
195
+ public:
196
+ using DataType = std::unordered_map<int64_t, int64_t>;
197
+ static const CompileTimeEntryType EntryType =
198
+ CompileTimeEntryType::LOGICAL_REORDER_MAP;
199
+ };
200
+
201
+ class VectorizationBreakPointOfReductionProducer {
202
+ public:
203
+ using DataType = int64_t;
204
+ static const CompileTimeEntryType EntryType =
205
+ CompileTimeEntryType::VECTORIZATION_BREAK_POINT_OF_RED_PROD;
206
+ };
207
+
208
+ //! Entry type definition class for `SCHEDULE_HYPERPARAMETERS`,
209
+ //! stores hyperparameters for SchedulerEntry::computeHeuristics
210
+ class SchedulerHyperParameters {
211
+ public:
212
+ using DataType = scheduler_utils::SchedulerHyperParameters;
213
+ static const CompileTimeEntryType EntryType =
214
+ CompileTimeEntryType::SCHEDULE_HYPERPARAMETERS;
215
+ };
216
+
217
+ //! Base abstract class for unified storage in `HeuristicDataCache`,
218
+ //! each entry in `HeuristicDataCache` will be a subclass.
219
+ class CompileTimeInfoBase : public PolymorphicBase {
220
+ public:
221
+ CompileTimeInfoBase(CompileTimeEntryType entry_type)
222
+ : entry_type_(entry_type) {}
223
+ CompileTimeEntryType type() {
224
+ return entry_type_;
225
+ }
226
+
227
+ private:
228
+ CompileTimeEntryType entry_type_;
229
+ };
230
+
231
+ } // namespace HeuristicCompileTime
232
+
233
+ //! Note: Do NOT export this class. MSVC issue with exported class that contains
234
+ //! std::vector<unique_ptr<xxx>>: https://godbolt.org/z/3E4e8T1P1
235
+ //! Compile-time information cache for `canSchedule` and `getHeuristics`
236
+ //! interfaces. Each cache instance stores information that could be inferred at
237
+ //! compile time in a fusion and therefore corresponds to an instance of
238
+ //! KernelExecutor.
239
+ class HeuristicDataCache {
240
+ using EntryOwningPtr =
241
+ std::unique_ptr<HeuristicCompileTime::CompileTimeInfoBase>;
242
+ using EntryPtr = HeuristicCompileTime::CompileTimeInfoBase*;
243
+ using EntryType = HeuristicCompileTime::CompileTimeEntryType;
244
+
245
+ public:
246
+ bool hasEntry(EntryType entry_type) {
247
+ return entry_type_map_.find(entry_type) != entry_type_map_.end();
248
+ }
249
+
250
+ void insert(EntryOwningPtr new_entry);
251
+
252
+ EntryPtr at(EntryType entry_type) {
253
+ return entry_type_map_.at(entry_type);
254
+ }
255
+
256
+ private:
257
+ std::vector<EntryOwningPtr> entries_;
258
+ std::unordered_map<EntryType, EntryPtr> entry_type_map_;
259
+ };
260
+
261
+ //! A utility class to facilitate accessing HeuristicDataCache.
262
+ //! This utility is needed because the information to be stored
263
+ //! in HeuristicDataCache is used in several different scenarios
264
+ //! and we want to support all these use cases in one interface.
265
+ //! The current use examples are:
266
+ //! 1. During fusion segmentation process, all the fusions
267
+ //! given to canSchedule are temporary and therefore the
268
+ //! compile time info do not need to be cached, and in fact
269
+ //! a cache wouldn't be instantiated by that time.
270
+ //!
271
+ //! 2. When a kernel is created for the first time, entries will be
272
+ //! missing in the cache and all the computed information will be
273
+ //! captured and written into the cache.
274
+ //!
275
+ //! 3. When we check a compiled fusion for heuristic hit, we want to
276
+ //! use the cached info to save runtime latency.
277
+ //!
278
+ //! The designed interface is used as:
279
+ //! auto entry = HeuristicDataCacheEntry<EntryClass>(data_cache, maker_fn);
280
+ //! auto& data = entry.get();
281
+ //!
282
+ //! `maker_fn` will be called to compute the information when no cached data
283
+ //! exists and `entry` will own the computed data when no data cache is
284
+ //! supplied.
285
+ template <typename EntryClass>
286
+ class HeuristicDataCacheEntry {
287
+ using EntryDataType = typename EntryClass::DataType;
288
+ using EntryDataTypeOwnPtr = std::unique_ptr<EntryDataType>;
289
+ using MakerFnType = std::function<EntryDataTypeOwnPtr()>;
290
+
291
+ public:
292
+ //! Creates a data entry with type defined in EntryClass,
293
+ //! eg. EntryClass = VectorizableInputsAndOutputs;
294
+ //!
295
+ //! @param data_cache, a pointer to an instantiated compile-time
296
+ //! info cache. The info data will be
297
+ //! 1. read from data cache if data cache is not recording.
298
+ //! 2. written into data cache if data cache is recording.
299
+ //! 3. managed by owned_data_ if data cache is nullptr
300
+ //! @param fn:
301
+ //! The factory function that needs to return a owning pointer
302
+ //! i.e. std::unique_ptr<EntryClass::DataType>. It will only
303
+ //! be called either when data cache is recording or when no data
304
+ //! cache is given.
305
+ HeuristicDataCacheEntry(HeuristicDataCache* data_cache, MakerFnType fn);
306
+
307
+ //! Unified interface to get actual data, either from cache
308
+ //! or from factory function.
309
+ EntryDataType& get() {
310
+ return *data_ptr_;
311
+ }
312
+
313
+ private:
314
+ //! Internal data owing pointer that will manage the computed
315
+ //! data where there is no data cache.
316
+ EntryDataTypeOwnPtr owned_data_ = nullptr;
317
+
318
+ //! Pointer to the valid data entry that could be accessed.
319
+ EntryDataType* data_ptr_ = nullptr;
320
+ };
321
+
322
+ } // namespace nvfuser
@@ -0,0 +1,68 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <debug.h>
11
+ #include <options.h>
12
+ #include <utils.h>
13
+
14
+ #include <iostream>
15
+
16
+ namespace nvfuser {
17
+
18
+ namespace scheduler_debug_utils {
19
+
20
+ // Basic logging utility for any messages in scheduler or segmenter
21
+ template <typename... Args>
22
+ void canScheduleMessage(const Args&... args) {
23
+ // Using builtin expect to reduce the overhead slightly,
24
+ // alternatively may want to allow this message in debug
25
+ // build only but that'd be inconvenient for user support.
26
+ if (C10_UNLIKELY(
27
+ isDebugDumpEnabled(DebugDumpOption::FusionSegmenterLog) ||
28
+ isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose))) {
29
+ debug() << c10::str(args...) << "\n";
30
+ }
31
+ }
32
+
33
+ // Short-cut message for flagging why shedulers cannot schedule fusions,
34
+ // assuming first argument is heuristic type (not actively checked).
35
+ template <typename SchedulerType, typename... Args>
36
+ void canScheduleRejectReason(
37
+ SchedulerType scheduler_type,
38
+ const Args&... args) {
39
+ canScheduleMessage(
40
+ "Scheduler _", scheduler_type, "_ ***rejected*** because : ", args...);
41
+ }
42
+
43
+ // Based on
44
+ // https://learn.microsoft.com/en-us/cpp/cpp/ellipses-and-variadic-templates?view=msvc-170#example
45
+ inline void log() {
46
+ if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) {
47
+ debug() << std::endl;
48
+ }
49
+ }
50
+
51
+ template <typename T>
52
+ void log(const T& t) {
53
+ if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) {
54
+ debug() << t << std::endl;
55
+ }
56
+ }
57
+
58
+ template <typename First, typename... Rest>
59
+ void log(const First& first, const Rest&... rest) {
60
+ if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) {
61
+ debug() << first;
62
+ log(rest...);
63
+ }
64
+ }
65
+
66
+ } // namespace scheduler_debug_utils
67
+
68
+ } // namespace nvfuser
@@ -0,0 +1,45 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <scheduler/heuristic.h>
11
+ #include <scheduler/registry.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ class Fusion;
16
+ class SchedulerRuntimeInfo;
17
+ class HeuristicDataCache;
18
+
19
+ // ExprEval scheduler represents the case where we allocate outputs directly
20
+ // using EE. No code is generated.
21
+ class ExprEvalScheduler : public SchedulerEntry {
22
+ public:
23
+ // This scheduler only accepts MatmulOp.
24
+ bool canScheduleCompileTime(Fusion* fusion) override;
25
+
26
+ bool canScheduleRunTime(
27
+ Fusion* fusion,
28
+ SchedulerRuntimeInfo& runtime_info,
29
+ HeuristicDataCache* data_cache) override {
30
+ return true;
31
+ }
32
+
33
+ std::unique_ptr<HeuristicParams> computeHeuristics(
34
+ Fusion* fusion,
35
+ SchedulerRuntimeInfo& runtime_info,
36
+ HeuristicDataCache* data_cache) override;
37
+
38
+ void schedule(Fusion* fusion, const HeuristicParams* params) override;
39
+
40
+ constexpr static SchedulerType schedulerType() {
41
+ return SchedulerType::ExprEval;
42
+ }
43
+ };
44
+
45
+ } // namespace nvfuser
@@ -0,0 +1,113 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <runtime/executor_params.h>
11
+ #include <scheduler/scheduler_types.h>
12
+ #include <utils.h>
13
+
14
+ #include <string>
15
+
16
+ namespace nvfuser {
17
+
18
+ class SchedulerRuntimeInfo;
19
+ class HeuristicDataCache;
20
+
21
+ // Top-level class representing heuristic parameters. Most schedulers
22
+ // have their own subclasses to have their specific parameters, except
23
+ // for ExprEval schedulers.
24
+ class HeuristicParams : public PolymorphicBase {
25
+ public:
26
+ std::string tag = "";
27
+
28
+ LaunchParams lparams;
29
+ CompileParams cparams;
30
+ const SchedulerType scheduler_type;
31
+
32
+ virtual std::string toString() const {
33
+ std::stringstream ss;
34
+ ss << "Heuristic Params (" << scheduler_type << ")";
35
+ return ss.str();
36
+ }
37
+
38
+ virtual size_t hash() const {
39
+ return 0;
40
+ };
41
+
42
+ virtual bool sameAs(const HeuristicParams* other) const {
43
+ if (!other->isStrictlyA<HeuristicParams>()) {
44
+ return false;
45
+ }
46
+ if (other->scheduler_type != scheduler_type) {
47
+ return false;
48
+ }
49
+ return other->cparams == cparams;
50
+ }
51
+
52
+ HeuristicParams() = delete;
53
+ explicit HeuristicParams(SchedulerType _scheduler_type)
54
+ : scheduler_type(_scheduler_type) {};
55
+
56
+ virtual std::unique_ptr<HeuristicParams> clone() const {
57
+ return std::make_unique<HeuristicParams>(*this);
58
+ }
59
+ };
60
+
61
+ //! Auxiliary class for storing heuristics. The managed data is either
62
+ //! a single heursitic for complete fusion, or a vector of heuristics used for
63
+ //! a segmented fusion.
64
+ class HeuristicParamsList {
65
+ public:
66
+ //! Constructor for segmented fusion case. Created with empty list and
67
+ //! uses emplaceBack for inserting heuristics in order
68
+ explicit HeuristicParamsList() = default;
69
+
70
+ //! Constructor fills heuristics_ with nullptr, which allows us to create
71
+ //! SchedulerEntries out of order.
72
+ explicit HeuristicParamsList(size_t num_heuristics) {
73
+ heuristics_.reserve(num_heuristics);
74
+ std::fill_n(std::back_inserter(heuristics_), num_heuristics, nullptr);
75
+ }
76
+
77
+ //! Constructor for complete fusion case, generates the scheduler entry
78
+ //! for the fusion owning the given expression
79
+ explicit HeuristicParamsList(
80
+ SchedulerType scheduler_type,
81
+ SchedulerRuntimeInfo& runtime_info,
82
+ HeuristicDataCache* data_cache = nullptr);
83
+
84
+ HeuristicParamsList(const HeuristicParamsList&) = delete;
85
+ HeuristicParamsList& operator=(const HeuristicParamsList&) = delete;
86
+
87
+ std::unique_ptr<HeuristicParams>& at(int index) {
88
+ return heuristics_.at(index);
89
+ }
90
+
91
+ //! Place a heuristics on the list. Applies to segmented fusion only.
92
+ void emplaceBack(std::unique_ptr<HeuristicParams>&& pt) {
93
+ NVF_ERROR(is_segmented_);
94
+ heuristics_.emplace_back(std::move(pt));
95
+ }
96
+
97
+ //! Returns list of heuristics for a segmneted fusion.
98
+ const std::vector<std::unique_ptr<HeuristicParams>>& heuristicsList() const {
99
+ return heuristics_;
100
+ }
101
+
102
+ //! Returns the single heuristics for a complete fusion.
103
+ HeuristicParams* singleKernelHeuristics() const {
104
+ NVF_ERROR(!is_segmented_);
105
+ return heuristics_.begin()->get();
106
+ }
107
+
108
+ private:
109
+ std::vector<std::unique_ptr<HeuristicParams>> heuristics_;
110
+ bool is_segmented_ = true;
111
+ };
112
+
113
+ } // namespace nvfuser