mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -76,6 +76,7 @@ class _PipelineConfig:
76
76
  class _PipelineScheduler:
77
77
  PIPELINE_1F1B = "1f1b"
78
78
  PIPELINE_GPIPE = "gpipe"
79
+ PIPELINE_SEQPIPE = "seqpipe"
79
80
 
80
81
 
81
82
  class _AutoParallelContext:
@@ -168,6 +169,24 @@ class _AutoParallelContext:
168
169
  self.check_context_handle()
169
170
  return _ParallelFusionConfig.CONFIG
170
171
 
172
+ def set_dump_local_norm(self, dump_local_norm):
173
+ """
174
+ Set dump local norm for auto parallel.
175
+
176
+ Args:
177
+ dump_local_norm (bool): User need to specify if he want to dump local norm. Default: False
178
+
179
+ Raises:
180
+ KeyError: When key of comm_fusion is not 'allreduce'.
181
+ """
182
+ self.check_context_handle()
183
+ self._context_handle.set_dump_local_norm(dump_local_norm)
184
+
185
+ def get_dump_local_norm(self):
186
+ """Get dump local norm."""
187
+ self.check_context_handle()
188
+ return self._context_handle.get_dump_local_norm()
189
+
171
190
  def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
172
191
  """
173
192
  Set fusion threshold (MB) for auto parallel.
@@ -584,7 +603,7 @@ class _AutoParallelContext:
584
603
  self.check_context_handle()
585
604
  dir_path = os.path.dirname(strategy_ckpt_save_file)
586
605
  if dir_path and not os.path.exists(dir_path):
587
- os.makedirs(dir_path, exist_ok=True)
606
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
588
607
  self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
589
608
 
590
609
  def get_strategy_ckpt_save_file(self):
@@ -643,7 +662,7 @@ class _AutoParallelContext:
643
662
  self.check_context_handle()
644
663
  dir_path = os.path.dirname(group_ckpt_save_file)
645
664
  if dir_path and not os.path.exists(dir_path):
646
- os.makedirs(dir_path)
665
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
647
666
  self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
648
667
 
649
668
  def get_parameter_broadcast_is_set(self):
@@ -896,7 +915,8 @@ class _AutoParallelContext:
896
915
  pipeline_config[pp_interleave])
897
916
 
898
917
  Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
899
- _PipelineScheduler.PIPELINE_GPIPE])
918
+ _PipelineScheduler.PIPELINE_GPIPE,
919
+ _PipelineScheduler.PIPELINE_SEQPIPE])
900
920
  if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
901
921
  raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
902
922
 
@@ -1117,9 +1137,9 @@ class _AutoParallelContext:
1117
1137
  """
1118
1138
  self.check_context_handle()
1119
1139
  if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
1120
- return
1140
+ self.set_enable_all_gather_fusion(True)
1121
1141
  if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
1122
- return
1142
+ self.set_enable_reduce_scatter_fusion(True)
1123
1143
  if not isinstance(comm_fusion, dict):
1124
1144
  raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
1125
1145
  comm_type, type(comm_fusion)))
@@ -1153,7 +1173,7 @@ class _AutoParallelContext:
1153
1173
  """
1154
1174
  self.check_context_handle()
1155
1175
  if not self.get_enable_all_reduce_fusion():
1156
- return
1176
+ self.set_enable_all_reduce_fusion(True)
1157
1177
  if not isinstance(comm_fusion, dict):
1158
1178
  raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
1159
1179
  type(comm_fusion)))
@@ -1210,7 +1230,7 @@ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
1210
1230
  """
1211
1231
  dir_path = os.path.dirname(path)
1212
1232
  if dir_path and not os.path.exists(dir_path):
1213
- os.makedirs(dir_path)
1233
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
1214
1234
  check_type = ["SAVE", "LOAD"]
1215
1235
  check_mode = ["all", "principal"]
1216
1236
  if type in check_type and mode in check_mode:
@@ -1266,7 +1286,8 @@ _set_auto_parallel_context_func_map = {
1266
1286
  "sharding_propagation": auto_parallel_context().set_sharding_propagation,
1267
1287
  "enable_alltoall": auto_parallel_context().set_enable_alltoall,
1268
1288
  "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
1269
- "comm_fusion": auto_parallel_context().set_comm_fusion}
1289
+ "comm_fusion": auto_parallel_context().set_comm_fusion,
1290
+ "dump_local_norm": auto_parallel_context().set_dump_local_norm}
1270
1291
 
1271
1292
  _get_auto_parallel_context_func_map = {
1272
1293
  "device_num": auto_parallel_context().get_device_num,
@@ -1298,7 +1319,8 @@ _get_auto_parallel_context_func_map = {
1298
1319
  "enable_alltoall": auto_parallel_context().get_enable_alltoall,
1299
1320
  "comm_fusion": auto_parallel_context().get_comm_fusion,
1300
1321
  "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
1301
- "full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
1322
+ "full_batch_is_set": auto_parallel_context().get_full_batch_is_set,
1323
+ "dump_local_norm": auto_parallel_context().get_dump_local_norm}
1302
1324
 
1303
1325
 
1304
1326
  @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@@ -16,11 +16,16 @@
16
16
  from __future__ import absolute_import
17
17
  from __future__ import division
18
18
 
19
+ import numpy as np
20
+
21
+ from mindspore import context
19
22
  from mindspore.nn.cell import Cell
20
23
  from mindspore.ops import operations as P
21
24
  from mindspore.ops.operations.comm_ops import AllGather
22
25
  from mindspore.communication import GlobalComm
23
26
  from mindspore.common import jit
27
+ from mindspore.communication import create_group
28
+ from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
24
29
 
25
30
  _ALLGATHER_CELL = None
26
31
 
@@ -30,6 +35,7 @@ class AllGatherCell(Cell):
30
35
  Allgather cell, used in model parallel scenario.
31
36
  To allgather the selected parameter slice from each device.
32
37
  """
38
+
33
39
  def __init__(self, group, do_reshape, after_reshape_slice_shape):
34
40
  super(AllGatherCell, self).__init__(auto_prefix=False)
35
41
  self.allgather = AllGather(group)
@@ -54,6 +60,7 @@ class SaveOptShardCkptCell(Cell):
54
60
  Note:
55
61
  This could be optimized later with less communication consumption.
56
62
  """
63
+
57
64
  def __init__(self, group, do_reshape, after_reshape_slice_shape):
58
65
  super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
59
66
  self.allgather1 = AllGather(group)
@@ -71,6 +78,21 @@ class SaveOptShardCkptCell(Cell):
71
78
  return x
72
79
 
73
80
 
81
+ class SingleCommunicator(Cell):
82
+ """
83
+ Used to broadcast single parameter.
84
+ """
85
+
86
+ def __init__(self, group_name):
87
+ super(SingleCommunicator, self).__init__()
88
+ self.allreduce = P.AllReduce(group=group_name)
89
+ self.add_flags(skip_auto_parallel_compile=True)
90
+
91
+ def construct(self, loaded_param):
92
+ result = self.allreduce(loaded_param)
93
+ return result
94
+
95
+
74
96
  def get_allgather_cell(group, need_merge_twice=False, do_reshape=False, after_reshape_slice_shape=()):
75
97
  """Get AllGatherCell object."""
76
98
  global _ALLGATHER_CELL
@@ -89,3 +111,66 @@ def destroy_allgather_cell():
89
111
  global _ALLGATHER_CELL
90
112
  if _ALLGATHER_CELL:
91
113
  _ALLGATHER_CELL = None
114
+
115
+
116
+ def _chang_parallel_context(origin_dataset_strategy):
117
+ """Change the original parallel state."""
118
+ if context.get_context("mode") == context.GRAPH_MODE:
119
+ context.set_auto_parallel_context(parallel_mode="hybrid_parallel")
120
+ if origin_dataset_strategy != "data_parallel":
121
+ context.set_auto_parallel_context(dataset_strategy="data_parallel")
122
+
123
+
124
+ def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
125
+ """Restore the original parallel state."""
126
+ if context.get_context("mode") == context.GRAPH_MODE:
127
+ context.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
128
+ if origin_dataset_strategy != "data_parallel":
129
+ if origin_dataset_strategy is not None and isinstance(origin_dataset_strategy, list):
130
+ origin_dataset_strategy = tuple(tuple(ds_item) for ds_item in origin_dataset_strategy)
131
+ context.set_auto_parallel_context(dataset_strategy=origin_dataset_strategy)
132
+
133
+
134
+ def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
135
+ """
136
+ Broadcast single parameter to other rank in data parallel dimension.
137
+ """
138
+ from mindspore import Tensor
139
+ origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
140
+ origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
141
+ if layout:
142
+ param_redundancy = get_parameter_redundancy(layout, initial_rank)
143
+ else:
144
+ param_redundancy = get_parameter_redundancy(net)
145
+ if not param_redundancy:
146
+ return
147
+ single_params = remove_param_redundancy(param_redundancy)
148
+ if not single_params:
149
+ return
150
+ param_redundancy_reversed = {}
151
+ for key, redundancy in param_redundancy.items():
152
+ for item in redundancy:
153
+ if len(item) == 1:
154
+ continue
155
+ if cur_rank in item:
156
+ param_redundancy_reversed.setdefault(item, []).append(key)
157
+ if not param_redundancy_reversed or cur_rank not in single_params:
158
+ return
159
+ net_param_dict = net.parameters_dict()
160
+ _chang_parallel_context(origin_dataset_strategy)
161
+ for group, params in param_redundancy_reversed.items():
162
+ create_group(str(group), list(group))
163
+ allreduce_input = []
164
+ for param in params:
165
+ if param not in net_param_dict:
166
+ continue
167
+ real_param = net_param_dict[param]
168
+ if param not in single_params[cur_rank]:
169
+ real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
170
+ allreduce_input.append(real_param)
171
+ if not allreduce_input:
172
+ continue
173
+ communicator = SingleCommunicator(str(group))
174
+ for real_param in allreduce_input:
175
+ real_param.set_data(communicator(real_param), real_param.sliced)
176
+ _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
@@ -24,7 +24,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
24
24
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
25
25
  _extract_layout_item
26
26
 
27
-
28
27
  MAX_PATH_LENGTH = 1024
29
28
 
30
29
 
@@ -37,14 +36,17 @@ def _convert_to_list(strategy, rank_id=None):
37
36
  dev_mat = list(layout.dev_matrix[0].dim)
38
37
  tensor_map = list(layout.tensor_map[0].dim)
39
38
  param_split_shape = list(layout.param_split_shape[0].dim)
39
+ field_size = int(layout.field)
40
+ shard_stride = int(layout.opt_weight_shard_step)
41
+ shard_size = int(layout.opt_weight_shard_size)
40
42
  pipeline_stage = 0
41
43
  origin_param_name = param_name
42
44
  if "-" in param_name:
43
45
  pipeline_stage, origin_param_name = param_name.split("-")
44
46
  pipeline_stage = int(pipeline_stage)
45
47
  if origin_param_name not in train_map:
46
- train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, int(layout.field),
47
- int(layout.opt_weight_shard_step), int(layout.opt_weight_shard_size),
48
+ train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, field_size,
49
+ shard_stride, shard_size,
48
50
  [pipeline_stage]]
49
51
  else:
50
52
  update_pipeline_stage_list = train_map.get(origin_param_name)[6] + [pipeline_stage]
@@ -54,15 +56,15 @@ def _convert_to_list(strategy, rank_id=None):
54
56
  not_device0_nor_pipeline0 = ((rank_id // stage_device_num) > 0) and (pipeline_stage > 0)
55
57
  if is_device0_and_pipeline0 or not_device0_nor_pipeline0:
56
58
  train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
57
- int(layout.field), int(layout.opt_weight_shard_step),
58
- int(layout.opt_weight_shard_size), update_pipeline_stage_list]
59
+ field_size, shard_stride,
60
+ shard_size, update_pipeline_stage_list]
59
61
  else:
60
62
  train_map.get(origin_param_name)[6] = update_pipeline_stage_list
61
63
  else:
62
64
  if np.all(pipeline_stage <= np.array(update_pipeline_stage_list)):
63
65
  train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
64
- int(layout.field), int(layout.opt_weight_shard_step),
65
- int(layout.opt_weight_shard_size), update_pipeline_stage_list]
66
+ field_size, shard_stride,
67
+ shard_size, update_pipeline_stage_list]
66
68
  else:
67
69
  train_map.get(origin_param_name)[6] = update_pipeline_stage_list
68
70
  except BaseException as e:
@@ -174,6 +176,8 @@ def _build_json_strategy(strategy_filename):
174
176
 
175
177
  def _build_searched_strategy(strategy_filename):
176
178
  """build searched strategy"""
179
+ if strategy_filename is None:
180
+ return strategy_filename
177
181
  _check_strategy_file(strategy_filename)
178
182
  if strategy_filename[-5:] != ".json":
179
183
  return _build_protobuf_strategy(strategy_filename)
@@ -239,7 +243,10 @@ def _extract_layout_map(strategy_file, rank_id=None):
239
243
  """Extract layout map"""
240
244
  layout_map = None
241
245
  if strategy_file is not None:
242
- src_strategy = _build_searched_strategy(strategy_file)
246
+ if not isinstance(strategy_file, dict):
247
+ src_strategy = _build_searched_strategy(strategy_file)
248
+ else:
249
+ src_strategy = strategy_file
243
250
  layout_map = _convert_to_list(src_strategy, rank_id)
244
251
  return layout_map
245
252
 
@@ -248,7 +255,10 @@ def _extract_pipeline_stage_num(strategy_file):
248
255
  """extract pipeline stage num"""
249
256
  pipeline_stage_num = 1
250
257
  if strategy_file is not None:
251
- src_strategy = _build_searched_strategy(strategy_file)
258
+ if not isinstance(strategy_file, dict):
259
+ src_strategy = _build_searched_strategy(strategy_file)
260
+ else:
261
+ src_strategy = strategy_file
252
262
  layout_map = _convert_to_list(src_strategy)
253
263
  pipeline_stage_set = set()
254
264
  for _, layout in layout_map.items():
@@ -323,7 +333,10 @@ def _get_device_num_from_strategy(strategy_file=None):
323
333
  """Get device num from strategy file"""
324
334
  if strategy_file is None:
325
335
  return 1
326
- src_strategy = _build_searched_strategy(strategy_file)
336
+ if not isinstance(strategy_file, dict):
337
+ src_strategy = _build_searched_strategy(strategy_file)
338
+ else:
339
+ src_strategy = strategy_file
327
340
  strategy_list = _convert_to_list(src_strategy)
328
341
  device_mat = list(strategy_list.values())[0][0]
329
342
  return np.prod(device_mat)
@@ -341,14 +354,15 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
341
354
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
342
355
  src_strategy_list.get(param_name))
343
356
  from_device_num = np.prod(from_dev_matrix)
344
- fake_tensor_shape = [8] * len(from_tensor_map)
345
357
  to_dev_matrix = [1]
346
- to_tensor_map = [-1] * len(fake_tensor_shape)
358
+ to_tensor_map = [-1] * len(from_tensor_map)
347
359
  to_opt_shard_step = 0
348
360
  to_opt_shard_size = 0
349
361
  if dst_strategy_list is not None:
350
362
  to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
351
363
  dst_strategy_list.get(param_name))
364
+ to_device_num = np.prod(to_dev_matrix)
365
+ fake_tensor_shape = [max(from_device_num, to_device_num)] * len(from_tensor_map)
352
366
  handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
353
367
  to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size)
354
368
  if handled_key in handled_layout:
@@ -433,7 +447,6 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
433
447
  param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
434
448
  device_list, rank_id)
435
449
 
436
-
437
450
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
438
451
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
439
452
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
@@ -443,10 +456,10 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
443
456
  transform_tensor = ms.Tensor(param_total_dict_copy[rank_id % device_num])
444
457
  requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
445
458
  layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
446
- transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
459
+ transform_param = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
447
460
  if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
448
- transform_para.set_dtype(ms.bfloat16)
449
- transform_param_dict[param_name] = transform_para
461
+ transform_param.set_dtype(ms.bfloat16)
462
+ transform_param_dict[param_name] = transform_param
450
463
  if device_num < 1:
451
464
  raise ValueError("None of the parameters in checkpoint file are in either src strategy or "
452
465
  "dst strategy. Please check correctness of strategy files.")
@@ -454,13 +467,13 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
454
467
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
455
468
  for param_name, _ in param_total_dict.items():
456
469
  if param_name not in transform_param_dict:
457
- transform_para = ms.Parameter(
470
+ transform_param = ms.Parameter(
458
471
  ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
459
472
  param_attr_dict[param_name][rank_id % device_num][0],
460
473
  param_attr_dict[param_name][rank_id % device_num][1])
461
474
  if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
462
- transform_para.set_dtype(ms.bfloat16)
463
- transform_param_dict[param_name] = transform_para
475
+ transform_param.set_dtype(ms.bfloat16)
476
+ transform_param_dict[param_name] = transform_param
464
477
 
465
478
  transform_param_list = [{"name": param_name, "data": param_data}
466
479
  for param_name, param_data in transform_param_dict.items()]
@@ -531,3 +544,18 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
531
544
  continue
532
545
  to_slice_tensor_shape += (item // to_tensor_strategy[i],)
533
546
  param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
547
+
548
+
549
+ def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
550
+ """Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
551
+ total_device_num = 1
552
+ for n in device_arrangement:
553
+ total_device_num *= n
554
+ if first_dim_sharded_device_index != len(device_arrangement) - 1:
555
+ return list(range(0, total_device_num))
556
+ first_dim_sharded_size = device_arrangement[-1 - first_dim_sharded_device_index]
557
+ range_size = total_device_num // first_dim_sharded_size
558
+ offset = rank % range_size
559
+ start = rank - offset
560
+ param_total_list = list(range(start, start + range_size))
561
+ return param_total_list
@@ -334,8 +334,10 @@ def _extract_layout_item(layout_item):
334
334
  tensor_map = layout_item[1]
335
335
  opt_shard_step = layout_item[4]
336
336
  opt_shard_size = layout_item[5]
337
+ tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
338
+ model_parallel_shard_size = np.prod(tensor_strategy)
337
339
  if opt_shard_size == -1:
338
- opt_shard_size = np.prod(dev_matrix) // opt_shard_step
340
+ opt_shard_size = np.prod(dev_matrix) // model_parallel_shard_size
339
341
  return dev_matrix, tensor_map, opt_shard_step, opt_shard_size
340
342
 
341
343
 
@@ -406,12 +408,35 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
406
408
  if opt_shard_step == 0 or opt_shard_size == 0:
407
409
  return dev_matrix, tensor_map, list(origin_full_tensor_shape)
408
410
  tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
409
- model_parallel_shard_size = np.prod(tensor_strategy)
410
- if model_parallel_shard_size != opt_shard_step:
411
+ repeated_dim = []
412
+ dev_sharded_index = []
413
+ for dim in tensor_map:
414
+ if dim != -1:
415
+ dev_sharded_index.append(len(dev_matrix) - dim - 1)
416
+ for index, value in enumerate(dev_matrix):
417
+ if index not in dev_sharded_index and value > 1:
418
+ repeated_dim.append(index)
419
+ if not repeated_dim:
420
+ raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
421
+ format(dev_matrix, tensor_map))
422
+ if len(repeated_dim) == 1 and np.prod(dev_matrix[repeated_dim[0] + 1:]) != opt_shard_step:
411
423
  raise ValueError("The optimizer sharding step {} is not equal to the model parallel sharding size {}.".
412
- format(opt_shard_step, model_parallel_shard_size))
413
-
424
+ format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
414
425
  first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
426
+ if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
427
+ tensor_shape_new = list(origin_full_tensor_shape)
428
+ tensor_shape_new[0] = tensor_strategy[0]
429
+ accu_shp = 1
430
+ for i in range(len(repeated_dim) - 1):
431
+ opt_sharding_size = dev_matrix[repeated_dim[i]]
432
+ tensor_shape_new.insert(i + 1, opt_sharding_size)
433
+ accu_shp = accu_shp * opt_sharding_size
434
+ tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
435
+ tensor_map_new = list(copy.deepcopy(tensor_map))
436
+ for index, r_dim in enumerate(repeated_dim):
437
+ tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
438
+ return list(dev_matrix), tensor_map_new, tensor_shape_new
439
+
415
440
  full_tensor_shape = list(origin_full_tensor_shape)
416
441
  full_tensor_shape[0] = tensor_strategy[0]
417
442
  full_tensor_shape.insert(1, first_dim_no_sharding_size)
@@ -452,7 +477,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
452
477
  result_map = {self_rank: transform_operators}
453
478
  for operators in transform_operators:
454
479
  op_name = operators[0]
455
- if op_name == "AllGather":
480
+ if op_name == "AllConcat":
456
481
  groups = operators[1][:-1]
457
482
  stack.append((index, groups))
458
483
  index += 1
@@ -466,7 +491,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
466
491
  index = 0
467
492
  for operators in new_transform_operators:
468
493
  op_name = operators[0]
469
- if op_name == "AllGather" and index < group_info[0]:
494
+ if op_name == "AllConcat" and index < group_info[0]:
470
495
  groups = operators[1][:-1]
471
496
  stack.insert(0, (index, groups))
472
497
  index += 1
@@ -491,7 +516,7 @@ def _generate_transform_operator_stack(transform_operators_map, self_rank):
491
516
  level = queue_front[1]
492
517
  current_operator = queue_front[2]
493
518
  if level >= 1:
494
- if current_operator[0] == "AllGather":
519
+ if current_operator[0] == "AllConcat":
495
520
  current_group = current_operator[1][:-1]
496
521
  for rank_id in current_group:
497
522
  handle_queue.append((rank_id, level - 1, transform_operators_map[rank_id][level - 1]))
@@ -523,7 +548,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
523
548
  if operator[0] != op_name:
524
549
  raise ValueError("The operator in the same level should be equal in the transform tensor operator "
525
550
  "list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
526
- if operator[0] != "AllGather":
551
+ if operator[0] != "AllConcat":
527
552
  tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(tensor_dict[rank_id % device_num],
528
553
  operator)
529
554
  continue
@@ -532,7 +557,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
532
557
  raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
533
558
  allgather_list = [tensor_dict[rank % device_num] for rank in operator[1][:-1]]
534
559
  tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
535
- if op_name == "AllGather":
560
+ if op_name == "AllConcat":
536
561
  for rank, value in tmp_tensor_dict.items():
537
562
  tensor_dict[rank % device_num] = value
538
563
  level_operators.clear()
@@ -565,6 +590,8 @@ def _apply_operator(operator_name):
565
590
  Returns:
566
591
  The data of tensor after apply operator.
567
592
  """
593
+ if str(type(numpy_data)) == "<class 'builtins.PySafeSlice'>":
594
+ numpy_data = numpy_data[:]
568
595
  if not isinstance(numpy_data, np.ndarray):
569
596
  raise TypeError("The data should be a numpy.ndarray.")
570
597
  _check_operator(reshape_op)
@@ -604,8 +631,6 @@ def _apply_operator(operator_name):
604
631
  Returns:
605
632
  The data of tensor after apply operator.
606
633
  """
607
- if not isinstance(numpy_data, np.ndarray):
608
- raise TypeError("The data should be a numpy.ndarray.")
609
634
  _check_operator(slice_op)
610
635
  if len(slice_op[1]) % 3 != 0:
611
636
  raise ValueError("The slice operator information is wrong.")
@@ -621,7 +646,7 @@ def _apply_operator(operator_name):
621
646
  return numpy_data[slice_index]
622
647
 
623
648
  _apply_operator_map = {"Reshape": _apply_reshape_operator, "StridedSlice": _apply_slice_operator,
624
- "AllGather": _apply_allconcat_operator}
649
+ "AllConcat": _apply_allconcat_operator}
625
650
  return _apply_operator_map.get(operator_name)
626
651
 
627
652
 
@@ -658,3 +683,92 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
658
683
  for i in range(1, len(tensor_slices_col)):
659
684
  new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
660
685
  return Tensor(new_tensor)
686
+
687
+
688
+ def _load_tensor_shape(dev_mat, tensor_map, full_shape=None, rank_id=-1):
689
+ """get tensor shape by slice"""
690
+ if rank_id == -1:
691
+ rank = get_rank()
692
+ else:
693
+ rank = rank_id
694
+ tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
695
+ tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
696
+ np_tensor_list = _chunk_shape_by_strategy(full_shape, tensor_strategy)
697
+ np_tensor_slice_index = np_tensor_list[int(tensor_slice_index)]
698
+ res = []
699
+ for index in np_tensor_slice_index:
700
+ res.append(slice(index[0], index[1]))
701
+ return tuple(res)
702
+
703
+
704
+ def _count_tensor_shape(dev_mat, tensor_map, full_shape=None, rank_id=-1):
705
+ """get tensor shape"""
706
+ if rank_id == -1:
707
+ rank = get_rank()
708
+ else:
709
+ rank = rank_id
710
+ tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
711
+ tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
712
+ np_tensor_list = _chunk_shape_by_strategy(full_shape, tensor_strategy)
713
+ np_tensor_slice_index = np_tensor_list[int(tensor_slice_index)]
714
+ res = []
715
+ for index in np_tensor_slice_index:
716
+ res.append(index[1] - index[0])
717
+ return res
718
+
719
+
720
+ def _load_tensor_shape_by_layout(tensor, layout, rank_id):
721
+ """get tensor shape by layout"""
722
+ if not isinstance(layout, tuple):
723
+ raise TypeError("The layout should be tuple! layout is {}".format(layout))
724
+ if len(layout) < 7:
725
+ raise ValueError("The length of layout must be larger than 6! layout is {}".format(layout))
726
+ slice_shape = layout[2]
727
+ if slice_shape:
728
+ return slice_shape
729
+ tensor_map = layout[1]
730
+ if not tensor_map:
731
+ return tensor.shape
732
+ dev_mat = layout[0]
733
+ uniform_split = layout[4]
734
+ group = layout[5]
735
+ full_shape = layout[6]
736
+ if not full_shape:
737
+ full_shape = tensor.shape
738
+ if uniform_split == 0:
739
+ raise RuntimeError("The load tensor only support uniform split now")
740
+ tensor_slice_shape = _count_tensor_shape(dev_mat, tensor_map, full_shape, rank_id)
741
+ if group:
742
+ # get a totally shard tensor slice for parallel optimizer
743
+ size = get_group_size(group)
744
+ tensor_slice_shape[0] //= size
745
+ return tensor_slice_shape
746
+
747
+
748
+ def _chunk_shape_by_strategy(full_shape, strategy):
749
+ """chunk shape by strategy"""
750
+ shape = []
751
+ for i in full_shape:
752
+ shape.append([0, i])
753
+ return _chunk_shape(shape, strategy, len(strategy))
754
+
755
+
756
+ def _chunk_shape(np_tensor, strategy, depth):
757
+ """_chunk shape"""
758
+ output = []
759
+ axis = len(np_tensor) - depth
760
+ left, right = np_tensor[axis]
761
+ num = strategy[0]
762
+ chunk_size = (right - left) / num
763
+ append = [[i, int(i + chunk_size)] for i in range(left, right) if i % chunk_size == 0]
764
+ np_tensor_new = []
765
+ for i in append:
766
+ np_tensor_tmp = copy.deepcopy(np_tensor)
767
+ np_tensor_tmp[axis] = i
768
+ np_tensor_new.append(np_tensor_tmp)
769
+ if depth == 1:
770
+ return np_tensor_new
771
+ for ret_ in np_tensor_new:
772
+ output.extend(
773
+ _chunk_shape(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
774
+ return output