mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/train/model.py CHANGED
@@ -36,7 +36,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
36
36
  from mindspore._checkparam import check_input_data, check_output_data
37
37
  from mindspore import _checkparam as Validator
38
38
  from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
39
- FlopsUtilizationCollector, MindIOTTPAdapter
39
+ FlopsUtilizationCollector, TFTRegister
40
40
  from mindspore.train.callback import __all__ as internal_cb_names
41
41
  from mindspore.train.callback._cluster_monitor import ClusterMonitor
42
42
  from mindspore import context
@@ -46,6 +46,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
46
46
  from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
47
47
  _cache_enable, _enable_distributed_mindrt
48
48
  from mindspore.train.metrics import Loss
49
+ from mindspore.train._utils import vlog_print
49
50
  from mindspore import nn
50
51
  from mindspore.boost import AutoBoost
51
52
  from mindspore.context import ParallelMode
@@ -119,6 +120,101 @@ def _save_final_ckpt(func):
119
120
  func(self, *args, **kwargs)
120
121
  return wrapper
121
122
 
123
+ def _handle_tft(func):
124
+ """
125
+ Decorator function, which starts uce handle process when an exception occurs during training.
126
+ """
127
+ @wraps(func)
128
+ def wrapper(self, *args, **kwargs):
129
+ obj = None
130
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
131
+ obj = kwargs.get('callbacks')
132
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
133
+ for item in kwargs.get('callbacks'):
134
+ if isinstance(item, TFTRegister):
135
+ obj = item
136
+ if obj:
137
+ tft = obj.tft
138
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
139
+ uce_env = "UCE:1" in tft_env
140
+ while True:
141
+ try:
142
+ return func(self, *args, **kwargs)
143
+ except RuntimeError as e:
144
+ logger.info("uce wrapper caught RuntimeError")
145
+ if not uce_env:
146
+ logger.info("uce wrapper caught RuntimeError uce not enable")
147
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
148
+ raise e
149
+ e_str = str(e)
150
+ logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
151
+ if "UCEError" in e_str:
152
+ logger.info("uce wrapper report UCEError")
153
+ tft.tft_report_error(tft.ReportState.RS_UCE.value)
154
+ elif "ForceStopError" in e_str:
155
+ logger.info("uce wrapper caught RuntimeError ForceStopError")
156
+ force_stop_err = tft.ReportState.RS_NORMAL.value
157
+ tft.tft_report_error(force_stop_err)
158
+ else:
159
+ logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
160
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
161
+ raise e
162
+ ret = tft.tft_wait_next_action()
163
+ if ret == tft.Action.EXIT.value:
164
+ raise e
165
+ repair_step = tft.tft_get_repair_step()
166
+ logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
167
+ {}".format(repair_step, self.batch_num))
168
+ initial_epoch = int(repair_step/self.batch_num)
169
+ initial_step = repair_step % self.batch_num
170
+ kwargs["initial_epoch"] = initial_epoch
171
+
172
+ train_dataset = args[1]
173
+ dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
174
+ sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
175
+
176
+ cb_initial_step = 0
177
+ if dataset_sink_mode:
178
+ train_dataset.set_init_step(initial_epoch)
179
+ dataset_size = train_dataset.get_dataset_size()
180
+ if sink_size != -1:
181
+ cb_initial_step = initial_epoch * sink_size + initial_step
182
+ else:
183
+ cb_initial_step = initial_epoch * dataset_size + initial_step
184
+ else:
185
+ train_dataset.set_init_step(initial_step)
186
+ cb_initial_step = initial_step
187
+
188
+ kwargs["initial_step"] = cb_initial_step
189
+
190
+ logger.info("uce wrapper repair complete \
191
+ initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
192
+ continue
193
+ except BaseException as e:
194
+ logger.info("uce wrapper caught BaseException error")
195
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
196
+ raise e
197
+ else:
198
+ return func(self, *args, **kwargs)
199
+ return wrapper
200
+
201
+
202
+ def _check_tft():
203
+ """Check if TFT is supported"""
204
+ tft_env = os.getenv("MS_ENABLE_TFT")
205
+ device_target = context.get_context("device_target")
206
+ if tft_env and device_target == "Ascend":
207
+ from mindspore._c_expression import MSContext
208
+ ascend_target = MSContext.get_instance().get_ascend_soc_version()
209
+ if ascend_target == 'ascend910':
210
+ raise ValueError("TFT is not supported when using ascend910")
211
+ ms_mode = context.get_context("mode")
212
+ if ms_mode != mindspore.GRAPH_MODE:
213
+ raise ValueError("TFT is only supported in GRAPH_MODE")
214
+ jit_level = context.get_context("jit_level")
215
+ if jit_level == "O2" and "UCE:1" in tft_env:
216
+ raise ValueError("TFT is not supported when using jit_level == O2")
217
+
122
218
 
123
219
  def _append_ccae(callbacks):
124
220
  """Add cluster monitoring when CCAE is enabled."""
@@ -290,21 +386,11 @@ class Model:
290
386
  amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
291
387
  precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
292
388
 
293
- - "O0": Do not change.
294
- - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
295
- The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
296
- Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
297
- - "O2": Cast network to float16, keep BatchNorm run in float32, using dynamic loss scale.
298
- - "O3": Cast network to float16, the BatchNorm is also cast to float16, loss scale will not be used.
299
- - "auto": Set level to recommended level in different devices. Set level to "O2" on GPU, set
300
- level to "O3" on Ascend. The recommended level is chosen by the expert experience, not applicable to all
301
- scenarios. User should specify the level for special network.
302
-
303
- "O2" is recommended on GPU, "O3" is recommended on Ascend.
389
+ For details on `amp_level` , refer to :func:`mindspore.amp.auto_mixed_precision`.
390
+
304
391
  The BatchNorm strategy can be changed by `keep_batchnorm_fp32` settings in `kwargs`. `keep_batchnorm_fp32`
305
392
  must be a bool. The loss scale strategy can be changed by `loss_scale_manager` setting in `kwargs`.
306
393
  `loss_scale_manager` should be a subclass of :class:`mindspore.amp.LossScaleManager`.
307
- The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network`.
308
394
 
309
395
  boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
310
396
  training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
@@ -379,6 +465,7 @@ class Model:
379
465
  self._mindspore_lite = None
380
466
  self._lite_infer = True # if backend lite infer fails, set False
381
467
  self._mindspore_lite_model_group_id = id(self) & 0xFFFF
468
+ self.batch_num = -1
382
469
 
383
470
  def _check_for_graph_cell(self, kwargs):
384
471
  """Check for graph cell"""
@@ -568,9 +655,13 @@ class Model:
568
655
  dataset.__loop_size__ = 1
569
656
 
570
657
  if dataset_helper is None:
658
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to create DatasetHelper.")
659
+ logger.info("Begin to create DatasetHelper.")
571
660
  dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
572
661
 
573
662
  if dataset_sink_mode:
663
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to connect network with dataset.")
664
+ logger.info("Begin to connect network with dataset.")
574
665
  network = connect_network_with_dataset(network, dataset_helper)
575
666
 
576
667
  if _get_recovery_context("enable_recovery") and is_train:
@@ -589,6 +680,10 @@ class Model:
589
680
  if self._backbone_is_train != is_train:
590
681
  network.set_train(is_train)
591
682
  self._backbone_is_train = is_train
683
+ # Mode train and eval are the same net, network will be set_grad in _build_train_network.
684
+ # But if mode just want to do predict or eval, must set network set_grad False
685
+ if not is_train:
686
+ network.set_grad(False)
592
687
  return network
593
688
 
594
689
  def _check_need_ckpt(self, callbacks):
@@ -687,6 +782,8 @@ class Model:
687
782
  if not train_dataset and not valid_dataset:
688
783
  raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
689
784
 
785
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to check device number in model.build().")
786
+ logger.info("Begin to check device number in model.build() procedure.")
690
787
  _device_number_check(self._parallel_mode, self._device_number)
691
788
 
692
789
  if train_dataset:
@@ -694,27 +791,44 @@ class Model:
694
791
  raise TypeError("The type of 'train_dataset' must be `Dataset`, "
695
792
  "but got {}.".format(type(train_dataset)))
696
793
 
794
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
795
+ "Begin to check parameter broadcast in model.build().")
796
+ logger.info("Begin to check parameter broadcast in model.build() procedure.")
697
797
  _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
698
798
  if self._parameter_broadcast:
699
799
  self._train_network.set_broadcast_flag()
700
800
 
801
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to exec preprocess in model.build().")
802
+ logger.info("Begin to exec preprocess in model.build() procedure.")
701
803
  train_dataset.__no_send__ = True
702
804
  train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
703
805
  dataset=train_dataset,
704
806
  dataset_sink_mode=True,
705
807
  sink_size=sink_size)
808
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
809
+ logger.info("Begin to warmup dataset in model.build() procedure.")
706
810
  self._warmup_dataset(epoch, train_dataset, sink_size)
707
811
 
708
812
  # Since dataset pipeline has been triggered, delete flag
709
813
  delattr(train_dataset, "__no_send__")
710
814
 
711
815
  # Waiting for the dataset warmup ready
816
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
817
+ "Begin waiting for dataset warmup in model.build().")
818
+ logger.info("Begin waiting for dataset warmup in model.build() procedure.")
712
819
  self._waiting_for_dataset_warmup_ready(train_dataset)
820
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
821
+ "The dataset warmup was successful in model.build().")
822
+ logger.info("The dataset warmup was successful in model.build() procedure.")
713
823
 
714
824
  if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
715
825
  train_network.add_flags_recursive(is_first_iteration=True)
716
826
  for inputs in train_dataset_helper:
827
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
828
+ "Begin to compile train network in model.build().")
829
+ logger.info("Begin to compile train network in model.build() procedure.")
717
830
  train_network.compile(*inputs)
831
+ self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
718
832
  break
719
833
 
720
834
  if valid_dataset:
@@ -732,6 +846,9 @@ class Model:
732
846
  if context.get_auto_parallel_context("pipeline_stages") > 1:
733
847
  eval_network.add_flags_recursive(is_first_iteration=False)
734
848
  for inputs in valid_dataset_helper:
849
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
850
+ "Begin to compile eval network in model.build().")
851
+ logger.info("Begin to compile eval network in model.build() procedure.")
735
852
  eval_network.compile(*inputs)
736
853
  break
737
854
 
@@ -746,9 +863,10 @@ class Model:
746
863
 
747
864
  return [callbacks]
748
865
 
866
+ @_handle_tft
749
867
  @_save_final_ckpt
750
868
  def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
751
- valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True):
869
+ valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True, initial_step=0):
752
870
  """
753
871
  Training.
754
872
 
@@ -772,12 +890,14 @@ class Model:
772
890
  self._train_network.set_broadcast_flag()
773
891
 
774
892
  cb_params = _InternalCallbackParam()
893
+ cb_params.cur_step_num = initial_step
775
894
  cb_params.train_network = self._train_network
776
895
  cb_params.epoch_num = epoch - initial_epoch
777
896
  if dataset_sink_mode and sink_size > 0:
778
897
  cb_params.batch_num = sink_size
779
898
  else:
780
899
  cb_params.batch_num = train_dataset.get_dataset_size()
900
+ self.batch_num = cb_params.batch_num
781
901
  cb_params.mode = "train"
782
902
  cb_params.loss_fn = self._loss_fn
783
903
  cb_params.optimizer = self._optimizer
@@ -801,16 +921,19 @@ class Model:
801
921
  epoch = 1
802
922
  cb_params.last_save_ckpt_step = None
803
923
  cb_params.latest_ckpt_file = None
924
+ cb_params.loss_scale_mananger = self._loss_scale_manager
804
925
 
805
926
  # build callback list
806
927
  with _CallbackManager(callbacks) as list_callback:
807
928
  self._check_reuse_dataset(train_dataset)
808
929
  if not dataset_sink_mode:
809
- self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
930
+ self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
931
+ valid_infos)
810
932
  elif context.get_context("device_target") == "CPU":
811
933
  logger.info("The CPU cannot support dataset sink mode currently."
812
934
  "So the training process will be performed with dataset not sink.")
813
- self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
935
+ self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
936
+ valid_infos)
814
937
  else:
815
938
  self._train_dataset_sink_process(epoch, train_dataset, list_callback,
816
939
  cb_params, sink_size, initial_epoch, valid_infos)
@@ -850,9 +973,7 @@ class Model:
850
973
  train_dataset.__total_batch__ = epoch * sink_size
851
974
 
852
975
  cb_params.sink_size = sink_size
853
- cb_params.cur_step_num = 0
854
976
  cb_params.dataset_sink_mode = True
855
-
856
977
  run_context = RunContext(cb_params)
857
978
  list_callback.on_train_begin(run_context)
858
979
  # used to stop training for early stop, such as stopAtTIme or stopATStep
@@ -861,7 +982,6 @@ class Model:
861
982
  dataset_helper = train_dataset._dataset_helper
862
983
 
863
984
  self.epoch_iter = 0
864
-
865
985
  self._check_enable_recovery()
866
986
  # Used to check whether need perform recovery for process which is restarted.
867
987
  self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
@@ -997,7 +1117,6 @@ class Model:
997
1117
  dataset_size (int): The number of batches in a dataset.
998
1118
  sink_size (int): Control the amount of data in each sink. Default: -1.
999
1119
  """
1000
-
1001
1120
  if not self.enable_recovery:
1002
1121
  self.need_load_ckpt = False
1003
1122
 
@@ -1084,7 +1203,6 @@ class Model:
1084
1203
  dataset=train_dataset,
1085
1204
  dataset_sink_mode=False,
1086
1205
  epoch_num=epoch)
1087
- cb_params.cur_step_num = 0
1088
1206
  cb_params.dataset_sink_mode = False
1089
1207
  run_context = RunContext(cb_params)
1090
1208
  list_callback.on_train_begin(run_context)
@@ -1106,7 +1224,6 @@ class Model:
1106
1224
  "returned by 'train_dataset'".format(len_element))
1107
1225
  cb_params.cur_step_num += 1
1108
1226
  self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
1109
-
1110
1227
  cb_params.train_dataset_element = next_element
1111
1228
  list_callback.on_train_step_begin(run_context)
1112
1229
  self._check_network_mode(self._train_network, True)
@@ -1150,31 +1267,6 @@ class Model:
1150
1267
 
1151
1268
  list_callback.on_train_end(run_context)
1152
1269
 
1153
- def _wrapper_train(self, callbacks):
1154
- """
1155
- This method used to wrap train function with ttp wrapper which will do event notify when
1156
- exceptions throw.
1157
-
1158
- Args:
1159
- callbacks (function): Callbacks passed by train method.
1160
- """
1161
-
1162
- if not callbacks:
1163
- return self._train
1164
- cbs = callbacks if isinstance(callbacks, list) else [callbacks]
1165
- obj = None
1166
- _train_wrapper = None
1167
- for item in cbs:
1168
- if isinstance(item, MindIOTTPAdapter):
1169
- obj = item
1170
-
1171
- if (obj is not None) and (obj.enable is True):
1172
- logger.info("MindIO TTP is enable, so we wrapper ttp exception handdler for self train method.")
1173
- _train_wrapper = obj.wrapper_ttp_persist(self._train)
1174
-
1175
- return self._train if not _train_wrapper else _train_wrapper
1176
-
1177
-
1178
1270
  def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=False, sink_size=-1, initial_epoch=0):
1179
1271
  """
1180
1272
  Training API.
@@ -1240,9 +1332,10 @@ class Model:
1240
1332
  ... loss_scale_manager=loss_scale_manager)
1241
1333
  >>> model.train(2, dataset)
1242
1334
  """
1335
+ _check_tft()
1336
+ device_target = context.get_context("device_target")
1243
1337
  # prepare dataset for obfuscated model
1244
1338
  train_dataset = self._prepare_obf_dataset(train_dataset)
1245
- device_target = context.get_context("device_target")
1246
1339
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1247
1340
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1248
1341
  dataset_sink_mode = False
@@ -1283,16 +1376,14 @@ class Model:
1283
1376
  _device_number_check(self._parallel_mode, self._device_number)
1284
1377
 
1285
1378
  callbacks = _append_ccae(callbacks)
1286
- _train_wrapper = None
1287
1379
  if callbacks:
1288
1380
  self._check_methods_for_custom_callbacks(callbacks, "train")
1289
- _train_wrapper = self._wrapper_train(callbacks)
1290
- _train_wrapper(epoch,
1291
- train_dataset,
1292
- callbacks=callbacks,
1293
- dataset_sink_mode=dataset_sink_mode,
1294
- sink_size=sink_size,
1295
- initial_epoch=initial_epoch)
1381
+ self._train(epoch,
1382
+ train_dataset,
1383
+ callbacks=callbacks,
1384
+ dataset_sink_mode=dataset_sink_mode,
1385
+ sink_size=sink_size,
1386
+ initial_epoch=initial_epoch)
1296
1387
 
1297
1388
  # When it's distributed training and using MindRT,
1298
1389
  # the node id should be reset to start from 0.
@@ -1396,7 +1487,7 @@ class Model:
1396
1487
 
1397
1488
  Tutorial Examples:
1398
1489
  - `Advanced Encapsulation: Model - Train and Save Model
1399
- <https://www.mindspore.cn/tutorials/en/master/advanced/model.html#training-and-saving-model>`_
1490
+ <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1400
1491
  """
1401
1492
  device_target = context.get_context("device_target")
1402
1493
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
@@ -1493,7 +1584,12 @@ class Model:
1493
1584
  if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1494
1585
  self._train_network.check_names_and_refresh_name()
1495
1586
  self._train_network._is_check_and_refresh = True
1587
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
1588
+ logger.info("Begin to init dataset in model.build() procedure.")
1496
1589
  self._init(train_dataset, valid_dataset, sink_size, epoch)
1590
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1591
+ "The model.build() which contains dataset warmup and network compile is success.")
1592
+ logger.info("The model.build() which contains dataset warmup and network compile is success.")
1497
1593
 
1498
1594
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
1499
1595
  """
@@ -1663,7 +1759,7 @@ class Model:
1663
1759
 
1664
1760
  Tutorial Examples:
1665
1761
  - `Advanced Encapsulation: Model - Train and Save Model
1666
- <https://www.mindspore.cn/tutorials/en/master/advanced/model.html#training-and-saving-model>`_
1762
+ <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1667
1763
  """
1668
1764
  valid_dataset = self._prepare_obf_dataset(valid_dataset)
1669
1765
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)