mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/model.py CHANGED
@@ -27,7 +27,6 @@ import time
27
27
  import numpy as np
28
28
 
29
29
  import mindspore
30
- import mindspore.dataset as ds
31
30
  from mindspore import log as logger
32
31
  from mindspore.train.serialization import save_checkpoint, load_checkpoint
33
32
  from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
@@ -36,7 +35,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
36
35
  from mindspore._checkparam import check_input_data, check_output_data
37
36
  from mindspore import _checkparam as Validator
38
37
  from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
39
- FlopsUtilizationCollector, TFTRegister
38
+ TrainFaultTolerance
40
39
  from mindspore.train.callback import __all__ as internal_cb_names
41
40
  from mindspore.train.callback._cluster_monitor import ClusterMonitor
42
41
  from mindspore import context
@@ -46,7 +45,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
46
45
  from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
47
46
  _cache_enable, _enable_distributed_mindrt
48
47
  from mindspore.train.metrics import Loss
49
- from mindspore.train._utils import vlog_print
48
+ from mindspore.log import vlog_print
50
49
  from mindspore import nn
51
50
  from mindspore.boost import AutoBoost
52
51
  from mindspore.context import ParallelMode
@@ -57,7 +56,9 @@ from mindspore.dataset.core.config import get_debug_mode
57
56
  from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
58
57
  from mindspore.train import amp
59
58
  from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
59
+ from mindspore._c_expression import _get_optimzer_timestamps
60
60
 
61
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
61
62
 
62
63
  def _transfer_tensor_to_tuple(inputs):
63
64
  """
@@ -91,6 +92,7 @@ def _save_final_ckpt(func):
91
92
  """
92
93
  Decorator function, which saves the current checkpoint when an exception occurs during training.
93
94
  """
95
+
94
96
  @wraps(func)
95
97
  def wrapper(self, *args, **kwargs):
96
98
  obj = None
@@ -107,7 +109,7 @@ def _save_final_ckpt(func):
107
109
  # pylint: disable=W0212
108
110
  prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
109
111
  cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
110
- + str(self._current_step_num) + "_breakpoint.ckpt"
112
+ + str(self._current_step_num) + "_breakpoint.ckpt"
111
113
  cur_file = os.path.join(obj._directory, cur_ckpoint_file)
112
114
  if "epoch_num" in obj._append_dict:
113
115
  obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
@@ -118,55 +120,82 @@ def _save_final_ckpt(func):
118
120
  raise e
119
121
  else:
120
122
  func(self, *args, **kwargs)
123
+
121
124
  return wrapper
122
125
 
126
+
127
+ def _handle_exception_info(obj, uce_env, tft, e):
128
+ """handle exception info"""
129
+ logger.info("uce wrapper caught RuntimeError")
130
+ if not uce_env:
131
+ logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
132
+ exc_info=True)
133
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
134
+ raise e
135
+ e_str = str(e)
136
+ logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
137
+ if "UCEError" in e_str:
138
+ logger.info("uce wrapper report UCEError")
139
+ obj.is_uce_rank = True
140
+ # if error is HBM_MULTI_BIT_ECC_ERROR
141
+ if "error_code=507054" in e_str:
142
+ hbm_error_time, optimize_start, optimizer_end = _get_optimzer_timestamps()
143
+ can_repair = tft.tft_can_do_uce_repair(hbm_error_time, optimize_start, optimizer_end)
144
+ logger.info(f"UCEError of type HBM_MULTI_BIT_ECC_ERROR occurs, \
145
+ hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, \
146
+ optimizer_end={optimizer_end}, can_repair={can_repair}")
147
+ if not can_repair:
148
+ logger.error(f"Caught UCEError of type HBM_MULTI_BIT_ECC_ERROR but can not repair, "
149
+ f"hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, "
150
+ f"optimizer_end={optimizer_end}", exc_info=True)
151
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
152
+ raise e
153
+ tft.tft_report_error(tft.ReportState.RS_UCE.value)
154
+ elif "ForceStopError" in e_str:
155
+ logger.warning("uce wrapper caught RuntimeError ForceStopError")
156
+ force_stop_err = tft.ReportState.RS_NORMAL.value
157
+ tft.tft_report_error(force_stop_err)
158
+ elif "ARF FINISH" in e_str:
159
+ logger.warning(f"ARF FINISH")
160
+ _set_recovery_context(is_arf=True)
161
+ tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
162
+ else:
163
+ logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
164
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
165
+ raise e
166
+
167
+
123
168
  def _handle_tft(func):
124
169
  """
125
170
  Decorator function, which starts uce handle process when an exception occurs during training.
126
171
  """
172
+
127
173
  @wraps(func)
128
174
  def wrapper(self, *args, **kwargs):
129
175
  obj = None
130
- if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
176
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
131
177
  obj = kwargs.get('callbacks')
132
178
  if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
133
179
  for item in kwargs.get('callbacks'):
134
- if isinstance(item, TFTRegister):
180
+ if isinstance(item, TrainFaultTolerance):
135
181
  obj = item
136
182
  if obj:
137
183
  tft = obj.tft
138
184
  tft_env = os.getenv("MS_ENABLE_TFT", "")
139
- uce_env = "UCE:1" in tft_env
185
+ uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
140
186
  while True:
141
187
  try:
142
188
  return func(self, *args, **kwargs)
143
189
  except RuntimeError as e:
144
- logger.info("uce wrapper caught RuntimeError")
145
- if not uce_env:
146
- logger.info("uce wrapper caught RuntimeError uce not enable")
147
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
148
- raise e
149
- e_str = str(e)
150
- logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
151
- if "UCEError" in e_str:
152
- obj.is_uce_rank = True
153
- logger.info("uce wrapper report UCEError")
154
- tft.tft_report_error(tft.ReportState.RS_UCE.value)
155
- elif "ForceStopError" in e_str:
156
- logger.info("uce wrapper caught RuntimeError ForceStopError")
157
- force_stop_err = tft.ReportState.RS_NORMAL.value
158
- tft.tft_report_error(force_stop_err)
159
- else:
160
- logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
161
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
162
- raise e
190
+ _handle_exception_info(obj, uce_env, tft, e)
163
191
  ret = tft.tft_wait_next_action()
164
192
  if ret == tft.Action.EXIT.value:
165
193
  raise e
166
194
  repair_step = tft.tft_get_repair_step()
167
- logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
168
- {}".format(repair_step, self.batch_num))
169
- initial_epoch = int(repair_step/self.batch_num)
195
+ logger.warning(
196
+ "uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
197
+ self.batch_num))
198
+ initial_epoch = int(repair_step / self.batch_num)
170
199
  initial_step = repair_step % self.batch_num
171
200
  kwargs["initial_epoch"] = initial_epoch
172
201
 
@@ -187,16 +216,19 @@ def _handle_tft(func):
187
216
  cb_initial_step = initial_step
188
217
 
189
218
  kwargs["initial_step"] = cb_initial_step
190
-
191
- logger.info("uce wrapper repair complete \
192
- initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
219
+ # reset all accu grads to zero
220
+ obj._reset_acc_grads()
221
+ logger.warning(
222
+ "uce wrapper repair complete initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch,
223
+ cb_initial_step))
193
224
  continue
194
225
  except BaseException as e:
195
- logger.info("uce wrapper caught BaseException error")
226
+ logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
196
227
  tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
197
228
  raise e
198
229
  else:
199
230
  return func(self, *args, **kwargs)
231
+
200
232
  return wrapper
201
233
 
202
234
 
@@ -213,7 +245,7 @@ def _check_tft():
213
245
  if ms_mode != mindspore.GRAPH_MODE:
214
246
  raise ValueError("TFT is only supported in GRAPH_MODE")
215
247
  jit_level = context.get_context("jit_level")
216
- if jit_level == "O2" and "UCE:1" in tft_env:
248
+ if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
217
249
  raise ValueError("TFT is not supported when using jit_level == O2")
218
250
 
219
251
 
@@ -403,12 +435,13 @@ class Model:
403
435
  the accuracy is reduced by less than 3%.
404
436
 
405
437
  If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
406
- In order for this function to work, you need to set the optimizer, eval_network or metric parameters
407
- at the same time.
438
+ In order for this function to work, you need to set the parameter `optimizer`, along with
439
+ at least one of the parameter `eval_network` or performance `metrics`.
408
440
 
409
441
  Notice: The current optimization enabled by default only applies to some networks, and not all networks
410
442
  can obtain the same benefits. It is recommended to enable this function on
411
- the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
443
+ the Graph mode + Ascend platform, and for better acceleration,
444
+ refer to :class:`mindspore.boost.AutoBoost` to configure
412
445
  boost_config_dict.
413
446
 
414
447
  Examples:
@@ -433,6 +466,7 @@ class Model:
433
466
  def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
434
467
  amp_level="O0", boost_level="O0", **kwargs):
435
468
  self._network = network
469
+ _init_auto_parallel_context(self._network)
436
470
  self._loss_fn = loss_fn
437
471
  self._optimizer = optimizer
438
472
  self._loss_scale_manager = None
@@ -467,6 +501,7 @@ class Model:
467
501
  self._lite_infer = True # if backend lite infer fails, set False
468
502
  self._mindspore_lite_model_group_id = id(self) & 0xFFFF
469
503
  self.batch_num = -1
504
+ _clear_auto_parallel_context(self._network)
470
505
 
471
506
  def _check_for_graph_cell(self, kwargs):
472
507
  """Check for graph cell"""
@@ -762,7 +797,7 @@ class Model:
762
797
  break
763
798
  logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
764
799
 
765
- def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
800
+ def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
766
801
  """
767
802
  Initialize compute graphs and data graphs with the sink mode.
768
803
 
@@ -791,7 +826,6 @@ class Model:
791
826
  if not isinstance(train_dataset, mindspore.dataset.Dataset):
792
827
  raise TypeError("The type of 'train_dataset' must be `Dataset`, "
793
828
  "but got {}.".format(type(train_dataset)))
794
-
795
829
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
796
830
  "Begin to check parameter broadcast in model.build().")
797
831
  logger.info("Begin to check parameter broadcast in model.build() procedure.")
@@ -804,23 +838,24 @@ class Model:
804
838
  train_dataset.__no_send__ = True
805
839
  train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
806
840
  dataset=train_dataset,
807
- dataset_sink_mode=True,
841
+ dataset_sink_mode=sink_mode,
808
842
  sink_size=sink_size)
809
843
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
810
- logger.info("Begin to warmup dataset in model.build() procedure.")
811
- self._warmup_dataset(epoch, train_dataset, sink_size)
844
+ if sink_mode:
845
+ logger.info("Begin to warmup dataset in model.build() procedure.")
846
+ self._warmup_dataset(epoch, train_dataset, sink_size)
812
847
 
813
- # Since dataset pipeline has been triggered, delete flag
814
- delattr(train_dataset, "__no_send__")
848
+ # Since dataset pipeline has been triggered, delete flag
849
+ delattr(train_dataset, "__no_send__")
815
850
 
816
- # Waiting for the dataset warmup ready
817
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
818
- "Begin waiting for dataset warmup in model.build().")
819
- logger.info("Begin waiting for dataset warmup in model.build() procedure.")
820
- self._waiting_for_dataset_warmup_ready(train_dataset)
821
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
822
- "The dataset warmup was successful in model.build().")
823
- logger.info("The dataset warmup was successful in model.build() procedure.")
851
+ # Waiting for the dataset warmup ready
852
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
853
+ "Begin waiting for dataset warmup in model.build().")
854
+ logger.info("Begin waiting for dataset warmup in model.build() procedure.")
855
+ self._waiting_for_dataset_warmup_ready(train_dataset)
856
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
857
+ "The dataset warmup was successful in model.build().")
858
+ logger.info("The dataset warmup was successful in model.build() procedure.")
824
859
 
825
860
  if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
826
861
  train_network.add_flags_recursive(is_first_iteration=True)
@@ -830,6 +865,7 @@ class Model:
830
865
  logger.info("Begin to compile train network in model.build() procedure.")
831
866
  train_network.compile(*inputs)
832
867
  self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
868
+ train_dataset.reset()
833
869
  break
834
870
 
835
871
  if valid_dataset:
@@ -843,7 +879,7 @@ class Model:
843
879
  valid_dataset.__no_send__ = True
844
880
  valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
845
881
  dataset=valid_dataset,
846
- dataset_sink_mode=True)
882
+ dataset_sink_mode=sink_mode)
847
883
  if context.get_auto_parallel_context("pipeline_stages") > 1:
848
884
  eval_network.add_flags_recursive(is_first_iteration=False)
849
885
  for inputs in valid_dataset_helper:
@@ -851,6 +887,7 @@ class Model:
851
887
  "Begin to compile eval network in model.build().")
852
888
  logger.info("Begin to compile eval network in model.build() procedure.")
853
889
  eval_network.compile(*inputs)
890
+ valid_dataset.reset()
854
891
  break
855
892
 
856
893
  @staticmethod
@@ -908,10 +945,6 @@ class Model:
908
945
  cb_params.list_callback = self._transform_callbacks(callbacks)
909
946
  valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
910
947
  cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
911
- if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
912
- FlopsUtilizationCollector not in cb_params.list_callback:
913
- cb_params.list_callback.insert(0, FlopsUtilizationCollector(
914
- cb_params.batch_num, full_flops=False))
915
948
  if context.get_context("mode") == context.PYNATIVE_MODE:
916
949
  cb_params.list_callback.insert(0, _StepSync())
917
950
  callbacks = cb_params.list_callback
@@ -923,6 +956,7 @@ class Model:
923
956
  cb_params.last_save_ckpt_step = None
924
957
  cb_params.latest_ckpt_file = None
925
958
  cb_params.loss_scale_mananger = self._loss_scale_manager
959
+ cb_params.is_arf = _get_recovery_context("is_arf")
926
960
 
927
961
  # build callback list
928
962
  with _CallbackManager(callbacks) as list_callback:
@@ -1027,6 +1061,9 @@ class Model:
1027
1061
  need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
1028
1062
  if need_exec_callback_step_end:
1029
1063
  list_callback.on_train_step_end(run_context)
1064
+ if cb_params.is_arf:
1065
+ cb_params.is_arf = False
1066
+ _set_recovery_context(is_arf=False)
1030
1067
 
1031
1068
  # Embedding cache server only run one step.
1032
1069
  if is_embedding_cache_server:
@@ -1057,7 +1094,7 @@ class Model:
1057
1094
  if should_stop:
1058
1095
  break
1059
1096
 
1060
- need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
1097
+ need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
1061
1098
  and not _get_recovery_context("latest_ckpt_file")
1062
1099
  self.epoch_iter += 1
1063
1100
  if need_reset_to_beginning:
@@ -1101,7 +1138,7 @@ class Model:
1101
1138
  Check whether enable recovery and execution mode consistency.
1102
1139
  """
1103
1140
 
1104
- enable_recovery = _get_recovery_context("enable_recovery")
1141
+ enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
1105
1142
  if not enable_recovery:
1106
1143
  self.enable_recovery = False
1107
1144
  else:
@@ -1118,6 +1155,8 @@ class Model:
1118
1155
  dataset_size (int): The number of batches in a dataset.
1119
1156
  sink_size (int): Control the amount of data in each sink. Default: -1.
1120
1157
  """
1158
+ if context.get_context("device_target") != "GPU":
1159
+ return
1121
1160
  if not self.enable_recovery:
1122
1161
  self.need_load_ckpt = False
1123
1162
 
@@ -1146,7 +1185,7 @@ class Model:
1146
1185
  load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
1147
1186
  except BaseException as e:
1148
1187
  os.remove(cb_params.latest_ckpt_file)
1149
- raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
1188
+ raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
1150
1189
  + cb_params.latest_ckpt_file) from e
1151
1190
  _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
1152
1191
  self.need_load_ckpt = False
@@ -1236,6 +1275,9 @@ class Model:
1236
1275
  self._loss_scale_manager.update_loss_scale(overflow)
1237
1276
 
1238
1277
  list_callback.on_train_step_end(run_context)
1278
+ if cb_params.is_arf:
1279
+ cb_params.is_arf = False
1280
+ _set_recovery_context(is_arf=False)
1239
1281
  # Embedding cache server only run one step.
1240
1282
  if is_embedding_cache_server:
1241
1283
  break
@@ -1333,10 +1375,9 @@ class Model:
1333
1375
  ... loss_scale_manager=loss_scale_manager)
1334
1376
  >>> model.train(2, dataset)
1335
1377
  """
1378
+ _init_auto_parallel_context(self._network)
1336
1379
  _check_tft()
1337
1380
  device_target = context.get_context("device_target")
1338
- # prepare dataset for obfuscated model
1339
- train_dataset = self._prepare_obf_dataset(train_dataset)
1340
1381
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1341
1382
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1342
1383
  dataset_sink_mode = False
@@ -1392,6 +1433,8 @@ class Model:
1392
1433
  if _enable_distributed_mindrt():
1393
1434
  _reset_op_id_with_offset()
1394
1435
 
1436
+ _clear_auto_parallel_context(self._network)
1437
+
1395
1438
  @staticmethod
1396
1439
  def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
1397
1440
  if get_debug_mode() and dataset_sink_mode:
@@ -1485,11 +1528,8 @@ class Model:
1485
1528
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1486
1529
  >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
1487
1530
  >>> model.fit(2, train_dataset, valid_dataset)
1488
-
1489
- Tutorial Examples:
1490
- - `Advanced Encapsulation: Model - Train and Save Model
1491
- <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1492
1531
  """
1532
+ _init_auto_parallel_context(self._network)
1493
1533
  device_target = context.get_context("device_target")
1494
1534
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1495
1535
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
@@ -1541,8 +1581,9 @@ class Model:
1541
1581
  valid_dataset=valid_dataset,
1542
1582
  valid_frequency=valid_frequency,
1543
1583
  valid_dataset_sink_mode=valid_dataset_sink_mode)
1584
+ _clear_auto_parallel_context(self._network)
1544
1585
 
1545
- def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
1586
+ def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
1546
1587
  """
1547
1588
  Build computational graphs and data graphs with the sink mode.
1548
1589
 
@@ -1561,6 +1602,7 @@ class Model:
1561
1602
  will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
1562
1603
  sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
1563
1604
  epoch (int): Control the training epochs. Default: ``1`` .
1605
+ sink_mode (bool): Determines whether to pass the data through dataset channel. Default: ``True`` .
1564
1606
 
1565
1607
  Examples:
1566
1608
  >>> from mindspore import nn
@@ -1581,20 +1623,22 @@ class Model:
1581
1623
  >>> model.build(dataset, epoch=2)
1582
1624
  >>> model.train(2, dataset)
1583
1625
  """
1626
+ _init_auto_parallel_context(self._network)
1584
1627
  epoch = Validator.check_positive_int(epoch)
1585
1628
  if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1586
1629
  self._train_network.check_names_and_refresh_name()
1587
1630
  self._train_network._is_check_and_refresh = True
1588
1631
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
1589
1632
  logger.info("Begin to init dataset in model.build() procedure.")
1590
- self._init(train_dataset, valid_dataset, sink_size, epoch)
1633
+ self._init(train_dataset, valid_dataset, sink_size, epoch, sink_mode)
1591
1634
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1592
1635
  "The model.build() which contains dataset warmup and network compile is success.")
1593
1636
  logger.info("The model.build() which contains dataset warmup and network compile is success.")
1637
+ _clear_auto_parallel_context(self._network)
1594
1638
 
1595
1639
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
1596
1640
  """
1597
- Evaluation process in `mindspore.train.Model.fit`.
1641
+ Evaluation process in :func:`mindspore.train.Model.fit`.
1598
1642
 
1599
1643
  Args:
1600
1644
  valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
@@ -1670,6 +1714,9 @@ class Model:
1670
1714
  cb_params.eval_results.update({"eval_loss": eval_loss})
1671
1715
  list_callback.on_eval_end(run_context)
1672
1716
 
1717
+ dataset_helper.stop_send()
1718
+ dataset_helper.release()
1719
+
1673
1720
  return metrics
1674
1721
 
1675
1722
  def _eval_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
@@ -1757,12 +1804,8 @@ class Model:
1757
1804
  >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
1758
1805
  >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
1759
1806
  >>> acc = model.eval(dataset, dataset_sink_mode=False)
1760
-
1761
- Tutorial Examples:
1762
- - `Advanced Encapsulation: Model - Train and Save Model
1763
- <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1764
1807
  """
1765
- valid_dataset = self._prepare_obf_dataset(valid_dataset)
1808
+ _init_auto_parallel_context(self._network)
1766
1809
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
1767
1810
 
1768
1811
  _device_number_check(self._parallel_mode, self._device_number)
@@ -1780,10 +1823,6 @@ class Model:
1780
1823
  cb_params.mode = "eval"
1781
1824
  cb_params.cur_step_num = 0
1782
1825
  cb_params.list_callback = self._transform_callbacks(callbacks)
1783
- if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
1784
- FlopsUtilizationCollector not in cb_params.list_callback:
1785
- cb_params.list_callback.insert(0, FlopsUtilizationCollector(
1786
- cb_params.batch_num, full_flops=False))
1787
1826
  cb_params.network = self._network
1788
1827
 
1789
1828
  self._clear_metrics()
@@ -1811,6 +1850,7 @@ class Model:
1811
1850
  # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1812
1851
  if _enable_distributed_mindrt():
1813
1852
  _reset_op_id_with_offset()
1853
+ _clear_auto_parallel_context(self._network)
1814
1854
 
1815
1855
  return eval_result
1816
1856
 
@@ -1823,7 +1863,8 @@ class Model:
1823
1863
  The predict data, can be a single tensor,
1824
1864
  a list of tensor, or a tuple of tensor.
1825
1865
 
1826
- config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
1866
+ config (dict, optional): The config parameter is enabled when the backend is ‘lite’.
1867
+
1827
1868
  The config includes two parts: config_path (configPath, str) and config_item (str, dict).
1828
1869
  When the config_item is set, its priority is higher than the config_path. Set the ranking
1829
1870
  table file for inference. The content of the configuration file is as follows:
@@ -1833,6 +1874,16 @@ class Model:
1833
1874
  For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
1834
1875
  config.ini file:
1835
1876
 
1877
+ The config has 3 forms:
1878
+ 1. configPath defines the path of the configuration file, which is used to pass user-defined
1879
+ options during model building. Default value: ``"" ``.
1880
+
1881
+ .. code-block::
1882
+
1883
+ config = {"configPath" : "/home/user/config.ini"}
1884
+
1885
+ Here is the content of the config.ini file:
1886
+
1836
1887
  .. code-block::
1837
1888
 
1838
1889
  [ascend_context]
@@ -1841,20 +1892,15 @@ class Model:
1841
1892
  [op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
1842
1893
  [op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
1843
1894
 
1844
- When only the config_path is configured, it is done as follows:
1845
-
1846
- .. code-block::
1847
-
1848
- config = {"configPath" : "/home/user/config.ini"}
1849
-
1850
- When only the config_dict is configured, it is done as follows:
1895
+ 2. Set the user-defined options in parameter dictionary, it is done as follows:
1851
1896
 
1852
1897
  .. code-block::
1853
1898
 
1854
1899
  config = {"ascend_context" : {"rank_table_file" : "path_b"},
1855
1900
  "execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
1856
1901
 
1857
- When both the `config_path` and the `config_dict` are configured, it is done as follows:
1902
+ 3. Both the `configPath` and the `parameter dictionary` are configured, The priority of the parameter
1903
+ dictionary is higher than that of the content in the configuration file. It is done as follows:
1858
1904
 
1859
1905
  .. code-block::
1860
1906
 
@@ -1862,12 +1908,13 @@ class Model:
1862
1908
  "ascend_context" : {"rank_table_file" : "path_b"},
1863
1909
  "execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
1864
1910
 
1865
- Note that both the "configPath" is configured in the config_dict and the config_item,
1866
- in this case, the path_b in the config_dict takes precedence.
1911
+ Note that in the "configPath" the parameter is set as "rank_table_file = [path_a]", but in dict is set
1912
+ as "ascend_context" : {"rank_table_file" : "path_b"}, in this case, the path_b takes precedence.
1867
1913
 
1868
1914
  Returns:
1869
1915
  Tensor, array(s) of predictions.
1870
1916
  """
1917
+
1871
1918
  def _get_lite_context(lite_context_input):
1872
1919
  # use default lite context parameters for now
1873
1920
  device_target = context.get_context("device_target").lower()
@@ -1901,7 +1948,7 @@ class Model:
1901
1948
  if not self._mindspore_lite:
1902
1949
  self._mindspore_lite = importlib.import_module('mindspore_lite')
1903
1950
 
1904
- use_past = False # default execute full model inference
1951
+ use_past = False # default execute full model inference
1905
1952
  model_group_id = None
1906
1953
  if self._predict_network.get_flags().__contains__("is_first_iteration"):
1907
1954
  is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
@@ -2014,6 +2061,7 @@ class Model:
2014
2061
  >>> model = Model(LeNet5())
2015
2062
  >>> result = model.predict(input_data)
2016
2063
  """
2064
+ _init_auto_parallel_context(self._network)
2017
2065
  if backend not in ['lite', None]:
2018
2066
  raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
2019
2067
  if backend == "lite" and self._lite_infer:
@@ -2029,6 +2077,7 @@ class Model:
2029
2077
  except BaseException as e:
2030
2078
  self._lite_infer = False
2031
2079
  logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
2080
+ _clear_auto_parallel_context(self._network)
2032
2081
 
2033
2082
  def _check_input_data():
2034
2083
  """Input data check."""
@@ -2094,7 +2143,9 @@ class Model:
2094
2143
 
2095
2144
  def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
2096
2145
  """
2097
- Generate parameter layout for the train network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
2146
+ Generate parameter layout for the train network when using `AutoParallel(cell)`
2147
+ to enable parallel mode.
2148
+
2098
2149
  Only dataset sink mode is supported for now.
2099
2150
 
2100
2151
  .. warning::
@@ -2113,9 +2164,9 @@ class Model:
2113
2164
  Configure pynative mode or CPU, the training process will be performed with
2114
2165
  dataset not sink. Default: ``True`` .
2115
2166
  sink_size (int): Control the number of steps for each sinking.
2167
+ If dataset_sink_mode is False, set sink_size as invalid.
2116
2168
  If sink_size = -1, sink the complete dataset for each epoch.
2117
2169
  If sink_size > 0, sink sink_size data for each epoch.
2118
- If dataset_sink_mode is False, set sink_size as invalid.
2119
2170
  Default: ``-1`` .
2120
2171
 
2121
2172
  Returns:
@@ -2129,10 +2180,10 @@ class Model:
2129
2180
  >>> from mindspore import Tensor, nn
2130
2181
  >>> from mindspore.train import Model
2131
2182
  >>> from mindspore.communication import init
2183
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
2132
2184
  >>>
2133
2185
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2134
2186
  >>> init()
2135
- >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
2136
2187
  >>>
2137
2188
  >>> # Create the dataset taking MNIST as an example. Refer to
2138
2189
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
@@ -2140,13 +2191,15 @@ class Model:
2140
2191
  >>> # Define the network structure of LeNet5. Refer to
2141
2192
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
2142
2193
  >>> net = LeNet5()
2194
+ >>> parallel_net = AutoParallel(net)
2143
2195
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
2144
2196
  >>> loss_scale_manager = ms.FixedLossScaleManager()
2145
2197
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
2146
- >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
2198
+ >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None,
2147
2199
  ... loss_scale_manager=loss_scale_manager)
2148
2200
  >>> layout_dict = model.infer_train_layout(dataset)
2149
2201
  """
2202
+ _init_auto_parallel_context(self._network)
2150
2203
  self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
2151
2204
 
2152
2205
  train_dataset.__no_send__ = True
@@ -2158,11 +2211,13 @@ class Model:
2158
2211
  train_network.compile(*inputs)
2159
2212
  break
2160
2213
  train_dataset.__model_hash__ = hash(self)
2214
+ _clear_auto_parallel_context(self._network)
2161
2215
  return train_network.parameter_layout_dict
2162
2216
 
2163
2217
  def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
2164
2218
  """
2165
- Generate parameter layout for the predict network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
2219
+ Generate parameter layout for the predict network when using `AutoParallel(cell)`
2220
+ to enable parallel mode.
2166
2221
 
2167
2222
  Data could be a single tensor or multiple tensors.
2168
2223
 
@@ -2185,21 +2240,47 @@ class Model:
2185
2240
  RuntimeError: If not in GRAPH_MODE.
2186
2241
 
2187
2242
  Examples:
2188
- >>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
2189
- >>> # mindspore.cn.
2190
2243
  >>> import numpy as np
2191
- >>> import mindspore as ms
2244
+ >>> import mindspore.nn as nn
2192
2245
  >>> from mindspore import Tensor
2193
2246
  >>> from mindspore.train import Model
2247
+ >>> from mindspore.ops import operations as P
2248
+ >>> from mindspore import context
2194
2249
  >>> from mindspore.communication import init
2250
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
2251
+ >>>
2252
+ >>> class Net(nn.Cell):
2253
+ >>> def __init__(self):
2254
+ >>> super(Net, self).__init__()
2255
+ >>> self.fc1 = nn.Dense(128, 768, activation='relu')
2256
+ >>> self.fc2 = nn.Dense(128, 768, activation='relu')
2257
+ >>> self.fc3 = nn.Dense(128, 768, activation='relu')
2258
+ >>> self.fc4 = nn.Dense(768, 768, activation='relu')
2259
+ >>> self.relu4 = nn.ReLU()
2260
+ >>> self.relu5 = nn.ReLU()
2261
+ >>> self.transpose = P.Transpose()
2262
+ >>> self.matmul1 = P.MatMul()
2263
+ >>> self.matmul2 = P.MatMul()
2264
+ >>>
2265
+ >>> def construct(self, x):
2266
+ >>> q = self.fc1(x)
2267
+ >>> k = self.fc2(x)
2268
+ >>> v = self.fc3(x)
2269
+ >>> k = self.transpose(k, (1, 0))
2270
+ >>> c = self.relu4(self.matmul1(q, k))
2271
+ >>> s = self.relu5(self.matmul2(c, v))
2272
+ >>> s = self.fc4(s)
2273
+ >>> return s
2195
2274
  >>>
2196
2275
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2197
2276
  >>> init()
2198
- >>> ms.set_auto_parallel_context(full_batch=True, parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
2199
- >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
2200
- >>> model = Model(Net())
2201
- >>> predict_map = model.infer_predict_layout(input_data)
2277
+ >>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
2278
+ >>> net = Net()
2279
+ >>> parallel_net = AutoParallel(net, parallel_mode='semi_auto')
2280
+ >>> model = Model(parallel_net)
2281
+ >>> predict_map = model.infer_predict_layout(inputs)
2202
2282
  """
2283
+ _init_auto_parallel_context(self._network)
2203
2284
  if context.get_context("mode") != context.GRAPH_MODE:
2204
2285
  raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
2205
2286
  "only supports GRAPH MODE and Ascend target currently.")
@@ -2219,6 +2300,7 @@ class Model:
2219
2300
  predict_net.phase = origin_phase
2220
2301
  else:
2221
2302
  predict_net.compile(*predict_data)
2303
+ _clear_auto_parallel_context(self._network)
2222
2304
  return predict_net.parameter_layout_dict
2223
2305
 
2224
2306
  def _flush_from_cache(self, cb_params):
@@ -2258,16 +2340,5 @@ class Model:
2258
2340
  """
2259
2341
  return self._eval_network
2260
2342
 
2261
- def _prepare_obf_dataset(self, dataset):
2262
- if not hasattr(self._network, 'obf_ratios'):
2263
- return dataset
2264
- data_size = dataset.get_dataset_size()
2265
- obf_ratio_dataset = []
2266
- for _ in range(data_size):
2267
- obf_ratio_dataset.append(self._network.obf_ratios)
2268
- obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
2269
- dataset = ds.zip((dataset, obf_ratio_dataset))
2270
- return dataset
2271
-
2272
2343
 
2273
2344
  __all__ = ["Model"]