nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,4319 @@
1
+ // automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+
4
+ #ifndef FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_
5
+ #define FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_
6
+
7
+ #include "flatbuffers/flatbuffers.h"
8
+
9
+ // Ensure the included flatbuffers.h is the same version as when this file was
10
+ // generated, otherwise it may not be compatible.
11
+ static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
12
+ FLATBUFFERS_VERSION_MINOR == 3 &&
13
+ FLATBUFFERS_VERSION_REVISION == 3,
14
+ "Non-compatible flatbuffers version included");
15
+
16
+ namespace nvfuser {
17
+ namespace serde {
18
+
19
+ struct State;
20
+
21
+ struct Scalar;
22
+ struct ScalarBuilder;
23
+
24
+ struct ScalarCpu;
25
+ struct ScalarCpuBuilder;
26
+
27
+ struct TensorArg;
28
+ struct TensorArgBuilder;
29
+
30
+ struct PolymorphicValue;
31
+ struct PolymorphicValueBuilder;
32
+
33
+ struct KernelArgumentHolder;
34
+ struct KernelArgumentHolderBuilder;
35
+
36
+ struct TensorShape;
37
+ struct TensorShapeBuilder;
38
+
39
+ struct LaunchParams;
40
+ struct LaunchParamsBuilder;
41
+
42
+ struct GlobalBufferInfo;
43
+ struct GlobalBufferInfoBuilder;
44
+
45
+ struct ExecutorEntry;
46
+ struct ExecutorEntryBuilder;
47
+
48
+ struct At;
49
+ struct AtBuilder;
50
+
51
+ struct BatchNorm;
52
+ struct BatchNormBuilder;
53
+
54
+ struct Broadcast;
55
+ struct BroadcastBuilder;
56
+
57
+ struct BroadcastInDim;
58
+ struct BroadcastInDimBuilder;
59
+
60
+ struct Cat;
61
+ struct CatBuilder;
62
+
63
+ struct Dtype;
64
+ struct DtypeBuilder;
65
+
66
+ struct Dimension;
67
+ struct DimensionBuilder;
68
+
69
+ struct Norm;
70
+ struct NormBuilder;
71
+
72
+ struct Output;
73
+ struct OutputBuilder;
74
+
75
+ struct Dims;
76
+ struct DimsBuilder;
77
+
78
+ struct Reduction;
79
+ struct ReductionBuilder;
80
+
81
+ struct Size;
82
+ struct SizeBuilder;
83
+
84
+ struct Slice;
85
+ struct SliceBuilder;
86
+
87
+ struct Squeeze;
88
+ struct SqueezeBuilder;
89
+
90
+ struct Tensor;
91
+ struct TensorBuilder;
92
+
93
+ struct TensorCreationSymbolic;
94
+ struct TensorCreationSymbolicBuilder;
95
+
96
+ struct Vector;
97
+ struct VectorBuilder;
98
+
99
+ struct Welford;
100
+ struct WelfordBuilder;
101
+
102
+ struct CudaKernel;
103
+ struct CudaKernelBuilder;
104
+
105
+ struct KernelExecutor;
106
+ struct KernelExecutorBuilder;
107
+
108
+ struct SegmentedEdge;
109
+ struct SegmentedEdgeBuilder;
110
+
111
+ struct SegmentedGroup;
112
+ struct SegmentedGroupBuilder;
113
+
114
+ struct SegmentedFusion;
115
+ struct SegmentedFusionBuilder;
116
+
117
+ struct FusionKernelRuntime;
118
+ struct FusionKernelRuntimeBuilder;
119
+
120
+ struct EncodingEntry;
121
+
122
+ struct InputsIdLookup;
123
+ struct InputsIdLookupBuilder;
124
+
125
+ struct KernelRuntimeState;
126
+ struct KernelRuntimeStateBuilder;
127
+
128
+ struct FusionExecutorCache;
129
+ struct FusionExecutorCacheBuilder;
130
+
131
+ struct RecordFunctor;
132
+ struct RecordFunctorBuilder;
133
+
134
+ struct TrieNode;
135
+ struct TrieNodeBuilder;
136
+
137
+ struct FusionCache;
138
+ struct FusionCacheBuilder;
139
+
140
+ enum class StateType : int32_t {
141
+ Tensor = 0,
142
+ Scalar = 1,
143
+ Vector = 2,
144
+ None = 3,
145
+ MIN = Tensor,
146
+ MAX = None
147
+ };
148
+
149
+ inline const StateType (&EnumValuesStateType())[4] {
150
+ static const StateType values[] = {
151
+ StateType::Tensor,
152
+ StateType::Scalar,
153
+ StateType::Vector,
154
+ StateType::None
155
+ };
156
+ return values;
157
+ }
158
+
159
+ inline const char * const *EnumNamesStateType() {
160
+ static const char * const names[5] = {
161
+ "Tensor",
162
+ "Scalar",
163
+ "Vector",
164
+ "None",
165
+ nullptr
166
+ };
167
+ return names;
168
+ }
169
+
170
+ inline const char *EnumNameStateType(StateType e) {
171
+ if (::flatbuffers::IsOutRange(e, StateType::Tensor, StateType::None)) return "";
172
+ const size_t index = static_cast<size_t>(e);
173
+ return EnumNamesStateType()[index];
174
+ }
175
+
176
+ enum class Contiguity : int32_t {
177
+ Strided = 0,
178
+ Contiguous = 1,
179
+ None = 2,
180
+ MIN = Strided,
181
+ MAX = None
182
+ };
183
+
184
+ inline const Contiguity (&EnumValuesContiguity())[3] {
185
+ static const Contiguity values[] = {
186
+ Contiguity::Strided,
187
+ Contiguity::Contiguous,
188
+ Contiguity::None
189
+ };
190
+ return values;
191
+ }
192
+
193
+ inline const char * const *EnumNamesContiguity() {
194
+ static const char * const names[4] = {
195
+ "Strided",
196
+ "Contiguous",
197
+ "None",
198
+ nullptr
199
+ };
200
+ return names;
201
+ }
202
+
203
+ inline const char *EnumNameContiguity(Contiguity e) {
204
+ if (::flatbuffers::IsOutRange(e, Contiguity::Strided, Contiguity::None)) return "";
205
+ const size_t index = static_cast<size_t>(e);
206
+ return EnumNamesContiguity()[index];
207
+ }
208
+
209
+ enum class RecordType : int32_t {
210
+ Base = 0,
211
+ AtOp = 1,
212
+ BatchNormOp = 2,
213
+ BroadcastOp = 3,
214
+ BroadcastInDim = 4,
215
+ CastTv = 5,
216
+ CastVal = 6,
217
+ CatOp = 7,
218
+ EmbeddingFwdOp = 8,
219
+ End = 9,
220
+ ExpandOp = 10,
221
+ FullOp = 11,
222
+ IotaOp = 12,
223
+ IndexSelectOp = 13,
224
+ SelectOp = 14,
225
+ TorchGatherOp = 15,
226
+ TakeAlongAxisOp = 16,
227
+ Unary_TV = 17,
228
+ Unary_VAL = 18,
229
+ Binary_TV = 19,
230
+ Binary_VAL = 20,
231
+ Binary_TV_VAL = 21,
232
+ Binary_VAL_TV = 22,
233
+ Ternary_TV = 23,
234
+ Ternary_VAL = 24,
235
+ Ternary_TV_TV_VAL = 25,
236
+ Ternary_TV_VAL_TV = 26,
237
+ Ternary_VAL_TV_TV = 27,
238
+ Ternary_VAL_VAL_TV = 28,
239
+ Ternary_TV_VAL_VAL = 29,
240
+ Ternary_VAL_TV_VAL = 30,
241
+ Ternary_Alpha_TV = 31,
242
+ Ternary_Alpha_VAL = 32,
243
+ Ternary_Alpha_TV_TV_VAL = 33,
244
+ Ternary_Alpha_TV_VAL_TV = 34,
245
+ Ternary_Alpha_VAL_TV_TV = 35,
246
+ Ternary_Alpha_VAL_VAL_TV = 36,
247
+ Ternary_Alpha_TV_VAL_VAL = 37,
248
+ Ternary_Alpha_VAL_TV_VAL = 38,
249
+ NormalDistOp = 39,
250
+ OutputTv = 40,
251
+ OutputVal = 41,
252
+ PadOp = 42,
253
+ PermuteOp = 43,
254
+ StrideOrderOp = 44,
255
+ ReductionMax = 45,
256
+ ReductionMin = 46,
257
+ ReductionProd = 47,
258
+ ReductionSum = 48,
259
+ ReshapeOp = 49,
260
+ Scalar = 50,
261
+ SdpaFwdOp = 51,
262
+ SdpaBwdOp = 52,
263
+ ShapeOp = 53,
264
+ SizeOp = 54,
265
+ SliceOp = 55,
266
+ SqueezeOp = 56,
267
+ Start = 57,
268
+ Tensor = 58,
269
+ TensorSizes = 59,
270
+ UniformDistOp = 60,
271
+ VarianceOp = 61,
272
+ VarianceMeanOp = 62,
273
+ Vector = 63,
274
+ WelfordOp = 64,
275
+ MIN = Base,
276
+ MAX = WelfordOp
277
+ };
278
+
279
+ inline const RecordType (&EnumValuesRecordType())[65] {
280
+ static const RecordType values[] = {
281
+ RecordType::Base,
282
+ RecordType::AtOp,
283
+ RecordType::BatchNormOp,
284
+ RecordType::BroadcastOp,
285
+ RecordType::BroadcastInDim,
286
+ RecordType::CastTv,
287
+ RecordType::CastVal,
288
+ RecordType::CatOp,
289
+ RecordType::EmbeddingFwdOp,
290
+ RecordType::End,
291
+ RecordType::ExpandOp,
292
+ RecordType::FullOp,
293
+ RecordType::IotaOp,
294
+ RecordType::IndexSelectOp,
295
+ RecordType::SelectOp,
296
+ RecordType::TorchGatherOp,
297
+ RecordType::TakeAlongAxisOp,
298
+ RecordType::Unary_TV,
299
+ RecordType::Unary_VAL,
300
+ RecordType::Binary_TV,
301
+ RecordType::Binary_VAL,
302
+ RecordType::Binary_TV_VAL,
303
+ RecordType::Binary_VAL_TV,
304
+ RecordType::Ternary_TV,
305
+ RecordType::Ternary_VAL,
306
+ RecordType::Ternary_TV_TV_VAL,
307
+ RecordType::Ternary_TV_VAL_TV,
308
+ RecordType::Ternary_VAL_TV_TV,
309
+ RecordType::Ternary_VAL_VAL_TV,
310
+ RecordType::Ternary_TV_VAL_VAL,
311
+ RecordType::Ternary_VAL_TV_VAL,
312
+ RecordType::Ternary_Alpha_TV,
313
+ RecordType::Ternary_Alpha_VAL,
314
+ RecordType::Ternary_Alpha_TV_TV_VAL,
315
+ RecordType::Ternary_Alpha_TV_VAL_TV,
316
+ RecordType::Ternary_Alpha_VAL_TV_TV,
317
+ RecordType::Ternary_Alpha_VAL_VAL_TV,
318
+ RecordType::Ternary_Alpha_TV_VAL_VAL,
319
+ RecordType::Ternary_Alpha_VAL_TV_VAL,
320
+ RecordType::NormalDistOp,
321
+ RecordType::OutputTv,
322
+ RecordType::OutputVal,
323
+ RecordType::PadOp,
324
+ RecordType::PermuteOp,
325
+ RecordType::StrideOrderOp,
326
+ RecordType::ReductionMax,
327
+ RecordType::ReductionMin,
328
+ RecordType::ReductionProd,
329
+ RecordType::ReductionSum,
330
+ RecordType::ReshapeOp,
331
+ RecordType::Scalar,
332
+ RecordType::SdpaFwdOp,
333
+ RecordType::SdpaBwdOp,
334
+ RecordType::ShapeOp,
335
+ RecordType::SizeOp,
336
+ RecordType::SliceOp,
337
+ RecordType::SqueezeOp,
338
+ RecordType::Start,
339
+ RecordType::Tensor,
340
+ RecordType::TensorSizes,
341
+ RecordType::UniformDistOp,
342
+ RecordType::VarianceOp,
343
+ RecordType::VarianceMeanOp,
344
+ RecordType::Vector,
345
+ RecordType::WelfordOp
346
+ };
347
+ return values;
348
+ }
349
+
350
+ inline const char * const *EnumNamesRecordType() {
351
+ static const char * const names[66] = {
352
+ "Base",
353
+ "AtOp",
354
+ "BatchNormOp",
355
+ "BroadcastOp",
356
+ "BroadcastInDim",
357
+ "CastTv",
358
+ "CastVal",
359
+ "CatOp",
360
+ "EmbeddingFwdOp",
361
+ "End",
362
+ "ExpandOp",
363
+ "FullOp",
364
+ "IotaOp",
365
+ "IndexSelectOp",
366
+ "SelectOp",
367
+ "TorchGatherOp",
368
+ "TakeAlongAxisOp",
369
+ "Unary_TV",
370
+ "Unary_VAL",
371
+ "Binary_TV",
372
+ "Binary_VAL",
373
+ "Binary_TV_VAL",
374
+ "Binary_VAL_TV",
375
+ "Ternary_TV",
376
+ "Ternary_VAL",
377
+ "Ternary_TV_TV_VAL",
378
+ "Ternary_TV_VAL_TV",
379
+ "Ternary_VAL_TV_TV",
380
+ "Ternary_VAL_VAL_TV",
381
+ "Ternary_TV_VAL_VAL",
382
+ "Ternary_VAL_TV_VAL",
383
+ "Ternary_Alpha_TV",
384
+ "Ternary_Alpha_VAL",
385
+ "Ternary_Alpha_TV_TV_VAL",
386
+ "Ternary_Alpha_TV_VAL_TV",
387
+ "Ternary_Alpha_VAL_TV_TV",
388
+ "Ternary_Alpha_VAL_VAL_TV",
389
+ "Ternary_Alpha_TV_VAL_VAL",
390
+ "Ternary_Alpha_VAL_TV_VAL",
391
+ "NormalDistOp",
392
+ "OutputTv",
393
+ "OutputVal",
394
+ "PadOp",
395
+ "PermuteOp",
396
+ "StrideOrderOp",
397
+ "ReductionMax",
398
+ "ReductionMin",
399
+ "ReductionProd",
400
+ "ReductionSum",
401
+ "ReshapeOp",
402
+ "Scalar",
403
+ "SdpaFwdOp",
404
+ "SdpaBwdOp",
405
+ "ShapeOp",
406
+ "SizeOp",
407
+ "SliceOp",
408
+ "SqueezeOp",
409
+ "Start",
410
+ "Tensor",
411
+ "TensorSizes",
412
+ "UniformDistOp",
413
+ "VarianceOp",
414
+ "VarianceMeanOp",
415
+ "Vector",
416
+ "WelfordOp",
417
+ nullptr
418
+ };
419
+ return names;
420
+ }
421
+
422
+ inline const char *EnumNameRecordType(RecordType e) {
423
+ if (::flatbuffers::IsOutRange(e, RecordType::Base, RecordType::WelfordOp)) return "";
424
+ const size_t index = static_cast<size_t>(e);
425
+ return EnumNamesRecordType()[index];
426
+ }
427
+
428
+ enum class RecordData : uint8_t {
429
+ NONE = 0,
430
+ At = 1,
431
+ BatchNorm = 2,
432
+ Broadcast = 3,
433
+ BroadcastInDim = 4,
434
+ Cat = 5,
435
+ Dimension = 6,
436
+ Dtype = 7,
437
+ Norm = 8,
438
+ Output = 9,
439
+ Dims = 10,
440
+ Slice = 11,
441
+ Squeeze = 12,
442
+ Reduction = 13,
443
+ Scalar = 14,
444
+ Size = 15,
445
+ Tensor = 16,
446
+ TensorCreationSymbolic = 17,
447
+ Vector = 18,
448
+ Welford = 19,
449
+ MIN = NONE,
450
+ MAX = Welford
451
+ };
452
+
453
+ inline const RecordData (&EnumValuesRecordData())[20] {
454
+ static const RecordData values[] = {
455
+ RecordData::NONE,
456
+ RecordData::At,
457
+ RecordData::BatchNorm,
458
+ RecordData::Broadcast,
459
+ RecordData::BroadcastInDim,
460
+ RecordData::Cat,
461
+ RecordData::Dimension,
462
+ RecordData::Dtype,
463
+ RecordData::Norm,
464
+ RecordData::Output,
465
+ RecordData::Dims,
466
+ RecordData::Slice,
467
+ RecordData::Squeeze,
468
+ RecordData::Reduction,
469
+ RecordData::Scalar,
470
+ RecordData::Size,
471
+ RecordData::Tensor,
472
+ RecordData::TensorCreationSymbolic,
473
+ RecordData::Vector,
474
+ RecordData::Welford
475
+ };
476
+ return values;
477
+ }
478
+
479
+ inline const char * const *EnumNamesRecordData() {
480
+ static const char * const names[21] = {
481
+ "NONE",
482
+ "At",
483
+ "BatchNorm",
484
+ "Broadcast",
485
+ "BroadcastInDim",
486
+ "Cat",
487
+ "Dimension",
488
+ "Dtype",
489
+ "Norm",
490
+ "Output",
491
+ "Dims",
492
+ "Slice",
493
+ "Squeeze",
494
+ "Reduction",
495
+ "Scalar",
496
+ "Size",
497
+ "Tensor",
498
+ "TensorCreationSymbolic",
499
+ "Vector",
500
+ "Welford",
501
+ nullptr
502
+ };
503
+ return names;
504
+ }
505
+
506
+ inline const char *EnumNameRecordData(RecordData e) {
507
+ if (::flatbuffers::IsOutRange(e, RecordData::NONE, RecordData::Welford)) return "";
508
+ const size_t index = static_cast<size_t>(e);
509
+ return EnumNamesRecordData()[index];
510
+ }
511
+
512
+ template<typename T> struct RecordDataTraits {
513
+ static const RecordData enum_value = RecordData::NONE;
514
+ };
515
+
516
+ template<> struct RecordDataTraits<nvfuser::serde::At> {
517
+ static const RecordData enum_value = RecordData::At;
518
+ };
519
+
520
+ template<> struct RecordDataTraits<nvfuser::serde::BatchNorm> {
521
+ static const RecordData enum_value = RecordData::BatchNorm;
522
+ };
523
+
524
+ template<> struct RecordDataTraits<nvfuser::serde::Broadcast> {
525
+ static const RecordData enum_value = RecordData::Broadcast;
526
+ };
527
+
528
+ template<> struct RecordDataTraits<nvfuser::serde::BroadcastInDim> {
529
+ static const RecordData enum_value = RecordData::BroadcastInDim;
530
+ };
531
+
532
+ template<> struct RecordDataTraits<nvfuser::serde::Cat> {
533
+ static const RecordData enum_value = RecordData::Cat;
534
+ };
535
+
536
+ template<> struct RecordDataTraits<nvfuser::serde::Dimension> {
537
+ static const RecordData enum_value = RecordData::Dimension;
538
+ };
539
+
540
+ template<> struct RecordDataTraits<nvfuser::serde::Dtype> {
541
+ static const RecordData enum_value = RecordData::Dtype;
542
+ };
543
+
544
+ template<> struct RecordDataTraits<nvfuser::serde::Norm> {
545
+ static const RecordData enum_value = RecordData::Norm;
546
+ };
547
+
548
+ template<> struct RecordDataTraits<nvfuser::serde::Output> {
549
+ static const RecordData enum_value = RecordData::Output;
550
+ };
551
+
552
+ template<> struct RecordDataTraits<nvfuser::serde::Dims> {
553
+ static const RecordData enum_value = RecordData::Dims;
554
+ };
555
+
556
+ template<> struct RecordDataTraits<nvfuser::serde::Slice> {
557
+ static const RecordData enum_value = RecordData::Slice;
558
+ };
559
+
560
+ template<> struct RecordDataTraits<nvfuser::serde::Squeeze> {
561
+ static const RecordData enum_value = RecordData::Squeeze;
562
+ };
563
+
564
+ template<> struct RecordDataTraits<nvfuser::serde::Reduction> {
565
+ static const RecordData enum_value = RecordData::Reduction;
566
+ };
567
+
568
+ template<> struct RecordDataTraits<nvfuser::serde::Scalar> {
569
+ static const RecordData enum_value = RecordData::Scalar;
570
+ };
571
+
572
+ template<> struct RecordDataTraits<nvfuser::serde::Size> {
573
+ static const RecordData enum_value = RecordData::Size;
574
+ };
575
+
576
+ template<> struct RecordDataTraits<nvfuser::serde::Tensor> {
577
+ static const RecordData enum_value = RecordData::Tensor;
578
+ };
579
+
580
+ template<> struct RecordDataTraits<nvfuser::serde::TensorCreationSymbolic> {
581
+ static const RecordData enum_value = RecordData::TensorCreationSymbolic;
582
+ };
583
+
584
+ template<> struct RecordDataTraits<nvfuser::serde::Vector> {
585
+ static const RecordData enum_value = RecordData::Vector;
586
+ };
587
+
588
+ template<> struct RecordDataTraits<nvfuser::serde::Welford> {
589
+ static const RecordData enum_value = RecordData::Welford;
590
+ };
591
+
592
+ bool VerifyRecordData(::flatbuffers::Verifier &verifier, const void *obj, RecordData type);
593
+ bool VerifyRecordDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<RecordData> *types);
594
+
595
+ enum class PolymorphicValueData : uint8_t {
596
+ NONE = 0,
597
+ Scalar = 1,
598
+ ScalarCpu = 2,
599
+ TensorArg = 3,
600
+ MIN = NONE,
601
+ MAX = TensorArg
602
+ };
603
+
604
+ inline const PolymorphicValueData (&EnumValuesPolymorphicValueData())[4] {
605
+ static const PolymorphicValueData values[] = {
606
+ PolymorphicValueData::NONE,
607
+ PolymorphicValueData::Scalar,
608
+ PolymorphicValueData::ScalarCpu,
609
+ PolymorphicValueData::TensorArg
610
+ };
611
+ return values;
612
+ }
613
+
614
+ inline const char * const *EnumNamesPolymorphicValueData() {
615
+ static const char * const names[5] = {
616
+ "NONE",
617
+ "Scalar",
618
+ "ScalarCpu",
619
+ "TensorArg",
620
+ nullptr
621
+ };
622
+ return names;
623
+ }
624
+
625
+ inline const char *EnumNamePolymorphicValueData(PolymorphicValueData e) {
626
+ if (::flatbuffers::IsOutRange(e, PolymorphicValueData::NONE, PolymorphicValueData::TensorArg)) return "";
627
+ const size_t index = static_cast<size_t>(e);
628
+ return EnumNamesPolymorphicValueData()[index];
629
+ }
630
+
631
+ template<typename T> struct PolymorphicValueDataTraits {
632
+ static const PolymorphicValueData enum_value = PolymorphicValueData::NONE;
633
+ };
634
+
635
+ template<> struct PolymorphicValueDataTraits<nvfuser::serde::Scalar> {
636
+ static const PolymorphicValueData enum_value = PolymorphicValueData::Scalar;
637
+ };
638
+
639
+ template<> struct PolymorphicValueDataTraits<nvfuser::serde::ScalarCpu> {
640
+ static const PolymorphicValueData enum_value = PolymorphicValueData::ScalarCpu;
641
+ };
642
+
643
+ template<> struct PolymorphicValueDataTraits<nvfuser::serde::TensorArg> {
644
+ static const PolymorphicValueData enum_value = PolymorphicValueData::TensorArg;
645
+ };
646
+
647
+ bool VerifyPolymorphicValueData(::flatbuffers::Verifier &verifier, const void *obj, PolymorphicValueData type);
648
+ bool VerifyPolymorphicValueDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<PolymorphicValueData> *types);
649
+
650
+ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) State FLATBUFFERS_FINAL_CLASS {
651
+ private:
652
+ int32_t index_;
653
+ int32_t type_;
654
+
655
+ public:
656
+ State()
657
+ : index_(0),
658
+ type_(0) {
659
+ }
660
+ State(int32_t _index, nvfuser::serde::StateType _type)
661
+ : index_(::flatbuffers::EndianScalar(_index)),
662
+ type_(::flatbuffers::EndianScalar(static_cast<int32_t>(_type))) {
663
+ }
664
+ int32_t index() const {
665
+ return ::flatbuffers::EndianScalar(index_);
666
+ }
667
+ nvfuser::serde::StateType type() const {
668
+ return static_cast<nvfuser::serde::StateType>(::flatbuffers::EndianScalar(type_));
669
+ }
670
+ };
671
+ FLATBUFFERS_STRUCT_END(State, 8);
672
+
673
+ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) EncodingEntry FLATBUFFERS_FINAL_CLASS {
674
+ private:
675
+ uint64_t id_;
676
+ uint64_t lru_iter_;
677
+
678
+ public:
679
+ EncodingEntry()
680
+ : id_(0),
681
+ lru_iter_(0) {
682
+ }
683
+ EncodingEntry(uint64_t _id, uint64_t _lru_iter)
684
+ : id_(::flatbuffers::EndianScalar(_id)),
685
+ lru_iter_(::flatbuffers::EndianScalar(_lru_iter)) {
686
+ }
687
+ uint64_t id() const {
688
+ return ::flatbuffers::EndianScalar(id_);
689
+ }
690
+ uint64_t lru_iter() const {
691
+ return ::flatbuffers::EndianScalar(lru_iter_);
692
+ }
693
+ };
694
+ FLATBUFFERS_STRUCT_END(EncodingEntry, 16);
695
+
696
+ struct Scalar FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
697
+ typedef ScalarBuilder Builder;
698
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
699
+ VT_DTYPE = 4,
700
+ VT_HAS_VALUE = 6,
701
+ VT_VALUE_TYPE = 8,
702
+ VT_BOOL_VALUE = 10,
703
+ VT_LONG_VALUE = 12,
704
+ VT_DOUBLE_VALUE = 14,
705
+ VT_REAL_VALUE = 16,
706
+ VT_IMAG_VALUE = 18
707
+ };
708
+ int64_t dtype() const {
709
+ return GetField<int64_t>(VT_DTYPE, 0);
710
+ }
711
+ bool has_value() const {
712
+ return GetField<uint8_t>(VT_HAS_VALUE, 0) != 0;
713
+ }
714
+ int64_t value_type() const {
715
+ return GetField<int64_t>(VT_VALUE_TYPE, 0);
716
+ }
717
+ bool bool_value() const {
718
+ return GetField<uint8_t>(VT_BOOL_VALUE, 0) != 0;
719
+ }
720
+ int64_t long_value() const {
721
+ return GetField<int64_t>(VT_LONG_VALUE, 0);
722
+ }
723
+ double double_value() const {
724
+ return GetField<double>(VT_DOUBLE_VALUE, 0.0);
725
+ }
726
+ double real_value() const {
727
+ return GetField<double>(VT_REAL_VALUE, 0.0);
728
+ }
729
+ double imag_value() const {
730
+ return GetField<double>(VT_IMAG_VALUE, 0.0);
731
+ }
732
+ bool Verify(::flatbuffers::Verifier &verifier) const {
733
+ return VerifyTableStart(verifier) &&
734
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
735
+ VerifyField<uint8_t>(verifier, VT_HAS_VALUE, 1) &&
736
+ VerifyField<int64_t>(verifier, VT_VALUE_TYPE, 8) &&
737
+ VerifyField<uint8_t>(verifier, VT_BOOL_VALUE, 1) &&
738
+ VerifyField<int64_t>(verifier, VT_LONG_VALUE, 8) &&
739
+ VerifyField<double>(verifier, VT_DOUBLE_VALUE, 8) &&
740
+ VerifyField<double>(verifier, VT_REAL_VALUE, 8) &&
741
+ VerifyField<double>(verifier, VT_IMAG_VALUE, 8) &&
742
+ verifier.EndTable();
743
+ }
744
+ };
745
+
746
+ struct ScalarBuilder {
747
+ typedef Scalar Table;
748
+ ::flatbuffers::FlatBufferBuilder &fbb_;
749
+ ::flatbuffers::uoffset_t start_;
750
+ void add_dtype(int64_t dtype) {
751
+ fbb_.AddElement<int64_t>(Scalar::VT_DTYPE, dtype, 0);
752
+ }
753
+ void add_has_value(bool has_value) {
754
+ fbb_.AddElement<uint8_t>(Scalar::VT_HAS_VALUE, static_cast<uint8_t>(has_value), 0);
755
+ }
756
+ void add_value_type(int64_t value_type) {
757
+ fbb_.AddElement<int64_t>(Scalar::VT_VALUE_TYPE, value_type, 0);
758
+ }
759
+ void add_bool_value(bool bool_value) {
760
+ fbb_.AddElement<uint8_t>(Scalar::VT_BOOL_VALUE, static_cast<uint8_t>(bool_value), 0);
761
+ }
762
+ void add_long_value(int64_t long_value) {
763
+ fbb_.AddElement<int64_t>(Scalar::VT_LONG_VALUE, long_value, 0);
764
+ }
765
+ void add_double_value(double double_value) {
766
+ fbb_.AddElement<double>(Scalar::VT_DOUBLE_VALUE, double_value, 0.0);
767
+ }
768
+ void add_real_value(double real_value) {
769
+ fbb_.AddElement<double>(Scalar::VT_REAL_VALUE, real_value, 0.0);
770
+ }
771
+ void add_imag_value(double imag_value) {
772
+ fbb_.AddElement<double>(Scalar::VT_IMAG_VALUE, imag_value, 0.0);
773
+ }
774
+ explicit ScalarBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
775
+ : fbb_(_fbb) {
776
+ start_ = fbb_.StartTable();
777
+ }
778
+ ::flatbuffers::Offset<Scalar> Finish() {
779
+ const auto end = fbb_.EndTable(start_);
780
+ auto o = ::flatbuffers::Offset<Scalar>(end);
781
+ return o;
782
+ }
783
+ };
784
+
785
+ inline ::flatbuffers::Offset<Scalar> CreateScalar(
786
+ ::flatbuffers::FlatBufferBuilder &_fbb,
787
+ int64_t dtype = 0,
788
+ bool has_value = false,
789
+ int64_t value_type = 0,
790
+ bool bool_value = false,
791
+ int64_t long_value = 0,
792
+ double double_value = 0.0,
793
+ double real_value = 0.0,
794
+ double imag_value = 0.0) {
795
+ ScalarBuilder builder_(_fbb);
796
+ builder_.add_imag_value(imag_value);
797
+ builder_.add_real_value(real_value);
798
+ builder_.add_double_value(double_value);
799
+ builder_.add_long_value(long_value);
800
+ builder_.add_value_type(value_type);
801
+ builder_.add_dtype(dtype);
802
+ builder_.add_bool_value(bool_value);
803
+ builder_.add_has_value(has_value);
804
+ return builder_.Finish();
805
+ }
806
+
807
+ struct ScalarCpu FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
808
+ typedef ScalarCpuBuilder Builder;
809
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
810
+ VT_SCALAR_VALUE = 4
811
+ };
812
+ const nvfuser::serde::Scalar *scalar_value() const {
813
+ return GetPointer<const nvfuser::serde::Scalar *>(VT_SCALAR_VALUE);
814
+ }
815
+ bool Verify(::flatbuffers::Verifier &verifier) const {
816
+ return VerifyTableStart(verifier) &&
817
+ VerifyOffset(verifier, VT_SCALAR_VALUE) &&
818
+ verifier.VerifyTable(scalar_value()) &&
819
+ verifier.EndTable();
820
+ }
821
+ };
822
+
823
+ struct ScalarCpuBuilder {
824
+ typedef ScalarCpu Table;
825
+ ::flatbuffers::FlatBufferBuilder &fbb_;
826
+ ::flatbuffers::uoffset_t start_;
827
+ void add_scalar_value(::flatbuffers::Offset<nvfuser::serde::Scalar> scalar_value) {
828
+ fbb_.AddOffset(ScalarCpu::VT_SCALAR_VALUE, scalar_value);
829
+ }
830
+ explicit ScalarCpuBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
831
+ : fbb_(_fbb) {
832
+ start_ = fbb_.StartTable();
833
+ }
834
+ ::flatbuffers::Offset<ScalarCpu> Finish() {
835
+ const auto end = fbb_.EndTable(start_);
836
+ auto o = ::flatbuffers::Offset<ScalarCpu>(end);
837
+ return o;
838
+ }
839
+ };
840
+
841
+ inline ::flatbuffers::Offset<ScalarCpu> CreateScalarCpu(
842
+ ::flatbuffers::FlatBufferBuilder &_fbb,
843
+ ::flatbuffers::Offset<nvfuser::serde::Scalar> scalar_value = 0) {
844
+ ScalarCpuBuilder builder_(_fbb);
845
+ builder_.add_scalar_value(scalar_value);
846
+ return builder_.Finish();
847
+ }
848
+
849
+ struct TensorArg FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
850
+ typedef TensorArgBuilder Builder;
851
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
852
+ VT_PTR = 4,
853
+ VT_SIZES = 6,
854
+ VT_STRIDES = 8,
855
+ VT_DTYPE = 10
856
+ };
857
+ uint64_t ptr() const {
858
+ return GetField<uint64_t>(VT_PTR, 0);
859
+ }
860
+ const ::flatbuffers::Vector<int64_t> *sizes() const {
861
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_SIZES);
862
+ }
863
+ const ::flatbuffers::Vector<int64_t> *strides() const {
864
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_STRIDES);
865
+ }
866
+ int64_t dtype() const {
867
+ return GetField<int64_t>(VT_DTYPE, 0);
868
+ }
869
+ bool Verify(::flatbuffers::Verifier &verifier) const {
870
+ return VerifyTableStart(verifier) &&
871
+ VerifyField<uint64_t>(verifier, VT_PTR, 8) &&
872
+ VerifyOffset(verifier, VT_SIZES) &&
873
+ verifier.VerifyVector(sizes()) &&
874
+ VerifyOffset(verifier, VT_STRIDES) &&
875
+ verifier.VerifyVector(strides()) &&
876
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
877
+ verifier.EndTable();
878
+ }
879
+ };
880
+
881
+ struct TensorArgBuilder {
882
+ typedef TensorArg Table;
883
+ ::flatbuffers::FlatBufferBuilder &fbb_;
884
+ ::flatbuffers::uoffset_t start_;
885
+ void add_ptr(uint64_t ptr) {
886
+ fbb_.AddElement<uint64_t>(TensorArg::VT_PTR, ptr, 0);
887
+ }
888
+ void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> sizes) {
889
+ fbb_.AddOffset(TensorArg::VT_SIZES, sizes);
890
+ }
891
+ void add_strides(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> strides) {
892
+ fbb_.AddOffset(TensorArg::VT_STRIDES, strides);
893
+ }
894
+ void add_dtype(int64_t dtype) {
895
+ fbb_.AddElement<int64_t>(TensorArg::VT_DTYPE, dtype, 0);
896
+ }
897
+ explicit TensorArgBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
898
+ : fbb_(_fbb) {
899
+ start_ = fbb_.StartTable();
900
+ }
901
+ ::flatbuffers::Offset<TensorArg> Finish() {
902
+ const auto end = fbb_.EndTable(start_);
903
+ auto o = ::flatbuffers::Offset<TensorArg>(end);
904
+ return o;
905
+ }
906
+ };
907
+
908
+ inline ::flatbuffers::Offset<TensorArg> CreateTensorArg(
909
+ ::flatbuffers::FlatBufferBuilder &_fbb,
910
+ uint64_t ptr = 0,
911
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> sizes = 0,
912
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> strides = 0,
913
+ int64_t dtype = 0) {
914
+ TensorArgBuilder builder_(_fbb);
915
+ builder_.add_dtype(dtype);
916
+ builder_.add_ptr(ptr);
917
+ builder_.add_strides(strides);
918
+ builder_.add_sizes(sizes);
919
+ return builder_.Finish();
920
+ }
921
+
922
+ inline ::flatbuffers::Offset<TensorArg> CreateTensorArgDirect(
923
+ ::flatbuffers::FlatBufferBuilder &_fbb,
924
+ uint64_t ptr = 0,
925
+ const std::vector<int64_t> *sizes = nullptr,
926
+ const std::vector<int64_t> *strides = nullptr,
927
+ int64_t dtype = 0) {
928
+ auto sizes__ = sizes ? _fbb.CreateVector<int64_t>(*sizes) : 0;
929
+ auto strides__ = strides ? _fbb.CreateVector<int64_t>(*strides) : 0;
930
+ return nvfuser::serde::CreateTensorArg(
931
+ _fbb,
932
+ ptr,
933
+ sizes__,
934
+ strides__,
935
+ dtype);
936
+ }
937
+
938
+ struct PolymorphicValue FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
939
+ typedef PolymorphicValueBuilder Builder;
940
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
941
+ VT_DATA_TYPE = 4,
942
+ VT_DATA = 6
943
+ };
944
+ nvfuser::serde::PolymorphicValueData data_type() const {
945
+ return static_cast<nvfuser::serde::PolymorphicValueData>(GetField<uint8_t>(VT_DATA_TYPE, 0));
946
+ }
947
+ const void *data() const {
948
+ return GetPointer<const void *>(VT_DATA);
949
+ }
950
+ template<typename T> const T *data_as() const;
951
+ const nvfuser::serde::Scalar *data_as_Scalar() const {
952
+ return data_type() == nvfuser::serde::PolymorphicValueData::Scalar ? static_cast<const nvfuser::serde::Scalar *>(data()) : nullptr;
953
+ }
954
+ const nvfuser::serde::ScalarCpu *data_as_ScalarCpu() const {
955
+ return data_type() == nvfuser::serde::PolymorphicValueData::ScalarCpu ? static_cast<const nvfuser::serde::ScalarCpu *>(data()) : nullptr;
956
+ }
957
+ const nvfuser::serde::TensorArg *data_as_TensorArg() const {
958
+ return data_type() == nvfuser::serde::PolymorphicValueData::TensorArg ? static_cast<const nvfuser::serde::TensorArg *>(data()) : nullptr;
959
+ }
960
+ bool Verify(::flatbuffers::Verifier &verifier) const {
961
+ return VerifyTableStart(verifier) &&
962
+ VerifyField<uint8_t>(verifier, VT_DATA_TYPE, 1) &&
963
+ VerifyOffset(verifier, VT_DATA) &&
964
+ VerifyPolymorphicValueData(verifier, data(), data_type()) &&
965
+ verifier.EndTable();
966
+ }
967
+ };
968
+
969
+ template<> inline const nvfuser::serde::Scalar *PolymorphicValue::data_as<nvfuser::serde::Scalar>() const {
970
+ return data_as_Scalar();
971
+ }
972
+
973
+ template<> inline const nvfuser::serde::ScalarCpu *PolymorphicValue::data_as<nvfuser::serde::ScalarCpu>() const {
974
+ return data_as_ScalarCpu();
975
+ }
976
+
977
+ template<> inline const nvfuser::serde::TensorArg *PolymorphicValue::data_as<nvfuser::serde::TensorArg>() const {
978
+ return data_as_TensorArg();
979
+ }
980
+
981
+ struct PolymorphicValueBuilder {
982
+ typedef PolymorphicValue Table;
983
+ ::flatbuffers::FlatBufferBuilder &fbb_;
984
+ ::flatbuffers::uoffset_t start_;
985
+ void add_data_type(nvfuser::serde::PolymorphicValueData data_type) {
986
+ fbb_.AddElement<uint8_t>(PolymorphicValue::VT_DATA_TYPE, static_cast<uint8_t>(data_type), 0);
987
+ }
988
+ void add_data(::flatbuffers::Offset<void> data) {
989
+ fbb_.AddOffset(PolymorphicValue::VT_DATA, data);
990
+ }
991
+ explicit PolymorphicValueBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
992
+ : fbb_(_fbb) {
993
+ start_ = fbb_.StartTable();
994
+ }
995
+ ::flatbuffers::Offset<PolymorphicValue> Finish() {
996
+ const auto end = fbb_.EndTable(start_);
997
+ auto o = ::flatbuffers::Offset<PolymorphicValue>(end);
998
+ return o;
999
+ }
1000
+ };
1001
+
1002
+ inline ::flatbuffers::Offset<PolymorphicValue> CreatePolymorphicValue(
1003
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1004
+ nvfuser::serde::PolymorphicValueData data_type = nvfuser::serde::PolymorphicValueData::NONE,
1005
+ ::flatbuffers::Offset<void> data = 0) {
1006
+ PolymorphicValueBuilder builder_(_fbb);
1007
+ builder_.add_data(data);
1008
+ builder_.add_data_type(data_type);
1009
+ return builder_.Finish();
1010
+ }
1011
+
1012
+ struct KernelArgumentHolder FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1013
+ typedef KernelArgumentHolderBuilder Builder;
1014
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1015
+ VT_ARGUMENTS = 4,
1016
+ VT_DEVICE_INDEX = 6,
1017
+ VT_CACHE_ID = 8
1018
+ };
1019
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::PolymorphicValue>> *arguments() const {
1020
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::PolymorphicValue>> *>(VT_ARGUMENTS);
1021
+ }
1022
+ int8_t device_index() const {
1023
+ return GetField<int8_t>(VT_DEVICE_INDEX, 0);
1024
+ }
1025
+ uint64_t cache_id() const {
1026
+ return GetField<uint64_t>(VT_CACHE_ID, 0);
1027
+ }
1028
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1029
+ return VerifyTableStart(verifier) &&
1030
+ VerifyOffset(verifier, VT_ARGUMENTS) &&
1031
+ verifier.VerifyVector(arguments()) &&
1032
+ verifier.VerifyVectorOfTables(arguments()) &&
1033
+ VerifyField<int8_t>(verifier, VT_DEVICE_INDEX, 1) &&
1034
+ VerifyField<uint64_t>(verifier, VT_CACHE_ID, 8) &&
1035
+ verifier.EndTable();
1036
+ }
1037
+ };
1038
+
1039
+ struct KernelArgumentHolderBuilder {
1040
+ typedef KernelArgumentHolder Table;
1041
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1042
+ ::flatbuffers::uoffset_t start_;
1043
+ void add_arguments(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::PolymorphicValue>>> arguments) {
1044
+ fbb_.AddOffset(KernelArgumentHolder::VT_ARGUMENTS, arguments);
1045
+ }
1046
+ void add_device_index(int8_t device_index) {
1047
+ fbb_.AddElement<int8_t>(KernelArgumentHolder::VT_DEVICE_INDEX, device_index, 0);
1048
+ }
1049
+ void add_cache_id(uint64_t cache_id) {
1050
+ fbb_.AddElement<uint64_t>(KernelArgumentHolder::VT_CACHE_ID, cache_id, 0);
1051
+ }
1052
+ explicit KernelArgumentHolderBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1053
+ : fbb_(_fbb) {
1054
+ start_ = fbb_.StartTable();
1055
+ }
1056
+ ::flatbuffers::Offset<KernelArgumentHolder> Finish() {
1057
+ const auto end = fbb_.EndTable(start_);
1058
+ auto o = ::flatbuffers::Offset<KernelArgumentHolder>(end);
1059
+ return o;
1060
+ }
1061
+ };
1062
+
1063
+ inline ::flatbuffers::Offset<KernelArgumentHolder> CreateKernelArgumentHolder(
1064
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1065
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::PolymorphicValue>>> arguments = 0,
1066
+ int8_t device_index = 0,
1067
+ uint64_t cache_id = 0) {
1068
+ KernelArgumentHolderBuilder builder_(_fbb);
1069
+ builder_.add_cache_id(cache_id);
1070
+ builder_.add_arguments(arguments);
1071
+ builder_.add_device_index(device_index);
1072
+ return builder_.Finish();
1073
+ }
1074
+
1075
+ inline ::flatbuffers::Offset<KernelArgumentHolder> CreateKernelArgumentHolderDirect(
1076
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1077
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::PolymorphicValue>> *arguments = nullptr,
1078
+ int8_t device_index = 0,
1079
+ uint64_t cache_id = 0) {
1080
+ auto arguments__ = arguments ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::PolymorphicValue>>(*arguments) : 0;
1081
+ return nvfuser::serde::CreateKernelArgumentHolder(
1082
+ _fbb,
1083
+ arguments__,
1084
+ device_index,
1085
+ cache_id);
1086
+ }
1087
+
1088
+ struct TensorShape FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1089
+ typedef TensorShapeBuilder Builder;
1090
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1091
+ VT_SHAPE = 4
1092
+ };
1093
+ const ::flatbuffers::Vector<int64_t> *shape() const {
1094
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_SHAPE);
1095
+ }
1096
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1097
+ return VerifyTableStart(verifier) &&
1098
+ VerifyOffset(verifier, VT_SHAPE) &&
1099
+ verifier.VerifyVector(shape()) &&
1100
+ verifier.EndTable();
1101
+ }
1102
+ };
1103
+
1104
+ struct TensorShapeBuilder {
1105
+ typedef TensorShape Table;
1106
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1107
+ ::flatbuffers::uoffset_t start_;
1108
+ void add_shape(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> shape) {
1109
+ fbb_.AddOffset(TensorShape::VT_SHAPE, shape);
1110
+ }
1111
+ explicit TensorShapeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1112
+ : fbb_(_fbb) {
1113
+ start_ = fbb_.StartTable();
1114
+ }
1115
+ ::flatbuffers::Offset<TensorShape> Finish() {
1116
+ const auto end = fbb_.EndTable(start_);
1117
+ auto o = ::flatbuffers::Offset<TensorShape>(end);
1118
+ return o;
1119
+ }
1120
+ };
1121
+
1122
+ inline ::flatbuffers::Offset<TensorShape> CreateTensorShape(
1123
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1124
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> shape = 0) {
1125
+ TensorShapeBuilder builder_(_fbb);
1126
+ builder_.add_shape(shape);
1127
+ return builder_.Finish();
1128
+ }
1129
+
1130
+ inline ::flatbuffers::Offset<TensorShape> CreateTensorShapeDirect(
1131
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1132
+ const std::vector<int64_t> *shape = nullptr) {
1133
+ auto shape__ = shape ? _fbb.CreateVector<int64_t>(*shape) : 0;
1134
+ return nvfuser::serde::CreateTensorShape(
1135
+ _fbb,
1136
+ shape__);
1137
+ }
1138
+
1139
+ struct LaunchParams FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1140
+ typedef LaunchParamsBuilder Builder;
1141
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1142
+ VT_GDIMX = 4,
1143
+ VT_GDIMY = 6,
1144
+ VT_GDIMZ = 8,
1145
+ VT_BDIMX = 10,
1146
+ VT_BDIMY = 12,
1147
+ VT_BDIMZ = 14,
1148
+ VT_SMEM = 16,
1149
+ VT_OUTPUT_SIZES = 18
1150
+ };
1151
+ int64_t gdimx() const {
1152
+ return GetField<int64_t>(VT_GDIMX, 0);
1153
+ }
1154
+ int64_t gdimy() const {
1155
+ return GetField<int64_t>(VT_GDIMY, 0);
1156
+ }
1157
+ int64_t gdimz() const {
1158
+ return GetField<int64_t>(VT_GDIMZ, 0);
1159
+ }
1160
+ int64_t bdimx() const {
1161
+ return GetField<int64_t>(VT_BDIMX, 0);
1162
+ }
1163
+ int64_t bdimy() const {
1164
+ return GetField<int64_t>(VT_BDIMY, 0);
1165
+ }
1166
+ int64_t bdimz() const {
1167
+ return GetField<int64_t>(VT_BDIMZ, 0);
1168
+ }
1169
+ int64_t smem() const {
1170
+ return GetField<int64_t>(VT_SMEM, 0);
1171
+ }
1172
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TensorShape>> *output_sizes() const {
1173
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TensorShape>> *>(VT_OUTPUT_SIZES);
1174
+ }
1175
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1176
+ return VerifyTableStart(verifier) &&
1177
+ VerifyField<int64_t>(verifier, VT_GDIMX, 8) &&
1178
+ VerifyField<int64_t>(verifier, VT_GDIMY, 8) &&
1179
+ VerifyField<int64_t>(verifier, VT_GDIMZ, 8) &&
1180
+ VerifyField<int64_t>(verifier, VT_BDIMX, 8) &&
1181
+ VerifyField<int64_t>(verifier, VT_BDIMY, 8) &&
1182
+ VerifyField<int64_t>(verifier, VT_BDIMZ, 8) &&
1183
+ VerifyField<int64_t>(verifier, VT_SMEM, 8) &&
1184
+ VerifyOffset(verifier, VT_OUTPUT_SIZES) &&
1185
+ verifier.VerifyVector(output_sizes()) &&
1186
+ verifier.VerifyVectorOfTables(output_sizes()) &&
1187
+ verifier.EndTable();
1188
+ }
1189
+ };
1190
+
1191
+ struct LaunchParamsBuilder {
1192
+ typedef LaunchParams Table;
1193
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1194
+ ::flatbuffers::uoffset_t start_;
1195
+ void add_gdimx(int64_t gdimx) {
1196
+ fbb_.AddElement<int64_t>(LaunchParams::VT_GDIMX, gdimx, 0);
1197
+ }
1198
+ void add_gdimy(int64_t gdimy) {
1199
+ fbb_.AddElement<int64_t>(LaunchParams::VT_GDIMY, gdimy, 0);
1200
+ }
1201
+ void add_gdimz(int64_t gdimz) {
1202
+ fbb_.AddElement<int64_t>(LaunchParams::VT_GDIMZ, gdimz, 0);
1203
+ }
1204
+ void add_bdimx(int64_t bdimx) {
1205
+ fbb_.AddElement<int64_t>(LaunchParams::VT_BDIMX, bdimx, 0);
1206
+ }
1207
+ void add_bdimy(int64_t bdimy) {
1208
+ fbb_.AddElement<int64_t>(LaunchParams::VT_BDIMY, bdimy, 0);
1209
+ }
1210
+ void add_bdimz(int64_t bdimz) {
1211
+ fbb_.AddElement<int64_t>(LaunchParams::VT_BDIMZ, bdimz, 0);
1212
+ }
1213
+ void add_smem(int64_t smem) {
1214
+ fbb_.AddElement<int64_t>(LaunchParams::VT_SMEM, smem, 0);
1215
+ }
1216
+ void add_output_sizes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TensorShape>>> output_sizes) {
1217
+ fbb_.AddOffset(LaunchParams::VT_OUTPUT_SIZES, output_sizes);
1218
+ }
1219
+ explicit LaunchParamsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1220
+ : fbb_(_fbb) {
1221
+ start_ = fbb_.StartTable();
1222
+ }
1223
+ ::flatbuffers::Offset<LaunchParams> Finish() {
1224
+ const auto end = fbb_.EndTable(start_);
1225
+ auto o = ::flatbuffers::Offset<LaunchParams>(end);
1226
+ return o;
1227
+ }
1228
+ };
1229
+
1230
+ inline ::flatbuffers::Offset<LaunchParams> CreateLaunchParams(
1231
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1232
+ int64_t gdimx = 0,
1233
+ int64_t gdimy = 0,
1234
+ int64_t gdimz = 0,
1235
+ int64_t bdimx = 0,
1236
+ int64_t bdimy = 0,
1237
+ int64_t bdimz = 0,
1238
+ int64_t smem = 0,
1239
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TensorShape>>> output_sizes = 0) {
1240
+ LaunchParamsBuilder builder_(_fbb);
1241
+ builder_.add_smem(smem);
1242
+ builder_.add_bdimz(bdimz);
1243
+ builder_.add_bdimy(bdimy);
1244
+ builder_.add_bdimx(bdimx);
1245
+ builder_.add_gdimz(gdimz);
1246
+ builder_.add_gdimy(gdimy);
1247
+ builder_.add_gdimx(gdimx);
1248
+ builder_.add_output_sizes(output_sizes);
1249
+ return builder_.Finish();
1250
+ }
1251
+
1252
+ inline ::flatbuffers::Offset<LaunchParams> CreateLaunchParamsDirect(
1253
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1254
+ int64_t gdimx = 0,
1255
+ int64_t gdimy = 0,
1256
+ int64_t gdimz = 0,
1257
+ int64_t bdimx = 0,
1258
+ int64_t bdimy = 0,
1259
+ int64_t bdimz = 0,
1260
+ int64_t smem = 0,
1261
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::TensorShape>> *output_sizes = nullptr) {
1262
+ auto output_sizes__ = output_sizes ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::TensorShape>>(*output_sizes) : 0;
1263
+ return nvfuser::serde::CreateLaunchParams(
1264
+ _fbb,
1265
+ gdimx,
1266
+ gdimy,
1267
+ gdimz,
1268
+ bdimx,
1269
+ bdimy,
1270
+ bdimz,
1271
+ smem,
1272
+ output_sizes__);
1273
+ }
1274
+
1275
+ struct GlobalBufferInfo FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1276
+ typedef GlobalBufferInfoBuilder Builder;
1277
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1278
+ VT_TV = 4,
1279
+ VT_SIZES = 6,
1280
+ VT_STRIDES = 8,
1281
+ VT_DTYPE = 10,
1282
+ VT_ZERO_INIT = 12,
1283
+ VT_RESETS_TO_ZERO = 14,
1284
+ VT_IS_PROFILE_BUFFER = 16,
1285
+ VT_IS_FUSION_OUTPUT = 18
1286
+ };
1287
+ int64_t tv() const {
1288
+ return GetField<int64_t>(VT_TV, -1LL);
1289
+ }
1290
+ const ::flatbuffers::Vector<int64_t> *sizes() const {
1291
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_SIZES);
1292
+ }
1293
+ const ::flatbuffers::Vector<int64_t> *strides() const {
1294
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_STRIDES);
1295
+ }
1296
+ int64_t dtype() const {
1297
+ return GetField<int64_t>(VT_DTYPE, 0);
1298
+ }
1299
+ bool zero_init() const {
1300
+ return GetField<uint8_t>(VT_ZERO_INIT, 0) != 0;
1301
+ }
1302
+ bool resets_to_zero() const {
1303
+ return GetField<uint8_t>(VT_RESETS_TO_ZERO, 0) != 0;
1304
+ }
1305
+ bool is_profile_buffer() const {
1306
+ return GetField<uint8_t>(VT_IS_PROFILE_BUFFER, 0) != 0;
1307
+ }
1308
+ bool is_fusion_output() const {
1309
+ return GetField<uint8_t>(VT_IS_FUSION_OUTPUT, 0) != 0;
1310
+ }
1311
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1312
+ return VerifyTableStart(verifier) &&
1313
+ VerifyField<int64_t>(verifier, VT_TV, 8) &&
1314
+ VerifyOffset(verifier, VT_SIZES) &&
1315
+ verifier.VerifyVector(sizes()) &&
1316
+ VerifyOffset(verifier, VT_STRIDES) &&
1317
+ verifier.VerifyVector(strides()) &&
1318
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
1319
+ VerifyField<uint8_t>(verifier, VT_ZERO_INIT, 1) &&
1320
+ VerifyField<uint8_t>(verifier, VT_RESETS_TO_ZERO, 1) &&
1321
+ VerifyField<uint8_t>(verifier, VT_IS_PROFILE_BUFFER, 1) &&
1322
+ VerifyField<uint8_t>(verifier, VT_IS_FUSION_OUTPUT, 1) &&
1323
+ verifier.EndTable();
1324
+ }
1325
+ };
1326
+
1327
+ struct GlobalBufferInfoBuilder {
1328
+ typedef GlobalBufferInfo Table;
1329
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1330
+ ::flatbuffers::uoffset_t start_;
1331
+ void add_tv(int64_t tv) {
1332
+ fbb_.AddElement<int64_t>(GlobalBufferInfo::VT_TV, tv, -1LL);
1333
+ }
1334
+ void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> sizes) {
1335
+ fbb_.AddOffset(GlobalBufferInfo::VT_SIZES, sizes);
1336
+ }
1337
+ void add_strides(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> strides) {
1338
+ fbb_.AddOffset(GlobalBufferInfo::VT_STRIDES, strides);
1339
+ }
1340
+ void add_dtype(int64_t dtype) {
1341
+ fbb_.AddElement<int64_t>(GlobalBufferInfo::VT_DTYPE, dtype, 0);
1342
+ }
1343
+ void add_zero_init(bool zero_init) {
1344
+ fbb_.AddElement<uint8_t>(GlobalBufferInfo::VT_ZERO_INIT, static_cast<uint8_t>(zero_init), 0);
1345
+ }
1346
+ void add_resets_to_zero(bool resets_to_zero) {
1347
+ fbb_.AddElement<uint8_t>(GlobalBufferInfo::VT_RESETS_TO_ZERO, static_cast<uint8_t>(resets_to_zero), 0);
1348
+ }
1349
+ void add_is_profile_buffer(bool is_profile_buffer) {
1350
+ fbb_.AddElement<uint8_t>(GlobalBufferInfo::VT_IS_PROFILE_BUFFER, static_cast<uint8_t>(is_profile_buffer), 0);
1351
+ }
1352
+ void add_is_fusion_output(bool is_fusion_output) {
1353
+ fbb_.AddElement<uint8_t>(GlobalBufferInfo::VT_IS_FUSION_OUTPUT, static_cast<uint8_t>(is_fusion_output), 0);
1354
+ }
1355
+ explicit GlobalBufferInfoBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1356
+ : fbb_(_fbb) {
1357
+ start_ = fbb_.StartTable();
1358
+ }
1359
+ ::flatbuffers::Offset<GlobalBufferInfo> Finish() {
1360
+ const auto end = fbb_.EndTable(start_);
1361
+ auto o = ::flatbuffers::Offset<GlobalBufferInfo>(end);
1362
+ return o;
1363
+ }
1364
+ };
1365
+
1366
+ inline ::flatbuffers::Offset<GlobalBufferInfo> CreateGlobalBufferInfo(
1367
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1368
+ int64_t tv = -1LL,
1369
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> sizes = 0,
1370
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> strides = 0,
1371
+ int64_t dtype = 0,
1372
+ bool zero_init = false,
1373
+ bool resets_to_zero = false,
1374
+ bool is_profile_buffer = false,
1375
+ bool is_fusion_output = false) {
1376
+ GlobalBufferInfoBuilder builder_(_fbb);
1377
+ builder_.add_dtype(dtype);
1378
+ builder_.add_tv(tv);
1379
+ builder_.add_strides(strides);
1380
+ builder_.add_sizes(sizes);
1381
+ builder_.add_is_fusion_output(is_fusion_output);
1382
+ builder_.add_is_profile_buffer(is_profile_buffer);
1383
+ builder_.add_resets_to_zero(resets_to_zero);
1384
+ builder_.add_zero_init(zero_init);
1385
+ return builder_.Finish();
1386
+ }
1387
+
1388
+ inline ::flatbuffers::Offset<GlobalBufferInfo> CreateGlobalBufferInfoDirect(
1389
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1390
+ int64_t tv = -1LL,
1391
+ const std::vector<int64_t> *sizes = nullptr,
1392
+ const std::vector<int64_t> *strides = nullptr,
1393
+ int64_t dtype = 0,
1394
+ bool zero_init = false,
1395
+ bool resets_to_zero = false,
1396
+ bool is_profile_buffer = false,
1397
+ bool is_fusion_output = false) {
1398
+ auto sizes__ = sizes ? _fbb.CreateVector<int64_t>(*sizes) : 0;
1399
+ auto strides__ = strides ? _fbb.CreateVector<int64_t>(*strides) : 0;
1400
+ return nvfuser::serde::CreateGlobalBufferInfo(
1401
+ _fbb,
1402
+ tv,
1403
+ sizes__,
1404
+ strides__,
1405
+ dtype,
1406
+ zero_init,
1407
+ resets_to_zero,
1408
+ is_profile_buffer,
1409
+ is_fusion_output);
1410
+ }
1411
+
1412
+ struct ExecutorEntry FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1413
+ typedef ExecutorEntryBuilder Builder;
1414
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1415
+ VT_INIT = 4,
1416
+ VT_LAUNCH_PARAMS = 6,
1417
+ VT_OUTPUTS = 8,
1418
+ VT_INTERMEDIATES = 10
1419
+ };
1420
+ bool init() const {
1421
+ return GetField<uint8_t>(VT_INIT, 0) != 0;
1422
+ }
1423
+ const nvfuser::serde::LaunchParams *launch_params() const {
1424
+ return GetPointer<const nvfuser::serde::LaunchParams *>(VT_LAUNCH_PARAMS);
1425
+ }
1426
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>> *outputs() const {
1427
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>> *>(VT_OUTPUTS);
1428
+ }
1429
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>> *intermediates() const {
1430
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>> *>(VT_INTERMEDIATES);
1431
+ }
1432
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1433
+ return VerifyTableStart(verifier) &&
1434
+ VerifyField<uint8_t>(verifier, VT_INIT, 1) &&
1435
+ VerifyOffset(verifier, VT_LAUNCH_PARAMS) &&
1436
+ verifier.VerifyTable(launch_params()) &&
1437
+ VerifyOffset(verifier, VT_OUTPUTS) &&
1438
+ verifier.VerifyVector(outputs()) &&
1439
+ verifier.VerifyVectorOfTables(outputs()) &&
1440
+ VerifyOffset(verifier, VT_INTERMEDIATES) &&
1441
+ verifier.VerifyVector(intermediates()) &&
1442
+ verifier.VerifyVectorOfTables(intermediates()) &&
1443
+ verifier.EndTable();
1444
+ }
1445
+ };
1446
+
1447
+ struct ExecutorEntryBuilder {
1448
+ typedef ExecutorEntry Table;
1449
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1450
+ ::flatbuffers::uoffset_t start_;
1451
+ void add_init(bool init) {
1452
+ fbb_.AddElement<uint8_t>(ExecutorEntry::VT_INIT, static_cast<uint8_t>(init), 0);
1453
+ }
1454
+ void add_launch_params(::flatbuffers::Offset<nvfuser::serde::LaunchParams> launch_params) {
1455
+ fbb_.AddOffset(ExecutorEntry::VT_LAUNCH_PARAMS, launch_params);
1456
+ }
1457
+ void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>>> outputs) {
1458
+ fbb_.AddOffset(ExecutorEntry::VT_OUTPUTS, outputs);
1459
+ }
1460
+ void add_intermediates(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>>> intermediates) {
1461
+ fbb_.AddOffset(ExecutorEntry::VT_INTERMEDIATES, intermediates);
1462
+ }
1463
+ explicit ExecutorEntryBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1464
+ : fbb_(_fbb) {
1465
+ start_ = fbb_.StartTable();
1466
+ }
1467
+ ::flatbuffers::Offset<ExecutorEntry> Finish() {
1468
+ const auto end = fbb_.EndTable(start_);
1469
+ auto o = ::flatbuffers::Offset<ExecutorEntry>(end);
1470
+ return o;
1471
+ }
1472
+ };
1473
+
1474
+ inline ::flatbuffers::Offset<ExecutorEntry> CreateExecutorEntry(
1475
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1476
+ bool init = false,
1477
+ ::flatbuffers::Offset<nvfuser::serde::LaunchParams> launch_params = 0,
1478
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>>> outputs = 0,
1479
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>>> intermediates = 0) {
1480
+ ExecutorEntryBuilder builder_(_fbb);
1481
+ builder_.add_intermediates(intermediates);
1482
+ builder_.add_outputs(outputs);
1483
+ builder_.add_launch_params(launch_params);
1484
+ builder_.add_init(init);
1485
+ return builder_.Finish();
1486
+ }
1487
+
1488
+ inline ::flatbuffers::Offset<ExecutorEntry> CreateExecutorEntryDirect(
1489
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1490
+ bool init = false,
1491
+ ::flatbuffers::Offset<nvfuser::serde::LaunchParams> launch_params = 0,
1492
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>> *outputs = nullptr,
1493
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>> *intermediates = nullptr) {
1494
+ auto outputs__ = outputs ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>>(*outputs) : 0;
1495
+ auto intermediates__ = intermediates ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::GlobalBufferInfo>>(*intermediates) : 0;
1496
+ return nvfuser::serde::CreateExecutorEntry(
1497
+ _fbb,
1498
+ init,
1499
+ launch_params,
1500
+ outputs__,
1501
+ intermediates__);
1502
+ }
1503
+
1504
+ struct At FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1505
+ typedef AtBuilder Builder;
1506
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1507
+ VT_INDEX = 4
1508
+ };
1509
+ int64_t index() const {
1510
+ return GetField<int64_t>(VT_INDEX, 0);
1511
+ }
1512
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1513
+ return VerifyTableStart(verifier) &&
1514
+ VerifyField<int64_t>(verifier, VT_INDEX, 8) &&
1515
+ verifier.EndTable();
1516
+ }
1517
+ };
1518
+
1519
+ struct AtBuilder {
1520
+ typedef At Table;
1521
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1522
+ ::flatbuffers::uoffset_t start_;
1523
+ void add_index(int64_t index) {
1524
+ fbb_.AddElement<int64_t>(At::VT_INDEX, index, 0);
1525
+ }
1526
+ explicit AtBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1527
+ : fbb_(_fbb) {
1528
+ start_ = fbb_.StartTable();
1529
+ }
1530
+ ::flatbuffers::Offset<At> Finish() {
1531
+ const auto end = fbb_.EndTable(start_);
1532
+ auto o = ::flatbuffers::Offset<At>(end);
1533
+ return o;
1534
+ }
1535
+ };
1536
+
1537
+ inline ::flatbuffers::Offset<At> CreateAt(
1538
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1539
+ int64_t index = 0) {
1540
+ AtBuilder builder_(_fbb);
1541
+ builder_.add_index(index);
1542
+ return builder_.Finish();
1543
+ }
1544
+
1545
+ struct BatchNorm FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1546
+ typedef BatchNormBuilder Builder;
1547
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1548
+ VT_TRAINING = 4,
1549
+ VT_CHANNELS_LAST = 6
1550
+ };
1551
+ bool training() const {
1552
+ return GetField<uint8_t>(VT_TRAINING, 0) != 0;
1553
+ }
1554
+ bool channels_last() const {
1555
+ return GetField<uint8_t>(VT_CHANNELS_LAST, 0) != 0;
1556
+ }
1557
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1558
+ return VerifyTableStart(verifier) &&
1559
+ VerifyField<uint8_t>(verifier, VT_TRAINING, 1) &&
1560
+ VerifyField<uint8_t>(verifier, VT_CHANNELS_LAST, 1) &&
1561
+ verifier.EndTable();
1562
+ }
1563
+ };
1564
+
1565
+ struct BatchNormBuilder {
1566
+ typedef BatchNorm Table;
1567
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1568
+ ::flatbuffers::uoffset_t start_;
1569
+ void add_training(bool training) {
1570
+ fbb_.AddElement<uint8_t>(BatchNorm::VT_TRAINING, static_cast<uint8_t>(training), 0);
1571
+ }
1572
+ void add_channels_last(bool channels_last) {
1573
+ fbb_.AddElement<uint8_t>(BatchNorm::VT_CHANNELS_LAST, static_cast<uint8_t>(channels_last), 0);
1574
+ }
1575
+ explicit BatchNormBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1576
+ : fbb_(_fbb) {
1577
+ start_ = fbb_.StartTable();
1578
+ }
1579
+ ::flatbuffers::Offset<BatchNorm> Finish() {
1580
+ const auto end = fbb_.EndTable(start_);
1581
+ auto o = ::flatbuffers::Offset<BatchNorm>(end);
1582
+ return o;
1583
+ }
1584
+ };
1585
+
1586
+ inline ::flatbuffers::Offset<BatchNorm> CreateBatchNorm(
1587
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1588
+ bool training = false,
1589
+ bool channels_last = false) {
1590
+ BatchNormBuilder builder_(_fbb);
1591
+ builder_.add_channels_last(channels_last);
1592
+ builder_.add_training(training);
1593
+ return builder_.Finish();
1594
+ }
1595
+
1596
+ struct Broadcast FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1597
+ typedef BroadcastBuilder Builder;
1598
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1599
+ VT_BROADCAST_DIMS = 4
1600
+ };
1601
+ const ::flatbuffers::Vector<uint8_t> *broadcast_dims() const {
1602
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_BROADCAST_DIMS);
1603
+ }
1604
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1605
+ return VerifyTableStart(verifier) &&
1606
+ VerifyOffset(verifier, VT_BROADCAST_DIMS) &&
1607
+ verifier.VerifyVector(broadcast_dims()) &&
1608
+ verifier.EndTable();
1609
+ }
1610
+ };
1611
+
1612
+ struct BroadcastBuilder {
1613
+ typedef Broadcast Table;
1614
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1615
+ ::flatbuffers::uoffset_t start_;
1616
+ void add_broadcast_dims(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> broadcast_dims) {
1617
+ fbb_.AddOffset(Broadcast::VT_BROADCAST_DIMS, broadcast_dims);
1618
+ }
1619
+ explicit BroadcastBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1620
+ : fbb_(_fbb) {
1621
+ start_ = fbb_.StartTable();
1622
+ }
1623
+ ::flatbuffers::Offset<Broadcast> Finish() {
1624
+ const auto end = fbb_.EndTable(start_);
1625
+ auto o = ::flatbuffers::Offset<Broadcast>(end);
1626
+ return o;
1627
+ }
1628
+ };
1629
+
1630
+ inline ::flatbuffers::Offset<Broadcast> CreateBroadcast(
1631
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1632
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> broadcast_dims = 0) {
1633
+ BroadcastBuilder builder_(_fbb);
1634
+ builder_.add_broadcast_dims(broadcast_dims);
1635
+ return builder_.Finish();
1636
+ }
1637
+
1638
+ inline ::flatbuffers::Offset<Broadcast> CreateBroadcastDirect(
1639
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1640
+ const std::vector<uint8_t> *broadcast_dims = nullptr) {
1641
+ auto broadcast_dims__ = broadcast_dims ? _fbb.CreateVector<uint8_t>(*broadcast_dims) : 0;
1642
+ return nvfuser::serde::CreateBroadcast(
1643
+ _fbb,
1644
+ broadcast_dims__);
1645
+ }
1646
+
1647
+ struct BroadcastInDim FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1648
+ typedef BroadcastInDimBuilder Builder;
1649
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1650
+ VT_OUTPUT_SIZE = 4,
1651
+ VT_BROADCAST_DIMS = 6
1652
+ };
1653
+ uint64_t output_size() const {
1654
+ return GetField<uint64_t>(VT_OUTPUT_SIZE, 0);
1655
+ }
1656
+ const ::flatbuffers::Vector<int64_t> *broadcast_dims() const {
1657
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_BROADCAST_DIMS);
1658
+ }
1659
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1660
+ return VerifyTableStart(verifier) &&
1661
+ VerifyField<uint64_t>(verifier, VT_OUTPUT_SIZE, 8) &&
1662
+ VerifyOffset(verifier, VT_BROADCAST_DIMS) &&
1663
+ verifier.VerifyVector(broadcast_dims()) &&
1664
+ verifier.EndTable();
1665
+ }
1666
+ };
1667
+
1668
+ struct BroadcastInDimBuilder {
1669
+ typedef BroadcastInDim Table;
1670
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1671
+ ::flatbuffers::uoffset_t start_;
1672
+ void add_output_size(uint64_t output_size) {
1673
+ fbb_.AddElement<uint64_t>(BroadcastInDim::VT_OUTPUT_SIZE, output_size, 0);
1674
+ }
1675
+ void add_broadcast_dims(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> broadcast_dims) {
1676
+ fbb_.AddOffset(BroadcastInDim::VT_BROADCAST_DIMS, broadcast_dims);
1677
+ }
1678
+ explicit BroadcastInDimBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1679
+ : fbb_(_fbb) {
1680
+ start_ = fbb_.StartTable();
1681
+ }
1682
+ ::flatbuffers::Offset<BroadcastInDim> Finish() {
1683
+ const auto end = fbb_.EndTable(start_);
1684
+ auto o = ::flatbuffers::Offset<BroadcastInDim>(end);
1685
+ return o;
1686
+ }
1687
+ };
1688
+
1689
+ inline ::flatbuffers::Offset<BroadcastInDim> CreateBroadcastInDim(
1690
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1691
+ uint64_t output_size = 0,
1692
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> broadcast_dims = 0) {
1693
+ BroadcastInDimBuilder builder_(_fbb);
1694
+ builder_.add_output_size(output_size);
1695
+ builder_.add_broadcast_dims(broadcast_dims);
1696
+ return builder_.Finish();
1697
+ }
1698
+
1699
+ inline ::flatbuffers::Offset<BroadcastInDim> CreateBroadcastInDimDirect(
1700
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1701
+ uint64_t output_size = 0,
1702
+ const std::vector<int64_t> *broadcast_dims = nullptr) {
1703
+ auto broadcast_dims__ = broadcast_dims ? _fbb.CreateVector<int64_t>(*broadcast_dims) : 0;
1704
+ return nvfuser::serde::CreateBroadcastInDim(
1705
+ _fbb,
1706
+ output_size,
1707
+ broadcast_dims__);
1708
+ }
1709
+
1710
+ struct Cat FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1711
+ typedef CatBuilder Builder;
1712
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1713
+ VT_DIM = 4,
1714
+ VT_MANUAL_PADDING = 6
1715
+ };
1716
+ int64_t dim() const {
1717
+ return GetField<int64_t>(VT_DIM, 0);
1718
+ }
1719
+ bool manual_padding() const {
1720
+ return GetField<uint8_t>(VT_MANUAL_PADDING, 0) != 0;
1721
+ }
1722
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1723
+ return VerifyTableStart(verifier) &&
1724
+ VerifyField<int64_t>(verifier, VT_DIM, 8) &&
1725
+ VerifyField<uint8_t>(verifier, VT_MANUAL_PADDING, 1) &&
1726
+ verifier.EndTable();
1727
+ }
1728
+ };
1729
+
1730
+ struct CatBuilder {
1731
+ typedef Cat Table;
1732
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1733
+ ::flatbuffers::uoffset_t start_;
1734
+ void add_dim(int64_t dim) {
1735
+ fbb_.AddElement<int64_t>(Cat::VT_DIM, dim, 0);
1736
+ }
1737
+ void add_manual_padding(bool manual_padding) {
1738
+ fbb_.AddElement<uint8_t>(Cat::VT_MANUAL_PADDING, static_cast<uint8_t>(manual_padding), 0);
1739
+ }
1740
+ explicit CatBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1741
+ : fbb_(_fbb) {
1742
+ start_ = fbb_.StartTable();
1743
+ }
1744
+ ::flatbuffers::Offset<Cat> Finish() {
1745
+ const auto end = fbb_.EndTable(start_);
1746
+ auto o = ::flatbuffers::Offset<Cat>(end);
1747
+ return o;
1748
+ }
1749
+ };
1750
+
1751
+ inline ::flatbuffers::Offset<Cat> CreateCat(
1752
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1753
+ int64_t dim = 0,
1754
+ bool manual_padding = false) {
1755
+ CatBuilder builder_(_fbb);
1756
+ builder_.add_dim(dim);
1757
+ builder_.add_manual_padding(manual_padding);
1758
+ return builder_.Finish();
1759
+ }
1760
+
1761
+ struct Dtype FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1762
+ typedef DtypeBuilder Builder;
1763
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1764
+ VT_DTYPE = 4
1765
+ };
1766
+ int64_t dtype() const {
1767
+ return GetField<int64_t>(VT_DTYPE, 0);
1768
+ }
1769
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1770
+ return VerifyTableStart(verifier) &&
1771
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
1772
+ verifier.EndTable();
1773
+ }
1774
+ };
1775
+
1776
+ struct DtypeBuilder {
1777
+ typedef Dtype Table;
1778
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1779
+ ::flatbuffers::uoffset_t start_;
1780
+ void add_dtype(int64_t dtype) {
1781
+ fbb_.AddElement<int64_t>(Dtype::VT_DTYPE, dtype, 0);
1782
+ }
1783
+ explicit DtypeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1784
+ : fbb_(_fbb) {
1785
+ start_ = fbb_.StartTable();
1786
+ }
1787
+ ::flatbuffers::Offset<Dtype> Finish() {
1788
+ const auto end = fbb_.EndTable(start_);
1789
+ auto o = ::flatbuffers::Offset<Dtype>(end);
1790
+ return o;
1791
+ }
1792
+ };
1793
+
1794
+ inline ::flatbuffers::Offset<Dtype> CreateDtype(
1795
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1796
+ int64_t dtype = 0) {
1797
+ DtypeBuilder builder_(_fbb);
1798
+ builder_.add_dtype(dtype);
1799
+ return builder_.Finish();
1800
+ }
1801
+
1802
+ struct Dimension FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1803
+ typedef DimensionBuilder Builder;
1804
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1805
+ VT_DIM = 4
1806
+ };
1807
+ int64_t dim() const {
1808
+ return GetField<int64_t>(VT_DIM, 0);
1809
+ }
1810
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1811
+ return VerifyTableStart(verifier) &&
1812
+ VerifyField<int64_t>(verifier, VT_DIM, 8) &&
1813
+ verifier.EndTable();
1814
+ }
1815
+ };
1816
+
1817
+ struct DimensionBuilder {
1818
+ typedef Dimension Table;
1819
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1820
+ ::flatbuffers::uoffset_t start_;
1821
+ void add_dim(int64_t dim) {
1822
+ fbb_.AddElement<int64_t>(Dimension::VT_DIM, dim, 0);
1823
+ }
1824
+ explicit DimensionBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1825
+ : fbb_(_fbb) {
1826
+ start_ = fbb_.StartTable();
1827
+ }
1828
+ ::flatbuffers::Offset<Dimension> Finish() {
1829
+ const auto end = fbb_.EndTable(start_);
1830
+ auto o = ::flatbuffers::Offset<Dimension>(end);
1831
+ return o;
1832
+ }
1833
+ };
1834
+
1835
+ inline ::flatbuffers::Offset<Dimension> CreateDimension(
1836
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1837
+ int64_t dim = 0) {
1838
+ DimensionBuilder builder_(_fbb);
1839
+ builder_.add_dim(dim);
1840
+ return builder_.Finish();
1841
+ }
1842
+
1843
+ struct Norm FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1844
+ typedef NormBuilder Builder;
1845
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1846
+ VT_AXES = 4,
1847
+ VT_CORRECTION = 6,
1848
+ VT_KEEP_DIM = 8
1849
+ };
1850
+ const ::flatbuffers::Vector<int64_t> *axes() const {
1851
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_AXES);
1852
+ }
1853
+ int64_t correction() const {
1854
+ return GetField<int64_t>(VT_CORRECTION, 0);
1855
+ }
1856
+ bool keep_dim() const {
1857
+ return GetField<uint8_t>(VT_KEEP_DIM, 0) != 0;
1858
+ }
1859
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1860
+ return VerifyTableStart(verifier) &&
1861
+ VerifyOffset(verifier, VT_AXES) &&
1862
+ verifier.VerifyVector(axes()) &&
1863
+ VerifyField<int64_t>(verifier, VT_CORRECTION, 8) &&
1864
+ VerifyField<uint8_t>(verifier, VT_KEEP_DIM, 1) &&
1865
+ verifier.EndTable();
1866
+ }
1867
+ };
1868
+
1869
+ struct NormBuilder {
1870
+ typedef Norm Table;
1871
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1872
+ ::flatbuffers::uoffset_t start_;
1873
+ void add_axes(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> axes) {
1874
+ fbb_.AddOffset(Norm::VT_AXES, axes);
1875
+ }
1876
+ void add_correction(int64_t correction) {
1877
+ fbb_.AddElement<int64_t>(Norm::VT_CORRECTION, correction, 0);
1878
+ }
1879
+ void add_keep_dim(bool keep_dim) {
1880
+ fbb_.AddElement<uint8_t>(Norm::VT_KEEP_DIM, static_cast<uint8_t>(keep_dim), 0);
1881
+ }
1882
+ explicit NormBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1883
+ : fbb_(_fbb) {
1884
+ start_ = fbb_.StartTable();
1885
+ }
1886
+ ::flatbuffers::Offset<Norm> Finish() {
1887
+ const auto end = fbb_.EndTable(start_);
1888
+ auto o = ::flatbuffers::Offset<Norm>(end);
1889
+ return o;
1890
+ }
1891
+ };
1892
+
1893
+ inline ::flatbuffers::Offset<Norm> CreateNorm(
1894
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1895
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> axes = 0,
1896
+ int64_t correction = 0,
1897
+ bool keep_dim = false) {
1898
+ NormBuilder builder_(_fbb);
1899
+ builder_.add_correction(correction);
1900
+ builder_.add_axes(axes);
1901
+ builder_.add_keep_dim(keep_dim);
1902
+ return builder_.Finish();
1903
+ }
1904
+
1905
+ inline ::flatbuffers::Offset<Norm> CreateNormDirect(
1906
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1907
+ const std::vector<int64_t> *axes = nullptr,
1908
+ int64_t correction = 0,
1909
+ bool keep_dim = false) {
1910
+ auto axes__ = axes ? _fbb.CreateVector<int64_t>(*axes) : 0;
1911
+ return nvfuser::serde::CreateNorm(
1912
+ _fbb,
1913
+ axes__,
1914
+ correction,
1915
+ keep_dim);
1916
+ }
1917
+
1918
+ struct Output FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1919
+ typedef OutputBuilder Builder;
1920
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1921
+ VT_STRIDE_ORDER = 4
1922
+ };
1923
+ const ::flatbuffers::Vector<int64_t> *stride_order() const {
1924
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_STRIDE_ORDER);
1925
+ }
1926
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1927
+ return VerifyTableStart(verifier) &&
1928
+ VerifyOffset(verifier, VT_STRIDE_ORDER) &&
1929
+ verifier.VerifyVector(stride_order()) &&
1930
+ verifier.EndTable();
1931
+ }
1932
+ };
1933
+
1934
+ struct OutputBuilder {
1935
+ typedef Output Table;
1936
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1937
+ ::flatbuffers::uoffset_t start_;
1938
+ void add_stride_order(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> stride_order) {
1939
+ fbb_.AddOffset(Output::VT_STRIDE_ORDER, stride_order);
1940
+ }
1941
+ explicit OutputBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1942
+ : fbb_(_fbb) {
1943
+ start_ = fbb_.StartTable();
1944
+ }
1945
+ ::flatbuffers::Offset<Output> Finish() {
1946
+ const auto end = fbb_.EndTable(start_);
1947
+ auto o = ::flatbuffers::Offset<Output>(end);
1948
+ return o;
1949
+ }
1950
+ };
1951
+
1952
+ inline ::flatbuffers::Offset<Output> CreateOutput(
1953
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1954
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> stride_order = 0) {
1955
+ OutputBuilder builder_(_fbb);
1956
+ builder_.add_stride_order(stride_order);
1957
+ return builder_.Finish();
1958
+ }
1959
+
1960
+ inline ::flatbuffers::Offset<Output> CreateOutputDirect(
1961
+ ::flatbuffers::FlatBufferBuilder &_fbb,
1962
+ const std::vector<int64_t> *stride_order = nullptr) {
1963
+ auto stride_order__ = stride_order ? _fbb.CreateVector<int64_t>(*stride_order) : 0;
1964
+ return nvfuser::serde::CreateOutput(
1965
+ _fbb,
1966
+ stride_order__);
1967
+ }
1968
+
1969
+ struct Dims FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
1970
+ typedef DimsBuilder Builder;
1971
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1972
+ VT_DIMS = 4
1973
+ };
1974
+ const ::flatbuffers::Vector<int64_t> *dims() const {
1975
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_DIMS);
1976
+ }
1977
+ bool Verify(::flatbuffers::Verifier &verifier) const {
1978
+ return VerifyTableStart(verifier) &&
1979
+ VerifyOffset(verifier, VT_DIMS) &&
1980
+ verifier.VerifyVector(dims()) &&
1981
+ verifier.EndTable();
1982
+ }
1983
+ };
1984
+
1985
+ struct DimsBuilder {
1986
+ typedef Dims Table;
1987
+ ::flatbuffers::FlatBufferBuilder &fbb_;
1988
+ ::flatbuffers::uoffset_t start_;
1989
+ void add_dims(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> dims) {
1990
+ fbb_.AddOffset(Dims::VT_DIMS, dims);
1991
+ }
1992
+ explicit DimsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
1993
+ : fbb_(_fbb) {
1994
+ start_ = fbb_.StartTable();
1995
+ }
1996
+ ::flatbuffers::Offset<Dims> Finish() {
1997
+ const auto end = fbb_.EndTable(start_);
1998
+ auto o = ::flatbuffers::Offset<Dims>(end);
1999
+ return o;
2000
+ }
2001
+ };
2002
+
2003
+ inline ::flatbuffers::Offset<Dims> CreateDims(
2004
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2005
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> dims = 0) {
2006
+ DimsBuilder builder_(_fbb);
2007
+ builder_.add_dims(dims);
2008
+ return builder_.Finish();
2009
+ }
2010
+
2011
+ inline ::flatbuffers::Offset<Dims> CreateDimsDirect(
2012
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2013
+ const std::vector<int64_t> *dims = nullptr) {
2014
+ auto dims__ = dims ? _fbb.CreateVector<int64_t>(*dims) : 0;
2015
+ return nvfuser::serde::CreateDims(
2016
+ _fbb,
2017
+ dims__);
2018
+ }
2019
+
2020
+ struct Reduction FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2021
+ typedef ReductionBuilder Builder;
2022
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2023
+ VT_AXES = 4,
2024
+ VT_KEEP_DIM = 6,
2025
+ VT_DTYPE = 8
2026
+ };
2027
+ const ::flatbuffers::Vector<int64_t> *axes() const {
2028
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_AXES);
2029
+ }
2030
+ bool keep_dim() const {
2031
+ return GetField<uint8_t>(VT_KEEP_DIM, 0) != 0;
2032
+ }
2033
+ int64_t dtype() const {
2034
+ return GetField<int64_t>(VT_DTYPE, 0);
2035
+ }
2036
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2037
+ return VerifyTableStart(verifier) &&
2038
+ VerifyOffset(verifier, VT_AXES) &&
2039
+ verifier.VerifyVector(axes()) &&
2040
+ VerifyField<uint8_t>(verifier, VT_KEEP_DIM, 1) &&
2041
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
2042
+ verifier.EndTable();
2043
+ }
2044
+ };
2045
+
2046
+ struct ReductionBuilder {
2047
+ typedef Reduction Table;
2048
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2049
+ ::flatbuffers::uoffset_t start_;
2050
+ void add_axes(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> axes) {
2051
+ fbb_.AddOffset(Reduction::VT_AXES, axes);
2052
+ }
2053
+ void add_keep_dim(bool keep_dim) {
2054
+ fbb_.AddElement<uint8_t>(Reduction::VT_KEEP_DIM, static_cast<uint8_t>(keep_dim), 0);
2055
+ }
2056
+ void add_dtype(int64_t dtype) {
2057
+ fbb_.AddElement<int64_t>(Reduction::VT_DTYPE, dtype, 0);
2058
+ }
2059
+ explicit ReductionBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2060
+ : fbb_(_fbb) {
2061
+ start_ = fbb_.StartTable();
2062
+ }
2063
+ ::flatbuffers::Offset<Reduction> Finish() {
2064
+ const auto end = fbb_.EndTable(start_);
2065
+ auto o = ::flatbuffers::Offset<Reduction>(end);
2066
+ return o;
2067
+ }
2068
+ };
2069
+
2070
+ inline ::flatbuffers::Offset<Reduction> CreateReduction(
2071
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2072
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> axes = 0,
2073
+ bool keep_dim = false,
2074
+ int64_t dtype = 0) {
2075
+ ReductionBuilder builder_(_fbb);
2076
+ builder_.add_dtype(dtype);
2077
+ builder_.add_axes(axes);
2078
+ builder_.add_keep_dim(keep_dim);
2079
+ return builder_.Finish();
2080
+ }
2081
+
2082
+ inline ::flatbuffers::Offset<Reduction> CreateReductionDirect(
2083
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2084
+ const std::vector<int64_t> *axes = nullptr,
2085
+ bool keep_dim = false,
2086
+ int64_t dtype = 0) {
2087
+ auto axes__ = axes ? _fbb.CreateVector<int64_t>(*axes) : 0;
2088
+ return nvfuser::serde::CreateReduction(
2089
+ _fbb,
2090
+ axes__,
2091
+ keep_dim,
2092
+ dtype);
2093
+ }
2094
+
2095
+ struct Size FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2096
+ typedef SizeBuilder Builder;
2097
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2098
+ VT_DIM = 4
2099
+ };
2100
+ int64_t dim() const {
2101
+ return GetField<int64_t>(VT_DIM, 0);
2102
+ }
2103
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2104
+ return VerifyTableStart(verifier) &&
2105
+ VerifyField<int64_t>(verifier, VT_DIM, 8) &&
2106
+ verifier.EndTable();
2107
+ }
2108
+ };
2109
+
2110
+ struct SizeBuilder {
2111
+ typedef Size Table;
2112
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2113
+ ::flatbuffers::uoffset_t start_;
2114
+ void add_dim(int64_t dim) {
2115
+ fbb_.AddElement<int64_t>(Size::VT_DIM, dim, 0);
2116
+ }
2117
+ explicit SizeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2118
+ : fbb_(_fbb) {
2119
+ start_ = fbb_.StartTable();
2120
+ }
2121
+ ::flatbuffers::Offset<Size> Finish() {
2122
+ const auto end = fbb_.EndTable(start_);
2123
+ auto o = ::flatbuffers::Offset<Size>(end);
2124
+ return o;
2125
+ }
2126
+ };
2127
+
2128
+ inline ::flatbuffers::Offset<Size> CreateSize(
2129
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2130
+ int64_t dim = 0) {
2131
+ SizeBuilder builder_(_fbb);
2132
+ builder_.add_dim(dim);
2133
+ return builder_.Finish();
2134
+ }
2135
+
2136
+ struct Slice FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2137
+ typedef SliceBuilder Builder;
2138
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2139
+ VT_MANUAL_NORMALIZATION = 4
2140
+ };
2141
+ bool manual_normalization() const {
2142
+ return GetField<uint8_t>(VT_MANUAL_NORMALIZATION, 0) != 0;
2143
+ }
2144
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2145
+ return VerifyTableStart(verifier) &&
2146
+ VerifyField<uint8_t>(verifier, VT_MANUAL_NORMALIZATION, 1) &&
2147
+ verifier.EndTable();
2148
+ }
2149
+ };
2150
+
2151
+ struct SliceBuilder {
2152
+ typedef Slice Table;
2153
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2154
+ ::flatbuffers::uoffset_t start_;
2155
+ void add_manual_normalization(bool manual_normalization) {
2156
+ fbb_.AddElement<uint8_t>(Slice::VT_MANUAL_NORMALIZATION, static_cast<uint8_t>(manual_normalization), 0);
2157
+ }
2158
+ explicit SliceBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2159
+ : fbb_(_fbb) {
2160
+ start_ = fbb_.StartTable();
2161
+ }
2162
+ ::flatbuffers::Offset<Slice> Finish() {
2163
+ const auto end = fbb_.EndTable(start_);
2164
+ auto o = ::flatbuffers::Offset<Slice>(end);
2165
+ return o;
2166
+ }
2167
+ };
2168
+
2169
+ inline ::flatbuffers::Offset<Slice> CreateSlice(
2170
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2171
+ bool manual_normalization = false) {
2172
+ SliceBuilder builder_(_fbb);
2173
+ builder_.add_manual_normalization(manual_normalization);
2174
+ return builder_.Finish();
2175
+ }
2176
+
2177
+ struct Squeeze FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2178
+ typedef SqueezeBuilder Builder;
2179
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2180
+ VT_SQUEEZE_DIMS = 4,
2181
+ VT_SQUEEZE_EXPANDED = 6
2182
+ };
2183
+ const ::flatbuffers::Vector<int64_t> *squeeze_dims() const {
2184
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_SQUEEZE_DIMS);
2185
+ }
2186
+ bool squeeze_expanded() const {
2187
+ return GetField<uint8_t>(VT_SQUEEZE_EXPANDED, 0) != 0;
2188
+ }
2189
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2190
+ return VerifyTableStart(verifier) &&
2191
+ VerifyOffset(verifier, VT_SQUEEZE_DIMS) &&
2192
+ verifier.VerifyVector(squeeze_dims()) &&
2193
+ VerifyField<uint8_t>(verifier, VT_SQUEEZE_EXPANDED, 1) &&
2194
+ verifier.EndTable();
2195
+ }
2196
+ };
2197
+
2198
+ struct SqueezeBuilder {
2199
+ typedef Squeeze Table;
2200
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2201
+ ::flatbuffers::uoffset_t start_;
2202
+ void add_squeeze_dims(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> squeeze_dims) {
2203
+ fbb_.AddOffset(Squeeze::VT_SQUEEZE_DIMS, squeeze_dims);
2204
+ }
2205
+ void add_squeeze_expanded(bool squeeze_expanded) {
2206
+ fbb_.AddElement<uint8_t>(Squeeze::VT_SQUEEZE_EXPANDED, static_cast<uint8_t>(squeeze_expanded), 0);
2207
+ }
2208
+ explicit SqueezeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2209
+ : fbb_(_fbb) {
2210
+ start_ = fbb_.StartTable();
2211
+ }
2212
+ ::flatbuffers::Offset<Squeeze> Finish() {
2213
+ const auto end = fbb_.EndTable(start_);
2214
+ auto o = ::flatbuffers::Offset<Squeeze>(end);
2215
+ return o;
2216
+ }
2217
+ };
2218
+
2219
+ inline ::flatbuffers::Offset<Squeeze> CreateSqueeze(
2220
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2221
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> squeeze_dims = 0,
2222
+ bool squeeze_expanded = false) {
2223
+ SqueezeBuilder builder_(_fbb);
2224
+ builder_.add_squeeze_dims(squeeze_dims);
2225
+ builder_.add_squeeze_expanded(squeeze_expanded);
2226
+ return builder_.Finish();
2227
+ }
2228
+
2229
+ inline ::flatbuffers::Offset<Squeeze> CreateSqueezeDirect(
2230
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2231
+ const std::vector<int64_t> *squeeze_dims = nullptr,
2232
+ bool squeeze_expanded = false) {
2233
+ auto squeeze_dims__ = squeeze_dims ? _fbb.CreateVector<int64_t>(*squeeze_dims) : 0;
2234
+ return nvfuser::serde::CreateSqueeze(
2235
+ _fbb,
2236
+ squeeze_dims__,
2237
+ squeeze_expanded);
2238
+ }
2239
+
2240
+ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2241
+ typedef TensorBuilder Builder;
2242
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2243
+ VT_SIZES = 4,
2244
+ VT_CONTIGUITY = 6,
2245
+ VT_STRIDE_ORDER = 8,
2246
+ VT_DTYPE = 10,
2247
+ VT_IS_CPU = 12
2248
+ };
2249
+ const ::flatbuffers::Vector<int64_t> *sizes() const {
2250
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_SIZES);
2251
+ }
2252
+ const ::flatbuffers::Vector<nvfuser::serde::Contiguity> *contiguity() const {
2253
+ return GetPointer<const ::flatbuffers::Vector<nvfuser::serde::Contiguity> *>(VT_CONTIGUITY);
2254
+ }
2255
+ const ::flatbuffers::Vector<int64_t> *stride_order() const {
2256
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_STRIDE_ORDER);
2257
+ }
2258
+ int64_t dtype() const {
2259
+ return GetField<int64_t>(VT_DTYPE, 0);
2260
+ }
2261
+ bool is_cpu() const {
2262
+ return GetField<uint8_t>(VT_IS_CPU, 0) != 0;
2263
+ }
2264
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2265
+ return VerifyTableStart(verifier) &&
2266
+ VerifyOffset(verifier, VT_SIZES) &&
2267
+ verifier.VerifyVector(sizes()) &&
2268
+ VerifyOffset(verifier, VT_CONTIGUITY) &&
2269
+ verifier.VerifyVector(contiguity()) &&
2270
+ VerifyOffset(verifier, VT_STRIDE_ORDER) &&
2271
+ verifier.VerifyVector(stride_order()) &&
2272
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
2273
+ VerifyField<uint8_t>(verifier, VT_IS_CPU, 1) &&
2274
+ verifier.EndTable();
2275
+ }
2276
+ };
2277
+
2278
+ struct TensorBuilder {
2279
+ typedef Tensor Table;
2280
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2281
+ ::flatbuffers::uoffset_t start_;
2282
+ void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> sizes) {
2283
+ fbb_.AddOffset(Tensor::VT_SIZES, sizes);
2284
+ }
2285
+ void add_contiguity(::flatbuffers::Offset<::flatbuffers::Vector<nvfuser::serde::Contiguity>> contiguity) {
2286
+ fbb_.AddOffset(Tensor::VT_CONTIGUITY, contiguity);
2287
+ }
2288
+ void add_stride_order(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> stride_order) {
2289
+ fbb_.AddOffset(Tensor::VT_STRIDE_ORDER, stride_order);
2290
+ }
2291
+ void add_dtype(int64_t dtype) {
2292
+ fbb_.AddElement<int64_t>(Tensor::VT_DTYPE, dtype, 0);
2293
+ }
2294
+ void add_is_cpu(bool is_cpu) {
2295
+ fbb_.AddElement<uint8_t>(Tensor::VT_IS_CPU, static_cast<uint8_t>(is_cpu), 0);
2296
+ }
2297
+ explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2298
+ : fbb_(_fbb) {
2299
+ start_ = fbb_.StartTable();
2300
+ }
2301
+ ::flatbuffers::Offset<Tensor> Finish() {
2302
+ const auto end = fbb_.EndTable(start_);
2303
+ auto o = ::flatbuffers::Offset<Tensor>(end);
2304
+ return o;
2305
+ }
2306
+ };
2307
+
2308
+ inline ::flatbuffers::Offset<Tensor> CreateTensor(
2309
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2310
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> sizes = 0,
2311
+ ::flatbuffers::Offset<::flatbuffers::Vector<nvfuser::serde::Contiguity>> contiguity = 0,
2312
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> stride_order = 0,
2313
+ int64_t dtype = 0,
2314
+ bool is_cpu = false) {
2315
+ TensorBuilder builder_(_fbb);
2316
+ builder_.add_dtype(dtype);
2317
+ builder_.add_stride_order(stride_order);
2318
+ builder_.add_contiguity(contiguity);
2319
+ builder_.add_sizes(sizes);
2320
+ builder_.add_is_cpu(is_cpu);
2321
+ return builder_.Finish();
2322
+ }
2323
+
2324
+ inline ::flatbuffers::Offset<Tensor> CreateTensorDirect(
2325
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2326
+ const std::vector<int64_t> *sizes = nullptr,
2327
+ const std::vector<nvfuser::serde::Contiguity> *contiguity = nullptr,
2328
+ const std::vector<int64_t> *stride_order = nullptr,
2329
+ int64_t dtype = 0,
2330
+ bool is_cpu = false) {
2331
+ auto sizes__ = sizes ? _fbb.CreateVector<int64_t>(*sizes) : 0;
2332
+ auto contiguity__ = contiguity ? _fbb.CreateVector<nvfuser::serde::Contiguity>(*contiguity) : 0;
2333
+ auto stride_order__ = stride_order ? _fbb.CreateVector<int64_t>(*stride_order) : 0;
2334
+ return nvfuser::serde::CreateTensor(
2335
+ _fbb,
2336
+ sizes__,
2337
+ contiguity__,
2338
+ stride_order__,
2339
+ dtype,
2340
+ is_cpu);
2341
+ }
2342
+
2343
+ struct TensorCreationSymbolic FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2344
+ typedef TensorCreationSymbolicBuilder Builder;
2345
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2346
+ VT_DTYPE = 4
2347
+ };
2348
+ int64_t dtype() const {
2349
+ return GetField<int64_t>(VT_DTYPE, 0);
2350
+ }
2351
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2352
+ return VerifyTableStart(verifier) &&
2353
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
2354
+ verifier.EndTable();
2355
+ }
2356
+ };
2357
+
2358
+ struct TensorCreationSymbolicBuilder {
2359
+ typedef TensorCreationSymbolic Table;
2360
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2361
+ ::flatbuffers::uoffset_t start_;
2362
+ void add_dtype(int64_t dtype) {
2363
+ fbb_.AddElement<int64_t>(TensorCreationSymbolic::VT_DTYPE, dtype, 0);
2364
+ }
2365
+ explicit TensorCreationSymbolicBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2366
+ : fbb_(_fbb) {
2367
+ start_ = fbb_.StartTable();
2368
+ }
2369
+ ::flatbuffers::Offset<TensorCreationSymbolic> Finish() {
2370
+ const auto end = fbb_.EndTable(start_);
2371
+ auto o = ::flatbuffers::Offset<TensorCreationSymbolic>(end);
2372
+ return o;
2373
+ }
2374
+ };
2375
+
2376
+ inline ::flatbuffers::Offset<TensorCreationSymbolic> CreateTensorCreationSymbolic(
2377
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2378
+ int64_t dtype = 0) {
2379
+ TensorCreationSymbolicBuilder builder_(_fbb);
2380
+ builder_.add_dtype(dtype);
2381
+ return builder_.Finish();
2382
+ }
2383
+
2384
+ struct Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2385
+ typedef VectorBuilder Builder;
2386
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2387
+ VT_DTYPE = 4
2388
+ };
2389
+ int64_t dtype() const {
2390
+ return GetField<int64_t>(VT_DTYPE, 0);
2391
+ }
2392
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2393
+ return VerifyTableStart(verifier) &&
2394
+ VerifyField<int64_t>(verifier, VT_DTYPE, 8) &&
2395
+ verifier.EndTable();
2396
+ }
2397
+ };
2398
+
2399
+ struct VectorBuilder {
2400
+ typedef Vector Table;
2401
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2402
+ ::flatbuffers::uoffset_t start_;
2403
+ void add_dtype(int64_t dtype) {
2404
+ fbb_.AddElement<int64_t>(Vector::VT_DTYPE, dtype, 0);
2405
+ }
2406
+ explicit VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2407
+ : fbb_(_fbb) {
2408
+ start_ = fbb_.StartTable();
2409
+ }
2410
+ ::flatbuffers::Offset<Vector> Finish() {
2411
+ const auto end = fbb_.EndTable(start_);
2412
+ auto o = ::flatbuffers::Offset<Vector>(end);
2413
+ return o;
2414
+ }
2415
+ };
2416
+
2417
+ inline ::flatbuffers::Offset<Vector> CreateVector(
2418
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2419
+ int64_t dtype = 0) {
2420
+ VectorBuilder builder_(_fbb);
2421
+ builder_.add_dtype(dtype);
2422
+ return builder_.Finish();
2423
+ }
2424
+
2425
+ struct Welford FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2426
+ typedef WelfordBuilder Builder;
2427
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2428
+ VT_AXES = 4
2429
+ };
2430
+ const ::flatbuffers::Vector<int64_t> *axes() const {
2431
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_AXES);
2432
+ }
2433
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2434
+ return VerifyTableStart(verifier) &&
2435
+ VerifyOffset(verifier, VT_AXES) &&
2436
+ verifier.VerifyVector(axes()) &&
2437
+ verifier.EndTable();
2438
+ }
2439
+ };
2440
+
2441
+ struct WelfordBuilder {
2442
+ typedef Welford Table;
2443
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2444
+ ::flatbuffers::uoffset_t start_;
2445
+ void add_axes(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> axes) {
2446
+ fbb_.AddOffset(Welford::VT_AXES, axes);
2447
+ }
2448
+ explicit WelfordBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2449
+ : fbb_(_fbb) {
2450
+ start_ = fbb_.StartTable();
2451
+ }
2452
+ ::flatbuffers::Offset<Welford> Finish() {
2453
+ const auto end = fbb_.EndTable(start_);
2454
+ auto o = ::flatbuffers::Offset<Welford>(end);
2455
+ return o;
2456
+ }
2457
+ };
2458
+
2459
+ inline ::flatbuffers::Offset<Welford> CreateWelford(
2460
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2461
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> axes = 0) {
2462
+ WelfordBuilder builder_(_fbb);
2463
+ builder_.add_axes(axes);
2464
+ return builder_.Finish();
2465
+ }
2466
+
2467
+ inline ::flatbuffers::Offset<Welford> CreateWelfordDirect(
2468
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2469
+ const std::vector<int64_t> *axes = nullptr) {
2470
+ auto axes__ = axes ? _fbb.CreateVector<int64_t>(*axes) : 0;
2471
+ return nvfuser::serde::CreateWelford(
2472
+ _fbb,
2473
+ axes__);
2474
+ }
2475
+
2476
+ struct CudaKernel FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2477
+ typedef CudaKernelBuilder Builder;
2478
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2479
+ VT_KERNEL_NAME = 4,
2480
+ VT_COMPILE_ARGS = 6,
2481
+ VT_CUBIN = 8,
2482
+ VT_CUBIN_FILENAME = 10,
2483
+ VT_PTX = 12,
2484
+ VT_PTX_FILENAME = 14,
2485
+ VT_BLOCK_SIZE = 16
2486
+ };
2487
+ const ::flatbuffers::String *kernel_name() const {
2488
+ return GetPointer<const ::flatbuffers::String *>(VT_KERNEL_NAME);
2489
+ }
2490
+ const ::flatbuffers::String *compile_args() const {
2491
+ return GetPointer<const ::flatbuffers::String *>(VT_COMPILE_ARGS);
2492
+ }
2493
+ const ::flatbuffers::Vector<uint8_t> *cubin() const {
2494
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUBIN);
2495
+ }
2496
+ const ::flatbuffers::String *cubin_filename() const {
2497
+ return GetPointer<const ::flatbuffers::String *>(VT_CUBIN_FILENAME);
2498
+ }
2499
+ const ::flatbuffers::Vector<uint8_t> *ptx() const {
2500
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_PTX);
2501
+ }
2502
+ const ::flatbuffers::String *ptx_filename() const {
2503
+ return GetPointer<const ::flatbuffers::String *>(VT_PTX_FILENAME);
2504
+ }
2505
+ int64_t block_size() const {
2506
+ return GetField<int64_t>(VT_BLOCK_SIZE, -1LL);
2507
+ }
2508
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2509
+ return VerifyTableStart(verifier) &&
2510
+ VerifyOffset(verifier, VT_KERNEL_NAME) &&
2511
+ verifier.VerifyString(kernel_name()) &&
2512
+ VerifyOffset(verifier, VT_COMPILE_ARGS) &&
2513
+ verifier.VerifyString(compile_args()) &&
2514
+ VerifyOffset(verifier, VT_CUBIN) &&
2515
+ verifier.VerifyVector(cubin()) &&
2516
+ VerifyOffset(verifier, VT_CUBIN_FILENAME) &&
2517
+ verifier.VerifyString(cubin_filename()) &&
2518
+ VerifyOffset(verifier, VT_PTX) &&
2519
+ verifier.VerifyVector(ptx()) &&
2520
+ VerifyOffset(verifier, VT_PTX_FILENAME) &&
2521
+ verifier.VerifyString(ptx_filename()) &&
2522
+ VerifyField<int64_t>(verifier, VT_BLOCK_SIZE, 8) &&
2523
+ verifier.EndTable();
2524
+ }
2525
+ };
2526
+
2527
+ struct CudaKernelBuilder {
2528
+ typedef CudaKernel Table;
2529
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2530
+ ::flatbuffers::uoffset_t start_;
2531
+ void add_kernel_name(::flatbuffers::Offset<::flatbuffers::String> kernel_name) {
2532
+ fbb_.AddOffset(CudaKernel::VT_KERNEL_NAME, kernel_name);
2533
+ }
2534
+ void add_compile_args(::flatbuffers::Offset<::flatbuffers::String> compile_args) {
2535
+ fbb_.AddOffset(CudaKernel::VT_COMPILE_ARGS, compile_args);
2536
+ }
2537
+ void add_cubin(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> cubin) {
2538
+ fbb_.AddOffset(CudaKernel::VT_CUBIN, cubin);
2539
+ }
2540
+ void add_cubin_filename(::flatbuffers::Offset<::flatbuffers::String> cubin_filename) {
2541
+ fbb_.AddOffset(CudaKernel::VT_CUBIN_FILENAME, cubin_filename);
2542
+ }
2543
+ void add_ptx(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> ptx) {
2544
+ fbb_.AddOffset(CudaKernel::VT_PTX, ptx);
2545
+ }
2546
+ void add_ptx_filename(::flatbuffers::Offset<::flatbuffers::String> ptx_filename) {
2547
+ fbb_.AddOffset(CudaKernel::VT_PTX_FILENAME, ptx_filename);
2548
+ }
2549
+ void add_block_size(int64_t block_size) {
2550
+ fbb_.AddElement<int64_t>(CudaKernel::VT_BLOCK_SIZE, block_size, -1LL);
2551
+ }
2552
+ explicit CudaKernelBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2553
+ : fbb_(_fbb) {
2554
+ start_ = fbb_.StartTable();
2555
+ }
2556
+ ::flatbuffers::Offset<CudaKernel> Finish() {
2557
+ const auto end = fbb_.EndTable(start_);
2558
+ auto o = ::flatbuffers::Offset<CudaKernel>(end);
2559
+ return o;
2560
+ }
2561
+ };
2562
+
2563
+ inline ::flatbuffers::Offset<CudaKernel> CreateCudaKernel(
2564
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2565
+ ::flatbuffers::Offset<::flatbuffers::String> kernel_name = 0,
2566
+ ::flatbuffers::Offset<::flatbuffers::String> compile_args = 0,
2567
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> cubin = 0,
2568
+ ::flatbuffers::Offset<::flatbuffers::String> cubin_filename = 0,
2569
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> ptx = 0,
2570
+ ::flatbuffers::Offset<::flatbuffers::String> ptx_filename = 0,
2571
+ int64_t block_size = -1LL) {
2572
+ CudaKernelBuilder builder_(_fbb);
2573
+ builder_.add_block_size(block_size);
2574
+ builder_.add_ptx_filename(ptx_filename);
2575
+ builder_.add_ptx(ptx);
2576
+ builder_.add_cubin_filename(cubin_filename);
2577
+ builder_.add_cubin(cubin);
2578
+ builder_.add_compile_args(compile_args);
2579
+ builder_.add_kernel_name(kernel_name);
2580
+ return builder_.Finish();
2581
+ }
2582
+
2583
+ inline ::flatbuffers::Offset<CudaKernel> CreateCudaKernelDirect(
2584
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2585
+ const char *kernel_name = nullptr,
2586
+ const char *compile_args = nullptr,
2587
+ const std::vector<uint8_t> *cubin = nullptr,
2588
+ const char *cubin_filename = nullptr,
2589
+ const std::vector<uint8_t> *ptx = nullptr,
2590
+ const char *ptx_filename = nullptr,
2591
+ int64_t block_size = -1LL) {
2592
+ auto kernel_name__ = kernel_name ? _fbb.CreateString(kernel_name) : 0;
2593
+ auto compile_args__ = compile_args ? _fbb.CreateString(compile_args) : 0;
2594
+ auto cubin__ = cubin ? _fbb.CreateVector<uint8_t>(*cubin) : 0;
2595
+ auto cubin_filename__ = cubin_filename ? _fbb.CreateString(cubin_filename) : 0;
2596
+ auto ptx__ = ptx ? _fbb.CreateVector<uint8_t>(*ptx) : 0;
2597
+ auto ptx_filename__ = ptx_filename ? _fbb.CreateString(ptx_filename) : 0;
2598
+ return nvfuser::serde::CreateCudaKernel(
2599
+ _fbb,
2600
+ kernel_name__,
2601
+ compile_args__,
2602
+ cubin__,
2603
+ cubin_filename__,
2604
+ ptx__,
2605
+ ptx_filename__,
2606
+ block_size);
2607
+ }
2608
+
2609
+ struct KernelExecutor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2610
+ typedef KernelExecutorBuilder Builder;
2611
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2612
+ VT_DEVICE_SMEM_LIMIT = 4,
2613
+ VT_BLOCK_SIZE_HIGH_WATER_MARK = 6,
2614
+ VT_MAXRREGCOUNT_HIGH_WATER_MARK = 8,
2615
+ VT_WARP_SIZE = 10,
2616
+ VT_HEURISTIC = 12,
2617
+ VT_FUSION_ID = 14,
2618
+ VT_CONCRETE_ID = 16,
2619
+ VT_RUNTIME_ID = 18,
2620
+ VT_GROUP_ID = 20,
2621
+ VT_KERNEL_CODE = 22,
2622
+ VT_EXECUTOR_ENTRY_LOOKUP_KEYS = 24,
2623
+ VT_EXECUTOR_ENTRY_LOOKUP_VALUES = 26,
2624
+ VT_INDEX_TYPE = 28,
2625
+ VT_COMPILED_KERNEL = 30
2626
+ };
2627
+ int64_t device_smem_limit() const {
2628
+ return GetField<int64_t>(VT_DEVICE_SMEM_LIMIT, 0);
2629
+ }
2630
+ int64_t block_size_high_water_mark() const {
2631
+ return GetField<int64_t>(VT_BLOCK_SIZE_HIGH_WATER_MARK, 0);
2632
+ }
2633
+ int64_t maxrregcount_high_water_mark() const {
2634
+ return GetField<int64_t>(VT_MAXRREGCOUNT_HIGH_WATER_MARK, 0);
2635
+ }
2636
+ int64_t warp_size() const {
2637
+ return GetField<int64_t>(VT_WARP_SIZE, 0);
2638
+ }
2639
+ int64_t heuristic() const {
2640
+ return GetField<int64_t>(VT_HEURISTIC, 0);
2641
+ }
2642
+ int64_t fusion_id() const {
2643
+ return GetField<int64_t>(VT_FUSION_ID, 0);
2644
+ }
2645
+ int64_t concrete_id() const {
2646
+ return GetField<int64_t>(VT_CONCRETE_ID, 0);
2647
+ }
2648
+ int64_t runtime_id() const {
2649
+ return GetField<int64_t>(VT_RUNTIME_ID, 0);
2650
+ }
2651
+ int64_t group_id() const {
2652
+ return GetField<int64_t>(VT_GROUP_ID, 0);
2653
+ }
2654
+ const ::flatbuffers::String *kernel_code() const {
2655
+ return GetPointer<const ::flatbuffers::String *>(VT_KERNEL_CODE);
2656
+ }
2657
+ const ::flatbuffers::Vector<uint64_t> *executor_entry_lookup_keys() const {
2658
+ return GetPointer<const ::flatbuffers::Vector<uint64_t> *>(VT_EXECUTOR_ENTRY_LOOKUP_KEYS);
2659
+ }
2660
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::ExecutorEntry>> *executor_entry_lookup_values() const {
2661
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::ExecutorEntry>> *>(VT_EXECUTOR_ENTRY_LOOKUP_VALUES);
2662
+ }
2663
+ int64_t index_type() const {
2664
+ return GetField<int64_t>(VT_INDEX_TYPE, 0);
2665
+ }
2666
+ const nvfuser::serde::CudaKernel *compiled_kernel() const {
2667
+ return GetPointer<const nvfuser::serde::CudaKernel *>(VT_COMPILED_KERNEL);
2668
+ }
2669
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2670
+ return VerifyTableStart(verifier) &&
2671
+ VerifyField<int64_t>(verifier, VT_DEVICE_SMEM_LIMIT, 8) &&
2672
+ VerifyField<int64_t>(verifier, VT_BLOCK_SIZE_HIGH_WATER_MARK, 8) &&
2673
+ VerifyField<int64_t>(verifier, VT_MAXRREGCOUNT_HIGH_WATER_MARK, 8) &&
2674
+ VerifyField<int64_t>(verifier, VT_WARP_SIZE, 8) &&
2675
+ VerifyField<int64_t>(verifier, VT_HEURISTIC, 8) &&
2676
+ VerifyField<int64_t>(verifier, VT_FUSION_ID, 8) &&
2677
+ VerifyField<int64_t>(verifier, VT_CONCRETE_ID, 8) &&
2678
+ VerifyField<int64_t>(verifier, VT_RUNTIME_ID, 8) &&
2679
+ VerifyField<int64_t>(verifier, VT_GROUP_ID, 8) &&
2680
+ VerifyOffset(verifier, VT_KERNEL_CODE) &&
2681
+ verifier.VerifyString(kernel_code()) &&
2682
+ VerifyOffset(verifier, VT_EXECUTOR_ENTRY_LOOKUP_KEYS) &&
2683
+ verifier.VerifyVector(executor_entry_lookup_keys()) &&
2684
+ VerifyOffset(verifier, VT_EXECUTOR_ENTRY_LOOKUP_VALUES) &&
2685
+ verifier.VerifyVector(executor_entry_lookup_values()) &&
2686
+ verifier.VerifyVectorOfTables(executor_entry_lookup_values()) &&
2687
+ VerifyField<int64_t>(verifier, VT_INDEX_TYPE, 8) &&
2688
+ VerifyOffset(verifier, VT_COMPILED_KERNEL) &&
2689
+ verifier.VerifyTable(compiled_kernel()) &&
2690
+ verifier.EndTable();
2691
+ }
2692
+ };
2693
+
2694
+ struct KernelExecutorBuilder {
2695
+ typedef KernelExecutor Table;
2696
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2697
+ ::flatbuffers::uoffset_t start_;
2698
+ void add_device_smem_limit(int64_t device_smem_limit) {
2699
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_DEVICE_SMEM_LIMIT, device_smem_limit, 0);
2700
+ }
2701
+ void add_block_size_high_water_mark(int64_t block_size_high_water_mark) {
2702
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_BLOCK_SIZE_HIGH_WATER_MARK, block_size_high_water_mark, 0);
2703
+ }
2704
+ void add_maxrregcount_high_water_mark(int64_t maxrregcount_high_water_mark) {
2705
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_MAXRREGCOUNT_HIGH_WATER_MARK, maxrregcount_high_water_mark, 0);
2706
+ }
2707
+ void add_warp_size(int64_t warp_size) {
2708
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_WARP_SIZE, warp_size, 0);
2709
+ }
2710
+ void add_heuristic(int64_t heuristic) {
2711
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_HEURISTIC, heuristic, 0);
2712
+ }
2713
+ void add_fusion_id(int64_t fusion_id) {
2714
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_FUSION_ID, fusion_id, 0);
2715
+ }
2716
+ void add_concrete_id(int64_t concrete_id) {
2717
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_CONCRETE_ID, concrete_id, 0);
2718
+ }
2719
+ void add_runtime_id(int64_t runtime_id) {
2720
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_RUNTIME_ID, runtime_id, 0);
2721
+ }
2722
+ void add_group_id(int64_t group_id) {
2723
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_GROUP_ID, group_id, 0);
2724
+ }
2725
+ void add_kernel_code(::flatbuffers::Offset<::flatbuffers::String> kernel_code) {
2726
+ fbb_.AddOffset(KernelExecutor::VT_KERNEL_CODE, kernel_code);
2727
+ }
2728
+ void add_executor_entry_lookup_keys(::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> executor_entry_lookup_keys) {
2729
+ fbb_.AddOffset(KernelExecutor::VT_EXECUTOR_ENTRY_LOOKUP_KEYS, executor_entry_lookup_keys);
2730
+ }
2731
+ void add_executor_entry_lookup_values(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::ExecutorEntry>>> executor_entry_lookup_values) {
2732
+ fbb_.AddOffset(KernelExecutor::VT_EXECUTOR_ENTRY_LOOKUP_VALUES, executor_entry_lookup_values);
2733
+ }
2734
+ void add_index_type(int64_t index_type) {
2735
+ fbb_.AddElement<int64_t>(KernelExecutor::VT_INDEX_TYPE, index_type, 0);
2736
+ }
2737
+ void add_compiled_kernel(::flatbuffers::Offset<nvfuser::serde::CudaKernel> compiled_kernel) {
2738
+ fbb_.AddOffset(KernelExecutor::VT_COMPILED_KERNEL, compiled_kernel);
2739
+ }
2740
+ explicit KernelExecutorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2741
+ : fbb_(_fbb) {
2742
+ start_ = fbb_.StartTable();
2743
+ }
2744
+ ::flatbuffers::Offset<KernelExecutor> Finish() {
2745
+ const auto end = fbb_.EndTable(start_);
2746
+ auto o = ::flatbuffers::Offset<KernelExecutor>(end);
2747
+ return o;
2748
+ }
2749
+ };
2750
+
2751
+ inline ::flatbuffers::Offset<KernelExecutor> CreateKernelExecutor(
2752
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2753
+ int64_t device_smem_limit = 0,
2754
+ int64_t block_size_high_water_mark = 0,
2755
+ int64_t maxrregcount_high_water_mark = 0,
2756
+ int64_t warp_size = 0,
2757
+ int64_t heuristic = 0,
2758
+ int64_t fusion_id = 0,
2759
+ int64_t concrete_id = 0,
2760
+ int64_t runtime_id = 0,
2761
+ int64_t group_id = 0,
2762
+ ::flatbuffers::Offset<::flatbuffers::String> kernel_code = 0,
2763
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> executor_entry_lookup_keys = 0,
2764
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::ExecutorEntry>>> executor_entry_lookup_values = 0,
2765
+ int64_t index_type = 0,
2766
+ ::flatbuffers::Offset<nvfuser::serde::CudaKernel> compiled_kernel = 0) {
2767
+ KernelExecutorBuilder builder_(_fbb);
2768
+ builder_.add_index_type(index_type);
2769
+ builder_.add_group_id(group_id);
2770
+ builder_.add_runtime_id(runtime_id);
2771
+ builder_.add_concrete_id(concrete_id);
2772
+ builder_.add_fusion_id(fusion_id);
2773
+ builder_.add_heuristic(heuristic);
2774
+ builder_.add_warp_size(warp_size);
2775
+ builder_.add_maxrregcount_high_water_mark(maxrregcount_high_water_mark);
2776
+ builder_.add_block_size_high_water_mark(block_size_high_water_mark);
2777
+ builder_.add_device_smem_limit(device_smem_limit);
2778
+ builder_.add_compiled_kernel(compiled_kernel);
2779
+ builder_.add_executor_entry_lookup_values(executor_entry_lookup_values);
2780
+ builder_.add_executor_entry_lookup_keys(executor_entry_lookup_keys);
2781
+ builder_.add_kernel_code(kernel_code);
2782
+ return builder_.Finish();
2783
+ }
2784
+
2785
+ inline ::flatbuffers::Offset<KernelExecutor> CreateKernelExecutorDirect(
2786
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2787
+ int64_t device_smem_limit = 0,
2788
+ int64_t block_size_high_water_mark = 0,
2789
+ int64_t maxrregcount_high_water_mark = 0,
2790
+ int64_t warp_size = 0,
2791
+ int64_t heuristic = 0,
2792
+ int64_t fusion_id = 0,
2793
+ int64_t concrete_id = 0,
2794
+ int64_t runtime_id = 0,
2795
+ int64_t group_id = 0,
2796
+ const char *kernel_code = nullptr,
2797
+ const std::vector<uint64_t> *executor_entry_lookup_keys = nullptr,
2798
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::ExecutorEntry>> *executor_entry_lookup_values = nullptr,
2799
+ int64_t index_type = 0,
2800
+ ::flatbuffers::Offset<nvfuser::serde::CudaKernel> compiled_kernel = 0) {
2801
+ auto kernel_code__ = kernel_code ? _fbb.CreateString(kernel_code) : 0;
2802
+ auto executor_entry_lookup_keys__ = executor_entry_lookup_keys ? _fbb.CreateVector<uint64_t>(*executor_entry_lookup_keys) : 0;
2803
+ auto executor_entry_lookup_values__ = executor_entry_lookup_values ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::ExecutorEntry>>(*executor_entry_lookup_values) : 0;
2804
+ return nvfuser::serde::CreateKernelExecutor(
2805
+ _fbb,
2806
+ device_smem_limit,
2807
+ block_size_high_water_mark,
2808
+ maxrregcount_high_water_mark,
2809
+ warp_size,
2810
+ heuristic,
2811
+ fusion_id,
2812
+ concrete_id,
2813
+ runtime_id,
2814
+ group_id,
2815
+ kernel_code__,
2816
+ executor_entry_lookup_keys__,
2817
+ executor_entry_lookup_values__,
2818
+ index_type,
2819
+ compiled_kernel);
2820
+ }
2821
+
2822
+ struct SegmentedEdge FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2823
+ typedef SegmentedEdgeBuilder Builder;
2824
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2825
+ VT_FROM_SEGMENTED_GROUP = 4,
2826
+ VT_TO_SEGMENTED_GROUP = 6,
2827
+ VT_VAL = 8
2828
+ };
2829
+ int64_t from_segmented_group() const {
2830
+ return GetField<int64_t>(VT_FROM_SEGMENTED_GROUP, 0);
2831
+ }
2832
+ int64_t to_segmented_group() const {
2833
+ return GetField<int64_t>(VT_TO_SEGMENTED_GROUP, 0);
2834
+ }
2835
+ int64_t val() const {
2836
+ return GetField<int64_t>(VT_VAL, 0);
2837
+ }
2838
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2839
+ return VerifyTableStart(verifier) &&
2840
+ VerifyField<int64_t>(verifier, VT_FROM_SEGMENTED_GROUP, 8) &&
2841
+ VerifyField<int64_t>(verifier, VT_TO_SEGMENTED_GROUP, 8) &&
2842
+ VerifyField<int64_t>(verifier, VT_VAL, 8) &&
2843
+ verifier.EndTable();
2844
+ }
2845
+ };
2846
+
2847
+ struct SegmentedEdgeBuilder {
2848
+ typedef SegmentedEdge Table;
2849
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2850
+ ::flatbuffers::uoffset_t start_;
2851
+ void add_from_segmented_group(int64_t from_segmented_group) {
2852
+ fbb_.AddElement<int64_t>(SegmentedEdge::VT_FROM_SEGMENTED_GROUP, from_segmented_group, 0);
2853
+ }
2854
+ void add_to_segmented_group(int64_t to_segmented_group) {
2855
+ fbb_.AddElement<int64_t>(SegmentedEdge::VT_TO_SEGMENTED_GROUP, to_segmented_group, 0);
2856
+ }
2857
+ void add_val(int64_t val) {
2858
+ fbb_.AddElement<int64_t>(SegmentedEdge::VT_VAL, val, 0);
2859
+ }
2860
+ explicit SegmentedEdgeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
2861
+ : fbb_(_fbb) {
2862
+ start_ = fbb_.StartTable();
2863
+ }
2864
+ ::flatbuffers::Offset<SegmentedEdge> Finish() {
2865
+ const auto end = fbb_.EndTable(start_);
2866
+ auto o = ::flatbuffers::Offset<SegmentedEdge>(end);
2867
+ return o;
2868
+ }
2869
+ };
2870
+
2871
+ inline ::flatbuffers::Offset<SegmentedEdge> CreateSegmentedEdge(
2872
+ ::flatbuffers::FlatBufferBuilder &_fbb,
2873
+ int64_t from_segmented_group = 0,
2874
+ int64_t to_segmented_group = 0,
2875
+ int64_t val = 0) {
2876
+ SegmentedEdgeBuilder builder_(_fbb);
2877
+ builder_.add_val(val);
2878
+ builder_.add_to_segmented_group(to_segmented_group);
2879
+ builder_.add_from_segmented_group(from_segmented_group);
2880
+ return builder_.Finish();
2881
+ }
2882
+
2883
+ struct SegmentedGroup FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
2884
+ typedef SegmentedGroupBuilder Builder;
2885
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2886
+ VT_PRODUCER_EDGES = 4,
2887
+ VT_CONSUMER_EDGES = 6,
2888
+ VT_INPUT_VALS = 8,
2889
+ VT_OUTPUT_VALS = 10,
2890
+ VT_GROUP_ID = 12,
2891
+ VT_HEURISTIC = 14,
2892
+ VT_EXPRS = 16,
2893
+ VT_LEVEL = 18,
2894
+ VT_VISITED = 20,
2895
+ VT_MERGE_WITH_SEGMENTED_GROUP = 22,
2896
+ VT_MERGE_THROUGH_SEGMENTED_EDGE = 24,
2897
+ VT_MERGED = 26,
2898
+ VT_IS_FUSION_INPUT = 28
2899
+ };
2900
+ const ::flatbuffers::Vector<int64_t> *producer_edges() const {
2901
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_PRODUCER_EDGES);
2902
+ }
2903
+ const ::flatbuffers::Vector<int64_t> *consumer_edges() const {
2904
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_CONSUMER_EDGES);
2905
+ }
2906
+ const ::flatbuffers::Vector<int64_t> *input_vals() const {
2907
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_INPUT_VALS);
2908
+ }
2909
+ const ::flatbuffers::Vector<int64_t> *output_vals() const {
2910
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_OUTPUT_VALS);
2911
+ }
2912
+ int32_t group_id() const {
2913
+ return GetField<int32_t>(VT_GROUP_ID, 0);
2914
+ }
2915
+ int64_t heuristic() const {
2916
+ return GetField<int64_t>(VT_HEURISTIC, 0);
2917
+ }
2918
+ const ::flatbuffers::Vector<int64_t> *exprs() const {
2919
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_EXPRS);
2920
+ }
2921
+ int32_t level() const {
2922
+ return GetField<int32_t>(VT_LEVEL, 0);
2923
+ }
2924
+ bool visited() const {
2925
+ return GetField<uint8_t>(VT_VISITED, 0) != 0;
2926
+ }
2927
+ int64_t merge_with_segmented_group() const {
2928
+ return GetField<int64_t>(VT_MERGE_WITH_SEGMENTED_GROUP, 0);
2929
+ }
2930
+ int64_t merge_through_segmented_edge() const {
2931
+ return GetField<int64_t>(VT_MERGE_THROUGH_SEGMENTED_EDGE, 0);
2932
+ }
2933
+ bool merged() const {
2934
+ return GetField<uint8_t>(VT_MERGED, 0) != 0;
2935
+ }
2936
+ bool is_fusion_input() const {
2937
+ return GetField<uint8_t>(VT_IS_FUSION_INPUT, 0) != 0;
2938
+ }
2939
+ bool Verify(::flatbuffers::Verifier &verifier) const {
2940
+ return VerifyTableStart(verifier) &&
2941
+ VerifyOffset(verifier, VT_PRODUCER_EDGES) &&
2942
+ verifier.VerifyVector(producer_edges()) &&
2943
+ VerifyOffset(verifier, VT_CONSUMER_EDGES) &&
2944
+ verifier.VerifyVector(consumer_edges()) &&
2945
+ VerifyOffset(verifier, VT_INPUT_VALS) &&
2946
+ verifier.VerifyVector(input_vals()) &&
2947
+ VerifyOffset(verifier, VT_OUTPUT_VALS) &&
2948
+ verifier.VerifyVector(output_vals()) &&
2949
+ VerifyField<int32_t>(verifier, VT_GROUP_ID, 4) &&
2950
+ VerifyField<int64_t>(verifier, VT_HEURISTIC, 8) &&
2951
+ VerifyOffset(verifier, VT_EXPRS) &&
2952
+ verifier.VerifyVector(exprs()) &&
2953
+ VerifyField<int32_t>(verifier, VT_LEVEL, 4) &&
2954
+ VerifyField<uint8_t>(verifier, VT_VISITED, 1) &&
2955
+ VerifyField<int64_t>(verifier, VT_MERGE_WITH_SEGMENTED_GROUP, 8) &&
2956
+ VerifyField<int64_t>(verifier, VT_MERGE_THROUGH_SEGMENTED_EDGE, 8) &&
2957
+ VerifyField<uint8_t>(verifier, VT_MERGED, 1) &&
2958
+ VerifyField<uint8_t>(verifier, VT_IS_FUSION_INPUT, 1) &&
2959
+ verifier.EndTable();
2960
+ }
2961
+ };
2962
+
2963
+ struct SegmentedGroupBuilder {
2964
+ typedef SegmentedGroup Table;
2965
+ ::flatbuffers::FlatBufferBuilder &fbb_;
2966
+ ::flatbuffers::uoffset_t start_;
2967
+ void add_producer_edges(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> producer_edges) {
2968
+ fbb_.AddOffset(SegmentedGroup::VT_PRODUCER_EDGES, producer_edges);
2969
+ }
2970
+ void add_consumer_edges(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> consumer_edges) {
2971
+ fbb_.AddOffset(SegmentedGroup::VT_CONSUMER_EDGES, consumer_edges);
2972
+ }
2973
+ void add_input_vals(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> input_vals) {
2974
+ fbb_.AddOffset(SegmentedGroup::VT_INPUT_VALS, input_vals);
2975
+ }
2976
+ void add_output_vals(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> output_vals) {
2977
+ fbb_.AddOffset(SegmentedGroup::VT_OUTPUT_VALS, output_vals);
2978
+ }
2979
+ void add_group_id(int32_t group_id) {
2980
+ fbb_.AddElement<int32_t>(SegmentedGroup::VT_GROUP_ID, group_id, 0);
2981
+ }
2982
+ void add_heuristic(int64_t heuristic) {
2983
+ fbb_.AddElement<int64_t>(SegmentedGroup::VT_HEURISTIC, heuristic, 0);
2984
+ }
2985
+ void add_exprs(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> exprs) {
2986
+ fbb_.AddOffset(SegmentedGroup::VT_EXPRS, exprs);
2987
+ }
2988
+ void add_level(int32_t level) {
2989
+ fbb_.AddElement<int32_t>(SegmentedGroup::VT_LEVEL, level, 0);
2990
+ }
2991
+ void add_visited(bool visited) {
2992
+ fbb_.AddElement<uint8_t>(SegmentedGroup::VT_VISITED, static_cast<uint8_t>(visited), 0);
2993
+ }
2994
+ void add_merge_with_segmented_group(int64_t merge_with_segmented_group) {
2995
+ fbb_.AddElement<int64_t>(SegmentedGroup::VT_MERGE_WITH_SEGMENTED_GROUP, merge_with_segmented_group, 0);
2996
+ }
2997
+ void add_merge_through_segmented_edge(int64_t merge_through_segmented_edge) {
2998
+ fbb_.AddElement<int64_t>(SegmentedGroup::VT_MERGE_THROUGH_SEGMENTED_EDGE, merge_through_segmented_edge, 0);
2999
+ }
3000
+ void add_merged(bool merged) {
3001
+ fbb_.AddElement<uint8_t>(SegmentedGroup::VT_MERGED, static_cast<uint8_t>(merged), 0);
3002
+ }
3003
+ void add_is_fusion_input(bool is_fusion_input) {
3004
+ fbb_.AddElement<uint8_t>(SegmentedGroup::VT_IS_FUSION_INPUT, static_cast<uint8_t>(is_fusion_input), 0);
3005
+ }
3006
+ explicit SegmentedGroupBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3007
+ : fbb_(_fbb) {
3008
+ start_ = fbb_.StartTable();
3009
+ }
3010
+ ::flatbuffers::Offset<SegmentedGroup> Finish() {
3011
+ const auto end = fbb_.EndTable(start_);
3012
+ auto o = ::flatbuffers::Offset<SegmentedGroup>(end);
3013
+ return o;
3014
+ }
3015
+ };
3016
+
3017
+ inline ::flatbuffers::Offset<SegmentedGroup> CreateSegmentedGroup(
3018
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3019
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> producer_edges = 0,
3020
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> consumer_edges = 0,
3021
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> input_vals = 0,
3022
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> output_vals = 0,
3023
+ int32_t group_id = 0,
3024
+ int64_t heuristic = 0,
3025
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> exprs = 0,
3026
+ int32_t level = 0,
3027
+ bool visited = false,
3028
+ int64_t merge_with_segmented_group = 0,
3029
+ int64_t merge_through_segmented_edge = 0,
3030
+ bool merged = false,
3031
+ bool is_fusion_input = false) {
3032
+ SegmentedGroupBuilder builder_(_fbb);
3033
+ builder_.add_merge_through_segmented_edge(merge_through_segmented_edge);
3034
+ builder_.add_merge_with_segmented_group(merge_with_segmented_group);
3035
+ builder_.add_heuristic(heuristic);
3036
+ builder_.add_level(level);
3037
+ builder_.add_exprs(exprs);
3038
+ builder_.add_group_id(group_id);
3039
+ builder_.add_output_vals(output_vals);
3040
+ builder_.add_input_vals(input_vals);
3041
+ builder_.add_consumer_edges(consumer_edges);
3042
+ builder_.add_producer_edges(producer_edges);
3043
+ builder_.add_is_fusion_input(is_fusion_input);
3044
+ builder_.add_merged(merged);
3045
+ builder_.add_visited(visited);
3046
+ return builder_.Finish();
3047
+ }
3048
+
3049
+ inline ::flatbuffers::Offset<SegmentedGroup> CreateSegmentedGroupDirect(
3050
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3051
+ const std::vector<int64_t> *producer_edges = nullptr,
3052
+ const std::vector<int64_t> *consumer_edges = nullptr,
3053
+ const std::vector<int64_t> *input_vals = nullptr,
3054
+ const std::vector<int64_t> *output_vals = nullptr,
3055
+ int32_t group_id = 0,
3056
+ int64_t heuristic = 0,
3057
+ const std::vector<int64_t> *exprs = nullptr,
3058
+ int32_t level = 0,
3059
+ bool visited = false,
3060
+ int64_t merge_with_segmented_group = 0,
3061
+ int64_t merge_through_segmented_edge = 0,
3062
+ bool merged = false,
3063
+ bool is_fusion_input = false) {
3064
+ auto producer_edges__ = producer_edges ? _fbb.CreateVector<int64_t>(*producer_edges) : 0;
3065
+ auto consumer_edges__ = consumer_edges ? _fbb.CreateVector<int64_t>(*consumer_edges) : 0;
3066
+ auto input_vals__ = input_vals ? _fbb.CreateVector<int64_t>(*input_vals) : 0;
3067
+ auto output_vals__ = output_vals ? _fbb.CreateVector<int64_t>(*output_vals) : 0;
3068
+ auto exprs__ = exprs ? _fbb.CreateVector<int64_t>(*exprs) : 0;
3069
+ return nvfuser::serde::CreateSegmentedGroup(
3070
+ _fbb,
3071
+ producer_edges__,
3072
+ consumer_edges__,
3073
+ input_vals__,
3074
+ output_vals__,
3075
+ group_id,
3076
+ heuristic,
3077
+ exprs__,
3078
+ level,
3079
+ visited,
3080
+ merge_with_segmented_group,
3081
+ merge_through_segmented_edge,
3082
+ merged,
3083
+ is_fusion_input);
3084
+ }
3085
+
3086
+ struct SegmentedFusion FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3087
+ typedef SegmentedFusionBuilder Builder;
3088
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3089
+ VT_VALID = 4,
3090
+ VT_SEGMENTED_FUSION_NAME = 6,
3091
+ VT_NUM_VALS = 8,
3092
+ VT_NUM_EXPRS = 10,
3093
+ VT_EDGES = 12,
3094
+ VT_GROUPS = 14,
3095
+ VT_FORCE_FP16_TV_SET = 16,
3096
+ VT_FORCE_HALF_PRECISION_TYPE = 18
3097
+ };
3098
+ bool valid() const {
3099
+ return GetField<uint8_t>(VT_VALID, 0) != 0;
3100
+ }
3101
+ uint64_t segmented_fusion_name() const {
3102
+ return GetField<uint64_t>(VT_SEGMENTED_FUSION_NAME, 0);
3103
+ }
3104
+ uint64_t num_vals() const {
3105
+ return GetField<uint64_t>(VT_NUM_VALS, 0);
3106
+ }
3107
+ uint64_t num_exprs() const {
3108
+ return GetField<uint64_t>(VT_NUM_EXPRS, 0);
3109
+ }
3110
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedEdge>> *edges() const {
3111
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedEdge>> *>(VT_EDGES);
3112
+ }
3113
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedGroup>> *groups() const {
3114
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedGroup>> *>(VT_GROUPS);
3115
+ }
3116
+ const ::flatbuffers::Vector<int64_t> *force_fp16_tv_set() const {
3117
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_FORCE_FP16_TV_SET);
3118
+ }
3119
+ int64_t force_half_precision_type() const {
3120
+ return GetField<int64_t>(VT_FORCE_HALF_PRECISION_TYPE, 0);
3121
+ }
3122
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3123
+ return VerifyTableStart(verifier) &&
3124
+ VerifyField<uint8_t>(verifier, VT_VALID, 1) &&
3125
+ VerifyField<uint64_t>(verifier, VT_SEGMENTED_FUSION_NAME, 8) &&
3126
+ VerifyField<uint64_t>(verifier, VT_NUM_VALS, 8) &&
3127
+ VerifyField<uint64_t>(verifier, VT_NUM_EXPRS, 8) &&
3128
+ VerifyOffset(verifier, VT_EDGES) &&
3129
+ verifier.VerifyVector(edges()) &&
3130
+ verifier.VerifyVectorOfTables(edges()) &&
3131
+ VerifyOffset(verifier, VT_GROUPS) &&
3132
+ verifier.VerifyVector(groups()) &&
3133
+ verifier.VerifyVectorOfTables(groups()) &&
3134
+ VerifyOffset(verifier, VT_FORCE_FP16_TV_SET) &&
3135
+ verifier.VerifyVector(force_fp16_tv_set()) &&
3136
+ VerifyField<int64_t>(verifier, VT_FORCE_HALF_PRECISION_TYPE, 8) &&
3137
+ verifier.EndTable();
3138
+ }
3139
+ };
3140
+
3141
+ struct SegmentedFusionBuilder {
3142
+ typedef SegmentedFusion Table;
3143
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3144
+ ::flatbuffers::uoffset_t start_;
3145
+ void add_valid(bool valid) {
3146
+ fbb_.AddElement<uint8_t>(SegmentedFusion::VT_VALID, static_cast<uint8_t>(valid), 0);
3147
+ }
3148
+ void add_segmented_fusion_name(uint64_t segmented_fusion_name) {
3149
+ fbb_.AddElement<uint64_t>(SegmentedFusion::VT_SEGMENTED_FUSION_NAME, segmented_fusion_name, 0);
3150
+ }
3151
+ void add_num_vals(uint64_t num_vals) {
3152
+ fbb_.AddElement<uint64_t>(SegmentedFusion::VT_NUM_VALS, num_vals, 0);
3153
+ }
3154
+ void add_num_exprs(uint64_t num_exprs) {
3155
+ fbb_.AddElement<uint64_t>(SegmentedFusion::VT_NUM_EXPRS, num_exprs, 0);
3156
+ }
3157
+ void add_edges(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedEdge>>> edges) {
3158
+ fbb_.AddOffset(SegmentedFusion::VT_EDGES, edges);
3159
+ }
3160
+ void add_groups(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedGroup>>> groups) {
3161
+ fbb_.AddOffset(SegmentedFusion::VT_GROUPS, groups);
3162
+ }
3163
+ void add_force_fp16_tv_set(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> force_fp16_tv_set) {
3164
+ fbb_.AddOffset(SegmentedFusion::VT_FORCE_FP16_TV_SET, force_fp16_tv_set);
3165
+ }
3166
+ void add_force_half_precision_type(int64_t force_half_precision_type) {
3167
+ fbb_.AddElement<int64_t>(SegmentedFusion::VT_FORCE_HALF_PRECISION_TYPE, force_half_precision_type, 0);
3168
+ }
3169
+ explicit SegmentedFusionBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3170
+ : fbb_(_fbb) {
3171
+ start_ = fbb_.StartTable();
3172
+ }
3173
+ ::flatbuffers::Offset<SegmentedFusion> Finish() {
3174
+ const auto end = fbb_.EndTable(start_);
3175
+ auto o = ::flatbuffers::Offset<SegmentedFusion>(end);
3176
+ return o;
3177
+ }
3178
+ };
3179
+
3180
+ inline ::flatbuffers::Offset<SegmentedFusion> CreateSegmentedFusion(
3181
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3182
+ bool valid = false,
3183
+ uint64_t segmented_fusion_name = 0,
3184
+ uint64_t num_vals = 0,
3185
+ uint64_t num_exprs = 0,
3186
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedEdge>>> edges = 0,
3187
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::SegmentedGroup>>> groups = 0,
3188
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> force_fp16_tv_set = 0,
3189
+ int64_t force_half_precision_type = 0) {
3190
+ SegmentedFusionBuilder builder_(_fbb);
3191
+ builder_.add_force_half_precision_type(force_half_precision_type);
3192
+ builder_.add_num_exprs(num_exprs);
3193
+ builder_.add_num_vals(num_vals);
3194
+ builder_.add_segmented_fusion_name(segmented_fusion_name);
3195
+ builder_.add_force_fp16_tv_set(force_fp16_tv_set);
3196
+ builder_.add_groups(groups);
3197
+ builder_.add_edges(edges);
3198
+ builder_.add_valid(valid);
3199
+ return builder_.Finish();
3200
+ }
3201
+
3202
+ inline ::flatbuffers::Offset<SegmentedFusion> CreateSegmentedFusionDirect(
3203
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3204
+ bool valid = false,
3205
+ uint64_t segmented_fusion_name = 0,
3206
+ uint64_t num_vals = 0,
3207
+ uint64_t num_exprs = 0,
3208
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::SegmentedEdge>> *edges = nullptr,
3209
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::SegmentedGroup>> *groups = nullptr,
3210
+ const std::vector<int64_t> *force_fp16_tv_set = nullptr,
3211
+ int64_t force_half_precision_type = 0) {
3212
+ auto edges__ = edges ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::SegmentedEdge>>(*edges) : 0;
3213
+ auto groups__ = groups ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::SegmentedGroup>>(*groups) : 0;
3214
+ auto force_fp16_tv_set__ = force_fp16_tv_set ? _fbb.CreateVector<int64_t>(*force_fp16_tv_set) : 0;
3215
+ return nvfuser::serde::CreateSegmentedFusion(
3216
+ _fbb,
3217
+ valid,
3218
+ segmented_fusion_name,
3219
+ num_vals,
3220
+ num_exprs,
3221
+ edges__,
3222
+ groups__,
3223
+ force_fp16_tv_set__,
3224
+ force_half_precision_type);
3225
+ }
3226
+
3227
+ struct FusionKernelRuntime FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3228
+ typedef FusionKernelRuntimeBuilder Builder;
3229
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3230
+ VT_FUSION_ID = 4,
3231
+ VT_CONCRETE_ID = 6,
3232
+ VT_RUNTIME_ID = 8,
3233
+ VT_ARGS = 10,
3234
+ VT_EXECUTORS = 12,
3235
+ VT_SEGMENTED_FUSION = 14
3236
+ };
3237
+ int64_t fusion_id() const {
3238
+ return GetField<int64_t>(VT_FUSION_ID, 0);
3239
+ }
3240
+ int64_t concrete_id() const {
3241
+ return GetField<int64_t>(VT_CONCRETE_ID, 0);
3242
+ }
3243
+ int64_t runtime_id() const {
3244
+ return GetField<int64_t>(VT_RUNTIME_ID, 0);
3245
+ }
3246
+ const nvfuser::serde::KernelArgumentHolder *args() const {
3247
+ return GetPointer<const nvfuser::serde::KernelArgumentHolder *>(VT_ARGS);
3248
+ }
3249
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelExecutor>> *executors() const {
3250
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelExecutor>> *>(VT_EXECUTORS);
3251
+ }
3252
+ const nvfuser::serde::SegmentedFusion *segmented_fusion() const {
3253
+ return GetPointer<const nvfuser::serde::SegmentedFusion *>(VT_SEGMENTED_FUSION);
3254
+ }
3255
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3256
+ return VerifyTableStart(verifier) &&
3257
+ VerifyField<int64_t>(verifier, VT_FUSION_ID, 8) &&
3258
+ VerifyField<int64_t>(verifier, VT_CONCRETE_ID, 8) &&
3259
+ VerifyField<int64_t>(verifier, VT_RUNTIME_ID, 8) &&
3260
+ VerifyOffset(verifier, VT_ARGS) &&
3261
+ verifier.VerifyTable(args()) &&
3262
+ VerifyOffset(verifier, VT_EXECUTORS) &&
3263
+ verifier.VerifyVector(executors()) &&
3264
+ verifier.VerifyVectorOfTables(executors()) &&
3265
+ VerifyOffset(verifier, VT_SEGMENTED_FUSION) &&
3266
+ verifier.VerifyTable(segmented_fusion()) &&
3267
+ verifier.EndTable();
3268
+ }
3269
+ };
3270
+
3271
+ struct FusionKernelRuntimeBuilder {
3272
+ typedef FusionKernelRuntime Table;
3273
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3274
+ ::flatbuffers::uoffset_t start_;
3275
+ void add_fusion_id(int64_t fusion_id) {
3276
+ fbb_.AddElement<int64_t>(FusionKernelRuntime::VT_FUSION_ID, fusion_id, 0);
3277
+ }
3278
+ void add_concrete_id(int64_t concrete_id) {
3279
+ fbb_.AddElement<int64_t>(FusionKernelRuntime::VT_CONCRETE_ID, concrete_id, 0);
3280
+ }
3281
+ void add_runtime_id(int64_t runtime_id) {
3282
+ fbb_.AddElement<int64_t>(FusionKernelRuntime::VT_RUNTIME_ID, runtime_id, 0);
3283
+ }
3284
+ void add_args(::flatbuffers::Offset<nvfuser::serde::KernelArgumentHolder> args) {
3285
+ fbb_.AddOffset(FusionKernelRuntime::VT_ARGS, args);
3286
+ }
3287
+ void add_executors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelExecutor>>> executors) {
3288
+ fbb_.AddOffset(FusionKernelRuntime::VT_EXECUTORS, executors);
3289
+ }
3290
+ void add_segmented_fusion(::flatbuffers::Offset<nvfuser::serde::SegmentedFusion> segmented_fusion) {
3291
+ fbb_.AddOffset(FusionKernelRuntime::VT_SEGMENTED_FUSION, segmented_fusion);
3292
+ }
3293
+ explicit FusionKernelRuntimeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3294
+ : fbb_(_fbb) {
3295
+ start_ = fbb_.StartTable();
3296
+ }
3297
+ ::flatbuffers::Offset<FusionKernelRuntime> Finish() {
3298
+ const auto end = fbb_.EndTable(start_);
3299
+ auto o = ::flatbuffers::Offset<FusionKernelRuntime>(end);
3300
+ return o;
3301
+ }
3302
+ };
3303
+
3304
+ inline ::flatbuffers::Offset<FusionKernelRuntime> CreateFusionKernelRuntime(
3305
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3306
+ int64_t fusion_id = 0,
3307
+ int64_t concrete_id = 0,
3308
+ int64_t runtime_id = 0,
3309
+ ::flatbuffers::Offset<nvfuser::serde::KernelArgumentHolder> args = 0,
3310
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelExecutor>>> executors = 0,
3311
+ ::flatbuffers::Offset<nvfuser::serde::SegmentedFusion> segmented_fusion = 0) {
3312
+ FusionKernelRuntimeBuilder builder_(_fbb);
3313
+ builder_.add_runtime_id(runtime_id);
3314
+ builder_.add_concrete_id(concrete_id);
3315
+ builder_.add_fusion_id(fusion_id);
3316
+ builder_.add_segmented_fusion(segmented_fusion);
3317
+ builder_.add_executors(executors);
3318
+ builder_.add_args(args);
3319
+ return builder_.Finish();
3320
+ }
3321
+
3322
+ inline ::flatbuffers::Offset<FusionKernelRuntime> CreateFusionKernelRuntimeDirect(
3323
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3324
+ int64_t fusion_id = 0,
3325
+ int64_t concrete_id = 0,
3326
+ int64_t runtime_id = 0,
3327
+ ::flatbuffers::Offset<nvfuser::serde::KernelArgumentHolder> args = 0,
3328
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::KernelExecutor>> *executors = nullptr,
3329
+ ::flatbuffers::Offset<nvfuser::serde::SegmentedFusion> segmented_fusion = 0) {
3330
+ auto executors__ = executors ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::KernelExecutor>>(*executors) : 0;
3331
+ return nvfuser::serde::CreateFusionKernelRuntime(
3332
+ _fbb,
3333
+ fusion_id,
3334
+ concrete_id,
3335
+ runtime_id,
3336
+ args,
3337
+ executors__,
3338
+ segmented_fusion);
3339
+ }
3340
+
3341
+ struct InputsIdLookup FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3342
+ typedef InputsIdLookupBuilder Builder;
3343
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3344
+ VT_MAX_CACHE_SIZE = 4,
3345
+ VT_CURRENT_ID = 6,
3346
+ VT_LRU_CACHE = 8,
3347
+ VT_ENCODING_LOOKUP_KEYS = 10,
3348
+ VT_ENCODING_LOOKUP_VALUES = 12
3349
+ };
3350
+ uint64_t max_cache_size() const {
3351
+ return GetField<uint64_t>(VT_MAX_CACHE_SIZE, 0);
3352
+ }
3353
+ uint64_t current_id() const {
3354
+ return GetField<uint64_t>(VT_CURRENT_ID, 0);
3355
+ }
3356
+ const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *lru_cache() const {
3357
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_LRU_CACHE);
3358
+ }
3359
+ const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *encoding_lookup_keys() const {
3360
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_ENCODING_LOOKUP_KEYS);
3361
+ }
3362
+ const ::flatbuffers::Vector<const nvfuser::serde::EncodingEntry *> *encoding_lookup_values() const {
3363
+ return GetPointer<const ::flatbuffers::Vector<const nvfuser::serde::EncodingEntry *> *>(VT_ENCODING_LOOKUP_VALUES);
3364
+ }
3365
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3366
+ return VerifyTableStart(verifier) &&
3367
+ VerifyField<uint64_t>(verifier, VT_MAX_CACHE_SIZE, 8) &&
3368
+ VerifyField<uint64_t>(verifier, VT_CURRENT_ID, 8) &&
3369
+ VerifyOffset(verifier, VT_LRU_CACHE) &&
3370
+ verifier.VerifyVector(lru_cache()) &&
3371
+ verifier.VerifyVectorOfStrings(lru_cache()) &&
3372
+ VerifyOffset(verifier, VT_ENCODING_LOOKUP_KEYS) &&
3373
+ verifier.VerifyVector(encoding_lookup_keys()) &&
3374
+ verifier.VerifyVectorOfStrings(encoding_lookup_keys()) &&
3375
+ VerifyOffset(verifier, VT_ENCODING_LOOKUP_VALUES) &&
3376
+ verifier.VerifyVector(encoding_lookup_values()) &&
3377
+ verifier.EndTable();
3378
+ }
3379
+ };
3380
+
3381
+ struct InputsIdLookupBuilder {
3382
+ typedef InputsIdLookup Table;
3383
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3384
+ ::flatbuffers::uoffset_t start_;
3385
+ void add_max_cache_size(uint64_t max_cache_size) {
3386
+ fbb_.AddElement<uint64_t>(InputsIdLookup::VT_MAX_CACHE_SIZE, max_cache_size, 0);
3387
+ }
3388
+ void add_current_id(uint64_t current_id) {
3389
+ fbb_.AddElement<uint64_t>(InputsIdLookup::VT_CURRENT_ID, current_id, 0);
3390
+ }
3391
+ void add_lru_cache(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> lru_cache) {
3392
+ fbb_.AddOffset(InputsIdLookup::VT_LRU_CACHE, lru_cache);
3393
+ }
3394
+ void add_encoding_lookup_keys(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> encoding_lookup_keys) {
3395
+ fbb_.AddOffset(InputsIdLookup::VT_ENCODING_LOOKUP_KEYS, encoding_lookup_keys);
3396
+ }
3397
+ void add_encoding_lookup_values(::flatbuffers::Offset<::flatbuffers::Vector<const nvfuser::serde::EncodingEntry *>> encoding_lookup_values) {
3398
+ fbb_.AddOffset(InputsIdLookup::VT_ENCODING_LOOKUP_VALUES, encoding_lookup_values);
3399
+ }
3400
+ explicit InputsIdLookupBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3401
+ : fbb_(_fbb) {
3402
+ start_ = fbb_.StartTable();
3403
+ }
3404
+ ::flatbuffers::Offset<InputsIdLookup> Finish() {
3405
+ const auto end = fbb_.EndTable(start_);
3406
+ auto o = ::flatbuffers::Offset<InputsIdLookup>(end);
3407
+ return o;
3408
+ }
3409
+ };
3410
+
3411
+ inline ::flatbuffers::Offset<InputsIdLookup> CreateInputsIdLookup(
3412
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3413
+ uint64_t max_cache_size = 0,
3414
+ uint64_t current_id = 0,
3415
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> lru_cache = 0,
3416
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> encoding_lookup_keys = 0,
3417
+ ::flatbuffers::Offset<::flatbuffers::Vector<const nvfuser::serde::EncodingEntry *>> encoding_lookup_values = 0) {
3418
+ InputsIdLookupBuilder builder_(_fbb);
3419
+ builder_.add_current_id(current_id);
3420
+ builder_.add_max_cache_size(max_cache_size);
3421
+ builder_.add_encoding_lookup_values(encoding_lookup_values);
3422
+ builder_.add_encoding_lookup_keys(encoding_lookup_keys);
3423
+ builder_.add_lru_cache(lru_cache);
3424
+ return builder_.Finish();
3425
+ }
3426
+
3427
+ inline ::flatbuffers::Offset<InputsIdLookup> CreateInputsIdLookupDirect(
3428
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3429
+ uint64_t max_cache_size = 0,
3430
+ uint64_t current_id = 0,
3431
+ const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *lru_cache = nullptr,
3432
+ const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *encoding_lookup_keys = nullptr,
3433
+ const std::vector<nvfuser::serde::EncodingEntry> *encoding_lookup_values = nullptr) {
3434
+ auto lru_cache__ = lru_cache ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*lru_cache) : 0;
3435
+ auto encoding_lookup_keys__ = encoding_lookup_keys ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*encoding_lookup_keys) : 0;
3436
+ auto encoding_lookup_values__ = encoding_lookup_values ? _fbb.CreateVectorOfStructs<nvfuser::serde::EncodingEntry>(*encoding_lookup_values) : 0;
3437
+ return nvfuser::serde::CreateInputsIdLookup(
3438
+ _fbb,
3439
+ max_cache_size,
3440
+ current_id,
3441
+ lru_cache__,
3442
+ encoding_lookup_keys__,
3443
+ encoding_lookup_values__);
3444
+ }
3445
+
3446
+ struct KernelRuntimeState FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3447
+ typedef KernelRuntimeStateBuilder Builder;
3448
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3449
+ VT_DEVICE_ID = 4,
3450
+ VT_CONCRETE_ID = 6,
3451
+ VT_HAS_DYNAMIC_TRANSFORM_INFO = 8,
3452
+ VT_RUNTIMES = 10
3453
+ };
3454
+ int64_t device_id() const {
3455
+ return GetField<int64_t>(VT_DEVICE_ID, 0);
3456
+ }
3457
+ int64_t concrete_id() const {
3458
+ return GetField<int64_t>(VT_CONCRETE_ID, 0);
3459
+ }
3460
+ bool has_dynamic_transform_info() const {
3461
+ return GetField<uint8_t>(VT_HAS_DYNAMIC_TRANSFORM_INFO, 0) != 0;
3462
+ }
3463
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionKernelRuntime>> *runtimes() const {
3464
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionKernelRuntime>> *>(VT_RUNTIMES);
3465
+ }
3466
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3467
+ return VerifyTableStart(verifier) &&
3468
+ VerifyField<int64_t>(verifier, VT_DEVICE_ID, 8) &&
3469
+ VerifyField<int64_t>(verifier, VT_CONCRETE_ID, 8) &&
3470
+ VerifyField<uint8_t>(verifier, VT_HAS_DYNAMIC_TRANSFORM_INFO, 1) &&
3471
+ VerifyOffset(verifier, VT_RUNTIMES) &&
3472
+ verifier.VerifyVector(runtimes()) &&
3473
+ verifier.VerifyVectorOfTables(runtimes()) &&
3474
+ verifier.EndTable();
3475
+ }
3476
+ };
3477
+
3478
+ struct KernelRuntimeStateBuilder {
3479
+ typedef KernelRuntimeState Table;
3480
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3481
+ ::flatbuffers::uoffset_t start_;
3482
+ void add_device_id(int64_t device_id) {
3483
+ fbb_.AddElement<int64_t>(KernelRuntimeState::VT_DEVICE_ID, device_id, 0);
3484
+ }
3485
+ void add_concrete_id(int64_t concrete_id) {
3486
+ fbb_.AddElement<int64_t>(KernelRuntimeState::VT_CONCRETE_ID, concrete_id, 0);
3487
+ }
3488
+ void add_has_dynamic_transform_info(bool has_dynamic_transform_info) {
3489
+ fbb_.AddElement<uint8_t>(KernelRuntimeState::VT_HAS_DYNAMIC_TRANSFORM_INFO, static_cast<uint8_t>(has_dynamic_transform_info), 0);
3490
+ }
3491
+ void add_runtimes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionKernelRuntime>>> runtimes) {
3492
+ fbb_.AddOffset(KernelRuntimeState::VT_RUNTIMES, runtimes);
3493
+ }
3494
+ explicit KernelRuntimeStateBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3495
+ : fbb_(_fbb) {
3496
+ start_ = fbb_.StartTable();
3497
+ }
3498
+ ::flatbuffers::Offset<KernelRuntimeState> Finish() {
3499
+ const auto end = fbb_.EndTable(start_);
3500
+ auto o = ::flatbuffers::Offset<KernelRuntimeState>(end);
3501
+ return o;
3502
+ }
3503
+ };
3504
+
3505
+ inline ::flatbuffers::Offset<KernelRuntimeState> CreateKernelRuntimeState(
3506
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3507
+ int64_t device_id = 0,
3508
+ int64_t concrete_id = 0,
3509
+ bool has_dynamic_transform_info = false,
3510
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionKernelRuntime>>> runtimes = 0) {
3511
+ KernelRuntimeStateBuilder builder_(_fbb);
3512
+ builder_.add_concrete_id(concrete_id);
3513
+ builder_.add_device_id(device_id);
3514
+ builder_.add_runtimes(runtimes);
3515
+ builder_.add_has_dynamic_transform_info(has_dynamic_transform_info);
3516
+ return builder_.Finish();
3517
+ }
3518
+
3519
+ inline ::flatbuffers::Offset<KernelRuntimeState> CreateKernelRuntimeStateDirect(
3520
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3521
+ int64_t device_id = 0,
3522
+ int64_t concrete_id = 0,
3523
+ bool has_dynamic_transform_info = false,
3524
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::FusionKernelRuntime>> *runtimes = nullptr) {
3525
+ auto runtimes__ = runtimes ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::FusionKernelRuntime>>(*runtimes) : 0;
3526
+ return nvfuser::serde::CreateKernelRuntimeState(
3527
+ _fbb,
3528
+ device_id,
3529
+ concrete_id,
3530
+ has_dynamic_transform_info,
3531
+ runtimes__);
3532
+ }
3533
+
3534
+ struct FusionExecutorCache FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3535
+ typedef FusionExecutorCacheBuilder Builder;
3536
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3537
+ VT_FUSION_ID = 4,
3538
+ VT_INPUTS_CACHE = 6,
3539
+ VT_KERNEL_RUNTIMES_MAP = 8,
3540
+ VT_KERNEL_CACHE_KEYS = 10,
3541
+ VT_KERNEL_CACHE_VALUES = 12
3542
+ };
3543
+ int64_t fusion_id() const {
3544
+ return GetField<int64_t>(VT_FUSION_ID, 0);
3545
+ }
3546
+ const nvfuser::serde::InputsIdLookup *inputs_cache() const {
3547
+ return GetPointer<const nvfuser::serde::InputsIdLookup *>(VT_INPUTS_CACHE);
3548
+ }
3549
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelRuntimeState>> *kernel_runtimes_map() const {
3550
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelRuntimeState>> *>(VT_KERNEL_RUNTIMES_MAP);
3551
+ }
3552
+ const ::flatbuffers::Vector<uint64_t> *kernel_cache_keys() const {
3553
+ return GetPointer<const ::flatbuffers::Vector<uint64_t> *>(VT_KERNEL_CACHE_KEYS);
3554
+ }
3555
+ const ::flatbuffers::Vector<uint64_t> *kernel_cache_values() const {
3556
+ return GetPointer<const ::flatbuffers::Vector<uint64_t> *>(VT_KERNEL_CACHE_VALUES);
3557
+ }
3558
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3559
+ return VerifyTableStart(verifier) &&
3560
+ VerifyField<int64_t>(verifier, VT_FUSION_ID, 8) &&
3561
+ VerifyOffset(verifier, VT_INPUTS_CACHE) &&
3562
+ verifier.VerifyTable(inputs_cache()) &&
3563
+ VerifyOffset(verifier, VT_KERNEL_RUNTIMES_MAP) &&
3564
+ verifier.VerifyVector(kernel_runtimes_map()) &&
3565
+ verifier.VerifyVectorOfTables(kernel_runtimes_map()) &&
3566
+ VerifyOffset(verifier, VT_KERNEL_CACHE_KEYS) &&
3567
+ verifier.VerifyVector(kernel_cache_keys()) &&
3568
+ VerifyOffset(verifier, VT_KERNEL_CACHE_VALUES) &&
3569
+ verifier.VerifyVector(kernel_cache_values()) &&
3570
+ verifier.EndTable();
3571
+ }
3572
+ };
3573
+
3574
+ struct FusionExecutorCacheBuilder {
3575
+ typedef FusionExecutorCache Table;
3576
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3577
+ ::flatbuffers::uoffset_t start_;
3578
+ void add_fusion_id(int64_t fusion_id) {
3579
+ fbb_.AddElement<int64_t>(FusionExecutorCache::VT_FUSION_ID, fusion_id, 0);
3580
+ }
3581
+ void add_inputs_cache(::flatbuffers::Offset<nvfuser::serde::InputsIdLookup> inputs_cache) {
3582
+ fbb_.AddOffset(FusionExecutorCache::VT_INPUTS_CACHE, inputs_cache);
3583
+ }
3584
+ void add_kernel_runtimes_map(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelRuntimeState>>> kernel_runtimes_map) {
3585
+ fbb_.AddOffset(FusionExecutorCache::VT_KERNEL_RUNTIMES_MAP, kernel_runtimes_map);
3586
+ }
3587
+ void add_kernel_cache_keys(::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> kernel_cache_keys) {
3588
+ fbb_.AddOffset(FusionExecutorCache::VT_KERNEL_CACHE_KEYS, kernel_cache_keys);
3589
+ }
3590
+ void add_kernel_cache_values(::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> kernel_cache_values) {
3591
+ fbb_.AddOffset(FusionExecutorCache::VT_KERNEL_CACHE_VALUES, kernel_cache_values);
3592
+ }
3593
+ explicit FusionExecutorCacheBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3594
+ : fbb_(_fbb) {
3595
+ start_ = fbb_.StartTable();
3596
+ }
3597
+ ::flatbuffers::Offset<FusionExecutorCache> Finish() {
3598
+ const auto end = fbb_.EndTable(start_);
3599
+ auto o = ::flatbuffers::Offset<FusionExecutorCache>(end);
3600
+ return o;
3601
+ }
3602
+ };
3603
+
3604
+ inline ::flatbuffers::Offset<FusionExecutorCache> CreateFusionExecutorCache(
3605
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3606
+ int64_t fusion_id = 0,
3607
+ ::flatbuffers::Offset<nvfuser::serde::InputsIdLookup> inputs_cache = 0,
3608
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::KernelRuntimeState>>> kernel_runtimes_map = 0,
3609
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> kernel_cache_keys = 0,
3610
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> kernel_cache_values = 0) {
3611
+ FusionExecutorCacheBuilder builder_(_fbb);
3612
+ builder_.add_fusion_id(fusion_id);
3613
+ builder_.add_kernel_cache_values(kernel_cache_values);
3614
+ builder_.add_kernel_cache_keys(kernel_cache_keys);
3615
+ builder_.add_kernel_runtimes_map(kernel_runtimes_map);
3616
+ builder_.add_inputs_cache(inputs_cache);
3617
+ return builder_.Finish();
3618
+ }
3619
+
3620
+ inline ::flatbuffers::Offset<FusionExecutorCache> CreateFusionExecutorCacheDirect(
3621
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3622
+ int64_t fusion_id = 0,
3623
+ ::flatbuffers::Offset<nvfuser::serde::InputsIdLookup> inputs_cache = 0,
3624
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::KernelRuntimeState>> *kernel_runtimes_map = nullptr,
3625
+ const std::vector<uint64_t> *kernel_cache_keys = nullptr,
3626
+ const std::vector<uint64_t> *kernel_cache_values = nullptr) {
3627
+ auto kernel_runtimes_map__ = kernel_runtimes_map ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::KernelRuntimeState>>(*kernel_runtimes_map) : 0;
3628
+ auto kernel_cache_keys__ = kernel_cache_keys ? _fbb.CreateVector<uint64_t>(*kernel_cache_keys) : 0;
3629
+ auto kernel_cache_values__ = kernel_cache_values ? _fbb.CreateVector<uint64_t>(*kernel_cache_values) : 0;
3630
+ return nvfuser::serde::CreateFusionExecutorCache(
3631
+ _fbb,
3632
+ fusion_id,
3633
+ inputs_cache,
3634
+ kernel_runtimes_map__,
3635
+ kernel_cache_keys__,
3636
+ kernel_cache_values__);
3637
+ }
3638
+
3639
+ struct RecordFunctor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3640
+ typedef RecordFunctorBuilder Builder;
3641
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3642
+ VT_ARGS = 4,
3643
+ VT_OUTPUTS = 6,
3644
+ VT_NAME = 8,
3645
+ VT_TYPE = 10,
3646
+ VT_DATA_TYPE = 12,
3647
+ VT_DATA = 14
3648
+ };
3649
+ const ::flatbuffers::Vector<const nvfuser::serde::State *> *args() const {
3650
+ return GetPointer<const ::flatbuffers::Vector<const nvfuser::serde::State *> *>(VT_ARGS);
3651
+ }
3652
+ const ::flatbuffers::Vector<const nvfuser::serde::State *> *outputs() const {
3653
+ return GetPointer<const ::flatbuffers::Vector<const nvfuser::serde::State *> *>(VT_OUTPUTS);
3654
+ }
3655
+ const ::flatbuffers::String *name() const {
3656
+ return GetPointer<const ::flatbuffers::String *>(VT_NAME);
3657
+ }
3658
+ nvfuser::serde::RecordType type() const {
3659
+ return static_cast<nvfuser::serde::RecordType>(GetField<int32_t>(VT_TYPE, 0));
3660
+ }
3661
+ nvfuser::serde::RecordData data_type() const {
3662
+ return static_cast<nvfuser::serde::RecordData>(GetField<uint8_t>(VT_DATA_TYPE, 0));
3663
+ }
3664
+ const void *data() const {
3665
+ return GetPointer<const void *>(VT_DATA);
3666
+ }
3667
+ template<typename T> const T *data_as() const;
3668
+ const nvfuser::serde::At *data_as_At() const {
3669
+ return data_type() == nvfuser::serde::RecordData::At ? static_cast<const nvfuser::serde::At *>(data()) : nullptr;
3670
+ }
3671
+ const nvfuser::serde::BatchNorm *data_as_BatchNorm() const {
3672
+ return data_type() == nvfuser::serde::RecordData::BatchNorm ? static_cast<const nvfuser::serde::BatchNorm *>(data()) : nullptr;
3673
+ }
3674
+ const nvfuser::serde::Broadcast *data_as_Broadcast() const {
3675
+ return data_type() == nvfuser::serde::RecordData::Broadcast ? static_cast<const nvfuser::serde::Broadcast *>(data()) : nullptr;
3676
+ }
3677
+ const nvfuser::serde::BroadcastInDim *data_as_BroadcastInDim() const {
3678
+ return data_type() == nvfuser::serde::RecordData::BroadcastInDim ? static_cast<const nvfuser::serde::BroadcastInDim *>(data()) : nullptr;
3679
+ }
3680
+ const nvfuser::serde::Cat *data_as_Cat() const {
3681
+ return data_type() == nvfuser::serde::RecordData::Cat ? static_cast<const nvfuser::serde::Cat *>(data()) : nullptr;
3682
+ }
3683
+ const nvfuser::serde::Dimension *data_as_Dimension() const {
3684
+ return data_type() == nvfuser::serde::RecordData::Dimension ? static_cast<const nvfuser::serde::Dimension *>(data()) : nullptr;
3685
+ }
3686
+ const nvfuser::serde::Dtype *data_as_Dtype() const {
3687
+ return data_type() == nvfuser::serde::RecordData::Dtype ? static_cast<const nvfuser::serde::Dtype *>(data()) : nullptr;
3688
+ }
3689
+ const nvfuser::serde::Norm *data_as_Norm() const {
3690
+ return data_type() == nvfuser::serde::RecordData::Norm ? static_cast<const nvfuser::serde::Norm *>(data()) : nullptr;
3691
+ }
3692
+ const nvfuser::serde::Output *data_as_Output() const {
3693
+ return data_type() == nvfuser::serde::RecordData::Output ? static_cast<const nvfuser::serde::Output *>(data()) : nullptr;
3694
+ }
3695
+ const nvfuser::serde::Dims *data_as_Dims() const {
3696
+ return data_type() == nvfuser::serde::RecordData::Dims ? static_cast<const nvfuser::serde::Dims *>(data()) : nullptr;
3697
+ }
3698
+ const nvfuser::serde::Slice *data_as_Slice() const {
3699
+ return data_type() == nvfuser::serde::RecordData::Slice ? static_cast<const nvfuser::serde::Slice *>(data()) : nullptr;
3700
+ }
3701
+ const nvfuser::serde::Squeeze *data_as_Squeeze() const {
3702
+ return data_type() == nvfuser::serde::RecordData::Squeeze ? static_cast<const nvfuser::serde::Squeeze *>(data()) : nullptr;
3703
+ }
3704
+ const nvfuser::serde::Reduction *data_as_Reduction() const {
3705
+ return data_type() == nvfuser::serde::RecordData::Reduction ? static_cast<const nvfuser::serde::Reduction *>(data()) : nullptr;
3706
+ }
3707
+ const nvfuser::serde::Scalar *data_as_Scalar() const {
3708
+ return data_type() == nvfuser::serde::RecordData::Scalar ? static_cast<const nvfuser::serde::Scalar *>(data()) : nullptr;
3709
+ }
3710
+ const nvfuser::serde::Size *data_as_Size() const {
3711
+ return data_type() == nvfuser::serde::RecordData::Size ? static_cast<const nvfuser::serde::Size *>(data()) : nullptr;
3712
+ }
3713
+ const nvfuser::serde::Tensor *data_as_Tensor() const {
3714
+ return data_type() == nvfuser::serde::RecordData::Tensor ? static_cast<const nvfuser::serde::Tensor *>(data()) : nullptr;
3715
+ }
3716
+ const nvfuser::serde::TensorCreationSymbolic *data_as_TensorCreationSymbolic() const {
3717
+ return data_type() == nvfuser::serde::RecordData::TensorCreationSymbolic ? static_cast<const nvfuser::serde::TensorCreationSymbolic *>(data()) : nullptr;
3718
+ }
3719
+ const nvfuser::serde::Vector *data_as_Vector() const {
3720
+ return data_type() == nvfuser::serde::RecordData::Vector ? static_cast<const nvfuser::serde::Vector *>(data()) : nullptr;
3721
+ }
3722
+ const nvfuser::serde::Welford *data_as_Welford() const {
3723
+ return data_type() == nvfuser::serde::RecordData::Welford ? static_cast<const nvfuser::serde::Welford *>(data()) : nullptr;
3724
+ }
3725
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3726
+ return VerifyTableStart(verifier) &&
3727
+ VerifyOffset(verifier, VT_ARGS) &&
3728
+ verifier.VerifyVector(args()) &&
3729
+ VerifyOffset(verifier, VT_OUTPUTS) &&
3730
+ verifier.VerifyVector(outputs()) &&
3731
+ VerifyOffset(verifier, VT_NAME) &&
3732
+ verifier.VerifyString(name()) &&
3733
+ VerifyField<int32_t>(verifier, VT_TYPE, 4) &&
3734
+ VerifyField<uint8_t>(verifier, VT_DATA_TYPE, 1) &&
3735
+ VerifyOffset(verifier, VT_DATA) &&
3736
+ VerifyRecordData(verifier, data(), data_type()) &&
3737
+ verifier.EndTable();
3738
+ }
3739
+ };
3740
+
3741
+ template<> inline const nvfuser::serde::At *RecordFunctor::data_as<nvfuser::serde::At>() const {
3742
+ return data_as_At();
3743
+ }
3744
+
3745
+ template<> inline const nvfuser::serde::BatchNorm *RecordFunctor::data_as<nvfuser::serde::BatchNorm>() const {
3746
+ return data_as_BatchNorm();
3747
+ }
3748
+
3749
+ template<> inline const nvfuser::serde::Broadcast *RecordFunctor::data_as<nvfuser::serde::Broadcast>() const {
3750
+ return data_as_Broadcast();
3751
+ }
3752
+
3753
+ template<> inline const nvfuser::serde::BroadcastInDim *RecordFunctor::data_as<nvfuser::serde::BroadcastInDim>() const {
3754
+ return data_as_BroadcastInDim();
3755
+ }
3756
+
3757
+ template<> inline const nvfuser::serde::Cat *RecordFunctor::data_as<nvfuser::serde::Cat>() const {
3758
+ return data_as_Cat();
3759
+ }
3760
+
3761
+ template<> inline const nvfuser::serde::Dimension *RecordFunctor::data_as<nvfuser::serde::Dimension>() const {
3762
+ return data_as_Dimension();
3763
+ }
3764
+
3765
+ template<> inline const nvfuser::serde::Dtype *RecordFunctor::data_as<nvfuser::serde::Dtype>() const {
3766
+ return data_as_Dtype();
3767
+ }
3768
+
3769
+ template<> inline const nvfuser::serde::Norm *RecordFunctor::data_as<nvfuser::serde::Norm>() const {
3770
+ return data_as_Norm();
3771
+ }
3772
+
3773
+ template<> inline const nvfuser::serde::Output *RecordFunctor::data_as<nvfuser::serde::Output>() const {
3774
+ return data_as_Output();
3775
+ }
3776
+
3777
+ template<> inline const nvfuser::serde::Dims *RecordFunctor::data_as<nvfuser::serde::Dims>() const {
3778
+ return data_as_Dims();
3779
+ }
3780
+
3781
+ template<> inline const nvfuser::serde::Slice *RecordFunctor::data_as<nvfuser::serde::Slice>() const {
3782
+ return data_as_Slice();
3783
+ }
3784
+
3785
+ template<> inline const nvfuser::serde::Squeeze *RecordFunctor::data_as<nvfuser::serde::Squeeze>() const {
3786
+ return data_as_Squeeze();
3787
+ }
3788
+
3789
+ template<> inline const nvfuser::serde::Reduction *RecordFunctor::data_as<nvfuser::serde::Reduction>() const {
3790
+ return data_as_Reduction();
3791
+ }
3792
+
3793
+ template<> inline const nvfuser::serde::Scalar *RecordFunctor::data_as<nvfuser::serde::Scalar>() const {
3794
+ return data_as_Scalar();
3795
+ }
3796
+
3797
+ template<> inline const nvfuser::serde::Size *RecordFunctor::data_as<nvfuser::serde::Size>() const {
3798
+ return data_as_Size();
3799
+ }
3800
+
3801
+ template<> inline const nvfuser::serde::Tensor *RecordFunctor::data_as<nvfuser::serde::Tensor>() const {
3802
+ return data_as_Tensor();
3803
+ }
3804
+
3805
+ template<> inline const nvfuser::serde::TensorCreationSymbolic *RecordFunctor::data_as<nvfuser::serde::TensorCreationSymbolic>() const {
3806
+ return data_as_TensorCreationSymbolic();
3807
+ }
3808
+
3809
+ template<> inline const nvfuser::serde::Vector *RecordFunctor::data_as<nvfuser::serde::Vector>() const {
3810
+ return data_as_Vector();
3811
+ }
3812
+
3813
+ template<> inline const nvfuser::serde::Welford *RecordFunctor::data_as<nvfuser::serde::Welford>() const {
3814
+ return data_as_Welford();
3815
+ }
3816
+
3817
+ struct RecordFunctorBuilder {
3818
+ typedef RecordFunctor Table;
3819
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3820
+ ::flatbuffers::uoffset_t start_;
3821
+ void add_args(::flatbuffers::Offset<::flatbuffers::Vector<const nvfuser::serde::State *>> args) {
3822
+ fbb_.AddOffset(RecordFunctor::VT_ARGS, args);
3823
+ }
3824
+ void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<const nvfuser::serde::State *>> outputs) {
3825
+ fbb_.AddOffset(RecordFunctor::VT_OUTPUTS, outputs);
3826
+ }
3827
+ void add_name(::flatbuffers::Offset<::flatbuffers::String> name) {
3828
+ fbb_.AddOffset(RecordFunctor::VT_NAME, name);
3829
+ }
3830
+ void add_type(nvfuser::serde::RecordType type) {
3831
+ fbb_.AddElement<int32_t>(RecordFunctor::VT_TYPE, static_cast<int32_t>(type), 0);
3832
+ }
3833
+ void add_data_type(nvfuser::serde::RecordData data_type) {
3834
+ fbb_.AddElement<uint8_t>(RecordFunctor::VT_DATA_TYPE, static_cast<uint8_t>(data_type), 0);
3835
+ }
3836
+ void add_data(::flatbuffers::Offset<void> data) {
3837
+ fbb_.AddOffset(RecordFunctor::VT_DATA, data);
3838
+ }
3839
+ explicit RecordFunctorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3840
+ : fbb_(_fbb) {
3841
+ start_ = fbb_.StartTable();
3842
+ }
3843
+ ::flatbuffers::Offset<RecordFunctor> Finish() {
3844
+ const auto end = fbb_.EndTable(start_);
3845
+ auto o = ::flatbuffers::Offset<RecordFunctor>(end);
3846
+ return o;
3847
+ }
3848
+ };
3849
+
3850
+ inline ::flatbuffers::Offset<RecordFunctor> CreateRecordFunctor(
3851
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3852
+ ::flatbuffers::Offset<::flatbuffers::Vector<const nvfuser::serde::State *>> args = 0,
3853
+ ::flatbuffers::Offset<::flatbuffers::Vector<const nvfuser::serde::State *>> outputs = 0,
3854
+ ::flatbuffers::Offset<::flatbuffers::String> name = 0,
3855
+ nvfuser::serde::RecordType type = nvfuser::serde::RecordType::Base,
3856
+ nvfuser::serde::RecordData data_type = nvfuser::serde::RecordData::NONE,
3857
+ ::flatbuffers::Offset<void> data = 0) {
3858
+ RecordFunctorBuilder builder_(_fbb);
3859
+ builder_.add_data(data);
3860
+ builder_.add_type(type);
3861
+ builder_.add_name(name);
3862
+ builder_.add_outputs(outputs);
3863
+ builder_.add_args(args);
3864
+ builder_.add_data_type(data_type);
3865
+ return builder_.Finish();
3866
+ }
3867
+
3868
+ inline ::flatbuffers::Offset<RecordFunctor> CreateRecordFunctorDirect(
3869
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3870
+ const std::vector<nvfuser::serde::State> *args = nullptr,
3871
+ const std::vector<nvfuser::serde::State> *outputs = nullptr,
3872
+ const char *name = nullptr,
3873
+ nvfuser::serde::RecordType type = nvfuser::serde::RecordType::Base,
3874
+ nvfuser::serde::RecordData data_type = nvfuser::serde::RecordData::NONE,
3875
+ ::flatbuffers::Offset<void> data = 0) {
3876
+ auto args__ = args ? _fbb.CreateVectorOfStructs<nvfuser::serde::State>(*args) : 0;
3877
+ auto outputs__ = outputs ? _fbb.CreateVectorOfStructs<nvfuser::serde::State>(*outputs) : 0;
3878
+ auto name__ = name ? _fbb.CreateString(name) : 0;
3879
+ return nvfuser::serde::CreateRecordFunctor(
3880
+ _fbb,
3881
+ args__,
3882
+ outputs__,
3883
+ name__,
3884
+ type,
3885
+ data_type,
3886
+ data);
3887
+ }
3888
+
3889
+ struct TrieNode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3890
+ typedef TrieNodeBuilder Builder;
3891
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3892
+ VT_RECORD = 4,
3893
+ VT_CHILDREN = 6,
3894
+ VT_FUSION_ID = 8,
3895
+ VT_VISITS = 10,
3896
+ VT_IS_TERMINAL = 12
3897
+ };
3898
+ const nvfuser::serde::RecordFunctor *record() const {
3899
+ return GetPointer<const nvfuser::serde::RecordFunctor *>(VT_RECORD);
3900
+ }
3901
+ const ::flatbuffers::Vector<uint64_t> *children() const {
3902
+ return GetPointer<const ::flatbuffers::Vector<uint64_t> *>(VT_CHILDREN);
3903
+ }
3904
+ uint64_t fusion_id() const {
3905
+ return GetField<uint64_t>(VT_FUSION_ID, 0);
3906
+ }
3907
+ uint64_t visits() const {
3908
+ return GetField<uint64_t>(VT_VISITS, 0);
3909
+ }
3910
+ bool is_terminal() const {
3911
+ return GetField<uint8_t>(VT_IS_TERMINAL, 0) != 0;
3912
+ }
3913
+ bool Verify(::flatbuffers::Verifier &verifier) const {
3914
+ return VerifyTableStart(verifier) &&
3915
+ VerifyOffset(verifier, VT_RECORD) &&
3916
+ verifier.VerifyTable(record()) &&
3917
+ VerifyOffset(verifier, VT_CHILDREN) &&
3918
+ verifier.VerifyVector(children()) &&
3919
+ VerifyField<uint64_t>(verifier, VT_FUSION_ID, 8) &&
3920
+ VerifyField<uint64_t>(verifier, VT_VISITS, 8) &&
3921
+ VerifyField<uint8_t>(verifier, VT_IS_TERMINAL, 1) &&
3922
+ verifier.EndTable();
3923
+ }
3924
+ };
3925
+
3926
+ struct TrieNodeBuilder {
3927
+ typedef TrieNode Table;
3928
+ ::flatbuffers::FlatBufferBuilder &fbb_;
3929
+ ::flatbuffers::uoffset_t start_;
3930
+ void add_record(::flatbuffers::Offset<nvfuser::serde::RecordFunctor> record) {
3931
+ fbb_.AddOffset(TrieNode::VT_RECORD, record);
3932
+ }
3933
+ void add_children(::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> children) {
3934
+ fbb_.AddOffset(TrieNode::VT_CHILDREN, children);
3935
+ }
3936
+ void add_fusion_id(uint64_t fusion_id) {
3937
+ fbb_.AddElement<uint64_t>(TrieNode::VT_FUSION_ID, fusion_id, 0);
3938
+ }
3939
+ void add_visits(uint64_t visits) {
3940
+ fbb_.AddElement<uint64_t>(TrieNode::VT_VISITS, visits, 0);
3941
+ }
3942
+ void add_is_terminal(bool is_terminal) {
3943
+ fbb_.AddElement<uint8_t>(TrieNode::VT_IS_TERMINAL, static_cast<uint8_t>(is_terminal), 0);
3944
+ }
3945
+ explicit TrieNodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
3946
+ : fbb_(_fbb) {
3947
+ start_ = fbb_.StartTable();
3948
+ }
3949
+ ::flatbuffers::Offset<TrieNode> Finish() {
3950
+ const auto end = fbb_.EndTable(start_);
3951
+ auto o = ::flatbuffers::Offset<TrieNode>(end);
3952
+ return o;
3953
+ }
3954
+ };
3955
+
3956
+ inline ::flatbuffers::Offset<TrieNode> CreateTrieNode(
3957
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3958
+ ::flatbuffers::Offset<nvfuser::serde::RecordFunctor> record = 0,
3959
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> children = 0,
3960
+ uint64_t fusion_id = 0,
3961
+ uint64_t visits = 0,
3962
+ bool is_terminal = false) {
3963
+ TrieNodeBuilder builder_(_fbb);
3964
+ builder_.add_visits(visits);
3965
+ builder_.add_fusion_id(fusion_id);
3966
+ builder_.add_children(children);
3967
+ builder_.add_record(record);
3968
+ builder_.add_is_terminal(is_terminal);
3969
+ return builder_.Finish();
3970
+ }
3971
+
3972
+ inline ::flatbuffers::Offset<TrieNode> CreateTrieNodeDirect(
3973
+ ::flatbuffers::FlatBufferBuilder &_fbb,
3974
+ ::flatbuffers::Offset<nvfuser::serde::RecordFunctor> record = 0,
3975
+ const std::vector<uint64_t> *children = nullptr,
3976
+ uint64_t fusion_id = 0,
3977
+ uint64_t visits = 0,
3978
+ bool is_terminal = false) {
3979
+ auto children__ = children ? _fbb.CreateVector<uint64_t>(*children) : 0;
3980
+ return nvfuser::serde::CreateTrieNode(
3981
+ _fbb,
3982
+ record,
3983
+ children__,
3984
+ fusion_id,
3985
+ visits,
3986
+ is_terminal);
3987
+ }
3988
+
3989
+ struct FusionCache FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
3990
+ typedef FusionCacheBuilder Builder;
3991
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
3992
+ VT_MAX_FUSIONS = 4,
3993
+ VT_STRUCTURE = 6,
3994
+ VT_TERMINAL_NODES = 8,
3995
+ VT_AUTO_GEN_SCHEDULES = 10,
3996
+ VT_GLOBAL_FUSION_COUNT = 12,
3997
+ VT_DEVICE_MAJOR = 14,
3998
+ VT_DEVICE_MINOR = 16,
3999
+ VT_CUDA_MAJOR = 18,
4000
+ VT_CUDA_MINOR = 20
4001
+ };
4002
+ uint64_t max_fusions() const {
4003
+ return GetField<uint64_t>(VT_MAX_FUSIONS, 0);
4004
+ }
4005
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TrieNode>> *structure() const {
4006
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TrieNode>> *>(VT_STRUCTURE);
4007
+ }
4008
+ const ::flatbuffers::Vector<uint64_t> *terminal_nodes() const {
4009
+ return GetPointer<const ::flatbuffers::Vector<uint64_t> *>(VT_TERMINAL_NODES);
4010
+ }
4011
+ const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionExecutorCache>> *auto_gen_schedules() const {
4012
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionExecutorCache>> *>(VT_AUTO_GEN_SCHEDULES);
4013
+ }
4014
+ int64_t global_fusion_count() const {
4015
+ return GetField<int64_t>(VT_GLOBAL_FUSION_COUNT, 0);
4016
+ }
4017
+ int64_t device_major() const {
4018
+ return GetField<int64_t>(VT_DEVICE_MAJOR, 0);
4019
+ }
4020
+ int64_t device_minor() const {
4021
+ return GetField<int64_t>(VT_DEVICE_MINOR, 0);
4022
+ }
4023
+ int64_t cuda_major() const {
4024
+ return GetField<int64_t>(VT_CUDA_MAJOR, 0);
4025
+ }
4026
+ int64_t cuda_minor() const {
4027
+ return GetField<int64_t>(VT_CUDA_MINOR, 0);
4028
+ }
4029
+ bool Verify(::flatbuffers::Verifier &verifier) const {
4030
+ return VerifyTableStart(verifier) &&
4031
+ VerifyField<uint64_t>(verifier, VT_MAX_FUSIONS, 8) &&
4032
+ VerifyOffset(verifier, VT_STRUCTURE) &&
4033
+ verifier.VerifyVector(structure()) &&
4034
+ verifier.VerifyVectorOfTables(structure()) &&
4035
+ VerifyOffset(verifier, VT_TERMINAL_NODES) &&
4036
+ verifier.VerifyVector(terminal_nodes()) &&
4037
+ VerifyOffset(verifier, VT_AUTO_GEN_SCHEDULES) &&
4038
+ verifier.VerifyVector(auto_gen_schedules()) &&
4039
+ verifier.VerifyVectorOfTables(auto_gen_schedules()) &&
4040
+ VerifyField<int64_t>(verifier, VT_GLOBAL_FUSION_COUNT, 8) &&
4041
+ VerifyField<int64_t>(verifier, VT_DEVICE_MAJOR, 8) &&
4042
+ VerifyField<int64_t>(verifier, VT_DEVICE_MINOR, 8) &&
4043
+ VerifyField<int64_t>(verifier, VT_CUDA_MAJOR, 8) &&
4044
+ VerifyField<int64_t>(verifier, VT_CUDA_MINOR, 8) &&
4045
+ verifier.EndTable();
4046
+ }
4047
+ };
4048
+
4049
+ struct FusionCacheBuilder {
4050
+ typedef FusionCache Table;
4051
+ ::flatbuffers::FlatBufferBuilder &fbb_;
4052
+ ::flatbuffers::uoffset_t start_;
4053
+ void add_max_fusions(uint64_t max_fusions) {
4054
+ fbb_.AddElement<uint64_t>(FusionCache::VT_MAX_FUSIONS, max_fusions, 0);
4055
+ }
4056
+ void add_structure(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TrieNode>>> structure) {
4057
+ fbb_.AddOffset(FusionCache::VT_STRUCTURE, structure);
4058
+ }
4059
+ void add_terminal_nodes(::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> terminal_nodes) {
4060
+ fbb_.AddOffset(FusionCache::VT_TERMINAL_NODES, terminal_nodes);
4061
+ }
4062
+ void add_auto_gen_schedules(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionExecutorCache>>> auto_gen_schedules) {
4063
+ fbb_.AddOffset(FusionCache::VT_AUTO_GEN_SCHEDULES, auto_gen_schedules);
4064
+ }
4065
+ void add_global_fusion_count(int64_t global_fusion_count) {
4066
+ fbb_.AddElement<int64_t>(FusionCache::VT_GLOBAL_FUSION_COUNT, global_fusion_count, 0);
4067
+ }
4068
+ void add_device_major(int64_t device_major) {
4069
+ fbb_.AddElement<int64_t>(FusionCache::VT_DEVICE_MAJOR, device_major, 0);
4070
+ }
4071
+ void add_device_minor(int64_t device_minor) {
4072
+ fbb_.AddElement<int64_t>(FusionCache::VT_DEVICE_MINOR, device_minor, 0);
4073
+ }
4074
+ void add_cuda_major(int64_t cuda_major) {
4075
+ fbb_.AddElement<int64_t>(FusionCache::VT_CUDA_MAJOR, cuda_major, 0);
4076
+ }
4077
+ void add_cuda_minor(int64_t cuda_minor) {
4078
+ fbb_.AddElement<int64_t>(FusionCache::VT_CUDA_MINOR, cuda_minor, 0);
4079
+ }
4080
+ explicit FusionCacheBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
4081
+ : fbb_(_fbb) {
4082
+ start_ = fbb_.StartTable();
4083
+ }
4084
+ ::flatbuffers::Offset<FusionCache> Finish() {
4085
+ const auto end = fbb_.EndTable(start_);
4086
+ auto o = ::flatbuffers::Offset<FusionCache>(end);
4087
+ return o;
4088
+ }
4089
+ };
4090
+
4091
+ inline ::flatbuffers::Offset<FusionCache> CreateFusionCache(
4092
+ ::flatbuffers::FlatBufferBuilder &_fbb,
4093
+ uint64_t max_fusions = 0,
4094
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::TrieNode>>> structure = 0,
4095
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint64_t>> terminal_nodes = 0,
4096
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<nvfuser::serde::FusionExecutorCache>>> auto_gen_schedules = 0,
4097
+ int64_t global_fusion_count = 0,
4098
+ int64_t device_major = 0,
4099
+ int64_t device_minor = 0,
4100
+ int64_t cuda_major = 0,
4101
+ int64_t cuda_minor = 0) {
4102
+ FusionCacheBuilder builder_(_fbb);
4103
+ builder_.add_cuda_minor(cuda_minor);
4104
+ builder_.add_cuda_major(cuda_major);
4105
+ builder_.add_device_minor(device_minor);
4106
+ builder_.add_device_major(device_major);
4107
+ builder_.add_global_fusion_count(global_fusion_count);
4108
+ builder_.add_max_fusions(max_fusions);
4109
+ builder_.add_auto_gen_schedules(auto_gen_schedules);
4110
+ builder_.add_terminal_nodes(terminal_nodes);
4111
+ builder_.add_structure(structure);
4112
+ return builder_.Finish();
4113
+ }
4114
+
4115
+ inline ::flatbuffers::Offset<FusionCache> CreateFusionCacheDirect(
4116
+ ::flatbuffers::FlatBufferBuilder &_fbb,
4117
+ uint64_t max_fusions = 0,
4118
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::TrieNode>> *structure = nullptr,
4119
+ const std::vector<uint64_t> *terminal_nodes = nullptr,
4120
+ const std::vector<::flatbuffers::Offset<nvfuser::serde::FusionExecutorCache>> *auto_gen_schedules = nullptr,
4121
+ int64_t global_fusion_count = 0,
4122
+ int64_t device_major = 0,
4123
+ int64_t device_minor = 0,
4124
+ int64_t cuda_major = 0,
4125
+ int64_t cuda_minor = 0) {
4126
+ auto structure__ = structure ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::TrieNode>>(*structure) : 0;
4127
+ auto terminal_nodes__ = terminal_nodes ? _fbb.CreateVector<uint64_t>(*terminal_nodes) : 0;
4128
+ auto auto_gen_schedules__ = auto_gen_schedules ? _fbb.CreateVector<::flatbuffers::Offset<nvfuser::serde::FusionExecutorCache>>(*auto_gen_schedules) : 0;
4129
+ return nvfuser::serde::CreateFusionCache(
4130
+ _fbb,
4131
+ max_fusions,
4132
+ structure__,
4133
+ terminal_nodes__,
4134
+ auto_gen_schedules__,
4135
+ global_fusion_count,
4136
+ device_major,
4137
+ device_minor,
4138
+ cuda_major,
4139
+ cuda_minor);
4140
+ }
4141
+
4142
+ inline bool VerifyRecordData(::flatbuffers::Verifier &verifier, const void *obj, RecordData type) {
4143
+ switch (type) {
4144
+ case RecordData::NONE: {
4145
+ return true;
4146
+ }
4147
+ case RecordData::At: {
4148
+ auto ptr = reinterpret_cast<const nvfuser::serde::At *>(obj);
4149
+ return verifier.VerifyTable(ptr);
4150
+ }
4151
+ case RecordData::BatchNorm: {
4152
+ auto ptr = reinterpret_cast<const nvfuser::serde::BatchNorm *>(obj);
4153
+ return verifier.VerifyTable(ptr);
4154
+ }
4155
+ case RecordData::Broadcast: {
4156
+ auto ptr = reinterpret_cast<const nvfuser::serde::Broadcast *>(obj);
4157
+ return verifier.VerifyTable(ptr);
4158
+ }
4159
+ case RecordData::BroadcastInDim: {
4160
+ auto ptr = reinterpret_cast<const nvfuser::serde::BroadcastInDim *>(obj);
4161
+ return verifier.VerifyTable(ptr);
4162
+ }
4163
+ case RecordData::Cat: {
4164
+ auto ptr = reinterpret_cast<const nvfuser::serde::Cat *>(obj);
4165
+ return verifier.VerifyTable(ptr);
4166
+ }
4167
+ case RecordData::Dimension: {
4168
+ auto ptr = reinterpret_cast<const nvfuser::serde::Dimension *>(obj);
4169
+ return verifier.VerifyTable(ptr);
4170
+ }
4171
+ case RecordData::Dtype: {
4172
+ auto ptr = reinterpret_cast<const nvfuser::serde::Dtype *>(obj);
4173
+ return verifier.VerifyTable(ptr);
4174
+ }
4175
+ case RecordData::Norm: {
4176
+ auto ptr = reinterpret_cast<const nvfuser::serde::Norm *>(obj);
4177
+ return verifier.VerifyTable(ptr);
4178
+ }
4179
+ case RecordData::Output: {
4180
+ auto ptr = reinterpret_cast<const nvfuser::serde::Output *>(obj);
4181
+ return verifier.VerifyTable(ptr);
4182
+ }
4183
+ case RecordData::Dims: {
4184
+ auto ptr = reinterpret_cast<const nvfuser::serde::Dims *>(obj);
4185
+ return verifier.VerifyTable(ptr);
4186
+ }
4187
+ case RecordData::Slice: {
4188
+ auto ptr = reinterpret_cast<const nvfuser::serde::Slice *>(obj);
4189
+ return verifier.VerifyTable(ptr);
4190
+ }
4191
+ case RecordData::Squeeze: {
4192
+ auto ptr = reinterpret_cast<const nvfuser::serde::Squeeze *>(obj);
4193
+ return verifier.VerifyTable(ptr);
4194
+ }
4195
+ case RecordData::Reduction: {
4196
+ auto ptr = reinterpret_cast<const nvfuser::serde::Reduction *>(obj);
4197
+ return verifier.VerifyTable(ptr);
4198
+ }
4199
+ case RecordData::Scalar: {
4200
+ auto ptr = reinterpret_cast<const nvfuser::serde::Scalar *>(obj);
4201
+ return verifier.VerifyTable(ptr);
4202
+ }
4203
+ case RecordData::Size: {
4204
+ auto ptr = reinterpret_cast<const nvfuser::serde::Size *>(obj);
4205
+ return verifier.VerifyTable(ptr);
4206
+ }
4207
+ case RecordData::Tensor: {
4208
+ auto ptr = reinterpret_cast<const nvfuser::serde::Tensor *>(obj);
4209
+ return verifier.VerifyTable(ptr);
4210
+ }
4211
+ case RecordData::TensorCreationSymbolic: {
4212
+ auto ptr = reinterpret_cast<const nvfuser::serde::TensorCreationSymbolic *>(obj);
4213
+ return verifier.VerifyTable(ptr);
4214
+ }
4215
+ case RecordData::Vector: {
4216
+ auto ptr = reinterpret_cast<const nvfuser::serde::Vector *>(obj);
4217
+ return verifier.VerifyTable(ptr);
4218
+ }
4219
+ case RecordData::Welford: {
4220
+ auto ptr = reinterpret_cast<const nvfuser::serde::Welford *>(obj);
4221
+ return verifier.VerifyTable(ptr);
4222
+ }
4223
+ default: return true;
4224
+ }
4225
+ }
4226
+
4227
+ inline bool VerifyRecordDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<RecordData> *types) {
4228
+ if (!values || !types) return !values && !types;
4229
+ if (values->size() != types->size()) return false;
4230
+ for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
4231
+ if (!VerifyRecordData(
4232
+ verifier, values->Get(i), types->GetEnum<RecordData>(i))) {
4233
+ return false;
4234
+ }
4235
+ }
4236
+ return true;
4237
+ }
4238
+
4239
+ inline bool VerifyPolymorphicValueData(::flatbuffers::Verifier &verifier, const void *obj, PolymorphicValueData type) {
4240
+ switch (type) {
4241
+ case PolymorphicValueData::NONE: {
4242
+ return true;
4243
+ }
4244
+ case PolymorphicValueData::Scalar: {
4245
+ auto ptr = reinterpret_cast<const nvfuser::serde::Scalar *>(obj);
4246
+ return verifier.VerifyTable(ptr);
4247
+ }
4248
+ case PolymorphicValueData::ScalarCpu: {
4249
+ auto ptr = reinterpret_cast<const nvfuser::serde::ScalarCpu *>(obj);
4250
+ return verifier.VerifyTable(ptr);
4251
+ }
4252
+ case PolymorphicValueData::TensorArg: {
4253
+ auto ptr = reinterpret_cast<const nvfuser::serde::TensorArg *>(obj);
4254
+ return verifier.VerifyTable(ptr);
4255
+ }
4256
+ default: return true;
4257
+ }
4258
+ }
4259
+
4260
+ inline bool VerifyPolymorphicValueDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<PolymorphicValueData> *types) {
4261
+ if (!values || !types) return !values && !types;
4262
+ if (values->size() != types->size()) return false;
4263
+ for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
4264
+ if (!VerifyPolymorphicValueData(
4265
+ verifier, values->Get(i), types->GetEnum<PolymorphicValueData>(i))) {
4266
+ return false;
4267
+ }
4268
+ }
4269
+ return true;
4270
+ }
4271
+
4272
+ inline const nvfuser::serde::FusionCache *GetFusionCache(const void *buf) {
4273
+ return ::flatbuffers::GetRoot<nvfuser::serde::FusionCache>(buf);
4274
+ }
4275
+
4276
+ inline const nvfuser::serde::FusionCache *GetSizePrefixedFusionCache(const void *buf) {
4277
+ return ::flatbuffers::GetSizePrefixedRoot<nvfuser::serde::FusionCache>(buf);
4278
+ }
4279
+
4280
+ inline const char *FusionCacheIdentifier() {
4281
+ return "NV01";
4282
+ }
4283
+
4284
+ inline bool FusionCacheBufferHasIdentifier(const void *buf) {
4285
+ return ::flatbuffers::BufferHasIdentifier(
4286
+ buf, FusionCacheIdentifier());
4287
+ }
4288
+
4289
+ inline bool SizePrefixedFusionCacheBufferHasIdentifier(const void *buf) {
4290
+ return ::flatbuffers::BufferHasIdentifier(
4291
+ buf, FusionCacheIdentifier(), true);
4292
+ }
4293
+
4294
+ inline bool VerifyFusionCacheBuffer(
4295
+ ::flatbuffers::Verifier &verifier) {
4296
+ return verifier.VerifyBuffer<nvfuser::serde::FusionCache>(FusionCacheIdentifier());
4297
+ }
4298
+
4299
+ inline bool VerifySizePrefixedFusionCacheBuffer(
4300
+ ::flatbuffers::Verifier &verifier) {
4301
+ return verifier.VerifySizePrefixedBuffer<nvfuser::serde::FusionCache>(FusionCacheIdentifier());
4302
+ }
4303
+
4304
+ inline void FinishFusionCacheBuffer(
4305
+ ::flatbuffers::FlatBufferBuilder &fbb,
4306
+ ::flatbuffers::Offset<nvfuser::serde::FusionCache> root) {
4307
+ fbb.Finish(root, FusionCacheIdentifier());
4308
+ }
4309
+
4310
+ inline void FinishSizePrefixedFusionCacheBuffer(
4311
+ ::flatbuffers::FlatBufferBuilder &fbb,
4312
+ ::flatbuffers::Offset<nvfuser::serde::FusionCache> root) {
4313
+ fbb.FinishSizePrefixed(root, FusionCacheIdentifier());
4314
+ }
4315
+
4316
+ } // namespace serde
4317
+ } // namespace nvfuser
4318
+
4319
+ #endif // FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_