mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +47 -198
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +229 -99
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +480 -372
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +5 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +975 -1981
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +324 -573
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +179 -120
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +798 -761
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +933 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1373 -192
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +53 -42
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +19 -15
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +3 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +361 -359
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  287. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  288. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4351 -3813
  334. mindspore/ops/function/nn_func.py +1712 -637
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +452 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +26 -18
  343. mindspore/ops/functional.py +23 -7
  344. mindspore/ops/functional_overload.py +1548 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +23 -15
  347. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +298 -87
  360. mindspore/ops/operations/debug_ops.py +157 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +212 -531
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1895 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +159 -40
  426. mindspore/parallel/_cell_wrapper.py +132 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +700 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +258 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -59
  446. mindspore/parallel/transform_safetensors.py +364 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +416 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +96 -27
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +269 -136
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +552 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py CHANGED
@@ -27,7 +27,6 @@ import time
27
27
  import numpy as np
28
28
 
29
29
  import mindspore
30
- import mindspore.dataset as ds
31
30
  from mindspore import log as logger
32
31
  from mindspore.train.serialization import save_checkpoint, load_checkpoint
33
32
  from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
@@ -36,7 +35,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
36
35
  from mindspore._checkparam import check_input_data, check_output_data
37
36
  from mindspore import _checkparam as Validator
38
37
  from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
39
- FlopsUtilizationCollector, TFTRegister
38
+ TrainFaultTolerance
40
39
  from mindspore.train.callback import __all__ as internal_cb_names
41
40
  from mindspore.train.callback._cluster_monitor import ClusterMonitor
42
41
  from mindspore import context
@@ -46,7 +45,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
46
45
  from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
47
46
  _cache_enable, _enable_distributed_mindrt
48
47
  from mindspore.train.metrics import Loss
49
- from mindspore.train._utils import vlog_print
48
+ from mindspore.log import vlog_print
50
49
  from mindspore import nn
51
50
  from mindspore.boost import AutoBoost
52
51
  from mindspore.context import ParallelMode
@@ -57,7 +56,11 @@ from mindspore.dataset.core.config import get_debug_mode
57
56
  from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
58
57
  from mindspore.train import amp
59
58
  from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
59
+ from mindspore._c_expression import _get_optimzer_timestamps
60
+ from mindspore._c_expression import clean_tdt_channel
60
61
 
62
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
63
+ from .serialization import load_param_into_net
61
64
 
62
65
  def _transfer_tensor_to_tuple(inputs):
63
66
  """
@@ -91,6 +94,7 @@ def _save_final_ckpt(func):
91
94
  """
92
95
  Decorator function, which saves the current checkpoint when an exception occurs during training.
93
96
  """
97
+
94
98
  @wraps(func)
95
99
  def wrapper(self, *args, **kwargs):
96
100
  obj = None
@@ -107,7 +111,7 @@ def _save_final_ckpt(func):
107
111
  # pylint: disable=W0212
108
112
  prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
109
113
  cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
110
- + str(self._current_step_num) + "_breakpoint.ckpt"
114
+ + str(self._current_step_num) + "_breakpoint.ckpt"
111
115
  cur_file = os.path.join(obj._directory, cur_ckpoint_file)
112
116
  if "epoch_num" in obj._append_dict:
113
117
  obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
@@ -118,85 +122,172 @@ def _save_final_ckpt(func):
118
122
  raise e
119
123
  else:
120
124
  func(self, *args, **kwargs)
125
+
121
126
  return wrapper
122
127
 
128
+
129
+ def _handle_exception_info(obj, uce_env, tft, e):
130
+ """handle exception info"""
131
+ logger.info("uce wrapper caught RuntimeError")
132
+ if not uce_env:
133
+ logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
134
+ exc_info=True)
135
+ if tft:
136
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
137
+ raise e
138
+ e_str = str(e)
139
+ logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
140
+ if "UCEError" in e_str:
141
+ logger.info("uce wrapper report UCEError")
142
+ obj.is_uce_rank = True
143
+ # if error is HBM_MULTI_BIT_ECC_ERROR
144
+ if "error_code=507054" in e_str:
145
+ hbm_error_time, optimize_start, optimizer_end = _get_optimzer_timestamps()
146
+ can_repair = tft.tft_can_do_uce_repair(hbm_error_time, optimize_start, optimizer_end)
147
+ logger.info(f"UCEError of type HBM_MULTI_BIT_ECC_ERROR occurs, \
148
+ hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, \
149
+ optimizer_end={optimizer_end}, can_repair={can_repair}")
150
+ if not can_repair:
151
+ logger.error(f"Caught UCEError of type HBM_MULTI_BIT_ECC_ERROR but can not repair, "
152
+ f"hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, "
153
+ f"optimizer_end={optimizer_end}", exc_info=True)
154
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
155
+ raise e
156
+ tft.tft_report_error(tft.ReportState.RS_UCE.value)
157
+ elif "ForceStopError" in e_str:
158
+ logger.warning("uce wrapper caught RuntimeError ForceStopError")
159
+ force_stop_err = tft.ReportState.RS_NORMAL.value
160
+ tft.tft_report_error(force_stop_err)
161
+ elif "ARF FINISH" in e_str:
162
+ logger.warning(f"ARF FINISH")
163
+ _set_recovery_context(is_arf=True)
164
+ tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
165
+ else:
166
+ logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
167
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
168
+ raise e
169
+
170
+
171
+ def _handle_training_result_error(model, tft_obj):
172
+ """
173
+ Handle training result error for resuming training.
174
+ """
175
+ ckpt_load_fn = tft_obj.ckpt_load_func
176
+ train_network = tft_obj.cb_params.train_network
177
+ logger.warning("Process training result error start.")
178
+ # 1. Clear tdt channel
179
+ logger.warning("Clean tdt channel.")
180
+ clean_tdt_channel()
181
+
182
+ # 2. Load checkpoint
183
+ logger.warning("Load checkpoint.")
184
+ new_param_dict, remove_redundancy = ckpt_load_fn()
185
+ param_not_load, ckpt_not_load = load_param_into_net(train_network, new_param_dict, True, remove_redundancy)
186
+ logger.warning(f"param_not_load: {param_not_load}")
187
+ logger.warning(f"ckpt_not_load: {ckpt_not_load}")
188
+ resume_epoch = new_param_dict.get('epoch_num')
189
+ resume_step = new_param_dict.get('step_num')
190
+ model._initial_step = int(resume_step.asnumpy())
191
+ logger.warning("Process training result error end.")
192
+ return (resume_epoch, resume_step)
193
+
194
+
195
+ def _calc_cb_initial_step(org_epoch, org_step, *args, **kwargs):
196
+ """calculate initial step for callback"""
197
+ train_dataset = args[1]
198
+ dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
199
+ sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
200
+
201
+ cb_initial_step = 0
202
+ if dataset_sink_mode:
203
+ train_dataset.set_init_step(org_epoch)
204
+ dataset_size = train_dataset.get_dataset_size()
205
+ if sink_size != -1:
206
+ cb_initial_step = org_epoch * sink_size + org_step
207
+ else:
208
+ cb_initial_step = org_epoch * dataset_size + org_step
209
+ else:
210
+ train_dataset.set_init_step(org_step)
211
+ cb_initial_step = org_step
212
+ if hasattr(train_dataset, '_dataset_helper'):
213
+ dataset_helper = train_dataset._dataset_helper
214
+ _reset_training_dataset(cb_initial_step, dataset_helper.iter.dataset.get_dataset_size())
215
+ return cb_initial_step
216
+
217
+
218
+ def _update_ckpt_callback_info(resume_train_step, **kwargs):
219
+ """
220
+ Update checkpoint callback internal state
221
+ """
222
+ ckpt_obj = None
223
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
224
+ ckpt_obj = kwargs.get('callbacks')
225
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
226
+ for item in kwargs.get('callbacks'):
227
+ if isinstance(item, ModelCheckpoint):
228
+ ckpt_obj = item
229
+ if ckpt_obj is not None:
230
+ ckpt_obj._last_triggered_step = 0
231
+ ckpt_obj._append_step_num = resume_train_step
232
+
233
+
123
234
  def _handle_tft(func):
124
235
  """
125
236
  Decorator function, which starts uce handle process when an exception occurs during training.
126
237
  """
238
+
127
239
  @wraps(func)
128
240
  def wrapper(self, *args, **kwargs):
129
241
  obj = None
130
- if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
242
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
131
243
  obj = kwargs.get('callbacks')
132
244
  if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
133
245
  for item in kwargs.get('callbacks'):
134
- if isinstance(item, TFTRegister):
246
+ if isinstance(item, TrainFaultTolerance):
135
247
  obj = item
136
248
  if obj:
137
249
  tft = obj.tft
138
250
  tft_env = os.getenv("MS_ENABLE_TFT", "")
139
- uce_env = "UCE:1" in tft_env
251
+ uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
252
+ tre_env = "TRE:1" in tft_env
140
253
  while True:
141
254
  try:
142
255
  return func(self, *args, **kwargs)
143
256
  except RuntimeError as e:
144
- logger.info("uce wrapper caught RuntimeError")
145
- if not uce_env:
146
- logger.info("uce wrapper caught RuntimeError uce not enable")
147
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
148
- raise e
149
- e_str = str(e)
150
- logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
151
- if "UCEError" in e_str:
152
- obj.is_uce_rank = True
153
- logger.info("uce wrapper report UCEError")
154
- tft.tft_report_error(tft.ReportState.RS_UCE.value)
155
- elif "ForceStopError" in e_str:
156
- logger.info("uce wrapper caught RuntimeError ForceStopError")
157
- force_stop_err = tft.ReportState.RS_NORMAL.value
158
- tft.tft_report_error(force_stop_err)
257
+ if tre_env and 'TREError' in str(e):
258
+ _, resume_step = _handle_training_result_error(self, obj)
259
+ repair_step = int(resume_step.asnumpy())
260
+ _update_ckpt_callback_info(repair_step, **kwargs)
261
+ logger.warning(f'Resume training after TREError from step {repair_step}.')
159
262
  else:
160
- logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
161
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
162
- raise e
163
- ret = tft.tft_wait_next_action()
164
- if ret == tft.Action.EXIT.value:
165
- raise e
166
- repair_step = tft.tft_get_repair_step()
167
- logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
168
- {}".format(repair_step, self.batch_num))
169
- initial_epoch = int(repair_step/self.batch_num)
263
+ _handle_exception_info(obj, uce_env, tft, e)
264
+ ret = tft.tft_wait_next_action()
265
+ if ret == tft.Action.EXIT.value:
266
+ raise e
267
+ repair_step = tft.tft_get_repair_step()
268
+ logger.warning(
269
+ "uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
270
+ self.batch_num))
271
+ initial_epoch = int(repair_step / self.batch_num)
170
272
  initial_step = repair_step % self.batch_num
171
273
  kwargs["initial_epoch"] = initial_epoch
172
-
173
- train_dataset = args[1]
174
- dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
175
- sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
176
-
177
- cb_initial_step = 0
178
- if dataset_sink_mode:
179
- train_dataset.set_init_step(initial_epoch)
180
- dataset_size = train_dataset.get_dataset_size()
181
- if sink_size != -1:
182
- cb_initial_step = initial_epoch * sink_size + initial_step
183
- else:
184
- cb_initial_step = initial_epoch * dataset_size + initial_step
185
- else:
186
- train_dataset.set_init_step(initial_step)
187
- cb_initial_step = initial_step
188
-
189
- kwargs["initial_step"] = cb_initial_step
190
-
191
- logger.info("uce wrapper repair complete \
192
- initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
274
+ cb_initial_step = _calc_cb_initial_step(initial_epoch, initial_step, *args, **kwargs)
275
+ if not self.enable_tre:
276
+ kwargs["initial_step"] = cb_initial_step
277
+ # reset all accu grads to zero
278
+ obj._reset_acc_grads()
279
+ logger.warning(
280
+ "uce wrapper repair complete initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch,
281
+ cb_initial_step))
193
282
  continue
194
283
  except BaseException as e:
195
- logger.info("uce wrapper caught BaseException error")
196
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
284
+ if tft:
285
+ logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
286
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
197
287
  raise e
198
288
  else:
199
289
  return func(self, *args, **kwargs)
290
+
200
291
  return wrapper
201
292
 
202
293
 
@@ -213,7 +304,7 @@ def _check_tft():
213
304
  if ms_mode != mindspore.GRAPH_MODE:
214
305
  raise ValueError("TFT is only supported in GRAPH_MODE")
215
306
  jit_level = context.get_context("jit_level")
216
- if jit_level == "O2" and "UCE:1" in tft_env:
307
+ if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
217
308
  raise ValueError("TFT is not supported when using jit_level == O2")
218
309
 
219
310
 
@@ -403,12 +494,13 @@ class Model:
403
494
  the accuracy is reduced by less than 3%.
404
495
 
405
496
  If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
406
- In order for this function to work, you need to set the optimizer, eval_network or metric parameters
407
- at the same time.
497
+ In order for this function to work, you need to set the parameter `optimizer`, along with
498
+ at least one of the parameter `eval_network` or performance `metrics`.
408
499
 
409
500
  Notice: The current optimization enabled by default only applies to some networks, and not all networks
410
501
  can obtain the same benefits. It is recommended to enable this function on
411
- the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
502
+ the Graph mode + Ascend platform, and for better acceleration,
503
+ refer to :class:`mindspore.boost.AutoBoost` to configure
412
504
  boost_config_dict.
413
505
 
414
506
  Examples:
@@ -433,6 +525,7 @@ class Model:
433
525
  def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
434
526
  amp_level="O0", boost_level="O0", **kwargs):
435
527
  self._network = network
528
+ _init_auto_parallel_context(self._network)
436
529
  self._loss_fn = loss_fn
437
530
  self._optimizer = optimizer
438
531
  self._loss_scale_manager = None
@@ -467,6 +560,9 @@ class Model:
467
560
  self._lite_infer = True # if backend lite infer fails, set False
468
561
  self._mindspore_lite_model_group_id = id(self) & 0xFFFF
469
562
  self.batch_num = -1
563
+ self.enable_tre = "TRE:1" in os.getenv("MS_ENABLE_TFT", "")
564
+ self._initial_step = None
565
+ _clear_auto_parallel_context(self._network)
470
566
 
471
567
  def _check_for_graph_cell(self, kwargs):
472
568
  """Check for graph cell"""
@@ -665,7 +761,7 @@ class Model:
665
761
  logger.info("Begin to connect network with dataset.")
666
762
  network = connect_network_with_dataset(network, dataset_helper)
667
763
 
668
- if _get_recovery_context("enable_recovery") and is_train:
764
+ if (_get_recovery_context("enable_recovery") or self.enable_tre) and is_train:
669
765
  _set_training_dataset(dataset_helper)
670
766
 
671
767
  network.set_train(is_train)
@@ -762,7 +858,7 @@ class Model:
762
858
  break
763
859
  logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
764
860
 
765
- def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
861
+ def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
766
862
  """
767
863
  Initialize compute graphs and data graphs with the sink mode.
768
864
 
@@ -791,7 +887,6 @@ class Model:
791
887
  if not isinstance(train_dataset, mindspore.dataset.Dataset):
792
888
  raise TypeError("The type of 'train_dataset' must be `Dataset`, "
793
889
  "but got {}.".format(type(train_dataset)))
794
-
795
890
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
796
891
  "Begin to check parameter broadcast in model.build().")
797
892
  logger.info("Begin to check parameter broadcast in model.build() procedure.")
@@ -804,23 +899,24 @@ class Model:
804
899
  train_dataset.__no_send__ = True
805
900
  train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
806
901
  dataset=train_dataset,
807
- dataset_sink_mode=True,
902
+ dataset_sink_mode=sink_mode,
808
903
  sink_size=sink_size)
809
904
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
810
- logger.info("Begin to warmup dataset in model.build() procedure.")
811
- self._warmup_dataset(epoch, train_dataset, sink_size)
905
+ if sink_mode:
906
+ logger.info("Begin to warmup dataset in model.build() procedure.")
907
+ self._warmup_dataset(epoch, train_dataset, sink_size)
812
908
 
813
- # Since dataset pipeline has been triggered, delete flag
814
- delattr(train_dataset, "__no_send__")
909
+ # Since dataset pipeline has been triggered, delete flag
910
+ delattr(train_dataset, "__no_send__")
815
911
 
816
- # Waiting for the dataset warmup ready
817
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
818
- "Begin waiting for dataset warmup in model.build().")
819
- logger.info("Begin waiting for dataset warmup in model.build() procedure.")
820
- self._waiting_for_dataset_warmup_ready(train_dataset)
821
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
822
- "The dataset warmup was successful in model.build().")
823
- logger.info("The dataset warmup was successful in model.build() procedure.")
912
+ # Waiting for the dataset warmup ready
913
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
914
+ "Begin waiting for dataset warmup in model.build().")
915
+ logger.info("Begin waiting for dataset warmup in model.build() procedure.")
916
+ self._waiting_for_dataset_warmup_ready(train_dataset)
917
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
918
+ "The dataset warmup was successful in model.build().")
919
+ logger.info("The dataset warmup was successful in model.build() procedure.")
824
920
 
825
921
  if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
826
922
  train_network.add_flags_recursive(is_first_iteration=True)
@@ -830,6 +926,7 @@ class Model:
830
926
  logger.info("Begin to compile train network in model.build() procedure.")
831
927
  train_network.compile(*inputs)
832
928
  self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
929
+ train_dataset.reset()
833
930
  break
834
931
 
835
932
  if valid_dataset:
@@ -843,7 +940,7 @@ class Model:
843
940
  valid_dataset.__no_send__ = True
844
941
  valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
845
942
  dataset=valid_dataset,
846
- dataset_sink_mode=True)
943
+ dataset_sink_mode=sink_mode)
847
944
  if context.get_auto_parallel_context("pipeline_stages") > 1:
848
945
  eval_network.add_flags_recursive(is_first_iteration=False)
849
946
  for inputs in valid_dataset_helper:
@@ -851,6 +948,7 @@ class Model:
851
948
  "Begin to compile eval network in model.build().")
852
949
  logger.info("Begin to compile eval network in model.build() procedure.")
853
950
  eval_network.compile(*inputs)
951
+ valid_dataset.reset()
854
952
  break
855
953
 
856
954
  @staticmethod
@@ -908,10 +1006,6 @@ class Model:
908
1006
  cb_params.list_callback = self._transform_callbacks(callbacks)
909
1007
  valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
910
1008
  cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
911
- if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
912
- FlopsUtilizationCollector not in cb_params.list_callback:
913
- cb_params.list_callback.insert(0, FlopsUtilizationCollector(
914
- cb_params.batch_num, full_flops=False))
915
1009
  if context.get_context("mode") == context.PYNATIVE_MODE:
916
1010
  cb_params.list_callback.insert(0, _StepSync())
917
1011
  callbacks = cb_params.list_callback
@@ -923,6 +1017,8 @@ class Model:
923
1017
  cb_params.last_save_ckpt_step = None
924
1018
  cb_params.latest_ckpt_file = None
925
1019
  cb_params.loss_scale_mananger = self._loss_scale_manager
1020
+ cb_params.is_arf = _get_recovery_context("is_arf")
1021
+ cb_params.initial_step = self._initial_step
926
1022
 
927
1023
  # build callback list
928
1024
  with _CallbackManager(callbacks) as list_callback:
@@ -1027,6 +1123,9 @@ class Model:
1027
1123
  need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
1028
1124
  if need_exec_callback_step_end:
1029
1125
  list_callback.on_train_step_end(run_context)
1126
+ if cb_params.is_arf:
1127
+ cb_params.is_arf = False
1128
+ _set_recovery_context(is_arf=False)
1030
1129
 
1031
1130
  # Embedding cache server only run one step.
1032
1131
  if is_embedding_cache_server:
@@ -1057,7 +1156,7 @@ class Model:
1057
1156
  if should_stop:
1058
1157
  break
1059
1158
 
1060
- need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
1159
+ need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
1061
1160
  and not _get_recovery_context("latest_ckpt_file")
1062
1161
  self.epoch_iter += 1
1063
1162
  if need_reset_to_beginning:
@@ -1101,7 +1200,7 @@ class Model:
1101
1200
  Check whether enable recovery and execution mode consistency.
1102
1201
  """
1103
1202
 
1104
- enable_recovery = _get_recovery_context("enable_recovery")
1203
+ enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
1105
1204
  if not enable_recovery:
1106
1205
  self.enable_recovery = False
1107
1206
  else:
@@ -1118,6 +1217,8 @@ class Model:
1118
1217
  dataset_size (int): The number of batches in a dataset.
1119
1218
  sink_size (int): Control the amount of data in each sink. Default: -1.
1120
1219
  """
1220
+ if context.get_context("device_target") != "GPU":
1221
+ return
1121
1222
  if not self.enable_recovery:
1122
1223
  self.need_load_ckpt = False
1123
1224
 
@@ -1146,7 +1247,7 @@ class Model:
1146
1247
  load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
1147
1248
  except BaseException as e:
1148
1249
  os.remove(cb_params.latest_ckpt_file)
1149
- raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
1250
+ raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
1150
1251
  + cb_params.latest_ckpt_file) from e
1151
1252
  _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
1152
1253
  self.need_load_ckpt = False
@@ -1236,6 +1337,9 @@ class Model:
1236
1337
  self._loss_scale_manager.update_loss_scale(overflow)
1237
1338
 
1238
1339
  list_callback.on_train_step_end(run_context)
1340
+ if cb_params.is_arf:
1341
+ cb_params.is_arf = False
1342
+ _set_recovery_context(is_arf=False)
1239
1343
  # Embedding cache server only run one step.
1240
1344
  if is_embedding_cache_server:
1241
1345
  break
@@ -1333,10 +1437,9 @@ class Model:
1333
1437
  ... loss_scale_manager=loss_scale_manager)
1334
1438
  >>> model.train(2, dataset)
1335
1439
  """
1440
+ _init_auto_parallel_context(self._network)
1336
1441
  _check_tft()
1337
1442
  device_target = context.get_context("device_target")
1338
- # prepare dataset for obfuscated model
1339
- train_dataset = self._prepare_obf_dataset(train_dataset)
1340
1443
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1341
1444
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1342
1445
  dataset_sink_mode = False
@@ -1392,6 +1495,8 @@ class Model:
1392
1495
  if _enable_distributed_mindrt():
1393
1496
  _reset_op_id_with_offset()
1394
1497
 
1498
+ _clear_auto_parallel_context(self._network)
1499
+
1395
1500
  @staticmethod
1396
1501
  def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
1397
1502
  if get_debug_mode() and dataset_sink_mode:
@@ -1485,11 +1590,8 @@ class Model:
1485
1590
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1486
1591
  >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
1487
1592
  >>> model.fit(2, train_dataset, valid_dataset)
1488
-
1489
- Tutorial Examples:
1490
- - `Advanced Encapsulation: Model - Train and Save Model
1491
- <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1492
1593
  """
1594
+ _init_auto_parallel_context(self._network)
1493
1595
  device_target = context.get_context("device_target")
1494
1596
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1495
1597
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
@@ -1541,8 +1643,9 @@ class Model:
1541
1643
  valid_dataset=valid_dataset,
1542
1644
  valid_frequency=valid_frequency,
1543
1645
  valid_dataset_sink_mode=valid_dataset_sink_mode)
1646
+ _clear_auto_parallel_context(self._network)
1544
1647
 
1545
- def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
1648
+ def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
1546
1649
  """
1547
1650
  Build computational graphs and data graphs with the sink mode.
1548
1651
 
@@ -1561,6 +1664,7 @@ class Model:
1561
1664
  will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
1562
1665
  sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
1563
1666
  epoch (int): Control the training epochs. Default: ``1`` .
1667
+ sink_mode (bool): Determines whether to pass the data through dataset channel. Default: ``True`` .
1564
1668
 
1565
1669
  Examples:
1566
1670
  >>> from mindspore import nn
@@ -1581,20 +1685,22 @@ class Model:
1581
1685
  >>> model.build(dataset, epoch=2)
1582
1686
  >>> model.train(2, dataset)
1583
1687
  """
1688
+ _init_auto_parallel_context(self._network)
1584
1689
  epoch = Validator.check_positive_int(epoch)
1585
1690
  if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1586
1691
  self._train_network.check_names_and_refresh_name()
1587
1692
  self._train_network._is_check_and_refresh = True
1588
1693
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
1589
1694
  logger.info("Begin to init dataset in model.build() procedure.")
1590
- self._init(train_dataset, valid_dataset, sink_size, epoch)
1695
+ self._init(train_dataset, valid_dataset, sink_size, epoch, sink_mode)
1591
1696
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1592
1697
  "The model.build() which contains dataset warmup and network compile is success.")
1593
1698
  logger.info("The model.build() which contains dataset warmup and network compile is success.")
1699
+ _clear_auto_parallel_context(self._network)
1594
1700
 
1595
1701
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
1596
1702
  """
1597
- Evaluation process in `mindspore.train.Model.fit`.
1703
+ Evaluation process in :func:`mindspore.train.Model.fit`.
1598
1704
 
1599
1705
  Args:
1600
1706
  valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
@@ -1670,6 +1776,9 @@ class Model:
1670
1776
  cb_params.eval_results.update({"eval_loss": eval_loss})
1671
1777
  list_callback.on_eval_end(run_context)
1672
1778
 
1779
+ dataset_helper.stop_send()
1780
+ dataset_helper.release()
1781
+
1673
1782
  return metrics
1674
1783
 
1675
1784
  def _eval_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
@@ -1757,12 +1866,8 @@ class Model:
1757
1866
  >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
1758
1867
  >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
1759
1868
  >>> acc = model.eval(dataset, dataset_sink_mode=False)
1760
-
1761
- Tutorial Examples:
1762
- - `Advanced Encapsulation: Model - Train and Save Model
1763
- <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1764
1869
  """
1765
- valid_dataset = self._prepare_obf_dataset(valid_dataset)
1870
+ _init_auto_parallel_context(self._network)
1766
1871
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
1767
1872
 
1768
1873
  _device_number_check(self._parallel_mode, self._device_number)
@@ -1780,10 +1885,6 @@ class Model:
1780
1885
  cb_params.mode = "eval"
1781
1886
  cb_params.cur_step_num = 0
1782
1887
  cb_params.list_callback = self._transform_callbacks(callbacks)
1783
- if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
1784
- FlopsUtilizationCollector not in cb_params.list_callback:
1785
- cb_params.list_callback.insert(0, FlopsUtilizationCollector(
1786
- cb_params.batch_num, full_flops=False))
1787
1888
  cb_params.network = self._network
1788
1889
 
1789
1890
  self._clear_metrics()
@@ -1811,6 +1912,7 @@ class Model:
1811
1912
  # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1812
1913
  if _enable_distributed_mindrt():
1813
1914
  _reset_op_id_with_offset()
1915
+ _clear_auto_parallel_context(self._network)
1814
1916
 
1815
1917
  return eval_result
1816
1918
 
@@ -1823,7 +1925,8 @@ class Model:
1823
1925
  The predict data, can be a single tensor,
1824
1926
  a list of tensor, or a tuple of tensor.
1825
1927
 
1826
- config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
1928
+ config (dict, optional): The config parameter is enabled when the backend is ‘lite’.
1929
+
1827
1930
  The config includes two parts: config_path (configPath, str) and config_item (str, dict).
1828
1931
  When the config_item is set, its priority is higher than the config_path. Set the ranking
1829
1932
  table file for inference. The content of the configuration file is as follows:
@@ -1833,6 +1936,16 @@ class Model:
1833
1936
  For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
1834
1937
  config.ini file:
1835
1938
 
1939
+ The config has 3 forms:
1940
+ 1. configPath defines the path of the configuration file, which is used to pass user-defined
1941
+ options during model building. Default value: ``"" ``.
1942
+
1943
+ .. code-block::
1944
+
1945
+ config = {"configPath" : "/home/user/config.ini"}
1946
+
1947
+ Here is the content of the config.ini file:
1948
+
1836
1949
  .. code-block::
1837
1950
 
1838
1951
  [ascend_context]
@@ -1841,20 +1954,15 @@ class Model:
1841
1954
  [op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
1842
1955
  [op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
1843
1956
 
1844
- When only the config_path is configured, it is done as follows:
1845
-
1846
- .. code-block::
1847
-
1848
- config = {"configPath" : "/home/user/config.ini"}
1849
-
1850
- When only the config_dict is configured, it is done as follows:
1957
+ 2. Set the user-defined options in parameter dictionary, it is done as follows:
1851
1958
 
1852
1959
  .. code-block::
1853
1960
 
1854
1961
  config = {"ascend_context" : {"rank_table_file" : "path_b"},
1855
1962
  "execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
1856
1963
 
1857
- When both the `config_path` and the `config_dict` are configured, it is done as follows:
1964
+ 3. Both the `configPath` and the `parameter dictionary` are configured, The priority of the parameter
1965
+ dictionary is higher than that of the content in the configuration file. It is done as follows:
1858
1966
 
1859
1967
  .. code-block::
1860
1968
 
@@ -1862,12 +1970,13 @@ class Model:
1862
1970
  "ascend_context" : {"rank_table_file" : "path_b"},
1863
1971
  "execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
1864
1972
 
1865
- Note that both the "configPath" is configured in the config_dict and the config_item,
1866
- in this case, the path_b in the config_dict takes precedence.
1973
+ Note that in the "configPath" the parameter is set as "rank_table_file = [path_a]", but in dict is set
1974
+ as "ascend_context" : {"rank_table_file" : "path_b"}, in this case, the path_b takes precedence.
1867
1975
 
1868
1976
  Returns:
1869
1977
  Tensor, array(s) of predictions.
1870
1978
  """
1979
+
1871
1980
  def _get_lite_context(lite_context_input):
1872
1981
  # use default lite context parameters for now
1873
1982
  device_target = context.get_context("device_target").lower()
@@ -1901,7 +2010,7 @@ class Model:
1901
2010
  if not self._mindspore_lite:
1902
2011
  self._mindspore_lite = importlib.import_module('mindspore_lite')
1903
2012
 
1904
- use_past = False # default execute full model inference
2013
+ use_past = False # default execute full model inference
1905
2014
  model_group_id = None
1906
2015
  if self._predict_network.get_flags().__contains__("is_first_iteration"):
1907
2016
  is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
@@ -2014,6 +2123,7 @@ class Model:
2014
2123
  >>> model = Model(LeNet5())
2015
2124
  >>> result = model.predict(input_data)
2016
2125
  """
2126
+ _init_auto_parallel_context(self._network)
2017
2127
  if backend not in ['lite', None]:
2018
2128
  raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
2019
2129
  if backend == "lite" and self._lite_infer:
@@ -2029,6 +2139,7 @@ class Model:
2029
2139
  except BaseException as e:
2030
2140
  self._lite_infer = False
2031
2141
  logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
2142
+ _clear_auto_parallel_context(self._network)
2032
2143
 
2033
2144
  def _check_input_data():
2034
2145
  """Input data check."""
@@ -2094,7 +2205,9 @@ class Model:
2094
2205
 
2095
2206
  def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
2096
2207
  """
2097
- Generate parameter layout for the train network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
2208
+ Generate parameter layout for the train network when using `AutoParallel(cell)`
2209
+ to enable parallel mode.
2210
+
2098
2211
  Only dataset sink mode is supported for now.
2099
2212
 
2100
2213
  .. warning::
@@ -2113,9 +2226,9 @@ class Model:
2113
2226
  Configure pynative mode or CPU, the training process will be performed with
2114
2227
  dataset not sink. Default: ``True`` .
2115
2228
  sink_size (int): Control the number of steps for each sinking.
2229
+ If dataset_sink_mode is False, set sink_size as invalid.
2116
2230
  If sink_size = -1, sink the complete dataset for each epoch.
2117
2231
  If sink_size > 0, sink sink_size data for each epoch.
2118
- If dataset_sink_mode is False, set sink_size as invalid.
2119
2232
  Default: ``-1`` .
2120
2233
 
2121
2234
  Returns:
@@ -2129,10 +2242,10 @@ class Model:
2129
2242
  >>> from mindspore import Tensor, nn
2130
2243
  >>> from mindspore.train import Model
2131
2244
  >>> from mindspore.communication import init
2245
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
2132
2246
  >>>
2133
2247
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2134
2248
  >>> init()
2135
- >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
2136
2249
  >>>
2137
2250
  >>> # Create the dataset taking MNIST as an example. Refer to
2138
2251
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
@@ -2140,13 +2253,15 @@ class Model:
2140
2253
  >>> # Define the network structure of LeNet5. Refer to
2141
2254
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
2142
2255
  >>> net = LeNet5()
2256
+ >>> parallel_net = AutoParallel(net)
2143
2257
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
2144
2258
  >>> loss_scale_manager = ms.FixedLossScaleManager()
2145
2259
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
2146
- >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
2260
+ >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None,
2147
2261
  ... loss_scale_manager=loss_scale_manager)
2148
2262
  >>> layout_dict = model.infer_train_layout(dataset)
2149
2263
  """
2264
+ _init_auto_parallel_context(self._network)
2150
2265
  self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
2151
2266
 
2152
2267
  train_dataset.__no_send__ = True
@@ -2158,11 +2273,13 @@ class Model:
2158
2273
  train_network.compile(*inputs)
2159
2274
  break
2160
2275
  train_dataset.__model_hash__ = hash(self)
2276
+ _clear_auto_parallel_context(self._network)
2161
2277
  return train_network.parameter_layout_dict
2162
2278
 
2163
2279
  def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
2164
2280
  """
2165
- Generate parameter layout for the predict network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
2281
+ Generate parameter layout for the predict network when using `AutoParallel(cell)`
2282
+ to enable parallel mode.
2166
2283
 
2167
2284
  Data could be a single tensor or multiple tensors.
2168
2285
 
@@ -2185,21 +2302,47 @@ class Model:
2185
2302
  RuntimeError: If not in GRAPH_MODE.
2186
2303
 
2187
2304
  Examples:
2188
- >>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
2189
- >>> # mindspore.cn.
2190
2305
  >>> import numpy as np
2191
- >>> import mindspore as ms
2306
+ >>> import mindspore.nn as nn
2192
2307
  >>> from mindspore import Tensor
2193
2308
  >>> from mindspore.train import Model
2309
+ >>> from mindspore.ops import operations as P
2310
+ >>> from mindspore import context
2194
2311
  >>> from mindspore.communication import init
2312
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
2313
+ >>>
2314
+ >>> class Net(nn.Cell):
2315
+ >>> def __init__(self):
2316
+ >>> super(Net, self).__init__()
2317
+ >>> self.fc1 = nn.Dense(128, 768, activation='relu')
2318
+ >>> self.fc2 = nn.Dense(128, 768, activation='relu')
2319
+ >>> self.fc3 = nn.Dense(128, 768, activation='relu')
2320
+ >>> self.fc4 = nn.Dense(768, 768, activation='relu')
2321
+ >>> self.relu4 = nn.ReLU()
2322
+ >>> self.relu5 = nn.ReLU()
2323
+ >>> self.transpose = P.Transpose()
2324
+ >>> self.matmul1 = P.MatMul()
2325
+ >>> self.matmul2 = P.MatMul()
2326
+ >>>
2327
+ >>> def construct(self, x):
2328
+ >>> q = self.fc1(x)
2329
+ >>> k = self.fc2(x)
2330
+ >>> v = self.fc3(x)
2331
+ >>> k = self.transpose(k, (1, 0))
2332
+ >>> c = self.relu4(self.matmul1(q, k))
2333
+ >>> s = self.relu5(self.matmul2(c, v))
2334
+ >>> s = self.fc4(s)
2335
+ >>> return s
2195
2336
  >>>
2196
2337
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2197
2338
  >>> init()
2198
- >>> ms.set_auto_parallel_context(full_batch=True, parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
2199
- >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
2200
- >>> model = Model(Net())
2201
- >>> predict_map = model.infer_predict_layout(input_data)
2339
+ >>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
2340
+ >>> net = Net()
2341
+ >>> parallel_net = AutoParallel(net, parallel_mode='semi_auto')
2342
+ >>> model = Model(parallel_net)
2343
+ >>> predict_map = model.infer_predict_layout(inputs)
2202
2344
  """
2345
+ _init_auto_parallel_context(self._network)
2203
2346
  if context.get_context("mode") != context.GRAPH_MODE:
2204
2347
  raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
2205
2348
  "only supports GRAPH MODE and Ascend target currently.")
@@ -2219,6 +2362,7 @@ class Model:
2219
2362
  predict_net.phase = origin_phase
2220
2363
  else:
2221
2364
  predict_net.compile(*predict_data)
2365
+ _clear_auto_parallel_context(self._network)
2222
2366
  return predict_net.parameter_layout_dict
2223
2367
 
2224
2368
  def _flush_from_cache(self, cb_params):
@@ -2258,16 +2402,5 @@ class Model:
2258
2402
  """
2259
2403
  return self._eval_network
2260
2404
 
2261
- def _prepare_obf_dataset(self, dataset):
2262
- if not hasattr(self._network, 'obf_ratios'):
2263
- return dataset
2264
- data_size = dataset.get_dataset_size()
2265
- obf_ratio_dataset = []
2266
- for _ in range(data_size):
2267
- obf_ratio_dataset.append(self._network.obf_ratios)
2268
- obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
2269
- dataset = ds.zip((dataset, obf_ratio_dataset))
2270
- return dataset
2271
-
2272
2405
 
2273
2406
  __all__ = ["Model"]