nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,334 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <utils.h>
12
+ #include <visibility.h>
13
+
14
+ #include <complex>
15
+ #include <unordered_map>
16
+
17
+ // dispatch.h prevents the need from adding manual dispatch in every class that
18
+ // wants to define how to process a series of nodes. dispatch.h provides 4
19
+ // classes that can be inherited providing a means to override functions on a
20
+ // per-node basis. There are currently 4 provided dispatch mechanisms:
21
+ //
22
+ // OptOutDispatch:
23
+ //
24
+ // provides the functions:
25
+ // virtual void handle(ValType* irnode){}
26
+ //
27
+ // This provides a mechanisms to override this handle for particular node
28
+ // types. For example if we only wanted to actually run a function on
29
+ // BinaryOps, we could inherit OptOutDispatch and simply override: void
30
+ // handle(BinaryOp*) { doSomething; } Then we could run through all our
31
+ // Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
32
+ // encountered our override function will be called. For every other node,
33
+ // nothing will be done.
34
+ //
35
+ // OptInDispatch:
36
+ //
37
+ // This class is similar to OptOutDispatch, however if we encounter a node
38
+ // that we haven't specified an override for in the derived class, an error
39
+ // will be thrown. This is useful if we create a class that is expected to
40
+ // handle any type of node it encounters.
41
+ //
42
+ // OptOutMutator:
43
+ //
44
+ // This class is similar to OptOutDispatch except the functions provided are of
45
+ // type: virtual Statement* mutate(Statement*) this is useful for when we want
46
+ // to have an IR node result from our overloaded functions.
47
+ //
48
+ // OptInMutator:
49
+ //
50
+ // This class is similar to OptInDispatch except the functions provided are of
51
+ // type: virtual Statement* mutate(Statement*) this is useful for when we want
52
+ // to have an IR node result from our overloaded functions.
53
+
54
+ namespace nvfuser {
55
+ class IrContainer;
56
+ class Fusion;
57
+
58
+ // Hierarchal dispatch functions for handle
59
+ class Statement;
60
+ class Expr;
61
+ class Val;
62
+
63
+ #define DISPATCH_FOR_ALL_VALS(f) \
64
+ f(IterDomain); \
65
+ f(TensorDomain); \
66
+ f(TensorView); \
67
+ f(NamedScalar);
68
+ #define DISPATCH_FOR_ALL_KIR_VALS(f) f(Predicate) f(TensorIndex)
69
+ #define DISPATCH_FOR_ALL_HIR_VALS(f) f(Stream)
70
+
71
+ #define DISPATCH_FOR_ALL_EXPRS(f) \
72
+ f(FullOp); \
73
+ f(IotaOp); \
74
+ f(EyeOp); \
75
+ f(UnaryOp); \
76
+ f(BinaryOp); \
77
+ f(TernaryOp); \
78
+ f(ArrayConstruct); \
79
+ f(StructConstruct); \
80
+ f(GetAttr); \
81
+ f(GetItem); \
82
+ f(ReverseArray); \
83
+ f(GetMetaData); \
84
+ f(TensorConstruct); \
85
+ f(SelectOp); \
86
+ f(IndexSelectOp); \
87
+ f(TorchGatherOp); \
88
+ f(ScatterOp); \
89
+ f(RNGOp); \
90
+ f(ReductionOp); \
91
+ f(GroupedReductionOp); \
92
+ f(WelfordOp); \
93
+ f(GroupedWelfordOp); \
94
+ f(LoadStoreOp); \
95
+ f(MmaOp); \
96
+ f(BroadcastOp); \
97
+ f(SqueezeOp); \
98
+ f(ExpandOp); \
99
+ f(RepeatOp); \
100
+ f(ViewAsScalar); \
101
+ f(ViewOp); \
102
+ f(CatOp); \
103
+ f(PadOp); \
104
+ f(SliceOp); \
105
+ f(Split); \
106
+ f(Merge); \
107
+ f(Swizzle); \
108
+ f(Swizzle2D); \
109
+ f(Resize); \
110
+ f(MatmulOp); \
111
+ f(LinearOp); \
112
+ f(SdpaFwdOp); \
113
+ f(SdpaBwdOp); \
114
+ f(EmbeddingFwdOp); \
115
+ f(Communication); \
116
+ f(ForLoop); \
117
+ f(P2PCommunication);
118
+ #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
119
+ f(Allocate); \
120
+ f(Asm); \
121
+ f(BlockSync); \
122
+ f(GridSync); \
123
+ f(FenceAsyncProxy); \
124
+ f(WgMmaFence); \
125
+ f(SetMaxNReg); \
126
+ f(Return); \
127
+ f(MBarrierInit); \
128
+ f(MBarrierInvalidate); \
129
+ f(MBarrierArrive); \
130
+ f(MBarrierArriveExpectTx); \
131
+ f(MBarrierWait); \
132
+ f(MBarrierWaitParity); \
133
+ f(BlockSerializeWait); \
134
+ f(BlockSerializeRelease); \
135
+ f(AsyncWait); \
136
+ f(AsyncCommit); \
137
+ f(IfThenElse); \
138
+ f(GridReduction); \
139
+ f(GroupedGridReduction); \
140
+ f(GridBroadcast); \
141
+ f(GridWelford); \
142
+ f(GroupedGridWelford); \
143
+ f(VectorizedWelfordOp); \
144
+ f(AllocateFusedReduction); \
145
+ f(InitMagicZero); \
146
+ f(UpdateMagicZero); \
147
+ f(GetRNGSeedAndOffsetFromHost); \
148
+ f(EncodeTensorMapTiled);
149
+ #define DISPATCH_FOR_ALL_HIR_EXPRS(f) \
150
+ f(HostUnit); \
151
+ f(PostOnStream); \
152
+ f(LaunchKernel); \
153
+ f(SetCurrentStream); \
154
+ f(GetCurrentStream); \
155
+ f(Wait); \
156
+ f(Synchronize); \
157
+ f(StartCoalescing); \
158
+ f(EndCoalescing);
159
+
160
+ // Forward declarations for all Val and Expr types
161
+
162
+ #define M(e) class e;
163
+ DISPATCH_FOR_ALL_VALS(M);
164
+ DISPATCH_FOR_ALL_EXPRS(M);
165
+ #undef M
166
+
167
+ namespace kir {
168
+
169
+ #define M(e) class e;
170
+ DISPATCH_FOR_ALL_KIR_VALS(M)
171
+ DISPATCH_FOR_ALL_KIR_EXPRS(M)
172
+ #undef M
173
+
174
+ } // namespace kir
175
+
176
+ namespace hir {
177
+
178
+ #define M(e) class e;
179
+ DISPATCH_FOR_ALL_HIR_VALS(M)
180
+ DISPATCH_FOR_ALL_HIR_EXPRS(M)
181
+ #undef M
182
+
183
+ } // namespace hir
184
+
185
+ namespace assoc_comm {
186
+ class FlattenedAssocCommOp;
187
+ } // namespace assoc_comm
188
+
189
+ // By default, all IR nodes are handled in this dispatch, and will call an empty
190
+ // function on all nodes.
191
+ class OptOutConstDispatch : public PolymorphicBase {
192
+ protected:
193
+ virtual void unhandled(const Statement*) {}
194
+
195
+ public:
196
+ // Hierarchal dispatch functions for handle
197
+ virtual void dispatch(const Statement*);
198
+ virtual void dispatch(const Expr*);
199
+ virtual void dispatch(const Val*);
200
+
201
+ #define M(e) virtual void handle(const e* stmt);
202
+ M(Val);
203
+ DISPATCH_FOR_ALL_VALS(M)
204
+ DISPATCH_FOR_ALL_EXPRS(M)
205
+ M(assoc_comm::FlattenedAssocCommOp);
206
+ #undef M
207
+ #define M(e) virtual void handle(const kir::e* stmt);
208
+ DISPATCH_FOR_ALL_KIR_VALS(M)
209
+ DISPATCH_FOR_ALL_KIR_EXPRS(M)
210
+ #undef M
211
+ #define M(e) virtual void handle(const hir::e* stmt);
212
+ DISPATCH_FOR_ALL_HIR_VALS(M)
213
+ DISPATCH_FOR_ALL_HIR_EXPRS(M)
214
+ #undef M
215
+ };
216
+
217
+ class NVF_API OptOutDispatch : public PolymorphicBase {
218
+ protected:
219
+ virtual void unhandled(Statement*);
220
+
221
+ public:
222
+ // Hierarchal dispatch functions for handle
223
+ virtual void dispatch(Statement*);
224
+ virtual void dispatch(Expr*);
225
+ virtual void dispatch(Val*);
226
+
227
+ #define M(e) virtual void handle(e* stmt);
228
+ M(Val);
229
+ DISPATCH_FOR_ALL_VALS(M)
230
+ DISPATCH_FOR_ALL_EXPRS(M)
231
+ M(assoc_comm::FlattenedAssocCommOp);
232
+ #undef M
233
+ #define M(e) virtual void handle(kir::e* stmt);
234
+ DISPATCH_FOR_ALL_KIR_VALS(M)
235
+ DISPATCH_FOR_ALL_KIR_EXPRS(M)
236
+ #undef M
237
+ #define M(e) virtual void handle(hir::e* stmt);
238
+ DISPATCH_FOR_ALL_HIR_VALS(M)
239
+ DISPATCH_FOR_ALL_HIR_EXPRS(M)
240
+ #undef M
241
+ };
242
+
243
+ class OptInConstDispatch : public OptOutConstDispatch {
244
+ public:
245
+ using OptOutConstDispatch::handle;
246
+
247
+ protected:
248
+ void unhandled(const Statement* stmt) final;
249
+ };
250
+
251
+ class OptInDispatch : public OptOutDispatch {
252
+ public:
253
+ using OptOutDispatch::handle;
254
+
255
+ protected:
256
+ void unhandled(Statement* stmt) final;
257
+ };
258
+
259
+ // Class to perform mutations on Fusion IR. Exprs can simply be redefined, but
260
+ // when mutating values they have to be registered through registerMutation so
261
+ // that exprs can detect there's been a muatation and know to modify all
262
+ // instances of that Val. This means each Val should be mutated "consistently".
263
+ // Otherwise behavior may be difficult to understand as it depends on which
264
+ // order mutate is called in. This class expects user to topologically call the
265
+ // statments of interest so inputs are called and mutated before exprs depending
266
+ // on them.
267
+ //
268
+ // Warning: TensorViews need to be treated carefully. As we don't generally
269
+ // register their mutation when their tensor domains only change. If a TV needs
270
+ // to be swapped out, it needs to be registered as a "proper" mutation like
271
+ // other vals, on top of TensorDomain being updated in the mutated TensorView.
272
+ //
273
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
274
+ class NVF_API OptOutMutator : public PolymorphicBase {
275
+ public:
276
+ // Hierarchal dispatch functions for handle
277
+ virtual void dispatchMutate(Statement* s);
278
+ virtual void dispatchMutate(Val* v);
279
+
280
+ void registerMutation(Val* val, Val* mutation);
281
+
282
+ Val* maybeMutated(Val* val) const;
283
+
284
+ std::unordered_map<Val*, Val*> mutations_;
285
+
286
+ //****Functions below defined in mutator.cpp*****
287
+
288
+ // Vals
289
+ virtual void mutate(Val*);
290
+
291
+ #define M(e) virtual void mutate(e* stmt);
292
+ DISPATCH_FOR_ALL_VALS(M)
293
+ #undef M
294
+ #define M(e) virtual void mutate(kir::e* stmt);
295
+ DISPATCH_FOR_ALL_KIR_VALS(M)
296
+ #undef M
297
+
298
+ //! This method replaces e if any inputs or attributes are registered for
299
+ //! mutation.
300
+ virtual void mutate(Expr* e) {
301
+ mutateExpr(
302
+ e,
303
+ /*replace_outputs*/ false,
304
+ /*replace_inputs*/ true,
305
+ /*replace_attrs*/ true);
306
+ }
307
+
308
+ //! Unlike mutate(Expr*), this method replaces e only if any outputs are
309
+ //! registered for mutation. Inputs and attributes are unchanges. This method
310
+ //! is useful for tranferring the definition of e's current outputs to those
311
+ //! their respective registered mutations.
312
+ Expr* mutateExprOutputsOnly(Expr* e) {
313
+ return mutateExpr(
314
+ e,
315
+ /*replace_outputs*/ true,
316
+ /*replace_inputs*/ false,
317
+ /*replace_attrs*/ false);
318
+ }
319
+
320
+ protected:
321
+ virtual void removeExpr(IrContainer*, Expr*) const;
322
+ virtual void registerNewExpr(Expr*) {}
323
+
324
+ private:
325
+ //! Replaces Expr if any inputs, attrs, or outputs are registered for
326
+ //! mutation. See comment on mutateExprOutputsOnly for more information.
327
+ Expr* mutateExpr(
328
+ Expr*,
329
+ bool replace_outputs = false,
330
+ bool replace_inputs = true,
331
+ bool replace_attrs = true);
332
+ };
333
+
334
+ } // namespace nvfuser
@@ -0,0 +1,49 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <cuda.h>
11
+
12
+ // How to lazily load a driver API and invoke it? Just forget about lazy loading
13
+ // and write code as if you are using the driver API directly. Magic will
14
+ // happen. To understand how the magic works, please refer to the cpp file's doc
15
+ // "How does the magic work?"
16
+
17
+ namespace nvfuser {
18
+
19
+ #define DECLARE_DRIVER_API_WRAPPER(funcName) \
20
+ extern decltype(::funcName)* funcName;
21
+
22
+ // List of driver APIs that you want the magic to happen.
23
+ #define ALL_DRIVER_API_WRAPPER_CUDA11(fn) \
24
+ fn(cuDeviceGetAttribute); \
25
+ fn(cuDeviceGetName); \
26
+ fn(cuFuncGetAttribute); \
27
+ fn(cuFuncSetAttribute); \
28
+ fn(cuGetErrorName); \
29
+ fn(cuGetErrorString); \
30
+ fn(cuLaunchCooperativeKernel); \
31
+ fn(cuLaunchKernel); \
32
+ fn(cuModuleGetFunction); \
33
+ fn(cuModuleLoadDataEx); \
34
+ fn(cuModuleUnload); \
35
+ fn(cuOccupancyMaxActiveBlocksPerMultiprocessor)
36
+
37
+ #if (CUDA_VERSION >= 12000)
38
+ #define ALL_DRIVER_API_WRAPPER(fn) \
39
+ ALL_DRIVER_API_WRAPPER_CUDA11(fn); \
40
+ fn(cuTensorMapEncodeTiled)
41
+ #else
42
+ #define ALL_DRIVER_API_WRAPPER ALL_DRIVER_API_WRAPPER_CUDA11
43
+ #endif
44
+
45
+ ALL_DRIVER_API_WRAPPER(DECLARE_DRIVER_API_WRAPPER);
46
+
47
+ #undef DECLARE_DRIVER_API_WRAPPER
48
+
49
+ } // namespace nvfuser
@@ -0,0 +1,316 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <visibility.h>
12
+
13
+ #include <expr_evaluator.h>
14
+ #include <ir/all_nodes.h>
15
+ #include <ir/cloner.h>
16
+ #include <ir/iostream.h>
17
+ #include <iter_visitor.h>
18
+ #include <logical_domain_map.h>
19
+ #include <transform_view.h>
20
+ #include <utils.h>
21
+
22
+ #include <functional>
23
+ #include <memory>
24
+ #include <vector>
25
+
26
+ namespace nvfuser {
27
+
28
+ class Fusion;
29
+ class DynamicTransformInitialInfoBuilder;
30
+
31
+ //! Initial information derived only from the symbolic Fusion without input
32
+ //! sizes
33
+ class DynamicTransformInitialInfo {
34
+ public:
35
+ Fusion* fusion() const {
36
+ return fusion_;
37
+ }
38
+
39
+ //! Return whether any dynamic transforms exist in the Fusion, or whether
40
+ //! there are any tensors which could potentially be empty (size-0 extent)
41
+ //! given some user input. In either of these cases, concretization may change
42
+ //! the structure of the Fusion.
43
+ bool isDynamic() const {
44
+ return hasPossibleEmptyTensor() || !dynamic_reshaped_tvs_.empty() ||
45
+ !dynamic_resized_ids_.empty();
46
+ }
47
+
48
+ //! Return whether there are any tensors with unknown extent in some
49
+ //! dimension, so that they might be empty
50
+ bool hasPossibleEmptyTensor() const {
51
+ return !maybe_zero_extents_.empty();
52
+ }
53
+
54
+ //! Return a set of scalars that are inputs or extents of input TensorViews
55
+ //! and that appear in inputs to dynamic expressions. Any Vals not in this
56
+ //! list do not affect concretization.
57
+ const std::unordered_set<Val*>& getRootDynamicVals() const {
58
+ return root_dynamic_vals_;
59
+ }
60
+
61
+ //! Return a set of scalars that appear as extents in TensorViews in the
62
+ //! Fusion. If any of these evaluate to zero, there is at least one empty
63
+ //! TensorView present.
64
+ const std::vector<Val*>& getMaybeZeroExtents() const {
65
+ return maybe_zero_extents_;
66
+ }
67
+
68
+ //! Return a vector of outputs of ViewOp expressions that have dynamic output
69
+ //! shapes
70
+ const std::vector<TensorView*>& getDynamicReshapedTensorViews() const {
71
+ return dynamic_reshaped_tvs_;
72
+ }
73
+
74
+ //! Return a vector of outputs of Resize expressions that have symbolic output
75
+ //! IterTypes
76
+ const std::vector<IterDomain*>& getDynamicResizedIterDomains() const {
77
+ return dynamic_resized_ids_;
78
+ }
79
+
80
+ //! Return a vector of outputs of ExpandOp expressions that have Symbolic
81
+ //! output IterTypes
82
+ const std::vector<TensorView*>& getDynamicExpandedTensorViews() const {
83
+ return dynamic_expanded_tvs_;
84
+ }
85
+
86
+ //! Return a vector of outputs of factory expressions like full, iota,
87
+ //! normal, and uniform that have Symbolic output IterTypes
88
+ const std::vector<TensorView*>& getDynamicFactoryOutputs() const {
89
+ return dynamic_factory_tvs_;
90
+ }
91
+
92
+ std::string toString() const;
93
+
94
+ DynamicTransformInitialInfo clone(IrCloner& ir_cloner) const;
95
+
96
+ //! Return a set containing positions in inputs() holding any scalar input
97
+ //! that would affect the structure of the concretized Fusion.
98
+ const std::unordered_set<size_t>& scalarInputsAffectingConcretization()
99
+ const {
100
+ return scalar_inputs_affecting_concretization_;
101
+ }
102
+
103
+ protected:
104
+ //! Holds the set of scalar fusion inputs that affect concretization.
105
+ std::unordered_set<size_t> scalar_inputs_affecting_concretization_;
106
+
107
+ private:
108
+ DynamicTransformInitialInfo(Fusion* fusion) : fusion_(fusion) {}
109
+
110
+ private:
111
+ Fusion* fusion_ = nullptr;
112
+
113
+ // We hold vectors of the _outputs_ of dynamic ops. The reason we don't hold
114
+ // the ops themselves is that during concretization, the ops will actually be
115
+ // removed by ir_utils::replaceValInExprInputs. The outputs will not: their
116
+ // definitions will merely be altered. When the ops are replaced, if we had
117
+ // referred to them directly here, we would run into segfaults. Referring only
118
+ // to the outputs avoids this issue.
119
+ std::vector<TensorView*> dynamic_reshaped_tvs_;
120
+
121
+ std::vector<IterDomain*> dynamic_resized_ids_;
122
+
123
+ std::vector<TensorView*> dynamic_expanded_tvs_;
124
+
125
+ std::vector<TensorView*> dynamic_factory_tvs_;
126
+
127
+ // This is a minimal set of scalars to check for empty tensors. If any are
128
+ // zero, we should traverse to find empty tensors.
129
+ std::unordered_set<Val*> maybe_zero_extents_set_;
130
+ // The set above is populated then used to create this unique vector
131
+ std::vector<Val*> maybe_zero_extents_;
132
+
133
+ // Root Vals that determine concretization
134
+ std::unordered_set<Val*> root_dynamic_vals_;
135
+
136
+ friend class DynamicTransformInitialInfoBuilder;
137
+ };
138
+
139
+ //! A set of transformations for a symbolic fusion with concrete sizes
140
+ //! of the fusion inputs
141
+ class DynamicTransformConcretizationInfo {
142
+ public:
143
+ NVF_API DynamicTransformConcretizationInfo(
144
+ const DynamicTransformInitialInfo* initial_info,
145
+ ExpressionEvaluator* expr_eval,
146
+ ExactLogicalDomainMap* exact_map = nullptr);
147
+
148
+ //! Return a vector of integers each corresponding to the position in
149
+ //! initialInfo()->getMaybeZeroExtents() of an extent Val which is guaranteed
150
+ //! to be zero.
151
+ const std::vector<int64_t>& getEmptyExtents() const {
152
+ return empty_extents_;
153
+ }
154
+
155
+ //! Return a vector of pairs holding the index of each reshaped TensorView in
156
+ //! the vector returned by initialInfo()->getDynamicReshapedTensorViews(),
157
+ //! along with an AnalyzeViewResult describing how that reshape operation
158
+ //! should be decomposed into split, merge, squeeze, and broadcast transforms.
159
+ //!
160
+ //! In case there are any zeros in the size of the input and output we will
161
+ //! not perform a reshape but rather replace the output with full(). Then
162
+ //! instead of an AnalyzeViewResult we will hold a vector of symbolic sizes
163
+ //! indicating how to concretize the output IterDomains.
164
+ //!
165
+ //! The symbolic sizes are the actual sizes 0 or 1, or -1 if the size of a
166
+ //! given reshaped dimension is greater than 1.
167
+ using ViewConcretizationInfo =
168
+ std::variant<AnalyzeViewResult, std::vector<int64_t>>;
169
+ const std::vector<std::pair<int64_t, ViewConcretizationInfo>>&
170
+ getReshapeTransforms() const {
171
+ return reshape_transforms_;
172
+ }
173
+
174
+ //! Return a vector of pairs holding the index of each resized IterDomain in
175
+ //! the vector returned by initialInfo()->getDynamicResizedIterDomains(),
176
+ //! along with the IterType it should be concretized to.
177
+ const std::vector<std::pair<int64_t, IterType>>& getResizeIterTypes() const {
178
+ return resize_itertypes_;
179
+ }
180
+
181
+ //! Return a vector of pairs holding the index of each expanded TensorView in
182
+ //! the vector returned by initialInfo()->getDynamicExpandedTensorViews(),
183
+ //! along with a vector of bools describing whether each axis in the output
184
+ //! root domain is expanded.
185
+ const std::vector<std::pair<int64_t, std::vector<bool>>>& getExpandAxes()
186
+ const {
187
+ return expand_axes_;
188
+ }
189
+
190
+ //! Return a vector of vectors of pairs. Each vector of pairs corresponds to a
191
+ //! TensorView returned by by initialInfo()->getDynamicFactoryOutputs(). The
192
+ //! pairs contain an integer position of a Symbolic axis and the IterType that
193
+ //! axis will be converted to.
194
+ const std::vector<std::vector<std::pair<int64_t, IterType>>>&
195
+ getFactoryOutputIterTypes() const {
196
+ return factory_output_itertypes_;
197
+ }
198
+
199
+ //! Comparison operator for the purposes of determining cache hits. This does
200
+ //! not guarantee equality of all members. Instead, it returns equal if the
201
+ //! resulting concretizations would be structurally equivalent. Note that
202
+ //! pointers to Statements may differ between equivalent concretizations due
203
+ //! to cloning before concretization.
204
+ NVF_API bool operator==(
205
+ const DynamicTransformConcretizationInfo& other) const;
206
+
207
+ bool operator!=(const DynamicTransformConcretizationInfo& other) const {
208
+ return !(*this == other);
209
+ }
210
+
211
+ //! Given an ExpressionEvaluator which already has input scalars bound to it,
212
+ //! determine the decomposition of each dynamic reshape operation to use
213
+ //! during concretization.
214
+ void analyzeReshapes(ExpressionEvaluator* expr_eval);
215
+
216
+ //! Given an ExpressionEvaluator which already has input scalars bound to it,
217
+ //! determine the concrete IterType of each resized IterDomain.
218
+ void analyzeResizes(ExpressionEvaluator* expr_eval);
219
+
220
+ //! Given an ExpressionEvaluator which already has input scalars bound to it,
221
+ //! determine which axes of dynamic expand operations are expanded.
222
+ void analyzeExpands(ExpressionEvaluator* expr_eval);
223
+
224
+ //! Given an ExpressionEvaluator which already has input scalars bound to it,
225
+ //! determine the IterTypes of factory function outputs.
226
+ void analyzeFactoryOutputs(ExpressionEvaluator* expr_eval);
227
+
228
+ const DynamicTransformInitialInfo* initialInfo() const {
229
+ return initial_info_;
230
+ }
231
+
232
+ void setInitialInfo(const DynamicTransformInitialInfo* initial_info) {
233
+ initial_info_ = initial_info;
234
+ }
235
+
236
+ Fusion* fusion() const {
237
+ return initial_info_->fusion();
238
+ }
239
+
240
+ NVF_API std::string toString() const;
241
+
242
+ NVF_API size_t hash() const;
243
+
244
+ private:
245
+ DynamicTransformConcretizationInfo(
246
+ const DynamicTransformInitialInfo* initial_info)
247
+ : initial_info_(initial_info) {}
248
+
249
+ private:
250
+ const DynamicTransformInitialInfo* initial_info_ = nullptr;
251
+
252
+ //! Holds the index of the output TensorView in the vector returned by
253
+ //! initial_info_->getDynamicReshapedTensorViews(), and the corresponding
254
+ //! result of analyzeView (or list of IterTypes for output of full() in the
255
+ //! case of empty reshapes).
256
+ std::vector<std::pair<int64_t, ViewConcretizationInfo>> reshape_transforms_;
257
+
258
+ //! Holds a vector of indices into initial_info_.getMaybeZeroExtents() which
259
+ //! evaluate to 0
260
+ std::vector<int64_t> empty_extents_;
261
+
262
+ //! Holds the index of the resized IterDomain (output of the Resize op) in the
263
+ //! vector returned by initial_info_->getDynamicResizedIterDomains() along
264
+ //! with its concretized IterType
265
+ std::vector<std::pair<int64_t, IterType>> resize_itertypes_;
266
+
267
+ //! Holds the index of the expanded TensorView in the vector returned by
268
+ //! initial_info_->getDynamicExpandedTensorViews(), and a corresponding vector
269
+ //! of bools indicating whether each axis is in fact expanded.
270
+ std::vector<std::pair<int64_t, std::vector<bool>>> expand_axes_;
271
+
272
+ //! Holds the axis and IterType corresponding to each TensorView returned by
273
+ //! initial_info_->getDynamicFactoryOutputs().
274
+ std::vector<std::vector<std::pair<int64_t, IterType>>>
275
+ factory_output_itertypes_;
276
+
277
+ friend class DynamicTransformInfoBuilder;
278
+ };
279
+
280
+ class DynamicTransform {
281
+ public:
282
+ //! Get initial information before we have inputs. This analyzes the Fusion to
283
+ //! determine whether it has dynamic operations, and caches their position for
284
+ //! faster concretization once inputs are available.
285
+ NVF_API static DynamicTransformInitialInfo getInitialInfo(Fusion* fusion);
286
+
287
+ //! Concretizes a given fusion. Note that the concretization is
288
+ //! in-place and the given fusion is modified. Return a map from old, symbolic
289
+ //! values to new, concrete values.
290
+ NVF_API static std::unordered_map<Val*, Val*> concretizeFusion(
291
+ Fusion* fusion,
292
+ const DynamicTransformConcretizationInfo* info);
293
+
294
+ //! Calls the above after computing concretization info from inputs
295
+ static std::unordered_map<Val*, Val*> concretizeFusion(
296
+ Fusion* fusion,
297
+ const std::vector<c10::IValue>& aten_inputs);
298
+
299
+ //! Calls the above after computing concretization info from
300
+ //! KernelArgumentHolder
301
+ static std::unordered_map<Val*, Val*> concretizeFusion(
302
+ Fusion* fusion,
303
+ const KernelArgumentHolder& args);
304
+ };
305
+
306
+ } // namespace nvfuser
307
+
308
+ namespace std {
309
+ template <>
310
+ struct hash<nvfuser::DynamicTransformConcretizationInfo> {
311
+ size_t operator()(
312
+ const nvfuser::DynamicTransformConcretizationInfo& info) const {
313
+ return info.hash();
314
+ }
315
+ };
316
+ } // namespace std