nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,179 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ATen/core/TensorBody.h>
11
+ #include <ATen/core/ivalue.h>
12
+ #include <c10/util/intrusive_ptr.h>
13
+
14
+ #include <exceptions.h>
15
+ #include <multidevice/multidevice.h>
16
+ #ifdef NVFUSER_DISTRIBUTED
17
+ #include <torch/csrc/distributed/c10d/Backend.hpp>
18
+ #include <torch/csrc/distributed/c10d/TCPStore.hpp>
19
+ #include <torch/csrc/distributed/c10d/Work.hpp>
20
+ #else
21
+ #include <multidevice/c10d_mock.h>
22
+ #endif
23
+ #include <visibility.h>
24
+
25
+ namespace nvfuser {
26
+
27
+ // This file implements the class Communicator which sets up the inter-process
28
+ // Backend. This class contains inter-process information, such as the rank, the
29
+ // world size, as well as the Process Group that can be called to perform
30
+ // inter-process communications.
31
+ //
32
+ // Each process is associated with a unique deviceId and device. The actual MPI
33
+ // rank remains private to the class and should not be used by the user. The
34
+ // communicator class holds privately the mappings ranks <-> device IDs <->
35
+ // device.
36
+
37
+ using RankType = DeviceIdxType;
38
+
39
+ // Supported backends. TODO: gloo untested
40
+ enum class CommunicatorBackend { kNccl, kUcc, kGloo };
41
+
42
+ std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb);
43
+
44
+ #ifdef USE_C10D_NCCL
45
+ constexpr CommunicatorBackend comm_backend_default = CommunicatorBackend::kNccl;
46
+ #else
47
+ constexpr CommunicatorBackend comm_backend_default = CommunicatorBackend::kUcc;
48
+ #endif
49
+ constexpr int comm_server_local_rank_default = 0;
50
+
51
+ class Communicator {
52
+ public:
53
+ static Communicator& getInstance() {
54
+ // This isn't the best practice to use singleton. Ideally, we'd like to
55
+ // ```
56
+ // static Communicator communicator;
57
+ // ```
58
+ // and let the destructor clean it up at program exit after `main` returns.
59
+ // This however would cause a "driver shutting down" error, likely because
60
+ // another static variable destructor shuts down the CUDA driver before
61
+ // ~Communicator. Note that the order of static variable destruction
62
+ // across translation units is undefined.
63
+ //
64
+ // Therefore, we `new Communicator()` as a raw pointer and let the user
65
+ // call Communicator::getInstance().cleanup() to clean up the Communicator
66
+ // explicitly before the end of `main`. For example, the cleanup method is
67
+ // called via MultiDeviceTestEnvironment::TearDown in C++ unit tests and
68
+ // nvfuser._cleanup() in Python.
69
+ static auto* communicator = new Communicator();
70
+ return *communicator;
71
+ }
72
+
73
+ Communicator(const Communicator&) = delete;
74
+ Communicator& operator=(const Communicator&) = delete;
75
+ ~Communicator() = delete;
76
+ // As said in `getInstance`, the user of this class is supposed to call this
77
+ // method to clean up the singleton. This obviously can only be called once.
78
+ void cleanup();
79
+
80
+ // returns if distributed config is available
81
+ auto is_available() const {
82
+ return is_available_;
83
+ }
84
+
85
+ // returns the number of processes in the communicator
86
+ auto size() const {
87
+ return size_;
88
+ }
89
+
90
+ // returns the local number of processes in the communicator (within the node)
91
+ auto local_size() const {
92
+ return local_size_;
93
+ }
94
+
95
+ // sets the communicator's default backend
96
+ void setDefaultBackend(CommunicatorBackend backend) {
97
+ default_backend_ = backend;
98
+ }
99
+
100
+ // performs a blocking barrier in the communicator
101
+ void barrier(std::optional<CommunicatorBackend> backend = std::nullopt);
102
+
103
+ // returns the backend associated with a team
104
+ // the argument "prefix" is prepended to the key used to retrieve preexisting
105
+ // backends. Prefix is used to distinguish between different backends with the
106
+ // same team
107
+ c10d::Backend* getBackendForTeam(
108
+ const Team& team,
109
+ std::optional<CommunicatorBackend> backend,
110
+ const std::string& prefix = "");
111
+
112
+ // returns the device associated with the current process
113
+ auto device() const {
114
+ return at::Device("cuda:" + std::to_string(local_rank_));
115
+ }
116
+
117
+ // returns the device Id associated with the current process
118
+ DeviceIdxType deviceId() const {
119
+ return rankToDiD(rank_);
120
+ }
121
+
122
+ // returns local rank associted with the current process,
123
+ // i.e. the rank within a machine/node as opposed to the rank within the
124
+ // world.
125
+ RankType local_rank() const {
126
+ return local_rank_;
127
+ }
128
+
129
+ // returns world backend for communicator backend or default backend if not
130
+ // specified.
131
+ c10d::Backend* getWorld(
132
+ std::optional<CommunicatorBackend> backend = std::nullopt);
133
+
134
+ // returns if a backend is available for creation
135
+ bool isBackendAvailable(CommunicatorBackend backend) const {
136
+ if (backend == CommunicatorBackend::kUcc) {
137
+ return ucc_available_;
138
+ } else if (backend == CommunicatorBackend::kNccl) {
139
+ return nccl_available_;
140
+ }
141
+ return false;
142
+ }
143
+
144
+ private:
145
+ Communicator(
146
+ CommunicatorBackend backend = comm_backend_default,
147
+ RankType server_local_rank = comm_server_local_rank_default);
148
+
149
+ // returns the rank corresponding to a device index
150
+ RankType dIdToRank(DeviceIdxType d_id) const {
151
+ return static_cast<RankType>(d_id);
152
+ }
153
+
154
+ // returns the device index corresponding to a rank
155
+ DeviceIdxType rankToDiD(RankType rank) const {
156
+ return static_cast<DeviceIdxType>(rank);
157
+ }
158
+
159
+ CommunicatorBackend getBackend(std::optional<CommunicatorBackend> backend) {
160
+ return backend.value_or(default_backend_);
161
+ }
162
+
163
+ bool is_available_;
164
+ CommunicatorBackend default_backend_;
165
+ RankType rank_;
166
+ int64_t size_;
167
+ RankType local_rank_;
168
+ int64_t local_size_;
169
+ std::string master_addr_;
170
+ int master_port_;
171
+ bool ucc_available_;
172
+ bool nccl_available_;
173
+ // stores the world's store used for the backend init
174
+ c10::intrusive_ptr<c10d::TCPStore> store_;
175
+ // cache for the created backends. The keys are strings generated from Teams
176
+ std::unordered_map<std::string, c10::intrusive_ptr<c10d::Backend>> backends_;
177
+ };
178
+
179
+ } // namespace nvfuser
@@ -0,0 +1,95 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+
9
+ #pragma once
10
+
11
+ #include <vector>
12
+
13
+ #include <exceptions.h>
14
+ #include <multidevice/multidevice.h>
15
+ #include <type.h>
16
+ #include <visibility.h>
17
+
18
+ namespace nvfuser {
19
+
20
+ // The class DeviceMesh represents a set of (unique) devices on which a Pipeline
21
+ // Stage will be executed. For now, we only support flat meshes, but later we
22
+ // will add support for n-dimensional meshes.
23
+ class DeviceMesh final {
24
+ public:
25
+ // https://google.github.io/styleguide/cppguide.html#Implicit_Conversions
26
+ //
27
+ // Not using `explicit` for the constructor that takes a vector would lead
28
+ // to contention between operator<<(std::vector) defined in c10/util/Logging.h
29
+ // and operator<<(DeviceMesh) defined later in this file, which would be
30
+ // resolved arbitrarily by the compiler.
31
+ //
32
+ // There are no such contention for std::initializer_list so I chose to
33
+ // allow implicit conversion for that. This allows users to write `DeviceMesh
34
+ // mesh = {1, 2};`, which is more concise.
35
+ explicit DeviceMesh(std::vector<DeviceIdxType> devices = {});
36
+ DeviceMesh(std::initializer_list<DeviceIdxType> devices);
37
+ DeviceMesh(const DeviceMesh&) = default;
38
+ DeviceMesh(DeviceMesh&&) = default;
39
+ DeviceMesh& operator=(const DeviceMesh&) = default;
40
+ DeviceMesh& operator=(DeviceMesh&&) = default;
41
+
42
+ // Creates a device mesh of [0 .. num_devices-1]. I didn't make it a
43
+ // constructor because single-element initializer lists would be directed to
44
+ // use that instead of the constructor for vectors.
45
+ static DeviceMesh createForNumDevices(int64_t num_devices);
46
+
47
+ // Returns the number of devices in the mesh
48
+ int64_t size() const {
49
+ return static_cast<int64_t>(vector_.size());
50
+ }
51
+
52
+ int64_t size(ParallelType parallel_type) const;
53
+
54
+ // Returns a vector containing the device indices of the mesh
55
+ const std::vector<DeviceIdxType>& vector() const {
56
+ return vector_;
57
+ }
58
+
59
+ // Returns whether a device is present in the mesh
60
+ bool has(const DeviceIdxType device) const {
61
+ return std::find(vector_.begin(), vector_.end(), device) != vector_.end();
62
+ }
63
+
64
+ // Returns the index of device in the mesh, or -1 if device is not present.
65
+ int64_t idxOf(const DeviceIdxType device) const {
66
+ auto it = std::find(vector_.begin(), vector_.end(), device);
67
+ if (it != vector_.end()) {
68
+ return std::distance(vector_.begin(), it);
69
+ }
70
+ return -1;
71
+ }
72
+
73
+ // Returns the device at a particular index in the mesh
74
+ DeviceIdxType at(int64_t index) const {
75
+ return vector_.at(index);
76
+ }
77
+
78
+ bool operator==(const DeviceMesh& other) const {
79
+ return vector_ == other.vector();
80
+ }
81
+
82
+ bool operator!=(const DeviceMesh& other) const {
83
+ return vector_ != other.vector();
84
+ }
85
+
86
+ private:
87
+ void setDevices(std::vector<DeviceIdxType> devices);
88
+
89
+ // stores the list of device indices
90
+ std::vector<DeviceIdxType> vector_;
91
+ };
92
+
93
+ std::ostream& operator<<(std::ostream& out, const DeviceMesh& mesh);
94
+
95
+ } // namespace nvfuser
@@ -0,0 +1,107 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/core/DeviceType.h>
11
+ #include <exceptions.h>
12
+ #include <fusion.h>
13
+ #include <fusion_segmenter.h>
14
+ #include <host_ir/executor.h>
15
+ #include <ir/cloner.h>
16
+ #include <multidevice/communication.h>
17
+ #include <multidevice/communicator.h>
18
+ #include <multidevice/multidevice.h>
19
+
20
+ namespace nvfuser {
21
+
22
+ /*
23
+ The MultiDeviceExecutor executes a Fusion on a multi-device setting.
24
+ It is instantiated from a Fusion and a Communicator.
25
+
26
+ The Fusion must be scheduled prior to the instantiation of the
27
+ MultiDeviceExecutor. One can use the multidevice scheduling API to specify
28
+ the desired tensor sharding. It is composed of two aspects:
29
+ *) Set each tensor's DeviceMesh, through TensorView::setDeviceMesh
30
+ *) parallelize each tensor axis, possibly with the multidevice sharding
31
+ parallel type ParallelType::DIDx
32
+
33
+ We make the following assumptions on the Fusion:
34
+ - Only one (non-reduction) axis is allowed to be parallelized
35
+ with ParallelType::DIDx. Moreover, this axis cannot be split/merged.
36
+ - We only support 1D device meshes for now
37
+ - We only support TensorViews in communication segments.
38
+
39
+ Summary of the different steps performed by the MultiDeviceExecutor:
40
+ I. At instantiation:
41
+ - resharding "Set" exprs are automatically inserted in the fusion where a
42
+ network communication is needed. See the function insertReshardings.
43
+ - the Fusion is segmented into segments which can be of two types:
44
+ 1) compute segments, composed of non-Resharding expressions only,
45
+ that can be purely execute on a single device
46
+ or
47
+ 2) communication, composed of exactly one resharding expression, which
48
+ can be either a "Set" or "Reduce" Exprs.
49
+ - the runtime order of execution of the different segments is computed in
50
+ prepareRuntimeOrder
51
+
52
+ II. At runtime, through the method runWithInput:
53
+ - allocateRecvBuffers allocates on each device the necessary buffers to
54
+ store the data received from network communications
55
+ - Each (compute or comm) segment is executed separately, in order:
56
+ 1) each compute segment is transformed into a fusion, compiled and executed
57
+ on a single device, see postKernel
58
+ 2) each comm segment is lowered into a series of communications (defined in
59
+ multidevice/communications.h) and are posted on the stream.
60
+ "Wait" primitives are also posted on the stream.
61
+
62
+ TODOS:
63
+ *) the MultiDeviceExecutor should be integrated into FusionExecutorCache.
64
+ *) The different steps should be divided into compilation, allocation,
65
+ runtime etc. This will be done along the way when we will have better
66
+ symbolic representation of the multidevice modules
67
+ *) Allocation of buffers needs to be reimplemented
68
+ *) Need to work on auto-scheduling, in particular, to combine inter-/intra-
69
+ device scheduling.
70
+ */
71
+
72
+ class MultiDeviceExecutor {
73
+ public:
74
+ MultiDeviceExecutor(
75
+ std::unique_ptr<Fusion> fusion,
76
+ Communicator& comm,
77
+ hir::HostIrEvaluatorParams params = hir::HostIrEvaluatorParams());
78
+
79
+ // Run the fusion on several devices with the given global inputs
80
+ std::vector<at::Tensor> runWithInput(const std::vector<c10::IValue>& inputs);
81
+
82
+ // Returns the Communicator
83
+ Communicator* comm() const {
84
+ return &comm_;
85
+ }
86
+
87
+ // check if the runtime is valid returns an error msg.
88
+ // An empty message means that the runtime is valid
89
+ std::string validate() const {
90
+ return host_ir_executor_->canRun();
91
+ }
92
+
93
+ //! Print to default debugging output stream
94
+ std::ostream& print(std::ostream& os = debug());
95
+
96
+ const auto& getFusionExecutorCaches() {
97
+ return host_ir_executor_->getFusionExecutorCaches();
98
+ };
99
+
100
+ private:
101
+ // holds the Communicator to be used for execution
102
+ Communicator& comm_;
103
+ // holds the HostIrEvaluator used for execution
104
+ std::unique_ptr<hir::HostIrEvaluator> host_ir_executor_;
105
+ };
106
+
107
+ } // namespace nvfuser
@@ -0,0 +1,18 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+
9
+ #pragma once
10
+
11
+ #include <c10/core/Device.h>
12
+
13
+ namespace nvfuser {
14
+ using DeviceIdxType = int64_t;
15
+ using DimensionType = int;
16
+ using DeviceType = c10::Device;
17
+ using Team = std::vector<DeviceIdxType>;
18
+ } // namespace nvfuser
@@ -0,0 +1,187 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <c10/util/ArrayRef.h>
11
+
12
+ #include <compute_at_map.h>
13
+ #include <fusion.h>
14
+ #include <id_model/id_model.h>
15
+ #include <ir/interface_nodes.h>
16
+ #include <multidevice/multidevice.h>
17
+ #include <visibility.h>
18
+
19
+ namespace nvfuser {
20
+
21
+ // Returns true iff nvFuser was compiled with distributed APIs enabled.
22
+ NVF_API bool distributedEnabled();
23
+
24
+ // For a resharding expression, either a set or reduce, returns root IDs
25
+ // that change sharding.
26
+ // (1) sharded root IterDomains that are added by the expression
27
+ // i.e. sharded IterDomains that are present in the output, but not the input.
28
+ // (2) sharded root IterDomains that are removed by the expression
29
+ // i.e. sharded IterDomains that are present in the input, but not the output.
30
+ // TODO: Analyze loop domain for unsharded/sharded IDs and return their
31
+ // parent root IDs.
32
+ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
33
+ TensorView* producer,
34
+ TensorView* consumer);
35
+
36
+ // Returns whether a TensorView has a non-reduction axis parallelized Didx
37
+ // Checks that the other non-reduction axis are not parallelized on Didx
38
+ bool isSharded(const TensorView*);
39
+
40
+ // Returns number of device dimensions in a TensorView's loop domain.
41
+ int64_t numDeviceDims(const TensorView*);
42
+
43
+ // Returns the subset of tvs which elements have the different multi-device
44
+ // sharding as ref
45
+ template <typename TvIterator>
46
+ std::unordered_set<TensorView*> getTvsWithDifferentSharding(
47
+ TensorView* ref,
48
+ TvIterator tvs) {
49
+ std::unordered_set<TensorView*> ret;
50
+ const auto& reference_dom = ref->getLoopDomain();
51
+ FusionGuard fg(ref->fusion());
52
+ auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());
53
+ std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;
54
+ for (auto id : reference_dom) {
55
+ auto ca_id =
56
+ ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE);
57
+ concrete_to_reference_map[ca_id] = id;
58
+ }
59
+
60
+ for (TensorView* tv : tvs) {
61
+ if (ref->getDeviceMesh().vector() != tv->getDeviceMesh().vector()) {
62
+ ret.insert(tv);
63
+ continue;
64
+ }
65
+ for (auto id : tv->getLoopDomain()) {
66
+ auto ca_id =
67
+ ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE);
68
+ if (concrete_to_reference_map.count(ca_id) > 0) {
69
+ auto ref_id = concrete_to_reference_map.at(ca_id);
70
+ if ((ref_id->isDeviceDim() || id->isDeviceDim()) &&
71
+ ref_id->getParallelType() != id->getParallelType()) {
72
+ ret.insert(tv);
73
+ break;
74
+ }
75
+ }
76
+ }
77
+ }
78
+ return ret;
79
+ }
80
+
81
+ // Returns whether an Expr embeds multi-device resharding
82
+ bool isResharding(const Expr* expr);
83
+
84
+ // Returns whether two tensors have different shardings. Expect a
85
+ // producer/consumer relationship between the arguments.
86
+ bool haveDifferentShardings(
87
+ const TensorView* producer,
88
+ const TensorView* consumer,
89
+ const IdModel& id_model);
90
+
91
+ // Returns whether a resharding expr reshards an inner axis
92
+ bool isInnerResharding(Expr* expr);
93
+
94
+ // Shards all tensors in tvs like reference
95
+ void shardAllLike(TensorView* ref, std::vector<TensorView*> tvs);
96
+
97
+ // Shards all TVs between from and to AND between TVs created inside a fusion
98
+ // and to. This is required for (1) expressions like rng_uniform that create a
99
+ // TV inside a fusion that is not between a path from user visible TVs. (2)
100
+ // multi-output expressions may have output tensors that are not along a path to
101
+ // the fusion output which would not be reachable otherwise. (2) sharding
102
+ // propagation checks all TVs in the fusion are assigned a device mesh
103
+ // regardless if they are reachable. To keep the checks simple, we require all
104
+ // TVs are assigned a mesh if they exist in the fusion.
105
+ void shardBetween(
106
+ const std::vector<TensorView*>& from,
107
+ const std::vector<TensorView*>& to,
108
+ TensorView* ref);
109
+ // Same as above but using the outputs of the from and to expressions
110
+ // to form the from and to TVs.
111
+ void shardBetween(
112
+ const std::vector<Expr*>& from,
113
+ const std::vector<Expr*>& to,
114
+ TensorView* ref);
115
+
116
+ // Returns the devices involved in an expr
117
+ std::set<DeviceIdxType> involvedDevices(Expr* expr);
118
+
119
+ // Returns the number of device indices present accross all
120
+ // device meshes in the Fusion
121
+ int64_t requestedNumberOfDevices(Fusion*);
122
+
123
+ // remove the multi-device scheduling annotations
124
+ void unshard(Fusion*);
125
+ void unshard(TensorView*);
126
+
127
+ // Returns the index of the sharded logical axis that produces the allocation
128
+ // IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel
129
+ // type, returns -1.
130
+ //
131
+ // This is used to correlate `tv` and its corresponding at::Tensor, e.g., by
132
+ // `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and
133
+ // `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in
134
+ // `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's
135
+ // extent if that IterDomain is sharded.
136
+ int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);
137
+
138
+ // Shards the input tensor along `axis`. How the tensor gets sliced along `axis`
139
+ // is determined by `mesh` and `device_id`. Returns the sharded tensor.
140
+ at::Tensor shardTensor(
141
+ at::Tensor tensor,
142
+ int64_t axis,
143
+ const DeviceMesh& mesh,
144
+ DeviceIdxType device_id);
145
+
146
+ // Reorders a TensorView so that the DID parallelized axis are in front.
147
+ void reorderDIDToFront(TensorView*);
148
+
149
+ // Given a TensorView and the shape of a sharded tensor of which certain
150
+ // dimensions are partially allocated, returns the global shape that'll be used
151
+ // to bind to the TensorView's logical domain. This is to solve #3282 so we can
152
+ // bind a sharded tensor to a TensorView that has a DID-parallel loop domain.
153
+ //
154
+ // For example, when `tv` is
155
+ // logical: iM, iN
156
+ // allocation: iDIDx{D}, iN/D, iM
157
+ // and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because,
158
+ // according to the allocation domain, iM is fully allocated and iN is sharded
159
+ // and thus partially allocated.
160
+ //
161
+ // If the TensorView is not sharded, this function returns `sizes`.
162
+ //
163
+ // Limitations:
164
+ // - The function assumes that there are no Merges from logical to the
165
+ // DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical
166
+ // dimension this DID-parallelization should be attributed to.
167
+ // - The function assumes that all Splits from logical to the DID-parallel
168
+ // IterDomains in allocation are even. This is because there are currently no
169
+ // ways to pass in the global shape.
170
+ //
171
+ // Despite these limitations, I took this approach as a shortcut to fix #3282,
172
+ // which blocked many other tasks. I'm however open to other better, long-term
173
+ // solutions. Some alternatives considered in #3282 are:
174
+ // - Try to bind `at::Tensor`s to allocation domains instead of logical. Many
175
+ // `*Op::evaluate` methods (e.g.
176
+ // https://github.com/NVIDIA/Fuser/blob/2415d904d1e9a5da7ca6fb1a55d3045bbd510341/csrc/ir/nodes.cpp#L4321-L4329)
177
+ // assume the input/output `at::Tensor`s have the same dimension order as the
178
+ // logical domain. Doing so would have to change them all.
179
+ // - Try to pass into FusionExecutorCache both logical (global) shapes and
180
+ // allocated (local) tensors for sharded TensorViews. The logical shapes would
181
+ // have to be passed through FusionKernelRuntime, FusionExecutor,
182
+ // ExpressionEvaluator, and so on, which is an API overhaul.
183
+ std::vector<int64_t> unshardedSizes(
184
+ const TensorView* tv,
185
+ c10::IntArrayRef sizes);
186
+
187
+ } // namespace nvfuser
@@ -0,0 +1,86 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <visibility.h>
12
+
13
+ #include <ir/all_nodes.h>
14
+ #include <iter_visitor.h>
15
+
16
+ namespace nvfuser {
17
+
18
+ //! See doc/reading/divisibility-of-split.md#predication
19
+ //! If an IterDomain is split and its inner output domain is
20
+ //! eventually split too, the second split must be divisible or the
21
+ //! inner domain must be predicated. This class finds Split
22
+ //! expressions that need to be divisible or predicated.
23
+ //!
24
+ //! Second splits are not limited to just direct output domains of
25
+ //! first splits but also indirect descendent domains as well.
26
+ //!
27
+ //! Predicating non-divisible split domains does not work if split
28
+ //! output domains are vectorized where ParallelType::Vectorize is
29
+ //! applied to an inner domain of splits. If it's non-divisible,
30
+ //! predicating the input domain of the non-divisible split results in
31
+ //! a vectoried operation is predicated out entirely since we do not
32
+ //! generate a fall-back non-vectorized else path. Runtime check is
33
+ //! done for those domains.
34
+ class NVF_API NonDivisibleSplitInfo : public IterVisitor {
35
+ public:
36
+ void build(Fusion* fusion);
37
+
38
+ const auto& splitsToPredicate() const {
39
+ return splits_to_predicate_;
40
+ }
41
+
42
+ const auto& splitsToValidate() const {
43
+ return splits_to_validate_;
44
+ }
45
+
46
+ private:
47
+ using IterVisitor::handle;
48
+
49
+ void handle(Split* split) override;
50
+
51
+ void handle(Merge* merge) override;
52
+
53
+ //! True if reachable from inner domains of splits
54
+ bool isReachableFromInnerDomains(IterDomain* id) const;
55
+
56
+ //! Forward propagate the reachability information
57
+ void propagateReachability(Split* split, bool is_protected);
58
+
59
+ //! Forward propagate the reachability information
60
+ void propagateReachability(Merge* merge);
61
+
62
+ void clearReachability();
63
+
64
+ //! Returns the extent of a split output domain if it's not proven to
65
+ //! be divisible.
66
+ Val* getMaybeNonDivisibleExtent(Split* split) const;
67
+
68
+ //! Remove redundant predicates as divisibility may be validated at
69
+ //! run time
70
+ void removeRedundancy();
71
+
72
+ //! Add validations to GpuLower::current()->validations()
73
+ void addValidations();
74
+
75
+ private:
76
+ //! Split expressions whose input domain must be predicated
77
+ std::unordered_map<TensorView*, std::vector<Split*>> splits_to_predicate_;
78
+ //! Split expressions whose divisibility must be validated at run time
79
+ std::unordered_set<Split*> splits_to_validate_;
80
+
81
+ //! Temporarily used for analyzing each tensor
82
+ TensorView* current_tv_ = nullptr;
83
+ std::unordered_set<IterDomain*> inner_domains_;
84
+ };
85
+
86
+ } // namespace nvfuser