mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -15,24 +15,27 @@
15
15
  """Checkpoint related classes and functions."""
16
16
 
17
17
  import os
18
+ from mindspore.utils import _tft_handler
18
19
  from mindspore.train.serialization import save_checkpoint
19
- from mindspore.parallel._utils import _get_device_num
20
- from mindspore import _checkparam as Validator
21
20
  from mindspore.train.callback._callback import Callback
22
- from mindspore import context
21
+ from mindspore import context, ops
23
22
  from mindspore.common.parameter import Parameter
24
23
  from mindspore.common.tensor import Tensor
25
24
  from mindspore.communication import get_rank, get_group_size
26
25
  from mindspore import log as logger
27
26
  from mindspore.train.serialization import _get_cur_rank_dp
28
27
  from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
28
+ from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
29
29
  from mindspore._c_expression import clean_tdt_channel
30
30
  from mindspore._c_expression import send_recv, reset_params
31
31
  from mindspore._c_expression import CollectiveManager
32
32
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
33
- from mindspore._c_expression import Tensor as Tensor_
33
+ from mindspore._c_expression import TensorPy as Tensor_
34
+ from mindspore.ops.operations.manually_defined._inner import TensorReport
34
35
  import mindspore
35
36
  import mindspore.common.dtype as mstype
37
+ from mindspore.parallel._recovery_context import _set_recovery_context
38
+
36
39
 
37
40
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
38
41
  """ Common func to generate ckpt dir name."""
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
40
43
  mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
41
44
  return os.path.join(ckpt_save_path, mid_dir)
42
45
 
46
+
43
47
  def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
44
48
  """ Callback used for TFT save ckpt function when errors occur."""
45
49
  logger.info("Enter _save_checkpoint_on_failure function")
46
- if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
50
+ if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
47
51
  raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
52
+ cb_params = args
53
+ # we record the current step and epoch num in on_train_step_end, so we can just reset it here
54
+ cb_params.cur_step_num = cb_ctx.cur_step_num
55
+ cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
56
+ if cb_params.optimizer is not None:
57
+ cb_params.optimizer.global_step = cb_ctx.global_step
58
+ if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
59
+ cb_params.network.optimizer.global_step = cb_ctx.global_step
60
+ append_dict = {}
61
+ append_dict["__exception_save__"] = True
62
+ # if user has provided a custom save callback, use it
63
+ if cb_ctx.save_cb:
64
+ cb_ctx.save_cb(cb_params, append_dict)
65
+ logger.info("Finish _save_checkpoint_on_failure function")
66
+ return
48
67
 
68
+ # if user has not provided a custom save callback, use default save logic
49
69
  ckpt_save_path = cb_ctx.ckpt_save_path
50
- cb_params = args
51
70
  cur_rank = get_rank()
52
- cur_step_num = cb_params.cur_step_num
71
+ step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
53
72
  cur_epoch_num = cb_params.cur_epoch_num
54
- batch_num = cb_params.batch_num
55
- if cur_step_num > step:
56
- cur_epoch_num = (step - 1) // batch_num + 1
57
- step_num_in_epoch = int((step - 1) % batch_num + 1)
58
-
59
- append_dict = {}
60
73
  append_dict["epoch_num"] = cur_epoch_num
61
- append_dict["step_num"] = step
74
+ append_dict["step_num"] = cb_params.cur_step_num
62
75
  append_dict["cur_rank"] = cur_rank
63
- append_dict["batch_num"] = batch_num
64
- append_dict["__exception_save__"] = True
65
-
66
- append_dict["global_step"] = Parameter([cb_ctx.global_step])
76
+ append_dict["batch_num"] = cb_params.batch_num
77
+ append_dict["global_step"] = cb_ctx.global_step
67
78
  outputs = cb_params.net_outputs
68
79
  if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
69
80
  append_dict["loss_scale"] = outputs[2]
@@ -76,49 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
76
87
  integrated_save=False, append_dict=append_dict)
77
88
  logger.info("Finish _save_checkpoint_on_failure function")
78
89
 
90
+
79
91
  def _rename_save_result(step, cb_ctx):
80
92
  """ Callback used for TFT rename function after ckpt save callback was finished and successful."""
81
93
  logger.info("Enter _rename_save_result function")
94
+ if cb_ctx.save_cb:
95
+ logger.info("User's save callback is provided, skip rename")
96
+ return
82
97
  tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
83
98
  fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
84
99
 
85
100
  os.rename(tmp_dir, fin_dir)
86
101
  logger.info("Finish _rename_save_result function")
87
102
 
103
+
88
104
  def _tft_exit_cb(ctx):
105
+ """Callback used for TFT exit function."""
89
106
  logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
90
107
  _tft_sem_post()
91
- os._exit(1) # pylint: disable=W0212
108
+ os._exit(1) # pylint: disable=W0212
92
109
 
93
110
 
94
111
  def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
95
112
  """ Callback used for TFT repair function."""
96
- logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
97
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
98
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
99
- logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
113
+ logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
114
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
115
+ cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
116
+ logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
100
117
  _repair_device(cb_ctx.device_id)
101
118
 
102
- if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
103
- or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
104
- logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
105
- {}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
119
+ if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
120
+ cb_ctx.tft.RepairType.RT_SEND.value,
121
+ cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
122
+ logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
123
+ repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
106
124
  cb_params = args
107
- src_rank = repair_info["src"][0]
108
- dst_rank = repair_info["dst"][0]
109
- if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
110
- raise ValueError("Call send_recv failed.")
111
- logger.info("Finish _tft_repair_callback")
125
+ if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
126
+ for i in range(len(repair_info["src"])):
127
+ src_rank = repair_info["src"][i]
128
+ dst_rank = repair_info["dst"][i]
129
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
130
+ raise ValueError("Call send_recv failed.")
131
+ else:
132
+ src_rank = repair_info["src"][0]
133
+ dst_rank = repair_info["dst"][0]
134
+ if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
135
+ raise ValueError("Call send_recv failed.")
136
+ logger.warning("Finish _tft_repair_callback")
112
137
 
113
138
 
114
139
  def _tft_clean_callback(is_uce_error, args, ctx):
115
140
  """ Callback used for TFT clean function."""
116
- logger.info("Enter _tft_clean_callback")
141
+ logger.warning("Enter _tft_clean_callback")
117
142
  ret = 0
118
143
  if is_uce_error:
119
144
  _get_uce_mem_info(ctx.device_id)
120
145
  err_strategy = _get_uce_process_strategy()
121
- logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
146
+ logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
122
147
  if err_strategy == "RS_UCE_HIGHLEVEL":
123
148
  ret = 0
124
149
  elif err_strategy == "RS_UCE_LOWLEVEL":
@@ -126,37 +151,49 @@ def _tft_clean_callback(is_uce_error, args, ctx):
126
151
  else:
127
152
  ret = 1
128
153
  clean_tdt_channel()
129
- logger.info("Enter _tft_clean_callback resume_hccl_comm")
154
+ logger.warning("Enter _tft_clean_callback resume_hccl_comm")
130
155
  CollectiveManager.get_instance().resume_hccl_comm()
131
- logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
156
+ logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
132
157
  return ret
133
158
 
134
159
 
135
160
  def _tft_stop_callback(args, cb_ctx):
136
161
  """ Callback used for TFT stop function."""
137
- logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
162
+ logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
138
163
  _stop_device(cb_ctx.device_id)
139
- if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
164
+ if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
140
165
  raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
141
166
  cb_ctx.is_uce_rank = False
167
+ if cb_ctx.tft.tft_get_repair_type() == "recover":
168
+ logger.warning(f"Reset limit step")
169
+ cb_ctx.tft.tft_reset_limit_step()
142
170
  logger.info("Finish _tft_stop_callback")
143
171
 
144
172
 
145
- class TFTRegister(Callback):
173
+ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
174
+ """Callback used for TFT Rebuild Group function."""
175
+ logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
176
+ _finalize_comm()
177
+ _rebuild_world_group()
178
+ _rebuild_sub_group()
179
+ _set_recovery_context(is_arf=True)
180
+ logger.warning("Enter _tft_rebuild_sub_groups ok ")
181
+
182
+
183
+ class TrainFaultTolerance(Callback):
146
184
  """
147
185
  This callback is used to enable the TFT feature
148
- `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
149
- This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
186
+ `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
187
+ and will execute TFT operations during training process, such as TFT init, report and exception handle.
150
188
 
151
189
  Note:
152
190
  Required for Ascend graph mode only. And sink size must be less than or equal to 1.
153
191
 
154
192
  Args:
155
- ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
156
- ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
157
- ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
158
- ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
159
- named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
193
+ ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
194
+ a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
195
+ is created in that directory. Default: ``None``.
196
+ kwargs (dict): Other dictionary type parameters.
160
197
 
161
198
  Raises:
162
199
  Exception: TFT init failed.
@@ -168,7 +205,7 @@ class TFTRegister(Callback):
168
205
 
169
206
  It's recommended to use the msrun startup method.
170
207
  Please see the `msrun start up
171
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/msrun_launcher.html>`_
208
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
172
209
  for more details.
173
210
 
174
211
  This example should be run with 4 devices.
@@ -181,7 +218,7 @@ class TFTRegister(Callback):
181
218
  >>> from mindspore import nn, ops, Parameter, train
182
219
  >>> from mindspore.communication import init, get_rank
183
220
  >>> from mindspore.common.initializer import initializer, HeUniform
184
- >>> from mindspore.train import Model, TFTRegister
221
+ >>> from mindspore.train import Model, TrainFaultTolerance
185
222
  >>> from mindspore import dataset as ds
186
223
  >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
187
224
  >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
@@ -252,43 +289,52 @@ class TFTRegister(Callback):
252
289
  >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
253
290
  >>> loss_fn = nn.CrossEntropyLoss()
254
291
  >>>
255
- >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
292
+ >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
256
293
  >>> net_with_loss.set_train()
257
294
  >>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
258
- >>> tft_cb = TFTRegister(0, "192.168.0.1", 2000, "./tft_checkpoint/")
295
+ >>> tft_cb = TrainFaultTolerance()
259
296
  >>> loss_cb = train.LossMonitor(1)
260
297
  >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
261
298
  """
262
299
 
263
- def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
264
- super(TFTRegister, self).__init__()
265
-
266
- tft_env = os.getenv("MS_ENABLE_TFT", "")
267
- if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
268
- raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
269
- mode = context.get_context("mode")
270
- device_target = context.get_context("device_target")
271
- if device_target != "Ascend" or mode != context.GRAPH_MODE:
272
- raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
273
-
274
- # let it raise errors if not install mindio_tft package
275
- from mindio_ttp import framework_ttp as tft
276
- self.tft = tft
277
- self.global_step = 0
278
- Validator.check_non_negative_int(ctrl_port)
300
+ def __init__(self, ckpt_save_path=None, **kwargs):
301
+ super(TrainFaultTolerance, self).__init__()
302
+ self.save_cb = kwargs.get("ckpt_save_fn", None)
303
+ self.ckpt_save_path = ckpt_save_path
304
+ if self.save_cb is None and self.ckpt_save_path is None:
305
+ raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
306
+ self.tft = _tft_handler.get_tft()
307
+ self._check_init()
308
+ self.global_step = None
309
+ self.learning_rate = None
279
310
  self.has_init_replica = False
280
311
  self.is_uce_rank = False
281
- self._controller_ip = ctrl_ip
282
- self._controller_rank_id = ctrl_rank_id
283
- self._controller_port = ctrl_port
284
312
  self.cb_params = None
313
+ self.initial_step = kwargs.get("initial_step", 0)
285
314
  self.device_id = context.get_context("device_id")
286
- self._init_tft()
287
- self.ckpt_save_path = ckpt_save_path
288
315
  self.assign = mindspore.ops.Assign()
289
316
  self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
290
317
  self.s1 = mindspore.hal.Stream()
318
+ self.cur_step_num = 0
319
+ self.cur_epoch_num = 0
291
320
  _tft_sem_enable()
321
+ self._tft_register()
322
+
323
+ def _check_init(self):
324
+ """Check if the mindio-ttp had inited"""
325
+ if self.tft is None:
326
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
327
+ if "ARF:1" in tft_env:
328
+ raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
329
+ logger.warning(f"TFT handle not init, try to init")
330
+ _tft_handler.init(config=None)
331
+ self.tft = _tft_handler.get_tft()
332
+ logger.warning(f"TFT handle init ok.")
333
+ mode = context.get_context("mode")
334
+ device_target = context.get_context("device_target")
335
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
336
+ raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
337
+ f"device:{device_target}, run mode: {mode}")
292
338
 
293
339
  def _is_params_consistent(self):
294
340
  for key, param in self.cb_params.train_network.parameters_and_names():
@@ -300,7 +346,7 @@ class TFTRegister(Callback):
300
346
  return False
301
347
 
302
348
  def _set_tft_optimizer_replica(self, run_context):
303
- """ set Mindio TFT optimizer replica info, used internal. """
349
+ """ Set Mindio TFT optimizer replica info, used internal. """
304
350
  cur_rank = get_rank()
305
351
  cb_params = run_context.original_args()
306
352
  train_network = cb_params.train_network
@@ -322,33 +368,49 @@ class TFTRegister(Callback):
322
368
  ]
323
369
  self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
324
370
 
325
- def _init_tft(self):
326
- """ Init Mindio TFT, used internal. """
327
- logger.info("Begin to init tft.")
371
+ @classmethod
372
+ def get_optimizer_wrapper(cls, origin_opt_cls):
373
+ """
374
+ Optimizer wrapper func when using tft.
375
+
376
+ Args:
377
+ origin_opt_cls (Class): origin optimizer class.
378
+ """
379
+
380
+ class TFTOptSubCls(origin_opt_cls):
381
+ """
382
+ Optimizer wrapper class when using tft.
383
+ """
384
+
385
+ def __init__(self, *args, **kwargs):
386
+ super(TFTOptSubCls, self).__init__(*args, **kwargs)
387
+ self.report = TensorReport()
388
+ self.report_end = TensorReport()
389
+ self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
390
+ self.depend = ops.Depend()
391
+ self.allreduce_sum = ops.AllReduce()
392
+ self.allreduce_sum.add_prim_attr("tft_report_before", True)
393
+ self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
394
+
395
+ def construct(self, gradients, **kwargs):
396
+ tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
397
+ self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
398
+ grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
399
+ opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
400
+ self.report_end("tft_report", self.tft_g_one_flag)
401
+ return opt_ret
402
+
403
+ return TFTOptSubCls
404
+
405
+ def _tft_register(self):
406
+ """Register callback functions."""
328
407
  self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
329
408
  self.tft.tft_register_rename_handler(_rename_save_result, self)
330
409
  self.tft.tft_register_exit_handler(_tft_exit_cb, self)
331
410
  self.tft.tft_register_stop_handler(_tft_stop_callback, self)
332
411
  self.tft.tft_register_clean_handler(_tft_clean_callback, self)
333
412
  self.tft.tft_register_repair_handler(_tft_repair_callback, self)
334
-
335
- world_size = _get_device_num()
336
- cur_rank = get_rank()
337
- enable_local_copy = False
338
- enable_arf = False
339
- enable_tls = False
340
- tls_key_dir = ""
341
-
342
- if cur_rank == self._controller_rank_id:
343
- logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
344
- self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
345
- self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
346
- logger.info("Finish start tft controller.")
347
-
348
- logger.info("Begin to start tft processor.")
349
- self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
350
- self.tft.tft_start_processor(self._controller_ip, self._controller_port)
351
- logger.info("Finished start tft processor.")
413
+ self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
352
414
 
353
415
  def _reset_acc_grads(self):
354
416
  accu_grad_params = map(lambda e: e[1],
@@ -360,7 +422,7 @@ class TFTRegister(Callback):
360
422
 
361
423
  def on_train_step_end(self, run_context):
362
424
  """
363
- And report status to MindIO TFT after every step finished.
425
+ Report status to MindIO TFT after every step finished.
364
426
 
365
427
  Args:
366
428
  run_context (RunContext): Context of the train running. Refer to
@@ -371,17 +433,27 @@ class TFTRegister(Callback):
371
433
  self._set_tft_optimizer_replica(run_context)
372
434
  cb_params = run_context.original_args()
373
435
  logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
436
+ self.cur_step_num = cb_params.cur_step_num
437
+ self.cur_epoch_num = cb_params.cur_epoch_num
374
438
  if cb_params.optimizer is not None:
375
- self.global_step = int(cb_params.optimizer.global_step.data)
439
+ self.global_step = cb_params.optimizer.global_step.clone()
376
440
  self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
377
- else:
378
- self.global_step = int(cb_params.network.optimizer.global_step.data)
441
+ elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
442
+ self.global_step = cb_params.network.optimizer.global_step.clone()
379
443
  self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
380
- self.tft.tft_end_updating_os(cb_params.cur_step_num)
444
+ else:
445
+ raise ValueError("TFT feature need optimizer or network's optimizer!")
446
+ self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
381
447
  logger.info("END Set optimizer finish step status to TFT.")
382
448
 
383
-
384
449
  def on_train_begin(self, run_context):
450
+ """
451
+ Register train params to MindIO TFT on train beginning.
452
+
453
+ Args:
454
+ run_context (RunContext): Context of the train running. Refer to
455
+ :class:`mindspore.train.RunContext` for detail.
456
+ """
385
457
  cb_params = run_context.original_args()
386
458
  sink_size = cb_params.get("sink_size", 0)
387
459
  if sink_size > 1:
@@ -391,7 +463,11 @@ class TFTRegister(Callback):
391
463
  self.cb_params = cb_params
392
464
 
393
465
  def end(self, run_context):
394
- cur_rank = get_rank()
395
- if cur_rank == self._controller_rank_id:
396
- self.tft.tft_destroy_controller()
397
- self.tft.tft_destroy_processor()
466
+ """
467
+ Unregister MindIO TFT on train end.
468
+
469
+ Args:
470
+ run_context (RunContext): Context of the train running. Refer to
471
+ :class:`mindspore.train.RunContext` for detail.
472
+ """
473
+ _tft_handler.unregister_tft()
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
98
98
  return next_op, (key, dataset_shapes, dataset_types)
99
99
 
100
100
 
101
+ def _get_jit_func(sink_fun, jit_config):
102
+ """
103
+ Get the jit function.
104
+ """
105
+ jit_config_dict = jit_config.jit_config_dict
106
+ jit_level = jit_config_dict['jit_level']
107
+ if jit_level == "":
108
+ jit_level = "O0"
109
+ backend = ""
110
+ if jit_level == "O2":
111
+ jit_level = "O0"
112
+ backend = "GE"
113
+ if "backend" in jit_config_dict:
114
+ backend = jit_config_dict["backend"]
115
+ fullgraph = False
116
+ if jit_config_dict['jit_syntax_level'] == "STRICT":
117
+ fullgraph = True
118
+ exc_mode = jit_config_dict['exc_mode']
119
+ infer_boost = jit_config_dict['infer_boost']
120
+ return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
121
+ infer_boost=infer_boost)
122
+
123
+
101
124
  def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
102
125
  """
103
126
  get the sink function.
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
107
130
  if jit_config is None:
108
131
  dst_sink_fun = sink_fun
109
132
  else:
110
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
133
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
111
134
  dataset.__sink_fun__ = dst_sink_fun
112
135
 
113
136
  return dataset.__sink_fun__
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
119
142
  if jit_config is None:
120
143
  dst_sink_fun = sink_fun
121
144
  else:
122
- dst_sink_fun = jit(sink_fun, jit_config=jit_config)
145
+ dst_sink_fun = _get_jit_func(sink_fun, jit_config)
123
146
  dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
124
147
 
125
148
  return dst_sink_fun
@@ -214,8 +214,7 @@ def _get_dataset_aux(dataset):
214
214
 
215
215
  def connect_network_with_dataset(network, dataset_helper):
216
216
  """
217
- Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
218
- <https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
217
+ Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
219
218
  (dataset_sink_mode=True).
220
219
 
221
220
  Args:
@@ -335,11 +334,11 @@ class DatasetHelper:
335
334
  dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
336
335
  dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
337
336
  Default: ``True``.
338
- sink_size (int): Control the amount of data in each sink.
337
+ sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
339
338
  If sink_size=-1, sink the complete dataset for each epoch.
340
339
  If sink_size>0, sink sink_size data for each epoch.
341
- Default: -1.
342
- epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1.
340
+ Default: ``-1``.
341
+ epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
343
342
 
344
343
  Examples:
345
344
  >>> import numpy as np
@@ -51,9 +51,10 @@ class FixedLossScaleManager(LossScaleManager):
51
51
  inherits from :class:`mindspore.amp.LossScaleManager`.
52
52
 
53
53
  Args:
54
- loss_scale (float): Magnification factor of gradients. Note that if `drop_overflow_update` is set to ``False`` ,
54
+ loss_scale (float, optional): Magnification factor of gradients.
55
+ Note that if `drop_overflow_update` is set to ``False`` ,
55
56
  the value of `loss_scale` in optimizer should be set to the same as here. Default: ``128.0`` .
56
- drop_overflow_update (bool): Whether to execute optimizer if there is an overflow.
57
+ drop_overflow_update (bool, optional): Whether to execute optimizer if there is an overflow.
57
58
  If ``True`` , the optimizer will
58
59
  not executed when overflow occurs. Default: ``True`` .
59
60
 
@@ -110,8 +111,8 @@ class FixedLossScaleManager(LossScaleManager):
110
111
 
111
112
  Returns:
112
113
  None or :class:`mindspore.nn.FixedLossScaleUpdateCell`. Instance of
113
- :class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is True. None when
114
- `drop_overflow_update` is False.
114
+ :class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is ``True``. None when
115
+ `drop_overflow_update` is ``False``.
115
116
  """
116
117
  if not self._drop_overflow_update:
117
118
  return None
@@ -124,9 +125,9 @@ class DynamicLossScaleManager(LossScaleManager):
124
125
  adjusted, inherits from :class:`mindspore.amp.LossScaleManager`.
125
126
 
126
127
  Args:
127
- init_loss_scale (float): Initialize loss scale. Default: ``2 ** 24`` .
128
- scale_factor (int): Coefficient of increase and decrease. Default: ``2`` .
129
- scale_window (int): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
128
+ init_loss_scale (float, optional): Initialize loss scale. Default: ``2 ** 24`` .
129
+ scale_factor (int, optional): Coefficient of increase and decrease. Default: ``2`` .
130
+ scale_window (int, optional): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
130
131
 
131
132
  Supported Platforms:
132
133
  ``Ascend`` ``GPU``
@@ -45,11 +45,11 @@ class Accuracy(EvaluationBase):
45
45
  >>> from mindspore import Tensor
46
46
  >>> from mindspore.train import Accuracy
47
47
  >>>
48
- >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32)
49
- >>> y = Tensor(np.array([1, 0, 1]), mindspore.float32)
48
+ >>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32)
49
+ >>> y_true = Tensor(np.array([1, 0, 1]), mindspore.float32)
50
50
  >>> metric = Accuracy('classification')
51
51
  >>> metric.clear()
52
- >>> metric.update(x, y)
52
+ >>> metric.update(y_pred, y_true)
53
53
  >>> accuracy = metric.eval()
54
54
  >>> print(accuracy)
55
55
  0.6666666666666666
@@ -23,15 +23,15 @@ from mindspore.train.metrics.metric import Metric, rearrange_inputs
23
23
 
24
24
  class ConfusionMatrix(Metric):
25
25
  """
26
- Computes the confusion matrix, which is commonly used to evaluate the performance of classification models,
26
+ Computes the Confusion Matrix, which is commonly used to evaluate the performance of classification models,
27
27
  including binary classification and multiple classification.
28
28
 
29
- If you only need confusion matrix, use this class. If you want to calculate other metrics, such as 'PPV',
29
+ If you only need Confusion Matrix, use this class. If you want to calculate other metrics, such as 'PPV',
30
30
  'TPR', 'TNR', etc., use class :class:`mindspore.train.ConfusionMatrixMetric` .
31
31
 
32
32
  Args:
33
33
  num_classes (int): Number of classes in the dataset.
34
- normalize (str): Normalization mode for confusion matrix. Default: ``"no_norm"`` . Choose from:
34
+ normalize (str): Normalization mode for Confusion Matrix. Default: ``"no_norm"`` . Choose from:
35
35
 
36
36
  - ``"no_norm"`` : No Normalization is used. Default: ``None``.
37
37
  - ``"target"`` : Normalization based on target value.
@@ -78,7 +78,7 @@ class ConfusionMatrix(Metric):
78
78
  @rearrange_inputs
79
79
  def update(self, *inputs):
80
80
  """
81
- Update state with y_pred and y.
81
+ Update state with `y_pred` and `y`.
82
82
 
83
83
  Args:
84
84
  inputs(tuple): Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, list or numpy.ndarray.
@@ -88,7 +88,7 @@ class ConfusionMatrix(Metric):
88
88
 
89
89
  Raises:
90
90
  ValueError: If the number of inputs is not 2.
91
- ValueError: If the dim of y_pred and y are not equal.
91
+ ValueError: If the dims of `y_pred` and `y` are not equal.
92
92
  """
93
93
  if len(inputs) != 2:
94
94
  raise ValueError("For 'ConfusionMatrix.update', it needs 2 inputs (predicted value, true value), "
@@ -151,8 +151,8 @@ class ConfusionMatrixMetric(Metric):
151
151
  batch, class channel and iteration are collected. All metrics supported by the interface are listed in comments
152
152
  of `metric_name`.
153
153
 
154
- If you want to calculate metrics related to confusion matrix, such as 'PPV', 'TPR', 'TNR', use this class.
155
- If you only want to calculate confusion matrix, please use :class:`mindspore.train.ConfusionMatrix` .
154
+ - If you want to calculate metrics related to confusion matrix, such as 'PPV', 'TPR', 'TNR', use this class.
155
+ - If you only want to calculate confusion matrix, please use :class:`mindspore.train.ConfusionMatrix` .
156
156
 
157
157
  Args:
158
158
  skip_channel (bool): Whether to skip the measurement calculation on the first channel of the predicted output.
@@ -163,9 +163,9 @@ class ConfusionMatrixMetric(Metric):
163
163
  "threat score", "accuracy", "balanced accuracy", "f1 score",
164
164
  "matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"].
165
165
  Default: ``"sensitivity"`` .
166
- calculation_method (bool): If true, the measurement for each sample will be calculated first.
166
+ calculation_method (bool): If ``True``, the measurement for each sample will be calculated first.
167
167
  If not, the confusion matrix of all samples will be accumulated first.
168
- As for classification task, 'calculation_method' should be False. Default: ``False`` .
168
+ As for classification task, 'calculation_method' should be ``False``. Default: ``False`` .
169
169
  decrease (str): The reduction method on data batch. `decrease` takes effect only when calculation_method
170
170
  is True. Default: ``"mean"`` . Choose from:
171
171
  ["none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"].