mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (493) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +25 -194
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +109 -75
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +2014 -3386
  46. mindspore/common/api.py +386 -355
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/generator.py +3 -0
  52. mindspore/common/hook_handle.py +5 -3
  53. mindspore/common/initializer.py +10 -6
  54. mindspore/common/jit_begin_end.py +94 -0
  55. mindspore/common/jit_config.py +6 -1
  56. mindspore/common/jit_context.py +76 -0
  57. mindspore/common/jit_trace.py +378 -0
  58. mindspore/common/lazy_inline.py +2 -2
  59. mindspore/common/mutable.py +5 -4
  60. mindspore/common/parameter.py +106 -39
  61. mindspore/common/seed.py +2 -2
  62. mindspore/common/sparse_tensor.py +23 -17
  63. mindspore/common/tensor.py +332 -714
  64. mindspore/communication/__init__.py +7 -5
  65. mindspore/communication/_comm_helper.py +47 -2
  66. mindspore/communication/comm_func.py +70 -53
  67. mindspore/communication/management.py +83 -17
  68. mindspore/context.py +228 -571
  69. mindspore/dataset/__init__.py +44 -20
  70. mindspore/dataset/audio/__init__.py +2 -8
  71. mindspore/dataset/audio/transforms.py +3 -17
  72. mindspore/dataset/core/config.py +3 -3
  73. mindspore/dataset/engine/cache_client.py +1 -1
  74. mindspore/dataset/engine/datasets.py +102 -120
  75. mindspore/dataset/engine/datasets_audio.py +22 -22
  76. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  77. mindspore/dataset/engine/datasets_text.py +78 -85
  78. mindspore/dataset/engine/datasets_user_defined.py +109 -77
  79. mindspore/dataset/engine/datasets_vision.py +111 -108
  80. mindspore/dataset/engine/iterators.py +5 -3
  81. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  82. mindspore/dataset/engine/samplers.py +279 -57
  83. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  84. mindspore/dataset/engine/validators.py +10 -0
  85. mindspore/dataset/text/__init__.py +7 -6
  86. mindspore/dataset/text/transforms.py +6 -5
  87. mindspore/dataset/text/utils.py +3 -3
  88. mindspore/dataset/transforms/__init__.py +0 -9
  89. mindspore/dataset/transforms/transforms.py +3 -3
  90. mindspore/dataset/utils/browse_dataset.py +1 -1
  91. mindspore/dataset/vision/__init__.py +2 -9
  92. mindspore/dataset/vision/transforms.py +202 -158
  93. mindspore/dataset/vision/utils.py +7 -5
  94. mindspore/device_context/ascend/op_debug.py +60 -1
  95. mindspore/device_context/ascend/op_tuning.py +0 -4
  96. mindspore/device_manager.py +39 -3
  97. mindspore/dnnl.dll +0 -0
  98. mindspore/dpcmi.dll +0 -0
  99. mindspore/experimental/es/embedding_service.py +35 -27
  100. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
  101. mindspore/experimental/map_parameter.py +4 -4
  102. mindspore/experimental/optim/adadelta.py +22 -26
  103. mindspore/experimental/optim/adagrad.py +4 -4
  104. mindspore/experimental/optim/adam.py +4 -0
  105. mindspore/experimental/optim/adamax.py +4 -4
  106. mindspore/experimental/optim/adamw.py +4 -0
  107. mindspore/experimental/optim/asgd.py +1 -1
  108. mindspore/experimental/optim/lr_scheduler.py +40 -22
  109. mindspore/experimental/optim/radam.py +5 -5
  110. mindspore/experimental/optim/rprop.py +1 -1
  111. mindspore/experimental/optim/sgd.py +1 -1
  112. mindspore/hal/contiguous_tensors_handle.py +6 -10
  113. mindspore/hal/device.py +55 -81
  114. mindspore/hal/event.py +38 -55
  115. mindspore/hal/memory.py +115 -147
  116. mindspore/hal/stream.py +81 -125
  117. mindspore/include/dataset/constants.h +7 -4
  118. mindspore/include/dataset/execute.h +2 -2
  119. mindspore/jpeg62.dll +0 -0
  120. mindspore/log.py +40 -2
  121. mindspore/mindrecord/__init__.py +20 -7
  122. mindspore/mindspore_backend_common.dll +0 -0
  123. mindspore/mindspore_backend_manager.dll +0 -0
  124. mindspore/mindspore_common.dll +0 -0
  125. mindspore/mindspore_core.dll +0 -0
  126. mindspore/mindspore_dump.dll +0 -0
  127. mindspore/mindspore_frontend.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_memory_pool.dll +0 -0
  130. mindspore/mindspore_ms_backend.dll +0 -0
  131. mindspore/mindspore_ops.dll +0 -0
  132. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  133. mindspore/mindspore_ops_kernel_common.dll +0 -0
  134. mindspore/mindspore_profiler.dll +0 -0
  135. mindspore/mindspore_pyboost.dll +0 -0
  136. mindspore/mindspore_pynative.dll +0 -0
  137. mindspore/mindspore_res_manager.dll +0 -0
  138. mindspore/mindspore_runtime_pipeline.dll +0 -0
  139. mindspore/mint/__init__.py +133 -702
  140. mindspore/mint/distributed/__init__.py +5 -1
  141. mindspore/mint/distributed/distributed.py +198 -113
  142. mindspore/mint/linalg/__init__.py +2 -0
  143. mindspore/mint/nn/__init__.py +280 -18
  144. mindspore/mint/nn/functional.py +282 -64
  145. mindspore/mint/nn/layer/__init__.py +4 -0
  146. mindspore/mint/nn/layer/_functions.py +7 -3
  147. mindspore/mint/nn/layer/activation.py +120 -13
  148. mindspore/mint/nn/layer/conv.py +234 -28
  149. mindspore/mint/nn/layer/normalization.py +15 -16
  150. mindspore/mint/nn/layer/padding.py +1 -1
  151. mindspore/mint/nn/layer/pooling.py +66 -1
  152. mindspore/mint/optim/__init__.py +2 -1
  153. mindspore/mint/optim/sgd.py +171 -0
  154. mindspore/msobj140.dll +0 -0
  155. mindspore/mspdb140.dll +0 -0
  156. mindspore/mspdbcore.dll +0 -0
  157. mindspore/mspdbst.dll +0 -0
  158. mindspore/mspft140.dll +0 -0
  159. mindspore/msvcdis140.dll +0 -0
  160. mindspore/msvcp140_1.dll +0 -0
  161. mindspore/msvcp140_2.dll +0 -0
  162. mindspore/msvcp140_atomic_wait.dll +0 -0
  163. mindspore/msvcp140_codecvt_ids.dll +0 -0
  164. mindspore/nn/__init__.py +4 -1
  165. mindspore/nn/cell.py +1253 -179
  166. mindspore/nn/layer/activation.py +23 -21
  167. mindspore/nn/layer/basic.py +22 -16
  168. mindspore/nn/layer/container.py +1 -1
  169. mindspore/nn/layer/conv.py +53 -42
  170. mindspore/nn/layer/embedding.py +9 -8
  171. mindspore/nn/layer/normalization.py +48 -42
  172. mindspore/nn/layer/pooling.py +75 -31
  173. mindspore/nn/layer/transformer.py +11 -10
  174. mindspore/nn/learning_rate_schedule.py +4 -2
  175. mindspore/nn/loss/loss.py +27 -19
  176. mindspore/nn/optim/ada_grad.py +6 -5
  177. mindspore/nn/optim/adadelta.py +9 -7
  178. mindspore/nn/optim/adafactor.py +1 -1
  179. mindspore/nn/optim/adam.py +18 -14
  180. mindspore/nn/optim/adamax.py +8 -7
  181. mindspore/nn/optim/adasum.py +5 -5
  182. mindspore/nn/optim/asgd.py +3 -1
  183. mindspore/nn/optim/ftrl.py +11 -9
  184. mindspore/nn/optim/lamb.py +1 -1
  185. mindspore/nn/optim/lazyadam.py +12 -10
  186. mindspore/nn/optim/momentum.py +7 -6
  187. mindspore/nn/optim/optimizer.py +2 -2
  188. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  189. mindspore/nn/optim/rmsprop.py +13 -12
  190. mindspore/nn/optim/rprop.py +9 -7
  191. mindspore/nn/optim/sgd.py +9 -6
  192. mindspore/nn/optim/tft_wrapper.py +5 -2
  193. mindspore/nn/probability/bijector/bijector.py +17 -11
  194. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  195. mindspore/nn/probability/bijector/invert.py +2 -2
  196. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  197. mindspore/nn/probability/bijector/softplus.py +3 -2
  198. mindspore/nn/probability/distribution/beta.py +3 -3
  199. mindspore/nn/probability/distribution/categorical.py +1 -1
  200. mindspore/nn/probability/distribution/cauchy.py +4 -2
  201. mindspore/nn/probability/distribution/exponential.py +6 -7
  202. mindspore/nn/probability/distribution/gamma.py +2 -2
  203. mindspore/nn/probability/distribution/gumbel.py +2 -2
  204. mindspore/nn/probability/distribution/half_normal.py +5 -3
  205. mindspore/nn/probability/distribution/logistic.py +5 -3
  206. mindspore/nn/probability/distribution/poisson.py +1 -1
  207. mindspore/nn/probability/distribution/uniform.py +5 -3
  208. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  209. mindspore/nn/reinforcement/tensor_array.py +1 -1
  210. mindspore/nn/wrap/__init__.py +6 -6
  211. mindspore/nn/wrap/cell_wrapper.py +178 -117
  212. mindspore/nn/wrap/grad_reducer.py +45 -36
  213. mindspore/nn/wrap/loss_scale.py +3 -3
  214. mindspore/numpy/array_creations.py +3 -3
  215. mindspore/numpy/array_ops.py +1 -1
  216. mindspore/numpy/utils.py +1 -2
  217. mindspore/numpy/utils_const.py +1 -2
  218. mindspore/opencv_core452.dll +0 -0
  219. mindspore/opencv_imgcodecs452.dll +0 -0
  220. mindspore/opencv_imgproc452.dll +0 -0
  221. mindspore/ops/__init__.py +3 -2
  222. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  223. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  224. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  225. mindspore/ops/_register_for_op.py +0 -11
  226. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  227. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  228. mindspore/ops/_vmap/vmap_array_ops.py +32 -6
  229. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  230. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  231. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  232. mindspore/ops/auto_generate/__init__.py +4 -3
  233. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
  234. mindspore/ops/auto_generate/gen_extend_func.py +286 -208
  235. mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
  236. mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
  237. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  238. mindspore/ops/composite/__init__.py +2 -1
  239. mindspore/ops/composite/base.py +19 -24
  240. mindspore/ops/composite/math_ops.py +6 -16
  241. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  242. mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
  243. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  244. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  248. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  249. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  250. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  251. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  252. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  254. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  255. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  256. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  257. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  259. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  260. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  263. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  264. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  268. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  271. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  272. mindspore/ops/function/__init__.py +28 -2
  273. mindspore/ops/function/_add_attr_func.py +58 -0
  274. mindspore/ops/function/array_func.py +1631 -2347
  275. mindspore/ops/function/clip_func.py +38 -45
  276. mindspore/ops/function/debug_func.py +36 -44
  277. mindspore/ops/function/grad/__init__.py +1 -0
  278. mindspore/ops/function/grad/grad_func.py +104 -71
  279. mindspore/ops/function/image_func.py +1 -1
  280. mindspore/ops/function/linalg_func.py +46 -78
  281. mindspore/ops/function/math_func.py +3024 -3855
  282. mindspore/ops/function/nn_func.py +678 -274
  283. mindspore/ops/function/other_func.py +159 -1
  284. mindspore/ops/function/parameter_func.py +17 -30
  285. mindspore/ops/function/random_func.py +216 -361
  286. mindspore/ops/function/reshard_func.py +4 -70
  287. mindspore/ops/function/sparse_func.py +3 -3
  288. mindspore/ops/function/sparse_unary_func.py +5 -5
  289. mindspore/ops/function/spectral_func.py +25 -58
  290. mindspore/ops/function/vmap_func.py +26 -18
  291. mindspore/ops/functional.py +8 -5
  292. mindspore/ops/functional_overload.py +655 -4
  293. mindspore/ops/op_info_register.py +32 -244
  294. mindspore/ops/operations/__init__.py +21 -14
  295. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  296. mindspore/ops/operations/_grad_ops.py +1 -10
  297. mindspore/ops/operations/_inner_ops.py +5 -76
  298. mindspore/ops/operations/_ms_kernel.py +4 -10
  299. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  300. mindspore/ops/operations/_scalar_ops.py +3 -2
  301. mindspore/ops/operations/_sequence_ops.py +1 -1
  302. mindspore/ops/operations/_tensor_array.py +1 -1
  303. mindspore/ops/operations/array_ops.py +39 -24
  304. mindspore/ops/operations/comm_ops.py +150 -107
  305. mindspore/ops/operations/custom_ops.py +287 -32
  306. mindspore/ops/operations/debug_ops.py +119 -16
  307. mindspore/ops/operations/inner_ops.py +1 -1
  308. mindspore/ops/operations/linalg_ops.py +1 -58
  309. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  310. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  311. mindspore/ops/operations/math_ops.py +21 -18
  312. mindspore/ops/operations/nn_ops.py +67 -224
  313. mindspore/ops/operations/other_ops.py +62 -9
  314. mindspore/ops/operations/random_ops.py +13 -7
  315. mindspore/ops/operations/reshard_ops.py +1 -1
  316. mindspore/ops/operations/sparse_ops.py +2 -2
  317. mindspore/ops/primitive.py +43 -32
  318. mindspore/ops/tensor_method.py +243 -17
  319. mindspore/ops_generate/__init__.py +0 -5
  320. mindspore/ops_generate/aclnn/__init__.py +0 -0
  321. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  322. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  323. mindspore/ops_generate/api/__init__.py +0 -0
  324. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  325. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  326. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  327. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  328. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  329. mindspore/ops_generate/api/gen_api.py +103 -0
  330. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  331. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  332. mindspore/ops_generate/common/__init__.py +0 -0
  333. mindspore/ops_generate/common/gen_constants.py +91 -0
  334. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  335. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  336. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  337. mindspore/ops_generate/gen_ops.py +23 -325
  338. mindspore/ops_generate/op_def/__init__.py +0 -0
  339. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  340. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  341. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
  342. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  343. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  344. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  345. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  346. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  347. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  348. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  349. mindspore/ops_generate/pyboost/__init__.py +0 -0
  350. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  351. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  352. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  353. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  354. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  355. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  356. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  357. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  358. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  359. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  360. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  361. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  362. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  363. mindspore/ops_generate/resources/__init__.py +0 -0
  364. mindspore/ops_generate/resources/resource_list.py +30 -0
  365. mindspore/ops_generate/resources/resource_loader.py +36 -0
  366. mindspore/ops_generate/resources/resource_manager.py +64 -0
  367. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  368. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  369. mindspore/parallel/__init__.py +6 -2
  370. mindspore/parallel/_auto_parallel_context.py +140 -12
  371. mindspore/parallel/_cell_wrapper.py +132 -15
  372. mindspore/parallel/_parallel_serialization.py +95 -4
  373. mindspore/parallel/_ps_context.py +1 -1
  374. mindspore/parallel/_recovery_context.py +7 -2
  375. mindspore/parallel/_tensor.py +142 -18
  376. mindspore/parallel/_utils.py +198 -25
  377. mindspore/parallel/algo_parameter_config.py +3 -3
  378. mindspore/parallel/auto_parallel.py +732 -0
  379. mindspore/parallel/checkpoint_convert.py +159 -0
  380. mindspore/parallel/checkpoint_transform.py +658 -37
  381. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  382. mindspore/parallel/cluster/run.py +1 -1
  383. mindspore/parallel/function/__init__.py +24 -0
  384. mindspore/parallel/function/reshard_func.py +258 -0
  385. mindspore/parallel/nn/__init__.py +25 -0
  386. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  387. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  388. mindspore/parallel/parameter_broadcast.py +24 -13
  389. mindspore/parallel/shard.py +137 -62
  390. mindspore/parallel/transform_safetensors.py +288 -95
  391. mindspore/pgodb140.dll +0 -0
  392. mindspore/pgort140.dll +0 -0
  393. mindspore/profiler/__init__.py +9 -5
  394. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  395. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  397. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
  398. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  399. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  400. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  401. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  402. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  403. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  404. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  405. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  406. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  407. mindspore/profiler/common/constant.py +12 -0
  408. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  409. mindspore/profiler/common/path_manager.py +24 -0
  410. mindspore/profiler/common/profiler_context.py +26 -2
  411. mindspore/profiler/common/profiler_meta_data.py +74 -0
  412. mindspore/profiler/common/profiler_parameters.py +59 -18
  413. mindspore/profiler/common/profiler_path_manager.py +66 -7
  414. mindspore/profiler/dynamic_profiler.py +112 -79
  415. mindspore/profiler/envprofiler.py +26 -1
  416. mindspore/profiler/experimental_config.py +197 -0
  417. mindspore/profiler/mstx.py +57 -14
  418. mindspore/profiler/platform/npu_profiler.py +33 -7
  419. mindspore/profiler/profiler.py +541 -45
  420. mindspore/profiler/profiler_action_controller.py +1 -1
  421. mindspore/profiler/profiler_interface.py +4 -0
  422. mindspore/profiler/schedule.py +57 -22
  423. mindspore/rewrite/api/node.py +15 -13
  424. mindspore/rewrite/api/symbol_tree.py +1 -1
  425. mindspore/run_check/_check_version.py +25 -14
  426. mindspore/run_check/run_check.py +1 -1
  427. mindspore/runtime/__init__.py +2 -2
  428. mindspore/runtime/executor.py +40 -11
  429. mindspore/runtime/memory.py +37 -13
  430. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  431. mindspore/swresample-4.dll +0 -0
  432. mindspore/swscale-6.dll +0 -0
  433. mindspore/tbbmalloc.dll +0 -0
  434. mindspore/tinyxml2.dll +0 -0
  435. mindspore/train/__init__.py +8 -8
  436. mindspore/train/_utils.py +43 -9
  437. mindspore/train/amp.py +1 -1
  438. mindspore/train/callback/__init__.py +2 -2
  439. mindspore/train/callback/_callback.py +2 -16
  440. mindspore/train/callback/_checkpoint.py +24 -40
  441. mindspore/train/callback/_cluster_monitor.py +14 -18
  442. mindspore/train/callback/_flops_collector.py +2 -3
  443. mindspore/train/callback/_history.py +7 -4
  444. mindspore/train/callback/_lambda_callback.py +2 -2
  445. mindspore/train/callback/_landscape.py +0 -3
  446. mindspore/train/callback/_loss_monitor.py +2 -1
  447. mindspore/train/callback/_on_request_exit.py +6 -5
  448. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  449. mindspore/train/callback/_summary_collector.py +8 -13
  450. mindspore/train/callback/_time_monitor.py +2 -1
  451. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
  452. mindspore/train/data_sink.py +25 -2
  453. mindspore/train/dataset_helper.py +4 -5
  454. mindspore/train/loss_scale_manager.py +8 -7
  455. mindspore/train/metrics/accuracy.py +3 -3
  456. mindspore/train/metrics/confusion_matrix.py +9 -9
  457. mindspore/train/metrics/error.py +3 -3
  458. mindspore/train/metrics/hausdorff_distance.py +4 -4
  459. mindspore/train/metrics/mean_surface_distance.py +3 -3
  460. mindspore/train/metrics/metric.py +0 -12
  461. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  462. mindspore/train/metrics/precision.py +8 -6
  463. mindspore/train/metrics/recall.py +9 -9
  464. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  465. mindspore/train/mind_ir_pb2.py +19 -12
  466. mindspore/train/model.py +262 -127
  467. mindspore/train/serialization.py +246 -988
  468. mindspore/train/summary/_summary_adapter.py +2 -2
  469. mindspore/train/summary/summary_record.py +1 -1
  470. mindspore/turbojpeg.dll +0 -0
  471. mindspore/utils/__init__.py +3 -2
  472. mindspore/utils/dryrun.py +4 -2
  473. mindspore/utils/hooks.py +81 -0
  474. mindspore/utils/runtime_execution_order_check.py +2 -0
  475. mindspore/utils/utils.py +138 -4
  476. mindspore/vcmeta.dll +0 -0
  477. mindspore/vcruntime140.dll +0 -0
  478. mindspore/vcruntime140_1.dll +0 -0
  479. mindspore/version.py +1 -1
  480. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
  481. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
  482. mindspore/_install_custom.py +0 -43
  483. mindspore/common/_register_for_adapter.py +0 -74
  484. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  485. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  486. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  487. mindspore/ops_generate/gen_constants.py +0 -190
  488. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  489. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  490. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  492. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
  493. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py CHANGED
@@ -27,7 +27,6 @@ import time
27
27
  import numpy as np
28
28
 
29
29
  import mindspore
30
- import mindspore.dataset as ds
31
30
  from mindspore import log as logger
32
31
  from mindspore.train.serialization import save_checkpoint, load_checkpoint
33
32
  from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
@@ -36,7 +35,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
36
35
  from mindspore._checkparam import check_input_data, check_output_data
37
36
  from mindspore import _checkparam as Validator
38
37
  from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
39
- TFTRegister
38
+ TrainFaultTolerance
40
39
  from mindspore.train.callback import __all__ as internal_cb_names
41
40
  from mindspore.train.callback._cluster_monitor import ClusterMonitor
42
41
  from mindspore import context
@@ -57,7 +56,11 @@ from mindspore.dataset.core.config import get_debug_mode
57
56
  from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
58
57
  from mindspore.train import amp
59
58
  from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
59
+ from mindspore._c_expression import _get_optimzer_timestamps
60
+ from mindspore._c_expression import clean_tdt_channel
60
61
 
62
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
63
+ from .serialization import load_param_into_net
61
64
 
62
65
  def _transfer_tensor_to_tuple(inputs):
63
66
  """
@@ -91,6 +94,7 @@ def _save_final_ckpt(func):
91
94
  """
92
95
  Decorator function, which saves the current checkpoint when an exception occurs during training.
93
96
  """
97
+
94
98
  @wraps(func)
95
99
  def wrapper(self, *args, **kwargs):
96
100
  obj = None
@@ -107,7 +111,7 @@ def _save_final_ckpt(func):
107
111
  # pylint: disable=W0212
108
112
  prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
109
113
  cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
110
- + str(self._current_step_num) + "_breakpoint.ckpt"
114
+ + str(self._current_step_num) + "_breakpoint.ckpt"
111
115
  cur_file = os.path.join(obj._directory, cur_ckpoint_file)
112
116
  if "epoch_num" in obj._append_dict:
113
117
  obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
@@ -118,88 +122,172 @@ def _save_final_ckpt(func):
118
122
  raise e
119
123
  else:
120
124
  func(self, *args, **kwargs)
125
+
121
126
  return wrapper
122
127
 
128
+
129
+ def _handle_exception_info(obj, uce_env, tft, e):
130
+ """handle exception info"""
131
+ logger.info("uce wrapper caught RuntimeError")
132
+ if not uce_env:
133
+ logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
134
+ exc_info=True)
135
+ if tft:
136
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
137
+ raise e
138
+ e_str = str(e)
139
+ logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
140
+ if "UCEError" in e_str:
141
+ logger.info("uce wrapper report UCEError")
142
+ obj.is_uce_rank = True
143
+ # if error is HBM_MULTI_BIT_ECC_ERROR
144
+ if "error_code=507054" in e_str:
145
+ hbm_error_time, optimize_start, optimizer_end = _get_optimzer_timestamps()
146
+ can_repair = tft.tft_can_do_uce_repair(hbm_error_time, optimize_start, optimizer_end)
147
+ logger.info(f"UCEError of type HBM_MULTI_BIT_ECC_ERROR occurs, \
148
+ hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, \
149
+ optimizer_end={optimizer_end}, can_repair={can_repair}")
150
+ if not can_repair:
151
+ logger.error(f"Caught UCEError of type HBM_MULTI_BIT_ECC_ERROR but can not repair, "
152
+ f"hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, "
153
+ f"optimizer_end={optimizer_end}", exc_info=True)
154
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
155
+ raise e
156
+ tft.tft_report_error(tft.ReportState.RS_UCE.value)
157
+ elif "ForceStopError" in e_str:
158
+ logger.warning("uce wrapper caught RuntimeError ForceStopError")
159
+ force_stop_err = tft.ReportState.RS_NORMAL.value
160
+ tft.tft_report_error(force_stop_err)
161
+ elif "ARF FINISH" in e_str:
162
+ logger.warning(f"ARF FINISH")
163
+ _set_recovery_context(is_arf=True)
164
+ tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
165
+ else:
166
+ logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
167
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
168
+ raise e
169
+
170
+
171
+ def _handle_training_result_error(model, tft_obj):
172
+ """
173
+ Handle training result error for resuming training.
174
+ """
175
+ ckpt_load_fn = tft_obj.ckpt_load_func
176
+ train_network = tft_obj.cb_params.train_network
177
+ logger.warning("Process training result error start.")
178
+ # 1. Clear tdt channel
179
+ logger.warning("Clean tdt channel.")
180
+ clean_tdt_channel()
181
+
182
+ # 2. Load checkpoint
183
+ logger.warning("Load checkpoint.")
184
+ new_param_dict, remove_redundancy = ckpt_load_fn()
185
+ param_not_load, ckpt_not_load = load_param_into_net(train_network, new_param_dict, True, remove_redundancy)
186
+ logger.warning(f"param_not_load: {param_not_load}")
187
+ logger.warning(f"ckpt_not_load: {ckpt_not_load}")
188
+ resume_epoch = new_param_dict.get('epoch_num')
189
+ resume_step = new_param_dict.get('step_num')
190
+ model._initial_step = int(resume_step.asnumpy())
191
+ logger.warning("Process training result error end.")
192
+ return (resume_epoch, resume_step)
193
+
194
+
195
+ def _calc_cb_initial_step(org_epoch, org_step, *args, **kwargs):
196
+ """calculate initial step for callback"""
197
+ train_dataset = args[1]
198
+ dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
199
+ sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
200
+
201
+ cb_initial_step = 0
202
+ if dataset_sink_mode:
203
+ train_dataset.set_init_step(org_epoch)
204
+ dataset_size = train_dataset.get_dataset_size()
205
+ if sink_size != -1:
206
+ cb_initial_step = org_epoch * sink_size + org_step
207
+ else:
208
+ cb_initial_step = org_epoch * dataset_size + org_step
209
+ else:
210
+ train_dataset.set_init_step(org_step)
211
+ cb_initial_step = org_step
212
+ if hasattr(train_dataset, '_dataset_helper'):
213
+ dataset_helper = train_dataset._dataset_helper
214
+ _reset_training_dataset(cb_initial_step, dataset_helper.iter.dataset.get_dataset_size())
215
+ return cb_initial_step
216
+
217
+
218
+ def _update_ckpt_callback_info(resume_train_step, **kwargs):
219
+ """
220
+ Update checkpoint callback internal state
221
+ """
222
+ ckpt_obj = None
223
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
224
+ ckpt_obj = kwargs.get('callbacks')
225
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
226
+ for item in kwargs.get('callbacks'):
227
+ if isinstance(item, ModelCheckpoint):
228
+ ckpt_obj = item
229
+ if ckpt_obj is not None:
230
+ ckpt_obj._last_triggered_step = 0
231
+ ckpt_obj._append_step_num = resume_train_step
232
+
233
+
123
234
  def _handle_tft(func):
124
235
  """
125
236
  Decorator function, which starts uce handle process when an exception occurs during training.
126
237
  """
238
+
127
239
  @wraps(func)
128
240
  def wrapper(self, *args, **kwargs):
129
241
  obj = None
130
- if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
242
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
131
243
  obj = kwargs.get('callbacks')
132
244
  if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
133
245
  for item in kwargs.get('callbacks'):
134
- if isinstance(item, TFTRegister):
246
+ if isinstance(item, TrainFaultTolerance):
135
247
  obj = item
136
248
  if obj:
137
249
  tft = obj.tft
138
250
  tft_env = os.getenv("MS_ENABLE_TFT", "")
139
- uce_env = "UCE:1" in tft_env
251
+ uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
252
+ tre_env = "TRE:1" in tft_env
140
253
  while True:
141
254
  try:
142
255
  return func(self, *args, **kwargs)
143
256
  except RuntimeError as e:
144
- logger.info("uce wrapper caught RuntimeError")
145
- if not uce_env:
146
- logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
147
- exc_info=True)
148
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
149
- raise e
150
- e_str = str(e)
151
- logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
152
- if "UCEError" in e_str:
153
- logger.info("uce wrapper report UCEError")
154
- obj.is_uce_rank = True
155
- tft.tft_report_error(tft.ReportState.RS_UCE.value)
156
- elif "ForceStopError" in e_str:
157
- logger.info("uce wrapper caught RuntimeError ForceStopError")
158
- force_stop_err = tft.ReportState.RS_NORMAL.value
159
- tft.tft_report_error(force_stop_err)
257
+ if tre_env and 'TREError' in str(e):
258
+ _, resume_step = _handle_training_result_error(self, obj)
259
+ repair_step = int(resume_step.asnumpy())
260
+ _update_ckpt_callback_info(repair_step, **kwargs)
261
+ logger.warning(f'Resume training after TREError from step {repair_step}.')
160
262
  else:
161
- logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
162
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
163
- raise e
164
- ret = tft.tft_wait_next_action()
165
- if ret == tft.Action.EXIT.value:
166
- raise e
167
- repair_step = tft.tft_get_repair_step()
168
- logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
169
- {}".format(repair_step, self.batch_num))
170
- initial_epoch = int(repair_step/self.batch_num)
263
+ _handle_exception_info(obj, uce_env, tft, e)
264
+ ret = tft.tft_wait_next_action()
265
+ if ret == tft.Action.EXIT.value:
266
+ raise e
267
+ repair_step = tft.tft_get_repair_step()
268
+ logger.warning(
269
+ "uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
270
+ self.batch_num))
271
+ initial_epoch = int(repair_step / self.batch_num)
171
272
  initial_step = repair_step % self.batch_num
172
273
  kwargs["initial_epoch"] = initial_epoch
173
-
174
- train_dataset = args[1]
175
- dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
176
- sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
177
-
178
- cb_initial_step = 0
179
- if dataset_sink_mode:
180
- train_dataset.set_init_step(initial_epoch)
181
- dataset_size = train_dataset.get_dataset_size()
182
- if sink_size != -1:
183
- cb_initial_step = initial_epoch * sink_size + initial_step
184
- else:
185
- cb_initial_step = initial_epoch * dataset_size + initial_step
186
- else:
187
- train_dataset.set_init_step(initial_step)
188
- cb_initial_step = initial_step
189
-
190
- kwargs["initial_step"] = cb_initial_step
274
+ cb_initial_step = _calc_cb_initial_step(initial_epoch, initial_step, *args, **kwargs)
275
+ if not self.enable_tre:
276
+ kwargs["initial_step"] = cb_initial_step
191
277
  # reset all accu grads to zero
192
278
  obj._reset_acc_grads()
193
-
194
- logger.info("uce wrapper repair complete \
195
- initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
279
+ logger.warning(
280
+ "uce wrapper repair complete initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch,
281
+ cb_initial_step))
196
282
  continue
197
283
  except BaseException as e:
198
- logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
199
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
284
+ if tft:
285
+ logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
286
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
200
287
  raise e
201
288
  else:
202
289
  return func(self, *args, **kwargs)
290
+
203
291
  return wrapper
204
292
 
205
293
 
@@ -216,7 +304,7 @@ def _check_tft():
216
304
  if ms_mode != mindspore.GRAPH_MODE:
217
305
  raise ValueError("TFT is only supported in GRAPH_MODE")
218
306
  jit_level = context.get_context("jit_level")
219
- if jit_level == "O2" and "UCE:1" in tft_env:
307
+ if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
220
308
  raise ValueError("TFT is not supported when using jit_level == O2")
221
309
 
222
310
 
@@ -406,12 +494,13 @@ class Model:
406
494
  the accuracy is reduced by less than 3%.
407
495
 
408
496
  If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
409
- In order for this function to work, you need to set the optimizer, eval_network or metric parameters
410
- at the same time.
497
+ In order for this function to work, you need to set the parameter `optimizer`, along with
498
+ at least one of the parameter `eval_network` or performance `metrics`.
411
499
 
412
500
  Notice: The current optimization enabled by default only applies to some networks, and not all networks
413
501
  can obtain the same benefits. It is recommended to enable this function on
414
- the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
502
+ the Graph mode + Ascend platform, and for better acceleration,
503
+ refer to :class:`mindspore.boost.AutoBoost` to configure
415
504
  boost_config_dict.
416
505
 
417
506
  Examples:
@@ -436,6 +525,7 @@ class Model:
436
525
  def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
437
526
  amp_level="O0", boost_level="O0", **kwargs):
438
527
  self._network = network
528
+ _init_auto_parallel_context(self._network)
439
529
  self._loss_fn = loss_fn
440
530
  self._optimizer = optimizer
441
531
  self._loss_scale_manager = None
@@ -470,6 +560,9 @@ class Model:
470
560
  self._lite_infer = True # if backend lite infer fails, set False
471
561
  self._mindspore_lite_model_group_id = id(self) & 0xFFFF
472
562
  self.batch_num = -1
563
+ self.enable_tre = "TRE:1" in os.getenv("MS_ENABLE_TFT", "")
564
+ self._initial_step = None
565
+ _clear_auto_parallel_context(self._network)
473
566
 
474
567
  def _check_for_graph_cell(self, kwargs):
475
568
  """Check for graph cell"""
@@ -668,7 +761,7 @@ class Model:
668
761
  logger.info("Begin to connect network with dataset.")
669
762
  network = connect_network_with_dataset(network, dataset_helper)
670
763
 
671
- if _get_recovery_context("enable_recovery") and is_train:
764
+ if (_get_recovery_context("enable_recovery") or self.enable_tre) and is_train:
672
765
  _set_training_dataset(dataset_helper)
673
766
 
674
767
  network.set_train(is_train)
@@ -765,7 +858,7 @@ class Model:
765
858
  break
766
859
  logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
767
860
 
768
- def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
861
+ def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
769
862
  """
770
863
  Initialize compute graphs and data graphs with the sink mode.
771
864
 
@@ -794,7 +887,6 @@ class Model:
794
887
  if not isinstance(train_dataset, mindspore.dataset.Dataset):
795
888
  raise TypeError("The type of 'train_dataset' must be `Dataset`, "
796
889
  "but got {}.".format(type(train_dataset)))
797
-
798
890
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
799
891
  "Begin to check parameter broadcast in model.build().")
800
892
  logger.info("Begin to check parameter broadcast in model.build() procedure.")
@@ -807,23 +899,24 @@ class Model:
807
899
  train_dataset.__no_send__ = True
808
900
  train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
809
901
  dataset=train_dataset,
810
- dataset_sink_mode=True,
902
+ dataset_sink_mode=sink_mode,
811
903
  sink_size=sink_size)
812
904
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
813
- logger.info("Begin to warmup dataset in model.build() procedure.")
814
- self._warmup_dataset(epoch, train_dataset, sink_size)
905
+ if sink_mode:
906
+ logger.info("Begin to warmup dataset in model.build() procedure.")
907
+ self._warmup_dataset(epoch, train_dataset, sink_size)
815
908
 
816
- # Since dataset pipeline has been triggered, delete flag
817
- delattr(train_dataset, "__no_send__")
909
+ # Since dataset pipeline has been triggered, delete flag
910
+ delattr(train_dataset, "__no_send__")
818
911
 
819
- # Waiting for the dataset warmup ready
820
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
821
- "Begin waiting for dataset warmup in model.build().")
822
- logger.info("Begin waiting for dataset warmup in model.build() procedure.")
823
- self._waiting_for_dataset_warmup_ready(train_dataset)
824
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
825
- "The dataset warmup was successful in model.build().")
826
- logger.info("The dataset warmup was successful in model.build() procedure.")
912
+ # Waiting for the dataset warmup ready
913
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
914
+ "Begin waiting for dataset warmup in model.build().")
915
+ logger.info("Begin waiting for dataset warmup in model.build() procedure.")
916
+ self._waiting_for_dataset_warmup_ready(train_dataset)
917
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
918
+ "The dataset warmup was successful in model.build().")
919
+ logger.info("The dataset warmup was successful in model.build() procedure.")
827
920
 
828
921
  if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
829
922
  train_network.add_flags_recursive(is_first_iteration=True)
@@ -833,6 +926,7 @@ class Model:
833
926
  logger.info("Begin to compile train network in model.build() procedure.")
834
927
  train_network.compile(*inputs)
835
928
  self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
929
+ train_dataset.reset()
836
930
  break
837
931
 
838
932
  if valid_dataset:
@@ -846,7 +940,7 @@ class Model:
846
940
  valid_dataset.__no_send__ = True
847
941
  valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
848
942
  dataset=valid_dataset,
849
- dataset_sink_mode=True)
943
+ dataset_sink_mode=sink_mode)
850
944
  if context.get_auto_parallel_context("pipeline_stages") > 1:
851
945
  eval_network.add_flags_recursive(is_first_iteration=False)
852
946
  for inputs in valid_dataset_helper:
@@ -854,6 +948,7 @@ class Model:
854
948
  "Begin to compile eval network in model.build().")
855
949
  logger.info("Begin to compile eval network in model.build() procedure.")
856
950
  eval_network.compile(*inputs)
951
+ valid_dataset.reset()
857
952
  break
858
953
 
859
954
  @staticmethod
@@ -922,6 +1017,8 @@ class Model:
922
1017
  cb_params.last_save_ckpt_step = None
923
1018
  cb_params.latest_ckpt_file = None
924
1019
  cb_params.loss_scale_mananger = self._loss_scale_manager
1020
+ cb_params.is_arf = _get_recovery_context("is_arf")
1021
+ cb_params.initial_step = self._initial_step
925
1022
 
926
1023
  # build callback list
927
1024
  with _CallbackManager(callbacks) as list_callback:
@@ -1026,6 +1123,9 @@ class Model:
1026
1123
  need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
1027
1124
  if need_exec_callback_step_end:
1028
1125
  list_callback.on_train_step_end(run_context)
1126
+ if cb_params.is_arf:
1127
+ cb_params.is_arf = False
1128
+ _set_recovery_context(is_arf=False)
1029
1129
 
1030
1130
  # Embedding cache server only run one step.
1031
1131
  if is_embedding_cache_server:
@@ -1056,7 +1156,7 @@ class Model:
1056
1156
  if should_stop:
1057
1157
  break
1058
1158
 
1059
- need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
1159
+ need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
1060
1160
  and not _get_recovery_context("latest_ckpt_file")
1061
1161
  self.epoch_iter += 1
1062
1162
  if need_reset_to_beginning:
@@ -1100,7 +1200,7 @@ class Model:
1100
1200
  Check whether enable recovery and execution mode consistency.
1101
1201
  """
1102
1202
 
1103
- enable_recovery = _get_recovery_context("enable_recovery")
1203
+ enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
1104
1204
  if not enable_recovery:
1105
1205
  self.enable_recovery = False
1106
1206
  else:
@@ -1117,6 +1217,8 @@ class Model:
1117
1217
  dataset_size (int): The number of batches in a dataset.
1118
1218
  sink_size (int): Control the amount of data in each sink. Default: -1.
1119
1219
  """
1220
+ if context.get_context("device_target") != "GPU":
1221
+ return
1120
1222
  if not self.enable_recovery:
1121
1223
  self.need_load_ckpt = False
1122
1224
 
@@ -1145,7 +1247,7 @@ class Model:
1145
1247
  load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
1146
1248
  except BaseException as e:
1147
1249
  os.remove(cb_params.latest_ckpt_file)
1148
- raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
1250
+ raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
1149
1251
  + cb_params.latest_ckpt_file) from e
1150
1252
  _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
1151
1253
  self.need_load_ckpt = False
@@ -1235,6 +1337,9 @@ class Model:
1235
1337
  self._loss_scale_manager.update_loss_scale(overflow)
1236
1338
 
1237
1339
  list_callback.on_train_step_end(run_context)
1340
+ if cb_params.is_arf:
1341
+ cb_params.is_arf = False
1342
+ _set_recovery_context(is_arf=False)
1238
1343
  # Embedding cache server only run one step.
1239
1344
  if is_embedding_cache_server:
1240
1345
  break
@@ -1332,10 +1437,9 @@ class Model:
1332
1437
  ... loss_scale_manager=loss_scale_manager)
1333
1438
  >>> model.train(2, dataset)
1334
1439
  """
1440
+ _init_auto_parallel_context(self._network)
1335
1441
  _check_tft()
1336
1442
  device_target = context.get_context("device_target")
1337
- # prepare dataset for obfuscated model
1338
- train_dataset = self._prepare_obf_dataset(train_dataset)
1339
1443
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1340
1444
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1341
1445
  dataset_sink_mode = False
@@ -1391,6 +1495,8 @@ class Model:
1391
1495
  if _enable_distributed_mindrt():
1392
1496
  _reset_op_id_with_offset()
1393
1497
 
1498
+ _clear_auto_parallel_context(self._network)
1499
+
1394
1500
  @staticmethod
1395
1501
  def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
1396
1502
  if get_debug_mode() and dataset_sink_mode:
@@ -1484,11 +1590,8 @@ class Model:
1484
1590
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1485
1591
  >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
1486
1592
  >>> model.fit(2, train_dataset, valid_dataset)
1487
-
1488
- Tutorial Examples:
1489
- - `Advanced Encapsulation: Model - Train and Save Model
1490
- <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1491
1593
  """
1594
+ _init_auto_parallel_context(self._network)
1492
1595
  device_target = context.get_context("device_target")
1493
1596
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1494
1597
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
@@ -1540,8 +1643,9 @@ class Model:
1540
1643
  valid_dataset=valid_dataset,
1541
1644
  valid_frequency=valid_frequency,
1542
1645
  valid_dataset_sink_mode=valid_dataset_sink_mode)
1646
+ _clear_auto_parallel_context(self._network)
1543
1647
 
1544
- def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
1648
+ def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
1545
1649
  """
1546
1650
  Build computational graphs and data graphs with the sink mode.
1547
1651
 
@@ -1560,6 +1664,7 @@ class Model:
1560
1664
  will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
1561
1665
  sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
1562
1666
  epoch (int): Control the training epochs. Default: ``1`` .
1667
+ sink_mode (bool): Determines whether to pass the data through dataset channel. Default: ``True`` .
1563
1668
 
1564
1669
  Examples:
1565
1670
  >>> from mindspore import nn
@@ -1580,16 +1685,18 @@ class Model:
1580
1685
  >>> model.build(dataset, epoch=2)
1581
1686
  >>> model.train(2, dataset)
1582
1687
  """
1688
+ _init_auto_parallel_context(self._network)
1583
1689
  epoch = Validator.check_positive_int(epoch)
1584
1690
  if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1585
1691
  self._train_network.check_names_and_refresh_name()
1586
1692
  self._train_network._is_check_and_refresh = True
1587
1693
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
1588
1694
  logger.info("Begin to init dataset in model.build() procedure.")
1589
- self._init(train_dataset, valid_dataset, sink_size, epoch)
1695
+ self._init(train_dataset, valid_dataset, sink_size, epoch, sink_mode)
1590
1696
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1591
1697
  "The model.build() which contains dataset warmup and network compile is success.")
1592
1698
  logger.info("The model.build() which contains dataset warmup and network compile is success.")
1699
+ _clear_auto_parallel_context(self._network)
1593
1700
 
1594
1701
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
1595
1702
  """
@@ -1759,12 +1866,8 @@ class Model:
1759
1866
  >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
1760
1867
  >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
1761
1868
  >>> acc = model.eval(dataset, dataset_sink_mode=False)
1762
-
1763
- Tutorial Examples:
1764
- - `Advanced Encapsulation: Model - Train and Save Model
1765
- <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1766
1869
  """
1767
- valid_dataset = self._prepare_obf_dataset(valid_dataset)
1870
+ _init_auto_parallel_context(self._network)
1768
1871
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
1769
1872
 
1770
1873
  _device_number_check(self._parallel_mode, self._device_number)
@@ -1809,6 +1912,7 @@ class Model:
1809
1912
  # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1810
1913
  if _enable_distributed_mindrt():
1811
1914
  _reset_op_id_with_offset()
1915
+ _clear_auto_parallel_context(self._network)
1812
1916
 
1813
1917
  return eval_result
1814
1918
 
@@ -1821,7 +1925,8 @@ class Model:
1821
1925
  The predict data, can be a single tensor,
1822
1926
  a list of tensor, or a tuple of tensor.
1823
1927
 
1824
- config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
1928
+ config (dict, optional): The config parameter is enabled when the backend is ‘lite’.
1929
+
1825
1930
  The config includes two parts: config_path (configPath, str) and config_item (str, dict).
1826
1931
  When the config_item is set, its priority is higher than the config_path. Set the ranking
1827
1932
  table file for inference. The content of the configuration file is as follows:
@@ -1831,6 +1936,16 @@ class Model:
1831
1936
  For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
1832
1937
  config.ini file:
1833
1938
 
1939
+ The config has 3 forms:
1940
+ 1. configPath defines the path of the configuration file, which is used to pass user-defined
1941
+ options during model building. Default value: ``"" ``.
1942
+
1943
+ .. code-block::
1944
+
1945
+ config = {"configPath" : "/home/user/config.ini"}
1946
+
1947
+ Here is the content of the config.ini file:
1948
+
1834
1949
  .. code-block::
1835
1950
 
1836
1951
  [ascend_context]
@@ -1839,20 +1954,15 @@ class Model:
1839
1954
  [op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
1840
1955
  [op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
1841
1956
 
1842
- When only the config_path is configured, it is done as follows:
1843
-
1844
- .. code-block::
1845
-
1846
- config = {"configPath" : "/home/user/config.ini"}
1847
-
1848
- When only the config_dict is configured, it is done as follows:
1957
+ 2. Set the user-defined options in parameter dictionary, it is done as follows:
1849
1958
 
1850
1959
  .. code-block::
1851
1960
 
1852
1961
  config = {"ascend_context" : {"rank_table_file" : "path_b"},
1853
1962
  "execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
1854
1963
 
1855
- When both the `config_path` and the `config_dict` are configured, it is done as follows:
1964
+ 3. Both the `configPath` and the `parameter dictionary` are configured, The priority of the parameter
1965
+ dictionary is higher than that of the content in the configuration file. It is done as follows:
1856
1966
 
1857
1967
  .. code-block::
1858
1968
 
@@ -1860,12 +1970,13 @@ class Model:
1860
1970
  "ascend_context" : {"rank_table_file" : "path_b"},
1861
1971
  "execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
1862
1972
 
1863
- Note that both the "configPath" is configured in the config_dict and the config_item,
1864
- in this case, the path_b in the config_dict takes precedence.
1973
+ Note that in the "configPath" the parameter is set as "rank_table_file = [path_a]", but in dict is set
1974
+ as "ascend_context" : {"rank_table_file" : "path_b"}, in this case, the path_b takes precedence.
1865
1975
 
1866
1976
  Returns:
1867
1977
  Tensor, array(s) of predictions.
1868
1978
  """
1979
+
1869
1980
  def _get_lite_context(lite_context_input):
1870
1981
  # use default lite context parameters for now
1871
1982
  device_target = context.get_context("device_target").lower()
@@ -1899,7 +2010,7 @@ class Model:
1899
2010
  if not self._mindspore_lite:
1900
2011
  self._mindspore_lite = importlib.import_module('mindspore_lite')
1901
2012
 
1902
- use_past = False # default execute full model inference
2013
+ use_past = False # default execute full model inference
1903
2014
  model_group_id = None
1904
2015
  if self._predict_network.get_flags().__contains__("is_first_iteration"):
1905
2016
  is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
@@ -2012,6 +2123,7 @@ class Model:
2012
2123
  >>> model = Model(LeNet5())
2013
2124
  >>> result = model.predict(input_data)
2014
2125
  """
2126
+ _init_auto_parallel_context(self._network)
2015
2127
  if backend not in ['lite', None]:
2016
2128
  raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
2017
2129
  if backend == "lite" and self._lite_infer:
@@ -2027,6 +2139,7 @@ class Model:
2027
2139
  except BaseException as e:
2028
2140
  self._lite_infer = False
2029
2141
  logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
2142
+ _clear_auto_parallel_context(self._network)
2030
2143
 
2031
2144
  def _check_input_data():
2032
2145
  """Input data check."""
@@ -2092,7 +2205,9 @@ class Model:
2092
2205
 
2093
2206
  def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
2094
2207
  """
2095
- Generate parameter layout for the train network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
2208
+ Generate parameter layout for the train network when using `AutoParallel(cell)`
2209
+ to enable parallel mode.
2210
+
2096
2211
  Only dataset sink mode is supported for now.
2097
2212
 
2098
2213
  .. warning::
@@ -2111,9 +2226,9 @@ class Model:
2111
2226
  Configure pynative mode or CPU, the training process will be performed with
2112
2227
  dataset not sink. Default: ``True`` .
2113
2228
  sink_size (int): Control the number of steps for each sinking.
2229
+ If dataset_sink_mode is False, set sink_size as invalid.
2114
2230
  If sink_size = -1, sink the complete dataset for each epoch.
2115
2231
  If sink_size > 0, sink sink_size data for each epoch.
2116
- If dataset_sink_mode is False, set sink_size as invalid.
2117
2232
  Default: ``-1`` .
2118
2233
 
2119
2234
  Returns:
@@ -2127,10 +2242,10 @@ class Model:
2127
2242
  >>> from mindspore import Tensor, nn
2128
2243
  >>> from mindspore.train import Model
2129
2244
  >>> from mindspore.communication import init
2245
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
2130
2246
  >>>
2131
2247
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2132
2248
  >>> init()
2133
- >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
2134
2249
  >>>
2135
2250
  >>> # Create the dataset taking MNIST as an example. Refer to
2136
2251
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
@@ -2138,13 +2253,15 @@ class Model:
2138
2253
  >>> # Define the network structure of LeNet5. Refer to
2139
2254
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
2140
2255
  >>> net = LeNet5()
2256
+ >>> parallel_net = AutoParallel(net)
2141
2257
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
2142
2258
  >>> loss_scale_manager = ms.FixedLossScaleManager()
2143
2259
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
2144
- >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
2260
+ >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None,
2145
2261
  ... loss_scale_manager=loss_scale_manager)
2146
2262
  >>> layout_dict = model.infer_train_layout(dataset)
2147
2263
  """
2264
+ _init_auto_parallel_context(self._network)
2148
2265
  self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
2149
2266
 
2150
2267
  train_dataset.__no_send__ = True
@@ -2156,11 +2273,13 @@ class Model:
2156
2273
  train_network.compile(*inputs)
2157
2274
  break
2158
2275
  train_dataset.__model_hash__ = hash(self)
2276
+ _clear_auto_parallel_context(self._network)
2159
2277
  return train_network.parameter_layout_dict
2160
2278
 
2161
2279
  def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
2162
2280
  """
2163
- Generate parameter layout for the predict network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
2281
+ Generate parameter layout for the predict network when using `AutoParallel(cell)`
2282
+ to enable parallel mode.
2164
2283
 
2165
2284
  Data could be a single tensor or multiple tensors.
2166
2285
 
@@ -2183,21 +2302,47 @@ class Model:
2183
2302
  RuntimeError: If not in GRAPH_MODE.
2184
2303
 
2185
2304
  Examples:
2186
- >>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
2187
- >>> # mindspore.cn.
2188
2305
  >>> import numpy as np
2189
- >>> import mindspore as ms
2306
+ >>> import mindspore.nn as nn
2190
2307
  >>> from mindspore import Tensor
2191
2308
  >>> from mindspore.train import Model
2309
+ >>> from mindspore.ops import operations as P
2310
+ >>> from mindspore import context
2192
2311
  >>> from mindspore.communication import init
2312
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
2313
+ >>>
2314
+ >>> class Net(nn.Cell):
2315
+ >>> def __init__(self):
2316
+ >>> super(Net, self).__init__()
2317
+ >>> self.fc1 = nn.Dense(128, 768, activation='relu')
2318
+ >>> self.fc2 = nn.Dense(128, 768, activation='relu')
2319
+ >>> self.fc3 = nn.Dense(128, 768, activation='relu')
2320
+ >>> self.fc4 = nn.Dense(768, 768, activation='relu')
2321
+ >>> self.relu4 = nn.ReLU()
2322
+ >>> self.relu5 = nn.ReLU()
2323
+ >>> self.transpose = P.Transpose()
2324
+ >>> self.matmul1 = P.MatMul()
2325
+ >>> self.matmul2 = P.MatMul()
2326
+ >>>
2327
+ >>> def construct(self, x):
2328
+ >>> q = self.fc1(x)
2329
+ >>> k = self.fc2(x)
2330
+ >>> v = self.fc3(x)
2331
+ >>> k = self.transpose(k, (1, 0))
2332
+ >>> c = self.relu4(self.matmul1(q, k))
2333
+ >>> s = self.relu5(self.matmul2(c, v))
2334
+ >>> s = self.fc4(s)
2335
+ >>> return s
2193
2336
  >>>
2194
2337
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2195
2338
  >>> init()
2196
- >>> ms.set_auto_parallel_context(full_batch=True, parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
2197
- >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
2198
- >>> model = Model(Net())
2199
- >>> predict_map = model.infer_predict_layout(input_data)
2339
+ >>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
2340
+ >>> net = Net()
2341
+ >>> parallel_net = AutoParallel(net, parallel_mode='semi_auto')
2342
+ >>> model = Model(parallel_net)
2343
+ >>> predict_map = model.infer_predict_layout(inputs)
2200
2344
  """
2345
+ _init_auto_parallel_context(self._network)
2201
2346
  if context.get_context("mode") != context.GRAPH_MODE:
2202
2347
  raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
2203
2348
  "only supports GRAPH MODE and Ascend target currently.")
@@ -2217,6 +2362,7 @@ class Model:
2217
2362
  predict_net.phase = origin_phase
2218
2363
  else:
2219
2364
  predict_net.compile(*predict_data)
2365
+ _clear_auto_parallel_context(self._network)
2220
2366
  return predict_net.parameter_layout_dict
2221
2367
 
2222
2368
  def _flush_from_cache(self, cb_params):
@@ -2256,16 +2402,5 @@ class Model:
2256
2402
  """
2257
2403
  return self._eval_network
2258
2404
 
2259
- def _prepare_obf_dataset(self, dataset):
2260
- if not hasattr(self._network, 'obf_ratios'):
2261
- return dataset
2262
- data_size = dataset.get_dataset_size()
2263
- obf_ratio_dataset = []
2264
- for _ in range(data_size):
2265
- obf_ratio_dataset.append(self._network.obf_ratios)
2266
- obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
2267
- dataset = ds.zip((dataset, obf_ratio_dataset))
2268
- return dataset
2269
-
2270
2405
 
2271
2406
  __all__ = ["Model"]