mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,7 @@ def _generate_cmd(cmd, cmd_args, output_name):
30
30
 
31
31
  """
32
32
  if cmd not in ['python', 'pytest', 'python3']:
33
- # If user don't set binary file name, defaulty use 'python' to launch the job.
33
+ # If user don't set binary file name, defaultly use 'python' to launch the job.
34
34
  command = f"python {cmd} {' '.join(cmd_args)} > {output_name} 2>&1 &"
35
35
  else:
36
36
  command = f"{cmd} {' '.join(cmd_args)} > {output_name} 2>&1 &"
@@ -42,7 +42,7 @@ def _generate_cmd_args_list(cmd, cmd_args):
42
42
  Generates arguments list for 'Popen'. It consists of a binary file name and subsequential arguments.
43
43
  """
44
44
  if cmd not in ['python', 'pytest', 'python3']:
45
- # If user don't set binary file name, defaulty use 'python' to launch the job.
45
+ # If user don't set binary file name, defaultly use 'python' to launch the job.
46
46
  return ['python'] + [cmd] + cmd_args
47
47
  return [cmd] + cmd_args
48
48
 
@@ -55,7 +55,7 @@ def _generate_cmd_args_list_with_core(cmd, cmd_args, affinity_cpu_str):
55
55
  taskset_args = ['taskset'] + ['-c'] + [affinity_cpu_str]
56
56
  final_cmd = []
57
57
  if cmd not in ['python', 'pytest', 'python3']:
58
- # If user don't set binary file name, defaulty use 'python' to launch the job.
58
+ # If user don't set binary file name, defaultly use 'python' to launch the job.
59
59
  final_cmd = taskset_args + ['python'] + [cmd] + cmd_args
60
60
  else:
61
61
  final_cmd = taskset_args + [cmd] + cmd_args
@@ -143,8 +143,14 @@ def _parse_global_device_to_cpu_map(local_rank_id, physical_device_id, device_to
143
143
  Parse the global device_to_cpu_map and return a cpu list for assigned local_rank_id.
144
144
 
145
145
  """
146
+ if local_rank_id >= len(list(device_to_cpu_map.keys())):
147
+ logger.warning(f"Cannot find process[{local_rank_id}] in args '--bind_core'. "
148
+ "Will not launch process with taskset.")
149
+ return ""
146
150
  input_device_id = int(list(device_to_cpu_map.keys())[local_rank_id].replace("device", ""))
147
151
  if physical_device_id != input_device_id:
152
+ logger.warning(f"Cannot find physical_device_id[{physical_device_id}] for process[{local_rank_id}] "
153
+ "in args '--bind_core'. Will not launch process with taskset.")
148
154
  return ""
149
155
  affinity_cpu_list = list(device_to_cpu_map.values())[local_rank_id]
150
156
  affinity_cpu_str = ",".join(affinity_cpu_list)
@@ -212,8 +218,6 @@ def _generate_bind_core_strategy(local_rank_id, device_to_cpu_map, arg_bind_core
212
218
  if isinstance(arg_bind_core, dict):
213
219
  affinity_cpu_str = _parse_global_device_to_cpu_map(local_rank_id, physical_device_id, arg_bind_core)
214
220
  if not affinity_cpu_str:
215
- logger.warning(f"Failed to find physical_device_id[{physical_device_id}] for "
216
- f"process[{local_rank_id}]. Will not launch process with taskset.")
217
221
  return None
218
222
  elif arg_bind_core is True:
219
223
  cpu_list_for_device = device_to_cpu_map.get(physical_device_id, [])
@@ -125,14 +125,16 @@ def get_args():
125
125
  default=-1,
126
126
  type=int,
127
127
  choices=[0, 1, 2, 3],
128
- help="specifies simulation level. When this argument is set, msrun only spawns one process "
129
- "but export RANK_SIZE with value worker_num and RANK_ID with value sim_rank_id."
128
+ help="specifies simulation level. This argument activates dryrun mode, functioning "
129
+ "equivalently to environment variable 'MS_SIMULATION_LEVEL' while having higher priority."
130
130
  )
131
131
  parser.add_argument(
132
132
  "--sim_rank_id",
133
133
  default=-1,
134
134
  type=int,
135
- help="specifies simulation process's rank id. Only one process is spawned in simulation scenario."
135
+ help="specifies simulation process's rank id. When this argument is set, only one process "
136
+ "is spawned on dryrun mode, functioning equivalently to environment variable 'RANK_ID' "
137
+ "while having higher priority."
136
138
  )
137
139
  parser.add_argument(
138
140
  "--rank_table_file",
@@ -1,22 +1,21 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- Provide llm boost for inference, such as LlamaBoost.
17
- """
18
- from __future__ import absolute_import
19
-
20
- from mindspore.experimental.llm_boost.ascend_native.llama_boost_ascend_native import LlamaBoostAscendNative
21
-
22
- __all__ = ['LlamaBoostAscendNative']
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """distributed init"""
17
+ from mindspore.parallel.distributed.distributed_data_parallel import DistributedDataParallel
18
+
19
+ __all__ = [
20
+ "DistributedDataParallel",
21
+ ]
@@ -0,0 +1,393 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """ Distributed data parallel wrapper. """
16
+ from __future__ import absolute_import
17
+
18
+ __all__ = ["DistributedDataParallel"]
19
+
20
+ import itertools
21
+ from contextlib import contextmanager
22
+ from typing import Optional
23
+ import mindspore.nn as nn
24
+ import mindspore.log as logger
25
+ from mindspore import Tensor, mint
26
+ from mindspore.common import dtype as mstype
27
+ from mindspore.mint.distributed import get_world_size
28
+ from mindspore.communication import GlobalComm
29
+ from mindspore.common.api import _pynative_executor
30
+ from mindspore.mint.distributed import broadcast, get_global_rank
31
+ from mindspore.parallel.distributed.flatten_grad_buffer import FlattenGradBuffer
32
+ from mindspore._c_expression import Reducer, _find_unused_parameters
33
+
34
+
35
+ def get_data_parallel_group():
36
+ """get default global data parallel group"""
37
+ return GlobalComm.WORLD_COMM_GROUP
38
+
39
+
40
+ def get_data_parallel_world_size(group):
41
+ """get group world size"""
42
+ return get_world_size(group)
43
+
44
+
45
+ def _find_tensors(obj):
46
+ if isinstance(obj, Tensor):
47
+ return [obj]
48
+ if isinstance(obj, (list, tuple)):
49
+ return itertools.chain.from_iterable(map(_find_tensors, obj))
50
+ if isinstance(obj, dict):
51
+ return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
52
+
53
+ return []
54
+
55
+
56
+ class DistributedDataParallel(nn.Cell):
57
+ """
58
+ DistributedDataParallel wrapper. DistributedDataParallel allocates contiguous memory buffer for gradients.
59
+ Parameters' gradients will be combined into multiple buckets which are the unit to conduct all-reduce
60
+ communication among data parallel group to overlap communication latency.
61
+
62
+ .. warning::
63
+ - The method is currently only supported in PyNative mode.
64
+ - This is an experimental interface, may be changed or canceled in the future.
65
+
66
+ Args:
67
+ module (nn.Cell): the module to be wrapped with DDP.
68
+ init_sync (bool, optional): whether to sync params from rank0 of process_group when init. Default: ``True``.
69
+ process_group (str, optional): the comm group of data prallel. Default: ``None``.
70
+ bucket_cap_mb (int, optional): size of bucket in MB, default is 25MB if not set. Default: ``None``.
71
+ find_unused_parameters (bool, optional): whether to find unused params in the bucket. Default: ``False``.
72
+ average_in_collective (bool, optional): True means allreduce sum within DP group firstly then scaling with
73
+ dp size. Otherwise scaling local rank grad first and then allreduce sum. Default: ``False``.
74
+ static_graph (bool, optional): Indicate whether it is a static network. When it is a static network, the
75
+ parameter `find_unused_parameters` will be ignored, and unused parameters will be searched for in the
76
+ first step. Bucket reconstruction will be performed in execution order before the second step to achieve
77
+ better performance. Default: ``False``.
78
+ reducer_mode (str, optional): the backend to be used, could be "CppReducer" for cpp backend or "PythonReducer"
79
+ for Python backend. Default: ``"CppReducer"``.
80
+
81
+ Returns:
82
+ Model wrapped with DistributedDataParallel.
83
+
84
+ Supported Platforms:
85
+ ``Ascend``
86
+
87
+ Examples:
88
+ .. note::
89
+ - When enabling recomputation or gradient freezing, the model should be wrapped by
90
+ `DistributedDataParallel` at the outermost layer.
91
+ - Before running the following examples, you need to configure the communication environment variables.
92
+ For Ascend devices, it is recommended to use the msrun startup method
93
+ without any third-party or configuration file dependencies. For detailed information, refer to
94
+ `msrun launch <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_ .
95
+
96
+ >>> from mindspore.parallel.distributed import DistributedDataParallel
97
+ >>> from mindspore.mint.optim import AdamW
98
+ >>> from mindspore import Parameter, Tensor, ops, nn
99
+ >>> import mindspore as ms
100
+ >>> from mindspore.communication import init
101
+ >>> from mindspore.mint.distributed.distributed import init_process_group
102
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
103
+ >>> init_process_group()
104
+ >>> # Define the network structure of LeNet5. Refer to
105
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
106
+ >>> net = LeNet5()
107
+ >>> net = DistributedDataParallel(module=net,
108
+ ... bucket_cap_mb=None,
109
+ ... average_in_collective=True,
110
+ ... static_graph=True)
111
+ >>> optimizer = AdamW(net.trainable_params(), 1e-4)
112
+ >>> loss_fn = nn.CrossEntropyLoss()
113
+ >>>
114
+ >>> def forward_fn(data, target):
115
+ ... logits = net(data)
116
+ ... loss = loss_fn(logits, target)
117
+ ... return loss, logits
118
+ >>>
119
+ >>> grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
120
+ >>>
121
+ >>> # Create the dataset taking MNIST as an example. Refer to
122
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
123
+ >>> dataset = create_dataset()
124
+ >>> for epoch in range(1):
125
+ ... step = 0
126
+ ... for image, label in dataset:
127
+ ... (loss_value, _), grads = grad_fn(image, label)
128
+ ... optimizer(grads)
129
+ ... net.zero_grad()
130
+ ... step += 1
131
+ ... print("epoch: %s, step: %s, loss is %.15f" % (epoch, step, loss_value))
132
+ """
133
+
134
+ def __init__(self, module, init_sync=True, process_group=None, bucket_cap_mb: Optional[int] = None,
135
+ find_unused_parameters=False, average_in_collective: bool = False, static_graph=False,
136
+ reducer_mode="CppReducer"):
137
+ super(DistributedDataParallel, self).__init__(auto_prefix=False)
138
+ self.init_sync = init_sync
139
+ self.bucket_cap_mb = bucket_cap_mb
140
+ self.average_in_collective = average_in_collective
141
+ self.grad_reduce_in_fp32 = False
142
+ self.process_group = process_group if process_group else get_data_parallel_group()
143
+ self.static_graph = static_graph
144
+ self.find_unused_parameters = find_unused_parameters
145
+
146
+ self.module = module
147
+ self.param_to_buffer = {}
148
+ self.has_buckets_grad_sync = False
149
+
150
+ # default is 25MB for each buck
151
+ if bucket_cap_mb is None:
152
+ bucket_cap_mb = 25
153
+ self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
154
+
155
+ # grads sync with allreduce comm
156
+ self.sync_enabled = True
157
+ self.reducer_mode = reducer_mode # "CppReducer" or "PythonReducer"
158
+ self.buffers = []
159
+ self.has_mark_unused_param = False
160
+
161
+ bucketed_params = []
162
+ self.skipped_params = []
163
+ for _, param in self.module.parameters_and_names():
164
+ if not param.requires_grad:
165
+ self.skipped_params.append(param)
166
+ continue
167
+ param.grad = None
168
+ param.main_grad = None
169
+ bucketed_params.append(param)
170
+ if self.average_in_collective:
171
+ # allreduce to add grads, then to scale grads with dp size
172
+ self.gradient_scaling_factor = 1.0
173
+ else:
174
+ # scale grads with dp size locally, then allreduce to add grads
175
+ data_parallel_world_size = get_data_parallel_world_size(self.process_group)
176
+ self.gradient_scaling_factor = 1.0 / data_parallel_world_size
177
+ self.bucketed_params = bucketed_params
178
+
179
+ if self.reducer_mode == "CppReducer":
180
+ self.reducer = Reducer(self.bucketed_params,
181
+ self.process_group,
182
+ bucket_cap_mb,
183
+ self.grad_reduce_in_fp32,
184
+ average_in_collective,
185
+ static_graph,
186
+ find_unused_parameters)
187
+ if self.init_sync:
188
+ self.broadcast_coalesced()
189
+ return
190
+ # allocate buffer for trained params
191
+ self.buffers = self.allocate_buffers_for_parameters(
192
+ self.bucketed_params,
193
+ group=self.process_group,
194
+ gradient_scaling_factor=self.gradient_scaling_factor,
195
+ )
196
+ if self.init_sync:
197
+ self.broadcast_coalesced()
198
+
199
+ # register hook for bucket grad reduce
200
+ self._register_hook_for_params()
201
+
202
+ # bucket rebuilding
203
+ self.rebuilt_params_ = []
204
+ self.buffer_iterations = 0
205
+ self.has_bucket_rebuilt = False
206
+ self.buffer_issued = 0
207
+ self.triggered_once = False
208
+
209
+ def _group_params_by_dtype(self, input_params):
210
+ param_and_grad_dtype_to_params = {}
211
+ # group all params by parameter's data type and their gradient's data type.
212
+ for param in input_params:
213
+ param_dtype = param.dtype
214
+ grad_dtype = mstype.float32 if self.grad_reduce_in_fp32 else param.dtype
215
+ if (param_dtype, grad_dtype) not in param_and_grad_dtype_to_params:
216
+ param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = []
217
+ param_and_grad_dtype_to_params[(param_dtype, grad_dtype)].append(param)
218
+ return param_and_grad_dtype_to_params
219
+
220
+ def allocate_buffers_for_parameters(self, input_params, group, gradient_scaling_factor):
221
+ """allocate buffers for parameters in different dtype group."""
222
+ param_and_grad_dtype_to_params = self._group_params_by_dtype(input_params)
223
+
224
+ buffers = []
225
+ # allocate buffer for each group separately
226
+ for (param_dtype, grad_dtype,), params in param_and_grad_dtype_to_params.items():
227
+ buffers.append(
228
+ FlattenGradBuffer(
229
+ average_in_collective=self.average_in_collective,
230
+ param_dtype=param_dtype,
231
+ grad_dtype=grad_dtype,
232
+ params=params,
233
+ data_parallel_group=group,
234
+ bucket_size=self.bucket_bytes_cap,
235
+ gradient_scaling_factor=gradient_scaling_factor,
236
+ ddp_handle=self,
237
+ )
238
+ )
239
+ for param in params:
240
+ self.param_to_buffer[param] = buffers[-1]
241
+ logger.debug("allocate buffers for parameters: %s", buffers)
242
+ return buffers
243
+
244
+ def final_grad_reduce(self):
245
+ """trigger final grad reduction"""
246
+ logger.debug("trigger ddp final grad reduce, %d, %d", self.static_graph, len(self.unused_param))
247
+ if self._should_rebuild_buckets():
248
+ for param in self.unused_param:
249
+ self.rebuilt_params_.append(param)
250
+ for buffer in self.buffers:
251
+ buffer.final_grad_reduce()
252
+ buffer.issued = 0
253
+ self.buffer_issued = 0
254
+
255
+ def _register_hook_for_params(self):
256
+ """register backward hook for each params."""
257
+ for param in self.module.get_parameters():
258
+ if param.requires_grad:
259
+ param.register_hook(self._make_param_hook(param))
260
+
261
+ def _post_forward(self, output):
262
+ """prepare for backward (e.g. find unused params) if needed"""
263
+ if self.reducer_mode == "CppReducer":
264
+ if _pynative_executor.grad_flag() and self.sync_enabled:
265
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
266
+ else:
267
+ unused_param_idx = []
268
+ if self.static_graph and not self.triggered_once:
269
+ self.triggered_once = True
270
+ self.find_unused_parameters = False
271
+ unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
272
+ elif self.find_unused_parameters:
273
+ unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
274
+ self.unused_param = [self.bucketed_params[idx] for idx in unused_param_idx]
275
+ self.unused_param_name = [param.name for param in self.unused_param]
276
+ self.has_mark_unused_param = False
277
+
278
+ def _pre_forward(self):
279
+ """pre-process of forward pass to allocate buffer for parameters."""
280
+ if self.reducer_mode == "CppReducer":
281
+ if _pynative_executor.grad_flag() and self.sync_enabled:
282
+ self.reducer.prepare_for_forward()
283
+ self.reducer.rebuild_buckets()
284
+ return
285
+ if self.rebuilt_params_ and self._should_rebuild_buckets():
286
+ for i in self.rebuilt_params_:
287
+ i.old_grad = i.grad
288
+
289
+ self.buffers = self.allocate_buffers_for_parameters(
290
+ self.rebuilt_params_,
291
+ group=self.process_group,
292
+ gradient_scaling_factor=self.gradient_scaling_factor,
293
+ )
294
+ for buffer in self.buffers:
295
+ buffer.sync_enabled = self.sync_enabled
296
+
297
+ for i in self.rebuilt_params_:
298
+ i.grad.copy_(i.old_grad)
299
+ i.old_grad = None
300
+
301
+ logger.debug("register unused param: %s", self.rebuilt_params_)
302
+ self.has_bucket_rebuilt = True
303
+ self.rebuilt_params_ = []
304
+
305
+ def construct(self, *inputs, **inputs_dict):
306
+ """construct for DistributedDataParallel."""
307
+ self._pre_forward()
308
+ output = self.module(*inputs, **inputs_dict)
309
+ self._post_forward(output)
310
+ return output
311
+
312
+ def zero_grad(self):
313
+ """DPP will accumulate grads automatically, it will zero grads when call zero_grad() manually."""
314
+ if self.reducer_mode == "CppReducer":
315
+ self.reducer.zero_grad()
316
+ else:
317
+ for buffer in self.buffers:
318
+ buffer.reset()
319
+
320
+ def _enable_sync(self, enable):
321
+ """enable grad buffer sync or not."""
322
+ for buffer in self.buffers:
323
+ buffer.sync_enabled = enable
324
+ self.sync_enabled = enable
325
+
326
+ @contextmanager
327
+ def no_sync(self):
328
+ """Context manager helper function. When enabled, no grad allreduce synchronization will be executed."""
329
+ self._enable_sync(False)
330
+ try:
331
+ yield
332
+ finally:
333
+ self._enable_sync(True)
334
+
335
+ def _should_rebuild_buckets(self):
336
+ if self.static_graph and not self.has_bucket_rebuilt:
337
+ return True
338
+ return False
339
+
340
+ def _make_param_hook(self, param):
341
+ """make closure function as the param hook."""
342
+ def param_hook(grad):
343
+ if not self.has_mark_unused_param:
344
+ for cur_param in self.unused_param:
345
+ buffer = self.param_to_buffer[cur_param]
346
+ logger.debug("register unused param: %s", cur_param)
347
+ buffer.register_grad_ready(cur_param)
348
+ self.has_mark_unused_param = True
349
+ elif param.name in self.unused_param_name:
350
+ logger.debug("unused param already registered: %s", param)
351
+ return param.grad
352
+
353
+ logger.debug("register normal param: %s", param)
354
+ buffer = self.param_to_buffer[param]
355
+ param.grad.add_(grad)
356
+ buffer.register_grad_ready(param)
357
+ if self._should_rebuild_buckets():
358
+ self.rebuilt_params_.append(param)
359
+ return param.grad
360
+
361
+ return param_hook
362
+
363
+ def broadcast_coalesced(self):
364
+ """broadcast params from rank 0"""
365
+ if self.reducer_mode == "CppReducer":
366
+ buckets = [[self.bucketed_params[idx] for idx in bucket] for bucket in self.reducer.bucket_indices]
367
+ else:
368
+ buckets = [bucket.params_list for buffer in self.buffers for bucket in buffer.buckets]
369
+ if self.skipped_params:
370
+ param_and_grad_dtype_to_params = self._group_params_by_dtype(self.skipped_params)
371
+ for params_list in param_and_grad_dtype_to_params.values():
372
+ buckets.append(params_list)
373
+
374
+ def finish(rate_limiter):
375
+ for _ in rate_limiter:
376
+ handle, coalesced, params = rate_limiter.pop(0)
377
+ handle.wait()
378
+ ptr = 0
379
+ for param in params:
380
+ param.view(-1).copy_(coalesced[ptr:ptr + param.numel()])
381
+ ptr += param.numel()
382
+
383
+ rate_limiter = []
384
+ for params in buckets:
385
+ flat_tensors = [t.view(-1) for t in params]
386
+ coalesced = mint.cat(flat_tensors)
387
+ global_rank = get_global_rank(self.process_group, 0)
388
+ handle = broadcast(coalesced, src=global_rank, group=self.process_group, async_op=True)
389
+ rate_limiter.append((handle, coalesced, params))
390
+
391
+ if len(rate_limiter) >= 2:
392
+ finish(rate_limiter)
393
+ finish(rate_limiter)