nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,27 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ATen/ATen.h>
11
+
12
+ namespace nvfuser {
13
+
14
+ //! This returns a slice of a thread local at::Tensor that contains all zeroes.
15
+ //! Uses of this memory should always "clean up" by resetting the memory to zero
16
+ //! at the end of the kernel.
17
+ at::Tensor contigZeroedTensor(
18
+ const std::vector<int64_t>& sizes,
19
+ const c10::ScalarType& aten_dtype,
20
+ const c10::Device& device);
21
+
22
+ //! This should be called after each kernel launch to allow subsequent launches
23
+ //! to re-use allocated memory. Note that it does not free allocated zeroed
24
+ //! memory, but rather it marks all zeroed memory as available for re-use.
25
+ void releaseZeroedMemory();
26
+
27
+ } // namespace nvfuser
@@ -0,0 +1,47 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/all_nodes.h>
12
+ #include <visibility.h>
13
+
14
+ namespace nvfuser {
15
+
16
+ //! Horizontally fuse multiple reductions.
17
+ //!
18
+ //! Given a list of tensors produced by ReductionOp, create a new
19
+ //! GroupedReductionOp expression that takes the input tensors of the
20
+ //! original reductions and produces the given tensors, replacing
21
+ //! their defining expressions.
22
+ //!
23
+ //! GroupedReductionOp works just like ReductionOp with a potential
24
+ //! benefit of aggregating synchronizations across individual
25
+ //! reductions. See the reduction::gridReduce2 runtime function for a
26
+ //! two-input version of grid reduction.
27
+ //!
28
+ //! The grouped reductions must follow several constraints, which
29
+ //! include:
30
+ //! - There must not exist any data dependency between individual
31
+ //! reductions.
32
+ //! - All reduction output tensors must have the same number of
33
+ //! dimensions, the same transformations and the same axes to
34
+ //! reduce.
35
+ //!
36
+ //! Note that Welford is not allowed yet, though it should be
37
+ //! technically straightforward to support horizontal fusions of
38
+ //! welford ops. Unclear how common it would be in practice, though.
39
+ //!
40
+ //! \param reduction_outputs Tensors produced by ReductionOp
41
+ //! \param error_on_failure Throw an exception if an error is detected
42
+ //! \return True if successfully grouped
43
+ NVF_API bool groupReductions(
44
+ const std::vector<TensorView*>& reduction_outputs,
45
+ bool error_on_failure = true);
46
+
47
+ } // namespace nvfuser
@@ -0,0 +1,60 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <fusion.h>
11
+ #include <host_ir/host_ir.h>
12
+
13
+ namespace nvfuser {
14
+
15
+ class KernelExecutor;
16
+
17
+ namespace hir {
18
+
19
+ /*
20
+ HostIrContainer is used to represent a host program.
21
+ 1) It inherits from Fusion, so that (Host) IRs can be resgistered to it.
22
+ 2) It holds a vector of Host Expressions `top_level_exprs_` that represent the
23
+ host program. For now, this vector is manually managed. Moreover, because we use
24
+ a vector as data structure, top_level_exprs_ can only represent linear Host
25
+ programs. Later, we it should support non-linear program having a DAG structure.
26
+ */
27
+
28
+ class HostIrContainer final : public Fusion {
29
+ public:
30
+ HostIrContainer() = default;
31
+ HostIrContainer(const HostIrContainer&) = delete;
32
+ HostIrContainer& operator=(const HostIrContainer&) = delete;
33
+
34
+ // Do not have a definition here as it requires the definition of
35
+ // KernelExecutor due to kernel_executors_.
36
+ // NOLINTNEXTLINE (modernize-use-equals-default)
37
+ ~HostIrContainer() override;
38
+
39
+ //! Print to an output stream
40
+ std::ostream& print(std::ostream& os) const;
41
+
42
+ const std::vector<Expr*>& topLevelExprs() const;
43
+
44
+ void pushBackTopLevelExprs(Expr* expr);
45
+
46
+ void pushBackKernelExecutor(std::unique_ptr<KernelExecutor> ke);
47
+
48
+ KernelExecutor* getKernelExecutor(int64_t index) const;
49
+
50
+ Stream* getDefaultStream();
51
+
52
+ private:
53
+ std::vector<Expr*> top_level_exprs_;
54
+ std::vector<std::unique_ptr<KernelExecutor>> kernel_executors_;
55
+ Stream* default_stream_ = nullptr;
56
+ };
57
+
58
+ } // namespace hir
59
+
60
+ } // namespace nvfuser
@@ -0,0 +1,152 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <dispatch.h>
11
+ #include <expr_evaluator.h>
12
+ #include <host_ir/container.h>
13
+ #include <host_ir/host_ir.h>
14
+ #include <multidevice/communicator.h>
15
+ #include <runtime/executor.h>
16
+ #include <runtime/executor_abstract.h>
17
+ #include <runtime/executor_params.h>
18
+ #include <runtime/fusion_executor_cache.h>
19
+
20
+ #include <c10/cuda/CUDAStream.h>
21
+
22
+ namespace nvfuser {
23
+
24
+ class HostIrExecutor : public ExecutorAbstract {
25
+ public:
26
+ HostIrExecutor(
27
+ int64_t fusion_id = 0,
28
+ int64_t concrete_id = 0,
29
+ int64_t runtime_id = 0,
30
+ int64_t group_id = 0);
31
+
32
+ static bool supported(Fusion* fusion);
33
+
34
+ void compile(Fusion* fusion);
35
+
36
+ bool isCompiled() const override;
37
+
38
+ NVF_API std::vector<at::Tensor> run(
39
+ KernelArgumentHolder& args,
40
+ std::vector<at::Tensor> outputs = {});
41
+
42
+ const std::unique_ptr<hir::HostIrContainer>& hostContainer() const {
43
+ return host_ir_container_;
44
+ }
45
+
46
+ private:
47
+ std::unique_ptr<hir::HostIrContainer> host_ir_container_;
48
+ Communicator* communicator_;
49
+ };
50
+
51
+ namespace hir {
52
+
53
+ /*
54
+ a HostIrEvaluator evaluates a host programs represented through a
55
+ HostIrContainer It is instantiated with the desired HostIrContainer, and runs
56
+ the Host program with concrete inputs by calling the method runWithInput.
57
+
58
+ For now HostIrEvaluator is an interpreter; later we could rather compile host
59
+ code.
60
+
61
+ Note: most of the implementation is copy pasted for MultiDeviceExecutor. This
62
+ duplication will be resolved in the future.
63
+ */
64
+
65
+ // Set of parameters that control the behavior of HostIrEvaluator
66
+ struct HostIrEvaluatorParams {
67
+ // Experimental: whether to use FusionExecutorCache rather than
68
+ // KernelExecutor.
69
+ bool use_fusion_executor_cache = false;
70
+ // Experimental: whether to apply auto-scheduling in FusionExecutorCache if
71
+ // use_fusion_executor_cache=true. WAR: temporary hack mainly use for
72
+ // development
73
+ bool skip_auto_scheduling = false;
74
+ // Experimental: whether to cache fusion executor. WAR: avoid recompilation
75
+ // but implicitely assumes that the input shape don't change over iterations
76
+ bool cache_fusion_executor = false;
77
+ // number of additional cuda streams to use at runtime for comm+compute
78
+ // pipelining
79
+ int64_t number_of_streams = 4;
80
+ };
81
+
82
+ class HostIrEvaluator final : public OptOutDispatch {
83
+ public:
84
+ HostIrEvaluator(
85
+ std::unique_ptr<HostIrContainer> container,
86
+ Communicator* communicator = nullptr,
87
+ HostIrEvaluatorParams = HostIrEvaluatorParams());
88
+ std::vector<at::Tensor> runWithInput(
89
+ std::unordered_map<Val*, c10::IValue> val_to_IValue);
90
+
91
+ const std::vector<Val*>& inputs() {
92
+ return container_->inputs();
93
+ }
94
+
95
+ const std::vector<Val*>& outputs() {
96
+ return container_->outputs();
97
+ }
98
+
99
+ std::ostream& print(std::ostream& os) const {
100
+ return container_->print(os);
101
+ };
102
+
103
+ const auto& getFusionExecutorCaches() {
104
+ return fec_;
105
+ };
106
+
107
+ const auto& getCudaStreams() {
108
+ return streams_;
109
+ }
110
+
111
+ // check if the runtime is valid returns an error msg.
112
+ // An empty message means that the runtime is valid
113
+ std::string canRun() const;
114
+
115
+ private:
116
+ using OptOutDispatch::handle;
117
+ void handle(SetCurrentStream* set_current_stream) override;
118
+ void handle(GetCurrentStream* get_current_stream) override;
119
+ void handle(Synchronize* synchronize) override;
120
+ void handle(PostOnStream* post_ir) override;
121
+ void handle(LaunchKernel* post_ir) override;
122
+ void handle(Communication* communication) override;
123
+ void handle(P2PCommunication* communication) override;
124
+ void handle(Wait* wait) override;
125
+ void handle(ForLoop* for_loop) override;
126
+ void handle(StartCoalescing* start_coalescing) override;
127
+ void handle(EndCoalescing* end_coalescing) override;
128
+ void handle(kir::IfThenElse* if_then_else) override;
129
+ void handle(MatmulOp* matmul) override;
130
+ void handle(LinearOp* linear) override;
131
+ void handle(kir::Allocate* allocate) override;
132
+ void unhandled(Statement* stmt) override;
133
+
134
+ c10::cuda::CUDAStream getCUDAStream(Stream* stream);
135
+
136
+ std::unique_ptr<HostIrContainer> container_;
137
+ Communicator* communicator_;
138
+ HostIrEvaluatorParams params_;
139
+ // Stores concrete computed values
140
+ ExpressionEvaluator expr_evaluator_;
141
+ // Cache Fusions, KernelExecutors
142
+ std::unordered_map<HostUnit*, std::unique_ptr<ExecutorAbstract>> executors_;
143
+ std::unordered_map<HostUnit*, FusionExecutorCache> fec_;
144
+ using StreamKey = std::variant<int64_t, Stream*>;
145
+ std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
146
+ std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
147
+ const int64_t my_device_index_;
148
+ };
149
+
150
+ } // namespace hir
151
+
152
+ } // namespace nvfuser
@@ -0,0 +1,320 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <fusion.h>
11
+ #include <ir/base_nodes.h>
12
+ #include <ir/builder.h>
13
+ #include <multidevice/communication.h>
14
+ #include <atomic>
15
+
16
+ namespace nvfuser {
17
+
18
+ namespace hir {
19
+
20
+ /*
21
+ Host Irs are used to represent a host program. They need to be registered in a
22
+ HostIrContainer. Each Ir represents a Host data or instruction.
23
+ */
24
+
25
+ /*
26
+ HostUnit represents a Fusion in the Host Program. In other words, it
27
+ represents a compute graph (or a segment of a larger compute graph)
28
+ represented by a Fusion that should be compiled and executed as a bulked item
29
+ from the host perspective.
30
+
31
+ This IR can be thought as a thin layer around the class `Fusion`, which
32
+ furthermore inherits from `Expr` so that it is an "IR" in nvFuser IR
33
+ semantics.
34
+
35
+ This IRs fundamentally allows nested IR structures. It could potentially be
36
+ useful in other instances than HostIrs.
37
+
38
+ Its implementation is minimal, the only specifity being the moethod
39
+ `fusion_to_execute()` that returns the fusion that the IR represents.
40
+
41
+ Note: HostUnit has no I/O itself -- however the Fusion it embbeds has I/O of
42
+ course, which are not registered in the surrounding HostIrContainer.
43
+
44
+ Note: Whether HostUnit should inherit from Expr or Val is debatable. Both are
45
+ possible, I define it as an Expr for now here but am open to change it.
46
+ */
47
+ class HostUnit : public Expr {
48
+ public:
49
+ using Expr::Expr;
50
+ HostUnit(IrBuilderPasskey passkey, std::unique_ptr<Fusion> fusion);
51
+ HostUnit(const HostUnit* src, IrCloner* ir_cloner);
52
+
53
+ HostUnit(const HostUnit& other) = delete;
54
+ HostUnit& operator=(const HostUnit& other) = delete;
55
+ HostUnit(HostUnit&& other) = delete;
56
+ HostUnit& operator=(HostUnit&& other) = delete;
57
+
58
+ NVFUSER_DECLARE_CLONE_AND_CREATE
59
+ std::string toString(int indent_size = 0) const override;
60
+ std::string toInlineString(int indent_size = 0) const override;
61
+ const char* getOpString() const override {
62
+ return "hir::HostUnit";
63
+ }
64
+
65
+ bool sameAs(const Statement* other) const override;
66
+
67
+ Fusion* fusion_to_execute() const {
68
+ return fusion_.get();
69
+ }
70
+
71
+ private:
72
+ std::unique_ptr<Fusion> fusion_;
73
+ };
74
+
75
+ /*
76
+ PostOnStream represents the host instruction of executing a HostUnit. Its I/O
77
+ represents in the host program the concrete I/O that will be bound at runtime
78
+ to the Fusion's I/O for compilation and execution. At runtime, PostOnStream
79
+ will compile and launch the kernel lowered from the HostUnit's embedded
80
+ Fusion.
81
+
82
+ Note: later PostOnStream will take a "Stream" argument
83
+
84
+ Note: later PostOnStream will also be able to launch network Communications
85
+
86
+ Note: later compilation and kernel launch will be separated and represented by
87
+ distinct Host IRs
88
+ */
89
+ class PostOnStream : public Expr {
90
+ public:
91
+ using Expr::Expr;
92
+ PostOnStream(
93
+ IrBuilderPasskey passkey,
94
+ Expr* host_op,
95
+ std::vector<Val*> inputs,
96
+ std::vector<Val*> outputs);
97
+
98
+ PostOnStream(const PostOnStream& other) = delete;
99
+ PostOnStream& operator=(const PostOnStream& other) = delete;
100
+ PostOnStream(PostOnStream&& other) = delete;
101
+ PostOnStream& operator=(PostOnStream&& other) = delete;
102
+
103
+ NVFUSER_DECLARE_CLONE_AND_CREATE
104
+
105
+ std::string toString(int indent_size = 0) const override;
106
+ std::string toInlineString(int indent_size = 0) const override;
107
+ const char* getOpString() const override {
108
+ return "hir::PostOnStream";
109
+ }
110
+
111
+ bool sameAs(const Statement* other) const override;
112
+
113
+ Expr* hostOpToPost() const {
114
+ return attributes_.at(0)->as<Expr>();
115
+ }
116
+ };
117
+
118
+ class LaunchKernel : public Expr {
119
+ public:
120
+ using Expr::Expr;
121
+ LaunchKernel(
122
+ IrBuilderPasskey passkey,
123
+ int64_t hic_executor_index, // Index into the HostIrContainer's vector of
124
+ // KernelExecutors--i.e., the kernel this IR
125
+ // should launch
126
+ const std::vector<Val*>& inputs,
127
+ const std::vector<Val*>& outputs);
128
+
129
+ LaunchKernel(const LaunchKernel& other) = delete;
130
+ LaunchKernel& operator=(const LaunchKernel& other) = delete;
131
+ LaunchKernel(LaunchKernel&& other) = delete;
132
+ LaunchKernel& operator=(LaunchKernel&& other) = delete;
133
+
134
+ NVFUSER_DECLARE_CLONE_AND_CREATE
135
+
136
+ std::string toString(int indent_size = 0) const override;
137
+ std::string toInlineString(int indent_size = 0) const override;
138
+ const char* getOpString() const override {
139
+ return "hir::LaunchKernel";
140
+ }
141
+
142
+ int64_t getIndex() const {
143
+ return attribute<int64_t>(0);
144
+ }
145
+ };
146
+
147
+ class Stream : public Val {
148
+ public:
149
+ // if index is provided, the IR represents the streams whose index is the
150
+ // dynamic value of that index. Otherwise, it statically represents a new
151
+ // Stream.
152
+ Stream(IrBuilderPasskey passkey, Val* index = nullptr);
153
+ Stream(const Stream* src, IrCloner* ir_cloner);
154
+ bool sameAs(const Statement* other) const override;
155
+
156
+ NVFUSER_DECLARE_CLONE
157
+ std::string toString(int indent_size = 0) const override;
158
+ std::string toInlineString(int indent_size = 0) const override;
159
+
160
+ Val* index() const {
161
+ return index_;
162
+ }
163
+
164
+ private:
165
+ Val* index_ = nullptr;
166
+ };
167
+
168
+ class SetCurrentStream : public Expr {
169
+ public:
170
+ using Expr::Expr;
171
+ SetCurrentStream(IrBuilderPasskey passkey, Stream* stream);
172
+
173
+ SetCurrentStream(const SetCurrentStream& other) = delete;
174
+ SetCurrentStream& operator=(const SetCurrentStream& other) = delete;
175
+ SetCurrentStream(SetCurrentStream&& other) = delete;
176
+ SetCurrentStream& operator=(SetCurrentStream&& other) = delete;
177
+
178
+ NVFUSER_DECLARE_CLONE_AND_CREATE
179
+
180
+ std::string toString(int indent_size = 0) const override;
181
+ std::string toInlineString(int indent_size = 0) const override;
182
+ const char* getOpString() const override {
183
+ return "hir::SetCurrentStream";
184
+ }
185
+
186
+ bool sameAs(const Statement* other) const override;
187
+
188
+ Stream* stream() const {
189
+ return attributes_.at(0)->as<Stream>();
190
+ }
191
+ };
192
+
193
+ class GetCurrentStream : public Expr {
194
+ public:
195
+ using Expr::Expr;
196
+ GetCurrentStream(IrBuilderPasskey passkey);
197
+
198
+ GetCurrentStream(const GetCurrentStream& other) = delete;
199
+ GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
200
+ GetCurrentStream(GetCurrentStream&& other) = delete;
201
+ GetCurrentStream& operator=(GetCurrentStream&& other) = delete;
202
+
203
+ NVFUSER_DECLARE_CLONE_AND_CREATE
204
+
205
+ std::string toString(int indent_size = 0) const override;
206
+ const char* getOpString() const override {
207
+ return "hir::GetCurrentStream";
208
+ }
209
+
210
+ Stream* stream() const {
211
+ return attributes_.at(0)->as<Stream>();
212
+ }
213
+ };
214
+
215
+ class Wait : public Expr {
216
+ public:
217
+ using Expr::Expr;
218
+ Wait(IrBuilderPasskey passkey, Expr* expr);
219
+
220
+ Wait(const Wait& other) = delete;
221
+ Wait& operator=(const Wait& other) = delete;
222
+ Wait(Wait&& other) = delete;
223
+ Wait& operator=(Wait&& other) = delete;
224
+
225
+ NVFUSER_DECLARE_CLONE_AND_CREATE
226
+
227
+ std::string toString(int indent_size = 0) const override;
228
+ std::string toInlineString(int indent_size = 0) const override;
229
+ const char* getOpString() const override {
230
+ return "hir::Wait";
231
+ }
232
+
233
+ bool sameAs(const Statement* other) const override;
234
+
235
+ Expr* communication() const {
236
+ return attributes_.at(0)->as<Expr>();
237
+ }
238
+ };
239
+
240
+ // Makes the current stream wait on the given stream. Non-blocking from the host
241
+ // point of view.
242
+ class Synchronize : public Expr {
243
+ public:
244
+ using Expr::Expr;
245
+ Synchronize(IrBuilderPasskey passkey, Stream* stream);
246
+
247
+ Synchronize(const Synchronize& other) = delete;
248
+ Synchronize& operator=(const Synchronize& other) = delete;
249
+ Synchronize(Synchronize&& other) = delete;
250
+ Synchronize& operator=(Synchronize&& other) = delete;
251
+
252
+ NVFUSER_DECLARE_CLONE_AND_CREATE
253
+
254
+ std::string toString(int indent_size = 0) const override;
255
+ std::string toInlineString(int indent_size = 0) const override;
256
+ const char* getOpString() const override {
257
+ return "hir::Synchronize";
258
+ }
259
+
260
+ bool sameAs(const Statement* other) const override;
261
+
262
+ Stream* stream() const {
263
+ return attributes_.at(0)->as<Stream>();
264
+ }
265
+ };
266
+
267
+ // For ProcessGroupNCCL, startCoalescing and endCoalescing correspond to
268
+ // ncclGroupStart and ncclGroupEnd respectively. Those calls group p2p calls
269
+ // that need to be progressed together -- one global work handle returned by
270
+ // endCoalescing needs to be progressed. This has the following main advantages:
271
+ // 1) calls are progressed concurrently
272
+ // 2) since NICs are two-sided, a send and a recv calls need to be coalesced to
273
+ // achieve full BW.
274
+ // 3) If not coalesced, we can easily reach a deadlock if the
275
+ // send/recv pairs are not ordered correctly.
276
+ // It is in general preferable to coalesce send/recv calls. The only drawback is
277
+ // that we don't have a fine-grain control on synchronicity, in other words, we
278
+ // can only synchronize with the grouped communication at once.
279
+ // Remark: ProcessGroupUCC does not implement coalesced groups for now
280
+ class StartCoalescing : public Expr {
281
+ public:
282
+ using Expr::Expr;
283
+ StartCoalescing(IrBuilderPasskey passkey);
284
+
285
+ StartCoalescing(const StartCoalescing& other) = delete;
286
+ StartCoalescing& operator=(const StartCoalescing& other) = delete;
287
+ StartCoalescing(StartCoalescing&& other) = delete;
288
+ StartCoalescing& operator=(StartCoalescing&& other) = delete;
289
+
290
+ NVFUSER_DECLARE_CLONE_AND_CREATE
291
+
292
+ std::string toString(int indent_size = 0) const override;
293
+ std::string toInlineString(int indent_size = 0) const override;
294
+ const char* getOpString() const override {
295
+ return "hir::StartCoalescing";
296
+ }
297
+ };
298
+
299
+ class EndCoalescing : public Expr {
300
+ public:
301
+ using Expr::Expr;
302
+ EndCoalescing(IrBuilderPasskey passkey);
303
+
304
+ EndCoalescing(const EndCoalescing& other) = delete;
305
+ EndCoalescing& operator=(const EndCoalescing& other) = delete;
306
+ EndCoalescing(EndCoalescing&& other) = delete;
307
+ EndCoalescing& operator=(EndCoalescing&& other) = delete;
308
+
309
+ NVFUSER_DECLARE_CLONE_AND_CREATE
310
+
311
+ std::string toString(int indent_size = 0) const override;
312
+ std::string toInlineString(int indent_size = 0) const override;
313
+ const char* getOpString() const override {
314
+ return "hir::EndCoalescing";
315
+ }
316
+ };
317
+
318
+ } // namespace hir
319
+
320
+ } // namespace nvfuser
@@ -0,0 +1,35 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <host_ir/container.h>
11
+ #include <ir/base_nodes.h>
12
+ #include <multidevice/communication.h>
13
+ #include <multidevice/multidevice.h>
14
+
15
+ namespace nvfuser {
16
+
17
+ class HostIrLower {
18
+ public:
19
+ // The flag `ignore_inner_resharding` is useful because the preseg passes
20
+ // `InsertReshardingsPass` and `ReorderShardedAxisPass` want different
21
+ // behaviors
22
+ static bool canLower(Expr* expr, bool ignore_inner_resharding = false);
23
+
24
+ // Lower a sharded Expr into a series of Communication.
25
+ static std::vector<Expr*> lower(Expr* c);
26
+
27
+ static std::unique_ptr<hir::HostIrContainer> lower(
28
+ std::unique_ptr<Fusion> fusion,
29
+ int64_t my_device_index);
30
+
31
+ private:
32
+ static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
33
+ };
34
+
35
+ } // namespace nvfuser
@@ -0,0 +1,56 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <device_lower/lower2device.h>
11
+ #include <device_lower/utils.h>
12
+ #include <id_model/id_model.h>
13
+
14
+ namespace nvfuser {
15
+
16
+ // Get the loop index of a given loop domain for circular buffer
17
+ // loops. nullptr is returned if not relevant.
18
+ //
19
+ // This is a WAR for circular buffering. TensorIndexer has a map of
20
+ // loop indices for all loop groups, however, it does not work with
21
+ // circular buffering. The loop graph is
22
+ // designed to represent each loop and each loop group is supposed
23
+ // to have a one-to-one relationship with each loop. However, for
24
+ // circular buffering, this assumption is broken as we are using
25
+ // the same iter domain for the prologue, main and epilogue
26
+ // loops. Ideally, those loops should have distinctive loop groups,
27
+ // but for now, here's a workaround to get a correct loop index
28
+ Val* getLoopIndexOfCircularBufferLoop(
29
+ IterDomain* loop_id,
30
+ const std::vector<ForLoop*>& for_loops,
31
+ const IdModel& id_model);
32
+
33
+ // For a circular-buffering expr, the producer loop index needs to be
34
+ // advanced by (#stages - 1) if it's the main loop. Return the offset
35
+ // if it's applicable. Otherwise, nullptr is returned.
36
+ Val* getLoopIndexOffsetForProducerOfCircularBuffer(
37
+ const Expr* expr,
38
+ const ForLoop* for_loop,
39
+ const IdModel& id_model);
40
+
41
+ // Get the additional offset for a circular buffer. This offset will
42
+ // be added to the normal linear index. For example, if this is a
43
+ // double buffered tensor, the offset would look like "i % 2", where i
44
+ // is the loop index of the double-buffer loop.
45
+ Val* getOffsetForCircularBufferTensor(
46
+ TensorView* circular_buffer_tv,
47
+ bool as_consumer,
48
+ const std::vector<ForLoop*>& for_loops);
49
+
50
+ // Find the circular buffering stage of a given circular buffered tensor
51
+ CircularBufferLoopStage getCircularBufferLoopStage(
52
+ const TensorView* circular_buffer_tv,
53
+ const std::vector<ForLoop*>& for_loops,
54
+ const ValGraph& loop_graph);
55
+
56
+ } // namespace nvfuser