mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,352 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Checkpoint related classes and functions."""
16
+
17
+ import os
18
+ from mindspore.train.serialization import save_checkpoint
19
+ from mindspore.parallel._utils import _get_device_num
20
+ from mindspore import _checkparam as Validator
21
+ from mindspore.train.callback._callback import Callback
22
+ from mindspore import context
23
+ from mindspore.common.parameter import Parameter
24
+ from mindspore.communication import get_rank, get_group_size
25
+ from mindspore import log as logger
26
+ from mindspore.train.serialization import _get_cur_rank_dp
27
+ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
28
+ from mindspore._c_expression import clean_tdt_channel
29
+ from mindspore._c_expression import send_recv
30
+ from mindspore._c_expression import CollectiveManager
31
+ from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
32
+
33
+ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
34
+ """ Common func to generate ckpt dir name."""
35
+ tmp = "_tmp" if is_tmp_file else ""
36
+ mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
37
+ return os.path.join(ckpt_save_path, mid_dir)
38
+
39
+ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
40
+ """ Callback used for TFT save ckpt function when errors occur."""
41
+ logger.info("Enter _save_checkpoint_on_failure function")
42
+ ckpt_save_path = cb_ctx.ckpt_save_path
43
+ cb_params = args
44
+ cur_rank = get_rank()
45
+ cur_step_num = cb_params.cur_step_num
46
+ cur_epoch_num = cb_params.cur_epoch_num
47
+ batch_num = cb_params.batch_num
48
+ if cur_step_num > step:
49
+ cur_epoch_num = (step - 1) // batch_num + 1
50
+ step_num_in_epoch = int((step - 1) % batch_num + 1)
51
+
52
+ append_dict = {}
53
+ append_dict["epoch_num"] = cur_epoch_num
54
+ append_dict["step_num"] = step
55
+ append_dict["cur_rank"] = cur_rank
56
+ append_dict["batch_num"] = batch_num
57
+ append_dict["__exception_save__"] = True
58
+
59
+ append_dict["global_step"] = Parameter([cb_ctx.global_step])
60
+ outputs = cb_params.net_outputs
61
+ if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
62
+ append_dict["loss_scale"] = outputs[2]
63
+
64
+ ckpt_file = f"ttp_rank_{str(cur_rank)}-{str(cur_epoch_num)}_{str(step_num_in_epoch)}.ckpt"
65
+ cur_ckpt_dir = _get_ckpt_dir(step, ckpt_save_path, True) + "/rank_" + str(cur_rank)
66
+ os.makedirs(cur_ckpt_dir, exist_ok=True)
67
+ cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
68
+ save_checkpoint(cb_params.train_network, cur_file,
69
+ integrated_save=False, append_dict=append_dict)
70
+ logger.info("Finish _save_checkpoint_on_failure function")
71
+
72
+ def _rename_save_result(step, cb_ctx):
73
+ """ Callback used for TFT rename function after ckpt save callback was finished and successful."""
74
+ logger.info("Enter _rename_save_result function")
75
+ tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
76
+ fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
77
+
78
+ os.rename(tmp_dir, fin_dir)
79
+ logger.info("Finish _rename_save_result function")
80
+
81
+ def _tft_exit_cb(ctx):
82
+ logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
83
+ _tft_sem_post()
84
+ os._exit(1) # pylint: disable=W0212
85
+
86
+
87
+
88
+ def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
89
+ """ Callback used for TFT repair function."""
90
+ logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
91
+ if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
92
+ or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
93
+ logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
94
+ _repair_device(cb_ctx.device_id)
95
+
96
+ if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
97
+ or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
98
+ logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
99
+ {}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
100
+ cb_params = args
101
+ src_rank = repair_info["src"][0]
102
+ dst_rank = repair_info["dst"][0]
103
+ send_recv(cb_params.network.trainable_params(), src_rank, dst_rank)
104
+ logger.info("Finish _tft_repair_callback")
105
+
106
+
107
+ def _tft_clean_callback(is_uce_error, ctx):
108
+ """ Callback used for TFT clean function."""
109
+ logger.info("Enter _tft_clean_callback")
110
+ ret = 0
111
+ if is_uce_error:
112
+ _get_uce_mem_info(ctx.device_id)
113
+ err_strategy = _get_uce_process_strategy()
114
+ logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
115
+ if err_strategy == "RS_UCE_HIGHLEVEL":
116
+ ret = 0
117
+ elif err_strategy == "RS_UCE_LOWLEVEL":
118
+ ret = 2
119
+ else:
120
+ ret = 1
121
+ clean_tdt_channel()
122
+ logger.info("Enter _tft_clean_callback resume_hccl_comm")
123
+ CollectiveManager.get_instance().resume_hccl_comm()
124
+ logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
125
+ return ret
126
+
127
+
128
+ def _tft_stop_callback(cb_ctx):
129
+ """ Callback used for TFT stop function."""
130
+ logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
131
+ _stop_device(cb_ctx.device_id)
132
+ logger.info("Finish _tft_stop_callback")
133
+
134
+
135
+ class TFTRegister(Callback):
136
+ """
137
+ This callback is used to enable the TFT feature
138
+ `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
139
+ This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
140
+
141
+ Note:
142
+ Required for Ascend graph mode only. And sink size must be less than or equal to 1.
143
+
144
+ Args:
145
+ ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
146
+ ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
147
+ ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
148
+ ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
149
+ named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
150
+
151
+ Raises:
152
+ Exception: TFT init failed.
153
+ ModuleNotFoundError: Mindio TFT whl package is not installed.
154
+
155
+ Examples:
156
+ >>> import numpy as np
157
+ >>> import os
158
+ >>> import math
159
+ >>> import mindspore as ms
160
+ >>> import mindspore.dataset as ds
161
+ >>> from mindspore import nn, ops, Parameter, train
162
+ >>> from mindspore.communication import init
163
+ >>> from mindspore.common.initializer import initializer, HeUniform
164
+ >>> from mindspore.train import Model, TFTRegister
165
+ >>> from mindspore import dataset as ds
166
+ >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
167
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
168
+ >>> init()
169
+ >>> ms.set_seed(1)
170
+ >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
171
+ >>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
172
+ >>> class MatMulCell(nn.Cell):
173
+ ... def __init__(self, param=None, shape=None):
174
+ ... super().__init__()
175
+ ... if shape is None:
176
+ ... shape = [28 * 28, 512]
177
+ ... weight_init = HeUniform(math.sqrt(5))
178
+ ... self.param = Parameter(initializer(weight_init, shape), name="param")
179
+ ... if param is not None:
180
+ ... self.param = param
181
+ ... self.print = ops.Print()
182
+ ... self.matmul = ops.MatMul()
183
+ ...
184
+ ... def construct(self, x):
185
+ ... out = self.matmul(x, self.param)
186
+ ... self.print("out is:", out)
187
+ ... return out
188
+ >>>
189
+ >>> class Network(nn.Cell):
190
+ ... def __init__(self):
191
+ ... super().__init__()
192
+ ... self.flatten = nn.Flatten()
193
+ ... self.layer1 = MatMulCell()
194
+ ... self.relu1 = nn.ReLU()
195
+ ... self.layer2 = nn.Dense(512, 512)
196
+ ... self.relu2 = nn.ReLU()
197
+ ... self.layer3 = nn.Dense(512, 10)
198
+ ...
199
+ ... def construct(self, x):
200
+ ... x = self.flatten(x)
201
+ ... x = self.layer1(x)
202
+ ... x = self.relu1(x)
203
+ ... x = self.layer2(x)
204
+ ... x = self.relu2(x)
205
+ ... logits = self.layer3(x)
206
+ ... return logits
207
+ >>>
208
+ >>> net = Network()
209
+ >>> net.layer1.pipeline_stage = 0
210
+ >>> net.relu1.pipeline_stage = 0
211
+ >>> net.layer2.pipeline_stage = 0
212
+ >>> net.relu2.pipeline_stage = 1
213
+ >>> net.layer3.pipeline_stage = 1
214
+ >>>
215
+ >>> def create_dataset(batch_size):
216
+ ... dataset_path = os.getenv("DATA_PATH")
217
+ ... dataset = ds.MnistDataset(dataset_path)
218
+ ... image_transforms = [
219
+ ... ds.vision.Rescale(1.0 / 255.0, 0),
220
+ ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
221
+ ... ds.vision.HWC2CHW()
222
+ ... ]
223
+ ... label_transform = ds.transforms.TypeCast(ms.int32)
224
+ ... dataset = dataset.map(image_transforms, 'image')
225
+ ... dataset = dataset.map(label_transform, 'label')
226
+ ... dataset = dataset.batch(batch_size)
227
+ ... return dataset
228
+ >>>
229
+ >>> data_set = create_dataset(32)
230
+ >>>
231
+ >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
232
+ >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
233
+ >>> loss_fn = nn.CrossEntropyLoss()
234
+ >>>
235
+ >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
236
+ >>> net_with_loss.set_train()
237
+ >>> model = Model(net_with_loss, optimizer=optimizer)
238
+ >>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
239
+ >>> loss_cb = train.LossMonitor(1)
240
+ >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
241
+ """
242
+
243
+ def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
244
+ super(TFTRegister, self).__init__()
245
+
246
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
247
+ if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
248
+ raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
249
+ mode = context.get_context("mode")
250
+ device_target = context.get_context("device_target")
251
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
252
+ raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
253
+
254
+ # let it raise errors if not install mindio_tft package
255
+ from mindio_ttp import framework_ttp as tft
256
+ self.tft = tft
257
+ self.global_step = 0
258
+ Validator.check_non_negative_int(ctrl_port)
259
+ self.has_init_replica = False
260
+ self._controller_ip = ctrl_ip
261
+ self._controller_rank_id = ctrl_rank_id
262
+ self._controller_port = ctrl_port
263
+ self.device_id = context.get_context("device_id")
264
+ self._init_tft()
265
+ self.ckpt_save_path = ckpt_save_path
266
+
267
+ def _set_tft_optimizer_replica(self, run_context):
268
+ """ set Mindio TFT optimizer replica info, used internal. """
269
+ cur_rank = get_rank()
270
+ cb_params = run_context.original_args()
271
+ train_network = cb_params.train_network
272
+ # in data_parallel mode, every ranks has same train parameters
273
+ if context.get_auto_parallel_context("parallel_mode") == "data_parallel":
274
+ group_size = get_group_size()
275
+ dp = tuple(range(group_size))
276
+ else:
277
+ param_layout_dict = train_network.parameter_layout_dict
278
+ dp = _get_cur_rank_dp(param_layout_dict) if param_layout_dict else _get_cur_rank_dp(train_network)
279
+ logger.warning(f"Set TFT replica with dp: {dp}.")
280
+ replica_info = [
281
+ {
282
+ "type": 1,
283
+ "rank_list": dp,
284
+ "replica_cnt": len(dp),
285
+ "replica_shift": 0
286
+ }
287
+ ]
288
+ self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
289
+
290
+ def _init_tft(self):
291
+ """ Init Mindio TFT, used internal. """
292
+ logger.info("Begin to init tft.")
293
+ self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
294
+ self.tft.tft_register_rename_handler(_rename_save_result, self)
295
+ self.tft.tft_register_exit_handler(_tft_exit_cb, self)
296
+ self.tft.tft_register_stop_handler(_tft_stop_callback, self)
297
+ self.tft.tft_register_clean_handler(_tft_clean_callback, self)
298
+ self.tft.tft_register_repair_handler(_tft_repair_callback, self)
299
+
300
+ world_size = _get_device_num()
301
+ cur_rank = get_rank()
302
+ enable_local_copy = False
303
+ enable_arf = False
304
+ enable_zit = False
305
+ enable_tls = False
306
+ tls_key_dir = ""
307
+
308
+ if cur_rank == self._controller_rank_id:
309
+ logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
310
+ self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf, enable_zit)
311
+ self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
312
+ logger.info("Finish start tft controller.")
313
+
314
+ logger.info("Begin to start tft processor.")
315
+ self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
316
+ self.tft.tft_start_processor(self._controller_ip, self._controller_port)
317
+ logger.info("Finished start tft processor.")
318
+
319
+ def on_train_step_end(self, run_context):
320
+ """
321
+ And report status to MindIO TFT after every step finished.
322
+
323
+ Args:
324
+ run_context (RunContext): Context of the train running. Refer to
325
+ :class:`mindspore.train.RunContext` for detail.
326
+ """
327
+ if self.has_init_replica is False:
328
+ self.has_init_replica = True
329
+ self._set_tft_optimizer_replica(run_context)
330
+ cb_params = run_context.original_args()
331
+ if cb_params.optimizer is not None:
332
+ self.global_step = int(cb_params.optimizer.global_step.data)
333
+ else:
334
+ self.global_step = int(cb_params.network.optimizer.global_step.data)
335
+ logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
336
+ self.tft.tft_end_updating_os(cb_params.cur_step_num)
337
+ logger.info("END Set optimizer finish step status to TFT.")
338
+
339
+
340
+ def on_train_begin(self, run_context):
341
+ cb_params = run_context.original_args()
342
+ sink_size = cb_params.get("sink_size", 0)
343
+ if sink_size > 1:
344
+ raise ValueError("TFT feature doesn't support sink_size > 1.")
345
+ logger.info("Set set args to TFT.")
346
+ self.tft.tft_set_step_args(cb_params)
347
+
348
+ def end(self, run_context):
349
+ cur_rank = get_rank()
350
+ if cur_rank == self._controller_rank_id:
351
+ self.tft.tft_destroy_controller()
352
+ self.tft.tft_destroy_processor()
@@ -213,7 +213,8 @@ def _get_dataset_aux(dataset):
213
213
  def connect_network_with_dataset(network, dataset_helper):
214
214
  """
215
215
  Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
216
- <https://mindspore.cn/tutorials/experts/en/master/optimize/execution_opt.html>`_, (dataset_sink_mode=True).
216
+ <https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
217
+ (dataset_sink_mode=True).
217
218
 
218
219
  Args:
219
220
  network (Cell): The training network for dataset.
@@ -261,7 +262,9 @@ def connect_network_with_dataset(network, dataset_helper):
261
262
  "The dataset has been connected to other network, please check the code.")
262
263
  is_dynamic = bool(network.get_inputs())
263
264
  queue_name = dataset.__transfer_dataset__.queue_name
264
- if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic):
265
+ # In pipeline parallel, some stages have no GetNext, should not get in.
266
+ use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
267
+ if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel:
265
268
  dataset_types, dataset_shapes = dataset_helper.get_data_info()
266
269
  # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
267
270
  if _need_to_full():
@@ -302,7 +305,8 @@ def connect_network_with_dataset(network, dataset_helper):
302
305
  dataset_types, dataset_shapes = dataset_helper.types_shapes()
303
306
  aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
304
307
 
305
- if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic):
308
+ if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
309
+ not use_pipeline_parallel:
306
310
  dataset_helper.get_data_info()
307
311
  network.add_flags(sink_mode=True)
308
312
  return network
@@ -200,7 +200,7 @@ class Metric(metaclass=ABCMeta):
200
200
 
201
201
  Tutorial Examples:
202
202
  - `Evaluation Metrics - Customized Metrics
203
- <https://mindspore.cn/tutorials/en/master/advanced/model/metric.html#customized-metrics>`_
203
+ <https://mindspore.cn/docs/en/master/model_train/train_process/model/metric.html#customized-metrics>`_
204
204
  """
205
205
  raise NotImplementedError('Must define clear function to use this base class')
206
206
 
@@ -214,7 +214,7 @@ class Metric(metaclass=ABCMeta):
214
214
 
215
215
  Tutorial Examples:
216
216
  - `Evaluation Metrics - Customized Metrics
217
- <https://mindspore.cn/tutorials/en/master/advanced/model/metric.html#customized-metrics>`_
217
+ <https://mindspore.cn/docs/en/master/model_train/train_process/model/metric.html#customized-metrics>`_
218
218
  """
219
219
  raise NotImplementedError('Must define eval function to use this base class')
220
220
 
@@ -231,7 +231,7 @@ class Metric(metaclass=ABCMeta):
231
231
 
232
232
  Tutorial Examples:
233
233
  - `Evaluation Metrics - Customized Metrics
234
- <https://mindspore.cn/tutorials/en/master/advanced/model/metric.html#customized-metrics>`_
234
+ <https://mindspore.cn/docs/en/master/model_train/train_process/model/metric.html#customized-metrics>`_
235
235
  """
236
236
  raise NotImplementedError('Must define update function to use this base class')
237
237
 
@@ -42,18 +42,18 @@ class ROC(Metric):
42
42
  >>> from mindspore.train import ROC
43
43
  >>>
44
44
  >>> # 1) binary classification example
45
- >>> x = Tensor(np.array([3, 1, 4, 2]))
45
+ >>> x = Tensor(np.array([0.28, 0.55, 0.15, 0.05]))
46
46
  >>> y = Tensor(np.array([0, 1, 2, 3]))
47
47
  >>> metric = ROC(pos_label=2)
48
48
  >>> metric.clear()
49
49
  >>> metric.update(x, y)
50
50
  >>> fpr, tpr, thresholds = metric.eval()
51
51
  >>> print(fpr)
52
- [0. 0. 0.33333333 0.6666667 1.]
52
+ [0. 0.33333333 0.66666667 0.66666667 1. ]
53
53
  >>> print(tpr)
54
- [0. 1. 1. 1. 1.]
54
+ [0. 0. 0. 1. 1.]
55
55
  >>> print(thresholds)
56
- [5 4 3 2 1]
56
+ [1.55 0.55 0.28 0.15 0.05]
57
57
  >>>
58
58
  >>> # 2) multiclass classification example
59
59
  >>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],