nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,3124 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/util/complex.h>
11
+ #include <debug.h>
12
+ #include <exceptions.h>
13
+ #include <ir/interface_nodes.h>
14
+ #include <ops/all_ops.h>
15
+ #include <options.h>
16
+ #include <python_frontend/fusion_definition.h>
17
+ #include <python_frontend/fusion_state.h>
18
+ #include <serde/fusion_cache_generated.h>
19
+ #include <serde/polymorphic_value.h>
20
+ #include <serde/utils.h>
21
+ #include <utils.h>
22
+
23
+ #include <algorithm>
24
+ #include <complex>
25
+ #include <variant>
26
+
27
+ namespace nvfuser::python_frontend {
28
+
29
+ //! RecordFunctor is the base class record for operations recorded by
30
+ //! the FusionState. It is, in essence, a node in the graph with
31
+ //! input edges, args, and output edges where the stored
32
+ //! values are indices into the recorded state.
33
+ //!
34
+ //! The virtual functor operator is executed on a cache miss to build the
35
+ //! appropriate part of the nvFuser Fusion IR for a given record.
36
+ //!
37
+ //! The hash and equality operators are used to facilitate the hashing of
38
+ //! RecordFunctors in a hash map given those operators need to be
39
+ //! specified for custom objects.
40
+ //!
41
+ //! The print function is used to print the given Record as a statement
42
+ //! in a python formated function.
43
+
44
+ struct RecordFunctor {
45
+ RecordFunctor(
46
+ std::vector<State> _args,
47
+ std::vector<State> _outputs,
48
+ std::string _name,
49
+ serde::RecordType _record_type,
50
+ bool _inline_def = false)
51
+ : args_(std::move(_args)),
52
+ arg_names_(args_.size()),
53
+ outputs_(std::move(_outputs)),
54
+ name_(std::move(_name)),
55
+ record_type_(_record_type),
56
+ inline_def_(
57
+ _inline_def &&
58
+ !isOptionDisabled(DisableOption::PythonInlineDefinitions)) {
59
+ // Set this Record as the parent of each output
60
+ if (inline_def_) {
61
+ for (auto& out : outputs_) {
62
+ out.setInlineDefRecord(this);
63
+ }
64
+ }
65
+ }
66
+ RecordFunctor(const RecordFunctor& other)
67
+ : args_(other.args_),
68
+ arg_names_(other.arg_names_),
69
+ outputs_(other.outputs_),
70
+ name_(other.name_),
71
+ record_type_(other.record_type_),
72
+ inline_def_(other.inline_def_) {
73
+ // Set this Record as the parent of each output
74
+ if (inline_def_) {
75
+ for (auto& out : outputs_) {
76
+ out.setInlineDefRecord(this);
77
+ }
78
+ }
79
+ }
80
+ virtual ~RecordFunctor() = default;
81
+ //! Allows for copying of Child Class objects with RecordFunctor pointers.
82
+ virtual RecordFunctor* clone() = 0;
83
+
84
+ //! The base class is placing the type, outputs, and args hashed as follows:
85
+ //! | 63 - 56 | 55 - 48 | 47 ----------- 32 | 32 ------------------------ 0 |
86
+ //! | Type | Outputs | Args | Child Class Specified |
87
+ virtual size_t hash() const {
88
+ size_t arg_hash = 0;
89
+ for (auto arg : args_) {
90
+ arg_hash ^= ((arg.index << 1) ^ static_cast<size_t>(arg.stype));
91
+ }
92
+ size_t output_hash = 0;
93
+ for (auto output : outputs_) {
94
+ output_hash ^= ((output.index << 1) ^ static_cast<size_t>(output.stype));
95
+ }
96
+ // NOTE: The inline_def is not part of the hash as it is not used for
97
+ // comparison
98
+ return ((static_cast<size_t>(record_type_) & 0xff) << 56) |
99
+ ((output_hash & 0xff) << 48) | ((arg_hash & 0xffff) << 32);
100
+ }
101
+
102
+ //! The base virtual equality operator is defined so all child
103
+ //! classes can utilize the check for the same args and outputs.
104
+ virtual bool operator==(const RecordFunctor& other) const {
105
+ auto result = (record_type_ == other.record_type_);
106
+ result = result && (args_.size() == other.args_.size()) &&
107
+ (outputs_.size() == other.outputs_.size());
108
+ result = result && (arg_names_ == other.arg_names_);
109
+ if (result) {
110
+ for (size_t i = 0; i < args_.size(); ++i) {
111
+ if ((args_[i].index != other.args_[i].index) ||
112
+ (args_[i].stype != other.args_[i].stype)) {
113
+ result = false;
114
+ break;
115
+ }
116
+ }
117
+ }
118
+ if (result) {
119
+ for (size_t i = 0; i < outputs_.size(); ++i) {
120
+ if ((outputs_[i].index != other.outputs_[i].index) ||
121
+ (outputs_[i].stype != other.outputs_[i].stype)) {
122
+ result = false;
123
+ break;
124
+ }
125
+ }
126
+ }
127
+ // NOTE: The inline_def is not part of the equality operator as it is not
128
+ // used for comparison
129
+ return result;
130
+ }
131
+
132
+ //! Abstraction for an operation to build this record's nvFuser Fusion IR
133
+ //! piece if the recording has a cache miss.
134
+ virtual void operator()(FusionState& fd) = 0;
135
+
136
+ //! Abstraction for storing data specific to a record functor.
137
+ virtual std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
138
+ flatbuffers::FlatBufferBuilder& builder) const {
139
+ return {serde::RecordData::NONE, flatbuffers::Offset<void>()};
140
+ }
141
+
142
+ //! The base serialize function that handles args, outputs, name and
143
+ //! recordType. Child recordFunctors should overload the recordData function
144
+ //! if has supplementary attributes.
145
+ virtual flatbuffers::Offset<serde::RecordFunctor> serialize(
146
+ flatbuffers::FlatBufferBuilder& builder) const {
147
+ // See table definition for RecordFunctor in serde/fusion_cache.fbs
148
+
149
+ std::vector<serde::State> fb_args;
150
+ fb_args.reserve(args_.size());
151
+ for (auto& it : args_) {
152
+ fb_args.emplace_back(it.index, it.stype);
153
+ }
154
+ auto args_fb =
155
+ builder.CreateVectorOfStructs(fb_args.data(), fb_args.size());
156
+
157
+ std::vector<serde::State> fb_outputs;
158
+ fb_outputs.reserve(outputs_.size());
159
+ for (auto& it : outputs_) {
160
+ fb_outputs.emplace_back(it.index, it.stype);
161
+ }
162
+ auto outputs_fb =
163
+ builder.CreateVectorOfStructs(fb_outputs.data(), fb_outputs.size());
164
+
165
+ auto&& [record_data_type, record_data] = recordData(builder);
166
+
167
+ return serde::CreateRecordFunctor(
168
+ builder,
169
+ args_fb,
170
+ outputs_fb,
171
+ builder.CreateString(name_),
172
+ recordType(),
173
+ record_data_type,
174
+ record_data);
175
+ }
176
+
177
+ //! The base print function when printing Record for a given FusionState
178
+ //! in python formated code.
179
+ virtual void print(std::ostream& os, bool close_function = true) const {
180
+ NVF_ERROR(
181
+ !inline_def_,
182
+ "The default print function does not handle inline definitions!");
183
+ bool first_output = true;
184
+ for (auto& output : outputs_) {
185
+ if (first_output) {
186
+ first_output = false;
187
+ } else {
188
+ os << ", ";
189
+ }
190
+ os << output;
191
+ }
192
+ if (always_returns_tuple_) {
193
+ os << ",";
194
+ }
195
+ if (!outputs_.empty()) {
196
+ os << " = "
197
+ << "fd." << name_ << "(";
198
+ } else {
199
+ os << "fd." << name_ << "(";
200
+ }
201
+ bool first_arg = true;
202
+ size_t idx = 0;
203
+ for (auto& arg : args_) {
204
+ if (first_arg) {
205
+ first_arg = false;
206
+ } else {
207
+ os << ", ";
208
+ }
209
+ if (!arg_names_[idx].empty()) {
210
+ os << arg_names_[idx] << "=";
211
+ }
212
+ ++idx;
213
+ os << arg;
214
+ }
215
+ if (close_function) {
216
+ os << ")";
217
+ }
218
+ }
219
+
220
+ size_t numOutputs() const {
221
+ return outputs_.size();
222
+ }
223
+
224
+ const std::vector<State>& outputs() const {
225
+ return outputs_;
226
+ }
227
+ std::vector<State>& args() {
228
+ return args_;
229
+ }
230
+
231
+ serde::RecordType recordType() const {
232
+ return record_type_;
233
+ }
234
+
235
+ bool inlineDef() const {
236
+ return inline_def_;
237
+ }
238
+
239
+ //! Set the name of an argument. If given, it will be listed as a keyword
240
+ //! argument during printing using the given name as the key. Unnamed
241
+ //! arguments are the default, and are listed as positional arguments before
242
+ //! any named arguments.
243
+ void setArgName(size_t pos, std::string name) {
244
+ arg_names_.at(pos) = name;
245
+ }
246
+
247
+ protected:
248
+ //! Inputs that are indices into the FusionState's Recorded State.
249
+ std::vector<State> args_;
250
+ //! String name to print for arg in Python, if any. Defaults to empty.
251
+ std::vector<std::string> arg_names_;
252
+ //! Outputs that are indices into the FusionState's Recorded State.
253
+ std::vector<State> outputs_;
254
+ //! Record Name
255
+ std::string name_;
256
+ //! Record Type of child class used for hashing
257
+ //! enum class RecordType is defined in flatbuffer schema
258
+ serde::RecordType record_type_;
259
+ //! Indicates if a record was defined inline with another record for printing
260
+ bool inline_def_;
261
+ //! Whether this record type returns a tuple of unknown length. This is only
262
+ //! used for TensorSizesRecord.
263
+ bool always_returns_tuple_ = false;
264
+ };
265
+
266
+ //! The OpRecord RecordFunctor is the most widely used child class because
267
+ //! it utilizes varidiac template arguments to represent unary, binary,
268
+ //! ternary, and other similar flavors of operations in nvFuser that have
269
+ //! a mix of Tensor and Scalar arguments only.
270
+ //!
271
+ //! The additional data memeber of this child class records the function
272
+ //! signature of the nvFuser Arith Operation to be replayed upon a cache
273
+ //! miss by the functor operator() call.
274
+
275
+ template <class OutType, class... ArgTypes>
276
+ struct OpRecord : RecordFunctor {
277
+ OpRecord(
278
+ std::vector<State> _args,
279
+ std::vector<State> _outputs,
280
+ std::string _name,
281
+ serde::RecordType record_type,
282
+ std::function<OutType(ArgTypes...)> fusion_op)
283
+ : RecordFunctor(
284
+ std::move(_args),
285
+ std::move(_outputs),
286
+ _name,
287
+ record_type),
288
+ fusion_op_(fusion_op) {}
289
+ ~OpRecord() override = default;
290
+ RecordFunctor* clone() final {
291
+ return new OpRecord(*this);
292
+ }
293
+
294
+ //! Child specific hash function in lower 32 bits.= at::Symbol
295
+ //! | 31 ------------------------------------- 0 |
296
+ //! | Arith Function Sigs hash code |
297
+ size_t hash() const final {
298
+ auto result = RecordFunctor::hash();
299
+ return result | (fusion_op_.target_type().hash_code() & 0xffffffff);
300
+ }
301
+
302
+ bool operator==(const RecordFunctor& other) const final {
303
+ auto result = false;
304
+ // A succesfull cast indicates a RecordFunctor of the same child class
305
+ if (auto child_ptr = dynamic_cast<const OpRecord*>(&other)) {
306
+ result = RecordFunctor::operator==(other);
307
+ if (result) {
308
+ // Match the nvFuser arith function types
309
+ result = result &&
310
+ (fusion_op_.target_type() == child_ptr->fusion_op_.target_type());
311
+ if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) {
312
+ debug() << "\nOpRecord: " << name_ << " Target Type [self: 0x"
313
+ << fusion_op_.target_type().name() << "] [other: 0x"
314
+ << child_ptr->fusion_op_.target_type().name() << "] ";
315
+ }
316
+ // Match the nvFuser arith function pointers
317
+ // IMPORTANT! you need to dereference the target pointer in order
318
+ // to match the function
319
+ result = result &&
320
+ (*fusion_op_.template target<OutType (*)(ArgTypes...)>() ==
321
+ *child_ptr->fusion_op_
322
+ .template target<OutType (*)(ArgTypes...)>());
323
+ if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) {
324
+ debug()
325
+ << "Target Ptr [self: 0x" << std::hex
326
+ << (size_t)*fusion_op_.template target<OutType (*)(ArgTypes...)>()
327
+ << "] [other: 0x" << std::hex
328
+ << (size_t)*child_ptr->fusion_op_
329
+ .template target<OutType (*)(ArgTypes...)>()
330
+ << "]\n";
331
+ }
332
+ }
333
+ }
334
+ return result;
335
+ }
336
+
337
+ //! The variadic set of indices for the number of args for this op are
338
+ //! deduced by providing the index_sequence as a parameter. Similarly,
339
+ //! the tuple type is also deduced.
340
+ //!
341
+ //! The tuple type is used to decide whether to cast the input argument
342
+ //! to a Fusion IR TensorView or leave it as a Fusion IR Val (Scalar).
343
+ //!
344
+ //! A deduced binary op could look like:
345
+ //! OutType opFunc<std::tuple<TensorView*, TensorView*>, 0, 1>
346
+ //! A deduced ternary op could look like:
347
+ //! OutTupe opFunc<std::tuple<TensorView*, Val*, Val*>, 0, 1, 2>
348
+ template <class TupleType, std::size_t... Is>
349
+ OutType opFunc(FusionState& fd, TupleType& tp, std::index_sequence<Is...>) {
350
+ return fusion_op_(
351
+ dynamic_cast<typename std::tuple_element<Is, TupleType>::type>(
352
+ fd.getFusionState(args_.at(Is).index))...);
353
+ }
354
+
355
+ void operator()(FusionState& fd) final {
356
+ using arg_tuple_t = std::tuple<ArgTypes...>;
357
+ auto indices =
358
+ std::make_index_sequence<std::tuple_size<arg_tuple_t>::value>();
359
+ // The tuple variable is never populated, it is passed for its type.
360
+ arg_tuple_t inputs;
361
+ auto output = opFunc(fd, inputs, indices);
362
+ fd.setFusionState(outputs_.at(0).index, output);
363
+ }
364
+
365
+ private:
366
+ //! An nvFuser Arith Operation function signature
367
+ std::function<OutType(ArgTypes...)> fusion_op_;
368
+ };
369
+
370
+ struct SliceOpRecord : RecordFunctor {
371
+ SliceOpRecord(
372
+ std::vector<State> _args,
373
+ std::vector<State> _outputs,
374
+ bool manual_normalization)
375
+ : RecordFunctor(
376
+ std::move(_args),
377
+ std::move(_outputs),
378
+ "ops.slice",
379
+ serde::RecordType::SliceOp),
380
+ manual_normalization_(manual_normalization) {
381
+ arg_names_[1] = "start_indices";
382
+ arg_names_[2] = "end_indices";
383
+ arg_names_[3] = "strides";
384
+ }
385
+ ~SliceOpRecord() override = default;
386
+ RecordFunctor* clone() final {
387
+ return new SliceOpRecord(*this);
388
+ }
389
+
390
+ //! Child specific hash function in lower 32 bits.
391
+ //! | 31 | 30 ------------------------ 0 |
392
+ //! | manual_normalization? | other |
393
+ size_t hash() const final {
394
+ auto result = RecordFunctor::hash();
395
+ result |= ((static_cast<size_t>(manual_normalization_) & 0x1) << 31);
396
+ return result;
397
+ }
398
+
399
+ bool operator==(const RecordFunctor& other) const final {
400
+ auto result = false;
401
+ if (auto child_ptr = dynamic_cast<const SliceOpRecord*>(&other)) {
402
+ result = RecordFunctor::operator==(other);
403
+ result =
404
+ result && (manual_normalization_ == child_ptr->manual_normalization_);
405
+ }
406
+ return result;
407
+ }
408
+
409
+ void operator()(FusionState& fd) final {
410
+ TensorView* arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
411
+ const std::vector<Val*>& start = fd.getFusionStateVector(args_.at(1).index);
412
+ const std::vector<Val*>& end = fd.getFusionStateVector(args_.at(2).index);
413
+ const std::vector<Val*>& stride =
414
+ fd.getFusionStateVector(args_.at(3).index);
415
+ std::vector<Slice> vec_slice;
416
+ for (const auto idx : c10::irange(arg->domain()->noReductions().size())) {
417
+ // NOTE: there's an extra move, we can use emplace_back if we go write
418
+ // some constructors for Slice.
419
+ Val* start_idx = start.at(idx);
420
+ Val* end_idx = end.at(idx);
421
+ Val* stride_idx = stride.at(idx);
422
+ NVF_CHECK(
423
+ !start_idx->isConstInt() || start_idx->evaluate().as<int64_t>() >= 0,
424
+ "Slice operation start_indices must be greater than or equal to 0. Start Indices: ",
425
+ start_idx->evaluate().as<int64_t>());
426
+ NVF_CHECK(
427
+ !start_idx->isConstInt() || !end_idx->isConstInt() ||
428
+ end_idx->evaluate().as<int64_t>() >=
429
+ start_idx->evaluate().as<int64_t>(),
430
+ "Slice operation end_indices must be greater than or equal to start_indices. Start Indices: ",
431
+ start_idx->evaluate().as<int64_t>(),
432
+ " End Indices: ",
433
+ end_idx->evaluate().as<int64_t>());
434
+ NVF_CHECK(
435
+ stride_idx->isConstInt() && stride_idx->evaluate().as<int64_t>() == 1,
436
+ "nvFuser Limitation: All slice operation strides must be of const size 1.");
437
+ vec_slice.push_back({start_idx, end_idx, stride_idx});
438
+ }
439
+ auto output = slice(arg, vec_slice, manual_normalization_);
440
+ fd.setFusionState(outputs_.at(0).index, output);
441
+ }
442
+
443
+ void print(std::ostream& os, bool close_function = true) const final {
444
+ RecordFunctor::print(os, false);
445
+ os << ", manual_normalization=" << manual_normalization_;
446
+ if (close_function) {
447
+ os << ")";
448
+ }
449
+ }
450
+
451
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
452
+ flatbuffers::FlatBufferBuilder& builder) const final {
453
+ return {
454
+ serde::RecordData::Slice,
455
+ serde::CreateSlice(builder, manual_normalization_).Union()};
456
+ }
457
+
458
+ private:
459
+ //! A flag to skip slice normalization step in composite operation.
460
+ bool manual_normalization_;
461
+ };
462
+
463
+ struct ReshapeOpRecord : RecordFunctor {
464
+ ReshapeOpRecord(std::vector<State> _args, std::vector<State> _outputs)
465
+ : RecordFunctor(
466
+ std::move(_args),
467
+ std::move(_outputs),
468
+ "ops.reshape",
469
+ serde::RecordType::ReshapeOp) {
470
+ arg_names_[1] = "new_shape";
471
+ }
472
+ ~ReshapeOpRecord() override = default;
473
+ RecordFunctor* clone() final {
474
+ return new ReshapeOpRecord(*this);
475
+ }
476
+
477
+ void operator()(FusionState& fd) final {
478
+ TensorView* arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
479
+ const std::vector<Val*>& new_shape =
480
+ fd.getFusionStateVector(args_.at(1).index);
481
+ auto output = reshape(arg, new_shape);
482
+ fd.setFusionState(outputs_.at(0).index, output);
483
+ }
484
+ };
485
+
486
+ struct PadOpRecord : RecordFunctor {
487
+ PadOpRecord(std::vector<State> _args, std::vector<State> _outputs)
488
+ : RecordFunctor(
489
+ std::move(_args),
490
+ std::move(_outputs),
491
+ "ops.pad",
492
+ serde::RecordType::PadOp) {}
493
+ ~PadOpRecord() override = default;
494
+ RecordFunctor* clone() final {
495
+ return new PadOpRecord(*this);
496
+ }
497
+
498
+ void operator()(FusionState& fd) final {
499
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
500
+ const std::vector<Val*>& val_widths =
501
+ fd.getFusionStateVector(args_.at(1).index);
502
+
503
+ TensorView* output = nullptr;
504
+ if (args_.at(2).stype == serde::StateType::Scalar) {
505
+ output = pad(arg, val_widths, fd.getFusionState(args_.at(2).index));
506
+ } else { // default: None
507
+ NVF_ERROR(args_.at(2).stype == serde::StateType::None);
508
+ output = pad(arg, val_widths);
509
+ }
510
+
511
+ fd.setFusionState(outputs_.at(0).index, output);
512
+ }
513
+ };
514
+
515
+ template <serde::RecordType op_type>
516
+ struct DimsOpRecord : RecordFunctor {
517
+ DimsOpRecord(
518
+ std::vector<State> _args,
519
+ std::vector<State> _outputs,
520
+ std::vector<int64_t> dims,
521
+ std::string name)
522
+ : RecordFunctor(std::move(_args), std::move(_outputs), name, op_type) {
523
+ int64_t rank = (int64_t)dims.size();
524
+ dims_.reserve(rank);
525
+ std::unordered_set<int64_t> dims_set;
526
+ for (auto dim : dims) {
527
+ dims_set.insert(dim);
528
+ if (dim < 0) {
529
+ NVF_CHECK(
530
+ dim >= -rank,
531
+ name + " dims argument is out of range, expects >= -" +
532
+ std::to_string(rank) + ", but got: " + std::to_string(dim));
533
+ dim += rank;
534
+ } else {
535
+ NVF_CHECK(
536
+ dim < rank,
537
+ name + " dims argument is out of range, expects < " +
538
+ std::to_string(rank) + ", but got: " + std::to_string(dim));
539
+ }
540
+ dims_.push_back(dim);
541
+ }
542
+ NVF_CHECK(
543
+ dims_set.size() == dims.size(),
544
+ name + " got duplicated dimension entries: " + toDelimitedString(dims));
545
+ }
546
+ ~DimsOpRecord() override = default;
547
+ RecordFunctor* clone() final {
548
+ return new DimsOpRecord(*this);
549
+ }
550
+
551
+ size_t hash() const final {
552
+ auto result = RecordFunctor::hash();
553
+ size_t dims_hash = 0;
554
+ for (auto dim : dims_) {
555
+ hashCombine(dims_hash, static_cast<size_t>(dim));
556
+ }
557
+ return result | (dims_hash & 0xffff);
558
+ }
559
+
560
+ bool operator==(const RecordFunctor& other) const final {
561
+ auto result = false;
562
+ if (auto child_ptr = dynamic_cast<const DimsOpRecord*>(&other)) {
563
+ result = RecordFunctor::operator==(other);
564
+ if (result) {
565
+ result = (dims_.size() == child_ptr->dims_.size());
566
+ if (result) {
567
+ for (size_t i = 0; i < dims_.size(); ++i) {
568
+ if (dims_[i] != child_ptr->dims_[i]) {
569
+ result = false;
570
+ break;
571
+ }
572
+ }
573
+ }
574
+ }
575
+ }
576
+ return result;
577
+ }
578
+
579
+ void operator()(FusionState& fd) final {
580
+ if constexpr (op_type == serde::RecordType::PermuteOp) {
581
+ auto arg =
582
+ fd.getFusionState(args_.at(0).index)->template as<TensorView>();
583
+ auto output = permute(arg, dims_);
584
+ fd.setFusionState(outputs_.at(0).index, output);
585
+ } else if constexpr (op_type == serde::RecordType::StrideOrderOp) {
586
+ auto arg =
587
+ fd.getFusionState(args_.at(0).index)->template as<TensorView>();
588
+ auto output = set(arg);
589
+ std::vector<IterDomain*> allocation_domain =
590
+ ir_utils::strideOrderToAllocation(output->getLogicalDomain(), dims_);
591
+ output->setAllocationDomain(allocation_domain, true);
592
+ fd.setFusionState(outputs_.at(0).index, output);
593
+ } else {
594
+ NVF_THROW("op_type is not recognized by dims operator.");
595
+ }
596
+ }
597
+
598
+ void print(std::ostream& os, bool close_function = true) const final {
599
+ RecordFunctor::print(os, false);
600
+ if constexpr (op_type == serde::RecordType::PermuteOp) {
601
+ os << ", dims=[";
602
+ } else if constexpr (op_type == serde::RecordType::StrideOrderOp) {
603
+ os << ", stride_order=[";
604
+ } else {
605
+ NVF_THROW("op_type is not recognized by dims operator.");
606
+ }
607
+ bool first_arg = true;
608
+ for (auto dim : dims_) {
609
+ if (first_arg) {
610
+ first_arg = false;
611
+ } else {
612
+ os << ", ";
613
+ }
614
+ os << dim;
615
+ }
616
+ os << "]";
617
+ if (close_function) {
618
+ os << ")";
619
+ }
620
+ }
621
+
622
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
623
+ flatbuffers::FlatBufferBuilder& builder) const final {
624
+ return {
625
+ serde::RecordData::Dims,
626
+ serde::CreateDimsDirect(builder, &dims_).Union()};
627
+ }
628
+
629
+ private:
630
+ //! Represents the mapping from the original shape to the new shape
631
+ std::vector<int64_t> dims_;
632
+ };
633
+
634
+ struct SqueezeOpRecord : RecordFunctor {
635
+ SqueezeOpRecord(
636
+ std::vector<State> _args,
637
+ std::vector<State> _outputs,
638
+ std::vector<int64_t> dims,
639
+ bool squeeze_expanded = false)
640
+ : RecordFunctor(
641
+ std::move(_args),
642
+ std::move(_outputs),
643
+ "ops.squeeze",
644
+ serde::RecordType::SqueezeOp),
645
+ dims_(std::move(dims)),
646
+ squeeze_expanded_(squeeze_expanded) {}
647
+ ~SqueezeOpRecord() override = default;
648
+ RecordFunctor* clone() final {
649
+ return new SqueezeOpRecord(*this);
650
+ }
651
+
652
+ //! Child specific hash function in lower 32 bits.
653
+ //! | 31 | 30 -------------------------------- 0 |
654
+ //! | squeeze_expanded? | Squeeze Dim hash |
655
+ size_t hash() const final {
656
+ auto result = RecordFunctor::hash();
657
+ size_t squeeze_dims_hash = 0;
658
+ for (auto dim : dims_) {
659
+ squeeze_dims_hash ^= static_cast<size_t>(dim);
660
+ }
661
+ result = result | (squeeze_dims_hash & 0x7fffffff);
662
+ result |= ((static_cast<size_t>(squeeze_expanded_) & 0x1) << 31);
663
+ return result;
664
+ }
665
+
666
+ bool operator==(const RecordFunctor& other) const final {
667
+ auto result = false;
668
+ if (auto child_ptr = dynamic_cast<const SqueezeOpRecord*>(&other)) {
669
+ result = RecordFunctor::operator==(other) && (dims_ == child_ptr->dims_);
670
+ }
671
+ return result;
672
+ }
673
+
674
+ void operator()(FusionState& fd) final {
675
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
676
+ // In pytorch, the squeeze operation cannot remove expanded dimensions.
677
+ // In nvfuser, for reduction operations, we apply squeeze to remove
678
+ // broadcast and expanded iterDomains. The squeeze_expanded_ flag bypasses
679
+ // assertion used to match pytorch's behavior.
680
+ auto output = squeeze(arg, dims_, squeeze_expanded_);
681
+ fd.setFusionState(outputs_.at(0).index, output);
682
+ }
683
+
684
+ void print(std::ostream& os, bool close_function = true) const final {
685
+ RecordFunctor::print(os, false);
686
+ os << ", dims=[";
687
+ bool first_arg = true;
688
+ for (auto dim : dims_) {
689
+ if (first_arg) {
690
+ first_arg = false;
691
+ } else {
692
+ os << ", ";
693
+ }
694
+ os << dim;
695
+ }
696
+ os << "], squeeze_expanded=" << (squeeze_expanded_ ? "True" : "False");
697
+ if (close_function) {
698
+ os << ")";
699
+ }
700
+ }
701
+
702
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
703
+ flatbuffers::FlatBufferBuilder& builder) const final {
704
+ return {
705
+ serde::RecordData::Squeeze,
706
+ serde::CreateSqueezeDirect(builder, &dims_, squeeze_expanded_).Union()};
707
+ }
708
+
709
+ private:
710
+ //! Dimension to squeeze.
711
+ std::vector<int64_t> dims_;
712
+ //! Option to remove expanded dimensions
713
+ bool squeeze_expanded_;
714
+ };
715
+
716
+ //! Specialized Record Functor for the FusionState's broadcast_in_dim op.
717
+ // NOTE: output_ndims gives the rank of the output tensor. This size can be
718
+ // found from the State after the definition is read and the Fusion IR is in the
719
+ // process of being created. However, pior to that point, the size is needed
720
+ // for matching a Fusion Record node in the Trie used to cache definitions.
721
+ struct BroadcastInDimOpRecord : RecordFunctor {
722
+ BroadcastInDimOpRecord(
723
+ std::vector<State> _args,
724
+ std::vector<State> _outputs,
725
+ size_t output_ndims,
726
+ std::vector<int64_t> broadcast_dims)
727
+ : RecordFunctor(
728
+ std::move(_args),
729
+ std::move(_outputs),
730
+ "ops.broadcast_in_dim",
731
+ serde::RecordType::BroadcastInDim),
732
+ output_ndims_(output_ndims),
733
+ broadcast_dims_(std::move(broadcast_dims)) {
734
+ arg_names_[1] = "shape";
735
+ }
736
+ ~BroadcastInDimOpRecord() override = default;
737
+ RecordFunctor* clone() final {
738
+ return new BroadcastInDimOpRecord(*this);
739
+ }
740
+
741
+ //! Child specific hash function in lower 32 bits.
742
+ //! | 31 ------------------------------------- 0 |
743
+ //! | broadcast_dims hash |
744
+ size_t hash() const final {
745
+ auto result = RecordFunctor::hash();
746
+ size_t broadcast_dims_hash = 0;
747
+ for (auto dim : broadcast_dims_) {
748
+ broadcast_dims_hash |= 1 << ((output_ndims_ - 1) - dim);
749
+ }
750
+ return result | (broadcast_dims_hash & 0xffffffff);
751
+ }
752
+
753
+ bool operator==(const RecordFunctor& other) const final {
754
+ auto result = false;
755
+ if (auto child_ptr = dynamic_cast<const BroadcastInDimOpRecord*>(&other)) {
756
+ result = RecordFunctor::operator==(other);
757
+ if (result) {
758
+ result =
759
+ ((output_ndims_ == child_ptr->output_ndims_) &&
760
+ (broadcast_dims_.size() == child_ptr->broadcast_dims_.size()));
761
+ if (result) {
762
+ for (size_t i = 0; i < broadcast_dims_.size(); ++i) {
763
+ if (broadcast_dims_[i] != child_ptr->broadcast_dims_[i]) {
764
+ result = false;
765
+ break;
766
+ }
767
+ }
768
+ }
769
+ }
770
+ }
771
+ return result;
772
+ }
773
+
774
+ void operator()(FusionState& fd) final {
775
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
776
+ const std::vector<Val*>& output_shape =
777
+ fd.getFusionStateVector(args_.at(1).index);
778
+
779
+ const auto& arg_domains_nr = arg->domain()->noReductions();
780
+ const auto arg_ndims = arg_domains_nr.size();
781
+ NVF_CHECK(
782
+ output_ndims_ >= arg_ndims,
783
+ "The new shape is expected to be greater-then-or-equal to the input: ",
784
+ output_ndims_,
785
+ " vs ",
786
+ arg_ndims);
787
+ NVF_CHECK(
788
+ arg_ndims == broadcast_dims_.size(),
789
+ "The broadcast dimensions should match the input dimensions: ",
790
+ arg_ndims,
791
+ " vs ",
792
+ broadcast_dims_.size(),
793
+ ". arg = ",
794
+ arg->toString());
795
+
796
+ std::vector<bool> is_broadcast_dim(output_ndims_, true);
797
+ for (const auto idx : c10::irange(broadcast_dims_.size())) {
798
+ if (idx > 0) {
799
+ NVF_CHECK(
800
+ broadcast_dims_[idx - 1] < broadcast_dims_[idx],
801
+ "Broadcast dimension is not greater than the previous value.");
802
+ }
803
+ NVF_CHECK(
804
+ broadcast_dims_[idx] < static_cast<int>(output_ndims_),
805
+ "Invalid broadcast_dims value.");
806
+ is_broadcast_dim.at(broadcast_dims_[idx]) = false;
807
+ }
808
+
809
+ auto output = broadcast(arg, is_broadcast_dim);
810
+ auto expanded_output = expand(output, output_shape);
811
+
812
+ fd.setFusionState(outputs_.at(0).index, expanded_output);
813
+ }
814
+
815
+ void print(std::ostream& os, bool close_function = true) const final {
816
+ RecordFunctor::print(os, false);
817
+ os << ", broadcast_dims=[";
818
+ bool first_arg = true;
819
+ for (auto dim : broadcast_dims_) {
820
+ if (first_arg) {
821
+ first_arg = false;
822
+ } else {
823
+ os << ", ";
824
+ }
825
+ os << dim;
826
+ }
827
+ os << "]";
828
+ if (close_function) {
829
+ os << ")";
830
+ }
831
+ }
832
+
833
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
834
+ flatbuffers::FlatBufferBuilder& builder) const final {
835
+ return {
836
+ serde::RecordData::BroadcastInDim,
837
+ serde::CreateBroadcastInDimDirect(
838
+ builder, output_ndims_, &broadcast_dims_)
839
+ .Union()};
840
+ };
841
+
842
+ private:
843
+ //! Number of dims of shape Vector used to communicate the output tensor shape
844
+ size_t output_ndims_;
845
+ //! Communicates which dimensions of the output the input tensor maps.
846
+ //! For instance, for output [2, 3, 4] and input [3]. This vector would
847
+ //! contain [1].
848
+ std::vector<int64_t> broadcast_dims_;
849
+ };
850
+
851
+ //! Specialized Record Functor for the FusionState's broadcast op.
852
+
853
+ struct BroadcastOpRecord : RecordFunctor {
854
+ BroadcastOpRecord(
855
+ std::vector<State> _args,
856
+ std::vector<State> _outputs,
857
+ std::string _name,
858
+ std::vector<bool> is_broadcast_dim)
859
+ : RecordFunctor(
860
+ std::move(_args),
861
+ std::move(_outputs),
862
+ _name,
863
+ serde::RecordType::BroadcastOp),
864
+ is_broadcast_dim_(std::move(is_broadcast_dim)) {}
865
+ ~BroadcastOpRecord() override = default;
866
+ RecordFunctor* clone() final {
867
+ return new BroadcastOpRecord(*this);
868
+ }
869
+
870
+ size_t hash() const final {
871
+ auto result = RecordFunctor::hash();
872
+ size_t is_broadcast_dim_hash = 0;
873
+ for (size_t i = 0; i < is_broadcast_dim_.size(); ++i) {
874
+ is_broadcast_dim_hash |=
875
+ (is_broadcast_dim_[i] << (is_broadcast_dim_.size() - 1 - i));
876
+ }
877
+ return result | (is_broadcast_dim_hash & 0xfff);
878
+ }
879
+
880
+ bool operator==(const RecordFunctor& other) const final {
881
+ auto result = false;
882
+ if (auto child_ptr = dynamic_cast<const BroadcastOpRecord*>(&other)) {
883
+ result = RecordFunctor::operator==(other);
884
+ result = result &&
885
+ std::equal(
886
+ is_broadcast_dim_.begin(),
887
+ is_broadcast_dim_.end(),
888
+ child_ptr->is_broadcast_dim_.begin());
889
+ }
890
+ return result;
891
+ }
892
+
893
+ void operator()(FusionState& fd) final {
894
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
895
+ auto output = broadcast(arg, is_broadcast_dim_);
896
+ fd.setFusionState(outputs_.at(0).index, output);
897
+ }
898
+
899
+ void print(std::ostream& os, bool close_function = true) const final {
900
+ RecordFunctor::print(os, false);
901
+ os << ", is_broadcast_dim=[";
902
+ bool first_arg = true;
903
+ for (auto dim : is_broadcast_dim_) {
904
+ if (first_arg) {
905
+ first_arg = false;
906
+ } else {
907
+ os << ", ";
908
+ }
909
+ os << (dim ? "True" : "False");
910
+ }
911
+ os << "]";
912
+ if (close_function) {
913
+ os << ")";
914
+ }
915
+ }
916
+
917
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
918
+ flatbuffers::FlatBufferBuilder& builder) const final {
919
+ auto fb_broadcast_dims = builder.CreateVector(is_broadcast_dim_);
920
+
921
+ serde::BroadcastBuilder bcast_builder(builder);
922
+ bcast_builder.add_broadcast_dims(fb_broadcast_dims);
923
+ auto expr_data = bcast_builder.Finish();
924
+ return {serde::RecordData::Broadcast, expr_data.Union()};
925
+ }
926
+
927
+ private:
928
+ //! Communicates which dimensions in the output are broadcasted.
929
+ std::vector<bool> is_broadcast_dim_;
930
+ };
931
+
932
+ //! Specialized Record Functor for the FusionState's expand op.
933
+ struct ExpandOpRecord : RecordFunctor {
934
+ ExpandOpRecord(std::vector<State> _args, std::vector<State> _outputs)
935
+ : RecordFunctor(
936
+ std::move(_args),
937
+ std::move(_outputs),
938
+ "ops.expand",
939
+ serde::RecordType::ExpandOp) {
940
+ arg_names_[1] = "shape";
941
+ }
942
+ ~ExpandOpRecord() override = default;
943
+ RecordFunctor* clone() final {
944
+ return new ExpandOpRecord(*this);
945
+ }
946
+
947
+ //! Child specific hash function in lower 32 bits.
948
+ //! | 31 --------------------------------------- 0 |
949
+ //! | None |
950
+ size_t hash() const final {
951
+ return RecordFunctor::hash();
952
+ }
953
+
954
+ bool operator==(const RecordFunctor& other) const final {
955
+ auto result = false;
956
+ if (dynamic_cast<const ExpandOpRecord*>(&other)) {
957
+ result = RecordFunctor::operator==(other);
958
+ }
959
+ return result;
960
+ }
961
+
962
+ void operator()(FusionState& fd) final {
963
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
964
+ const std::vector<Val*>& output_shape =
965
+ fd.getFusionStateVector(args_.at(1).index);
966
+
967
+ size_t arg_ndims = arg->domain()->noReductions().size();
968
+ NVF_CHECK(
969
+ output_shape.size() == arg_ndims,
970
+ "The new shape is expected to be equal to the input: ",
971
+ output_shape.size(),
972
+ " vs ",
973
+ arg_ndims);
974
+ auto expanded_output = expand(arg, output_shape);
975
+
976
+ fd.setFusionState(outputs_.at(0).index, expanded_output);
977
+ }
978
+
979
+ void print(std::ostream& os, bool close_function = true) const final {
980
+ RecordFunctor::print(os, false);
981
+ if (close_function) {
982
+ os << ")";
983
+ }
984
+ }
985
+ };
986
+
987
+ template <class OutType, class ArgType>
988
+ struct CastOpRecord : RecordFunctor {
989
+ CastOpRecord(
990
+ std::vector<State> _args,
991
+ std::vector<State> _outputs,
992
+ std::string _name,
993
+ serde::RecordType record_type,
994
+ std::function<OutType(DataType, ArgType)> fusion_op,
995
+ PrimDataType dtype)
996
+ : RecordFunctor(
997
+ std::move(_args),
998
+ std::move(_outputs),
999
+ _name,
1000
+ record_type),
1001
+ fusion_op_(fusion_op),
1002
+ dtype_(dtype) {}
1003
+ ~CastOpRecord() override = default;
1004
+ RecordFunctor* clone() final {
1005
+ return new CastOpRecord(*this);
1006
+ }
1007
+
1008
+ //! Child specific hash function in lower 32 bits.
1009
+ //! | 31 --- 24 | 23 -------------------------- 0 |
1010
+ //! | Dtype | Arith Function Sig hash code |
1011
+ size_t hash() const final {
1012
+ auto result = RecordFunctor::hash();
1013
+ result |= ((static_cast<size_t>(dtype_) & 0xff) << 24);
1014
+ result |= (fusion_op_.target_type().hash_code() & 0xffffff);
1015
+ return result;
1016
+ }
1017
+
1018
+ bool operator==(const RecordFunctor& other) const final {
1019
+ auto result = false;
1020
+ if (auto child_ptr = dynamic_cast<const CastOpRecord*>(&other)) {
1021
+ result = RecordFunctor::operator==(other);
1022
+ if (result) {
1023
+ result = result &&
1024
+ (fusion_op_.target_type() == child_ptr->fusion_op_.target_type());
1025
+ if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) {
1026
+ debug() << "\nCastOpRecord: " << name_ << " Target Type [self: 0x"
1027
+ << fusion_op_.target_type().name() << "] [other: 0x"
1028
+ << child_ptr->fusion_op_.target_type().name() << "]";
1029
+ }
1030
+ // IMPORTANT! you need to dereference the target pointer in order
1031
+ // to match the function
1032
+ result = result &&
1033
+ (*fusion_op_.template target<OutType (*)(DataType, ArgType)>() ==
1034
+ *child_ptr->fusion_op_
1035
+ .template target<OutType (*)(DataType, ArgType)>());
1036
+ if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) {
1037
+ debug() << " Target Ptr [self: 0x" << std::hex
1038
+ << (size_t)*fusion_op_
1039
+ .template target<OutType (*)(DataType, ArgType)>()
1040
+ << "] [other: 0x" << std::hex
1041
+ << (size_t)*child_ptr->fusion_op_
1042
+ .template target<OutType (*)(DataType, ArgType)>()
1043
+ << "]\n";
1044
+ }
1045
+ result = result && (dtype_ == child_ptr->dtype_);
1046
+ }
1047
+ }
1048
+ return result;
1049
+ }
1050
+
1051
+ void operator()(FusionState& fd) final {
1052
+ auto arg = dynamic_cast<ArgType>(fd.getFusionState(args_.at(0).index));
1053
+ auto output = fusion_op_(dtype_, arg);
1054
+ fd.setFusionState(outputs_.at(0).index, output);
1055
+ }
1056
+
1057
+ void print(std::ostream& os, bool close_function = true) const final {
1058
+ RecordFunctor::print(os, false);
1059
+ os << ", dtype=" << dtypeToPyString(dtype_);
1060
+ if (close_function) {
1061
+ os << ")";
1062
+ }
1063
+ }
1064
+
1065
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1066
+ flatbuffers::FlatBufferBuilder& builder) const final {
1067
+ return {
1068
+ serde::RecordData::Dtype,
1069
+ serde::CreateDtype(builder, nvfuser::toUnderlying(dtype_)).Union()};
1070
+ }
1071
+
1072
+ private:
1073
+ //! nvFuser arith function signature
1074
+ std::function<OutType(DataType, ArgType)> fusion_op_;
1075
+ //! Type to cast to.
1076
+ PrimDataType dtype_;
1077
+ };
1078
+
1079
+ struct CatOpRecord : RecordFunctor {
1080
+ CatOpRecord(
1081
+ std::vector<State> _args,
1082
+ std::vector<State> _outputs,
1083
+ int64_t dim,
1084
+ bool manual_padding)
1085
+ : RecordFunctor(
1086
+ std::move(_args),
1087
+ std::move(_outputs),
1088
+ "ops.cat",
1089
+ serde::RecordType::CatOp),
1090
+ dim_(dim),
1091
+ manual_padding_(manual_padding) {}
1092
+ ~CatOpRecord() override = default;
1093
+ RecordFunctor* clone() final {
1094
+ return new CatOpRecord(*this);
1095
+ }
1096
+
1097
+ //! Child specific hash function in lower 32 bits.
1098
+ //! | 31 | 30 ------------------------ 0 |
1099
+ //! | manual_padding? | dim |
1100
+ size_t hash() const final {
1101
+ auto result = RecordFunctor::hash();
1102
+ result |= ((static_cast<size_t>(manual_padding_) & 0x1) << 31);
1103
+ result |= (static_cast<size_t>(dim_) & 0x7fff);
1104
+ return result;
1105
+ }
1106
+
1107
+ bool operator==(const RecordFunctor& other) const final {
1108
+ auto result = false;
1109
+ if (auto child_ptr = dynamic_cast<const CatOpRecord*>(&other)) {
1110
+ result = RecordFunctor::operator==(other);
1111
+ result = result && (dim_ == child_ptr->dim_);
1112
+ result = result && (manual_padding_ == child_ptr->manual_padding_);
1113
+ }
1114
+ return result;
1115
+ }
1116
+
1117
+ void operator()(FusionState& fd) final {
1118
+ std::vector<TensorView*> input_tvs;
1119
+ input_tvs.reserve(args_.size());
1120
+ for (auto& a : args_) {
1121
+ input_tvs.push_back(
1122
+ fd.getFusionState(a.index)->template as<TensorView>());
1123
+ }
1124
+ auto output =
1125
+ cat(input_tvs, dim_, /*iter_type_opt=*/std::nullopt, manual_padding_);
1126
+ fd.setFusionState(outputs_.at(0).index, output);
1127
+ }
1128
+
1129
+ void print(std::ostream& os, bool close_function = true) const final {
1130
+ // Similar to RecordFunctor::print(os, false), but don't print args
1131
+ bool first_output = true;
1132
+ for (auto& output : outputs_) {
1133
+ if (first_output) {
1134
+ first_output = false;
1135
+ } else {
1136
+ os << ", ";
1137
+ }
1138
+ os << output;
1139
+ }
1140
+ if (always_returns_tuple_) {
1141
+ os << ",";
1142
+ }
1143
+ if (!outputs_.empty()) {
1144
+ os << " = "
1145
+ << "fd." << name_ << "(";
1146
+ } else {
1147
+ os << "fd." << name_ << "(";
1148
+ }
1149
+ os << "[";
1150
+ bool first_arg = true;
1151
+ for (auto& arg : args_) {
1152
+ if (first_arg) {
1153
+ first_arg = false;
1154
+ } else {
1155
+ os << ", ";
1156
+ }
1157
+ os << arg;
1158
+ }
1159
+ os << "], dim=" << dim_;
1160
+ os << ", manual_padding=" << manual_padding_;
1161
+ if (close_function) {
1162
+ os << ")";
1163
+ }
1164
+ }
1165
+
1166
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1167
+ flatbuffers::FlatBufferBuilder& builder) const final {
1168
+ return {
1169
+ serde::RecordData::Cat,
1170
+ serde::CreateCat(builder, dim_, manual_padding_).Union()};
1171
+ }
1172
+
1173
+ private:
1174
+ //! The dimension along which we will concatenate
1175
+ int64_t dim_;
1176
+ //! A flag to skip the pad operation in the cat composite operation.
1177
+ bool manual_padding_;
1178
+ };
1179
+
1180
+ //! Specialized Record Functor for recording FusionState End.
1181
+ //! The accompanying Fusion Cache Entry holds a Fusion Object.
1182
+
1183
+ struct EndRecord : RecordFunctor {
1184
+ EndRecord() : RecordFunctor({}, {}, "end", serde::RecordType::End) {}
1185
+ ~EndRecord() override = default;
1186
+ RecordFunctor* clone() final {
1187
+ return new EndRecord(*this);
1188
+ }
1189
+
1190
+ //! Child specific hash function in lower 32 bits.
1191
+ //! | 31 --------------------------------------- 0 |
1192
+ //! | None |
1193
+ size_t hash() const final {
1194
+ return RecordFunctor::hash();
1195
+ }
1196
+
1197
+ bool operator==(const RecordFunctor& other) const final {
1198
+ auto result = false;
1199
+ if (dynamic_cast<const EndRecord*>(&other)) {
1200
+ result = RecordFunctor::operator==(other);
1201
+ }
1202
+ return result;
1203
+ }
1204
+
1205
+ void operator()(FusionState& fd) final {}
1206
+ };
1207
+
1208
+ //! Specialized Record Functor for recording FusionState input tensors.
1209
+
1210
+ struct TensorRecord : RecordFunctor {
1211
+ TensorRecord(
1212
+ std::vector<State> _outputs,
1213
+ std::vector<int64_t> _shape,
1214
+ std::vector<std::optional<bool>> _contiguity,
1215
+ PrimDataType _dtype,
1216
+ bool _is_cpu = false,
1217
+ std::vector<int64_t> _stride_order = {})
1218
+ : RecordFunctor(
1219
+ {},
1220
+ std::move(_outputs),
1221
+ "define_tensor",
1222
+ serde::RecordType::Tensor),
1223
+ shape_(std::move(_shape)),
1224
+ contiguity_(std::move(_contiguity)),
1225
+ stride_order_(std::move(_stride_order)),
1226
+ dtype_(_dtype),
1227
+ is_cpu_(_is_cpu) {
1228
+ if (!stride_order_.empty()) {
1229
+ int64_t rank = (int64_t)stride_order_.size();
1230
+ std::unordered_set<int64_t> order_set;
1231
+ for (auto& order : stride_order_) {
1232
+ order_set.insert(order);
1233
+ if (order < 0) {
1234
+ NVF_CHECK(
1235
+ order >= -rank,
1236
+ "define_tensor stride_order argument is out of range, expects >= -" +
1237
+ std::to_string(rank) + ", but got: " + std::to_string(order));
1238
+ order += rank;
1239
+ } else {
1240
+ NVF_CHECK(
1241
+ order < rank,
1242
+ "define_tensor stride_order argument is out of range, expects < " +
1243
+ std::to_string(rank) + ", but got: " + std::to_string(order));
1244
+ }
1245
+ }
1246
+ NVF_CHECK(
1247
+ order_set.size() == stride_order_.size(),
1248
+ "define_tensor got duplicated stride_order entries: " +
1249
+ toDelimitedString(stride_order_));
1250
+ }
1251
+ }
1252
+ ~TensorRecord() override = default;
1253
+ RecordFunctor* clone() final {
1254
+ return new TensorRecord(*this);
1255
+ }
1256
+
1257
+ //! Child specific hash function in lower 32 bits.
1258
+ //! | 31 | 30 --- 24 | 23 --------- 12 | 11 ------------------------ 0 |
1259
+ //! | CPU? | Dtype | Symbolic Sizes | Contiguous Info & stride_order |
1260
+ size_t hash() const final {
1261
+ auto result = RecordFunctor::hash();
1262
+ size_t ssize_hash = 0;
1263
+ for (size_t i = 0; i < shape_.size(); ++i) {
1264
+ size_t ssize = 0;
1265
+ if (shape_[i] == -1) {
1266
+ ssize = 1;
1267
+ }
1268
+ ssize_hash |= (ssize << (shape_.size() - 1 - i));
1269
+ }
1270
+ size_t contig_stride_hash = 0;
1271
+ for (size_t i = 0; i < contiguity_.size(); ++i) {
1272
+ auto contiguity_value = contiguity_[i];
1273
+ contig_stride_hash |=
1274
+ ((contiguity_value.has_value() && contiguity_value.value())
1275
+ << (contiguity_.size() - 1 - i));
1276
+ }
1277
+ for (size_t i = 0; i < stride_order_.size(); ++i) {
1278
+ contig_stride_hash ^= (stride_order_[i] << i);
1279
+ }
1280
+
1281
+ result |= ((static_cast<size_t>(is_cpu_) & 0x1) << 31);
1282
+ result |= ((static_cast<size_t>(dtype_) & 0x7f) << 24);
1283
+ return result | ((ssize_hash & 0xfff) << 12) | (contig_stride_hash & 0xfff);
1284
+ }
1285
+
1286
+ bool operator==(const RecordFunctor& other) const final {
1287
+ auto result = false;
1288
+ if (auto child_ptr = dynamic_cast<const TensorRecord*>(&other)) {
1289
+ result = RecordFunctor::operator==(other);
1290
+ result = result && (dtype_ == child_ptr->dtype_);
1291
+ result = result && (is_cpu_ == child_ptr->is_cpu_);
1292
+ if (result) {
1293
+ result =
1294
+ ((shape_.size() == child_ptr->shape_.size()) &&
1295
+ (stride_order_.size() == child_ptr->stride_order_.size()) &&
1296
+ (contiguity_.size() == child_ptr->contiguity_.size()));
1297
+ if (result) {
1298
+ for (size_t i = 0; i < shape_.size(); ++i) {
1299
+ if (shape_[i] != child_ptr->shape_[i]) {
1300
+ result = false;
1301
+ break;
1302
+ }
1303
+ }
1304
+ }
1305
+ if (result) {
1306
+ for (size_t i = 0; i < stride_order_.size(); ++i) {
1307
+ if (stride_order_[i] != child_ptr->stride_order_[i]) {
1308
+ result = false;
1309
+ break;
1310
+ }
1311
+ }
1312
+ }
1313
+ if (result) {
1314
+ for (size_t i = 0; i < contiguity_.size(); ++i) {
1315
+ if (contiguity_[i] != child_ptr->contiguity_[i]) {
1316
+ result = false;
1317
+ break;
1318
+ }
1319
+ }
1320
+ }
1321
+ }
1322
+ }
1323
+ return result;
1324
+ }
1325
+
1326
+ void operator()(FusionState& fd) final {
1327
+ auto rank = shape_.size();
1328
+ std::vector<bool> is_expand(rank);
1329
+
1330
+ for (const auto index : c10::irange(rank)) {
1331
+ // since contiguity_ vector is given to the corresponding order in alloc
1332
+ // domain, while is_expand is given to root domain, we need to map it
1333
+ // correctly with `contig_index` and `index`.
1334
+ //
1335
+ // stride_order[i] indicates that:
1336
+ // `logical_domain[i]` (and therefore `root_domain[i]` for input) maps
1337
+ // to `alloc_domain[rank - 1 - stride_order_[i]]`
1338
+ //
1339
+ // Hence `index` on root domain would be corresponding to the contiguity
1340
+ // index `contig_index = rank - 1 - stride_order[index]`
1341
+ const auto contig_index = stride_order_.empty()
1342
+ ? index
1343
+ : rank - 1 - static_cast<size_t>(stride_order_[index]);
1344
+ const bool is_broadcast = !contiguity_[contig_index].has_value();
1345
+ const bool has_non_broadcast_size = (shape_[index] != 1);
1346
+ // A root dimension is expand dimension if:
1347
+ // The dimension is marked a broadcast; and
1348
+ // The dimension has an expanded extent.
1349
+ is_expand[index] = is_broadcast && has_non_broadcast_size;
1350
+ }
1351
+
1352
+ auto tv = TensorViewBuilder()
1353
+ .contiguity(contiguity_)
1354
+ .shape(shape_)
1355
+ .dtype(dtype_)
1356
+ .expanded(std::move(is_expand))
1357
+ .strideOrder(stride_order_)
1358
+ .build();
1359
+
1360
+ if (shape_.empty() && is_cpu_) {
1361
+ tv->setCpuScalar(true);
1362
+ } else {
1363
+ NVF_CHECK(!is_cpu_, "CPU non-scalar tensor is not supported!");
1364
+ }
1365
+
1366
+ fd.setFusionState(outputs_.at(0).index, tv);
1367
+ fd.addInput(tv, outputs_.at(0).index);
1368
+ }
1369
+
1370
+ void print(std::ostream& os, bool close_function = true) const final {
1371
+ RecordFunctor::print(os, false);
1372
+ os << "shape=[";
1373
+ bool first_arg = true;
1374
+ for (auto ss : shape_) {
1375
+ if (first_arg) {
1376
+ first_arg = false;
1377
+ } else {
1378
+ os << ", ";
1379
+ }
1380
+ os << ss;
1381
+ }
1382
+ os << "], contiguity=[";
1383
+ first_arg = true;
1384
+ for (auto ci : contiguity_) {
1385
+ if (first_arg) {
1386
+ first_arg = false;
1387
+ } else {
1388
+ os << ", ";
1389
+ }
1390
+ if (!ci.has_value()) {
1391
+ os << "None";
1392
+ } else {
1393
+ if (*ci) {
1394
+ os << "True";
1395
+ } else {
1396
+ os << "False";
1397
+ }
1398
+ }
1399
+ }
1400
+ os << "], dtype=" << dtypeToPyString(dtype_);
1401
+ os << ", is_cpu=" << (is_cpu_ ? "True" : "False");
1402
+ if (!stride_order_.empty()) {
1403
+ os << ", stride_order=[";
1404
+ bool first_arg = true;
1405
+ for (auto item : stride_order_) {
1406
+ if (first_arg) {
1407
+ first_arg = false;
1408
+ } else {
1409
+ os << ", ";
1410
+ }
1411
+ os << item;
1412
+ }
1413
+ os << "]";
1414
+ }
1415
+ if (close_function) {
1416
+ os << ")";
1417
+ }
1418
+ }
1419
+
1420
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1421
+ flatbuffers::FlatBufferBuilder& builder) const final {
1422
+ auto fb_sizes = builder.CreateVector(shape_);
1423
+
1424
+ auto mapOptionalToEnum = [](std::optional<bool> v) -> serde::Contiguity {
1425
+ if (!v.has_value()) {
1426
+ return serde::Contiguity::None;
1427
+ } else if (v.value()) {
1428
+ return serde::Contiguity::Contiguous;
1429
+ } else {
1430
+ return serde::Contiguity::Strided;
1431
+ }
1432
+ };
1433
+ std::vector<serde::Contiguity> contiguity_enum;
1434
+ std::transform(
1435
+ contiguity_.cbegin(),
1436
+ contiguity_.cend(),
1437
+ std::back_inserter(contiguity_enum),
1438
+ mapOptionalToEnum);
1439
+ auto fb_contiguity_enum = builder.CreateVector(contiguity_enum);
1440
+ auto fb_stride_order = builder.CreateVector(stride_order_);
1441
+
1442
+ serde::TensorBuilder tensor_builder(builder);
1443
+ tensor_builder.add_sizes(fb_sizes);
1444
+ tensor_builder.add_contiguity(fb_contiguity_enum);
1445
+ tensor_builder.add_stride_order(fb_stride_order);
1446
+ tensor_builder.add_dtype(toUnderlying(dtype_));
1447
+ tensor_builder.add_is_cpu(is_cpu_);
1448
+ auto expr_data = tensor_builder.Finish();
1449
+ return {serde::RecordData::Tensor, expr_data.Union()};
1450
+ }
1451
+
1452
+ private:
1453
+ //! A vector of tensor dimension sizes.
1454
+ //! This vector only captures sizes of -1 or 1 to indicate a symbolic
1455
+ //! dimension (-1) or a broadcast dimension (1).
1456
+ std::vector<int64_t> shape_;
1457
+ //! A vector to indicate whether the a tensor dimension is contiguous
1458
+ //! with the dimension just to its right.
1459
+ std::vector<std::optional<bool>> contiguity_;
1460
+ //! A vector to indicate stride order of tensor
1461
+ std::vector<int64_t> stride_order_;
1462
+ //! Tensor data type.
1463
+ PrimDataType dtype_;
1464
+ //! Notes a scalar CPU Tensor
1465
+ bool is_cpu_;
1466
+ };
1467
+
1468
+ //! Specialized Record Functor for recording FusionState outputs.
1469
+
1470
+ template <class OutputType>
1471
+ struct OutputRecord : RecordFunctor {
1472
+ OutputRecord(
1473
+ std::vector<State> _args,
1474
+ serde::RecordType record_type,
1475
+ std::vector<int64_t> stride_order = {})
1476
+ : RecordFunctor(std::move(_args), {}, "add_output", record_type) {
1477
+ if (!stride_order.empty()) {
1478
+ stride_order_ = stride_order;
1479
+ }
1480
+ }
1481
+ ~OutputRecord() override = default;
1482
+ RecordFunctor* clone() final {
1483
+ return new OutputRecord(*this);
1484
+ }
1485
+
1486
+ //! Nothing extra necessary in hash
1487
+ //! Child specific hash function in lower 32 bits.
1488
+ //! | 31 ---------------------------------------- 0 |
1489
+ //! | stride_order hash |
1490
+ size_t hash() const final {
1491
+ size_t stride_order_hash = 0;
1492
+ for (auto i : c10::irange(stride_order_.size())) {
1493
+ stride_order_hash = (stride_order_hash << 4) | stride_order_[i];
1494
+ }
1495
+ return RecordFunctor::hash() | (stride_order_hash & 0xffffffff);
1496
+ }
1497
+
1498
+ bool operator==(const RecordFunctor& other) const final {
1499
+ auto result = false;
1500
+ if (auto child_ptr = dynamic_cast<const OutputRecord*>(&other)) {
1501
+ result = RecordFunctor::operator==(other);
1502
+ if (result) {
1503
+ result = (stride_order_.size() == child_ptr->stride_order_.size());
1504
+ if (result) {
1505
+ for (size_t i = 0; i < stride_order_.size(); ++i) {
1506
+ if (stride_order_[i] != child_ptr->stride_order_[i]) {
1507
+ result = false;
1508
+ break;
1509
+ }
1510
+ }
1511
+ }
1512
+ }
1513
+ }
1514
+ return result;
1515
+ }
1516
+
1517
+ void operator()(FusionState& fd) final {
1518
+ auto output = fd.getFusionState(args_.at(0).index);
1519
+ Val* alias_input = nullptr;
1520
+ if (args_.size() == 2) {
1521
+ alias_input = fd.getFusionState(args_.at(1).index);
1522
+ }
1523
+
1524
+ if (alias_input) {
1525
+ NVF_CHECK(
1526
+ stride_order_.empty(),
1527
+ "stride_order can't be dictated for aliased outputs.");
1528
+ if constexpr (std::is_same_v<OutputType, TensorView>) {
1529
+ fd.aliasOutputToInput(output, alias_input);
1530
+ } else {
1531
+ NVF_THROW("Scalar outputs should not alias inputs.");
1532
+ }
1533
+ } else {
1534
+ if constexpr (std::is_same_v<OutputType, TensorView>) {
1535
+ auto tv_output = output->template as<TensorView>();
1536
+ if (!stride_order_.empty()) {
1537
+ auto logical_domain = tv_output->getLogicalDomain();
1538
+ std::vector<IterDomain*> allocation_domain =
1539
+ ir_utils::strideOrderToAllocation(logical_domain, stride_order_);
1540
+ tv_output->setAllocationDomain(allocation_domain, true);
1541
+ }
1542
+ fd.addOutput(tv_output, args_.at(0).index);
1543
+ } else {
1544
+ NVF_CHECK(
1545
+ stride_order_.empty(),
1546
+ "stride_order can't be dictated for scalar outputs.");
1547
+ fd.addOutput(output, args_.at(0).index);
1548
+ }
1549
+ }
1550
+ }
1551
+
1552
+ void print(std::ostream& os, bool close_function = true) const final {
1553
+ RecordFunctor::print(os, false);
1554
+ if (!stride_order_.empty()) {
1555
+ os << ", stride_order=[";
1556
+ bool first_arg = true;
1557
+ for (auto item : stride_order_) {
1558
+ if (first_arg) {
1559
+ first_arg = false;
1560
+ } else {
1561
+ os << ", ";
1562
+ }
1563
+ os << item;
1564
+ }
1565
+ os << "]";
1566
+ }
1567
+ if (close_function) {
1568
+ os << ")";
1569
+ }
1570
+ }
1571
+
1572
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1573
+ flatbuffers::FlatBufferBuilder& builder) const final {
1574
+ return {
1575
+ serde::RecordData::Output,
1576
+ serde::CreateOutputDirect(builder, &stride_order_).Union()};
1577
+ }
1578
+
1579
+ private:
1580
+ //! The tensor dimensions to reduce
1581
+ std::vector<int64_t> stride_order_;
1582
+ };
1583
+
1584
+ //! Specialized Record Functor for the FusionState's sum/min/max ops.
1585
+
1586
+ struct ReductionOpRecord : RecordFunctor {
1587
+ ReductionOpRecord(
1588
+ std::vector<State> _args,
1589
+ std::vector<State> _outputs,
1590
+ std::string _name,
1591
+ serde::RecordType record_type,
1592
+ std::function<
1593
+ TensorView*(TensorView*, const std::vector<int64_t>&, bool, DataType)>
1594
+ fusion_op,
1595
+ std::vector<int64_t> axes,
1596
+ bool keep_dim,
1597
+ PrimDataType dtype)
1598
+ : RecordFunctor(
1599
+ std::move(_args),
1600
+ std::move(_outputs),
1601
+ _name,
1602
+ record_type),
1603
+ fusion_op_(std::move(fusion_op)),
1604
+ axes_(std::move(axes)),
1605
+ keep_dim_(keep_dim),
1606
+ dtype_(dtype) {}
1607
+ ~ReductionOpRecord() override = default;
1608
+ RecordFunctor* clone() final {
1609
+ return new ReductionOpRecord(*this);
1610
+ }
1611
+
1612
+ //! Child specific hash function in lower 32 bits.
1613
+ //! | 31 -- 28 | 27 --- 20 | 19 ----------------- 0 |
1614
+ //! | keep_dim | Dtype | Axes Hash |
1615
+ size_t hash() const final {
1616
+ auto result = RecordFunctor::hash();
1617
+ size_t axes_hash = 0;
1618
+ // Normally I would make a little endian hash of the axes but I do not
1619
+ // know the size of the tensor based on just the record information.
1620
+ for (auto i : c10::irange(axes_.size())) {
1621
+ axes_hash |= (1 << axes_[i]);
1622
+ }
1623
+
1624
+ return result | (static_cast<size_t>(keep_dim_) << 28) |
1625
+ ((static_cast<size_t>(dtype_) & 0xff) << 20) | (axes_hash & 0xfffff);
1626
+ }
1627
+
1628
+ bool operator==(const RecordFunctor& other) const final {
1629
+ auto result = false;
1630
+ if (auto child_ptr = dynamic_cast<const ReductionOpRecord*>(&other)) {
1631
+ result = RecordFunctor::operator==(other);
1632
+ if (result) {
1633
+ result = result &&
1634
+ (fusion_op_.target_type() == child_ptr->fusion_op_.target_type());
1635
+ if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) {
1636
+ debug() << "\nReductionOpRecord: " << name_
1637
+ << " Target Type [self: 0x" << fusion_op_.target_type().name()
1638
+ << "] [other: 0x"
1639
+ << child_ptr->fusion_op_.target_type().name() << "]";
1640
+ }
1641
+ // IMPORTANT! you need to dereference the target pointer in order
1642
+ // to match the function
1643
+ result = result &&
1644
+ (*fusion_op_.template target<
1645
+
1646
+ TensorView* (*)(TensorView*,
1647
+ const std::vector<int64_t>&,
1648
+ bool,
1649
+ DataType)>() ==
1650
+ *child_ptr->fusion_op_.template target<
1651
+
1652
+ TensorView* (*)(TensorView*,
1653
+ const std::vector<int64_t>&,
1654
+ bool,
1655
+ DataType)>());
1656
+ if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) {
1657
+ debug() << " Target Ptr [self: 0x" << std::hex
1658
+ << (size_t)*fusion_op_.template target<
1659
+
1660
+ TensorView* (*)(TensorView*,
1661
+ const std::vector<int64_t>&,
1662
+ bool,
1663
+ DataType)>()
1664
+ << "] [other: 0x" << std::hex
1665
+ << (size_t)*child_ptr->fusion_op_.template target<
1666
+
1667
+ TensorView* (*)(TensorView*,
1668
+ const std::vector<int64_t>&,
1669
+ bool,
1670
+ DataType)>()
1671
+ << "]\n";
1672
+ }
1673
+ result = result && (keep_dim_ == child_ptr->keep_dim_);
1674
+ result = result && (dtype_ == child_ptr->dtype_);
1675
+ if (result) {
1676
+ result = (axes_.size() == child_ptr->axes_.size());
1677
+ if (result) {
1678
+ for (size_t i = 0; i < axes_.size(); ++i) {
1679
+ if (axes_[i] != child_ptr->axes_[i]) {
1680
+ result = false;
1681
+ break;
1682
+ }
1683
+ }
1684
+ }
1685
+ }
1686
+ }
1687
+ }
1688
+ return result;
1689
+ }
1690
+
1691
+ void operator()(FusionState& fd) final {
1692
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
1693
+ auto output = fusion_op_(arg, axes_, keep_dim_, dtype_);
1694
+ fd.setFusionState(outputs_.at(0).index, output);
1695
+ }
1696
+
1697
+ void print(std::ostream& os, bool close_function = true) const final {
1698
+ RecordFunctor::print(os, false);
1699
+ os << ", dims=[";
1700
+ bool first_arg = true;
1701
+ for (auto axis : axes_) {
1702
+ if (first_arg) {
1703
+ first_arg = false;
1704
+ } else {
1705
+ os << ", ";
1706
+ }
1707
+ os << axis;
1708
+ }
1709
+ os << "]";
1710
+ os << ", keepdim=" << (keep_dim_ ? "True" : "False");
1711
+ os << ", dtype=" << dtypeToPyString(dtype_);
1712
+ if (close_function) {
1713
+ os << ")";
1714
+ }
1715
+ }
1716
+
1717
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1718
+ flatbuffers::FlatBufferBuilder& builder) const final {
1719
+ // TODO add dtype
1720
+ return {
1721
+ serde::RecordData::Reduction,
1722
+ serde::CreateReductionDirect(
1723
+ builder, &axes_, keep_dim_, toUnderlying(dtype_))
1724
+ .Union()};
1725
+ }
1726
+
1727
+ private:
1728
+ //! nvFuser arith function signature for a given reduction operation
1729
+ std::function<
1730
+ TensorView*(TensorView*, const std::vector<int64_t>&, bool, DataType)>
1731
+ fusion_op_;
1732
+ //! The tensor dimensions to reduce
1733
+ std::vector<int64_t> axes_;
1734
+ //! Indicates whether to keep the reduced dimension(s).
1735
+ bool keep_dim_;
1736
+ //! The output data type.
1737
+ PrimDataType dtype_;
1738
+ };
1739
+
1740
+ struct IndexSelectOpRecord : RecordFunctor {
1741
+ IndexSelectOpRecord(
1742
+ std::vector<State> _args,
1743
+ std::vector<State> _outputs,
1744
+ int64_t dim)
1745
+ : RecordFunctor(
1746
+ std::move(_args),
1747
+ std::move(_outputs),
1748
+ "ops.index_select",
1749
+ serde::RecordType::IndexSelectOp),
1750
+ dim_(dim) {}
1751
+ ~IndexSelectOpRecord() override = default;
1752
+ RecordFunctor* clone() final {
1753
+ return new IndexSelectOpRecord(*this);
1754
+ }
1755
+
1756
+ bool operator==(const RecordFunctor& other) const final {
1757
+ auto result = false;
1758
+ if (auto child_ptr = dynamic_cast<const IndexSelectOpRecord*>(&other)) {
1759
+ result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_;
1760
+ }
1761
+ return result;
1762
+ }
1763
+
1764
+ void operator()(FusionState& fd) final {
1765
+ auto arg1 = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
1766
+ auto arg3 = fd.getFusionState(args_.at(1).index)->template as<TensorView>();
1767
+
1768
+ Val* output = indexSelect(arg1, dim_, arg3);
1769
+ fd.setFusionState(outputs_.at(0).index, output);
1770
+ }
1771
+
1772
+ void print(std::ostream& os, bool close_function = true) const final {
1773
+ RecordFunctor::print(os, false);
1774
+ os << ", dim=" << dim_;
1775
+ if (close_function) {
1776
+ os << ")";
1777
+ }
1778
+ }
1779
+
1780
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1781
+ flatbuffers::FlatBufferBuilder& builder) const final {
1782
+ return {
1783
+ serde::RecordData::Dimension,
1784
+ serde::CreateDimension(builder, dim_).Union()};
1785
+ }
1786
+
1787
+ private:
1788
+ //! Dimension to select.
1789
+ int64_t dim_;
1790
+ };
1791
+
1792
+ // TODO Merge IndexSelectOpRecord and SelectOpRecord for cleaner interface.
1793
+ // If the index TensorView is a scalar, then use select operation.
1794
+ struct SelectOpRecord : RecordFunctor {
1795
+ SelectOpRecord(
1796
+ std::vector<State> _args,
1797
+ std::vector<State> _outputs,
1798
+ int64_t dim)
1799
+ : RecordFunctor(
1800
+ std::move(_args),
1801
+ std::move(_outputs),
1802
+ "ops.select",
1803
+ serde::RecordType::SelectOp),
1804
+ dim_(dim) {}
1805
+ ~SelectOpRecord() override = default;
1806
+ RecordFunctor* clone() final {
1807
+ return new SelectOpRecord(*this);
1808
+ }
1809
+
1810
+ bool operator==(const RecordFunctor& other) const final {
1811
+ auto result = false;
1812
+ if (auto child_ptr = dynamic_cast<const SelectOpRecord*>(&other)) {
1813
+ result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_;
1814
+ }
1815
+ return result;
1816
+ }
1817
+
1818
+ void operator()(FusionState& fd) final {
1819
+ auto arg1 = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
1820
+ auto arg3 = fd.getFusionState(args_.at(1).index);
1821
+
1822
+ Val* output = select(arg1, dim_, arg3);
1823
+ fd.setFusionState(outputs_.at(0).index, output);
1824
+ }
1825
+
1826
+ void print(std::ostream& os, bool close_function = true) const final {
1827
+ RecordFunctor::print(os, false);
1828
+ os << ", dim=" << dim_;
1829
+ if (close_function) {
1830
+ os << ")";
1831
+ }
1832
+ }
1833
+
1834
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1835
+ flatbuffers::FlatBufferBuilder& builder) const final {
1836
+ return {
1837
+ serde::RecordData::Dimension,
1838
+ serde::CreateDimension(builder, dim_).Union()};
1839
+ }
1840
+
1841
+ private:
1842
+ //! Dimension to select.
1843
+ int64_t dim_;
1844
+ };
1845
+
1846
+ struct TorchGatherOpRecord : RecordFunctor {
1847
+ TorchGatherOpRecord(
1848
+ std::vector<State> _args,
1849
+ std::vector<State> _outputs,
1850
+ int64_t dim)
1851
+ : RecordFunctor(
1852
+ std::move(_args),
1853
+ std::move(_outputs),
1854
+ "ops.gather",
1855
+ serde::RecordType::TorchGatherOp),
1856
+ dim_(dim) {}
1857
+ ~TorchGatherOpRecord() override = default;
1858
+ RecordFunctor* clone() final {
1859
+ return new TorchGatherOpRecord(*this);
1860
+ }
1861
+
1862
+ void operator()(FusionState& fd) final {
1863
+ auto arg1 = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
1864
+ auto arg3 = fd.getFusionState(args_.at(1).index)->template as<TensorView>();
1865
+
1866
+ Val* output = torchGather(arg1, dim_, arg3);
1867
+ fd.setFusionState(outputs_.at(0).index, output);
1868
+ }
1869
+
1870
+ bool operator==(const RecordFunctor& other) const final {
1871
+ auto result = false;
1872
+ if (auto child_ptr = dynamic_cast<const TorchGatherOpRecord*>(&other)) {
1873
+ result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_;
1874
+ }
1875
+ return result;
1876
+ }
1877
+
1878
+ void print(std::ostream& os, bool close_function = true) const final {
1879
+ RecordFunctor::print(os, false);
1880
+ os << ", dim=" << dim_;
1881
+ if (close_function) {
1882
+ os << ")";
1883
+ }
1884
+ }
1885
+
1886
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1887
+ flatbuffers::FlatBufferBuilder& builder) const final {
1888
+ return {
1889
+ serde::RecordData::Dimension,
1890
+ serde::CreateDimension(builder, dim_).Union()};
1891
+ }
1892
+
1893
+ private:
1894
+ //! Dimension to select.
1895
+ int64_t dim_;
1896
+ };
1897
+
1898
+ //! Similar to TorchGatherOpRecord but enforces that non-index dimension
1899
+ //! extents match between index tensor and value tensor.
1900
+ struct TakeAlongAxisOpRecord : RecordFunctor {
1901
+ TakeAlongAxisOpRecord(
1902
+ std::vector<State> _args,
1903
+ std::vector<State> _outputs,
1904
+ int64_t dim)
1905
+ : RecordFunctor(
1906
+ std::move(_args),
1907
+ std::move(_outputs),
1908
+ "ops.take_along_axis",
1909
+ serde::RecordType::TakeAlongAxisOp),
1910
+ dim_(dim) {}
1911
+ ~TakeAlongAxisOpRecord() override = default;
1912
+ RecordFunctor* clone() final {
1913
+ return new TakeAlongAxisOpRecord(*this);
1914
+ }
1915
+
1916
+ void operator()(FusionState& fd) final {
1917
+ auto arg1 = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
1918
+ auto arg3 = fd.getFusionState(args_.at(1).index)->template as<TensorView>();
1919
+
1920
+ Val* output = takeAlongAxis(arg1, arg3, dim_);
1921
+ fd.setFusionState(outputs_.at(0).index, output);
1922
+ }
1923
+
1924
+ bool operator==(const RecordFunctor& other) const final {
1925
+ auto result = false;
1926
+ if (auto child_ptr = dynamic_cast<const TakeAlongAxisOpRecord*>(&other)) {
1927
+ result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_;
1928
+ }
1929
+ return result;
1930
+ }
1931
+
1932
+ void print(std::ostream& os, bool close_function = true) const final {
1933
+ RecordFunctor::print(os, false);
1934
+ os << ", dim=" << dim_;
1935
+ if (close_function) {
1936
+ os << ")";
1937
+ }
1938
+ }
1939
+
1940
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
1941
+ flatbuffers::FlatBufferBuilder& builder) const final {
1942
+ return {
1943
+ serde::RecordData::Dimension,
1944
+ serde::CreateDimension(builder, dim_).Union()};
1945
+ }
1946
+
1947
+ private:
1948
+ //! Dimension to select.
1949
+ int64_t dim_;
1950
+ };
1951
+
1952
+ //! Specialized Record Functor for recording FusionState scalars for both
1953
+ //! inputs and constants.
1954
+
1955
+ struct ScalarRecord : RecordFunctor {
1956
+ ScalarRecord(
1957
+ std::vector<State> _outputs,
1958
+ PolymorphicValue value,
1959
+ std::optional<PrimDataType> dtype,
1960
+ bool inline_def = false)
1961
+ : RecordFunctor(
1962
+ {},
1963
+ std::move(_outputs),
1964
+ "define_scalar",
1965
+ serde::RecordType::Scalar,
1966
+ inline_def),
1967
+ value_(
1968
+ dtype.has_value() ? castToDtype(std::move(value), dtype.value())
1969
+ : std::move(value)),
1970
+ dtype_(
1971
+ dtype.has_value()
1972
+ ? dtype.value()
1973
+ : std::get<PrimDataType>(getDataType(value_).type)) {}
1974
+ ~ScalarRecord() override = default;
1975
+ RecordFunctor* clone() final {
1976
+ return new ScalarRecord(*this);
1977
+ }
1978
+
1979
+ //! Child specific hash function in lower 32 bits.
1980
+ //! | 31 --------------------------------------- 0 |
1981
+ //! | Dtype |
1982
+ size_t hash() const final {
1983
+ auto result = RecordFunctor::hash();
1984
+ return result | (static_cast<size_t>(dtype_) & 0xffffffff);
1985
+ }
1986
+
1987
+ bool operator==(const RecordFunctor& other) const final {
1988
+ if (auto child_ptr = dynamic_cast<const ScalarRecord*>(&other)) {
1989
+ if (RecordFunctor::operator==(other)) {
1990
+ if (value_.hasValue() != child_ptr->value_.hasValue() ||
1991
+ dtype_ != child_ptr->dtype_) {
1992
+ return false;
1993
+ }
1994
+ if (value_.hasValue()) {
1995
+ if (value_.is<double>() && std::isnan(value_.as<double>()) &&
1996
+ std::isnan(child_ptr->value_.as<double>())) {
1997
+ return true;
1998
+ } else {
1999
+ return value_ == child_ptr->value_;
2000
+ }
2001
+ } else {
2002
+ return true;
2003
+ }
2004
+ }
2005
+ }
2006
+ return false;
2007
+ }
2008
+
2009
+ void operator()(FusionState& fd) final {
2010
+ Val* output = IrBuilder::create<nvfuser::Val>(value_, dtype_);
2011
+ if (!value_.hasValue()) {
2012
+ fd.addInput(output, outputs_.at(0).index);
2013
+ }
2014
+ fd.setFusionState(outputs_.at(0).index, output);
2015
+ }
2016
+
2017
+ void print(std::ostream& os, bool close_function = true) const final {
2018
+ if (inline_def_) {
2019
+ NVF_CHECK(
2020
+ value_.hasValue(),
2021
+ "Only ScalarRecords with values support inline definitions!");
2022
+ if (value_.is<bool>()) {
2023
+ NVF_CHECK(
2024
+ dtype_ == PrimDataType::Bool,
2025
+ "A ScalarRecord for Bool inline definition not have a matching data type!");
2026
+ os << ((bool)value_ ? "True" : "False");
2027
+ } else if (value_.is<double>()) {
2028
+ NVF_CHECK(
2029
+ dtype_ == PrimDataType::Double,
2030
+ "A ScalarRecord for Double inline definition not have a matching data type!");
2031
+ if (std::isinf(value_.as<double>())) {
2032
+ if (std::signbit(value_.as<double>())) {
2033
+ os << "float(\"-inf\")";
2034
+ } else {
2035
+ os << "float(\"inf\")";
2036
+ }
2037
+ } else if (std::isnan(value_.as<double>())) {
2038
+ os << "float(\"nan\")";
2039
+ } else {
2040
+ os << std::showpoint << value_.as<double>();
2041
+ }
2042
+ } else if (value_.is<int64_t>()) {
2043
+ NVF_CHECK(
2044
+ dtype_ == PrimDataType::Int,
2045
+ "A ScalarRecord for Int inline definition not have a matching data type!");
2046
+ os << value_;
2047
+ } else {
2048
+ NVF_THROW("A ScalarRecord with an unsupported inline definition type!");
2049
+ }
2050
+ // NOTE: close_function is not relevant for the inline definition as the
2051
+ // printing is specific to each operator and not partially done with the
2052
+ // base class print method.
2053
+ } else {
2054
+ RecordFunctor::print(os, false);
2055
+ if (value_.hasValue()) {
2056
+ if (value_.is<bool>()) {
2057
+ os << ((bool)value_ ? "True" : "False");
2058
+ } else if (value_.is<std::complex<double>>()) {
2059
+ os << std::showpoint << std::real(value_.as<std::complex<double>>())
2060
+ << "+" << std::showpoint
2061
+ << std::imag(value_.as<std::complex<double>>()) << "j";
2062
+ } else if (value_.is<double>()) {
2063
+ if (std::isinf(value_.as<double>())) {
2064
+ if (std::signbit(value_.as<double>())) {
2065
+ os << "float(\"-inf\")";
2066
+ } else {
2067
+ os << "float(\"inf\")";
2068
+ }
2069
+ } else if (std::isnan(value_.as<double>())) {
2070
+ os << "float(\"nan\")";
2071
+ } else {
2072
+ os << std::showpoint << value_.as<double>();
2073
+ }
2074
+ } else if (value_.is<int64_t>()) {
2075
+ os << value_;
2076
+ } else {
2077
+ NVF_CHECK(false, "Unsupported dtype.");
2078
+ }
2079
+ } else {
2080
+ os << "None";
2081
+ }
2082
+
2083
+ os << ", dtype=" << dtypeToPyString(dtype_);
2084
+
2085
+ if (close_function) {
2086
+ os << ")";
2087
+ }
2088
+ }
2089
+ }
2090
+
2091
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2092
+ flatbuffers::FlatBufferBuilder& builder) const final {
2093
+ return {
2094
+ serde::RecordData::Scalar,
2095
+ serde::serializeScalar(builder, value_, dtype_).Union()};
2096
+ }
2097
+
2098
+ inline std::pair<serde::RecordData, flatbuffers::Offset<void>> valueRecordData(
2099
+ flatbuffers::FlatBufferBuilder& builder,
2100
+ PolymorphicValue value) const;
2101
+
2102
+ private:
2103
+ //! The scalar's value, an input is a nullopt
2104
+ PolymorphicValue value_;
2105
+ //! Scalar data type.
2106
+ PrimDataType dtype_;
2107
+ };
2108
+
2109
+ //! Specialized Record Functor for recording FusionDefinition Start.
2110
+ //! There should only ever be one instance of this Record in the
2111
+ //! Fusion Cache.
2112
+
2113
+ struct StartRecord : RecordFunctor {
2114
+ StartRecord() : RecordFunctor({}, {}, "start", serde::RecordType::Start) {}
2115
+ ~StartRecord() override = default;
2116
+ RecordFunctor* clone() final {
2117
+ return new StartRecord(*this);
2118
+ }
2119
+
2120
+ //! Child specific hash function in lower 32 bits.
2121
+ //! | 31 --------------------------------------- 0 |
2122
+ //! | None |
2123
+ size_t hash() const final {
2124
+ return RecordFunctor::hash();
2125
+ }
2126
+
2127
+ bool operator==(const RecordFunctor& other) const final {
2128
+ auto result = false;
2129
+ if (dynamic_cast<const StartRecord*>(&other)) {
2130
+ result = RecordFunctor::operator==(other);
2131
+ }
2132
+ return result;
2133
+ }
2134
+
2135
+ void operator()(FusionState& fd) final {}
2136
+ };
2137
+
2138
+ //! Specialized Record Functors for Normalization based ops.
2139
+
2140
+ struct NormOpRecord : RecordFunctor {
2141
+ NormOpRecord(
2142
+ std::vector<State> args,
2143
+ std::vector<State> outputs,
2144
+ std::string name,
2145
+ serde::RecordType type,
2146
+ std::vector<int64_t> axes,
2147
+ int64_t correction,
2148
+ bool keep_dim)
2149
+ : RecordFunctor(std::move(args), std::move(outputs), name, type),
2150
+ axes_(std::move(axes)),
2151
+ correction_(correction),
2152
+ keep_dim_(keep_dim) {}
2153
+ ~NormOpRecord() override = default;
2154
+ RecordFunctor* clone() override = 0;
2155
+
2156
+ // I am skipping the bassel's correction value in the hash because
2157
+ // I suspect we might change it to a bool from a 64-bit value
2158
+ //! Child specific hash function in lower 32 bits.
2159
+ //! | 31 -- 28 | 27 ----------------------------- 0 |
2160
+ //! | keep_dim | Axes Hash |
2161
+ size_t hash() const final {
2162
+ auto result = RecordFunctor::hash();
2163
+ size_t axes_hash = 0;
2164
+ // Normally I would make a little endian hash of the axes but I do not
2165
+ // know the size of the tensor based on just the record information.
2166
+ for (auto i : c10::irange(axes_.size())) {
2167
+ axes_hash |= (1 << axes_[i]);
2168
+ }
2169
+ return result | (static_cast<size_t>(keep_dim_) << 28) |
2170
+ (axes_hash & 0xfffffff);
2171
+ }
2172
+
2173
+ bool operator==(const RecordFunctor& other) const final {
2174
+ auto result = false;
2175
+ if (auto child_ptr = dynamic_cast<const NormOpRecord*>(&other)) {
2176
+ result = RecordFunctor::operator==(other);
2177
+ result = result && (correction_ == child_ptr->correction_);
2178
+ result = result && (keep_dim_ == child_ptr->keep_dim_);
2179
+ if (result) {
2180
+ result = (axes_.size() == child_ptr->axes_.size());
2181
+ if (result) {
2182
+ for (size_t i = 0; i < axes_.size(); ++i) {
2183
+ if (axes_[i] != child_ptr->axes_[i]) {
2184
+ result = false;
2185
+ break;
2186
+ }
2187
+ }
2188
+ }
2189
+ }
2190
+ }
2191
+ return result;
2192
+ }
2193
+
2194
+ //! Each NormOp Child should define the operator() to build the IR
2195
+ void operator()(FusionState& fd) override = 0;
2196
+
2197
+ void print(std::ostream& os, bool close_function = true) const final {
2198
+ RecordFunctor::print(os, false);
2199
+ os << ", dims=[";
2200
+ bool first_arg = true;
2201
+ for (auto axis : axes_) {
2202
+ if (first_arg) {
2203
+ first_arg = false;
2204
+ } else {
2205
+ os << ", ";
2206
+ }
2207
+ os << axis;
2208
+ }
2209
+ os << "]";
2210
+ os << ", correction=" << correction_;
2211
+ os << ", keepdim=" << (keep_dim_ ? "True" : "False");
2212
+ if (close_function) {
2213
+ os << ")";
2214
+ }
2215
+ }
2216
+
2217
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2218
+ flatbuffers::FlatBufferBuilder& builder) const final {
2219
+ return {
2220
+ serde::RecordData::Norm,
2221
+ serde::CreateNormDirect(builder, &axes_, correction_, keep_dim_)
2222
+ .Union()};
2223
+ }
2224
+
2225
+ protected:
2226
+ //! Dimensions of tensor to reduce for variance calculation
2227
+ std::vector<int64_t> axes_;
2228
+ //! Bessel's correction value
2229
+ int64_t correction_;
2230
+ //! Indicates whether to keep the reduced dimension(s).
2231
+ bool keep_dim_;
2232
+ };
2233
+
2234
+ struct VarianceOpRecord : NormOpRecord {
2235
+ VarianceOpRecord(
2236
+ std::vector<State> args,
2237
+ std::vector<State> outputs,
2238
+ std::vector<int64_t> axes,
2239
+ int64_t correction,
2240
+ bool keep_dim)
2241
+ : NormOpRecord(
2242
+ std::move(args),
2243
+ std::move(outputs),
2244
+ "ops.var",
2245
+ serde::RecordType::VarianceOp,
2246
+ std::move(axes),
2247
+ correction,
2248
+ keep_dim) {}
2249
+ ~VarianceOpRecord() override = default;
2250
+ RecordFunctor* clone() final {
2251
+ return new VarianceOpRecord(*this);
2252
+ }
2253
+
2254
+ void operator()(FusionState& fd) final {
2255
+ auto arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2256
+ auto output = variance(arg, axes_, correction_, keep_dim_);
2257
+ fd.setFusionState(outputs_.at(0).index, output);
2258
+ }
2259
+ };
2260
+
2261
+ //! VarianceMean requires a separate Record because nvFuser defines the output
2262
+ //! of var_mean as a custom struct.
2263
+ struct VarianceMeanOpRecord : NormOpRecord {
2264
+ VarianceMeanOpRecord(
2265
+ std::vector<State> args,
2266
+ std::vector<State> outputs,
2267
+ std::vector<int64_t> axes,
2268
+ int64_t correction,
2269
+ bool keep_dim)
2270
+ : NormOpRecord(
2271
+ std::move(args),
2272
+ std::move(outputs),
2273
+ "ops.var_mean",
2274
+ serde::RecordType::VarianceMeanOp,
2275
+ std::move(axes),
2276
+ correction,
2277
+ keep_dim) {}
2278
+ ~VarianceMeanOpRecord() override = default;
2279
+ RecordFunctor* clone() final {
2280
+ return new VarianceMeanOpRecord(*this);
2281
+ }
2282
+
2283
+ void operator()(FusionState& fd) final {
2284
+ auto arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2285
+ auto output = variance_mean(arg, axes_, correction_, keep_dim_);
2286
+ fd.setFusionState(outputs_.at(0).index, output.var);
2287
+ fd.setFusionState(outputs_.at(1).index, output.mean);
2288
+ }
2289
+ };
2290
+
2291
+ struct WelfordOpRecord : RecordFunctor {
2292
+ WelfordOpRecord(
2293
+ std::vector<State> _args,
2294
+ std::vector<State> _outputs,
2295
+ std::vector<int64_t> axes)
2296
+ : RecordFunctor(
2297
+ std::move(_args),
2298
+ std::move(_outputs),
2299
+ "ops.welford",
2300
+ serde::RecordType::WelfordOp),
2301
+ axes_(std::move(axes)) {}
2302
+ ~WelfordOpRecord() override = default;
2303
+ RecordFunctor* clone() final {
2304
+ return new WelfordOpRecord(*this);
2305
+ }
2306
+
2307
+ size_t hash() const final {
2308
+ auto result = RecordFunctor::hash();
2309
+ size_t axes_hash = 0;
2310
+ for (auto axis : axes_) {
2311
+ hashCombine(axes_hash, static_cast<size_t>(axis));
2312
+ }
2313
+ return result | (axes_hash & 0xffff);
2314
+ }
2315
+
2316
+ bool operator==(const RecordFunctor& other) const final {
2317
+ auto result = false;
2318
+ if (auto child_ptr = dynamic_cast<const WelfordOpRecord*>(&other)) {
2319
+ result = RecordFunctor::operator==(other);
2320
+ if (result) {
2321
+ result = (axes_.size() == child_ptr->axes_.size());
2322
+ if (result) {
2323
+ for (size_t i = 0; i < axes_.size(); ++i) {
2324
+ if (axes_[i] != child_ptr->axes_[i]) {
2325
+ result = false;
2326
+ break;
2327
+ }
2328
+ }
2329
+ }
2330
+ }
2331
+ }
2332
+ return result;
2333
+ }
2334
+
2335
+ void operator()(FusionState& fd) final {
2336
+ auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
2337
+ auto output = WelfordRaw(arg, axes_);
2338
+ fd.setFusionState(outputs_.at(0).index, output.avg);
2339
+ fd.setFusionState(outputs_.at(1).index, output.var_sum);
2340
+ fd.setFusionState(outputs_.at(2).index, output.n);
2341
+ }
2342
+
2343
+ void print(std::ostream& os, bool close_function = true) const final {
2344
+ RecordFunctor::print(os, false);
2345
+ os << ", dims=[";
2346
+ bool first_arg = true;
2347
+ for (auto axis : axes_) {
2348
+ if (first_arg) {
2349
+ first_arg = false;
2350
+ } else {
2351
+ os << ", ";
2352
+ }
2353
+ os << axis;
2354
+ }
2355
+ os << "]";
2356
+ if (close_function) {
2357
+ os << ")";
2358
+ }
2359
+ }
2360
+
2361
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2362
+ flatbuffers::FlatBufferBuilder& builder) const final {
2363
+ return {
2364
+ serde::RecordData::Welford,
2365
+ serde::CreateWelfordDirect(builder, &axes_).Union()};
2366
+ }
2367
+
2368
+ private:
2369
+ //! The tensor dimensions to reduce
2370
+ std::vector<int64_t> axes_;
2371
+ };
2372
+
2373
+ struct BatchNormOpRecord : RecordFunctor {
2374
+ BatchNormOpRecord(
2375
+ std::vector<State> args,
2376
+ std::vector<State> outputs,
2377
+ bool training,
2378
+ bool channels_last)
2379
+ : RecordFunctor(
2380
+ std::move(args),
2381
+ std::move(outputs),
2382
+ "ops.batch_norm",
2383
+ serde::RecordType::BatchNormOp),
2384
+ training_(training),
2385
+ channels_last_(channels_last) {}
2386
+ ~BatchNormOpRecord() override = default;
2387
+ RecordFunctor* clone() final {
2388
+ return new BatchNormOpRecord(*this);
2389
+ }
2390
+
2391
+ bool operator==(const RecordFunctor& other) const final {
2392
+ auto result = false;
2393
+ if (auto child_ptr = dynamic_cast<const BatchNormOpRecord*>(&other)) {
2394
+ result = RecordFunctor::operator==(other);
2395
+ result = result && (training_ == child_ptr->training_);
2396
+ result = result && (channels_last_ == child_ptr->channels_last_);
2397
+ }
2398
+ return result;
2399
+ }
2400
+
2401
+ size_t hash() const final {
2402
+ auto result = RecordFunctor::hash();
2403
+ return result | (static_cast<size_t>(training_) << 28) |
2404
+ (static_cast<size_t>(channels_last_) << 29);
2405
+ }
2406
+
2407
+ void operator()(FusionState& fd) final {
2408
+ auto x = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2409
+ auto weight = (args_.at(1).stype == serde::StateType::Tensor)
2410
+ ? fd.getFusionState(args_.at(1).index)->as<TensorView>()
2411
+ : nullptr;
2412
+ auto bias = (args_.at(2).stype == serde::StateType::Tensor)
2413
+ ? fd.getFusionState(args_.at(2).index)->as<TensorView>()
2414
+ : nullptr;
2415
+ auto running_mean = (args_.at(3).stype == serde::StateType::Tensor)
2416
+ ? fd.getFusionState(args_.at(3).index)->as<TensorView>()
2417
+ : nullptr;
2418
+ auto running_var = (args_.at(4).stype == serde::StateType::Tensor)
2419
+ ? fd.getFusionState(args_.at(4).index)->as<TensorView>()
2420
+ : nullptr;
2421
+ auto momentum = fd.getFusionState(args_.at(5).index)->as<Val>();
2422
+ auto eps = fd.getFusionState(args_.at(6).index)->as<Val>();
2423
+ auto output = batch_norm(
2424
+ x,
2425
+ weight,
2426
+ bias,
2427
+ running_mean,
2428
+ running_var,
2429
+ training_,
2430
+ momentum,
2431
+ eps,
2432
+ channels_last_);
2433
+ fd.setFusionState(outputs_.at(0).index, output.output);
2434
+ fd.setFusionState(outputs_.at(1).index, output.mean);
2435
+ fd.setFusionState(outputs_.at(2).index, output.invstd);
2436
+ }
2437
+
2438
+ void print(std::ostream& os, bool close_function = true) const final {
2439
+ RecordFunctor::print(os, false);
2440
+ os << ", training=" << (training_ ? "True" : "False");
2441
+ os << ", channels_last=" << (channels_last_ ? "True" : "False");
2442
+ if (close_function) {
2443
+ os << ")";
2444
+ }
2445
+ }
2446
+
2447
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2448
+ flatbuffers::FlatBufferBuilder& builder) const final {
2449
+ return {
2450
+ serde::RecordData::BatchNorm,
2451
+ serde::CreateBatchNorm(builder, training_, channels_last_).Union()};
2452
+ }
2453
+
2454
+ private:
2455
+ bool training_;
2456
+ bool channels_last_;
2457
+ };
2458
+
2459
+ //! Specialized Record Functor for the FusionState's tensor_size op.
2460
+ //! Uses the default hash() and print() methods of Record Functor
2461
+
2462
+ struct TensorSizesRecord : RecordFunctor {
2463
+ TensorSizesRecord(std::vector<State> args, std::vector<State> outputs)
2464
+ : RecordFunctor(
2465
+ std::move(args),
2466
+ std::move(outputs),
2467
+ "ops.tensor_sizes",
2468
+ serde::RecordType::TensorSizes) {
2469
+ always_returns_tuple_ = true;
2470
+ }
2471
+ ~TensorSizesRecord() override = default;
2472
+ RecordFunctor* clone() final {
2473
+ return new TensorSizesRecord(*this);
2474
+ }
2475
+
2476
+ bool operator==(const RecordFunctor& other) const final {
2477
+ auto result = false;
2478
+ if (dynamic_cast<const TensorSizesRecord*>(&other)) {
2479
+ result = RecordFunctor::operator==(other);
2480
+ }
2481
+ return result;
2482
+ }
2483
+
2484
+ void operator()(FusionState& fd) final {
2485
+ auto arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2486
+ auto sizes = shape(arg);
2487
+ for (const auto idx : c10::irange(sizes.size())) {
2488
+ fd.setFusionState(outputs_.at(idx).index, sizes[idx]);
2489
+ }
2490
+ }
2491
+ };
2492
+
2493
+ //! Specialized Record Functor for the shape op.
2494
+ //! Uses the default hash() and print() methods of Record Functor
2495
+
2496
+ struct ShapeOpRecord : RecordFunctor {
2497
+ ShapeOpRecord(std::vector<State> args, std::vector<State> outputs)
2498
+ : RecordFunctor(
2499
+ std::move(args),
2500
+ std::move(outputs),
2501
+ "ops.shape",
2502
+ serde::RecordType::ShapeOp) {}
2503
+ ~ShapeOpRecord() override = default;
2504
+ RecordFunctor* clone() final {
2505
+ return new ShapeOpRecord(*this);
2506
+ }
2507
+
2508
+ bool operator==(const RecordFunctor& other) const final {
2509
+ auto result = false;
2510
+ if (dynamic_cast<const ShapeOpRecord*>(&other)) {
2511
+ result = RecordFunctor::operator==(other);
2512
+ }
2513
+ return result;
2514
+ }
2515
+
2516
+ void operator()(FusionState& fd) final {
2517
+ auto arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2518
+ auto result = shape(arg);
2519
+ fd.setFusionStateVector(outputs_.at(0).index, result);
2520
+ }
2521
+ };
2522
+
2523
+ //! Specialized Record Functor for the size op.
2524
+ //! Uses the default hash() and print() methods of Record Functor
2525
+
2526
+ struct SizeOpRecord : RecordFunctor {
2527
+ SizeOpRecord(std::vector<State> args, std::vector<State> outputs, int64_t dim)
2528
+ : RecordFunctor(
2529
+ std::move(args),
2530
+ std::move(outputs),
2531
+ "ops.size",
2532
+ serde::RecordType::SizeOp),
2533
+ dim_(dim) {}
2534
+ ~SizeOpRecord() override = default;
2535
+ RecordFunctor* clone() final {
2536
+ return new SizeOpRecord(*this);
2537
+ }
2538
+
2539
+ //! Child specific hash function in lower 32 bits.
2540
+ //! | 31 -------------------------------------- 0 |
2541
+ //! | dim |
2542
+ size_t hash() const final {
2543
+ auto result = RecordFunctor::hash();
2544
+ return result | (static_cast<size_t>(dim_) & 0xffffffff);
2545
+ }
2546
+
2547
+ bool operator==(const RecordFunctor& other) const final {
2548
+ auto result = false;
2549
+ if (auto child_ptr = dynamic_cast<const SizeOpRecord*>(&other)) {
2550
+ result = RecordFunctor::operator==(other);
2551
+ result = result && (dim_ == child_ptr->dim_);
2552
+ }
2553
+ return result;
2554
+ }
2555
+
2556
+ void operator()(FusionState& fd) final {
2557
+ auto arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2558
+ auto result = size(arg, dim_);
2559
+ fd.setFusionState(outputs_.at(0).index, result);
2560
+ }
2561
+
2562
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2563
+ flatbuffers::FlatBufferBuilder& builder) const final {
2564
+ return {serde::RecordData::Size, serde::CreateSize(builder, dim_).Union()};
2565
+ }
2566
+
2567
+ void print(std::ostream& os, bool close_function = true) const final {
2568
+ RecordFunctor::print(os, false);
2569
+ os << ", dim=" << dim_;
2570
+ if (close_function) {
2571
+ os << ")";
2572
+ }
2573
+ }
2574
+
2575
+ private:
2576
+ int64_t dim_;
2577
+ };
2578
+
2579
+ //! Specialized Record Functor for the at() op.
2580
+ //! Uses the default hash() and print() methods of Record Functor
2581
+
2582
+ struct AtOpRecord : RecordFunctor {
2583
+ AtOpRecord(std::vector<State> args, std::vector<State> outputs, int64_t index)
2584
+ : RecordFunctor(
2585
+ std::move(args),
2586
+ std::move(outputs),
2587
+ "ops.at",
2588
+ serde::RecordType::AtOp),
2589
+ index_(index) {}
2590
+ ~AtOpRecord() override = default;
2591
+ RecordFunctor* clone() final {
2592
+ return new AtOpRecord(*this);
2593
+ }
2594
+
2595
+ //! Child specific hash function in lower 32 bits.
2596
+ //! | 31 -------------------------------------- 0 |
2597
+ //! | index |
2598
+ size_t hash() const final {
2599
+ auto result = RecordFunctor::hash();
2600
+ return result | (static_cast<size_t>(index_) & 0xffffffff);
2601
+ }
2602
+
2603
+ bool operator==(const RecordFunctor& other) const final {
2604
+ auto result = false;
2605
+ if (auto child_ptr = dynamic_cast<const AtOpRecord*>(&other)) {
2606
+ result = RecordFunctor::operator==(other);
2607
+ result = result && (index_ == child_ptr->index_);
2608
+ }
2609
+ return result;
2610
+ }
2611
+
2612
+ void operator()(FusionState& fd) final {
2613
+ NVF_CHECK(
2614
+ args_.at(0).stype == serde::StateType::Vector,
2615
+ "Expected Vector State!");
2616
+ const std::vector<Val*>& arg = fd.getFusionStateVector(args_.at(0).index);
2617
+ auto result = at(arg, index_);
2618
+ fd.setFusionState(outputs_.at(0).index, result);
2619
+ }
2620
+
2621
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2622
+ flatbuffers::FlatBufferBuilder& builder) const final {
2623
+ return {serde::RecordData::At, serde::CreateAt(builder, index_).Union()};
2624
+ }
2625
+
2626
+ void print(std::ostream& os, bool close_function = true) const final {
2627
+ RecordFunctor::print(os, false);
2628
+ os << ", index=" << index_;
2629
+ if (close_function) {
2630
+ os << ")";
2631
+ }
2632
+ }
2633
+
2634
+ private:
2635
+ int64_t index_;
2636
+ };
2637
+
2638
+ struct FullOpRecord : RecordFunctor {
2639
+ FullOpRecord(
2640
+ std::vector<State> _args,
2641
+ std::vector<State> _outputs,
2642
+ PrimDataType dtype)
2643
+ : RecordFunctor(
2644
+ std::move(_args),
2645
+ std::move(_outputs),
2646
+ "ops.full",
2647
+ serde::RecordType::FullOp),
2648
+ dtype_(dtype) {
2649
+ setArgName(0, "shape");
2650
+ setArgName(1, "fill_value");
2651
+ }
2652
+ ~FullOpRecord() override = default;
2653
+ RecordFunctor* clone() final {
2654
+ return new FullOpRecord(*this);
2655
+ }
2656
+
2657
+ //! Child specific hash function in lower 32 bits.
2658
+ //! | 31 -------------------------------------- 0 |
2659
+ //! | Dtype |
2660
+ size_t hash() const final {
2661
+ auto result = RecordFunctor::hash();
2662
+ result |= (static_cast<size_t>(dtype_) & 0xffffffff);
2663
+ return result;
2664
+ }
2665
+
2666
+ bool operator==(const RecordFunctor& other) const final {
2667
+ auto result = false;
2668
+ if (auto child_ptr = dynamic_cast<const FullOpRecord*>(&other)) {
2669
+ result = RecordFunctor::operator==(other) && dtype_ == child_ptr->dtype_;
2670
+ }
2671
+ return result;
2672
+ }
2673
+
2674
+ void operator()(FusionState& fd) final {
2675
+ const std::vector<Val*>& shape = fd.getFusionStateVector(args_.at(0).index);
2676
+ auto fill_value = fd.getFusionState(args_.at(1).index);
2677
+
2678
+ auto output = full(shape, fill_value, dtype_);
2679
+ fd.setFusionState(outputs_.at(0).index, output);
2680
+ }
2681
+
2682
+ void print(std::ostream& os, bool close_function = true) const override {
2683
+ RecordFunctor::print(os, false);
2684
+ os << ", dtype=" << dtypeToPyString(dtype_);
2685
+ if (close_function) {
2686
+ os << ")";
2687
+ }
2688
+ }
2689
+
2690
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2691
+ flatbuffers::FlatBufferBuilder& builder) const final {
2692
+ return {
2693
+ serde::RecordData::TensorCreationSymbolic,
2694
+ serde::CreateTensorCreationSymbolic(builder, toUnderlying(dtype_))
2695
+ .Union()};
2696
+ }
2697
+
2698
+ private:
2699
+ //! Type of output
2700
+ PrimDataType dtype_;
2701
+ };
2702
+
2703
+ struct IotaOpRecord : RecordFunctor {
2704
+ IotaOpRecord(
2705
+ std::vector<State> _args,
2706
+ std::vector<State> _outputs,
2707
+ PrimDataType dtype)
2708
+ : RecordFunctor(
2709
+ std::move(_args),
2710
+ std::move(_outputs),
2711
+ "ops.iota",
2712
+ serde::RecordType::IotaOp),
2713
+ dtype_(dtype) {}
2714
+ ~IotaOpRecord() override = default;
2715
+ RecordFunctor* clone() final {
2716
+ return new IotaOpRecord(*this);
2717
+ }
2718
+
2719
+ //! Child specific hash function in lower 32 bits.
2720
+ //! | 31 -------------------------------------- 0 |
2721
+ //! | Dtype |
2722
+ size_t hash() const final {
2723
+ return RecordFunctor::hash() | static_cast<uint32_t>(dtype_);
2724
+ }
2725
+
2726
+ bool operator==(const RecordFunctor& other) const final {
2727
+ auto result = false;
2728
+ if (auto child_ptr = dynamic_cast<const IotaOpRecord*>(&other)) {
2729
+ result = RecordFunctor::operator==(other) && dtype_ == child_ptr->dtype_;
2730
+ }
2731
+ return result;
2732
+ }
2733
+
2734
+ void operator()(FusionState& fd) final {
2735
+ auto length = fd.getFusionState(args_.at(0).index);
2736
+ auto start = (args_.at(1).stype == serde::StateType::Scalar)
2737
+ ? fd.getFusionState(args_.at(1).index)->as<Val>()
2738
+ : nullptr;
2739
+ auto step = (args_.at(2).stype == serde::StateType::Scalar)
2740
+ ? fd.getFusionState(args_.at(2).index)->as<Val>()
2741
+ : nullptr;
2742
+ auto output = iota(length, start, step, dtype_);
2743
+ fd.setFusionState(outputs_.at(0).index, output);
2744
+ }
2745
+
2746
+ void print(std::ostream& os, bool close_function = true) const override {
2747
+ RecordFunctor::print(os, false);
2748
+ os << ", dtype=" << dtypeToPyString(dtype_);
2749
+ if (close_function) {
2750
+ os << ")";
2751
+ }
2752
+ }
2753
+
2754
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2755
+ flatbuffers::FlatBufferBuilder& builder) const final {
2756
+ return {
2757
+ serde::RecordData::Dtype,
2758
+ serde::CreateDtype(builder, nvfuser::toUnderlying(dtype_)).Union()};
2759
+ }
2760
+
2761
+ private:
2762
+ //! Type of output
2763
+ PrimDataType dtype_;
2764
+ };
2765
+
2766
+ //! Specialized Record Functors for random ops.
2767
+ template <serde::RecordType RType>
2768
+ struct RandomDistOpRecord : RecordFunctor {
2769
+ RandomDistOpRecord(
2770
+ std::vector<State> _args,
2771
+ std::vector<State> _outputs,
2772
+ PrimDataType dtype)
2773
+ : RecordFunctor(std::move(_args), std::move(_outputs), "", RType),
2774
+ dtype_(dtype) {
2775
+ if constexpr (RType == serde::RecordType::UniformDistOp) {
2776
+ name_ = "ops.uniform";
2777
+ } else if constexpr (RType == serde::RecordType::NormalDistOp) {
2778
+ name_ = "ops.normal";
2779
+ } else {
2780
+ static_assert(
2781
+ (RType == serde::RecordType::NormalDistOp) ||
2782
+ (RType == serde::RecordType::UniformDistOp));
2783
+ }
2784
+ setArgName(2, "shape");
2785
+ if (args_.size() == 5) {
2786
+ setArgName(3, "rng_seed");
2787
+ setArgName(4, "rng_offset");
2788
+ }
2789
+ }
2790
+ ~RandomDistOpRecord() override = default;
2791
+ RecordFunctor* clone() final {
2792
+ return new RandomDistOpRecord(*this);
2793
+ }
2794
+
2795
+ //! Child specific hash function in lower 32 bits.
2796
+ //! | 31 --------------------------------------- 0 |
2797
+ //! | Dtype |
2798
+ size_t hash() const final {
2799
+ auto result = RecordFunctor::hash();
2800
+ return result | (static_cast<size_t>(dtype_) & 0xffffffff);
2801
+ }
2802
+
2803
+ bool operator==(const RecordFunctor& other) const final {
2804
+ auto result = false;
2805
+ if (auto child_ptr = dynamic_cast<const RandomDistOpRecord*>(&other)) {
2806
+ result = RecordFunctor::operator==(other);
2807
+ result = result && (dtype_ == child_ptr->dtype_);
2808
+ }
2809
+ return result;
2810
+ }
2811
+
2812
+ void operator()(FusionState& fd) final {
2813
+ auto arg1 = fd.getFusionState(args_.at(0).index);
2814
+ auto arg2 = fd.getFusionState(args_.at(1).index);
2815
+ const std::vector<Val*>& output_shape =
2816
+ fd.getFusionStateVector(args_.at(2).index);
2817
+
2818
+ Val* output = nullptr;
2819
+ if constexpr (RType == serde::RecordType::UniformDistOp) {
2820
+ if (args_.size() == 3) { // stochastic uniform
2821
+ output = uniform(output_shape, arg1, arg2, dtype_);
2822
+ } else if (args_.size() == 5) { // provided seed and offset
2823
+ auto seed = fd.getFusionState(args_.at(3).index);
2824
+ auto offset = fd.getFusionState(args_.at(4).index);
2825
+ output = uniform(output_shape, arg1, arg2, dtype_, seed, offset);
2826
+ }
2827
+ } else if constexpr (RType == serde::RecordType::NormalDistOp) {
2828
+ if (args_.size() == 3) { // stochastic normal
2829
+ output = normal(output_shape, arg1, arg2, dtype_);
2830
+ } else if (args_.size() == 5) { // provided seed and offset
2831
+ auto seed = fd.getFusionState(args_.at(3).index);
2832
+ auto offset = fd.getFusionState(args_.at(4).index);
2833
+ output = normal(output_shape, arg1, arg2, dtype_, seed, offset);
2834
+ }
2835
+ } else {
2836
+ static_assert(
2837
+ (RType == serde::RecordType::NormalDistOp) ||
2838
+ (RType == serde::RecordType::UniformDistOp));
2839
+ }
2840
+
2841
+ fd.setFusionState(outputs_.at(0).index, output);
2842
+ }
2843
+
2844
+ void print(std::ostream& os, bool close_function = true) const final {
2845
+ RecordFunctor::print(os, false);
2846
+ os << ", dtype=" << dtypeToPyString(dtype_);
2847
+ if (close_function) {
2848
+ os << ")";
2849
+ }
2850
+ }
2851
+
2852
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2853
+ flatbuffers::FlatBufferBuilder& builder) const final {
2854
+ return {
2855
+ serde::RecordData::TensorCreationSymbolic,
2856
+ serde::CreateTensorCreationSymbolic(builder, toUnderlying(dtype_))
2857
+ .Union()};
2858
+ }
2859
+
2860
+ private:
2861
+ //! DataType of output
2862
+ PrimDataType dtype_;
2863
+ };
2864
+
2865
+ //! Specialized Record Functor for recording Vector of Scalars
2866
+
2867
+ struct VectorRecord : RecordFunctor {
2868
+ VectorRecord(
2869
+ std::vector<State> _args,
2870
+ std::vector<State> _outputs,
2871
+ PrimDataType dtype,
2872
+ bool inline_def = false)
2873
+ : RecordFunctor(
2874
+ std::move(_args),
2875
+ std::move(_outputs),
2876
+ "define_vector",
2877
+ serde::RecordType::Vector,
2878
+ inline_def),
2879
+ dtype_(dtype) {}
2880
+ ~VectorRecord() override = default;
2881
+ RecordFunctor* clone() final {
2882
+ return new VectorRecord(*this);
2883
+ }
2884
+
2885
+ //! Child specific hash function in lower 32 bits.
2886
+ //! | 31 --------------------------------------- 0 |
2887
+ //! | Dtype |
2888
+ size_t hash() const final {
2889
+ auto result = RecordFunctor::hash();
2890
+ return result | (static_cast<size_t>(dtype_) & 0xffffffff);
2891
+ }
2892
+
2893
+ bool operator==(const RecordFunctor& other) const final {
2894
+ auto result = false;
2895
+ if (auto child_ptr = dynamic_cast<const VectorRecord*>(&other)) {
2896
+ result = RecordFunctor::operator==(other);
2897
+ result = result && (dtype_ == child_ptr->dtype_);
2898
+ }
2899
+ return result;
2900
+ }
2901
+
2902
+ void operator()(FusionState& fd) final {
2903
+ std::vector<Val*> output(args_.size(), nullptr);
2904
+ NVF_CHECK(
2905
+ dtype_ == DataType::Int,
2906
+ "Only Int Dtype is not supported by a vector of sizes: ",
2907
+ dtype_);
2908
+ for (size_t i = 0; i < args_.size(); ++i) {
2909
+ NVF_CHECK(
2910
+ args_.at(i).stype == serde::StateType::Scalar,
2911
+ "Unsupported State type!");
2912
+ output.at(i) = fd.getFusionState(args_.at(i).index);
2913
+ }
2914
+ fd.setFusionStateVector(outputs_.at(0).index, output);
2915
+ }
2916
+
2917
+ void print(std::ostream& os, bool close_function = true) const final {
2918
+ if (inline_def_) {
2919
+ bool first_arg = true;
2920
+ NVF_CHECK(outputs_.size() == 1, "VectorRecord's does not have 1 output!");
2921
+ os << "[";
2922
+ for (auto& arg : args_) {
2923
+ if (first_arg) {
2924
+ first_arg = false;
2925
+ } else {
2926
+ os << ", ";
2927
+ }
2928
+ os << arg;
2929
+ }
2930
+ os << "]";
2931
+ } else {
2932
+ bool first_output = true;
2933
+ for (auto& output : outputs_) {
2934
+ if (first_output) {
2935
+ first_output = false;
2936
+ } else {
2937
+ os << ", ";
2938
+ }
2939
+ os << output;
2940
+ }
2941
+ os << " = fd." << name_ << "([";
2942
+ bool first_arg = true;
2943
+ for (auto& arg : args_) {
2944
+ if (first_arg) {
2945
+ first_arg = false;
2946
+ } else {
2947
+ os << ", ";
2948
+ }
2949
+ os << arg;
2950
+ }
2951
+ os << "], dtype=" << dtypeToPyString(dtype_);
2952
+ if (close_function) {
2953
+ os << ")";
2954
+ }
2955
+ }
2956
+ }
2957
+
2958
+ std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
2959
+ flatbuffers::FlatBufferBuilder& builder) const final {
2960
+ return {
2961
+ serde::RecordData::Vector,
2962
+ serde::CreateVector(builder, nvfuser::toUnderlying(dtype_)).Union()};
2963
+ };
2964
+
2965
+ private:
2966
+ //! Scalar data type.
2967
+ PrimDataType dtype_;
2968
+ };
2969
+
2970
+ struct SdpaFwdOpRecord : RecordFunctor {
2971
+ SdpaFwdOpRecord(std::vector<State> args, std::vector<State> outputs)
2972
+ : RecordFunctor(
2973
+ std::move(args),
2974
+ std::move(outputs),
2975
+ "ops.sdpfa_fwd",
2976
+ serde::RecordType::SdpaFwdOp) {}
2977
+ ~SdpaFwdOpRecord() override = default;
2978
+ RecordFunctor* clone() final {
2979
+ return new SdpaFwdOpRecord(*this);
2980
+ }
2981
+
2982
+ void operator()(FusionState& fd) final {
2983
+ auto query = fd.getFusionState(args_.at(0).index)->as<TensorView>();
2984
+ auto key = fd.getFusionState(args_.at(1).index)->as<TensorView>();
2985
+ auto value = fd.getFusionState(args_.at(2).index)->as<TensorView>();
2986
+ auto dropout_p = (args_.at(3).stype == serde::StateType::Scalar)
2987
+ ? fd.getFusionState(args_.at(3).index)->as<Val>()
2988
+ : nullptr;
2989
+ auto is_causal = (args_.at(4).stype == serde::StateType::Scalar)
2990
+ ? fd.getFusionState(args_.at(4).index)->as<Val>()
2991
+ : nullptr;
2992
+ auto scale = (args_.at(5).stype == serde::StateType::Scalar)
2993
+ ? fd.getFusionState(args_.at(5).index)->as<Val>()
2994
+ : nullptr;
2995
+ auto output = sdpfa_fwd(query, key, value, dropout_p, is_causal, scale);
2996
+ fd.setFusionState(outputs_.at(0).index, output.output);
2997
+ fd.setFusionState(outputs_.at(1).index, output.log_sumexp);
2998
+ fd.setFusionState(outputs_.at(2).index, output.philox_seed);
2999
+ fd.setFusionState(outputs_.at(3).index, output.philox_offset);
3000
+ }
3001
+ };
3002
+
3003
+ struct SdpaBwdOpRecord : RecordFunctor {
3004
+ SdpaBwdOpRecord(std::vector<State> args, std::vector<State> outputs)
3005
+ : RecordFunctor(
3006
+ std::move(args),
3007
+ std::move(outputs),
3008
+ "ops.sdpfa_bwd",
3009
+ serde::RecordType::SdpaBwdOp) {}
3010
+ ~SdpaBwdOpRecord() override = default;
3011
+ RecordFunctor* clone() final {
3012
+ return new SdpaBwdOpRecord(*this);
3013
+ }
3014
+
3015
+ void operator()(FusionState& fd) final {
3016
+ auto grad_output = fd.getFusionState(args_.at(0).index)->as<TensorView>();
3017
+ auto query = fd.getFusionState(args_.at(1).index)->as<TensorView>();
3018
+ auto key = fd.getFusionState(args_.at(2).index)->as<TensorView>();
3019
+ auto value = fd.getFusionState(args_.at(3).index)->as<TensorView>();
3020
+ auto output = fd.getFusionState(args_.at(4).index)->as<TensorView>();
3021
+ auto log_sumexp = fd.getFusionState(args_.at(5).index)->as<TensorView>();
3022
+
3023
+ auto dropout_p = (args_.at(6).stype == serde::StateType::Scalar)
3024
+ ? fd.getFusionState(args_.at(6).index)->as<Val>()
3025
+ : nullptr;
3026
+ auto is_causal = (args_.at(7).stype == serde::StateType::Scalar)
3027
+ ? fd.getFusionState(args_.at(7).index)->as<Val>()
3028
+ : nullptr;
3029
+
3030
+ auto philox_seed = fd.getFusionState(args_.at(8).index)->as<TensorView>();
3031
+ auto philox_offset = fd.getFusionState(args_.at(9).index)->as<TensorView>();
3032
+
3033
+ auto scale = (args_.at(10).stype == serde::StateType::Scalar)
3034
+ ? fd.getFusionState(args_.at(10).index)->as<Val>()
3035
+ : nullptr;
3036
+
3037
+ auto grad = sdpfa_bwd(
3038
+ grad_output,
3039
+ query,
3040
+ key,
3041
+ value,
3042
+ output,
3043
+ log_sumexp,
3044
+ dropout_p,
3045
+ is_causal,
3046
+ philox_seed,
3047
+ philox_offset,
3048
+ scale);
3049
+ fd.setFusionState(outputs_.at(0).index, grad.grad_query);
3050
+ fd.setFusionState(outputs_.at(1).index, grad.grad_key);
3051
+ fd.setFusionState(outputs_.at(2).index, grad.grad_value);
3052
+ }
3053
+ };
3054
+
3055
+ struct EmbeddingFwdOpRecord : RecordFunctor {
3056
+ EmbeddingFwdOpRecord(std::vector<State> args, std::vector<State> outputs)
3057
+ : RecordFunctor(
3058
+ std::move(args),
3059
+ std::move(outputs),
3060
+ "ops.embedding_fwd",
3061
+ serde::RecordType::EmbeddingFwdOp) {}
3062
+ ~EmbeddingFwdOpRecord() override = default;
3063
+ RecordFunctor* clone() final {
3064
+ return new EmbeddingFwdOpRecord(*this);
3065
+ }
3066
+
3067
+ void operator()(FusionState& fd) final {
3068
+ auto input = fd.getFusionState(args_.at(0).index)->as<TensorView>();
3069
+ auto weight = fd.getFusionState(args_.at(1).index)->as<TensorView>();
3070
+ auto padding_idx = (args_.at(2).stype == serde::StateType::Scalar)
3071
+ ? fd.getFusionState(args_.at(2).index)->as<Val>()
3072
+ : nullptr;
3073
+ auto max_norm = (args_.at(3).stype == serde::StateType::Scalar)
3074
+ ? fd.getFusionState(args_.at(3).index)->as<Val>()
3075
+ : nullptr;
3076
+ auto norm_type = (args_.at(4).stype == serde::StateType::Scalar)
3077
+ ? fd.getFusionState(args_.at(4).index)->as<Val>()
3078
+ : nullptr;
3079
+ auto scale_grad_by_freq = (args_.at(5).stype == serde::StateType::Scalar)
3080
+ ? fd.getFusionState(args_.at(5).index)->as<Val>()
3081
+ : nullptr;
3082
+ auto sparse = (args_.at(6).stype == serde::StateType::Scalar)
3083
+ ? fd.getFusionState(args_.at(6).index)->as<Val>()
3084
+ : nullptr;
3085
+
3086
+ auto output = embedding_fwd(
3087
+ input,
3088
+ weight,
3089
+ padding_idx,
3090
+ max_norm,
3091
+ norm_type,
3092
+ scale_grad_by_freq,
3093
+ sparse);
3094
+ fd.setFusionState(outputs_.at(0).index, output);
3095
+ }
3096
+ };
3097
+
3098
+ } // namespace nvfuser::python_frontend
3099
+
3100
+ //! Creating the template specialized hash and equal_to functions for a
3101
+ //! RecordFunctor object in order to use hash maps (unordered_maps) in STL.
3102
+ namespace std {
3103
+ using namespace nvfuser::python_frontend;
3104
+
3105
+ template <>
3106
+ struct hash<RecordFunctor*> {
3107
+ size_t operator()(const RecordFunctor* p) const {
3108
+ NVF_CHECK(p, "The RecordFunctor Pointer for hashing is null!");
3109
+ return p->hash();
3110
+ }
3111
+ };
3112
+ template <>
3113
+ struct equal_to<RecordFunctor*> {
3114
+ bool operator()(const RecordFunctor* p, const RecordFunctor* q) const {
3115
+ NVF_CHECK(
3116
+ p,
3117
+ "The RecordFunctor Pointer on the lhs of an equality check is null!");
3118
+ NVF_CHECK(
3119
+ q,
3120
+ "The RecordFunctor Pointer on the rhs of an equality check is null!");
3121
+ return p->operator==(*q);
3122
+ }
3123
+ };
3124
+ } // namespace std