mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,10 @@ from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleM
27
27
  from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
28
28
  load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \
29
29
  async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model, export_split_mindir, \
30
- load_checkpoint_async, check_checkpoint
30
+ load_checkpoint_async, check_checkpoint, get_ckpt_path_with_strategy
31
31
  from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
32
32
  CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, FlopsUtilizationCollector, \
33
- History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore, MindIOTTPAdapter
33
+ History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore, TFTRegister
34
34
  from mindspore.train.summary import SummaryRecord
35
35
  from mindspore.train.train_thor import ConvertNetUtils, ConvertModelUtils
36
36
  from mindspore.train.metrics import *
@@ -40,7 +40,8 @@ __all__ = ["Model", "DatasetHelper", "connect_network_with_dataset", "build_trai
40
40
  "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "check_checkpoint",
41
41
  "load_param_into_net", "export", "load", "export_split_mindir", "parse_print", "build_searched_strategy",
42
42
  "merge_sliced_parameter", "load_distributed_checkpoint", "async_ckpt_thread_status",
43
- "restore_group_info_list", "convert_model", "data_sink", "obfuscate_model", "load_checkpoint_async"]
43
+ "restore_group_info_list", "convert_model", "data_sink", "obfuscate_model", "load_checkpoint_async",
44
+ "get_ckpt_path_with_strategy"]
44
45
  __all__.extend(callback.__all__)
45
46
  __all__.extend(summary.__all__)
46
47
  __all__.extend(train_thor.__all__)
mindspore/train/_utils.py CHANGED
@@ -16,6 +16,8 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  import os
19
+ import threading
20
+ from datetime import datetime
19
21
  import json
20
22
  from collections.abc import Iterable
21
23
 
@@ -25,15 +27,18 @@ from mindspore.common.tensor import Tensor
25
27
  from mindspore._c_expression import Tensor as Tensor_
26
28
  from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
27
29
  from mindspore.common import dtype as mstype
30
+ from mindspore import context
28
31
  from mindspore import log as logger
29
32
  from mindspore import _checkparam as Validator
30
33
  from mindspore.common.api import _cell_graph_executor
34
+ from mindspore.communication import get_group_size
31
35
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
32
36
  from mindspore.train.checkpoint_pb2 import Checkpoint
33
37
  from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
34
38
  from mindspore.train.lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
35
39
  from mindspore.parallel._parallel_serialization import _make_dir
36
40
  from mindspore.ops.operations import debug_ops
41
+ from mindspore.nn.cell import Cell
37
42
 
38
43
 
39
44
  def _convert_type(types):
@@ -71,6 +76,18 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
71
76
  queue_name = _cell_graph_executor.get_queue_name(phase)
72
77
  if queue_name is None:
73
78
  queue_name = str("")
79
+
80
+ use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
81
+
82
+ # temp env to disable dynamic feature of sink size 1
83
+ dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
84
+ dynamic_sink1 = True
85
+ if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
86
+ dynamic_sink1 = False
87
+
88
+ if use_pipeline_parallel or not dynamic_sink1:
89
+ create_data_info_queue = False
90
+
74
91
  exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
75
92
  create_data_info_queue=create_data_info_queue, queue_name=queue_name)
76
93
  _cell_graph_executor.init_dataset(exec_dataset.queue_name,
@@ -295,10 +312,68 @@ def parse_strategy_ckpt(file_name):
295
312
 
296
313
  for ele in param.parallel_layouts.tensor_map[0].ListFields()[0][1]:
297
314
  tensor_map.append(ele)
298
- layout_dict[param.param_name] = [dev_matrix, tensor_map]
315
+ layout_dict[param.param_name] = [dev_matrix, tensor_map, param.parallel_layouts.opt_weight_shard_step,
316
+ param.parallel_layouts.opt_weight_shard_size]
299
317
  return layout_dict
300
318
 
301
319
 
320
+ def _get_strategy_opt_shard(param_redundancy_dict, parameter_layout_opt_shard):
321
+ """Strategy ckpt append opt shard."""
322
+ for key, value in parameter_layout_opt_shard.items():
323
+ if value[1] not in (-1, 0):
324
+ opt_para_num = value[1]
325
+ param_redundancy_ranks = param_redundancy_dict.get(key)
326
+ res = []
327
+ for param_ranks in param_redundancy_ranks:
328
+ if len(param_ranks) % opt_para_num == 0:
329
+ for i in range(0, opt_para_num):
330
+ res.append(param_ranks[i::opt_para_num])
331
+ param_redundancy_dict[key] = tuple(res)
332
+
333
+
334
+ def _get_layout_opt_shard(layout_obj, param_redundancy_dict):
335
+ """Layout ckpt append opt shard."""
336
+ for key, value in layout_obj.items():
337
+ if value[5]:
338
+ world_groups = ("hccl_world_group", "nccl_world_group", "mccl_world_group")
339
+ if value[5] in world_groups:
340
+ opt_para_num = get_group_size()
341
+ elif "-" in value[5]:
342
+ opt_para_str = value[5].split("-")[0]
343
+ opt_para_num = int(opt_para_str)
344
+ else:
345
+ raise ValueError(f"For get_parameter_redundancy, the format of the parallel communication domain for "
346
+ f"the optimizer is incorrect.")
347
+ param_redundancy_ranks = param_redundancy_dict.get(key)
348
+ res = []
349
+ for param_ranks in param_redundancy_ranks:
350
+ if len(param_ranks) % opt_para_num == 0:
351
+ for i in range(0, opt_para_num):
352
+ res.append(param_ranks[i::opt_para_num])
353
+ param_redundancy_dict[key] = tuple(res)
354
+
355
+
356
+ def _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundancy_dict, initial_rank):
357
+ """Get parameter redundancy without opt shard."""
358
+ for key, (slices, deploy_loc, *_) in parameter_layout.items():
359
+ redundancy_matrix = np.zeros(shape=slices + [len(slices)], dtype=np.int8)
360
+ for i in deploy_loc:
361
+ internal_slice = tuple(slice(None) for _ in range(i))
362
+ for j in range(slices[-i - 1]):
363
+ if i == -1:
364
+ continue
365
+ else:
366
+ redundancy_matrix[(..., j) + internal_slice + (i,)] = j
367
+ locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist()
368
+ redundancy_dict = {}
369
+ for index, locate in enumerate(locate_list):
370
+ redundancy_dict.setdefault(tuple(locate), []).append(index + initial_rank)
371
+ redundancy_list = []
372
+ for _, indices in sorted(redundancy_dict.items()):
373
+ redundancy_list.append(tuple(indices))
374
+ param_redundancy_dict[key] = tuple(redundancy_list)
375
+
376
+
302
377
  def get_parameter_redundancy(layout_obj, initial_rank=0):
303
378
  """
304
379
  Get parameter redundancy map.
@@ -319,31 +394,31 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
319
394
  'param4': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15))}
320
395
  """
321
396
  if isinstance(layout_obj, str):
322
- parameter_layout = parse_strategy_ckpt(layout_obj)
397
+ parameter_layout_total = parse_strategy_ckpt(layout_obj)
398
+ parameter_layout = {}
399
+ parameter_layout_opt_shard = {}
400
+ for key, value in parameter_layout_total.items():
401
+ parameter_layout[key] = value[0:2]
402
+ parameter_layout_opt_shard[key] = value[2:]
403
+ elif isinstance(layout_obj, Cell):
404
+ from mindspore.communication.management import get_process_group_ranks
405
+ groups_ranks = (tuple(get_process_group_ranks()),)
406
+ param_redundancy_dict = {param.name: groups_ranks for _, param in layout_obj.parameters_and_names()}
407
+ return param_redundancy_dict
323
408
  else:
324
409
  parameter_layout = {}
325
410
  for k, v in layout_obj.items():
326
411
  parameter_layout[k] = v[:2]
327
412
 
328
413
  param_redundancy_dict = {}
329
- for key, (slices, deploy_loc, *_) in parameter_layout.items():
330
- redundancy_matrix = np.zeros(shape=slices + [len(slices)], dtype=np.int8)
331
- for i in deploy_loc:
332
- internal_slice = tuple(slice(None) for _ in range(i))
333
- for j in range(slices[-i - 1]):
334
- if i == -1:
335
- continue
336
- else:
337
- redundancy_matrix[(..., j) + internal_slice + (i,)] = j
338
- locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist()
339
- redundancy_dict = {}
340
- for index, locate in enumerate(locate_list):
341
- redundancy_dict.setdefault(tuple(locate), []).append(index+initial_rank)
342
- redundancy_list = []
343
- for _, indices in sorted(redundancy_dict.items()):
344
- redundancy_list.append(tuple(indices))
345
414
 
346
- param_redundancy_dict[key] = tuple(redundancy_list)
415
+ _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundancy_dict, initial_rank)
416
+
417
+ if isinstance(layout_obj, str):
418
+ _get_strategy_opt_shard(param_redundancy_dict, parameter_layout_opt_shard)
419
+ else:
420
+ _get_layout_opt_shard(layout_obj, param_redundancy_dict)
421
+
347
422
  return param_redundancy_dict
348
423
 
349
424
 
@@ -437,3 +512,14 @@ def parse_hccl_file(hccl_file_path):
437
512
  rankid_dict[int(device["rank_id"])] = device["device_ip"]
438
513
 
439
514
  return rankid_dict
515
+
516
+
517
+ def vlog_print(level, module, file, line, message):
518
+ '''Read environment variable VLOG_v and print to log'''
519
+ if os.environ.get("VLOG_v") == level:
520
+ now = datetime.now()
521
+ formatted_time = now.strftime("%Y-%m-%d-%H:%M:%S.%f")[:-3] + f".{now.microsecond // 1000}"
522
+ path = 'mindspore' + file.split("mindspore")[-1]
523
+ pid = os.getpid()
524
+ thread_id = threading.get_ident()
525
+ print(f"[V{level}] {module}({pid},{thread_id},python):{formatted_time} [{path}:{line}] {message}", flush=True)
mindspore/train/amp.py CHANGED
@@ -16,6 +16,9 @@
16
16
  from __future__ import absolute_import
17
17
  import inspect
18
18
  import types
19
+ from typing import Any
20
+ import functools
21
+ import collections
19
22
 
20
23
  import mindspore as ms
21
24
  from mindspore import nn
@@ -29,8 +32,9 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScal
29
32
  from mindspore import boost, context
30
33
  from mindspore.ops import operations as P
31
34
  from mindspore.ops import Primitive
35
+ from mindspore.ops import auto_generate as gen
32
36
  from mindspore import log as logger
33
-
37
+ from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel
34
38
 
35
39
  AMP_WHITE_LIST = [
36
40
  nn.Conv1d,
@@ -52,17 +56,67 @@ AMP_WHITE_LIST = [
52
56
  P.BatchMatMul,
53
57
  P.PReLU,
54
58
  P.ReLU,
55
- P.Ger
59
+ P.Ger,
56
60
  ]
57
61
 
58
-
59
62
  AMP_BLACK_LIST = [
60
63
  nn.BatchNorm1d,
61
64
  nn.BatchNorm2d,
62
65
  nn.BatchNorm3d,
63
- nn.LayerNorm
66
+ nn.LayerNorm,
64
67
  ]
65
68
 
69
+ AMP_AUTO_WHITE_LIST = [
70
+ P.Conv2D,
71
+ P.Conv3D,
72
+ P.Conv2DTranspose,
73
+ P.Conv3DTranspose,
74
+ gen.Convolution,
75
+ P.MatMul,
76
+ gen.MatMulExt,
77
+ P.BatchMatMul,
78
+ gen.BatchMatMulExt,
79
+ gen.PReLU,
80
+ P.Einsum,
81
+ gen.Dense,
82
+ gen.Addmm,
83
+ ]
84
+
85
+ AMP_AUTO_BLACK_LIST = [
86
+ gen.Pow,
87
+ gen.ACos,
88
+ gen.Asin,
89
+ gen.Cosh,
90
+ P.Erfinv,
91
+ P.Exp,
92
+ P.Expm1,
93
+ P.Log,
94
+ P.Log1p,
95
+ P.Reciprocal,
96
+ P.Rsqrt,
97
+ P.Sinh,
98
+ P.Tan,
99
+ P.Softplus,
100
+ gen.SoftplusExt,
101
+ P.LayerNorm,
102
+ gen.LayerNormExt,
103
+ P.BatchNorm,
104
+ gen.GroupNorm,
105
+ P.KLDivLoss,
106
+ P.SmoothL1Loss,
107
+ P.MultilabelMarginLoss,
108
+ P.SoftMarginLoss,
109
+ P.TripletMarginLoss,
110
+ P.MultiMarginLoss,
111
+ P.BCEWithLogitsLoss,
112
+ P.Pdist,
113
+ P.Cdist,
114
+ P.Renorm,
115
+ ]
116
+
117
+ # Indicates which inputs of primitives need to be converted
118
+ AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})
119
+
66
120
  # Primitives in inner amp black list will not be converted in O2/O3
67
121
  _INNER_AMP_BLACK_LIST = []
68
122
 
@@ -302,6 +356,42 @@ def _auto_black_list(network, black_list, dtype):
302
356
  return network
303
357
 
304
358
 
359
+ class amp_decorator:
360
+ """
361
+ Auto mixed precision decorator.
362
+ Type of lists: List[Tuple[str, List[int]]]
363
+ """
364
+ def __init__(self, amp_level, amp_dtype, white_list, black_list):
365
+ self.amp_level = amp_level
366
+ self.amp_dtype = amp_dtype
367
+ self.white_list = white_list
368
+ self.black_list = black_list
369
+
370
+ def __enter__(self):
371
+ push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)
372
+
373
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
374
+ pop_amp_strategy()
375
+
376
+
377
+ def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
378
+ """
379
+ Set auto mixed precision context decorator for object.
380
+ Type of lists: List[Tuple[str, List[int]]]
381
+ """
382
+ if inspect.isfunction(obj) or inspect.ismethod(obj):
383
+ @functools.wraps(obj)
384
+ def wrapper(*args, **kwargs):
385
+ with amp_decorator(amp_level, amp_dtype, white_list, black_list):
386
+ return obj(*args, **kwargs)
387
+ return wrapper
388
+ if isinstance(obj, nn.Cell):
389
+ obj.construct = types.MethodType(
390
+ _set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
391
+ return obj
392
+ raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")
393
+
394
+
305
395
  def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
306
396
  """
307
397
  Returns a network processed with auto mixed precision.
@@ -312,26 +402,44 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
312
402
  converted to lower precision float, and calculation results are converted back to full precision float,
313
403
  i.e. ``mstype.float32`` .
314
404
 
315
- The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
316
- operators are specifically converted.
405
+ The `amp_level` and its corresponding lists determine which cells and operators are converted.
317
406
 
318
- The current built-in whitelist contents are:
407
+ When `amp_level` is set to ``O0``, no cells and operators are converted.
319
408
 
320
- [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
321
- :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
322
- :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
323
- :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
324
- :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
325
- :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
326
- :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
409
+ When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision
410
+ operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`.
327
411
 
328
- The current built-in blacklist contents are:
412
+ When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the
413
+ list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`.
329
414
 
330
- [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
331
- :class:`mindspore.nn.LayerNorm`]
415
+ When `amp_level` is set to ``O3``, all cells will be converted to low precision.
416
+
417
+ When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision
418
+ operations, operators in `auto_blacklist` will be converted to full precision operations, operators in
419
+ `promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators
420
+ not listed will run in the type defined by their inputs.
421
+
422
+ Operators in `auto_whitelist` are:
423
+
424
+ ``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``,
425
+ ``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm``
426
+
427
+ Operators in `auto_blacklist` are:
428
+
429
+ ``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
430
+ ``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
431
+ ``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
432
+ ``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
433
+ ``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
434
+ ``Norm``
435
+
436
+ Operators in `promote_list` are:
437
+
438
+ ``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
439
+ ``BiasAdd``
332
440
 
333
441
  For details on automatic mixed precision, refer to
334
- `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ .
442
+ `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
335
443
 
336
444
  Note:
337
445
  - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
@@ -339,10 +447,18 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
339
447
  - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
340
448
  mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
341
449
  need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
450
+ - When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you
451
+ may need to manually convert the type to avoid type inconsistency errors of the loss function.
452
+ - When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy
453
+ specified by `to_float` takes effect first.
454
+
455
+ .. warning::
456
+ ``auto`` level of `amp_level` is an experimental API that is subject to change or deletion.
342
457
 
343
458
  Args:
344
- network (Cell): Definition of the network.
345
- amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
459
+ network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level`
460
+ is set to ``auto`` .
461
+ amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
346
462
 
347
463
  - "O0": Do not change.
348
464
  - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
@@ -350,12 +466,16 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
350
466
  - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
351
467
  to lower precision operations.
352
468
  - "O3": Cast network to lower precision.
469
+ - "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in
470
+ `auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted
471
+ to the higher accuracy float type of the operator inputs, and operators not listed will run in the
472
+ type defined by their inputs.
353
473
 
354
474
  dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
355
475
  default: ``mstype.float16`` .
356
476
 
357
477
  Raises:
358
- TypeError: If `network` is not a Cell.
478
+ TypeError: If `network` is not a Cell or a function.
359
479
  ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
360
480
  ValueError: If `amp_level` is not within the supported range.
361
481
 
@@ -368,7 +488,12 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
368
488
  >>> net = amp.auto_mixed_precision(network, amp_level)
369
489
  """
370
490
  if not isinstance(network, nn.Cell):
371
- raise TypeError("The network type should be Cell.")
491
+ if amp_level == "auto":
492
+ if not inspect.isfunction(network) and not inspect.ismethod(network):
493
+ raise TypeError("For amp_level 'auto', the network type should be Cell or function.")
494
+ # function is supported for amp_level 'auto'
495
+ else:
496
+ raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.")
372
497
 
373
498
  if dtype not in (mstype.float16, mstype.bfloat16):
374
499
  raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
@@ -377,7 +502,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
377
502
  return network
378
503
 
379
504
  # Return network if the same amp level has already been configurated
380
- if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
505
+ if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"):
381
506
  logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
382
507
  f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
383
508
  f"degradation.")
@@ -396,8 +521,16 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
396
521
  else:
397
522
  network.to_float(dtype)
398
523
  network = _OutputTo32(network)
524
+ elif amp_level == "auto":
525
+ white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST]
526
+ black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST]
527
+ # set amp_strategy attribute for the object
528
+ amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list)
529
+ setattr(network, "amp_strategy", amp_strategy)
530
+ # set amp_strategy context decorator for the object
531
+ network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list)
399
532
  else:
400
- raise ValueError("The amp level {} is not supported".format(amp_level))
533
+ raise ValueError(f"The amp level {amp_level} is not supported")
401
534
 
402
535
  setattr(network, "_amp_level", amp_level)
403
536
 
@@ -437,6 +570,10 @@ _config_level = {
437
570
  "O3": {
438
571
  "keep_batchnorm_fp32": False,
439
572
  "cast_model_type": mstype.float16,
573
+ "loss_scale_manager": None},
574
+ "auto": {
575
+ "keep_batchnorm_fp32": False,
576
+ "cast_model_type": mstype.float32,
440
577
  "loss_scale_manager": None}}
441
578
 
442
579
 
@@ -461,20 +598,11 @@ def _check_kwargs(key_words):
461
598
  def _check_level(level, boost_level):
462
599
  """Check level."""
463
600
  if not isinstance(level, str):
464
- raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
465
- but got type {}.".format(type(level)))
601
+ raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],"
602
+ f"but got type {type(level)}.")
466
603
  validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
467
604
  validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
468
605
 
469
- if level == "auto":
470
- device_target = context.get_context('device_target')
471
- if device_target == "GPU":
472
- level = "O2"
473
- elif device_target == "Ascend":
474
- level = "O3"
475
- else:
476
- raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
477
-
478
606
  enable_boost = False
479
607
  if boost_level in ["O1", "O2"]:
480
608
  enable_boost = True
@@ -499,7 +627,8 @@ def _add_loss_network(network, loss_fn, cast_model_type):
499
627
  return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
500
628
 
501
629
  validator.check_value_type('loss_fn', loss_fn, nn.Cell)
502
- if cast_model_type == mstype.float16:
630
+ if cast_model_type in (mstype.float16, mstype.bfloat16) or \
631
+ (hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")):
503
632
  network = WithLossCell(network, loss_fn)
504
633
  else:
505
634
  network = nn.WithLossCell(network, loss_fn)
@@ -555,20 +684,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
555
684
  Default: ``None`` .
556
685
  level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
557
686
 
558
- - 'O0': Do not change.
559
- - 'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32.
560
- The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
561
- Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
562
- - 'O2': Cast network to float16, keep `mindspore.nn.BatchNorm` series interface,
563
- :class:`mindspore.nn.LayerNorm` and `loss_fn` (if set) run in float32, using dynamic loss scale.
564
- - 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
565
- - 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
566
- level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
567
- scenarios. User should specify the level for special network.
568
-
569
- 'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
570
- `cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
571
- `kwargs`.
687
+ For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`.
688
+
689
+ Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level`
690
+ setting may be overwritten by settings in `kwargs`.
572
691
 
573
692
  boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
574
693
  training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
@@ -649,7 +768,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
649
768
 
650
769
  def get_white_list():
651
770
  """
652
- Provide a copy of internal white list used by auto mixed precision.
771
+ Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``.
653
772
 
654
773
  The current built-in whitelist contents are:
655
774
 
@@ -687,7 +806,7 @@ def get_white_list():
687
806
 
688
807
  def get_black_list():
689
808
  """
690
- Provide a copy of internal black list used by auto mixed precision.
809
+ Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``.
691
810
 
692
811
  The current built-in blacklist contents are:
693
812
 
@@ -710,7 +829,6 @@ def get_black_list():
710
829
 
711
830
  def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
712
831
  """
713
- Custom mixed precision by setting whitelist or blacklist.
714
832
  When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
715
833
  When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
716
834
  Only one of `white_list` and `black_list` should be provided.
@@ -36,9 +36,9 @@ from mindspore.train.callback._reduce_lr_on_plateau import ReduceLROnPlateau
36
36
  from mindspore.train.callback._on_request_exit import OnRequestExit
37
37
  from mindspore.train.callback._backup_and_restore import BackupAndRestore
38
38
  from mindspore.train.callback._flops_collector import FlopsUtilizationCollector
39
- from mindspore.train.callback._mindio_ttp import MindIOTTPAdapter
39
+ from mindspore.train.callback._tft_register import TFTRegister
40
40
 
41
41
  __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "FlopsUtilizationCollector",
42
42
  "SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
43
43
  "History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping", "OnRequestExit", "BackupAndRestore",
44
- "MindIOTTPAdapter"]
44
+ "TFTRegister"]
@@ -123,7 +123,7 @@ class Callback:
123
123
  recording current attributes. Users can add custimized attributes to the information.
124
124
  Training process can also be stopped by calling `request_stop` method. For details
125
125
  of custom Callback, please check
126
- `Callback tutorial <https://www.mindspore.cn/tutorials/en/master/advanced/model/
126
+ `Callback tutorial <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/
127
127
  callback.html#customized-callback-mechanism>`_.
128
128
 
129
129
  Examples:
@@ -493,7 +493,7 @@ class RunContext:
493
493
  `RunContext.original_args()` and add extra attributes to the information, but also can stop the
494
494
  training process by calling `request_stop` method. For details of custom Callback,
495
495
  please check
496
- `Callback Mechanism <https://www.mindspore.cn/tutorials/en/master/advanced/model/callback.html>`_.
496
+ `Callback Mechanism <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/callback.html>`_.
497
497
 
498
498
  `RunContext.original_args()` holds the model context information as a dictionary variable, and
499
499
  different attributes of the dictionary are stored in training or eval process. Details are as follows:
@@ -575,7 +575,7 @@ class RunContext:
575
575
 
576
576
  Tutorial Examples:
577
577
  - `Callback Mechanism - Customized Callback Mechanism
578
- <https://mindspore.cn/tutorials/en/master/advanced/model/callback.html#customized-callback-mechanism>`_
578
+ <https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#customized-callback-mechanism>`_
579
579
  """
580
580
  return self._original_args
581
581
 
@@ -588,7 +588,7 @@ class RunContext:
588
588
 
589
589
  Tutorial Examples:
590
590
  - `Callback Mechanism - Customized Training Termination Time
591
- <https://mindspore.cn/tutorials/en/master/advanced/model/callback.html#
591
+ <https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#
592
592
  customized-training-termination-time>`_
593
593
  """
594
594
  self._stop_requested = True