mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,393 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """ Distributed data parallel wrapper. """
16
+ from __future__ import absolute_import
17
+
18
+ __all__ = ["DistributedDataParallel"]
19
+
20
+ import itertools
21
+ from contextlib import contextmanager
22
+ from typing import Optional
23
+ import mindspore.nn as nn
24
+ import mindspore.log as logger
25
+ from mindspore import Tensor, mint
26
+ from mindspore.common import dtype as mstype
27
+ from mindspore.mint.distributed import get_world_size
28
+ from mindspore.communication import GlobalComm
29
+ from mindspore.common.api import _pynative_executor
30
+ from mindspore.mint.distributed import broadcast, get_global_rank
31
+ from mindspore.parallel.distributed.flatten_grad_buffer import FlattenGradBuffer
32
+ from mindspore._c_expression import Reducer, _find_unused_parameters
33
+
34
+
35
+ def get_data_parallel_group():
36
+ """get default global data parallel group"""
37
+ return GlobalComm.WORLD_COMM_GROUP
38
+
39
+
40
+ def get_data_parallel_world_size(group):
41
+ """get group world size"""
42
+ return get_world_size(group)
43
+
44
+
45
+ def _find_tensors(obj):
46
+ if isinstance(obj, Tensor):
47
+ return [obj]
48
+ if isinstance(obj, (list, tuple)):
49
+ return itertools.chain.from_iterable(map(_find_tensors, obj))
50
+ if isinstance(obj, dict):
51
+ return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
52
+
53
+ return []
54
+
55
+
56
+ class DistributedDataParallel(nn.Cell):
57
+ """
58
+ DistributedDataParallel wrapper. DistributedDataParallel allocates contiguous memory buffer for gradients.
59
+ Parameters' gradients will be combined into multiple buckets which are the unit to conduct all-reduce
60
+ communication among data parallel group to overlap communication latency.
61
+
62
+ .. warning::
63
+ - The method is currently only supported in PyNative mode.
64
+ - This is an experimental interface, may be changed or canceled in the future.
65
+
66
+ Args:
67
+ module (nn.Cell): the module to be wrapped with DDP.
68
+ init_sync (bool, optional): whether to sync params from rank0 of process_group when init. Default: ``True``.
69
+ process_group (str, optional): the comm group of data prallel. Default: ``None``.
70
+ bucket_cap_mb (int, optional): size of bucket in MB, default is 25MB if not set. Default: ``None``.
71
+ find_unused_parameters (bool, optional): whether to find unused params in the bucket. Default: ``False``.
72
+ average_in_collective (bool, optional): True means allreduce sum within DP group firstly then scaling with
73
+ dp size. Otherwise scaling local rank grad first and then allreduce sum. Default: ``False``.
74
+ static_graph (bool, optional): Indicate whether it is a static network. When it is a static network, the
75
+ parameter `find_unused_parameters` will be ignored, and unused parameters will be searched for in the
76
+ first step. Bucket reconstruction will be performed in execution order before the second step to achieve
77
+ better performance. Default: ``False``.
78
+ reducer_mode (str, optional): the backend to be used, could be "CppReducer" for cpp backend or "PythonReducer"
79
+ for Python backend. Default: ``"CppReducer"``.
80
+
81
+ Returns:
82
+ Model wrapped with DistributedDataParallel.
83
+
84
+ Supported Platforms:
85
+ ``Ascend``
86
+
87
+ Examples:
88
+ .. note::
89
+ - When enabling recomputation or gradient freezing, the model should be wrapped by
90
+ `DistributedDataParallel` at the outermost layer.
91
+ - Before running the following examples, you need to configure the communication environment variables.
92
+ For Ascend devices, it is recommended to use the msrun startup method
93
+ without any third-party or configuration file dependencies. For detailed information, refer to
94
+ `msrun launch <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_ .
95
+
96
+ >>> from mindspore.parallel.distributed import DistributedDataParallel
97
+ >>> from mindspore.mint.optim import AdamW
98
+ >>> from mindspore import Parameter, Tensor, ops, nn
99
+ >>> import mindspore as ms
100
+ >>> from mindspore.communication import init
101
+ >>> from mindspore.mint.distributed.distributed import init_process_group
102
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
103
+ >>> init_process_group()
104
+ >>> # Define the network structure of LeNet5. Refer to
105
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
106
+ >>> net = LeNet5()
107
+ >>> net = DistributedDataParallel(module=net,
108
+ ... bucket_cap_mb=None,
109
+ ... average_in_collective=True,
110
+ ... static_graph=True)
111
+ >>> optimizer = AdamW(net.trainable_params(), 1e-4)
112
+ >>> loss_fn = nn.CrossEntropyLoss()
113
+ >>>
114
+ >>> def forward_fn(data, target):
115
+ ... logits = net(data)
116
+ ... loss = loss_fn(logits, target)
117
+ ... return loss, logits
118
+ >>>
119
+ >>> grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
120
+ >>>
121
+ >>> # Create the dataset taking MNIST as an example. Refer to
122
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
123
+ >>> dataset = create_dataset()
124
+ >>> for epoch in range(1):
125
+ ... step = 0
126
+ ... for image, label in dataset:
127
+ ... (loss_value, _), grads = grad_fn(image, label)
128
+ ... optimizer(grads)
129
+ ... net.zero_grad()
130
+ ... step += 1
131
+ ... print("epoch: %s, step: %s, loss is %.15f" % (epoch, step, loss_value))
132
+ """
133
+
134
+ def __init__(self, module, init_sync=True, process_group=None, bucket_cap_mb: Optional[int] = None,
135
+ find_unused_parameters=False, average_in_collective: bool = False, static_graph=False,
136
+ reducer_mode="CppReducer"):
137
+ super(DistributedDataParallel, self).__init__(auto_prefix=False)
138
+ self.init_sync = init_sync
139
+ self.bucket_cap_mb = bucket_cap_mb
140
+ self.average_in_collective = average_in_collective
141
+ self.grad_reduce_in_fp32 = False
142
+ self.process_group = process_group if process_group else get_data_parallel_group()
143
+ self.static_graph = static_graph
144
+ self.find_unused_parameters = find_unused_parameters
145
+
146
+ self.module = module
147
+ self.param_to_buffer = {}
148
+ self.has_buckets_grad_sync = False
149
+
150
+ # default is 25MB for each buck
151
+ if bucket_cap_mb is None:
152
+ bucket_cap_mb = 25
153
+ self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
154
+
155
+ # grads sync with allreduce comm
156
+ self.sync_enabled = True
157
+ self.reducer_mode = reducer_mode # "CppReducer" or "PythonReducer"
158
+ self.buffers = []
159
+ self.has_mark_unused_param = False
160
+
161
+ bucketed_params = []
162
+ self.skipped_params = []
163
+ for _, param in self.module.parameters_and_names():
164
+ if not param.requires_grad:
165
+ self.skipped_params.append(param)
166
+ continue
167
+ param.grad = None
168
+ param.main_grad = None
169
+ bucketed_params.append(param)
170
+ if self.average_in_collective:
171
+ # allreduce to add grads, then to scale grads with dp size
172
+ self.gradient_scaling_factor = 1.0
173
+ else:
174
+ # scale grads with dp size locally, then allreduce to add grads
175
+ data_parallel_world_size = get_data_parallel_world_size(self.process_group)
176
+ self.gradient_scaling_factor = 1.0 / data_parallel_world_size
177
+ self.bucketed_params = bucketed_params
178
+
179
+ if self.reducer_mode == "CppReducer":
180
+ self.reducer = Reducer(self.bucketed_params,
181
+ self.process_group,
182
+ bucket_cap_mb,
183
+ self.grad_reduce_in_fp32,
184
+ average_in_collective,
185
+ static_graph,
186
+ find_unused_parameters)
187
+ if self.init_sync:
188
+ self.broadcast_coalesced()
189
+ return
190
+ # allocate buffer for trained params
191
+ self.buffers = self.allocate_buffers_for_parameters(
192
+ self.bucketed_params,
193
+ group=self.process_group,
194
+ gradient_scaling_factor=self.gradient_scaling_factor,
195
+ )
196
+ if self.init_sync:
197
+ self.broadcast_coalesced()
198
+
199
+ # register hook for bucket grad reduce
200
+ self._register_hook_for_params()
201
+
202
+ # bucket rebuilding
203
+ self.rebuilt_params_ = []
204
+ self.buffer_iterations = 0
205
+ self.has_bucket_rebuilt = False
206
+ self.buffer_issued = 0
207
+ self.triggered_once = False
208
+
209
+ def _group_params_by_dtype(self, input_params):
210
+ param_and_grad_dtype_to_params = {}
211
+ # group all params by parameter's data type and their gradient's data type.
212
+ for param in input_params:
213
+ param_dtype = param.dtype
214
+ grad_dtype = mstype.float32 if self.grad_reduce_in_fp32 else param.dtype
215
+ if (param_dtype, grad_dtype) not in param_and_grad_dtype_to_params:
216
+ param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = []
217
+ param_and_grad_dtype_to_params[(param_dtype, grad_dtype)].append(param)
218
+ return param_and_grad_dtype_to_params
219
+
220
+ def allocate_buffers_for_parameters(self, input_params, group, gradient_scaling_factor):
221
+ """allocate buffers for parameters in different dtype group."""
222
+ param_and_grad_dtype_to_params = self._group_params_by_dtype(input_params)
223
+
224
+ buffers = []
225
+ # allocate buffer for each group separately
226
+ for (param_dtype, grad_dtype,), params in param_and_grad_dtype_to_params.items():
227
+ buffers.append(
228
+ FlattenGradBuffer(
229
+ average_in_collective=self.average_in_collective,
230
+ param_dtype=param_dtype,
231
+ grad_dtype=grad_dtype,
232
+ params=params,
233
+ data_parallel_group=group,
234
+ bucket_size=self.bucket_bytes_cap,
235
+ gradient_scaling_factor=gradient_scaling_factor,
236
+ ddp_handle=self,
237
+ )
238
+ )
239
+ for param in params:
240
+ self.param_to_buffer[param] = buffers[-1]
241
+ logger.debug("allocate buffers for parameters: %s", buffers)
242
+ return buffers
243
+
244
+ def final_grad_reduce(self):
245
+ """trigger final grad reduction"""
246
+ logger.debug("trigger ddp final grad reduce, %d, %d", self.static_graph, len(self.unused_param))
247
+ if self._should_rebuild_buckets():
248
+ for param in self.unused_param:
249
+ self.rebuilt_params_.append(param)
250
+ for buffer in self.buffers:
251
+ buffer.final_grad_reduce()
252
+ buffer.issued = 0
253
+ self.buffer_issued = 0
254
+
255
+ def _register_hook_for_params(self):
256
+ """register backward hook for each params."""
257
+ for param in self.module.get_parameters():
258
+ if param.requires_grad:
259
+ param.register_hook(self._make_param_hook(param))
260
+
261
+ def _post_forward(self, output):
262
+ """prepare for backward (e.g. find unused params) if needed"""
263
+ if self.reducer_mode == "CppReducer":
264
+ if _pynative_executor.grad_flag() and self.sync_enabled:
265
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
266
+ else:
267
+ unused_param_idx = []
268
+ if self.static_graph and not self.triggered_once:
269
+ self.triggered_once = True
270
+ self.find_unused_parameters = False
271
+ unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
272
+ elif self.find_unused_parameters:
273
+ unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
274
+ self.unused_param = [self.bucketed_params[idx] for idx in unused_param_idx]
275
+ self.unused_param_name = [param.name for param in self.unused_param]
276
+ self.has_mark_unused_param = False
277
+
278
+ def _pre_forward(self):
279
+ """pre-process of forward pass to allocate buffer for parameters."""
280
+ if self.reducer_mode == "CppReducer":
281
+ if _pynative_executor.grad_flag() and self.sync_enabled:
282
+ self.reducer.prepare_for_forward()
283
+ self.reducer.rebuild_buckets()
284
+ return
285
+ if self.rebuilt_params_ and self._should_rebuild_buckets():
286
+ for i in self.rebuilt_params_:
287
+ i.old_grad = i.grad
288
+
289
+ self.buffers = self.allocate_buffers_for_parameters(
290
+ self.rebuilt_params_,
291
+ group=self.process_group,
292
+ gradient_scaling_factor=self.gradient_scaling_factor,
293
+ )
294
+ for buffer in self.buffers:
295
+ buffer.sync_enabled = self.sync_enabled
296
+
297
+ for i in self.rebuilt_params_:
298
+ i.grad.copy_(i.old_grad)
299
+ i.old_grad = None
300
+
301
+ logger.debug("register unused param: %s", self.rebuilt_params_)
302
+ self.has_bucket_rebuilt = True
303
+ self.rebuilt_params_ = []
304
+
305
+ def construct(self, *inputs, **inputs_dict):
306
+ """construct for DistributedDataParallel."""
307
+ self._pre_forward()
308
+ output = self.module(*inputs, **inputs_dict)
309
+ self._post_forward(output)
310
+ return output
311
+
312
+ def zero_grad(self):
313
+ """DPP will accumulate grads automatically, it will zero grads when call zero_grad() manually."""
314
+ if self.reducer_mode == "CppReducer":
315
+ self.reducer.zero_grad()
316
+ else:
317
+ for buffer in self.buffers:
318
+ buffer.reset()
319
+
320
+ def _enable_sync(self, enable):
321
+ """enable grad buffer sync or not."""
322
+ for buffer in self.buffers:
323
+ buffer.sync_enabled = enable
324
+ self.sync_enabled = enable
325
+
326
+ @contextmanager
327
+ def no_sync(self):
328
+ """Context manager helper function. When enabled, no grad allreduce synchronization will be executed."""
329
+ self._enable_sync(False)
330
+ try:
331
+ yield
332
+ finally:
333
+ self._enable_sync(True)
334
+
335
+ def _should_rebuild_buckets(self):
336
+ if self.static_graph and not self.has_bucket_rebuilt:
337
+ return True
338
+ return False
339
+
340
+ def _make_param_hook(self, param):
341
+ """make closure function as the param hook."""
342
+ def param_hook(grad):
343
+ if not self.has_mark_unused_param:
344
+ for cur_param in self.unused_param:
345
+ buffer = self.param_to_buffer[cur_param]
346
+ logger.debug("register unused param: %s", cur_param)
347
+ buffer.register_grad_ready(cur_param)
348
+ self.has_mark_unused_param = True
349
+ elif param.name in self.unused_param_name:
350
+ logger.debug("unused param already registered: %s", param)
351
+ return param.grad
352
+
353
+ logger.debug("register normal param: %s", param)
354
+ buffer = self.param_to_buffer[param]
355
+ param.grad.add_(grad)
356
+ buffer.register_grad_ready(param)
357
+ if self._should_rebuild_buckets():
358
+ self.rebuilt_params_.append(param)
359
+ return param.grad
360
+
361
+ return param_hook
362
+
363
+ def broadcast_coalesced(self):
364
+ """broadcast params from rank 0"""
365
+ if self.reducer_mode == "CppReducer":
366
+ buckets = [[self.bucketed_params[idx] for idx in bucket] for bucket in self.reducer.bucket_indices]
367
+ else:
368
+ buckets = [bucket.params_list for buffer in self.buffers for bucket in buffer.buckets]
369
+ if self.skipped_params:
370
+ param_and_grad_dtype_to_params = self._group_params_by_dtype(self.skipped_params)
371
+ for params_list in param_and_grad_dtype_to_params.values():
372
+ buckets.append(params_list)
373
+
374
+ def finish(rate_limiter):
375
+ for _ in rate_limiter:
376
+ handle, coalesced, params = rate_limiter.pop(0)
377
+ handle.wait()
378
+ ptr = 0
379
+ for param in params:
380
+ param.view(-1).copy_(coalesced[ptr:ptr + param.numel()])
381
+ ptr += param.numel()
382
+
383
+ rate_limiter = []
384
+ for params in buckets:
385
+ flat_tensors = [t.view(-1) for t in params]
386
+ coalesced = mint.cat(flat_tensors)
387
+ global_rank = get_global_rank(self.process_group, 0)
388
+ handle = broadcast(coalesced, src=global_rank, group=self.process_group, async_op=True)
389
+ rate_limiter.append((handle, coalesced, params))
390
+
391
+ if len(rate_limiter) >= 2:
392
+ finish(rate_limiter)
393
+ finish(rate_limiter)
@@ -0,0 +1,295 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """ Param and grad buffer, bucket implemenatrion. """
16
+ from __future__ import absolute_import
17
+
18
+ __all__ = ["Bucket", "FlattenGradBuffer"]
19
+
20
+ from enum import Enum
21
+ import numpy as np
22
+ from mindspore import mint, Tensor
23
+ from mindspore.common.initializer import Zero
24
+ from mindspore.communication.management import get_group_size
25
+ import mindspore.communication.comm_func as comm_func
26
+
27
+
28
+ class BufferType(Enum):
29
+ PARAM = 0
30
+ GRAD = 1
31
+
32
+
33
+ MEM_ALIGN_SIZE = 512
34
+ ALIGN_BYTES = 32
35
+ MIN_BUCKET_SIZE = int(1 * 1024 * 1024)
36
+ DEFAULT_BUCKET_SIZE = int(25 * 1024 * 1024)
37
+
38
+
39
+ class Bucket:
40
+ """
41
+ Bucket to track a subset of parameters and gradients in the buffer. Bucket records the parameters
42
+ whose gradient has already been computed. It also provide functionality to synchronize gradients among
43
+ data parallel group when all parameters' graidents have been computed.
44
+
45
+ Args:
46
+ average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
47
+ params (List(Parameters)): Parameters belongs to this bucket.
48
+ grad_data (Tensor): A section of buffers' gradient data, coressponding to parameters in this bucket.
49
+ offset (int): Start index in the buffer.
50
+ numel_unpadded (int): Number of unpadded elements in bucket.
51
+ data_parallel_group (str): Data parallel group name.
52
+ data_parallel_world_size (int): Data parallel group size.
53
+ gradient_scaling_factor (float): Work with average_in_collective, it is 1.0 when average_in_collective
54
+ true else 1.0/dp
55
+ """
56
+
57
+ def __init__(self, average_in_collective, params, grad_data, offset, numel_unpadded, data_parallel_group,
58
+ data_parallel_world_size, gradient_scaling_factor):
59
+ self.average_in_collective = average_in_collective
60
+ self.params_list = params
61
+ self.params = set(params)
62
+ self.params_grad_ready = set()
63
+ self.grad_data = grad_data
64
+ self.grad_data_numel = self.grad_data.numel()
65
+ self.offset = offset
66
+ self.numel_unpadded = numel_unpadded
67
+ self.data_parallel_group = data_parallel_group
68
+ self.data_parallel_world_size = data_parallel_world_size
69
+ self.gradient_scaling_factor = gradient_scaling_factor
70
+
71
+ if self.data_parallel_world_size > 1:
72
+ self.grad_reducer = comm_func.all_reduce
73
+
74
+ self.reset()
75
+
76
+ def inplace_reduce_dp(self, src):
77
+ """conduct all-reduce/reduce-scatter on src tensor and inplace update result into target."""
78
+ self.communication_result, self.communication_handle = self.grad_reducer(
79
+ src, "sum", self.data_parallel_group, async_op=True
80
+ )
81
+
82
+ def reset(self):
83
+ """reset bucket for the next iteration."""
84
+ self.params_grad_ready = set()
85
+ self.is_reduce_issued = False
86
+ self.communication_handle = None
87
+ self.communication_result = None
88
+
89
+ def issue_grad_reduce(self):
90
+ """issue grad reduce for the local grad data view."""
91
+ if self.is_reduce_issued:
92
+ raise RuntimeError("The bucket reduce is already issued")
93
+
94
+ if self.gradient_scaling_factor != 1.0:
95
+ self.grad_data.copy_(mint.mul(self.grad_data, self.gradient_scaling_factor))
96
+
97
+ if self.data_parallel_world_size > 1:
98
+ self.inplace_reduce_dp(self.grad_data)
99
+
100
+ self.is_reduce_issued = True
101
+
102
+ def final_grad_reduce(self):
103
+ """finalize grad reduce for the local grad data view."""
104
+ start_idx = 0
105
+ end_idx = self.grad_data_numel
106
+ target = self.grad_data[start_idx:end_idx]
107
+
108
+ if not self.is_reduce_issued:
109
+ raise RuntimeError(
110
+ f"The bucket reduce has not been issued "
111
+ f"with only {len(self.params_grad_ready)}/{len(self.params)} params ready"
112
+ )
113
+
114
+ if self.data_parallel_world_size > 1:
115
+ self.communication_handle.wait()
116
+ target.copy_(self.communication_result)
117
+ self.communication_result = None
118
+ if self.average_in_collective:
119
+ target.copy_(mint.div(target, self.data_parallel_world_size))
120
+
121
+ def register_grad_ready(self, param):
122
+ """register grad ready and issue bucket grad reduce when the bucket is ready."""
123
+ if param not in self.params:
124
+ raise ValueError("The param to be registered is not in the bucket")
125
+
126
+ if param in self.params_grad_ready:
127
+ raise ValueError(f"The param {param} is already registered")
128
+
129
+ self.params_grad_ready.add(param)
130
+ if len(self.params_grad_ready) == len(self.params):
131
+ self.issue_grad_reduce()
132
+ return True
133
+
134
+ return False
135
+
136
+ def __repr__(self):
137
+ return f"Bucket (offset={self.offset}, param_lens={len(self.params)})"
138
+
139
+
140
+ class FlattenGradBuffer:
141
+ """
142
+ Allocate contiguous memory buffer for given parameters and corresponding gradients. Breaking
143
+ up parameters and gradients buffer into small buckets, which is the unit for all-reduce/reduce-scatter
144
+ communication during back-propagation.
145
+
146
+ Args:
147
+ average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
148
+ param_dtype (mindspore.dtype): The parameters' datatype.
149
+ grad_dtype (mindspore.dtype): The gradients' datatype.
150
+ params (List(Parameters)): Parameters belongs to this buffer.
151
+ data_parallel_group (str): Data parallel group name.
152
+ bucket_size (int): Bucket size threshold used to partition bucekts.
153
+ gradient_scaling_factor (float):
154
+ """
155
+
156
+ def __init__(self, average_in_collective, param_dtype, grad_dtype, params, data_parallel_group,
157
+ bucket_size, gradient_scaling_factor, ddp_handle):
158
+ super(FlattenGradBuffer, self).__init__()
159
+ self.param_dtype = param_dtype
160
+ self.grad_dtype = grad_dtype
161
+ self.data_parallel_group = data_parallel_group
162
+ self.data_parallel_world_size = get_group_size(group=self.data_parallel_group)
163
+ self.gradient_scaling_factor = gradient_scaling_factor
164
+ self.average_in_collective = average_in_collective
165
+
166
+ self.buckets = []
167
+ self.param_index_map = {}
168
+ self.param_to_bucket = {}
169
+ self.sync_enabled = True
170
+ self.issued = 0
171
+ self.ddp_handle = ddp_handle
172
+
173
+ buckets_metadata = self.calc_partition_metadata(bucket_size, params)
174
+ self.instantiate_buckets(buckets_metadata, params)
175
+
176
+ def calc_partition_metadata(self, bucket_size, params):
177
+ """calc bucket partition metadata"""
178
+ # helper func
179
+ def _need_new_bucket(bucket_numel, bucket_id):
180
+ target_bucket_size = bucket_size
181
+ if bucket_id == 0 and bucket_size == DEFAULT_BUCKET_SIZE:
182
+ target_bucket_size = MIN_BUCKET_SIZE
183
+ return (
184
+ bucket_size is not None
185
+ and bucket_numel != 0
186
+ and bucket_numel >= target_bucket_size
187
+ )
188
+
189
+ def _build_bucket():
190
+ nonlocal buckets_metadata, bucket_start_index, bucket_params, bucket_id
191
+ bucket_end_index = data_start_index
192
+ buckets_metadata.append(
193
+ (bucket_start_index, bucket_end_index, bucket_params)
194
+ )
195
+ bucket_start_index = bucket_end_index
196
+ bucket_id = bucket_id + 1
197
+ bucket_params = []
198
+
199
+ param_data_list = []
200
+ buckets_metadata = []
201
+ data_start_index = 0
202
+ data_end_index = 0
203
+ bucket_id = 0
204
+ bucket_start_index = 0
205
+ bucket_params = []
206
+ for param in params[::]: # traverse from the beginning
207
+ last_bucket_numel = data_start_index - bucket_start_index
208
+ if _need_new_bucket(last_bucket_numel, bucket_id):
209
+ _build_bucket()
210
+ data_end_index = data_start_index + param.numel()
211
+ bucket_params.append(param)
212
+ param_data_list.append(param)
213
+ self.param_index_map[param] = (data_start_index, data_end_index, bucket_id)
214
+ data_start_index = data_end_index
215
+
216
+ # add bucket for the last few params which do not reach the bucket_size threshold
217
+ if data_start_index - bucket_start_index > 0:
218
+ bucket_end_index = data_start_index
219
+ buckets_metadata.append(
220
+ (bucket_start_index, bucket_end_index, bucket_params)
221
+ )
222
+ data_start_index = bucket_end_index
223
+
224
+ # allocate contiguous memory for parameters and gradients
225
+ self.numel = data_start_index
226
+ self.grad_data = Tensor(shape=(self.numel), dtype=self.grad_dtype, init=Zero())
227
+ self.grad_data.init_data()
228
+ self.numel_unpadded = 0
229
+ return buckets_metadata
230
+
231
+ def instantiate_buckets(self, buckets_metadata, params):
232
+ """build bucket instance according to partition metadata"""
233
+ for bucket_start_index, bucket_end_index, bucket_params in buckets_metadata:
234
+ local_grad_data = self.grad_data[bucket_start_index:bucket_end_index]
235
+ self.numel_unpadded += bucket_end_index - bucket_start_index
236
+ bucket = Bucket(
237
+ average_in_collective=self.average_in_collective,
238
+ params=bucket_params,
239
+ grad_data=local_grad_data,
240
+ offset=bucket_start_index,
241
+ numel_unpadded=bucket_end_index - bucket_start_index,
242
+ data_parallel_group=self.data_parallel_group,
243
+ data_parallel_world_size=self.data_parallel_world_size,
244
+ gradient_scaling_factor=self.gradient_scaling_factor,
245
+ )
246
+ self.buckets.append(bucket)
247
+ for param in bucket_params:
248
+ self.param_to_bucket[param] = bucket
249
+
250
+ for param in params:
251
+ data_start_index, _, _ = self.param_index_map[param]
252
+ param.grad = self._get_buffer_slice(
253
+ param.shape, data_start_index, BufferType.GRAD
254
+ )
255
+
256
+ def _get_buffer_slice(self, shape, start_index, buffer_type):
257
+ """get the buffer view with the same shape"""
258
+ end_index = start_index + int(np.prod(shape))
259
+ if start_index < 0 or end_index > self.numel:
260
+ raise ValueError("index out of range")
261
+ if buffer_type == BufferType.GRAD:
262
+ buffer_tensor = self.grad_data[start_index:end_index]
263
+ else:
264
+ raise TypeError("Invalid buffer type for _get_buffer_slice.")
265
+ buffer_tensor = buffer_tensor.view(shape)
266
+ return buffer_tensor
267
+
268
+ def reset(self):
269
+ """reset buffer for the next iteration."""
270
+ self.grad_data.zero_()
271
+ for bucket in self.buckets:
272
+ bucket.reset()
273
+ self.sync_enabled = True
274
+
275
+ def final_grad_reduce(self):
276
+ """finalize grad reduce for each bucket"""
277
+ for bucket in self.buckets:
278
+ bucket.final_grad_reduce()
279
+
280
+ def register_grad_ready(self, param):
281
+ """register ready grad in its buckets"""
282
+ if self.sync_enabled:
283
+ bucket = self.param_to_bucket[param]
284
+ if bucket.register_grad_ready(param):
285
+ self.issued += 1
286
+ if self.issued == len(self.buckets):
287
+ self.ddp_handle.buffer_issued += 1
288
+ if self.ddp_handle.buffer_issued == len(self.ddp_handle.buffers):
289
+ self.ddp_handle.final_grad_reduce()
290
+
291
+ def __repr__(self):
292
+ param_index_with_name = {
293
+ param.name: index for (param, index) in self.param_index_map.items()
294
+ }
295
+ return f"Buffer has buckets: \n {self.buckets} \n and param_index_map: \n {param_index_with_name}"