mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (579) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +47 -198
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +229 -99
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +480 -372
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +5 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +975 -1981
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +324 -573
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/dnnl.dll +0 -0
  112. mindspore/experimental/es/embedding_service.py +35 -27
  113. mindspore/experimental/llm_boost/__init__.py +1 -0
  114. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  115. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  116. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  119. mindspore/experimental/llm_boost/register.py +1 -0
  120. mindspore/experimental/map_parameter.py +4 -4
  121. mindspore/experimental/optim/adadelta.py +6 -6
  122. mindspore/experimental/optim/adagrad.py +4 -4
  123. mindspore/experimental/optim/adam.py +7 -0
  124. mindspore/experimental/optim/adamax.py +4 -4
  125. mindspore/experimental/optim/adamw.py +4 -0
  126. mindspore/experimental/optim/asgd.py +1 -1
  127. mindspore/experimental/optim/lr_scheduler.py +73 -46
  128. mindspore/experimental/optim/radam.py +34 -31
  129. mindspore/experimental/optim/rprop.py +1 -1
  130. mindspore/experimental/optim/sgd.py +1 -1
  131. mindspore/hal/contiguous_tensors_handle.py +6 -10
  132. mindspore/hal/device.py +55 -53
  133. mindspore/hal/event.py +52 -52
  134. mindspore/hal/memory.py +179 -120
  135. mindspore/hal/stream.py +150 -109
  136. mindspore/include/api/context.h +0 -1
  137. mindspore/include/dataset/constants.h +7 -4
  138. mindspore/include/dataset/execute.h +2 -2
  139. mindspore/jpeg62.dll +0 -0
  140. mindspore/log.py +50 -0
  141. mindspore/mindrecord/__init__.py +21 -8
  142. mindspore/mindrecord/config.py +17 -316
  143. mindspore/mindrecord/filereader.py +1 -9
  144. mindspore/mindrecord/filewriter.py +5 -15
  145. mindspore/mindrecord/mindpage.py +1 -9
  146. mindspore/mindspore_backend_common.dll +0 -0
  147. mindspore/mindspore_backend_manager.dll +0 -0
  148. mindspore/mindspore_common.dll +0 -0
  149. mindspore/mindspore_core.dll +0 -0
  150. mindspore/mindspore_dump.dll +0 -0
  151. mindspore/mindspore_frontend.dll +0 -0
  152. mindspore/mindspore_glog.dll +0 -0
  153. mindspore/mindspore_memory_pool.dll +0 -0
  154. mindspore/mindspore_ms_backend.dll +0 -0
  155. mindspore/mindspore_ops.dll +0 -0
  156. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  157. mindspore/mindspore_ops_kernel_common.dll +0 -0
  158. mindspore/mindspore_profiler.dll +0 -0
  159. mindspore/mindspore_pyboost.dll +0 -0
  160. mindspore/mindspore_pynative.dll +0 -0
  161. mindspore/mindspore_res_manager.dll +0 -0
  162. mindspore/mindspore_runtime_pipeline.dll +0 -0
  163. mindspore/mint/__init__.py +798 -761
  164. mindspore/mint/distributed/__init__.py +70 -4
  165. mindspore/mint/distributed/distributed.py +2679 -44
  166. mindspore/mint/linalg/__init__.py +8 -0
  167. mindspore/mint/nn/__init__.py +743 -22
  168. mindspore/mint/nn/functional.py +716 -23
  169. mindspore/mint/nn/layer/__init__.py +21 -4
  170. mindspore/mint/nn/layer/_functions.py +334 -0
  171. mindspore/mint/nn/layer/activation.py +276 -1
  172. mindspore/mint/nn/layer/basic.py +123 -0
  173. mindspore/mint/nn/layer/conv.py +933 -0
  174. mindspore/mint/nn/layer/normalization.py +223 -28
  175. mindspore/mint/nn/layer/padding.py +797 -0
  176. mindspore/mint/nn/layer/pooling.py +235 -0
  177. mindspore/mint/optim/__init__.py +3 -1
  178. mindspore/mint/optim/adam.py +223 -0
  179. mindspore/mint/optim/adamw.py +26 -19
  180. mindspore/mint/optim/sgd.py +171 -0
  181. mindspore/mint/special/__init__.py +2 -1
  182. mindspore/multiprocessing/__init__.py +5 -0
  183. mindspore/nn/__init__.py +4 -1
  184. mindspore/nn/cell.py +1373 -192
  185. mindspore/nn/dynamic_lr.py +2 -1
  186. mindspore/nn/layer/activation.py +29 -27
  187. mindspore/nn/layer/basic.py +51 -35
  188. mindspore/nn/layer/channel_shuffle.py +3 -3
  189. mindspore/nn/layer/container.py +1 -1
  190. mindspore/nn/layer/conv.py +53 -42
  191. mindspore/nn/layer/embedding.py +12 -11
  192. mindspore/nn/layer/normalization.py +56 -49
  193. mindspore/nn/layer/padding.py +4 -3
  194. mindspore/nn/layer/pooling.py +120 -42
  195. mindspore/nn/layer/rnn_cells.py +1 -1
  196. mindspore/nn/layer/rnns.py +2 -1
  197. mindspore/nn/layer/timedistributed.py +5 -5
  198. mindspore/nn/layer/transformer.py +59 -36
  199. mindspore/nn/learning_rate_schedule.py +8 -4
  200. mindspore/nn/loss/loss.py +58 -55
  201. mindspore/nn/optim/ada_grad.py +7 -5
  202. mindspore/nn/optim/adadelta.py +11 -9
  203. mindspore/nn/optim/adafactor.py +1 -1
  204. mindspore/nn/optim/adam.py +19 -15
  205. mindspore/nn/optim/adamax.py +8 -7
  206. mindspore/nn/optim/adasum.py +5 -5
  207. mindspore/nn/optim/asgd.py +3 -1
  208. mindspore/nn/optim/ftrl.py +11 -9
  209. mindspore/nn/optim/lamb.py +1 -1
  210. mindspore/nn/optim/lars.py +1 -4
  211. mindspore/nn/optim/lazyadam.py +12 -10
  212. mindspore/nn/optim/momentum.py +7 -6
  213. mindspore/nn/optim/optimizer.py +3 -3
  214. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  215. mindspore/nn/optim/rmsprop.py +13 -12
  216. mindspore/nn/optim/rprop.py +11 -9
  217. mindspore/nn/optim/sgd.py +9 -6
  218. mindspore/nn/optim/tft_wrapper.py +5 -2
  219. mindspore/nn/optim/thor.py +2 -1
  220. mindspore/nn/probability/bijector/bijector.py +17 -11
  221. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  222. mindspore/nn/probability/bijector/invert.py +2 -2
  223. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  224. mindspore/nn/probability/bijector/softplus.py +3 -2
  225. mindspore/nn/probability/distribution/beta.py +3 -3
  226. mindspore/nn/probability/distribution/categorical.py +1 -1
  227. mindspore/nn/probability/distribution/cauchy.py +4 -2
  228. mindspore/nn/probability/distribution/exponential.py +6 -7
  229. mindspore/nn/probability/distribution/gamma.py +2 -2
  230. mindspore/nn/probability/distribution/gumbel.py +2 -2
  231. mindspore/nn/probability/distribution/half_normal.py +5 -3
  232. mindspore/nn/probability/distribution/logistic.py +5 -3
  233. mindspore/nn/probability/distribution/poisson.py +1 -1
  234. mindspore/nn/probability/distribution/uniform.py +5 -3
  235. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  236. mindspore/nn/reinforcement/tensor_array.py +1 -1
  237. mindspore/nn/utils/init.py +13 -11
  238. mindspore/nn/wrap/__init__.py +6 -6
  239. mindspore/nn/wrap/cell_wrapper.py +181 -122
  240. mindspore/nn/wrap/grad_reducer.py +45 -36
  241. mindspore/nn/wrap/loss_scale.py +6 -7
  242. mindspore/numpy/array_creations.py +63 -65
  243. mindspore/numpy/array_ops.py +149 -144
  244. mindspore/numpy/logic_ops.py +41 -42
  245. mindspore/numpy/math_ops.py +361 -359
  246. mindspore/numpy/utils.py +17 -18
  247. mindspore/numpy/utils_const.py +5 -6
  248. mindspore/opencv_core452.dll +0 -0
  249. mindspore/opencv_imgcodecs452.dll +0 -0
  250. mindspore/opencv_imgproc452.dll +0 -0
  251. mindspore/ops/__init__.py +5 -3
  252. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  253. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  256. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  257. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  258. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  259. mindspore/ops/_register_for_op.py +0 -11
  260. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  261. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  262. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  263. mindspore/ops/_vmap/vmap_base.py +0 -2
  264. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  265. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  266. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  267. mindspore/ops/auto_generate/__init__.py +4 -3
  268. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  269. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  270. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  271. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  272. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  273. mindspore/ops/composite/__init__.py +2 -1
  274. mindspore/ops/composite/base.py +20 -25
  275. mindspore/ops/composite/math_ops.py +6 -16
  276. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  277. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  278. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  279. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  282. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  283. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  284. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  285. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  286. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  287. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  288. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  289. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  291. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  292. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  293. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  294. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  295. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  299. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  301. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  302. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  303. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  304. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  305. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  306. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  307. mindspore/ops/function/__init__.py +40 -2
  308. mindspore/ops/function/_add_attr_func.py +58 -0
  309. mindspore/ops/function/array_func.py +2089 -2403
  310. mindspore/ops/function/clip_func.py +80 -23
  311. mindspore/ops/function/debug_func.py +57 -57
  312. mindspore/ops/function/grad/__init__.py +1 -0
  313. mindspore/ops/function/grad/grad_func.py +104 -71
  314. mindspore/ops/function/image_func.py +2 -2
  315. mindspore/ops/function/linalg_func.py +47 -78
  316. mindspore/ops/function/math_func.py +4351 -3813
  317. mindspore/ops/function/nn_func.py +1712 -637
  318. mindspore/ops/function/other_func.py +159 -1
  319. mindspore/ops/function/parameter_func.py +18 -84
  320. mindspore/ops/function/random_func.py +452 -387
  321. mindspore/ops/function/reshard_func.py +4 -70
  322. mindspore/ops/function/sparse_func.py +3 -3
  323. mindspore/ops/function/sparse_unary_func.py +6 -6
  324. mindspore/ops/function/spectral_func.py +25 -58
  325. mindspore/ops/function/vmap_func.py +26 -18
  326. mindspore/ops/functional.py +23 -7
  327. mindspore/ops/functional_overload.py +1548 -0
  328. mindspore/ops/op_info_register.py +32 -244
  329. mindspore/ops/operations/__init__.py +23 -15
  330. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  331. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  332. mindspore/ops/operations/_grad_ops.py +2 -43
  333. mindspore/ops/operations/_infer_ops.py +2 -1
  334. mindspore/ops/operations/_inner_ops.py +43 -84
  335. mindspore/ops/operations/_ms_kernel.py +4 -10
  336. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  337. mindspore/ops/operations/_scalar_ops.py +3 -2
  338. mindspore/ops/operations/_sequence_ops.py +1 -1
  339. mindspore/ops/operations/_tensor_array.py +1 -1
  340. mindspore/ops/operations/array_ops.py +81 -324
  341. mindspore/ops/operations/comm_ops.py +154 -108
  342. mindspore/ops/operations/custom_ops.py +298 -87
  343. mindspore/ops/operations/debug_ops.py +157 -59
  344. mindspore/ops/operations/inner_ops.py +7 -5
  345. mindspore/ops/operations/linalg_ops.py +1 -57
  346. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  347. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  348. mindspore/ops/operations/math_ops.py +32 -234
  349. mindspore/ops/operations/nn_ops.py +212 -531
  350. mindspore/ops/operations/other_ops.py +62 -9
  351. mindspore/ops/operations/random_ops.py +13 -7
  352. mindspore/ops/operations/reshard_ops.py +1 -1
  353. mindspore/ops/operations/sparse_ops.py +2 -2
  354. mindspore/ops/primitive.py +66 -53
  355. mindspore/ops/tensor_method.py +1895 -0
  356. mindspore/ops_generate/__init__.py +0 -5
  357. mindspore/ops_generate/aclnn/__init__.py +0 -0
  358. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  359. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  360. mindspore/ops_generate/api/__init__.py +0 -0
  361. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  362. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  363. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  364. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  365. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  366. mindspore/ops_generate/api/gen_api.py +103 -0
  367. mindspore/ops_generate/api/op_api_proto.py +235 -0
  368. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  369. mindspore/ops_generate/common/__init__.py +0 -0
  370. mindspore/ops_generate/common/base_generator.py +11 -0
  371. mindspore/ops_generate/common/gen_constants.py +91 -0
  372. mindspore/ops_generate/common/gen_utils.py +348 -0
  373. mindspore/ops_generate/common/op_proto.py +473 -0
  374. mindspore/ops_generate/common/template.py +523 -0
  375. mindspore/ops_generate/gen_ops.py +22 -1069
  376. mindspore/ops_generate/op_def/__init__.py +0 -0
  377. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  378. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  379. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  380. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  381. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  382. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  383. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  384. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  385. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  386. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  387. mindspore/ops_generate/pyboost/__init__.py +0 -0
  388. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  389. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  390. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  391. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  393. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  394. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  395. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  396. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  397. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  398. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  399. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  400. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  401. mindspore/ops_generate/resources/__init__.py +0 -0
  402. mindspore/ops_generate/resources/resource_list.py +30 -0
  403. mindspore/ops_generate/resources/resource_loader.py +36 -0
  404. mindspore/ops_generate/resources/resource_manager.py +64 -0
  405. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  406. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  407. mindspore/parallel/__init__.py +7 -3
  408. mindspore/parallel/_auto_parallel_context.py +159 -40
  409. mindspore/parallel/_cell_wrapper.py +132 -15
  410. mindspore/parallel/_parallel_serialization.py +107 -5
  411. mindspore/parallel/_ps_context.py +1 -1
  412. mindspore/parallel/_recovery_context.py +7 -2
  413. mindspore/parallel/_tensor.py +142 -18
  414. mindspore/parallel/_utils.py +199 -23
  415. mindspore/parallel/algo_parameter_config.py +4 -4
  416. mindspore/parallel/auto_parallel.py +732 -0
  417. mindspore/parallel/checkpoint_convert.py +159 -0
  418. mindspore/parallel/checkpoint_transform.py +700 -35
  419. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  420. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  421. mindspore/parallel/cluster/run.py +21 -4
  422. mindspore/parallel/function/__init__.py +24 -0
  423. mindspore/parallel/function/reshard_func.py +258 -0
  424. mindspore/parallel/nn/__init__.py +25 -0
  425. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  426. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  427. mindspore/parallel/parameter_broadcast.py +25 -14
  428. mindspore/parallel/shard.py +137 -59
  429. mindspore/parallel/transform_safetensors.py +364 -305
  430. mindspore/profiler/__init__.py +22 -5
  431. mindspore/profiler/analysis/__init__.py +0 -0
  432. mindspore/profiler/analysis/parser/__init__.py +0 -0
  433. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  434. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  435. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  436. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  437. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  440. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  441. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  446. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  447. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  448. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  449. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  450. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  451. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  452. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  453. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  454. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  455. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  456. mindspore/profiler/analysis/task_manager.py +131 -0
  457. mindspore/profiler/analysis/time_converter.py +84 -0
  458. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  459. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  460. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  461. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  462. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  463. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  464. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  465. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  466. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  467. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  468. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  469. mindspore/profiler/analysis/work_flow.py +73 -0
  470. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  471. mindspore/profiler/common/command_executor.py +90 -0
  472. mindspore/profiler/common/constant.py +186 -3
  473. mindspore/profiler/common/file_manager.py +208 -0
  474. mindspore/profiler/common/log.py +130 -0
  475. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  476. mindspore/profiler/common/path_manager.py +395 -0
  477. mindspore/profiler/common/process_bar.py +168 -0
  478. mindspore/profiler/common/process_pool.py +9 -3
  479. mindspore/profiler/common/profiler_context.py +500 -0
  480. mindspore/profiler/common/profiler_info.py +304 -0
  481. mindspore/profiler/common/profiler_meta_data.py +74 -0
  482. mindspore/profiler/common/profiler_output_path.py +284 -0
  483. mindspore/profiler/common/profiler_parameters.py +251 -0
  484. mindspore/profiler/common/profiler_path_manager.py +179 -0
  485. mindspore/profiler/common/record_function.py +76 -0
  486. mindspore/profiler/common/tlv_decoder.py +76 -0
  487. mindspore/profiler/common/util.py +75 -2
  488. mindspore/profiler/dynamic_profiler.py +341 -75
  489. mindspore/profiler/envprofiler.py +163 -0
  490. mindspore/profiler/experimental_config.py +197 -0
  491. mindspore/profiler/mstx.py +242 -0
  492. mindspore/profiler/platform/__init__.py +21 -0
  493. mindspore/profiler/platform/base_profiler.py +40 -0
  494. mindspore/profiler/platform/cpu_profiler.py +124 -0
  495. mindspore/profiler/platform/gpu_profiler.py +74 -0
  496. mindspore/profiler/platform/npu_profiler.py +335 -0
  497. mindspore/profiler/profiler.py +1073 -90
  498. mindspore/profiler/profiler_action_controller.py +187 -0
  499. mindspore/profiler/profiler_interface.py +118 -0
  500. mindspore/profiler/schedule.py +243 -0
  501. mindspore/rewrite/api/node.py +15 -13
  502. mindspore/rewrite/api/symbol_tree.py +2 -3
  503. mindspore/run_check/_check_version.py +27 -20
  504. mindspore/run_check/run_check.py +1 -1
  505. mindspore/runtime/__init__.py +37 -0
  506. mindspore/runtime/device.py +27 -0
  507. mindspore/runtime/event.py +209 -0
  508. mindspore/runtime/executor.py +177 -0
  509. mindspore/runtime/memory.py +416 -0
  510. mindspore/runtime/stream.py +460 -0
  511. mindspore/runtime/thread_bind_core.py +401 -0
  512. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  513. mindspore/swresample-4.dll +0 -0
  514. mindspore/swscale-6.dll +0 -0
  515. mindspore/tinyxml2.dll +0 -0
  516. mindspore/train/__init__.py +8 -8
  517. mindspore/train/_utils.py +96 -27
  518. mindspore/train/amp.py +9 -5
  519. mindspore/train/callback/__init__.py +2 -2
  520. mindspore/train/callback/_callback.py +2 -16
  521. mindspore/train/callback/_checkpoint.py +53 -55
  522. mindspore/train/callback/_cluster_monitor.py +14 -18
  523. mindspore/train/callback/_early_stop.py +1 -1
  524. mindspore/train/callback/_flops_collector.py +103 -68
  525. mindspore/train/callback/_history.py +8 -5
  526. mindspore/train/callback/_lambda_callback.py +2 -2
  527. mindspore/train/callback/_landscape.py +0 -3
  528. mindspore/train/callback/_loss_monitor.py +2 -1
  529. mindspore/train/callback/_on_request_exit.py +6 -5
  530. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  531. mindspore/train/callback/_summary_collector.py +52 -19
  532. mindspore/train/callback/_time_monitor.py +2 -1
  533. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  534. mindspore/train/data_sink.py +25 -2
  535. mindspore/train/dataset_helper.py +15 -16
  536. mindspore/train/loss_scale_manager.py +8 -7
  537. mindspore/train/metrics/accuracy.py +3 -3
  538. mindspore/train/metrics/confusion_matrix.py +9 -9
  539. mindspore/train/metrics/error.py +3 -3
  540. mindspore/train/metrics/hausdorff_distance.py +4 -4
  541. mindspore/train/metrics/mean_surface_distance.py +3 -3
  542. mindspore/train/metrics/metric.py +0 -12
  543. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  544. mindspore/train/metrics/precision.py +11 -10
  545. mindspore/train/metrics/recall.py +9 -9
  546. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +174 -46
  548. mindspore/train/model.py +269 -136
  549. mindspore/train/serialization.py +622 -978
  550. mindspore/train/summary/_summary_adapter.py +2 -2
  551. mindspore/train/summary/summary_record.py +2 -3
  552. mindspore/train/train_thor/model_thor.py +1 -1
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +6 -3
  555. mindspore/utils/dryrun.py +140 -0
  556. mindspore/utils/hooks.py +81 -0
  557. mindspore/utils/runtime_execution_order_check.py +552 -0
  558. mindspore/utils/utils.py +138 -4
  559. mindspore/version.py +1 -1
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  561. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
  562. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  563. mindspore/_install_custom.py +0 -43
  564. mindspore/common/_register_for_adapter.py +0 -74
  565. mindspore/common/_tensor_overload.py +0 -139
  566. mindspore/mindspore_np_dtype.dll +0 -0
  567. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  568. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  569. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  570. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  571. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  572. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  573. mindspore/ops_generate/gen_utils.py +0 -209
  574. mindspore/ops_generate/op_proto.py +0 -145
  575. mindspore/ops_generate/template.py +0 -261
  576. mindspore/profiler/envprofiling.py +0 -254
  577. mindspore/profiler/profiling.py +0 -1926
  578. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  579. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -15,24 +15,27 @@
15
15
  """Checkpoint related classes and functions."""
16
16
 
17
17
  import os
18
+ from mindspore.utils import _tft_handler
18
19
  from mindspore.train.serialization import save_checkpoint
19
- from mindspore.parallel._utils import _get_device_num
20
- from mindspore import _checkparam as Validator
21
20
  from mindspore.train.callback._callback import Callback
22
- from mindspore import context
21
+ from mindspore import context, ops
23
22
  from mindspore.common.parameter import Parameter
24
23
  from mindspore.common.tensor import Tensor
25
24
  from mindspore.communication import get_rank, get_group_size
26
25
  from mindspore import log as logger
27
26
  from mindspore.train.serialization import _get_cur_rank_dp
28
- from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
27
+ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
28
+ from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
29
29
  from mindspore._c_expression import clean_tdt_channel
30
- from mindspore._c_expression import send_recv
30
+ from mindspore._c_expression import send_recv, reset_params
31
31
  from mindspore._c_expression import CollectiveManager
32
32
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
33
- from mindspore._c_expression import Tensor as Tensor_
33
+ from mindspore._c_expression import TensorPy as Tensor_
34
+ from mindspore.ops.operations.manually_defined._inner import TensorReport
34
35
  import mindspore
35
36
  import mindspore.common.dtype as mstype
37
+ from mindspore.parallel._recovery_context import _set_recovery_context
38
+
36
39
 
37
40
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
38
41
  """ Common func to generate ckpt dir name."""
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
40
43
  mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
41
44
  return os.path.join(ckpt_save_path, mid_dir)
42
45
 
46
+
43
47
  def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
44
48
  """ Callback used for TFT save ckpt function when errors occur."""
45
49
  logger.info("Enter _save_checkpoint_on_failure function")
46
- if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
50
+ if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
47
51
  raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
52
+ cb_params = args
53
+ # we record the current step and epoch num in on_train_step_end, so we can just reset it here
54
+ cb_params.cur_step_num = cb_ctx.cur_step_num
55
+ cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
56
+ if cb_params.optimizer is not None:
57
+ cb_params.optimizer.global_step = cb_ctx.global_step
58
+ if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
59
+ cb_params.network.optimizer.global_step = cb_ctx.global_step
60
+ append_dict = {}
61
+ append_dict["__exception_save__"] = True
62
+ # if user has provided a custom save callback, use it
63
+ if cb_ctx.save_cb:
64
+ cb_ctx.save_cb(cb_params, append_dict)
65
+ logger.info("Finish _save_checkpoint_on_failure function")
66
+ return
48
67
 
68
+ # if user has not provided a custom save callback, use default save logic
49
69
  ckpt_save_path = cb_ctx.ckpt_save_path
50
- cb_params = args
51
70
  cur_rank = get_rank()
52
- cur_step_num = cb_params.cur_step_num
71
+ step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
53
72
  cur_epoch_num = cb_params.cur_epoch_num
54
- batch_num = cb_params.batch_num
55
- if cur_step_num > step:
56
- cur_epoch_num = (step - 1) // batch_num + 1
57
- step_num_in_epoch = int((step - 1) % batch_num + 1)
58
-
59
- append_dict = {}
60
73
  append_dict["epoch_num"] = cur_epoch_num
61
- append_dict["step_num"] = step
74
+ append_dict["step_num"] = cb_params.cur_step_num
62
75
  append_dict["cur_rank"] = cur_rank
63
- append_dict["batch_num"] = batch_num
64
- append_dict["__exception_save__"] = True
65
-
66
- append_dict["global_step"] = Parameter([cb_ctx.global_step])
76
+ append_dict["batch_num"] = cb_params.batch_num
77
+ append_dict["global_step"] = cb_ctx.global_step
67
78
  outputs = cb_params.net_outputs
68
79
  if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
69
80
  append_dict["loss_scale"] = outputs[2]
@@ -76,47 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
76
87
  integrated_save=False, append_dict=append_dict)
77
88
  logger.info("Finish _save_checkpoint_on_failure function")
78
89
 
90
+
79
91
  def _rename_save_result(step, cb_ctx):
80
92
  """ Callback used for TFT rename function after ckpt save callback was finished and successful."""
81
93
  logger.info("Enter _rename_save_result function")
94
+ if cb_ctx.save_cb:
95
+ logger.info("User's save callback is provided, skip rename")
96
+ return
82
97
  tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
83
98
  fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
84
99
 
85
100
  os.rename(tmp_dir, fin_dir)
86
101
  logger.info("Finish _rename_save_result function")
87
102
 
103
+
88
104
  def _tft_exit_cb(ctx):
105
+ """Callback used for TFT exit function."""
89
106
  logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
90
107
  _tft_sem_post()
91
- os._exit(1) # pylint: disable=W0212
108
+ os._exit(1) # pylint: disable=W0212
109
+
92
110
 
93
111
  def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
94
112
  """ Callback used for TFT repair function."""
95
- logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
96
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
97
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
98
- logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
113
+ logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
114
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
115
+ cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
116
+ logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
99
117
  _repair_device(cb_ctx.device_id)
100
118
 
101
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
102
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
103
- logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
104
- {}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
119
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
120
+ cb_ctx.tft.RepairType.RT_SEND.value,
121
+ cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
122
+ logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
123
+ repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
105
124
  cb_params = args
106
- src_rank = repair_info["src"][0]
107
- dst_rank = repair_info["dst"][0]
108
- send_recv(cb_params.network.trainable_params(), src_rank, dst_rank)
109
- logger.info("Finish _tft_repair_callback")
125
+ if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
126
+ for i in range(len(repair_info["src"])):
127
+ src_rank = repair_info["src"][i]
128
+ dst_rank = repair_info["dst"][i]
129
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
130
+ raise ValueError("Call send_recv failed.")
131
+ else:
132
+ src_rank = repair_info["src"][0]
133
+ dst_rank = repair_info["dst"][0]
134
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
135
+ raise ValueError("Call send_recv failed.")
136
+ logger.warning("Finish _tft_repair_callback")
110
137
 
111
138
 
112
139
  def _tft_clean_callback(is_uce_error, args, ctx):
113
140
  """ Callback used for TFT clean function."""
114
- logger.info("Enter _tft_clean_callback")
141
+ logger.warning("Enter _tft_clean_callback")
115
142
  ret = 0
116
143
  if is_uce_error:
117
144
  _get_uce_mem_info(ctx.device_id)
118
145
  err_strategy = _get_uce_process_strategy()
119
- logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
146
+ logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
120
147
  if err_strategy == "RS_UCE_HIGHLEVEL":
121
148
  ret = 0
122
149
  elif err_strategy == "RS_UCE_LOWLEVEL":
@@ -124,59 +151,81 @@ def _tft_clean_callback(is_uce_error, args, ctx):
124
151
  else:
125
152
  ret = 1
126
153
  clean_tdt_channel()
127
- logger.info("Enter _tft_clean_callback resume_hccl_comm")
154
+ logger.warning("Enter _tft_clean_callback resume_hccl_comm")
128
155
  CollectiveManager.get_instance().resume_hccl_comm()
129
- logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
156
+ logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
130
157
  return ret
131
158
 
132
159
 
133
160
  def _tft_stop_callback(args, cb_ctx):
134
161
  """ Callback used for TFT stop function."""
135
- logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
162
+ logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
136
163
  _stop_device(cb_ctx.device_id)
137
- if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
164
+ if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
138
165
  raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
139
166
  cb_ctx.is_uce_rank = False
167
+ if cb_ctx.tft.tft_get_repair_type() == "recover":
168
+ logger.warning(f"Reset limit step")
169
+ cb_ctx.tft.tft_reset_limit_step()
140
170
  logger.info("Finish _tft_stop_callback")
141
171
 
142
172
 
143
- class TFTRegister(Callback):
173
+ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
174
+ """Callback used for TFT Rebuild Group function."""
175
+ logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
176
+ _finalize_comm()
177
+ _rebuild_world_group()
178
+ _rebuild_sub_group()
179
+ _set_recovery_context(is_arf=True)
180
+ logger.warning("Enter _tft_rebuild_sub_groups ok ")
181
+
182
+
183
+ class TrainFaultTolerance(Callback):
144
184
  """
145
185
  This callback is used to enable the TFT feature
146
- `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
147
- This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
186
+ `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
187
+ and will execute TFT operations during training process, such as TFT init, report and exception handle.
148
188
 
149
189
  Note:
150
190
  Required for Ascend graph mode only. And sink size must be less than or equal to 1.
151
191
 
152
192
  Args:
153
- ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
154
- ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
155
- ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
156
- ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
157
- named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
193
+ ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
194
+ a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
195
+ is created in that directory. Default: ``None``.
196
+ kwargs (dict): Other dictionary type parameters.
158
197
 
159
198
  Raises:
160
199
  Exception: TFT init failed.
161
200
  ModuleNotFoundError: Mindio TFT whl package is not installed.
162
201
 
163
202
  Examples:
203
+ .. note::
204
+ Before running the following examples, you need to configure the communication environment variables.
205
+
206
+ It's recommended to use the msrun startup method.
207
+ Please see the `msrun start up
208
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
209
+ for more details.
210
+
211
+ This example should be run with 4 devices.
212
+
164
213
  >>> import numpy as np
165
214
  >>> import os
166
215
  >>> import math
167
216
  >>> import mindspore as ms
168
217
  >>> import mindspore.dataset as ds
169
218
  >>> from mindspore import nn, ops, Parameter, train
170
- >>> from mindspore.communication import init
219
+ >>> from mindspore.communication import init, get_rank
171
220
  >>> from mindspore.common.initializer import initializer, HeUniform
172
- >>> from mindspore.train import Model, TFTRegister
221
+ >>> from mindspore.train import Model, TrainFaultTolerance
173
222
  >>> from mindspore import dataset as ds
174
223
  >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
175
224
  >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
176
225
  >>> init()
177
226
  >>> ms.set_seed(1)
178
227
  >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
179
- >>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
228
+ ... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
180
229
  >>> class MatMulCell(nn.Cell):
181
230
  ... def __init__(self, param=None, shape=None):
182
231
  ... super().__init__()
@@ -234,48 +283,74 @@ class TFTRegister(Callback):
234
283
  ... dataset = dataset.batch(batch_size)
235
284
  ... return dataset
236
285
  >>>
237
- >>> data_set = create_dataset(32)
286
+ >>> dataset = create_dataset(32)
238
287
  >>>
239
288
  >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
240
289
  >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
241
290
  >>> loss_fn = nn.CrossEntropyLoss()
242
291
  >>>
243
- >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
292
+ >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
244
293
  >>> net_with_loss.set_train()
245
- >>> model = Model(net_with_loss, optimizer=optimizer)
246
- >>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
294
+ >>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
295
+ >>> tft_cb = TrainFaultTolerance()
247
296
  >>> loss_cb = train.LossMonitor(1)
248
297
  >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
249
298
  """
250
299
 
251
- def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
252
- super(TFTRegister, self).__init__()
253
-
254
- tft_env = os.getenv("MS_ENABLE_TFT", "")
255
- if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
256
- raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
257
- mode = context.get_context("mode")
258
- device_target = context.get_context("device_target")
259
- if device_target != "Ascend" or mode != context.GRAPH_MODE:
260
- raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
261
-
262
- # let it raise errors if not install mindio_tft package
263
- from mindio_ttp import framework_ttp as tft
264
- self.tft = tft
265
- self.is_uce_rank = False
266
- self.global_step = 0
267
- Validator.check_non_negative_int(ctrl_port)
268
- self.has_init_replica = False
269
- self._controller_ip = ctrl_ip
270
- self._controller_rank_id = ctrl_rank_id
271
- self._controller_port = ctrl_port
300
+ def __init__(self, ckpt_save_path=None, **kwargs):
301
+ super(TrainFaultTolerance, self).__init__()
302
+ self.save_cb = kwargs.get("ckpt_save_fn", None)
303
+ self.ckpt_save_path = ckpt_save_path
304
+ if self.save_cb is None and self.ckpt_save_path is None:
305
+ raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
272
306
  self.cb_params = None
307
+ self.initial_step = kwargs.get("initial_step", 0)
273
308
  self.device_id = context.get_context("device_id")
274
- self._init_tft()
275
- self.ckpt_save_path = ckpt_save_path
309
+ self.cur_step_num = 0
310
+ self.cur_epoch_num = 0
311
+ # For TREError(Training Result Error) scene, parameter `ckpt_load_fn` must be provided to load checkpoint
312
+ # from file for resuming training, the `ckpt_load_fn` is a function, prototype of which is:
313
+ # `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
314
+ # i.e. (param_dict, remove_redundancy)
315
+ self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
316
+ self.tft = _tft_handler.get_tft()
317
+ if self._only_enable_tre():
318
+ return
319
+ self._check_init()
320
+ self.global_step = None
321
+ self.learning_rate = None
322
+ self.has_init_replica = False
323
+ self.is_uce_rank = False
324
+
276
325
  self.assign = mindspore.ops.Assign()
277
326
  self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
278
327
  self.s1 = mindspore.hal.Stream()
328
+ _tft_sem_enable()
329
+ self._tft_register()
330
+
331
+ def _only_enable_tre(self):
332
+ """Check if only configured MS_ENABLE_TFT='{TRE:1}'"""
333
+ env_enable = os.getenv("MS_ENABLE_TFT", "")
334
+ non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
335
+ if any(flag in env_enable for flag in non_tre_flags):
336
+ return False
337
+ return "TRE:1" in env_enable
338
+
339
+ def _check_init(self):
340
+ """Check if the mindio-ttp had inited"""
341
+ if self.tft is None:
342
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
343
+ if "ARF:1" in tft_env:
344
+ raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
345
+ logger.warning(f"TFT handle not init, try to init")
346
+ _tft_handler.init(config=None)
347
+ self.tft = _tft_handler.get_tft()
348
+ logger.warning(f"TFT handle init ok.")
349
+ mode = context.get_context("mode")
350
+ device_target = context.get_context("device_target")
351
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
352
+ raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
353
+ f"device:{device_target}, run mode: {mode}")
279
354
 
280
355
  def _is_params_consistent(self):
281
356
  for key, param in self.cb_params.train_network.parameters_and_names():
@@ -287,7 +362,7 @@ class TFTRegister(Callback):
287
362
  return False
288
363
 
289
364
  def _set_tft_optimizer_replica(self, run_context):
290
- """ set Mindio TFT optimizer replica info, used internal. """
365
+ """ Set Mindio TFT optimizer replica info, used internal. """
291
366
  cur_rank = get_rank()
292
367
  cb_params = run_context.original_args()
293
368
  train_network = cb_params.train_network
@@ -309,59 +384,98 @@ class TFTRegister(Callback):
309
384
  ]
310
385
  self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
311
386
 
312
- def _init_tft(self):
313
- """ Init Mindio TFT, used internal. """
314
- logger.info("Begin to init tft.")
387
+ @classmethod
388
+ def get_optimizer_wrapper(cls, origin_opt_cls):
389
+ """
390
+ Optimizer wrapper func when using tft.
391
+
392
+ Args:
393
+ origin_opt_cls (Class): origin optimizer class.
394
+ """
395
+
396
+ class TFTOptSubCls(origin_opt_cls):
397
+ """
398
+ Optimizer wrapper class when using tft.
399
+ """
400
+
401
+ def __init__(self, *args, **kwargs):
402
+ super(TFTOptSubCls, self).__init__(*args, **kwargs)
403
+ self.report = TensorReport()
404
+ self.report_end = TensorReport()
405
+ self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
406
+ self.depend = ops.Depend()
407
+ self.allreduce_sum = ops.AllReduce()
408
+ self.allreduce_sum.add_prim_attr("tft_report_before", True)
409
+ self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
410
+
411
+ def construct(self, gradients, **kwargs):
412
+ tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
413
+ self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
414
+ grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
415
+ opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
416
+ self.report_end("tft_report", self.tft_g_one_flag)
417
+ return opt_ret
418
+
419
+ return TFTOptSubCls
420
+
421
+ def _tft_register(self):
422
+ """Register callback functions."""
315
423
  self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
316
424
  self.tft.tft_register_rename_handler(_rename_save_result, self)
317
425
  self.tft.tft_register_exit_handler(_tft_exit_cb, self)
318
426
  self.tft.tft_register_stop_handler(_tft_stop_callback, self)
319
427
  self.tft.tft_register_clean_handler(_tft_clean_callback, self)
320
428
  self.tft.tft_register_repair_handler(_tft_repair_callback, self)
429
+ self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
321
430
 
322
- world_size = _get_device_num()
323
- cur_rank = get_rank()
324
- enable_local_copy = False
325
- enable_arf = False
326
- enable_tls = False
327
- tls_key_dir = ""
328
-
329
- if cur_rank == self._controller_rank_id:
330
- logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
331
- self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
332
- self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
333
- logger.info("Finish start tft controller.")
334
-
335
- logger.info("Begin to start tft processor.")
336
- self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
337
- self.tft.tft_start_processor(self._controller_ip, self._controller_port)
338
- logger.info("Finished start tft processor.")
431
+ def _reset_acc_grads(self):
432
+ accu_grad_params = map(lambda e: e[1],
433
+ filter(lambda e: e[1].name.startswith('accu_grads'),
434
+ self.cb_params.train_network.parameters_and_names()))
435
+ accu_grad_list = list(accu_grad_params)
436
+ if reset_params(accu_grad_list) != 0:
437
+ raise ValueError("Call reset_params failed.")
339
438
 
340
439
  def on_train_step_end(self, run_context):
341
440
  """
342
- And report status to MindIO TFT after every step finished.
441
+ Report status to MindIO TFT after every step finished.
343
442
 
344
443
  Args:
345
444
  run_context (RunContext): Context of the train running. Refer to
346
445
  :class:`mindspore.train.RunContext` for detail.
347
446
  """
447
+ if self._only_enable_tre():
448
+ return
348
449
  if self.has_init_replica is False:
349
450
  self.has_init_replica = True
350
451
  self._set_tft_optimizer_replica(run_context)
351
452
  cb_params = run_context.original_args()
352
453
  logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
353
- self.tft.tft_end_updating_os(cb_params.cur_step_num)
454
+ self.cur_step_num = cb_params.cur_step_num
455
+ self.cur_epoch_num = cb_params.cur_epoch_num
354
456
  if cb_params.optimizer is not None:
355
- self.global_step = int(cb_params.optimizer.global_step.data)
457
+ self.global_step = cb_params.optimizer.global_step.clone()
356
458
  self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
357
- else:
358
- self.global_step = int(cb_params.network.optimizer.global_step.data)
459
+ elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
460
+ self.global_step = cb_params.network.optimizer.global_step.clone()
359
461
  self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
462
+ else:
463
+ raise ValueError("TFT feature need optimizer or network's optimizer!")
464
+ self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
360
465
  logger.info("END Set optimizer finish step status to TFT.")
361
466
 
362
-
363
467
  def on_train_begin(self, run_context):
468
+ """
469
+ Register train params to MindIO TFT on train beginning.
470
+
471
+ Args:
472
+ run_context (RunContext): Context of the train running. Refer to
473
+ :class:`mindspore.train.RunContext` for detail.
474
+ """
364
475
  cb_params = run_context.original_args()
476
+ if self._only_enable_tre():
477
+ self.cb_params = cb_params
478
+ return
365
479
  sink_size = cb_params.get("sink_size", 0)
366
480
  if sink_size > 1:
367
481
  raise ValueError("TFT feature doesn't support sink_size > 1.")
@@ -370,7 +484,13 @@ class TFTRegister(Callback):
370
484
  self.cb_params = cb_params
371
485
 
372
486
  def end(self, run_context):
373
- cur_rank = get_rank()
374
- if cur_rank == self._controller_rank_id:
375
- self.tft.tft_destroy_controller()
376
- self.tft.tft_destroy_processor()
487
+ """
488
+ Unregister MindIO TFT on train end.
489
+
490
+ Args:
491
+ run_context (RunContext): Context of the train running. Refer to
492
+ :class:`mindspore.train.RunContext` for detail.
493
+ """
494
+ if self._only_enable_tre():
495
+ return
496
+ _tft_handler.unregister_tft()
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
98
98
  return next_op, (key, dataset_shapes, dataset_types)
99
99
 
100
100
 
101
+ def _get_jit_func(sink_fun, jit_config):
102
+ """
103
+ Get the jit function.
104
+ """
105
+ jit_config_dict = jit_config.jit_config_dict
106
+ jit_level = jit_config_dict['jit_level']
107
+ if jit_level == "":
108
+ jit_level = "O0"
109
+ backend = ""
110
+ if jit_level == "O2":
111
+ jit_level = "O0"
112
+ backend = "GE"
113
+ if "backend" in jit_config_dict:
114
+ backend = jit_config_dict["backend"]
115
+ fullgraph = False
116
+ if jit_config_dict['jit_syntax_level'] == "STRICT":
117
+ fullgraph = True
118
+ exc_mode = jit_config_dict['exc_mode']
119
+ infer_boost = jit_config_dict['infer_boost']
120
+ return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
121
+ infer_boost=infer_boost)
122
+
123
+
101
124
  def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
102
125
  """
103
126
  get the sink function.
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
107
130
  if jit_config is None:
108
131
  dst_sink_fun = sink_fun
109
132
  else:
110
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
133
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
111
134
  dataset.__sink_fun__ = dst_sink_fun
112
135
 
113
136
  return dataset.__sink_fun__
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
119
142
  if jit_config is None:
120
143
  dst_sink_fun = sink_fun
121
144
  else:
122
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
145
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
123
146
  dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
124
147
 
125
148
  return dst_sink_fun
@@ -15,7 +15,6 @@
15
15
  """Dataset help for minddata dataset"""
16
16
  from __future__ import absolute_import
17
17
 
18
- import os
19
18
  import math
20
19
  import copy
21
20
 
@@ -25,9 +24,11 @@ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
25
24
  from mindspore.common.dtype import pytype_to_dtype
26
25
  from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
27
26
  from mindspore.common._utils import is_shape_unknown
27
+ from mindspore.dataset.core import config as dataset_config
28
28
  from mindspore.dataset.engine import offload
29
29
  from mindspore import context, nn
30
- from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
30
+ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
31
+ _construct_tensor_list, enable_data_broadcast
31
32
  from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
32
33
  _to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \
33
34
  _origin_shapes, _dynamic_shape_for_dataset
@@ -213,8 +214,7 @@ def _get_dataset_aux(dataset):
213
214
 
214
215
  def connect_network_with_dataset(network, dataset_helper):
215
216
  """
216
- Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
217
- <https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
217
+ Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
218
218
  (dataset_sink_mode=True).
219
219
 
220
220
  Args:
@@ -263,16 +263,14 @@ def connect_network_with_dataset(network, dataset_helper):
263
263
  "The dataset has been connected to other network, please check the code.")
264
264
  is_dynamic = bool(network.get_inputs())
265
265
  queue_name = dataset.__transfer_dataset__.queue_name
266
+
266
267
  # In pipeline parallel, some stages have no GetNext, should not get in.
268
+ # Don't enable dynamic shape(multi-subgraph) feature in pp/dataset_broadcast mode,
269
+ # otherwise get_data_info will stuck since some rank do not consume data.
267
270
  use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
271
+ data_broadcast = enable_data_broadcast()
268
272
 
269
- # temp env to disable dynamic feature of sink size 1
270
- dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
271
- dynamic_sink1 = True
272
- if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
273
- dynamic_sink1 = False
274
-
275
- if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and dynamic_sink1:
273
+ if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and not data_broadcast:
276
274
  dataset_types, dataset_shapes = dataset_helper.get_data_info()
277
275
  # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
278
276
  if _need_to_full():
@@ -314,7 +312,7 @@ def connect_network_with_dataset(network, dataset_helper):
314
312
  aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
315
313
 
316
314
  if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
317
- not use_pipeline_parallel and dynamic_sink1:
315
+ not use_pipeline_parallel and not data_broadcast:
318
316
  dataset_helper.get_data_info()
319
317
  network.add_flags(sink_mode=True)
320
318
  return network
@@ -336,11 +334,11 @@ class DatasetHelper:
336
334
  dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
337
335
  dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
338
336
  Default: ``True``.
339
- sink_size (int): Control the amount of data in each sink.
337
+ sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
340
338
  If sink_size=-1, sink the complete dataset for each epoch.
341
339
  If sink_size>0, sink sink_size data for each epoch.
342
- Default: -1.
343
- epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1.
340
+ Default: ``-1``.
341
+ epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
344
342
 
345
343
  Examples:
346
344
  >>> import numpy as np
@@ -686,8 +684,9 @@ class _DatasetIterNormal:
686
684
  self.dataset = dataset
687
685
  self.device_num = _get_device_num()
688
686
  self.global_rank = _get_global_rank()
687
+ do_copy = dataset_config.get_iterator_mode()["do_copy"]
689
688
  self.iter = self.dataset.create_tuple_iterator(
690
- num_epochs=epoch_num, do_copy=True)
689
+ num_epochs=epoch_num, do_copy=do_copy)
691
690
 
692
691
  def __iter__(self):
693
692
  return self