mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (493) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +25 -194
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +109 -75
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +2014 -3386
  46. mindspore/common/api.py +386 -355
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/generator.py +3 -0
  52. mindspore/common/hook_handle.py +5 -3
  53. mindspore/common/initializer.py +10 -6
  54. mindspore/common/jit_begin_end.py +94 -0
  55. mindspore/common/jit_config.py +6 -1
  56. mindspore/common/jit_context.py +76 -0
  57. mindspore/common/jit_trace.py +378 -0
  58. mindspore/common/lazy_inline.py +2 -2
  59. mindspore/common/mutable.py +5 -4
  60. mindspore/common/parameter.py +106 -39
  61. mindspore/common/seed.py +2 -2
  62. mindspore/common/sparse_tensor.py +23 -17
  63. mindspore/common/tensor.py +332 -714
  64. mindspore/communication/__init__.py +7 -5
  65. mindspore/communication/_comm_helper.py +47 -2
  66. mindspore/communication/comm_func.py +70 -53
  67. mindspore/communication/management.py +83 -17
  68. mindspore/context.py +228 -571
  69. mindspore/dataset/__init__.py +44 -20
  70. mindspore/dataset/audio/__init__.py +2 -8
  71. mindspore/dataset/audio/transforms.py +3 -17
  72. mindspore/dataset/core/config.py +3 -3
  73. mindspore/dataset/engine/cache_client.py +1 -1
  74. mindspore/dataset/engine/datasets.py +102 -120
  75. mindspore/dataset/engine/datasets_audio.py +22 -22
  76. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  77. mindspore/dataset/engine/datasets_text.py +78 -85
  78. mindspore/dataset/engine/datasets_user_defined.py +109 -77
  79. mindspore/dataset/engine/datasets_vision.py +111 -108
  80. mindspore/dataset/engine/iterators.py +5 -3
  81. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  82. mindspore/dataset/engine/samplers.py +279 -57
  83. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  84. mindspore/dataset/engine/validators.py +10 -0
  85. mindspore/dataset/text/__init__.py +7 -6
  86. mindspore/dataset/text/transforms.py +6 -5
  87. mindspore/dataset/text/utils.py +3 -3
  88. mindspore/dataset/transforms/__init__.py +0 -9
  89. mindspore/dataset/transforms/transforms.py +3 -3
  90. mindspore/dataset/utils/browse_dataset.py +1 -1
  91. mindspore/dataset/vision/__init__.py +2 -9
  92. mindspore/dataset/vision/transforms.py +202 -158
  93. mindspore/dataset/vision/utils.py +7 -5
  94. mindspore/device_context/ascend/op_debug.py +60 -1
  95. mindspore/device_context/ascend/op_tuning.py +0 -4
  96. mindspore/device_manager.py +39 -3
  97. mindspore/dnnl.dll +0 -0
  98. mindspore/dpcmi.dll +0 -0
  99. mindspore/experimental/es/embedding_service.py +35 -27
  100. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
  101. mindspore/experimental/map_parameter.py +4 -4
  102. mindspore/experimental/optim/adadelta.py +22 -26
  103. mindspore/experimental/optim/adagrad.py +4 -4
  104. mindspore/experimental/optim/adam.py +4 -0
  105. mindspore/experimental/optim/adamax.py +4 -4
  106. mindspore/experimental/optim/adamw.py +4 -0
  107. mindspore/experimental/optim/asgd.py +1 -1
  108. mindspore/experimental/optim/lr_scheduler.py +40 -22
  109. mindspore/experimental/optim/radam.py +5 -5
  110. mindspore/experimental/optim/rprop.py +1 -1
  111. mindspore/experimental/optim/sgd.py +1 -1
  112. mindspore/hal/contiguous_tensors_handle.py +6 -10
  113. mindspore/hal/device.py +55 -81
  114. mindspore/hal/event.py +38 -55
  115. mindspore/hal/memory.py +115 -147
  116. mindspore/hal/stream.py +81 -125
  117. mindspore/include/dataset/constants.h +7 -4
  118. mindspore/include/dataset/execute.h +2 -2
  119. mindspore/jpeg62.dll +0 -0
  120. mindspore/log.py +40 -2
  121. mindspore/mindrecord/__init__.py +20 -7
  122. mindspore/mindspore_backend_common.dll +0 -0
  123. mindspore/mindspore_backend_manager.dll +0 -0
  124. mindspore/mindspore_common.dll +0 -0
  125. mindspore/mindspore_core.dll +0 -0
  126. mindspore/mindspore_dump.dll +0 -0
  127. mindspore/mindspore_frontend.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_memory_pool.dll +0 -0
  130. mindspore/mindspore_ms_backend.dll +0 -0
  131. mindspore/mindspore_ops.dll +0 -0
  132. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  133. mindspore/mindspore_ops_kernel_common.dll +0 -0
  134. mindspore/mindspore_profiler.dll +0 -0
  135. mindspore/mindspore_pyboost.dll +0 -0
  136. mindspore/mindspore_pynative.dll +0 -0
  137. mindspore/mindspore_res_manager.dll +0 -0
  138. mindspore/mindspore_runtime_pipeline.dll +0 -0
  139. mindspore/mint/__init__.py +133 -702
  140. mindspore/mint/distributed/__init__.py +5 -1
  141. mindspore/mint/distributed/distributed.py +198 -113
  142. mindspore/mint/linalg/__init__.py +2 -0
  143. mindspore/mint/nn/__init__.py +280 -18
  144. mindspore/mint/nn/functional.py +282 -64
  145. mindspore/mint/nn/layer/__init__.py +4 -0
  146. mindspore/mint/nn/layer/_functions.py +7 -3
  147. mindspore/mint/nn/layer/activation.py +120 -13
  148. mindspore/mint/nn/layer/conv.py +234 -28
  149. mindspore/mint/nn/layer/normalization.py +15 -16
  150. mindspore/mint/nn/layer/padding.py +1 -1
  151. mindspore/mint/nn/layer/pooling.py +66 -1
  152. mindspore/mint/optim/__init__.py +2 -1
  153. mindspore/mint/optim/sgd.py +171 -0
  154. mindspore/msobj140.dll +0 -0
  155. mindspore/mspdb140.dll +0 -0
  156. mindspore/mspdbcore.dll +0 -0
  157. mindspore/mspdbst.dll +0 -0
  158. mindspore/mspft140.dll +0 -0
  159. mindspore/msvcdis140.dll +0 -0
  160. mindspore/msvcp140_1.dll +0 -0
  161. mindspore/msvcp140_2.dll +0 -0
  162. mindspore/msvcp140_atomic_wait.dll +0 -0
  163. mindspore/msvcp140_codecvt_ids.dll +0 -0
  164. mindspore/nn/__init__.py +4 -1
  165. mindspore/nn/cell.py +1253 -179
  166. mindspore/nn/layer/activation.py +23 -21
  167. mindspore/nn/layer/basic.py +22 -16
  168. mindspore/nn/layer/container.py +1 -1
  169. mindspore/nn/layer/conv.py +53 -42
  170. mindspore/nn/layer/embedding.py +9 -8
  171. mindspore/nn/layer/normalization.py +48 -42
  172. mindspore/nn/layer/pooling.py +75 -31
  173. mindspore/nn/layer/transformer.py +11 -10
  174. mindspore/nn/learning_rate_schedule.py +4 -2
  175. mindspore/nn/loss/loss.py +27 -19
  176. mindspore/nn/optim/ada_grad.py +6 -5
  177. mindspore/nn/optim/adadelta.py +9 -7
  178. mindspore/nn/optim/adafactor.py +1 -1
  179. mindspore/nn/optim/adam.py +18 -14
  180. mindspore/nn/optim/adamax.py +8 -7
  181. mindspore/nn/optim/adasum.py +5 -5
  182. mindspore/nn/optim/asgd.py +3 -1
  183. mindspore/nn/optim/ftrl.py +11 -9
  184. mindspore/nn/optim/lamb.py +1 -1
  185. mindspore/nn/optim/lazyadam.py +12 -10
  186. mindspore/nn/optim/momentum.py +7 -6
  187. mindspore/nn/optim/optimizer.py +2 -2
  188. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  189. mindspore/nn/optim/rmsprop.py +13 -12
  190. mindspore/nn/optim/rprop.py +9 -7
  191. mindspore/nn/optim/sgd.py +9 -6
  192. mindspore/nn/optim/tft_wrapper.py +5 -2
  193. mindspore/nn/probability/bijector/bijector.py +17 -11
  194. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  195. mindspore/nn/probability/bijector/invert.py +2 -2
  196. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  197. mindspore/nn/probability/bijector/softplus.py +3 -2
  198. mindspore/nn/probability/distribution/beta.py +3 -3
  199. mindspore/nn/probability/distribution/categorical.py +1 -1
  200. mindspore/nn/probability/distribution/cauchy.py +4 -2
  201. mindspore/nn/probability/distribution/exponential.py +6 -7
  202. mindspore/nn/probability/distribution/gamma.py +2 -2
  203. mindspore/nn/probability/distribution/gumbel.py +2 -2
  204. mindspore/nn/probability/distribution/half_normal.py +5 -3
  205. mindspore/nn/probability/distribution/logistic.py +5 -3
  206. mindspore/nn/probability/distribution/poisson.py +1 -1
  207. mindspore/nn/probability/distribution/uniform.py +5 -3
  208. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  209. mindspore/nn/reinforcement/tensor_array.py +1 -1
  210. mindspore/nn/wrap/__init__.py +6 -6
  211. mindspore/nn/wrap/cell_wrapper.py +178 -117
  212. mindspore/nn/wrap/grad_reducer.py +45 -36
  213. mindspore/nn/wrap/loss_scale.py +3 -3
  214. mindspore/numpy/array_creations.py +3 -3
  215. mindspore/numpy/array_ops.py +1 -1
  216. mindspore/numpy/utils.py +1 -2
  217. mindspore/numpy/utils_const.py +1 -2
  218. mindspore/opencv_core452.dll +0 -0
  219. mindspore/opencv_imgcodecs452.dll +0 -0
  220. mindspore/opencv_imgproc452.dll +0 -0
  221. mindspore/ops/__init__.py +3 -2
  222. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  223. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  224. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  225. mindspore/ops/_register_for_op.py +0 -11
  226. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  227. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  228. mindspore/ops/_vmap/vmap_array_ops.py +32 -6
  229. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  230. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  231. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  232. mindspore/ops/auto_generate/__init__.py +4 -3
  233. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
  234. mindspore/ops/auto_generate/gen_extend_func.py +286 -208
  235. mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
  236. mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
  237. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  238. mindspore/ops/composite/__init__.py +2 -1
  239. mindspore/ops/composite/base.py +19 -24
  240. mindspore/ops/composite/math_ops.py +6 -16
  241. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  242. mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
  243. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  244. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  248. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  249. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  250. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  251. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  252. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  254. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  255. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  256. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  257. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  259. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  260. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  263. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  264. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  268. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  271. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  272. mindspore/ops/function/__init__.py +28 -2
  273. mindspore/ops/function/_add_attr_func.py +58 -0
  274. mindspore/ops/function/array_func.py +1631 -2347
  275. mindspore/ops/function/clip_func.py +38 -45
  276. mindspore/ops/function/debug_func.py +36 -44
  277. mindspore/ops/function/grad/__init__.py +1 -0
  278. mindspore/ops/function/grad/grad_func.py +104 -71
  279. mindspore/ops/function/image_func.py +1 -1
  280. mindspore/ops/function/linalg_func.py +46 -78
  281. mindspore/ops/function/math_func.py +3024 -3855
  282. mindspore/ops/function/nn_func.py +678 -274
  283. mindspore/ops/function/other_func.py +159 -1
  284. mindspore/ops/function/parameter_func.py +17 -30
  285. mindspore/ops/function/random_func.py +216 -361
  286. mindspore/ops/function/reshard_func.py +4 -70
  287. mindspore/ops/function/sparse_func.py +3 -3
  288. mindspore/ops/function/sparse_unary_func.py +5 -5
  289. mindspore/ops/function/spectral_func.py +25 -58
  290. mindspore/ops/function/vmap_func.py +26 -18
  291. mindspore/ops/functional.py +8 -5
  292. mindspore/ops/functional_overload.py +655 -4
  293. mindspore/ops/op_info_register.py +32 -244
  294. mindspore/ops/operations/__init__.py +21 -14
  295. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  296. mindspore/ops/operations/_grad_ops.py +1 -10
  297. mindspore/ops/operations/_inner_ops.py +5 -76
  298. mindspore/ops/operations/_ms_kernel.py +4 -10
  299. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  300. mindspore/ops/operations/_scalar_ops.py +3 -2
  301. mindspore/ops/operations/_sequence_ops.py +1 -1
  302. mindspore/ops/operations/_tensor_array.py +1 -1
  303. mindspore/ops/operations/array_ops.py +39 -24
  304. mindspore/ops/operations/comm_ops.py +150 -107
  305. mindspore/ops/operations/custom_ops.py +287 -32
  306. mindspore/ops/operations/debug_ops.py +119 -16
  307. mindspore/ops/operations/inner_ops.py +1 -1
  308. mindspore/ops/operations/linalg_ops.py +1 -58
  309. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  310. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  311. mindspore/ops/operations/math_ops.py +21 -18
  312. mindspore/ops/operations/nn_ops.py +67 -224
  313. mindspore/ops/operations/other_ops.py +62 -9
  314. mindspore/ops/operations/random_ops.py +13 -7
  315. mindspore/ops/operations/reshard_ops.py +1 -1
  316. mindspore/ops/operations/sparse_ops.py +2 -2
  317. mindspore/ops/primitive.py +43 -32
  318. mindspore/ops/tensor_method.py +243 -17
  319. mindspore/ops_generate/__init__.py +0 -5
  320. mindspore/ops_generate/aclnn/__init__.py +0 -0
  321. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  322. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  323. mindspore/ops_generate/api/__init__.py +0 -0
  324. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  325. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  326. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  327. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  328. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  329. mindspore/ops_generate/api/gen_api.py +103 -0
  330. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  331. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  332. mindspore/ops_generate/common/__init__.py +0 -0
  333. mindspore/ops_generate/common/gen_constants.py +91 -0
  334. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  335. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  336. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  337. mindspore/ops_generate/gen_ops.py +23 -325
  338. mindspore/ops_generate/op_def/__init__.py +0 -0
  339. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  340. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  341. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
  342. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  343. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  344. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  345. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  346. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  347. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  348. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  349. mindspore/ops_generate/pyboost/__init__.py +0 -0
  350. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  351. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  352. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  353. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  354. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  355. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  356. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  357. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  358. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  359. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  360. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  361. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  362. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  363. mindspore/ops_generate/resources/__init__.py +0 -0
  364. mindspore/ops_generate/resources/resource_list.py +30 -0
  365. mindspore/ops_generate/resources/resource_loader.py +36 -0
  366. mindspore/ops_generate/resources/resource_manager.py +64 -0
  367. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  368. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  369. mindspore/parallel/__init__.py +6 -2
  370. mindspore/parallel/_auto_parallel_context.py +140 -12
  371. mindspore/parallel/_cell_wrapper.py +132 -15
  372. mindspore/parallel/_parallel_serialization.py +95 -4
  373. mindspore/parallel/_ps_context.py +1 -1
  374. mindspore/parallel/_recovery_context.py +7 -2
  375. mindspore/parallel/_tensor.py +142 -18
  376. mindspore/parallel/_utils.py +198 -25
  377. mindspore/parallel/algo_parameter_config.py +3 -3
  378. mindspore/parallel/auto_parallel.py +732 -0
  379. mindspore/parallel/checkpoint_convert.py +159 -0
  380. mindspore/parallel/checkpoint_transform.py +658 -37
  381. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  382. mindspore/parallel/cluster/run.py +1 -1
  383. mindspore/parallel/function/__init__.py +24 -0
  384. mindspore/parallel/function/reshard_func.py +258 -0
  385. mindspore/parallel/nn/__init__.py +25 -0
  386. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  387. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  388. mindspore/parallel/parameter_broadcast.py +24 -13
  389. mindspore/parallel/shard.py +137 -62
  390. mindspore/parallel/transform_safetensors.py +288 -95
  391. mindspore/pgodb140.dll +0 -0
  392. mindspore/pgort140.dll +0 -0
  393. mindspore/profiler/__init__.py +9 -5
  394. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  395. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  397. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
  398. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  399. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  400. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  401. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  402. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  403. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  404. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  405. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  406. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  407. mindspore/profiler/common/constant.py +12 -0
  408. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  409. mindspore/profiler/common/path_manager.py +24 -0
  410. mindspore/profiler/common/profiler_context.py +26 -2
  411. mindspore/profiler/common/profiler_meta_data.py +74 -0
  412. mindspore/profiler/common/profiler_parameters.py +59 -18
  413. mindspore/profiler/common/profiler_path_manager.py +66 -7
  414. mindspore/profiler/dynamic_profiler.py +112 -79
  415. mindspore/profiler/envprofiler.py +26 -1
  416. mindspore/profiler/experimental_config.py +197 -0
  417. mindspore/profiler/mstx.py +57 -14
  418. mindspore/profiler/platform/npu_profiler.py +33 -7
  419. mindspore/profiler/profiler.py +541 -45
  420. mindspore/profiler/profiler_action_controller.py +1 -1
  421. mindspore/profiler/profiler_interface.py +4 -0
  422. mindspore/profiler/schedule.py +57 -22
  423. mindspore/rewrite/api/node.py +15 -13
  424. mindspore/rewrite/api/symbol_tree.py +1 -1
  425. mindspore/run_check/_check_version.py +25 -14
  426. mindspore/run_check/run_check.py +1 -1
  427. mindspore/runtime/__init__.py +2 -2
  428. mindspore/runtime/executor.py +40 -11
  429. mindspore/runtime/memory.py +37 -13
  430. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  431. mindspore/swresample-4.dll +0 -0
  432. mindspore/swscale-6.dll +0 -0
  433. mindspore/tbbmalloc.dll +0 -0
  434. mindspore/tinyxml2.dll +0 -0
  435. mindspore/train/__init__.py +8 -8
  436. mindspore/train/_utils.py +43 -9
  437. mindspore/train/amp.py +1 -1
  438. mindspore/train/callback/__init__.py +2 -2
  439. mindspore/train/callback/_callback.py +2 -16
  440. mindspore/train/callback/_checkpoint.py +24 -40
  441. mindspore/train/callback/_cluster_monitor.py +14 -18
  442. mindspore/train/callback/_flops_collector.py +2 -3
  443. mindspore/train/callback/_history.py +7 -4
  444. mindspore/train/callback/_lambda_callback.py +2 -2
  445. mindspore/train/callback/_landscape.py +0 -3
  446. mindspore/train/callback/_loss_monitor.py +2 -1
  447. mindspore/train/callback/_on_request_exit.py +6 -5
  448. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  449. mindspore/train/callback/_summary_collector.py +8 -13
  450. mindspore/train/callback/_time_monitor.py +2 -1
  451. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
  452. mindspore/train/data_sink.py +25 -2
  453. mindspore/train/dataset_helper.py +4 -5
  454. mindspore/train/loss_scale_manager.py +8 -7
  455. mindspore/train/metrics/accuracy.py +3 -3
  456. mindspore/train/metrics/confusion_matrix.py +9 -9
  457. mindspore/train/metrics/error.py +3 -3
  458. mindspore/train/metrics/hausdorff_distance.py +4 -4
  459. mindspore/train/metrics/mean_surface_distance.py +3 -3
  460. mindspore/train/metrics/metric.py +0 -12
  461. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  462. mindspore/train/metrics/precision.py +8 -6
  463. mindspore/train/metrics/recall.py +9 -9
  464. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  465. mindspore/train/mind_ir_pb2.py +19 -12
  466. mindspore/train/model.py +262 -127
  467. mindspore/train/serialization.py +246 -988
  468. mindspore/train/summary/_summary_adapter.py +2 -2
  469. mindspore/train/summary/summary_record.py +1 -1
  470. mindspore/turbojpeg.dll +0 -0
  471. mindspore/utils/__init__.py +3 -2
  472. mindspore/utils/dryrun.py +4 -2
  473. mindspore/utils/hooks.py +81 -0
  474. mindspore/utils/runtime_execution_order_check.py +2 -0
  475. mindspore/utils/utils.py +138 -4
  476. mindspore/vcmeta.dll +0 -0
  477. mindspore/vcruntime140.dll +0 -0
  478. mindspore/vcruntime140_1.dll +0 -0
  479. mindspore/version.py +1 -1
  480. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
  481. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
  482. mindspore/_install_custom.py +0 -43
  483. mindspore/common/_register_for_adapter.py +0 -74
  484. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  485. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  486. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  487. mindspore/ops_generate/gen_constants.py +0 -190
  488. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  489. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  490. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  492. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
  493. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -107,22 +107,18 @@ class SummaryCollector(Callback):
107
107
  The first output will be treated as the loss and it will be averaged. Default: ``True`` .
108
108
  - collect_graph (bool): Whether to collect the computational graph. Currently, only
109
109
  training computational graph is collected. Default: ``True`` .
110
- - collect_train_lineage (bool): Whether to collect lineage data for the training phase,
111
- this field will be displayed on the `lineage page \
112
- <https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_
113
- of MindInsight. Default: ``True`` .
114
- - collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase,
115
- this field will be displayed on the `lineage page
116
- <https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_
117
- of MindInsight. Default: ``True`` .
110
+ - collect_train_lineage (bool): Whether to collect lineage data for the training phase.
111
+ Default: ``True`` .
112
+ - collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase.
113
+ Default: ``True`` .
118
114
  - collect_input_data (bool): Whether to collect dataset for each training.
119
115
  Currently only image data is supported.
120
116
  If there are multiple columns of data in the dataset, the first column should be image data.
121
117
  Default: ``True`` .
122
118
  - collect_dataset_graph (bool): Whether to collect dataset graph for the training phase.
123
119
  Default: ``True`` .
124
- - histogram_regular (Union[str, None]): Collect weight and bias for parameter distribution page
125
- and displayed in MindInsight. This field allows regular strings to control which parameters to collect.
120
+ - histogram_regular (Union[str, None]): Collect weight and bias for parameter distribution page.
121
+ This field allows regular strings to control which parameters to collect.
126
122
  It is not recommended to collect too many parameters at once, as it can affect performance.
127
123
  Note that if you collect too many parameters and run out of memory, the training will fail.
128
124
  Default: ``None`` , it means only the first five parameters are collected.
@@ -153,8 +149,7 @@ class SummaryCollector(Callback):
153
149
  True: it means that after specified data is set, non-specified data is collected as the default behavior.
154
150
  False: it means that after specified data is set, only the specified data is collected,
155
151
  and the others are not collected. Default: ``True`` .
156
- custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
157
- `lineage page <https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_ .
152
+ custom_lineage_data (Union[dict, None]): Allows you to customize the data.
158
153
  In the custom data, the type of the key supports str, and the type of value supports str, int
159
154
  and float. Default: ``None`` , it means there is no custom data.
160
155
  collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only.
@@ -168,7 +163,7 @@ class SummaryCollector(Callback):
168
163
  affect the number of steps TensorSummary will be collected.
169
164
  Default: ``None`` , which means to follow the behavior as described above.
170
165
  max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
171
- For example, to write not larger than 4GB, specify `max_file_size=4*1024**3`.
166
+ For example, to write not larger than 4GB, specify `max_file_size=4*1024*3`.
172
167
  Default: ``None`` , which means no limit.
173
168
  export_options (Union[None, dict]): Perform custom operations on the export data.
174
169
  Note that the size of export files is not limited by the max_file_size.
@@ -28,7 +28,8 @@ class TimeMonitor(Callback):
28
28
  Args:
29
29
  data_size (int): How many steps are the intervals between print information each time.
30
30
  if the program get `batch_num` during training, `data_size` will be set to `batch_num`,
31
- otherwise `data_size` will be used. Default: ``None`` .
31
+ otherwise `data_size` will be used. If the program does not get `batch_num` during training,
32
+ meanwhile `data_size` does not set, the program will report an error. Default: ``None`` .
32
33
 
33
34
  data_time (bool): Whether to show the average time of fetching data in Host.
34
35
  Note that data fetch and network compute are processed sequentially in non dataset sink mode, while
@@ -15,24 +15,27 @@
15
15
  """Checkpoint related classes and functions."""
16
16
 
17
17
  import os
18
+ from mindspore.utils import _tft_handler
18
19
  from mindspore.train.serialization import save_checkpoint
19
- from mindspore.parallel._utils import _get_device_num
20
- from mindspore import _checkparam as Validator
21
20
  from mindspore.train.callback._callback import Callback
22
- from mindspore import context
21
+ from mindspore import context, ops
23
22
  from mindspore.common.parameter import Parameter
24
23
  from mindspore.common.tensor import Tensor
25
24
  from mindspore.communication import get_rank, get_group_size
26
25
  from mindspore import log as logger
27
26
  from mindspore.train.serialization import _get_cur_rank_dp
28
27
  from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
28
+ from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
29
29
  from mindspore._c_expression import clean_tdt_channel
30
30
  from mindspore._c_expression import send_recv, reset_params
31
31
  from mindspore._c_expression import CollectiveManager
32
32
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
33
- from mindspore._c_expression import Tensor as Tensor_
33
+ from mindspore._c_expression import TensorPy as Tensor_
34
+ from mindspore.ops.operations.manually_defined._inner import TensorReport
34
35
  import mindspore
35
36
  import mindspore.common.dtype as mstype
37
+ from mindspore.parallel._recovery_context import _set_recovery_context
38
+
36
39
 
37
40
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
38
41
  """ Common func to generate ckpt dir name."""
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
40
43
  mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
41
44
  return os.path.join(ckpt_save_path, mid_dir)
42
45
 
46
+
43
47
  def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
44
48
  """ Callback used for TFT save ckpt function when errors occur."""
45
49
  logger.info("Enter _save_checkpoint_on_failure function")
46
- if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
50
+ if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
47
51
  raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
52
+ cb_params = args
53
+ # we record the current step and epoch num in on_train_step_end, so we can just reset it here
54
+ cb_params.cur_step_num = cb_ctx.cur_step_num
55
+ cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
56
+ if cb_params.optimizer is not None:
57
+ cb_params.optimizer.global_step = cb_ctx.global_step
58
+ if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
59
+ cb_params.network.optimizer.global_step = cb_ctx.global_step
60
+ append_dict = {}
61
+ append_dict["__exception_save__"] = True
62
+ # if user has provided a custom save callback, use it
63
+ if cb_ctx.save_cb:
64
+ cb_ctx.save_cb(cb_params, append_dict)
65
+ logger.info("Finish _save_checkpoint_on_failure function")
66
+ return
48
67
 
68
+ # if user has not provided a custom save callback, use default save logic
49
69
  ckpt_save_path = cb_ctx.ckpt_save_path
50
- cb_params = args
51
70
  cur_rank = get_rank()
52
- cur_step_num = cb_params.cur_step_num
71
+ step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
53
72
  cur_epoch_num = cb_params.cur_epoch_num
54
- batch_num = cb_params.batch_num
55
- if cur_step_num > step:
56
- cur_epoch_num = (step - 1) // batch_num + 1
57
- step_num_in_epoch = int((step - 1) % batch_num + 1)
58
-
59
- append_dict = {}
60
73
  append_dict["epoch_num"] = cur_epoch_num
61
- append_dict["step_num"] = step
74
+ append_dict["step_num"] = cb_params.cur_step_num
62
75
  append_dict["cur_rank"] = cur_rank
63
- append_dict["batch_num"] = batch_num
64
- append_dict["__exception_save__"] = True
65
-
66
- append_dict["global_step"] = Parameter([cb_ctx.global_step])
76
+ append_dict["batch_num"] = cb_params.batch_num
77
+ append_dict["global_step"] = cb_ctx.global_step
67
78
  outputs = cb_params.net_outputs
68
79
  if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
69
80
  append_dict["loss_scale"] = outputs[2]
@@ -76,49 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
76
87
  integrated_save=False, append_dict=append_dict)
77
88
  logger.info("Finish _save_checkpoint_on_failure function")
78
89
 
90
+
79
91
  def _rename_save_result(step, cb_ctx):
80
92
  """ Callback used for TFT rename function after ckpt save callback was finished and successful."""
81
93
  logger.info("Enter _rename_save_result function")
94
+ if cb_ctx.save_cb:
95
+ logger.info("User's save callback is provided, skip rename")
96
+ return
82
97
  tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
83
98
  fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
84
99
 
85
100
  os.rename(tmp_dir, fin_dir)
86
101
  logger.info("Finish _rename_save_result function")
87
102
 
103
+
88
104
  def _tft_exit_cb(ctx):
105
+ """Callback used for TFT exit function."""
89
106
  logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
90
107
  _tft_sem_post()
91
- os._exit(1) # pylint: disable=W0212
108
+ os._exit(1) # pylint: disable=W0212
92
109
 
93
110
 
94
111
  def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
95
112
  """ Callback used for TFT repair function."""
96
- logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
97
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
98
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
99
- logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
113
+ logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
114
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
115
+ cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
116
+ logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
100
117
  _repair_device(cb_ctx.device_id)
101
118
 
102
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
103
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
104
- logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
105
- {}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
119
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
120
+ cb_ctx.tft.RepairType.RT_SEND.value,
121
+ cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
122
+ logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
123
+ repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
106
124
  cb_params = args
107
- src_rank = repair_info["src"][0]
108
- dst_rank = repair_info["dst"][0]
109
- if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
110
- raise ValueError("Call send_recv failed.")
111
- logger.info("Finish _tft_repair_callback")
125
+ if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
126
+ for i in range(len(repair_info["src"])):
127
+ src_rank = repair_info["src"][i]
128
+ dst_rank = repair_info["dst"][i]
129
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
130
+ raise ValueError("Call send_recv failed.")
131
+ else:
132
+ src_rank = repair_info["src"][0]
133
+ dst_rank = repair_info["dst"][0]
134
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
135
+ raise ValueError("Call send_recv failed.")
136
+ logger.warning("Finish _tft_repair_callback")
112
137
 
113
138
 
114
139
  def _tft_clean_callback(is_uce_error, args, ctx):
115
140
  """ Callback used for TFT clean function."""
116
- logger.info("Enter _tft_clean_callback")
141
+ logger.warning("Enter _tft_clean_callback")
117
142
  ret = 0
118
143
  if is_uce_error:
119
144
  _get_uce_mem_info(ctx.device_id)
120
145
  err_strategy = _get_uce_process_strategy()
121
- logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
146
+ logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
122
147
  if err_strategy == "RS_UCE_HIGHLEVEL":
123
148
  ret = 0
124
149
  elif err_strategy == "RS_UCE_LOWLEVEL":
@@ -126,37 +151,49 @@ def _tft_clean_callback(is_uce_error, args, ctx):
126
151
  else:
127
152
  ret = 1
128
153
  clean_tdt_channel()
129
- logger.info("Enter _tft_clean_callback resume_hccl_comm")
154
+ logger.warning("Enter _tft_clean_callback resume_hccl_comm")
130
155
  CollectiveManager.get_instance().resume_hccl_comm()
131
- logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
156
+ logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
132
157
  return ret
133
158
 
134
159
 
135
160
  def _tft_stop_callback(args, cb_ctx):
136
161
  """ Callback used for TFT stop function."""
137
- logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
162
+ logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
138
163
  _stop_device(cb_ctx.device_id)
139
- if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
164
+ if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
140
165
  raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
141
166
  cb_ctx.is_uce_rank = False
167
+ if cb_ctx.tft.tft_get_repair_type() == "recover":
168
+ logger.warning(f"Reset limit step")
169
+ cb_ctx.tft.tft_reset_limit_step()
142
170
  logger.info("Finish _tft_stop_callback")
143
171
 
144
172
 
145
- class TFTRegister(Callback):
173
+ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
174
+ """Callback used for TFT Rebuild Group function."""
175
+ logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
176
+ _finalize_comm()
177
+ _rebuild_world_group()
178
+ _rebuild_sub_group()
179
+ _set_recovery_context(is_arf=True)
180
+ logger.warning("Enter _tft_rebuild_sub_groups ok ")
181
+
182
+
183
+ class TrainFaultTolerance(Callback):
146
184
  """
147
185
  This callback is used to enable the TFT feature
148
- `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
149
- This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
186
+ `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
187
+ and will execute TFT operations during training process, such as TFT init, report and exception handle.
150
188
 
151
189
  Note:
152
190
  Required for Ascend graph mode only. And sink size must be less than or equal to 1.
153
191
 
154
192
  Args:
155
- ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
156
- ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
157
- ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
158
- ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
159
- named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
193
+ ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
194
+ a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
195
+ is created in that directory. Default: ``None``.
196
+ kwargs (dict): Other dictionary type parameters.
160
197
 
161
198
  Raises:
162
199
  Exception: TFT init failed.
@@ -168,7 +205,7 @@ class TFTRegister(Callback):
168
205
 
169
206
  It's recommended to use the msrun startup method.
170
207
  Please see the `msrun start up
171
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/msrun_launcher.html>`_
208
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
172
209
  for more details.
173
210
 
174
211
  This example should be run with 4 devices.
@@ -181,7 +218,7 @@ class TFTRegister(Callback):
181
218
  >>> from mindspore import nn, ops, Parameter, train
182
219
  >>> from mindspore.communication import init, get_rank
183
220
  >>> from mindspore.common.initializer import initializer, HeUniform
184
- >>> from mindspore.train import Model, TFTRegister
221
+ >>> from mindspore.train import Model, TrainFaultTolerance
185
222
  >>> from mindspore import dataset as ds
186
223
  >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
187
224
  >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
@@ -252,43 +289,68 @@ class TFTRegister(Callback):
252
289
  >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
253
290
  >>> loss_fn = nn.CrossEntropyLoss()
254
291
  >>>
255
- >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
292
+ >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
256
293
  >>> net_with_loss.set_train()
257
294
  >>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
258
- >>> tft_cb = TFTRegister(0, "192.168.0.1", 2000, "./tft_checkpoint/")
295
+ >>> tft_cb = TrainFaultTolerance()
259
296
  >>> loss_cb = train.LossMonitor(1)
260
297
  >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
261
298
  """
262
299
 
263
- def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
264
- super(TFTRegister, self).__init__()
265
-
266
- tft_env = os.getenv("MS_ENABLE_TFT", "")
267
- if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
268
- raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
269
- mode = context.get_context("mode")
270
- device_target = context.get_context("device_target")
271
- if device_target != "Ascend" or mode != context.GRAPH_MODE:
272
- raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
273
-
274
- # let it raise errors if not install mindio_tft package
275
- from mindio_ttp import framework_ttp as tft
276
- self.tft = tft
277
- self.global_step = 0
278
- Validator.check_non_negative_int(ctrl_port)
279
- self.has_init_replica = False
280
- self.is_uce_rank = False
281
- self._controller_ip = ctrl_ip
282
- self._controller_rank_id = ctrl_rank_id
283
- self._controller_port = ctrl_port
300
+ def __init__(self, ckpt_save_path=None, **kwargs):
301
+ super(TrainFaultTolerance, self).__init__()
302
+ self.save_cb = kwargs.get("ckpt_save_fn", None)
303
+ self.ckpt_save_path = ckpt_save_path
304
+ if self.save_cb is None and self.ckpt_save_path is None:
305
+ raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
284
306
  self.cb_params = None
307
+ self.initial_step = kwargs.get("initial_step", 0)
285
308
  self.device_id = context.get_context("device_id")
286
- self._init_tft()
287
- self.ckpt_save_path = ckpt_save_path
309
+ self.cur_step_num = 0
310
+ self.cur_epoch_num = 0
311
+ # For TREError(Training Result Error) scene, parameter `ckpt_load_fn` must be provided to load checkpoint
312
+ # from file for resuming training, the `ckpt_load_fn` is a function, prototype of which is:
313
+ # `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
314
+ # i.e. (param_dict, remove_redundancy)
315
+ self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
316
+ self.tft = _tft_handler.get_tft()
317
+ if self._only_enable_tre():
318
+ return
319
+ self._check_init()
320
+ self.global_step = None
321
+ self.learning_rate = None
322
+ self.has_init_replica = False
323
+ self.is_uce_rank = False
324
+
288
325
  self.assign = mindspore.ops.Assign()
289
326
  self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
290
327
  self.s1 = mindspore.hal.Stream()
291
328
  _tft_sem_enable()
329
+ self._tft_register()
330
+
331
+ def _only_enable_tre(self):
332
+ """Check if only configured MS_ENABLE_TFT='{TRE:1}'"""
333
+ env_enable = os.getenv("MS_ENABLE_TFT", "")
334
+ non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
335
+ if any(flag in env_enable for flag in non_tre_flags):
336
+ return False
337
+ return "TRE:1" in env_enable
338
+
339
+ def _check_init(self):
340
+ """Check if the mindio-ttp had inited"""
341
+ if self.tft is None:
342
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
343
+ if "ARF:1" in tft_env:
344
+ raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
345
+ logger.warning(f"TFT handle not init, try to init")
346
+ _tft_handler.init(config=None)
347
+ self.tft = _tft_handler.get_tft()
348
+ logger.warning(f"TFT handle init ok.")
349
+ mode = context.get_context("mode")
350
+ device_target = context.get_context("device_target")
351
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
352
+ raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
353
+ f"device:{device_target}, run mode: {mode}")
292
354
 
293
355
  def _is_params_consistent(self):
294
356
  for key, param in self.cb_params.train_network.parameters_and_names():
@@ -300,7 +362,7 @@ class TFTRegister(Callback):
300
362
  return False
301
363
 
302
364
  def _set_tft_optimizer_replica(self, run_context):
303
- """ set Mindio TFT optimizer replica info, used internal. """
365
+ """ Set Mindio TFT optimizer replica info, used internal. """
304
366
  cur_rank = get_rank()
305
367
  cb_params = run_context.original_args()
306
368
  train_network = cb_params.train_network
@@ -322,33 +384,49 @@ class TFTRegister(Callback):
322
384
  ]
323
385
  self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
324
386
 
325
- def _init_tft(self):
326
- """ Init Mindio TFT, used internal. """
327
- logger.info("Begin to init tft.")
387
+ @classmethod
388
+ def get_optimizer_wrapper(cls, origin_opt_cls):
389
+ """
390
+ Optimizer wrapper func when using tft.
391
+
392
+ Args:
393
+ origin_opt_cls (Class): origin optimizer class.
394
+ """
395
+
396
+ class TFTOptSubCls(origin_opt_cls):
397
+ """
398
+ Optimizer wrapper class when using tft.
399
+ """
400
+
401
+ def __init__(self, *args, **kwargs):
402
+ super(TFTOptSubCls, self).__init__(*args, **kwargs)
403
+ self.report = TensorReport()
404
+ self.report_end = TensorReport()
405
+ self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
406
+ self.depend = ops.Depend()
407
+ self.allreduce_sum = ops.AllReduce()
408
+ self.allreduce_sum.add_prim_attr("tft_report_before", True)
409
+ self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
410
+
411
+ def construct(self, gradients, **kwargs):
412
+ tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
413
+ self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
414
+ grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
415
+ opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
416
+ self.report_end("tft_report", self.tft_g_one_flag)
417
+ return opt_ret
418
+
419
+ return TFTOptSubCls
420
+
421
+ def _tft_register(self):
422
+ """Register callback functions."""
328
423
  self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
329
424
  self.tft.tft_register_rename_handler(_rename_save_result, self)
330
425
  self.tft.tft_register_exit_handler(_tft_exit_cb, self)
331
426
  self.tft.tft_register_stop_handler(_tft_stop_callback, self)
332
427
  self.tft.tft_register_clean_handler(_tft_clean_callback, self)
333
428
  self.tft.tft_register_repair_handler(_tft_repair_callback, self)
334
-
335
- world_size = _get_device_num()
336
- cur_rank = get_rank()
337
- enable_local_copy = False
338
- enable_arf = False
339
- enable_tls = False
340
- tls_key_dir = ""
341
-
342
- if cur_rank == self._controller_rank_id:
343
- logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
344
- self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
345
- self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
346
- logger.info("Finish start tft controller.")
347
-
348
- logger.info("Begin to start tft processor.")
349
- self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
350
- self.tft.tft_start_processor(self._controller_ip, self._controller_port)
351
- logger.info("Finished start tft processor.")
429
+ self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
352
430
 
353
431
  def _reset_acc_grads(self):
354
432
  accu_grad_params = map(lambda e: e[1],
@@ -360,29 +438,44 @@ class TFTRegister(Callback):
360
438
 
361
439
  def on_train_step_end(self, run_context):
362
440
  """
363
- And report status to MindIO TFT after every step finished.
441
+ Report status to MindIO TFT after every step finished.
364
442
 
365
443
  Args:
366
444
  run_context (RunContext): Context of the train running. Refer to
367
445
  :class:`mindspore.train.RunContext` for detail.
368
446
  """
447
+ if self._only_enable_tre():
448
+ return
369
449
  if self.has_init_replica is False:
370
450
  self.has_init_replica = True
371
451
  self._set_tft_optimizer_replica(run_context)
372
452
  cb_params = run_context.original_args()
373
453
  logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
454
+ self.cur_step_num = cb_params.cur_step_num
455
+ self.cur_epoch_num = cb_params.cur_epoch_num
374
456
  if cb_params.optimizer is not None:
375
- self.global_step = int(cb_params.optimizer.global_step.data)
457
+ self.global_step = cb_params.optimizer.global_step.clone()
376
458
  self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
377
- else:
378
- self.global_step = int(cb_params.network.optimizer.global_step.data)
459
+ elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
460
+ self.global_step = cb_params.network.optimizer.global_step.clone()
379
461
  self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
380
- self.tft.tft_end_updating_os(cb_params.cur_step_num)
462
+ else:
463
+ raise ValueError("TFT feature need optimizer or network's optimizer!")
464
+ self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
381
465
  logger.info("END Set optimizer finish step status to TFT.")
382
466
 
383
-
384
467
  def on_train_begin(self, run_context):
468
+ """
469
+ Register train params to MindIO TFT on train beginning.
470
+
471
+ Args:
472
+ run_context (RunContext): Context of the train running. Refer to
473
+ :class:`mindspore.train.RunContext` for detail.
474
+ """
385
475
  cb_params = run_context.original_args()
476
+ if self._only_enable_tre():
477
+ self.cb_params = cb_params
478
+ return
386
479
  sink_size = cb_params.get("sink_size", 0)
387
480
  if sink_size > 1:
388
481
  raise ValueError("TFT feature doesn't support sink_size > 1.")
@@ -391,7 +484,13 @@ class TFTRegister(Callback):
391
484
  self.cb_params = cb_params
392
485
 
393
486
  def end(self, run_context):
394
- cur_rank = get_rank()
395
- if cur_rank == self._controller_rank_id:
396
- self.tft.tft_destroy_controller()
397
- self.tft.tft_destroy_processor()
487
+ """
488
+ Unregister MindIO TFT on train end.
489
+
490
+ Args:
491
+ run_context (RunContext): Context of the train running. Refer to
492
+ :class:`mindspore.train.RunContext` for detail.
493
+ """
494
+ if self._only_enable_tre():
495
+ return
496
+ _tft_handler.unregister_tft()
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
98
98
  return next_op, (key, dataset_shapes, dataset_types)
99
99
 
100
100
 
101
+ def _get_jit_func(sink_fun, jit_config):
102
+ """
103
+ Get the jit function.
104
+ """
105
+ jit_config_dict = jit_config.jit_config_dict
106
+ jit_level = jit_config_dict['jit_level']
107
+ if jit_level == "":
108
+ jit_level = "O0"
109
+ backend = ""
110
+ if jit_level == "O2":
111
+ jit_level = "O0"
112
+ backend = "GE"
113
+ if "backend" in jit_config_dict:
114
+ backend = jit_config_dict["backend"]
115
+ fullgraph = False
116
+ if jit_config_dict['jit_syntax_level'] == "STRICT":
117
+ fullgraph = True
118
+ exc_mode = jit_config_dict['exc_mode']
119
+ infer_boost = jit_config_dict['infer_boost']
120
+ return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
121
+ infer_boost=infer_boost)
122
+
123
+
101
124
  def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
102
125
  """
103
126
  get the sink function.
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
107
130
  if jit_config is None:
108
131
  dst_sink_fun = sink_fun
109
132
  else:
110
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
133
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
111
134
  dataset.__sink_fun__ = dst_sink_fun
112
135
 
113
136
  return dataset.__sink_fun__
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
119
142
  if jit_config is None:
120
143
  dst_sink_fun = sink_fun
121
144
  else:
122
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
145
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
123
146
  dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
124
147
 
125
148
  return dst_sink_fun
@@ -214,8 +214,7 @@ def _get_dataset_aux(dataset):
214
214
 
215
215
  def connect_network_with_dataset(network, dataset_helper):
216
216
  """
217
- Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
218
- <https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
217
+ Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
219
218
  (dataset_sink_mode=True).
220
219
 
221
220
  Args:
@@ -335,11 +334,11 @@ class DatasetHelper:
335
334
  dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
336
335
  dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
337
336
  Default: ``True``.
338
- sink_size (int): Control the amount of data in each sink.
337
+ sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
339
338
  If sink_size=-1, sink the complete dataset for each epoch.
340
339
  If sink_size>0, sink sink_size data for each epoch.
341
- Default: -1.
342
- epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1.
340
+ Default: ``-1``.
341
+ epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
343
342
 
344
343
  Examples:
345
344
  >>> import numpy as np