nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,351 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <compute_at_map.h>
13
+ #include <device_lower/analysis/trivial_broadcast.h>
14
+ #include <disjoint_set.h>
15
+ #include <ir/all_nodes.h>
16
+
17
+ namespace nvfuser {
18
+
19
+ // Goes through the transformations associated with a series of ids and
20
+ // alloction ids. Checks the ordering of the iteration domains through these
21
+ // operations to pick out which operations are consistently ordered. For
22
+ // example: [i0, i1, i2]
23
+ // ->split(0, 4)->merge(1)->merge(1)->merge(0)
24
+ // are consistently ordered from largest to smallest extents, but
25
+ // ->split(0, 4)->merge(1)->merge(0, 2)->merge(0) is not consistently ordered
26
+ // with the alloction domain.
27
+ //
28
+ // This property is important to understand the contiguity of dimensions through
29
+ // complex transformations.
30
+ class OrderedIdInformation : public OptInDispatch {
31
+ public:
32
+ static OrderedIdInformation get(
33
+ const std::vector<IterDomain*>& ids,
34
+ const std::vector<IterDomain*>& alloc_domain,
35
+ std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info) {
36
+ OrderedIdInformation info(alloc_domain, concrete_info);
37
+ info.traverseTo(ids);
38
+ return info;
39
+ }
40
+
41
+ const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
42
+ idToAllocIds() const {
43
+ return id_to_alloc_ids_;
44
+ }
45
+
46
+ virtual bool isConsistentlyOrdered(IterDomain* id) const {
47
+ return consistently_ordered_ids_.find(id) !=
48
+ consistently_ordered_ids_.end();
49
+ }
50
+
51
+ bool exclusivelyConsumesAllocs(IterDomain* id) const {
52
+ return exclusively_consumes_allocs_.find(id) !=
53
+ exclusively_consumes_allocs_.end();
54
+ }
55
+
56
+ virtual std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>::
57
+ const_iterator
58
+ findAllocIDs(IterDomain* id) const {
59
+ return id_to_alloc_ids_.find(id);
60
+ }
61
+
62
+ protected:
63
+ OrderedIdInformation(
64
+ const std::vector<IterDomain*>& alloc_domain,
65
+ std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info =
66
+ nullptr);
67
+
68
+ void traverseTo(const std::vector<IterDomain*>& ids);
69
+
70
+ // Returns if the id in active_ids should be in exclusively_consumes_allocs_
71
+ bool checkExclusivelyConsumesAllocs(IterDomain* id);
72
+
73
+ void handle(Split*) override;
74
+
75
+ void handle(Merge* merge) override;
76
+
77
+ void handle(Swizzle* swizzle) override;
78
+
79
+ void handle(Swizzle2D* swizzle) override;
80
+
81
+ void handle(Resize* resize) override;
82
+
83
+ virtual std::vector<IterDomain*>::const_iterator findActiveId(
84
+ IterDomain* id) const {
85
+ return std::find(active_ids_.begin(), active_ids_.end(), id);
86
+ }
87
+
88
+ bool isActiveId(IterDomain* id) const {
89
+ return findActiveId(id) != active_ids_.end();
90
+ }
91
+
92
+ int64_t getActiveIdPos(IterDomain* id) const {
93
+ auto it = findActiveId(id);
94
+ NVF_ERROR(it != active_ids_.end());
95
+ return std::distance(active_ids_.begin(), it);
96
+ }
97
+
98
+ bool isConcretized(IterDomain* id) const {
99
+ NVF_ERROR(concrete_info_ != nullptr);
100
+ return concrete_info_->isConcretized(id);
101
+ }
102
+
103
+ protected:
104
+ // Track which allocation ids were used to generate each iter domain
105
+ std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
106
+ id_to_alloc_ids_;
107
+
108
+ // Track all IterDomains that have correct ordered transforms for contiguity.
109
+ // i.e. if we have:
110
+ //
111
+ // alloc = [i0, i1, i2]
112
+ // i3 = merge(i0, i2)
113
+ // would not be consistently ordered transformed
114
+ //
115
+ // alloc = [i0, i1, i2]
116
+ // i4, i5 = spit(merge(merge(i0, i1), i2), 4)
117
+ // would be consistently ordered transforms
118
+ //
119
+ // alloc = [i0, i1, i2, i3]
120
+ // i4 = merge(i1, i2) would also be consistently ordered transformed
121
+ std::unordered_set<IterDomain*> consistently_ordered_ids_;
122
+
123
+ // Active series of IterDomains that are updated while we're processing the
124
+ // domain. Helps us identify which ids are consistently_ordered_ids_. Used
125
+ // for intermediate storage, not to return.
126
+ std::vector<IterDomain*> active_ids_;
127
+
128
+ // IterDomains in this set exclusively consume all the uses of their
129
+ // allocations. For example: [i0, i1] split(0, f)->merge(1) [ceilDiv(i0, f),
130
+ // f*i1] neither iter domains exclusively consume the allocations. With
131
+ // another: merge(0) -> [ceilDiv(i0, f)*f*i1] The resulting iter domain does
132
+ // exclusively consume the allocations.
133
+ //
134
+ // Also:
135
+ // [i0, i1, i2, i3] merge(1)->merge(1)
136
+ // ->[i0, i1*i2*i3]
137
+ // both resulting iter domains do exclusively consume their allocations
138
+ std::unordered_set<IterDomain*> exclusively_consumes_allocs_;
139
+
140
+ // Broadcast domains that are concretized cannot be considered contiguously
141
+ // indexable.
142
+ // TODO: This constraint is more conservative than necessary as it's only if
143
+ // the domain is concretized within the local indexing, not in the entire
144
+ // fusion.
145
+ std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_;
146
+
147
+ // TODO: Temporary WAR to do ContigIDGroup-specific processing
148
+ bool using_id_graph_ = false;
149
+ };
150
+
151
+ // Based on provided divisible split set, goes through expressions and marks all
152
+ // IterDomains that are dependent on a non-divisible split.
153
+ class NonDivisibleSplitDependencies : public OptInDispatch {
154
+ public:
155
+ NonDivisibleSplitDependencies() = delete;
156
+
157
+ NonDivisibleSplitDependencies(
158
+ const std::vector<IterDomain*>& ids,
159
+ const std::vector<IterDomain*>& alloc_domain,
160
+ const std::unordered_set<Split*>& divisible_splits);
161
+
162
+ bool dependsOnNonDivisibleSplit(IterDomain* id) const {
163
+ return depends_on_non_divisible_split.find(id) !=
164
+ depends_on_non_divisible_split.end();
165
+ }
166
+
167
+ private:
168
+ std::unordered_set<IterDomain*> depends_on_non_divisible_split;
169
+ };
170
+
171
+ // A merge is contiguous if:
172
+ // Inputs of outer are to the left in the allocation domain of the inputs of
173
+ // RHS. All inputs are contiguous in the allocation domain:
174
+ // - All marked as contiguous
175
+ // - Only gaps between inputs are broadcast or reductoin dims
176
+ // There are no split transformations performed on outer or inner
177
+ // All transformations on outer or inner are contiguous merges
178
+ // If this criteria holds, then we can index the input allocation domains of
179
+ // this merge with the indexing provided to the output of the merge in the
180
+ // backward index pass
181
+
182
+ class ContigIDs : public OptInDispatch {
183
+ public:
184
+ //! Check through the history of ids whose inputs map to alloc_domain with
185
+ //! contiguity alloc_contiguity. Return unordered_set of all merges that are
186
+ //! contiguous. Ignore allocation order is primarily used for predicate
187
+ //! generation. In this case we can linearize indexing of any ID that only
188
+ //! consists of merge operations.
189
+ //!
190
+ //! Mapping information from CA Index concrete to reference domains
191
+ //! is used to find if merged output domains can be indexed. If there's
192
+ //! no mapping to a reference domain, there's no corresponding
193
+ //! index, so it isn't marked as conting merge.
194
+ //!
195
+ //! p2c_id_map can be used when replayed producer domains are
196
+ //! analyzed, in which case producer-to-consumer maps should be
197
+ //! passed.
198
+ //!
199
+ //! If ignore_indexability is true, ignore the constraint on indexing. It is
200
+ //! the caller that is responsible for its correctness.
201
+ //! Not really sure why but clang-tidy only complains about
202
+ //! std::unordered_map if passed as a const reference.
203
+ ContigIDs(
204
+ const std::vector<IterDomain*>& ids,
205
+ const std::vector<IterDomain*>& alloc_domain,
206
+ const std::vector<std::optional<bool>>& alloc_contiguity,
207
+ const std::unordered_set<IterDomain*>& final_ids,
208
+ const std::unordered_map<IterDomain*, Val*>& index_map,
209
+ const std::unordered_set<Split*>& divisible_splits,
210
+ std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
211
+ bool ignore_indexability = false,
212
+ bool ignore_consistent_ordering = false);
213
+
214
+ //! \param ids IterDomains on the loop domain we're looking for contiguous
215
+ //! indexing into. \param alloc_domain the allocation domain of the domain
216
+ //! we're looking for contiguous indexing into. \param alloc_contiguity the
217
+ //! contiguity of the alloc_domain. \param concrete_to_ref concrete ids of the
218
+ //! exact map that the reference index is using for indexing. \param
219
+ //! divisible_splits a set of all splits in the fusion that are divisible.
220
+ //! \param ca_map compute at map of the fusion.
221
+ //! \param concrete_info concretized broadcast information of the fusion.
222
+ //! \param p2c_id_map map from producer to consumer ids used for indexing
223
+ //! producer tensors.
224
+ //! \param ignore_consistent_ordering true for actual indexing into tensors
225
+ //! but false for predicate analysis. Ordering of merges don't matter for
226
+ //! predicate generation as they don't map to a physical address.
227
+ //! \param ignore_indexability can only be true if providing a real
228
+ //! concrete_to_ref map. As what it's checking is if the index is actually
229
+ //! indexable based on the reference.
230
+ ContigIDs(
231
+ const std::vector<IterDomain*>& ids,
232
+ const std::vector<IterDomain*>& alloc_domain,
233
+ const std::vector<std::optional<bool>>& alloc_contiguity,
234
+ const std::unordered_set<IterDomain*>& final_ids,
235
+ const std::unordered_map<IterDomain*, Val*>& index_map,
236
+ const std::unordered_set<Split*>& divisible_splits,
237
+ std::shared_ptr<const ComputeAtMap> ca_map,
238
+ std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info,
239
+ std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
240
+ bool ignore_indexability = false,
241
+ bool ignore_consistent_ordering = false);
242
+
243
+ //! Return an empty ContigIDs with no contiguous ID
244
+ static ContigIDs getNonContigIDs();
245
+
246
+ const std::unordered_set<IterDomain*>& contigIDs() const {
247
+ return contig_ids_;
248
+ }
249
+
250
+ const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>&
251
+ withinContigIDs() const {
252
+ return within_contig_ids_;
253
+ }
254
+
255
+ const std::unordered_map<IterDomain*, IterDomain*>& allocToIndexedID() const {
256
+ return alloc_to_indexed_id_;
257
+ }
258
+
259
+ VectorOfUniqueEntries<IterDomain*> indexedAllocIDs(IterDomain* id) const {
260
+ auto alloc_ids_it = consistent_transform_info_->idToAllocIds().find(id);
261
+ if (alloc_ids_it == consistent_transform_info_->idToAllocIds().end()) {
262
+ return {};
263
+ }
264
+ return alloc_ids_it->second;
265
+ }
266
+
267
+ private:
268
+ using OptInDispatch::handle;
269
+
270
+ bool inAlloc(const std::vector<IterDomain*>& ids) {
271
+ return std::all_of(ids.begin(), ids.end(), [this](IterDomain* id) {
272
+ return is_contig_alloc_.find(id) != is_contig_alloc_.end();
273
+ });
274
+ }
275
+
276
+ bool isContig(IterDomain* id) {
277
+ return contig_ids_.find(id) != contig_ids_.end();
278
+ }
279
+
280
+ // Split outputs are not contiguous, don't need to do anything.
281
+ void handle(Split*) override {}
282
+
283
+ void handle(Merge* merge) override;
284
+
285
+ // TODO:
286
+ // Currently not propagating any contiguity information
287
+ // as contiguity is generally not preserved after swizzles.
288
+ // But in follow ups we could gradually add back a few special
289
+ // cases, depending on specific swizzle type and axes.
290
+ void handle(Swizzle* swizzle) override {}
291
+ void handle(Swizzle2D* swizzle) override {}
292
+
293
+ void handle(Resize* resize) override {}
294
+
295
+ IterDomain* getCAIndexConcreteId(IterDomain* id) const;
296
+
297
+ //! True if an ID is indexable.
298
+ //! E.g., a merged domain with broadcast may not be indexable when
299
+ //! its corresponding reference tensor has non-broadcast domains.
300
+ bool isIndexable(IterDomain* id) const;
301
+
302
+ //! Return an ID mapped with id_map_ or itself
303
+ IterDomain* getMappedId(IterDomain* id) const;
304
+
305
+ private:
306
+ void build(const std::vector<IterDomain*>& ids);
307
+
308
+ //! Allocation domains to analyze contiguity
309
+ const std::vector<IterDomain*>& alloc_domain_;
310
+ //! Contiguity of alloc_domain_
311
+ const std::vector<std::optional<bool>>& alloc_contiguity_;
312
+ //! Domains where indexing/predicates cannot be done with their
313
+ //! consumers domains
314
+ const std::unordered_set<IterDomain*>& final_ids_;
315
+ //! Mapping of concrete domains to indices. Just used to check if
316
+ //! there's an index for an IterDomain.
317
+ const std::unordered_map<IterDomain*, Val*>& index_map_;
318
+ // Divisible split information as we can still consider iter domains
319
+ // contiguous through divisible splits.
320
+ const std::unordered_set<Split*>& divisible_splits_;
321
+
322
+ std::shared_ptr<const ComputeAtMap> ca_map_;
323
+ std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_;
324
+
325
+ //! Producer-to-consumer index map in the case of analyzing replayed
326
+ //! producer tensors
327
+ const std::unordered_map<IterDomain*, IterDomain*> p2c_id_map_;
328
+
329
+ const bool ignore_indexability_ = false;
330
+ const bool ignore_consistent_ordering_ = false;
331
+
332
+ //! Mapping of allocation domain to bool indicating contiguity
333
+ std::unordered_map<IterDomain*, bool> is_contig_alloc_;
334
+ // Mark if ids are result of contigous merges
335
+ std::unordered_set<IterDomain*> contig_ids_;
336
+ // Given contiguous domain, return all iter domains within its history.
337
+ std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
338
+ within_contig_ids_;
339
+ //! Mapping of allocation domain to the actual indexed domain, which can
340
+ //! be itself or a contig merged domain if found.
341
+ std::unordered_map<IterDomain*, IterDomain*> alloc_to_indexed_id_;
342
+
343
+ std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;
344
+
345
+ NonDivisibleSplitDependencies non_divisible_id_info_;
346
+
347
+ //! IDs that depend on resize output IDs
348
+ std::unordered_set<IterDomain*> resize_deps_;
349
+ };
350
+
351
+ } // namespace nvfuser
@@ -0,0 +1,50 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <cuda_runtime.h>
11
+ #include <driver_api.h>
12
+ #include <exceptions.h>
13
+ #include <nvrtc.h>
14
+
15
+ #define NVFUSER_NVRTC_SAFE_CALL(x) \
16
+ do { \
17
+ nvrtcResult _result = x; \
18
+ NVF_ERROR( \
19
+ _result == NVRTC_SUCCESS, \
20
+ "NVRTC error: " #x "failed with error ", \
21
+ nvrtcGetErrorString(_result)); \
22
+ } while (0)
23
+
24
+ #define NVFUSER_CUDA_SAFE_CALL(x) \
25
+ do { \
26
+ CUresult _result = x; \
27
+ if (_result != CUDA_SUCCESS) { \
28
+ const char* msg; \
29
+ const char* name; \
30
+ cuGetErrorName(_result, &name); \
31
+ cuGetErrorString(_result, &msg); \
32
+ NVF_ERROR( \
33
+ _result == CUDA_SUCCESS, \
34
+ "CUDA error: ", \
35
+ name, \
36
+ " failed with error ", \
37
+ msg); \
38
+ } \
39
+ } while (0)
40
+
41
+ #define NVFUSER_CUDA_RT_SAFE_CALL(x) \
42
+ do { \
43
+ cudaError_t _result = x; \
44
+ NVF_ERROR( \
45
+ _result == cudaSuccess, \
46
+ "CUDA error: ", \
47
+ cudaGetErrorName(_result), \
48
+ " failed with error ", \
49
+ cudaGetErrorString(_result)); \
50
+ } while (0)
@@ -0,0 +1,50 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <visibility.h>
11
+
12
+ #include <string>
13
+ #include <unordered_map>
14
+ #include <vector>
15
+
16
+ namespace nvfuser {
17
+
18
+ //! This guard controls the output for debug info, such as any output resulting
19
+ //! from use of the $NVFUSER_DUMP environment variable options. Debug output can
20
+ //! be captured like so:
21
+ //!
22
+ //! std::stringstream ss
23
+ //! {
24
+ //! DebugStreamGuard dsg(ss);
25
+ //! // Unmodified original code
26
+ //!
27
+ //! // ss.str() now holds a std::string of debug info
28
+ //! // The guard resets the debug stream at the end of its lifetime
29
+ //! }
30
+ //! // Code after the dsg object is destroyed will use the previously-set
31
+ //! // stream, which defaults to std::cout.
32
+ class DebugStreamGuard {
33
+ public:
34
+ NVF_API DebugStreamGuard(std::ostream& stream);
35
+
36
+ NVF_API ~DebugStreamGuard();
37
+
38
+ static std::ostream& getCurStream();
39
+
40
+ void setCurStream(std::ostream& stream);
41
+
42
+ private:
43
+ std::ostream* prev_stream_;
44
+ };
45
+
46
+ //! This is just a short alias to avoid having to type
47
+ //! DebugStreamGuard::getCurStream() for each line we want to debug-print.
48
+ NVF_API std::ostream& debug();
49
+
50
+ } // namespace nvfuser
@@ -0,0 +1,53 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/base_nodes.h>
12
+ #include <kernel.h>
13
+ #include <polymorphic_value.h>
14
+ #include <runtime/executor_params.h>
15
+ #include <visibility.h>
16
+
17
+ #include <unordered_map>
18
+ #include <utility>
19
+
20
+ namespace nvfuser {
21
+
22
+ // for more info on shared memory access see page 54-72 of:
23
+ // https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf
24
+
25
+ // Warning: The bank confliction checking utility here is not a replacement of
26
+ // nsight compute. This utility currently has the following assumptions and
27
+ // limitations:
28
+ //
29
+ // 1. This utility assumes that the data of the tensor is accessed by
30
+ // `T0[index]`, where `index` is the one stored in the `TensorIndex`
31
+ // object.
32
+ // 2. This utility only checks the first iteration. If we have something like
33
+ // `T1_s[tidx, 5]`, then different iterations should have different
34
+ // conflictions, which will not be evaluated for all of them
35
+ // 3. This utility assumes that all tensors are independent, which means:
36
+ // 3.1 All shared memory tensors are allocated starting from a multiple of
37
+ // 4*32 bytes
38
+ // 3.2 The only source of bank confliction is from within a tensor.
39
+ // There is no bank conflict between different tensors.
40
+ //
41
+ // Also note that this utility will not provide accurate estimation if the above
42
+ // assumptions are satisfied
43
+
44
+ // Returns (expression, input conflict ways, output conflict ways)
45
+ // way == 0 --> not applicable (for example, tensor is not a smem tensor)
46
+ // way == 1 --> no conflict
47
+ NVF_API std::unordered_map<const Expr*, std::pair<int64_t, int64_t>>
48
+ getBankConflictInfo(
49
+ const kir::Kernel* kernel,
50
+ LaunchParams launch_params = {},
51
+ const std::unordered_map<Val*, PolymorphicValue>& known_values = {});
52
+
53
+ } // namespace nvfuser
@@ -0,0 +1,109 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+
12
+ #include <ir/all_nodes.h>
13
+ #include <kernel_ir.h>
14
+ #include <kernel_ir_dispatch.h>
15
+
16
+ namespace nvfuser {
17
+
18
+ IterDomain* getCircularBufferAxis(const TensorView* tv);
19
+
20
+ void validateCircularBufferedTensor(const TensorView* tv);
21
+
22
+ class CircularBufferInfo {
23
+ // Lowering information of circular buffered tensors.
24
+ struct TvInfo {
25
+ IterDomain* circular_buffer_axis = nullptr;
26
+ Val* original_alloc_size = nullptr;
27
+ };
28
+
29
+ public:
30
+ void build(Fusion* fusion);
31
+
32
+ void setCircularBufferTv(const TensorView* tv);
33
+
34
+ IterDomain* getCircularBufferAxis(const TensorView* tv) const;
35
+
36
+ //! Get all valid circular buffer TensorViews
37
+ std::vector<const TensorView*> getCircularBufferTvs() const;
38
+
39
+ //! Get a loop that matches with a given circular-buffer axis. If
40
+ //! ignore_prologue is true, a matched loop is ignored if it's a
41
+ //! prologue loop.
42
+ static ForLoop* getCircularBufferLoop(
43
+ IterDomain* axis,
44
+ const std::vector<ForLoop*>& loops,
45
+ bool ignore_prologue = false);
46
+
47
+ //! Get a loop that matches with the circular-buffer axis of a given
48
+ //! circular-buffered tensor. If ignore_prologue is true, a matched
49
+ //! loop is ignored if it's a prologue loop.
50
+ ForLoop* getCircularBufferLoop(
51
+ const TensorView* tv,
52
+ const std::vector<ForLoop*>& loops,
53
+ bool ignore_prologue = false);
54
+
55
+ //! Get the circular-buffered tensors for the given loop/axis.
56
+ std::unordered_set<const TensorView*> getCircularBufferTvs(
57
+ ForLoop* axis) const;
58
+ std::unordered_set<const TensorView*> getCircularBufferTvs(
59
+ IterDomain* axis) const;
60
+
61
+ void setOriginalAllocSize(const TensorView* tv, Val* size);
62
+
63
+ Val* getOriginalAllocSize(const TensorView* tv);
64
+
65
+ //! Returns true if the iterdomain will be realized
66
+ //! as a circular buffer loop.
67
+ bool isCircularBufferedIterDomain(IterDomain* id);
68
+
69
+ //! Get the circular buffer options for the given axis.
70
+ const CircularBufferOptions& getCircularBufferOptionsFor(
71
+ IterDomain* circular_buffered_id) const;
72
+
73
+ std::string toString() const;
74
+
75
+ private:
76
+ const TvInfo& getTvInfo(const TensorView* tv) const;
77
+
78
+ TvInfo& getTvInfo(const TensorView* tv);
79
+
80
+ //! Set the number of circular buffer options for the given
81
+ //! circular_buffered_id. Current code generation only supports one option per
82
+ //! loop disjoint set, so this function will throw an error if trying to set
83
+ //! different options to iterdomains that are loop mapped.
84
+ void setCircularBufferOptions(
85
+ IterDomain* circular_buffered_id,
86
+ const CircularBufferOptions& opt);
87
+
88
+ private:
89
+ //! Keeps track of information for lowering circular buffered tensors
90
+ std::unordered_map<const TensorView*, TvInfo> map_;
91
+
92
+ //! Keeps track of which concrete loop map is realizing circular buffer
93
+ //! iterdomains.
94
+ std::unordered_set<const IterDomain*> concrete_circular_buffered_loop_id_;
95
+
96
+ //! Keeps track of circular buffer loop stage depth and prefetch distance.
97
+ //! Currently for each disjoint set of loop mapped iterdomains,
98
+ //! Only one stage depth and prefetch distance is supported, so that the loops
99
+ //! can indeed shared with the same prolog extent and main loop offset.
100
+ std::unordered_map<IterDomain*, CircularBufferOptions>
101
+ circular_buffer_options_;
102
+
103
+ //! Keeps track of circular buffer tvs for each disjoint set of loop mapped
104
+ //! iterdomains.
105
+ std::unordered_map<IterDomain*, std::unordered_set<const TensorView*>>
106
+ circular_buffer_tvs_;
107
+ };
108
+
109
+ } // namespace nvfuser
@@ -0,0 +1,65 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/macros/Export.h>
11
+ #include <exceptions.h>
12
+
13
+ #include <compute_at_map.h>
14
+ #include <dispatch.h>
15
+ #include <ir/all_nodes.h>
16
+ #include <kernel_ir.h>
17
+
18
+ #include <vector>
19
+
20
+ namespace nvfuser {
21
+
22
+ class LoopIndexing;
23
+ class ComputeAtMap;
24
+
25
+ //! Traverses a Fusion to find the minimum supported CUDA compute capability
26
+ //! that will be able to run the generated kernel.
27
+ class MinimumDeviceVersion : private IterVisitor {
28
+ public:
29
+ static std::pair<std::pair<int, int>, std::string> compute(Fusion* fusion) {
30
+ MinimumDeviceVersion mdv;
31
+ mdv.traverse(fusion);
32
+ return {mdv.min_version_, mdv.reason_};
33
+ }
34
+
35
+ private:
36
+ using IterVisitor::dispatch;
37
+ using IterVisitor::handle;
38
+ using IterVisitor::traverse;
39
+
40
+ //! Check dtypes of all Vals. BFloat16 requires Ampere (8.0+), Float8_xxx
41
+ //! requires Hopper (9.0+)
42
+ void dispatch(Val* v) final;
43
+
44
+ //! MmaOp currently supports Turing and newer (7.5+) depending on macro
45
+ void handle(MmaOp* mma_op) final;
46
+
47
+ //! LoadStoreOpType::CpAsync requires Ampere (8.0+)
48
+ //! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
49
+ void handle(LoadStoreOp* ls_op) final;
50
+
51
+ //! If TensorView has warp specialized circular buffering, it will use the
52
+ //! setmaxnreg ptx instruction that requires Hopper (9.0+).
53
+ //! https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
54
+ void handle(TensorView* tv) final;
55
+
56
+ //! bump min_version_ to at least this value
57
+ void ensureVersion(std::pair<int, int> version, std::string reason);
58
+
59
+ private:
60
+ std::pair<int, int> min_version_ = {7, 0};
61
+ std::string reason_ =
62
+ "nvFuser supports Volta and above (compute capability 7.0+)";
63
+ };
64
+
65
+ } // namespace nvfuser