mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -168,6 +168,24 @@ class _AutoParallelContext:
168
168
  self.check_context_handle()
169
169
  return _ParallelFusionConfig.CONFIG
170
170
 
171
+ def set_dump_local_norm(self, dump_local_norm):
172
+ """
173
+ Set dump local norm for auto parallel.
174
+
175
+ Args:
176
+ dump_local_norm (bool): User need to specify if he want to dump local norm. Default: False
177
+
178
+ Raises:
179
+ KeyError: When key of comm_fusion is not 'allreduce'.
180
+ """
181
+ self.check_context_handle()
182
+ self._context_handle.set_dump_local_norm(dump_local_norm)
183
+
184
+ def get_dump_local_norm(self):
185
+ """Get dump local norm."""
186
+ self.check_context_handle()
187
+ return self._context_handle.get_dump_local_norm()
188
+
171
189
  def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
172
190
  """
173
191
  Set fusion threshold (MB) for auto parallel.
@@ -584,7 +602,7 @@ class _AutoParallelContext:
584
602
  self.check_context_handle()
585
603
  dir_path = os.path.dirname(strategy_ckpt_save_file)
586
604
  if dir_path and not os.path.exists(dir_path):
587
- os.makedirs(dir_path, exist_ok=True)
605
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
588
606
  self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
589
607
 
590
608
  def get_strategy_ckpt_save_file(self):
@@ -643,7 +661,7 @@ class _AutoParallelContext:
643
661
  self.check_context_handle()
644
662
  dir_path = os.path.dirname(group_ckpt_save_file)
645
663
  if dir_path and not os.path.exists(dir_path):
646
- os.makedirs(dir_path)
664
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
647
665
  self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
648
666
 
649
667
  def get_parameter_broadcast_is_set(self):
@@ -1117,9 +1135,9 @@ class _AutoParallelContext:
1117
1135
  """
1118
1136
  self.check_context_handle()
1119
1137
  if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
1120
- return
1138
+ self.set_enable_all_gather_fusion(True)
1121
1139
  if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
1122
- return
1140
+ self.set_enable_reduce_scatter_fusion(True)
1123
1141
  if not isinstance(comm_fusion, dict):
1124
1142
  raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
1125
1143
  comm_type, type(comm_fusion)))
@@ -1153,7 +1171,7 @@ class _AutoParallelContext:
1153
1171
  """
1154
1172
  self.check_context_handle()
1155
1173
  if not self.get_enable_all_reduce_fusion():
1156
- return
1174
+ self.set_enable_all_reduce_fusion(True)
1157
1175
  if not isinstance(comm_fusion, dict):
1158
1176
  raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
1159
1177
  type(comm_fusion)))
@@ -1210,7 +1228,7 @@ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
1210
1228
  """
1211
1229
  dir_path = os.path.dirname(path)
1212
1230
  if dir_path and not os.path.exists(dir_path):
1213
- os.makedirs(dir_path)
1231
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
1214
1232
  check_type = ["SAVE", "LOAD"]
1215
1233
  check_mode = ["all", "principal"]
1216
1234
  if type in check_type and mode in check_mode:
@@ -1266,7 +1284,8 @@ _set_auto_parallel_context_func_map = {
1266
1284
  "sharding_propagation": auto_parallel_context().set_sharding_propagation,
1267
1285
  "enable_alltoall": auto_parallel_context().set_enable_alltoall,
1268
1286
  "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
1269
- "comm_fusion": auto_parallel_context().set_comm_fusion}
1287
+ "comm_fusion": auto_parallel_context().set_comm_fusion,
1288
+ "dump_local_norm": auto_parallel_context().set_dump_local_norm}
1270
1289
 
1271
1290
  _get_auto_parallel_context_func_map = {
1272
1291
  "device_num": auto_parallel_context().get_device_num,
@@ -1298,7 +1317,8 @@ _get_auto_parallel_context_func_map = {
1298
1317
  "enable_alltoall": auto_parallel_context().get_enable_alltoall,
1299
1318
  "comm_fusion": auto_parallel_context().get_comm_fusion,
1300
1319
  "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
1301
- "full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
1320
+ "full_batch_is_set": auto_parallel_context().get_full_batch_is_set,
1321
+ "dump_local_norm": auto_parallel_context().get_dump_local_norm}
1302
1322
 
1303
1323
 
1304
1324
  @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@@ -16,11 +16,16 @@
16
16
  from __future__ import absolute_import
17
17
  from __future__ import division
18
18
 
19
+ import numpy as np
20
+
21
+ from mindspore import context
19
22
  from mindspore.nn.cell import Cell
20
23
  from mindspore.ops import operations as P
21
24
  from mindspore.ops.operations.comm_ops import AllGather
22
25
  from mindspore.communication import GlobalComm
23
26
  from mindspore.common import jit
27
+ from mindspore.communication import create_group
28
+ from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
24
29
 
25
30
  _ALLGATHER_CELL = None
26
31
 
@@ -30,6 +35,7 @@ class AllGatherCell(Cell):
30
35
  Allgather cell, used in model parallel scenario.
31
36
  To allgather the selected parameter slice from each device.
32
37
  """
38
+
33
39
  def __init__(self, group, do_reshape, after_reshape_slice_shape):
34
40
  super(AllGatherCell, self).__init__(auto_prefix=False)
35
41
  self.allgather = AllGather(group)
@@ -54,6 +60,7 @@ class SaveOptShardCkptCell(Cell):
54
60
  Note:
55
61
  This could be optimized later with less communication consumption.
56
62
  """
63
+
57
64
  def __init__(self, group, do_reshape, after_reshape_slice_shape):
58
65
  super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
59
66
  self.allgather1 = AllGather(group)
@@ -71,6 +78,21 @@ class SaveOptShardCkptCell(Cell):
71
78
  return x
72
79
 
73
80
 
81
+ class SingleCommunicator(Cell):
82
+ """
83
+ Used to broadcast single parameter.
84
+ """
85
+
86
+ def __init__(self, group_name):
87
+ super(SingleCommunicator, self).__init__()
88
+ self.allreduce = P.AllReduce(group=group_name)
89
+ self.add_flags(skip_auto_parallel_compile=True)
90
+
91
+ def construct(self, loaded_param):
92
+ result = self.allreduce(loaded_param)
93
+ return result
94
+
95
+
74
96
  def get_allgather_cell(group, need_merge_twice=False, do_reshape=False, after_reshape_slice_shape=()):
75
97
  """Get AllGatherCell object."""
76
98
  global _ALLGATHER_CELL
@@ -89,3 +111,64 @@ def destroy_allgather_cell():
89
111
  global _ALLGATHER_CELL
90
112
  if _ALLGATHER_CELL:
91
113
  _ALLGATHER_CELL = None
114
+
115
+
116
+ def _chang_parallel_context(origin_dataset_strategy):
117
+ """Change the original parallel state."""
118
+ if context.get_context("mode") == context.GRAPH_MODE:
119
+ context.set_auto_parallel_context(parallel_mode="hybrid_parallel")
120
+ if origin_dataset_strategy != "data_parallel":
121
+ context.set_auto_parallel_context(dataset_strategy="data_parallel")
122
+
123
+
124
+ def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
125
+ """Restore the original parallel state."""
126
+ if context.get_context("mode") == context.GRAPH_MODE:
127
+ context.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
128
+ if origin_dataset_strategy != "data_parallel":
129
+ context.set_auto_parallel_context(dataset_strategy=origin_dataset_strategy)
130
+
131
+
132
+ def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
133
+ """
134
+ Broadcast single parameter to other rank in data parallel dimension.
135
+ """
136
+ from mindspore import Tensor
137
+ origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
138
+ origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
139
+ if layout:
140
+ param_redundancy = get_parameter_redundancy(layout, initial_rank)
141
+ else:
142
+ param_redundancy = get_parameter_redundancy(net)
143
+ if not param_redundancy:
144
+ return
145
+ single_params = remove_param_redundancy(param_redundancy)
146
+ if not single_params:
147
+ return
148
+ param_redundancy_reversed = {}
149
+ for key, redundancy in param_redundancy.items():
150
+ for item in redundancy:
151
+ if len(item) == 1:
152
+ continue
153
+ if cur_rank in item:
154
+ param_redundancy_reversed.setdefault(item, []).append(key)
155
+ if not param_redundancy_reversed or cur_rank not in single_params:
156
+ return
157
+ net_param_dict = net.parameters_dict()
158
+ _chang_parallel_context(origin_dataset_strategy)
159
+ for group, params in param_redundancy_reversed.items():
160
+ create_group(str(group), list(group))
161
+ allreduce_input = []
162
+ for param in params:
163
+ if param not in net_param_dict:
164
+ continue
165
+ real_param = net_param_dict[param]
166
+ if param not in single_params[cur_rank]:
167
+ real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
168
+ allreduce_input.append(real_param)
169
+ if not allreduce_input:
170
+ continue
171
+ communicator = SingleCommunicator(str(group))
172
+ for real_param in allreduce_input:
173
+ real_param.set_data(communicator(real_param), real_param.sliced)
174
+ _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
@@ -24,7 +24,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
24
24
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
25
25
  _extract_layout_item
26
26
 
27
-
28
27
  MAX_PATH_LENGTH = 1024
29
28
 
30
29
 
@@ -37,14 +36,17 @@ def _convert_to_list(strategy, rank_id=None):
37
36
  dev_mat = list(layout.dev_matrix[0].dim)
38
37
  tensor_map = list(layout.tensor_map[0].dim)
39
38
  param_split_shape = list(layout.param_split_shape[0].dim)
39
+ field_size = int(layout.field)
40
+ shard_stride = int(layout.opt_weight_shard_step)
41
+ shard_size = int(layout.opt_weight_shard_size)
40
42
  pipeline_stage = 0
41
43
  origin_param_name = param_name
42
44
  if "-" in param_name:
43
45
  pipeline_stage, origin_param_name = param_name.split("-")
44
46
  pipeline_stage = int(pipeline_stage)
45
47
  if origin_param_name not in train_map:
46
- train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, int(layout.field),
47
- int(layout.opt_weight_shard_step), int(layout.opt_weight_shard_size),
48
+ train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, field_size,
49
+ shard_stride, shard_size,
48
50
  [pipeline_stage]]
49
51
  else:
50
52
  update_pipeline_stage_list = train_map.get(origin_param_name)[6] + [pipeline_stage]
@@ -54,15 +56,15 @@ def _convert_to_list(strategy, rank_id=None):
54
56
  not_device0_nor_pipeline0 = ((rank_id // stage_device_num) > 0) and (pipeline_stage > 0)
55
57
  if is_device0_and_pipeline0 or not_device0_nor_pipeline0:
56
58
  train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
57
- int(layout.field), int(layout.opt_weight_shard_step),
58
- int(layout.opt_weight_shard_size), update_pipeline_stage_list]
59
+ field_size, shard_stride,
60
+ shard_size, update_pipeline_stage_list]
59
61
  else:
60
62
  train_map.get(origin_param_name)[6] = update_pipeline_stage_list
61
63
  else:
62
64
  if np.all(pipeline_stage <= np.array(update_pipeline_stage_list)):
63
65
  train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
64
- int(layout.field), int(layout.opt_weight_shard_step),
65
- int(layout.opt_weight_shard_size), update_pipeline_stage_list]
66
+ field_size, shard_stride,
67
+ shard_size, update_pipeline_stage_list]
66
68
  else:
67
69
  train_map.get(origin_param_name)[6] = update_pipeline_stage_list
68
70
  except BaseException as e:
@@ -174,6 +176,8 @@ def _build_json_strategy(strategy_filename):
174
176
 
175
177
  def _build_searched_strategy(strategy_filename):
176
178
  """build searched strategy"""
179
+ if strategy_filename is None:
180
+ return strategy_filename
177
181
  _check_strategy_file(strategy_filename)
178
182
  if strategy_filename[-5:] != ".json":
179
183
  return _build_protobuf_strategy(strategy_filename)
@@ -239,7 +243,10 @@ def _extract_layout_map(strategy_file, rank_id=None):
239
243
  """Extract layout map"""
240
244
  layout_map = None
241
245
  if strategy_file is not None:
242
- src_strategy = _build_searched_strategy(strategy_file)
246
+ if not isinstance(strategy_file, dict):
247
+ src_strategy = _build_searched_strategy(strategy_file)
248
+ else:
249
+ src_strategy = strategy_file
243
250
  layout_map = _convert_to_list(src_strategy, rank_id)
244
251
  return layout_map
245
252
 
@@ -248,7 +255,10 @@ def _extract_pipeline_stage_num(strategy_file):
248
255
  """extract pipeline stage num"""
249
256
  pipeline_stage_num = 1
250
257
  if strategy_file is not None:
251
- src_strategy = _build_searched_strategy(strategy_file)
258
+ if not isinstance(strategy_file, dict):
259
+ src_strategy = _build_searched_strategy(strategy_file)
260
+ else:
261
+ src_strategy = strategy_file
252
262
  layout_map = _convert_to_list(src_strategy)
253
263
  pipeline_stage_set = set()
254
264
  for _, layout in layout_map.items():
@@ -323,7 +333,10 @@ def _get_device_num_from_strategy(strategy_file=None):
323
333
  """Get device num from strategy file"""
324
334
  if strategy_file is None:
325
335
  return 1
326
- src_strategy = _build_searched_strategy(strategy_file)
336
+ if not isinstance(strategy_file, dict):
337
+ src_strategy = _build_searched_strategy(strategy_file)
338
+ else:
339
+ src_strategy = strategy_file
327
340
  strategy_list = _convert_to_list(src_strategy)
328
341
  device_mat = list(strategy_list.values())[0][0]
329
342
  return np.prod(device_mat)
@@ -341,14 +354,15 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
341
354
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
342
355
  src_strategy_list.get(param_name))
343
356
  from_device_num = np.prod(from_dev_matrix)
344
- fake_tensor_shape = [8] * len(from_tensor_map)
345
357
  to_dev_matrix = [1]
346
- to_tensor_map = [-1] * len(fake_tensor_shape)
358
+ to_tensor_map = [-1] * len(from_tensor_map)
347
359
  to_opt_shard_step = 0
348
360
  to_opt_shard_size = 0
349
361
  if dst_strategy_list is not None:
350
362
  to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
351
363
  dst_strategy_list.get(param_name))
364
+ to_device_num = np.prod(to_dev_matrix)
365
+ fake_tensor_shape = [max(from_device_num, to_device_num)] * len(from_tensor_map)
352
366
  handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
353
367
  to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size)
354
368
  if handled_key in handled_layout:
@@ -433,7 +447,6 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
433
447
  param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
434
448
  device_list, rank_id)
435
449
 
436
-
437
450
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
438
451
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
439
452
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
@@ -443,10 +456,10 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
443
456
  transform_tensor = ms.Tensor(param_total_dict_copy[rank_id % device_num])
444
457
  requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
445
458
  layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
446
- transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
459
+ transform_param = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
447
460
  if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
448
- transform_para.set_dtype(ms.bfloat16)
449
- transform_param_dict[param_name] = transform_para
461
+ transform_param.set_dtype(ms.bfloat16)
462
+ transform_param_dict[param_name] = transform_param
450
463
  if device_num < 1:
451
464
  raise ValueError("None of the parameters in checkpoint file are in either src strategy or "
452
465
  "dst strategy. Please check correctness of strategy files.")
@@ -454,13 +467,13 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
454
467
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
455
468
  for param_name, _ in param_total_dict.items():
456
469
  if param_name not in transform_param_dict:
457
- transform_para = ms.Parameter(
470
+ transform_param = ms.Parameter(
458
471
  ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
459
472
  param_attr_dict[param_name][rank_id % device_num][0],
460
473
  param_attr_dict[param_name][rank_id % device_num][1])
461
474
  if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
462
- transform_para.set_dtype(ms.bfloat16)
463
- transform_param_dict[param_name] = transform_para
475
+ transform_param.set_dtype(ms.bfloat16)
476
+ transform_param_dict[param_name] = transform_param
464
477
 
465
478
  transform_param_list = [{"name": param_name, "data": param_data}
466
479
  for param_name, param_data in transform_param_dict.items()]
@@ -531,3 +544,18 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
531
544
  continue
532
545
  to_slice_tensor_shape += (item // to_tensor_strategy[i],)
533
546
  param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
547
+
548
+
549
+ def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
550
+ """Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
551
+ total_device_num = 1
552
+ for n in device_arrangement:
553
+ total_device_num *= n
554
+ if first_dim_sharded_device_index != len(device_arrangement) - 1:
555
+ return list(range(0, total_device_num))
556
+ first_dim_sharded_size = device_arrangement[-1 - first_dim_sharded_device_index]
557
+ range_size = total_device_num // first_dim_sharded_size
558
+ offset = rank % range_size
559
+ start = rank - offset
560
+ param_total_list = list(range(start, start + range_size))
561
+ return param_total_list
@@ -334,8 +334,10 @@ def _extract_layout_item(layout_item):
334
334
  tensor_map = layout_item[1]
335
335
  opt_shard_step = layout_item[4]
336
336
  opt_shard_size = layout_item[5]
337
+ tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
338
+ model_parallel_shard_size = np.prod(tensor_strategy)
337
339
  if opt_shard_size == -1:
338
- opt_shard_size = np.prod(dev_matrix) // opt_shard_step
340
+ opt_shard_size = np.prod(dev_matrix) // model_parallel_shard_size
339
341
  return dev_matrix, tensor_map, opt_shard_step, opt_shard_size
340
342
 
341
343
 
@@ -406,12 +408,35 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
406
408
  if opt_shard_step == 0 or opt_shard_size == 0:
407
409
  return dev_matrix, tensor_map, list(origin_full_tensor_shape)
408
410
  tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
409
- model_parallel_shard_size = np.prod(tensor_strategy)
410
- if model_parallel_shard_size != opt_shard_step:
411
+ repeated_dim = []
412
+ dev_sharded_index = []
413
+ for dim in tensor_map:
414
+ if dim != -1:
415
+ dev_sharded_index.append(len(dev_matrix) - dim - 1)
416
+ for index, value in enumerate(dev_matrix):
417
+ if index not in dev_sharded_index and value > 1:
418
+ repeated_dim.append(index)
419
+ if not repeated_dim:
420
+ raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
421
+ format(dev_matrix, tensor_map))
422
+ if len(repeated_dim) == 1 and np.prod(dev_matrix[repeated_dim[0] + 1:]) != opt_shard_step:
411
423
  raise ValueError("The optimizer sharding step {} is not equal to the model parallel sharding size {}.".
412
- format(opt_shard_step, model_parallel_shard_size))
413
-
424
+ format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
414
425
  first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
426
+ if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
427
+ tensor_shape_new = list(origin_full_tensor_shape)
428
+ tensor_shape_new[0] = tensor_strategy[0]
429
+ accu_shp = 1
430
+ for i in range(len(repeated_dim) - 1):
431
+ opt_sharding_size = dev_matrix[repeated_dim[i]]
432
+ tensor_shape_new.insert(i + 1, opt_sharding_size)
433
+ accu_shp = accu_shp * opt_sharding_size
434
+ tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
435
+ tensor_map_new = list(copy.deepcopy(tensor_map))
436
+ for index, r_dim in enumerate(repeated_dim):
437
+ tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
438
+ return list(dev_matrix), tensor_map_new, tensor_shape_new
439
+
415
440
  full_tensor_shape = list(origin_full_tensor_shape)
416
441
  full_tensor_shape[0] = tensor_strategy[0]
417
442
  full_tensor_shape.insert(1, first_dim_no_sharding_size)
@@ -452,7 +477,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
452
477
  result_map = {self_rank: transform_operators}
453
478
  for operators in transform_operators:
454
479
  op_name = operators[0]
455
- if op_name == "AllGather":
480
+ if op_name == "AllConcat":
456
481
  groups = operators[1][:-1]
457
482
  stack.append((index, groups))
458
483
  index += 1
@@ -466,7 +491,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
466
491
  index = 0
467
492
  for operators in new_transform_operators:
468
493
  op_name = operators[0]
469
- if op_name == "AllGather" and index < group_info[0]:
494
+ if op_name == "AllConcat" and index < group_info[0]:
470
495
  groups = operators[1][:-1]
471
496
  stack.insert(0, (index, groups))
472
497
  index += 1
@@ -491,7 +516,7 @@ def _generate_transform_operator_stack(transform_operators_map, self_rank):
491
516
  level = queue_front[1]
492
517
  current_operator = queue_front[2]
493
518
  if level >= 1:
494
- if current_operator[0] == "AllGather":
519
+ if current_operator[0] == "AllConcat":
495
520
  current_group = current_operator[1][:-1]
496
521
  for rank_id in current_group:
497
522
  handle_queue.append((rank_id, level - 1, transform_operators_map[rank_id][level - 1]))
@@ -523,7 +548,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
523
548
  if operator[0] != op_name:
524
549
  raise ValueError("The operator in the same level should be equal in the transform tensor operator "
525
550
  "list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
526
- if operator[0] != "AllGather":
551
+ if operator[0] != "AllConcat":
527
552
  tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(tensor_dict[rank_id % device_num],
528
553
  operator)
529
554
  continue
@@ -532,7 +557,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
532
557
  raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
533
558
  allgather_list = [tensor_dict[rank % device_num] for rank in operator[1][:-1]]
534
559
  tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
535
- if op_name == "AllGather":
560
+ if op_name == "AllConcat":
536
561
  for rank, value in tmp_tensor_dict.items():
537
562
  tensor_dict[rank % device_num] = value
538
563
  level_operators.clear()
@@ -621,7 +646,7 @@ def _apply_operator(operator_name):
621
646
  return numpy_data[slice_index]
622
647
 
623
648
  _apply_operator_map = {"Reshape": _apply_reshape_operator, "StridedSlice": _apply_slice_operator,
624
- "AllGather": _apply_allconcat_operator}
649
+ "AllConcat": _apply_allconcat_operator}
625
650
  return _apply_operator_map.get(operator_name)
626
651
 
627
652
 
@@ -658,3 +683,48 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
658
683
  for i in range(1, len(tensor_slices_col)):
659
684
  new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
660
685
  return Tensor(new_tensor)
686
+
687
+
688
+ def _load_tensor_shape(dev_mat, tensor_map, full_shape=None, rank_id=-1):
689
+ """get tensor shape by slice"""
690
+ if rank_id == -1:
691
+ rank = get_rank()
692
+ else:
693
+ rank = rank_id
694
+ tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
695
+ tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
696
+ np_tensor_list = _chunk_shape_by_strategy(full_shape, tensor_strategy)
697
+ np_tensor_slice_index = np_tensor_list[int(tensor_slice_index)]
698
+ res = []
699
+ for index in np_tensor_slice_index:
700
+ res.append(slice(index[0], index[1]))
701
+ return tuple(res)
702
+
703
+
704
+ def _chunk_shape_by_strategy(full_shape, strategy):
705
+ """chunk shape by strategy"""
706
+ shape = []
707
+ for i in full_shape:
708
+ shape.append([0, i])
709
+ return _chunk_shape(shape, strategy, len(strategy))
710
+
711
+
712
+ def _chunk_shape(np_tensor, strategy, depth):
713
+ """_chunk shape"""
714
+ output = []
715
+ axis = len(np_tensor) - depth
716
+ left, right = np_tensor[axis]
717
+ num = strategy[0]
718
+ chunk_size = (right - left) / num
719
+ append = [[i, int(i + chunk_size)] for i in range(left, right) if i % chunk_size == 0]
720
+ np_tensor_new = []
721
+ for i in append:
722
+ np_tensor_tmp = copy.deepcopy(np_tensor)
723
+ np_tensor_tmp[axis] = i
724
+ np_tensor_new.append(np_tensor_tmp)
725
+ if depth == 1:
726
+ return np_tensor_new
727
+ for ret_ in np_tensor_new:
728
+ output.extend(
729
+ _chunk_shape(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
730
+ return output
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Utils of auto parallel"""
16
+ import os
16
17
  from importlib import import_module
17
18
  import numpy as np
18
19
  import mindspore as ms
@@ -22,12 +23,13 @@ from mindspore.common.tensor import Tensor
22
23
  from mindspore.common.dtype import dtype_to_nptype
23
24
  from mindspore.common import dtype as mstype
24
25
  from mindspore.communication.management import get_group_size, get_rank
26
+ from mindspore.communication._comm_helper import _is_initialized
25
27
  from mindspore.parallel._auto_parallel_context import auto_parallel_context
26
28
  from mindspore.common.seed import get_seed
27
29
  from mindspore._c_expression import GraphExecutor_
28
30
  from mindspore.parallel._tensor import _load_tensor_by_layout
29
31
 
30
- SUPPORTED_TUPLE_IN_TUPLE_STRATEGY = ["GroupedMatmul", "FusedInferAttentionScore"]
32
+ SUPPORTED_TUPLE_IN_TUPLE_STRATEGY = ["GroupedMatmul", "FusedInferAttentionScore", "Custom"]
31
33
 
32
34
 
33
35
  def _get_parallel_mode():
@@ -45,6 +47,16 @@ def _is_in_auto_parallel_mode():
45
47
  return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL]
46
48
 
47
49
 
50
+ def _is_parallel_mode():
51
+ if not _is_initialized() or context.get_context('mode') == context.PYNATIVE_MODE:
52
+ return False
53
+ if os.getenv("RUN_MODE") != "predict":
54
+ return False
55
+ if get_group_size() > 1 and _get_parallel_mode() == ms.ParallelMode.STAND_ALONE:
56
+ return True
57
+ return False
58
+
59
+
48
60
  def _is_in_data_parallel_mode():
49
61
  return _get_parallel_mode() == ms.ParallelMode.DATA_PARALLEL
50
62
 
@@ -234,7 +234,7 @@ def set_algo_parameters(**kwargs):
234
234
 
235
235
  Args:
236
236
  fully_use_devices (bool): Whether ONLY searching strategies that fully use all available devices.
237
- Default: ``True`` . For example with 8 devices available, if set ``True`` , strategy (4, 1) will not be
237
+ Default: ``False`` . For example with 8 devices available, if set ``True`` , strategy (4, 1) will not be
238
238
  included in ReLU's candidate strategies, because strategy (4, 1) only utilizes 4 devices.
239
239
  elementwise_op_strategy_follow (bool): Whether the elementwise operator has the consistent strategies as its
240
240
  subsequent operators. Elementwise operators refer to operators that operate on input element by element,
@@ -264,14 +264,14 @@ def set_algo_parameters(**kwargs):
264
264
 
265
265
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
266
266
  Please see the `rank table startup
267
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
267
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
268
268
  for more details.
269
269
 
270
270
  For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
271
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
271
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
272
272
 
273
273
  For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
274
- Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
274
+ Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
275
275
 
276
276
  >>> import numpy as np
277
277
  >>> import mindspore as ms
@@ -386,7 +386,7 @@ def reset_algo_parameters():
386
386
 
387
387
  After reset, the values of the attributes are:
388
388
 
389
- - fully_use_devices: True.
389
+ - fully_use_devices: False.
390
390
  - elementwise_op_strategy_follow: False.
391
391
  - enable_algo_approxi: False.
392
392
  - algo_approxi_epsilon: 0.1.