nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,298 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+ #include <exceptions.h>
10
+ #include <visibility.h>
11
+
12
+ #include <python_frontend/fusion_record.h>
13
+ #include <runtime/fusion_executor_cache.h>
14
+ #include <scheduler/compile_time_info.h>
15
+ #include <scheduler/registry.h>
16
+
17
+ #include <memory>
18
+ #include <mutex>
19
+
20
+ namespace nvfuser::python_frontend {
21
+
22
+ //! \struct UserSchedule
23
+ //! \brief A container to hold a scheduled Fusion IR as well as an executor
24
+ //! to contain the corresponding generated kernel.
25
+ struct UserSchedule {
26
+ UserSchedule(int64_t fusion_id, int64_t device_id);
27
+
28
+ //! Runtime information for schedulers
29
+ std::unique_ptr<SchedulerRuntimeInfo> runtime_info;
30
+
31
+ //! The scheduler heuristic for this UserSchedule
32
+ std::unique_ptr<SchedulerEntry> scheduler;
33
+
34
+ //! The parameters for scheduler heuristic.
35
+ std::unique_ptr<HeuristicParams> heuristic_params;
36
+
37
+ //! The compile-time data cache.
38
+ std::unique_ptr<HeuristicDataCache> data_cache;
39
+
40
+ //! Concretized, Scheduled Fusion IR
41
+ std::unique_ptr<Fusion> scheduled_fusion;
42
+
43
+ //! Generated kernel container
44
+ std::unique_ptr<KernelExecutor> executor;
45
+
46
+ //! ID of fusion in python frontend fusion cache
47
+ int64_t fusion_id_ = -1;
48
+
49
+ //! device ID for this user schedule
50
+ int64_t device_id_ = -1;
51
+
52
+ //! Get scheduler runtime info for UserSchedule
53
+ SchedulerRuntimeInfo* runtimeInfo() {
54
+ NVF_ERROR(
55
+ runtime_info != nullptr,
56
+ "Requires SchedulerRuntimeInfo to use heuristic schedulers");
57
+ return runtime_info.get();
58
+ }
59
+
60
+ //! Get Fusion for UserSchedule
61
+ Fusion* fusion() {
62
+ NVF_ERROR(
63
+ scheduled_fusion != nullptr,
64
+ "Requires Fusion to use heuristic schedulers");
65
+ return scheduled_fusion.get();
66
+ }
67
+
68
+ //! Return if we can schedule FusionDefinition with heuristic.
69
+ bool canSchedule(const SchedulerType& heuristic);
70
+
71
+ //! Return if we can schedule FusionDefinition with heuristic along with any
72
+ //! debug messages from canScheduleRejectReason.
73
+ std::tuple<bool, std::string> canScheduleDebug(
74
+ const SchedulerType& scheduler_type);
75
+
76
+ //! Create scheduler and get heuristic parameters for fusion.
77
+ HeuristicParams* computeHeuristics(SchedulerType scheduler_type);
78
+
79
+ //! Schedule fusion with selected heuristics and scheduler.
80
+ void schedule();
81
+
82
+ //! Schedule fusion with heuristic.
83
+ void scheduleWithType(SchedulerType scheduler_type);
84
+ };
85
+
86
+ //! \struct FusionSchedules
87
+ //! \brief A container for auto generated and user defined schedules
88
+ //! that correspond to compiled kernels for each complete Fusion Definition.
89
+ struct FusionSchedules {
90
+ FusionSchedules(int64_t fusion_id = 0);
91
+ Fusion* preschedFusion();
92
+
93
+ //! Schedules Automatically generated by nvFuser for dynamic inputs. (default)
94
+ //! NOTE: The FusionExecutorCache also holds the Unscheduled Fusion IR
95
+ std::unique_ptr<FusionExecutorCache> auto_gen_schedules;
96
+ //! Schedules defined by the user for specific input sizes.
97
+ //! They are also generated per device as all devices may not be the same.
98
+ //! Key: Input Encoding hash of Fusion inputs as is created by the
99
+ //! InputsIdLookup struct found inside of the FusionCache.
100
+ //! Value: A vector based on device_id of User Defined Fusion Schedules.
101
+ std::unordered_map<size_t, std::unordered_map<int, UserSchedule>>
102
+ user_def_schedules;
103
+ //! Keeps a pointer to the last scheduled Fusion IR for printing
104
+ Fusion* last_user_def_scheduled_ir;
105
+ //! Keeps a pointer to the last executed executor for printing its cuda kernel
106
+ KernelExecutor* last_user_def_executor;
107
+ //! For thread-Safe locking of Fusion Schedules
108
+ std::mutex scheds_lock;
109
+ //! ID of fusion in python frontend fusion cache
110
+ int64_t fusion_id_ = -1;
111
+ //! Fusion IDs of input arguments for FusionState
112
+ std::vector<int64_t> inputs_fid_;
113
+ //! IDs for Extents for TensorView input arguments for FusionState
114
+ std::vector<int64_t> extents_fid_;
115
+ //! Fusion IDs of output arguments for FusionState
116
+ std::vector<int64_t> outputs_fid_;
117
+ //! Map Fusion Val to its corresponding FusionDefinition index
118
+ std::unordered_map<const Val*, int64_t> map_value_to_fid_;
119
+ };
120
+
121
+ //! \struct TrieNode
122
+ //! \brief Is the container for a Node in a prefix tree or trie
123
+ //! where each node represents a statement in a fusion definition and
124
+ //! the leaf Nodes represent a complete Fusion that is cached.
125
+
126
+ struct TrieNode {
127
+ TrieNode(
128
+ RecordFunctor* rec,
129
+ TrieNode* _parent = nullptr,
130
+ size_t _fusion_id = 0);
131
+
132
+ // Queries whether the entry denotes a leaf node which also represents
133
+ // a the end of Fusion entry in the cache.
134
+ bool isTerminal() const;
135
+ //! getException returns the cached Exception raise during construction of
136
+ //! Fusion. It returns std::nullopt if the no error thrown. This function is
137
+ //! called at the end of FusionDefinition::finalizeDefinition to avoid
138
+ //! silently using a bad FusionDefinition cached in FusionCache.
139
+ std::optional<std::string> getException();
140
+ //! setException is called to record exception message thrown during
141
+ //! construction of Fusion.
142
+ void setException(const char* e);
143
+ //! Serialize TrieNode using flatbuffers
144
+ NVF_API flatbuffers::Offset<serde::TrieNode> serialize(
145
+ flatbuffers::FlatBufferBuilder& builder,
146
+ const std::map<RecordFunctor*, size_t>&
147
+ map_record_functor_to_trie_node_id);
148
+
149
+ //! An entry's primary data is the record it holds
150
+ std::unique_ptr<RecordFunctor> record;
151
+ //! A hash map of the children for the current node.
152
+ //! The hash map hashes a pointer to a RecordFunctor because
153
+ //! the hash function is virtual.
154
+ std::unordered_map<RecordFunctor*, std::unique_ptr<TrieNode>> children;
155
+ //! An index into FusionCache's vector of nvFuser object that holds an
156
+ //! unscheduled Fusion. The id is only valid if the entry is terminal.
157
+ size_t fusion_id;
158
+ //! Count of times the Entry is traversed
159
+ size_t visits;
160
+ //! Parent node for printing
161
+ TrieNode* parent;
162
+ //! For thread-Safe locking of a node
163
+ std::mutex trie_node_lock;
164
+ //! exception is used to track if we failed to create a valid fusion for
165
+ //! FusionDefinition at this given TrieNode
166
+ std::optional<std::string> exception = std::nullopt;
167
+ };
168
+
169
+ //! \class FusionCache
170
+ //! \brief A singleton class used in the nvFuser python interface
171
+ //! to manage the caching of fusions.
172
+ //!
173
+ //! The fusion cache implements a prefix tree (trie) of records in order to
174
+ //! cache fusions. A leaf of the tree with a terminal node contains a
175
+ //! container for caching the kernels generated for specific fusions.
176
+ //!
177
+ //! \todo
178
+ //! Add the ability to evict a fusion. There is currently a max number
179
+ //! of fusions that is checked to prevent a runaway case.
180
+ //!
181
+ //! \note
182
+ //! Thread-Safety is assured by the Python GIL. If a no-GIL python is used
183
+ //! then further scrutiny needs to be applied to the mutexes used to limit
184
+ //! acccess to the singleton pointer, node creation, and user schedule
185
+ //! creation. Otherwise, the Python GIL provides a natural thread based mutex
186
+ //! that does not allow for multiple threads to interact.
187
+
188
+ class FusionCache {
189
+ //! The constructor is private given the FusionCache is only constructed
190
+ //! as a singleton.
191
+ FusionCache(size_t max_fusions, std::optional<int64_t> selected_device);
192
+
193
+ public:
194
+ //! Copy and Assignment of the FusionCache is not supported
195
+ //! clang-tidy: deleted member function should be public
196
+ FusionCache(const FusionCache&) = delete;
197
+ FusionCache& operator=(const FusionCache&) = delete;
198
+
199
+ //! The next 4 public methods are the python interface methods
200
+
201
+ //! Gets a pointer to the singleton and creates a new one if necessary
202
+ NVF_API static FusionCache* get(
203
+ size_t max_fusions = 16384,
204
+ std::optional<int64_t> selected_device = std::nullopt,
205
+ bool load_from_default_workspace = true);
206
+ //! Number of fusions cached
207
+ NVF_API size_t numFusions() const;
208
+ //! Get device associated with this FusionCache
209
+ NVF_API std::optional<int64_t> deviceId() const;
210
+ //! print cache contents
211
+ NVF_API void print(std::ostream& os) const;
212
+ //! print cache stats
213
+ NVF_API void stats(std::ostream& os) const;
214
+ //! Reset Cache to an empty state
215
+ NVF_API static void reset();
216
+
217
+ //! Serialize Fusion Cache using flatbuffers
218
+ NVF_API void serialize(std::string filename) const;
219
+ //! Deserialize Fusion Cache using flatbuffers
220
+ NVF_API void deserialize(std::string filename);
221
+
222
+ //! The rest of the public methods are only used in C++
223
+
224
+ //! Thread-Unsafe: Queries the current trie node to see if a record matches
225
+ //! one of its children
226
+ NVF_API std::optional<TrieNode*> queryChildren(
227
+ TrieNode* node,
228
+ RecordFunctor* rec) const;
229
+ //! Query a Fusion's Schedules based on fusion id or cache id
230
+ FusionSchedules* queryFusionSchedules(size_t fusion_id) const;
231
+ //! Determine if a user schedule exists for given inputs.
232
+ bool existUserSchedule(
233
+ const FusionSchedules* scheds,
234
+ const at::ArrayRef<c10::IValue>& inputs,
235
+ int device);
236
+ //! Lookup the User Schedule Id and return null if one does not exist.
237
+ //! NOTE: this method cannot be const because the InputsIdLookup can
238
+ //! cause a modification to that data member for cache eviction.
239
+ std::optional<size_t> queryUserScheduleId(
240
+ const FusionSchedules* scheds,
241
+ const at::ArrayRef<c10::IValue>& inputs);
242
+ //! Lookup the User Schedule based on Id
243
+ const UserSchedule& queryUserSchedule(
244
+ const FusionSchedules* scheds,
245
+ size_t id,
246
+ int device) const;
247
+ //! Thread-Safe: Creates a child node for the current cache entry and an
248
+ //! optional fusion_id is returned if the new entry is terminal
249
+ NVF_API TrieNode* createChild(TrieNode* node, RecordFunctor* rec);
250
+ //! Lookup the User Schedule based on Id
251
+ UserSchedule* createUserSchedule(
252
+ FusionSchedules* scheds,
253
+ const at::ArrayRef<c10::IValue>& inputs,
254
+ int device,
255
+ bool overwrite_existing_schedule = false);
256
+ //! Get the root Trie ptr
257
+ NVF_API TrieNode* rootTriePtr();
258
+
259
+ private:
260
+ //! The static pointer to the FusionCache
261
+ static FusionCache* singleton_;
262
+ //! Lock for accessing the singleton by multiple threads
263
+ static std::mutex singleton_lock_;
264
+
265
+ //! The max allowed number of fusions in the cache
266
+ size_t max_fusions_;
267
+ //! A separate process is created for each device in a distributed setting.
268
+ //! Each FusionCache becomes associated with a device.
269
+ std::optional<int64_t> device_id_;
270
+ //! The root (start) of the prefix tree to start a cache look up of a given
271
+ //! fusion definition.
272
+ std::unique_ptr<TrieNode> root_;
273
+ //! A vector of nvFuser Fusion IR fusions.
274
+ std::vector<std::unique_ptr<FusionSchedules>> fusions_;
275
+ //! A vector of Terminal trie nodes for Stats collection
276
+ std::vector<TrieNode*> terminal_nodes_;
277
+
278
+ //! Items specifically to aid user defined schedules these data members
279
+ //! are for the mechanics of user schedule usage and don't make sense as
280
+ //! part of an abstraction
281
+
282
+ // Inputs for user defined schedules are encoded into an integer Id
283
+ // NOTE: I would prefer this be per FusionSchedules object but the container
284
+ // is not allowed to be copied or moved.
285
+ InputsIdLookup user_def_input_encodings_;
286
+ };
287
+
288
+ //! Serialize Fusion Cache to common workspace
289
+ //! /tmp/nvfuser_kernel_db/nvf_serde_[cuda_major]_[cuda_minor]_[nvrtc_major]_[nvrtc_minor]
290
+ //!
291
+ //! '''python
292
+ //! # Use atexit to automatically call serialize on program exit
293
+ //! import atexit
294
+ //! atexit.register(nvfuser.serialize)
295
+ //! '''
296
+ NVF_API void serialize();
297
+
298
+ } // namespace nvfuser::python_frontend
@@ -0,0 +1,372 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <functional>
11
+ #include <iostream>
12
+ #include <unordered_map>
13
+
14
+ #include <exceptions.h>
15
+ #include <python_frontend/distributed_tensor.h>
16
+ #include <python_frontend/fusion_state.h>
17
+ #include <python_frontend/segmentation.h>
18
+ #include <visibility.h>
19
+
20
+ namespace nvfuser::python_frontend {
21
+
22
+ class FusionCache;
23
+ class FusionDefinition;
24
+ class FusionInterface;
25
+ class FusionState;
26
+ struct RecordFunctor;
27
+ class SegmentationState;
28
+ struct TrieNode;
29
+ struct UserSchedule;
30
+
31
+ //! This is helper function used to print a python formated
32
+ //! Fusion IR DataType when printing a fusion definition.
33
+
34
+ NVF_API const char* dtypeToPyString(PrimDataType t);
35
+
36
+ //! The Tensor and Scalar classes are used to define separate function
37
+ //! signatures in the FusionDefinition to identify the appropriate Operator
38
+ //! function.
39
+ //!
40
+ //! Example:
41
+ //!
42
+ //! add(Tensor* arg1, Tensor* arg2) -> Tensor*
43
+ //! add(Tensor* arg1, Val* arg2) -> Tensor*
44
+ //! add(Val* arg1, Val* arg2) -> Val*
45
+ struct Tensor {
46
+ Tensor(size_t _index, size_t _dims, FusionDefinition* _fd)
47
+ : index(_index), dims(_dims), fusion_definition(_fd) {}
48
+
49
+ size_t operator()() const {
50
+ return index;
51
+ }
52
+
53
+ bool operator==(const Tensor& other) const {
54
+ if (index != other.index) {
55
+ return false;
56
+ }
57
+
58
+ if (dims != other.dims) {
59
+ return false;
60
+ }
61
+
62
+ if (fusion_definition != other.fusion_definition) {
63
+ return false;
64
+ }
65
+ return true;
66
+ }
67
+
68
+ bool operator!=(const Tensor& other) const {
69
+ return !(*this == other);
70
+ }
71
+
72
+ //! A unique index to identifiy each recorded state item.
73
+ size_t index;
74
+ size_t dims;
75
+
76
+ //! Pointer to the FusionDefinition used to create this tensor
77
+ //! The FusionDefinition pointer is necessary to enable special
78
+ //! dunder operations (ie __add__()) from the python API.
79
+ FusionDefinition* fusion_definition;
80
+ };
81
+
82
+ struct Scalar {
83
+ Scalar(size_t _index, FusionDefinition* _fd)
84
+ : index(_index), fusion_definition(_fd) {}
85
+
86
+ size_t operator()() const {
87
+ return index;
88
+ }
89
+
90
+ bool operator==(const Scalar& other) const {
91
+ if (index != other.index) {
92
+ return false;
93
+ }
94
+
95
+ if (fusion_definition != other.fusion_definition) {
96
+ return false;
97
+ }
98
+ return true;
99
+ }
100
+
101
+ bool operator!=(const Scalar& other) const {
102
+ return !(*this == other);
103
+ }
104
+
105
+ //! A unique index to identifiy each recorded state item.
106
+ size_t index;
107
+
108
+ //! Pointer to the FusionDefinition used to create this scalar
109
+ //! The FusionDefinition pointer is necessary to enable special
110
+ //! dunder operations (ie __add__()) from the python API.
111
+ FusionDefinition* fusion_definition;
112
+ };
113
+
114
+ struct Vector {
115
+ Vector(size_t _index, size_t _size, FusionDefinition* _fd)
116
+ : index(_index), size(_size), fusion_definition(_fd) {}
117
+
118
+ size_t operator()() const {
119
+ return index;
120
+ }
121
+
122
+ bool operator==(const Vector& other) const {
123
+ if (index != other.index) {
124
+ return false;
125
+ }
126
+
127
+ if (size != other.size) {
128
+ return false;
129
+ }
130
+
131
+ if (fusion_definition != other.fusion_definition) {
132
+ return false;
133
+ }
134
+ return true;
135
+ }
136
+
137
+ bool operator!=(const Vector& other) const {
138
+ return !(*this == other);
139
+ }
140
+
141
+ //! A unique index to identifiy each recorded state item.
142
+ size_t index;
143
+ //! Elements in the vector
144
+ size_t size;
145
+
146
+ //! Pointer to the FusionDefinition used to create this scalar
147
+ FusionDefinition* fusion_definition;
148
+ };
149
+
150
+ //! FusionDefinition defines the C++ side of a Python Context manager to
151
+ //! encapsulate the definition of fusion operations.
152
+ //!
153
+ //! The FusionDefinition records the state definitions and operations prior
154
+ //! to exiting the context manager. Upon exit, the operations are queried
155
+ //! in a cache and the recorded records are used to build an nvFuser Fusion
156
+ //! object if the definition missed in the cache.
157
+ //!
158
+ //! The nested Operators class was designed to allow the user to query all the
159
+ //! available Operators in the FusionDefinition via python help.
160
+ //!
161
+ //! Example:
162
+ //! help(FusionDefinition.Operators)
163
+ class NVF_API FusionDefinition : public FusionState {
164
+ public:
165
+ FusionDefinition(std::optional<size_t> id, size_t max_length = 256);
166
+
167
+ // The copy/move/assign constructors/operators are removed
168
+ FusionDefinition(const FusionDefinition& fd) = delete;
169
+ FusionDefinition(FusionDefinition&& fd) = delete;
170
+ FusionDefinition& operator=(const FusionDefinition& fd) = delete;
171
+ FusionDefinition& operator=(FusionDefinition&& fd) = delete;
172
+
173
+ //! Enter Python Context Manager -- Reset trie for new cache lookup
174
+ NVF_API FusionDefinition* setupDefinition();
175
+ //! Exit Python Context Manager -- Triggers Fusion IR build if it is not
176
+ //! cached
177
+ NVF_API void finalizeDefinition();
178
+ //! Check that a user schedule exists for FusionDefinition and input
179
+ //! arguments on device.
180
+ NVF_API bool existSchedule(const at::ArrayRef<c10::IValue>& inputs);
181
+ //! Setup user scheduling of a fusion
182
+ //! Copies fusion object and sets up FusionGuard
183
+ NVF_API void setupSchedule(
184
+ const at::ArrayRef<c10::IValue>& inputs,
185
+ bool overwrite_existing_schedule = false);
186
+ //! Finalized use scheduling of a fusion
187
+ //! resets FusionGuard, lowers IR to a kernel, compiles kernel
188
+ NVF_API void finalizeSchedule(const at::ArrayRef<c10::IValue>& inputs);
189
+ //! A hook that gets called right before
190
+ //! FusionDefinition.multidevice_schedule.
191
+ NVF_API void setupMultideviceSchedule();
192
+ //! A hook that gets called right after FusionDefinition.multidevice_schedule.
193
+ NVF_API void finalizeMultideviceSchedule();
194
+ //! Prints a python function representing the definition
195
+ NVF_API void print(std::ostream& os) const;
196
+ //! Executes a fusion if a valid definition or cache lookup occurred prior.
197
+ //!
198
+ //! This method returns a list of `DistributedTensor`s. Each
199
+ //! `DistributedTensor` is either the local view of a distributed tensor
200
+ //! (when the mesh is non-empty) or a non-distributed tensor
201
+ //! (when the mesh is empty).
202
+ //!
203
+ //! Alternatives considered:
204
+ //! 1. Return std::vector<std::variant<at::Tensor, DistributedTensor>>.
205
+ //! Because DistributedTensor can also represent a non-distributed tensor, I
206
+ //! chose the current API for simplicity -- C++ is more verbose than Python
207
+ //! when dealing with dynamic types.
208
+ //! 2. Return std::variant<std::vector<at::Tensor>,
209
+ //! std::vector<DistributedTensor>>. Same reason.
210
+ //! 3. Store output shardings (i.e. the mesh and the mesh-to-tenseor-axis
211
+ //! mapping) to a field of FusionDefinition and retrieve it using another
212
+ //! method. This would be similar to getDebugOutput. I didn't choose that
213
+ //! because it introduced a new state in the class that could get out of sync.
214
+ NVF_API std::vector<DistributedTensor> execute(
215
+ const at::ArrayRef<c10::IValue>& inputs,
216
+ std::optional<int8_t> device,
217
+ bool override_user_schedule,
218
+ bool capture_debug_output,
219
+ bool profile,
220
+ std::vector<std::string> _enable_options,
221
+ std::vector<std::string> _disable_options) const;
222
+ //! Return debugging output captured through exeuction with
223
+ //! capture_debug_output=true
224
+ std::optional<std::string> getDebugOutput() const {
225
+ return debug_output_;
226
+ }
227
+ // Returns the tolerances values based on reduction sizes.
228
+ NVF_API std::vector<std::pair<double, double>> getValTolerances(
229
+ const at::ArrayRef<c10::IValue>& inputs);
230
+
231
+ //! Return the unscheduled Fusion IR
232
+ NVF_API std::string fusionIr();
233
+ //! Return the user scheduled FusionIR;
234
+ NVF_API std::string userScheduleIr();
235
+ //! Return the Cuda code for the last executed set of inputs
236
+ NVF_API std::string lastCudaCode(
237
+ bool intrinsic_code,
238
+ bool override_user_schedule) const;
239
+ //! Return the Cuda code for the given inputs
240
+ NVF_API std::string cudaCodeFor(
241
+ const at::ArrayRef<c10::IValue>& inputs,
242
+ bool intrinsic_code,
243
+ bool override_user_schedule) const;
244
+ //! Return the Cuda code for the last executed set of inputs
245
+ NVF_API std::string lastScheduledFusionIr(
246
+ bool tensor_transforms,
247
+ bool override_user_schedule) const;
248
+ //! Return the Cuda code for the given inputs
249
+ NVF_API std::string scheduledFusionIrFor(
250
+ const at::ArrayRef<c10::IValue>& inputs,
251
+ bool tensor_transforms,
252
+ bool override_user_schedule) const;
253
+ //! Return fusion id of defined FusionDefinition
254
+ NVF_API std::optional<size_t> id() const;
255
+ //! Prints the Prescheduled Fusion IR representation
256
+ void printMathIr();
257
+
258
+ bool completed() {
259
+ return id().has_value();
260
+ }
261
+
262
+ //! Return a prescheduled Fusion object
263
+ Fusion* preschedFusion();
264
+
265
+ //! Return UserSchedule struct if it exists
266
+ UserSchedule* userSchedule();
267
+
268
+ //! These methods are used to record the FusionDefinition for cache lookup
269
+
270
+ //! Defines a Tensor State Record
271
+ NVF_API Tensor addTensor(TensorView* tv);
272
+ //! Defines a Scalar State Record
273
+ NVF_API Scalar defineScalar();
274
+ //! Defines a Tensor State Record
275
+ NVF_API Tensor defineTensor(size_t dims);
276
+ //! Defines a Vector State Record
277
+ NVF_API Vector defineVector(size_t size);
278
+ //! Defines a Record that records the operation required to
279
+ //! build the corresponding Fusion IR operation on cache miss.
280
+ NVF_API void defineRecord(RecordFunctor* record);
281
+ //! Gets a Record State object
282
+ NVF_API State recordingState(size_t index) const;
283
+ //! Get all Tensors in FusionState.
284
+ NVF_API std::vector<Tensor> tensors();
285
+
286
+ //! Run segmentation algorithm on FusionDefinition. Returns the number of
287
+ //! segments.
288
+ NVF_API int64_t setupSegmentation(const at::ArrayRef<c10::IValue>& inputs);
289
+ //! Given an empty FusionDefinition and a segment id, buildSegment creates the
290
+ //! CPP Fusion, translates it to the python FusionDefinition, then return a
291
+ //! mapping from segment fusion state indices to the original fusion state
292
+ //! indices.
293
+ NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
294
+ FusionDefinition& segment_fd,
295
+ int64_t segment_id);
296
+ //! After creating segments, destroy SegmentationState.
297
+ NVF_API void finalizeSegmentation();
298
+
299
+ private:
300
+ //! Returns the FusionCache Ptr that holds the cache of Fusions
301
+ FusionCache* fusionCache() const;
302
+ //! Composite operations can create hidden TensorViews in the CPP fusion
303
+ //! These TensorViews are not visible from python definition. This function
304
+ //! finds and adds them to FusionDefinition
305
+ void findHiddenTensorViews(Fusion* fusion);
306
+ //! Update Symbolic FusionStates after DynamicTransform pass
307
+ void updateSymbolicStates(
308
+ const std::unordered_map<Val*, Val*>& symbolic_to_concretized_map);
309
+ // Check that the NvFuser TensorView and the Python Tensor dimensions match.
310
+ // Apply after buildFusionIr
311
+ void verifyTensorDimensions();
312
+
313
+ //! Holds the defined maximum length of a FusionDefinition in order to
314
+ //! prevent a run away error. The user should feel free to increase this
315
+ //! number as appropriate.
316
+ size_t max_length_;
317
+ //! Fusion Cache Id for Scheduled Fusion.
318
+ std::optional<size_t> fusion_id_;
319
+ //! A pointer to the FusionCache.
320
+ FusionCache* fusion_cache_;
321
+ //! Current pointer to node in FusionCache.
322
+ TrieNode* trie_node_;
323
+
324
+ // Book keeping data members for user created schedules
325
+
326
+ //! Data member for holding previous fusion container when manually setting
327
+ //! the fusion guard.
328
+ Fusion* prev_fusion_;
329
+ //! Data member for holding the current user schedule object
330
+ UserSchedule* user_sched_;
331
+ //! Number of recording_states_ before applying user schedule
332
+ int64_t num_recording_states_presched_ = 0;
333
+ //! Data member that creates SegmentedFusion from cloned, prescheduled Fusion
334
+ //! then translates the segments to python FusionDefinitions.
335
+ std::unique_ptr<SegmentationState> segmentation_state_;
336
+
337
+ public:
338
+ //! The Operators are not directly defined in this header. They are defined
339
+ //! in the python bindings through lambda functions so the user only needs to
340
+ //! define new operators in one place.
341
+ //! Operators define what operations are fused.
342
+ struct Operators {
343
+ Operators(FusionDefinition* fd) : fusion_definition(fd) {}
344
+ bool validUse() const {
345
+ return !fusion_definition->completed();
346
+ }
347
+
348
+ FusionDefinition* fusion_definition;
349
+ };
350
+
351
+ //! The SchedOperators are not directly defined in this header. They are
352
+ //! defined in the python bindings through lambda functions so the user only
353
+ //! needs to define new operators in one place.
354
+ //! SchedOperators allow the user to define how a fusion should be blocked
355
+ //! for execution.
356
+ struct SchedOperators {
357
+ SchedOperators(FusionDefinition* fd) : fusion_definition(fd) {}
358
+ bool validUse() const {
359
+ return fusion_definition->completed();
360
+ }
361
+
362
+ FusionDefinition* fusion_definition;
363
+ };
364
+
365
+ Operators ops;
366
+ SchedOperators sched;
367
+
368
+ private:
369
+ mutable std::optional<std::string> debug_output_ = std::nullopt;
370
+ };
371
+
372
+ } // namespace nvfuser::python_frontend