mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -15,24 +15,27 @@
15
15
  """Checkpoint related classes and functions."""
16
16
 
17
17
  import os
18
+ from mindspore.utils import _tft_handler
18
19
  from mindspore.train.serialization import save_checkpoint
19
- from mindspore.parallel._utils import _get_device_num
20
- from mindspore import _checkparam as Validator
21
20
  from mindspore.train.callback._callback import Callback
22
- from mindspore import context
21
+ from mindspore import context, ops
23
22
  from mindspore.common.parameter import Parameter
24
23
  from mindspore.common.tensor import Tensor
25
24
  from mindspore.communication import get_rank, get_group_size
26
25
  from mindspore import log as logger
27
26
  from mindspore.train.serialization import _get_cur_rank_dp
28
- from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
27
+ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
28
+ from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
29
29
  from mindspore._c_expression import clean_tdt_channel
30
- from mindspore._c_expression import send_recv
30
+ from mindspore._c_expression import send_recv, reset_params
31
31
  from mindspore._c_expression import CollectiveManager
32
32
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
33
- from mindspore._c_expression import Tensor as Tensor_
33
+ from mindspore._c_expression import TensorPy as Tensor_
34
+ from mindspore.ops.operations.manually_defined._inner import TensorReport
34
35
  import mindspore
35
36
  import mindspore.common.dtype as mstype
37
+ from mindspore.parallel._recovery_context import _set_recovery_context
38
+
36
39
 
37
40
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
38
41
  """ Common func to generate ckpt dir name."""
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
40
43
  mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
41
44
  return os.path.join(ckpt_save_path, mid_dir)
42
45
 
46
+
43
47
  def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
44
48
  """ Callback used for TFT save ckpt function when errors occur."""
45
49
  logger.info("Enter _save_checkpoint_on_failure function")
46
- if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
50
+ if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
47
51
  raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
52
+ cb_params = args
53
+ # we record the current step and epoch num in on_train_step_end, so we can just reset it here
54
+ cb_params.cur_step_num = cb_ctx.cur_step_num
55
+ cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
56
+ if cb_params.optimizer is not None:
57
+ cb_params.optimizer.global_step = cb_ctx.global_step
58
+ if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
59
+ cb_params.network.optimizer.global_step = cb_ctx.global_step
60
+ append_dict = {}
61
+ append_dict["__exception_save__"] = True
62
+ # if user has provided a custom save callback, use it
63
+ if cb_ctx.save_cb:
64
+ cb_ctx.save_cb(cb_params, append_dict)
65
+ logger.info("Finish _save_checkpoint_on_failure function")
66
+ return
48
67
 
68
+ # if user has not provided a custom save callback, use default save logic
49
69
  ckpt_save_path = cb_ctx.ckpt_save_path
50
- cb_params = args
51
70
  cur_rank = get_rank()
52
- cur_step_num = cb_params.cur_step_num
71
+ step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
53
72
  cur_epoch_num = cb_params.cur_epoch_num
54
- batch_num = cb_params.batch_num
55
- if cur_step_num > step:
56
- cur_epoch_num = (step - 1) // batch_num + 1
57
- step_num_in_epoch = int((step - 1) % batch_num + 1)
58
-
59
- append_dict = {}
60
73
  append_dict["epoch_num"] = cur_epoch_num
61
- append_dict["step_num"] = step
74
+ append_dict["step_num"] = cb_params.cur_step_num
62
75
  append_dict["cur_rank"] = cur_rank
63
- append_dict["batch_num"] = batch_num
64
- append_dict["__exception_save__"] = True
65
-
66
- append_dict["global_step"] = Parameter([cb_ctx.global_step])
76
+ append_dict["batch_num"] = cb_params.batch_num
77
+ append_dict["global_step"] = cb_ctx.global_step
67
78
  outputs = cb_params.net_outputs
68
79
  if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
69
80
  append_dict["loss_scale"] = outputs[2]
@@ -76,47 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
76
87
  integrated_save=False, append_dict=append_dict)
77
88
  logger.info("Finish _save_checkpoint_on_failure function")
78
89
 
90
+
79
91
  def _rename_save_result(step, cb_ctx):
80
92
  """ Callback used for TFT rename function after ckpt save callback was finished and successful."""
81
93
  logger.info("Enter _rename_save_result function")
94
+ if cb_ctx.save_cb:
95
+ logger.info("User's save callback is provided, skip rename")
96
+ return
82
97
  tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
83
98
  fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
84
99
 
85
100
  os.rename(tmp_dir, fin_dir)
86
101
  logger.info("Finish _rename_save_result function")
87
102
 
103
+
88
104
  def _tft_exit_cb(ctx):
105
+ """Callback used for TFT exit function."""
89
106
  logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
90
107
  _tft_sem_post()
91
- os._exit(1) # pylint: disable=W0212
108
+ os._exit(1) # pylint: disable=W0212
109
+
92
110
 
93
111
  def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
94
112
  """ Callback used for TFT repair function."""
95
- logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
96
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
97
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
98
- logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
113
+ logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
114
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
115
+ cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
116
+ logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
99
117
  _repair_device(cb_ctx.device_id)
100
118
 
101
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
102
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
103
- logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
104
- {}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
119
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
120
+ cb_ctx.tft.RepairType.RT_SEND.value,
121
+ cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
122
+ logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
123
+ repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
105
124
  cb_params = args
106
- src_rank = repair_info["src"][0]
107
- dst_rank = repair_info["dst"][0]
108
- send_recv(cb_params.network.trainable_params(), src_rank, dst_rank)
109
- logger.info("Finish _tft_repair_callback")
125
+ if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
126
+ for i in range(len(repair_info["src"])):
127
+ src_rank = repair_info["src"][i]
128
+ dst_rank = repair_info["dst"][i]
129
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
130
+ raise ValueError("Call send_recv failed.")
131
+ else:
132
+ src_rank = repair_info["src"][0]
133
+ dst_rank = repair_info["dst"][0]
134
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
135
+ raise ValueError("Call send_recv failed.")
136
+ logger.warning("Finish _tft_repair_callback")
110
137
 
111
138
 
112
139
  def _tft_clean_callback(is_uce_error, args, ctx):
113
140
  """ Callback used for TFT clean function."""
114
- logger.info("Enter _tft_clean_callback")
141
+ logger.warning("Enter _tft_clean_callback")
115
142
  ret = 0
116
143
  if is_uce_error:
117
144
  _get_uce_mem_info(ctx.device_id)
118
145
  err_strategy = _get_uce_process_strategy()
119
- logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
146
+ logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
120
147
  if err_strategy == "RS_UCE_HIGHLEVEL":
121
148
  ret = 0
122
149
  elif err_strategy == "RS_UCE_LOWLEVEL":
@@ -124,59 +151,81 @@ def _tft_clean_callback(is_uce_error, args, ctx):
124
151
  else:
125
152
  ret = 1
126
153
  clean_tdt_channel()
127
- logger.info("Enter _tft_clean_callback resume_hccl_comm")
154
+ logger.warning("Enter _tft_clean_callback resume_hccl_comm")
128
155
  CollectiveManager.get_instance().resume_hccl_comm()
129
- logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
156
+ logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
130
157
  return ret
131
158
 
132
159
 
133
160
  def _tft_stop_callback(args, cb_ctx):
134
161
  """ Callback used for TFT stop function."""
135
- logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
162
+ logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
136
163
  _stop_device(cb_ctx.device_id)
137
- if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
164
+ if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
138
165
  raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
139
166
  cb_ctx.is_uce_rank = False
167
+ if cb_ctx.tft.tft_get_repair_type() == "recover":
168
+ logger.warning(f"Reset limit step")
169
+ cb_ctx.tft.tft_reset_limit_step()
140
170
  logger.info("Finish _tft_stop_callback")
141
171
 
142
172
 
143
- class TFTRegister(Callback):
173
+ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
174
+ """Callback used for TFT Rebuild Group function."""
175
+ logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
176
+ _finalize_comm()
177
+ _rebuild_world_group()
178
+ _rebuild_sub_group()
179
+ _set_recovery_context(is_arf=True)
180
+ logger.warning("Enter _tft_rebuild_sub_groups ok ")
181
+
182
+
183
+ class TrainFaultTolerance(Callback):
144
184
  """
145
185
  This callback is used to enable the TFT feature
146
- `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
147
- This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
186
+ `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
187
+ and will execute TFT operations during training process, such as TFT init, report and exception handle.
148
188
 
149
189
  Note:
150
190
  Required for Ascend graph mode only. And sink size must be less than or equal to 1.
151
191
 
152
192
  Args:
153
- ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
154
- ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
155
- ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
156
- ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
157
- named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
193
+ ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
194
+ a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
195
+ is created in that directory. Default: ``None``.
196
+ kwargs (dict): Other dictionary type parameters.
158
197
 
159
198
  Raises:
160
199
  Exception: TFT init failed.
161
200
  ModuleNotFoundError: Mindio TFT whl package is not installed.
162
201
 
163
202
  Examples:
203
+ .. note::
204
+ Before running the following examples, you need to configure the communication environment variables.
205
+
206
+ It's recommended to use the msrun startup method.
207
+ Please see the `msrun start up
208
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
209
+ for more details.
210
+
211
+ This example should be run with 4 devices.
212
+
164
213
  >>> import numpy as np
165
214
  >>> import os
166
215
  >>> import math
167
216
  >>> import mindspore as ms
168
217
  >>> import mindspore.dataset as ds
169
218
  >>> from mindspore import nn, ops, Parameter, train
170
- >>> from mindspore.communication import init
219
+ >>> from mindspore.communication import init, get_rank
171
220
  >>> from mindspore.common.initializer import initializer, HeUniform
172
- >>> from mindspore.train import Model, TFTRegister
221
+ >>> from mindspore.train import Model, TrainFaultTolerance
173
222
  >>> from mindspore import dataset as ds
174
223
  >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
175
224
  >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
176
225
  >>> init()
177
226
  >>> ms.set_seed(1)
178
227
  >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
179
- >>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
228
+ ... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
180
229
  >>> class MatMulCell(nn.Cell):
181
230
  ... def __init__(self, param=None, shape=None):
182
231
  ... super().__init__()
@@ -234,48 +283,58 @@ class TFTRegister(Callback):
234
283
  ... dataset = dataset.batch(batch_size)
235
284
  ... return dataset
236
285
  >>>
237
- >>> data_set = create_dataset(32)
286
+ >>> dataset = create_dataset(32)
238
287
  >>>
239
288
  >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
240
289
  >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
241
290
  >>> loss_fn = nn.CrossEntropyLoss()
242
291
  >>>
243
- >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
292
+ >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
244
293
  >>> net_with_loss.set_train()
245
- >>> model = Model(net_with_loss, optimizer=optimizer)
246
- >>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
294
+ >>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
295
+ >>> tft_cb = TrainFaultTolerance()
247
296
  >>> loss_cb = train.LossMonitor(1)
248
297
  >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
249
298
  """
250
299
 
251
- def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
252
- super(TFTRegister, self).__init__()
253
-
254
- tft_env = os.getenv("MS_ENABLE_TFT", "")
255
- if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
256
- raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
257
- mode = context.get_context("mode")
258
- device_target = context.get_context("device_target")
259
- if device_target != "Ascend" or mode != context.GRAPH_MODE:
260
- raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
261
-
262
- # let it raise errors if not install mindio_tft package
263
- from mindio_ttp import framework_ttp as tft
264
- self.tft = tft
265
- self.is_uce_rank = False
266
- self.global_step = 0
267
- Validator.check_non_negative_int(ctrl_port)
300
+ def __init__(self, ckpt_save_path=None, **kwargs):
301
+ super(TrainFaultTolerance, self).__init__()
302
+ self.save_cb = kwargs.get("ckpt_save_fn", None)
303
+ self.ckpt_save_path = ckpt_save_path
304
+ if self.save_cb is None and self.ckpt_save_path is None:
305
+ raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
306
+ self.tft = _tft_handler.get_tft()
307
+ self._check_init()
308
+ self.global_step = None
309
+ self.learning_rate = None
268
310
  self.has_init_replica = False
269
- self._controller_ip = ctrl_ip
270
- self._controller_rank_id = ctrl_rank_id
271
- self._controller_port = ctrl_port
311
+ self.is_uce_rank = False
272
312
  self.cb_params = None
313
+ self.initial_step = kwargs.get("initial_step", 0)
273
314
  self.device_id = context.get_context("device_id")
274
- self._init_tft()
275
- self.ckpt_save_path = ckpt_save_path
276
315
  self.assign = mindspore.ops.Assign()
277
316
  self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
278
317
  self.s1 = mindspore.hal.Stream()
318
+ self.cur_step_num = 0
319
+ self.cur_epoch_num = 0
320
+ _tft_sem_enable()
321
+ self._tft_register()
322
+
323
+ def _check_init(self):
324
+ """Check if the mindio-ttp had inited"""
325
+ if self.tft is None:
326
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
327
+ if "ARF:1" in tft_env:
328
+ raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
329
+ logger.warning(f"TFT handle not init, try to init")
330
+ _tft_handler.init(config=None)
331
+ self.tft = _tft_handler.get_tft()
332
+ logger.warning(f"TFT handle init ok.")
333
+ mode = context.get_context("mode")
334
+ device_target = context.get_context("device_target")
335
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
336
+ raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
337
+ f"device:{device_target}, run mode: {mode}")
279
338
 
280
339
  def _is_params_consistent(self):
281
340
  for key, param in self.cb_params.train_network.parameters_and_names():
@@ -287,7 +346,7 @@ class TFTRegister(Callback):
287
346
  return False
288
347
 
289
348
  def _set_tft_optimizer_replica(self, run_context):
290
- """ set Mindio TFT optimizer replica info, used internal. """
349
+ """ Set Mindio TFT optimizer replica info, used internal. """
291
350
  cur_rank = get_rank()
292
351
  cb_params = run_context.original_args()
293
352
  train_network = cb_params.train_network
@@ -309,37 +368,61 @@ class TFTRegister(Callback):
309
368
  ]
310
369
  self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
311
370
 
312
- def _init_tft(self):
313
- """ Init Mindio TFT, used internal. """
314
- logger.info("Begin to init tft.")
371
+ @classmethod
372
+ def get_optimizer_wrapper(cls, origin_opt_cls):
373
+ """
374
+ Optimizer wrapper func when using tft.
375
+
376
+ Args:
377
+ origin_opt_cls (Class): origin optimizer class.
378
+ """
379
+
380
+ class TFTOptSubCls(origin_opt_cls):
381
+ """
382
+ Optimizer wrapper class when using tft.
383
+ """
384
+
385
+ def __init__(self, *args, **kwargs):
386
+ super(TFTOptSubCls, self).__init__(*args, **kwargs)
387
+ self.report = TensorReport()
388
+ self.report_end = TensorReport()
389
+ self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
390
+ self.depend = ops.Depend()
391
+ self.allreduce_sum = ops.AllReduce()
392
+ self.allreduce_sum.add_prim_attr("tft_report_before", True)
393
+ self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
394
+
395
+ def construct(self, gradients, **kwargs):
396
+ tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
397
+ self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
398
+ grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
399
+ opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
400
+ self.report_end("tft_report", self.tft_g_one_flag)
401
+ return opt_ret
402
+
403
+ return TFTOptSubCls
404
+
405
+ def _tft_register(self):
406
+ """Register callback functions."""
315
407
  self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
316
408
  self.tft.tft_register_rename_handler(_rename_save_result, self)
317
409
  self.tft.tft_register_exit_handler(_tft_exit_cb, self)
318
410
  self.tft.tft_register_stop_handler(_tft_stop_callback, self)
319
411
  self.tft.tft_register_clean_handler(_tft_clean_callback, self)
320
412
  self.tft.tft_register_repair_handler(_tft_repair_callback, self)
413
+ self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
321
414
 
322
- world_size = _get_device_num()
323
- cur_rank = get_rank()
324
- enable_local_copy = False
325
- enable_arf = False
326
- enable_tls = False
327
- tls_key_dir = ""
328
-
329
- if cur_rank == self._controller_rank_id:
330
- logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
331
- self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
332
- self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
333
- logger.info("Finish start tft controller.")
334
-
335
- logger.info("Begin to start tft processor.")
336
- self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
337
- self.tft.tft_start_processor(self._controller_ip, self._controller_port)
338
- logger.info("Finished start tft processor.")
415
+ def _reset_acc_grads(self):
416
+ accu_grad_params = map(lambda e: e[1],
417
+ filter(lambda e: e[1].name.startswith('accu_grads'),
418
+ self.cb_params.train_network.parameters_and_names()))
419
+ accu_grad_list = list(accu_grad_params)
420
+ if reset_params(accu_grad_list) != 0:
421
+ raise ValueError("Call reset_params failed.")
339
422
 
340
423
  def on_train_step_end(self, run_context):
341
424
  """
342
- And report status to MindIO TFT after every step finished.
425
+ Report status to MindIO TFT after every step finished.
343
426
 
344
427
  Args:
345
428
  run_context (RunContext): Context of the train running. Refer to
@@ -350,17 +433,27 @@ class TFTRegister(Callback):
350
433
  self._set_tft_optimizer_replica(run_context)
351
434
  cb_params = run_context.original_args()
352
435
  logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
353
- self.tft.tft_end_updating_os(cb_params.cur_step_num)
436
+ self.cur_step_num = cb_params.cur_step_num
437
+ self.cur_epoch_num = cb_params.cur_epoch_num
354
438
  if cb_params.optimizer is not None:
355
- self.global_step = int(cb_params.optimizer.global_step.data)
439
+ self.global_step = cb_params.optimizer.global_step.clone()
356
440
  self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
357
- else:
358
- self.global_step = int(cb_params.network.optimizer.global_step.data)
441
+ elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
442
+ self.global_step = cb_params.network.optimizer.global_step.clone()
359
443
  self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
444
+ else:
445
+ raise ValueError("TFT feature need optimizer or network's optimizer!")
446
+ self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
360
447
  logger.info("END Set optimizer finish step status to TFT.")
361
448
 
362
-
363
449
  def on_train_begin(self, run_context):
450
+ """
451
+ Register train params to MindIO TFT on train beginning.
452
+
453
+ Args:
454
+ run_context (RunContext): Context of the train running. Refer to
455
+ :class:`mindspore.train.RunContext` for detail.
456
+ """
364
457
  cb_params = run_context.original_args()
365
458
  sink_size = cb_params.get("sink_size", 0)
366
459
  if sink_size > 1:
@@ -370,7 +463,11 @@ class TFTRegister(Callback):
370
463
  self.cb_params = cb_params
371
464
 
372
465
  def end(self, run_context):
373
- cur_rank = get_rank()
374
- if cur_rank == self._controller_rank_id:
375
- self.tft.tft_destroy_controller()
376
- self.tft.tft_destroy_processor()
466
+ """
467
+ Unregister MindIO TFT on train end.
468
+
469
+ Args:
470
+ run_context (RunContext): Context of the train running. Refer to
471
+ :class:`mindspore.train.RunContext` for detail.
472
+ """
473
+ _tft_handler.unregister_tft()
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
98
98
  return next_op, (key, dataset_shapes, dataset_types)
99
99
 
100
100
 
101
+ def _get_jit_func(sink_fun, jit_config):
102
+ """
103
+ Get the jit function.
104
+ """
105
+ jit_config_dict = jit_config.jit_config_dict
106
+ jit_level = jit_config_dict['jit_level']
107
+ if jit_level == "":
108
+ jit_level = "O0"
109
+ backend = ""
110
+ if jit_level == "O2":
111
+ jit_level = "O0"
112
+ backend = "GE"
113
+ if "backend" in jit_config_dict:
114
+ backend = jit_config_dict["backend"]
115
+ fullgraph = False
116
+ if jit_config_dict['jit_syntax_level'] == "STRICT":
117
+ fullgraph = True
118
+ exc_mode = jit_config_dict['exc_mode']
119
+ infer_boost = jit_config_dict['infer_boost']
120
+ return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
121
+ infer_boost=infer_boost)
122
+
123
+
101
124
  def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
102
125
  """
103
126
  get the sink function.
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
107
130
  if jit_config is None:
108
131
  dst_sink_fun = sink_fun
109
132
  else:
110
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
133
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
111
134
  dataset.__sink_fun__ = dst_sink_fun
112
135
 
113
136
  return dataset.__sink_fun__
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
119
142
  if jit_config is None:
120
143
  dst_sink_fun = sink_fun
121
144
  else:
122
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
145
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
123
146
  dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
124
147
 
125
148
  return dst_sink_fun
@@ -15,7 +15,6 @@
15
15
  """Dataset help for minddata dataset"""
16
16
  from __future__ import absolute_import
17
17
 
18
- import os
19
18
  import math
20
19
  import copy
21
20
 
@@ -25,9 +24,11 @@ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
25
24
  from mindspore.common.dtype import pytype_to_dtype
26
25
  from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
27
26
  from mindspore.common._utils import is_shape_unknown
27
+ from mindspore.dataset.core import config as dataset_config
28
28
  from mindspore.dataset.engine import offload
29
29
  from mindspore import context, nn
30
- from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
30
+ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
31
+ _construct_tensor_list, enable_data_broadcast
31
32
  from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
32
33
  _to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \
33
34
  _origin_shapes, _dynamic_shape_for_dataset
@@ -213,8 +214,7 @@ def _get_dataset_aux(dataset):
213
214
 
214
215
  def connect_network_with_dataset(network, dataset_helper):
215
216
  """
216
- Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
217
- <https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
217
+ Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
218
218
  (dataset_sink_mode=True).
219
219
 
220
220
  Args:
@@ -263,16 +263,14 @@ def connect_network_with_dataset(network, dataset_helper):
263
263
  "The dataset has been connected to other network, please check the code.")
264
264
  is_dynamic = bool(network.get_inputs())
265
265
  queue_name = dataset.__transfer_dataset__.queue_name
266
+
266
267
  # In pipeline parallel, some stages have no GetNext, should not get in.
268
+ # Don't enable dynamic shape(multi-subgraph) feature in pp/dataset_broadcast mode,
269
+ # otherwise get_data_info will stuck since some rank do not consume data.
267
270
  use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
271
+ data_broadcast = enable_data_broadcast()
268
272
 
269
- # temp env to disable dynamic feature of sink size 1
270
- dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
271
- dynamic_sink1 = True
272
- if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
273
- dynamic_sink1 = False
274
-
275
- if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and dynamic_sink1:
273
+ if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and not data_broadcast:
276
274
  dataset_types, dataset_shapes = dataset_helper.get_data_info()
277
275
  # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
278
276
  if _need_to_full():
@@ -314,7 +312,7 @@ def connect_network_with_dataset(network, dataset_helper):
314
312
  aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
315
313
 
316
314
  if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
317
- not use_pipeline_parallel and dynamic_sink1:
315
+ not use_pipeline_parallel and not data_broadcast:
318
316
  dataset_helper.get_data_info()
319
317
  network.add_flags(sink_mode=True)
320
318
  return network
@@ -336,11 +334,11 @@ class DatasetHelper:
336
334
  dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
337
335
  dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
338
336
  Default: ``True``.
339
- sink_size (int): Control the amount of data in each sink.
337
+ sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
340
338
  If sink_size=-1, sink the complete dataset for each epoch.
341
339
  If sink_size>0, sink sink_size data for each epoch.
342
- Default: -1.
343
- epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1.
340
+ Default: ``-1``.
341
+ epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
344
342
 
345
343
  Examples:
346
344
  >>> import numpy as np
@@ -686,8 +684,9 @@ class _DatasetIterNormal:
686
684
  self.dataset = dataset
687
685
  self.device_num = _get_device_num()
688
686
  self.global_rank = _get_global_rank()
687
+ do_copy = dataset_config.get_iterator_mode()["do_copy"]
689
688
  self.iter = self.dataset.create_tuple_iterator(
690
- num_epochs=epoch_num, do_copy=True)
689
+ num_epochs=epoch_num, do_copy=do_copy)
691
690
 
692
691
  def __iter__(self):
693
692
  return self
@@ -51,9 +51,10 @@ class FixedLossScaleManager(LossScaleManager):
51
51
  inherits from :class:`mindspore.amp.LossScaleManager`.
52
52
 
53
53
  Args:
54
- loss_scale (float): Magnification factor of gradients. Note that if `drop_overflow_update` is set to ``False`` ,
54
+ loss_scale (float, optional): Magnification factor of gradients.
55
+ Note that if `drop_overflow_update` is set to ``False`` ,
55
56
  the value of `loss_scale` in optimizer should be set to the same as here. Default: ``128.0`` .
56
- drop_overflow_update (bool): Whether to execute optimizer if there is an overflow.
57
+ drop_overflow_update (bool, optional): Whether to execute optimizer if there is an overflow.
57
58
  If ``True`` , the optimizer will
58
59
  not executed when overflow occurs. Default: ``True`` .
59
60
 
@@ -110,8 +111,8 @@ class FixedLossScaleManager(LossScaleManager):
110
111
 
111
112
  Returns:
112
113
  None or :class:`mindspore.nn.FixedLossScaleUpdateCell`. Instance of
113
- :class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is True. None when
114
- `drop_overflow_update` is False.
114
+ :class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is ``True``. None when
115
+ `drop_overflow_update` is ``False``.
115
116
  """
116
117
  if not self._drop_overflow_update:
117
118
  return None
@@ -124,9 +125,9 @@ class DynamicLossScaleManager(LossScaleManager):
124
125
  adjusted, inherits from :class:`mindspore.amp.LossScaleManager`.
125
126
 
126
127
  Args:
127
- init_loss_scale (float): Initialize loss scale. Default: ``2 ** 24`` .
128
- scale_factor (int): Coefficient of increase and decrease. Default: ``2`` .
129
- scale_window (int): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
128
+ init_loss_scale (float, optional): Initialize loss scale. Default: ``2 ** 24`` .
129
+ scale_factor (int, optional): Coefficient of increase and decrease. Default: ``2`` .
130
+ scale_window (int, optional): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
130
131
 
131
132
  Supported Platforms:
132
133
  ``Ascend`` ``GPU``