mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ from mindspore.ops.auto_generate import avg_pool1d_ext
32
32
  __all__ = ['AvgPool3d', 'MaxPool3d', 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'MaxPool1d', 'FractionalMaxPool2d',
33
33
  'FractionalMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
34
34
  'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'LPPool1d',
35
- 'LPPool2d', 'AvgPool2dExt', 'MaxPool2dExt', 'AvgPool1dExt']
35
+ 'LPPool2d', 'AvgPool2dExt', 'AvgPool3dExt', 'MaxPool2dExt', 'AvgPool1dExt']
36
36
 
37
37
 
38
38
  class _PoolNd(Cell):
@@ -299,11 +299,12 @@ class MaxPool3d(_PoolNd):
299
299
  For Atlas training series products, this interface is not supported.
300
300
 
301
301
  Args:
302
- kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
302
+ kernel_size (Union[int, tuple[int]], optional): The size of kernel used to take the maximum value,
303
303
  is an int number or a single element tuple that represents depth, height and width of the kernel, or a tuple
304
304
  of three int numbers that represent depth, height and width respectively.
305
305
  The value must be a positive integer. Default: ``1`` .
306
- stride (Union[int, tuple[int]]): The moving stride of pooling operation, an int number or a single element tuple
306
+ stride (Union[int, tuple[int]], optional): The moving stride of pooling operation,
307
+ an int number or a single element tuple
307
308
  that represents the moving stride of pooling kernel in the directions of depth, height and the width,
308
309
  or a tuple of three int numbers that represent depth, height and width of movement respectively.
309
310
  The value must be a positive integer. If the value is None, the default value `kernel_size` is used.
@@ -324,18 +325,19 @@ class MaxPool3d(_PoolNd):
324
325
  in the depth, height and width dimension is determined by the `padding` parameter.
325
326
  If this mode is set, `padding` must be greater than or equal to 0.
326
327
 
327
- padding (Union(int, tuple[int], list[int])): Pooling padding value. Default: ``0`` .
328
+ padding (Union(int, tuple[int], list[int]), optional): Pooling padding value. Default: ``0`` .
328
329
  `padding` can only be an integer or a tuple/list containing one or three integers.
329
330
  If `padding` is an integer or a tuple/list containing one integer, it will be padded in six directions of
330
331
  front, back, top, bottom, left and right of the input. If `padding` is a tuple/list containing three
331
332
  integers, it will be padded in front and back of the input `padding[0]` times, up and down `padding[1]`
332
333
  times, and left and right of the input `padding[2]` times.
333
- dilation (Union(int, tuple[int])): The spacing between the elements of the kernel in convolution,
334
+ dilation (Union(int, tuple[int]), optional): The spacing between the elements of the kernel in convolution,
334
335
  used to increase the receptive field of the pooling operation. If it is a tuple, it must contain one or
335
336
  three integers. Default: ``1`` .
336
- return_indices (bool): If ``True`` , output is a Tuple of 2 Tensors, representing the maxpool result and where
337
+ return_indices (bool, optional): If ``True`` , output is a Tuple of 2 Tensors,
338
+ representing the maxpool result and where
337
339
  the max values are generated. Otherwise, only the maxpool result is returned. Default: ``False`` .
338
- ceil_mode (bool): If ``True``, use ceil to calculate output shape.
340
+ ceil_mode (bool, optional): If ``True``, use ceil to calculate output shape.
339
341
  If ``False``, use ceil to calculate output shape. Default: ``False`` .
340
342
 
341
343
  Inputs:
@@ -713,9 +715,9 @@ class MaxPool1d(_PoolNd):
713
715
  \text{input}(N_i, C_j, s_0 \times l + n)
714
716
 
715
717
  Args:
716
- kernel_size (int): The size of kernel used to take the max value, Default: ``1`` .
717
- stride (int): The distance of kernel moving, an int number that represents
718
- the width of movement is stride, Default: ``1`` .
718
+ kernel_size (int, optional): The size of kernel used to take the max value. Default: ``1`` .
719
+ stride (int, optional): The distance of kernel moving, an int number that represents
720
+ the width of movement is stride. Default: ``1`` .
719
721
  pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
720
722
  ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
721
723
 
@@ -731,24 +733,25 @@ class MaxPool1d(_PoolNd):
731
733
  at the begin and end is determined by the `padding` parameter.
732
734
  If this mode is set, `padding` must be greater than or equal to 0.
733
735
 
734
- padding (Union(int, tuple[int], list[int])): Padding value for the pooling. Default value is ``0``.
736
+ padding (Union(int, tuple[int], list[int]), optional): Padding value for the pooling. Default value is ``0``.
735
737
  padding can only be an integer or a tuple/list containing a single integer, in which case padding times or
736
738
  padding[0] times are padded on both sides of the input.
737
- dilation (Union(int, tuple[int])): The spacing between the elements of the kernel in convolution,
739
+ dilation (Union(int, tuple[int]), optional): The spacing between the elements of the kernel in convolution,
738
740
  used to increase the receptive field of the pooling operation. If it is a tuple, its length can only be 1.
739
741
  Default: ``1`` .
740
- return_indices (bool): If ``True`` , the function will return both the result of max pooling and the indices of
742
+ return_indices (bool, optional): If ``True`` , the function will return
743
+ both the result of max pooling and the indices of
741
744
  the max elements. Default: ``False`` .
742
- ceil_mode (bool): If True, use ceil to compute the output shape instead of floor. Default: ``False`` .
745
+ ceil_mode (bool, optional): If True, use ceil to compute the output shape instead of floor. Default: ``False`` .
743
746
 
744
747
  Inputs:
745
748
  - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`.
746
749
 
747
750
  Outputs:
748
- If `return_indices` is False, output is a Tensor, with shape :math:`(N, C_{out}, L_{out})` or
751
+ If `return_indices` is ``False``, output is a Tensor, with shape :math:`(N, C_{out}, L_{out})` or
749
752
  :math:`(C_{out}, L_{out})`. It has the same data type as `x`.
750
753
 
751
- If `return_indices` is True, output is a Tuple of 2 Tensors, representing the maxpool result and where
754
+ If `return_indices` is ``True``, output is a Tuple of 2 Tensors, representing the maxpool result and where
752
755
  the max values are generated.
753
756
 
754
757
  - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, L_{out})` or
@@ -1021,6 +1024,47 @@ class AvgPool3d(_PoolNd):
1021
1024
  return out
1022
1025
 
1023
1026
 
1027
+ class AvgPool3dExt(Cell):
1028
+ r"""
1029
+ Applies a 3D average pooling over an input Tensor which can be regarded as
1030
+ a composition of 3D input planes.
1031
+
1032
+ .. warning::
1033
+ This is an experimental API that is subject to change or deletion.
1034
+
1035
+ For details, please refer to :func:`mindspore.mint.nn.functional.avg_pool3d`.
1036
+
1037
+ Supported Platforms:
1038
+ ``Ascend``
1039
+
1040
+ Examples:
1041
+ >>> import mindspore as ms
1042
+ >>> pool = ms.nn.AvgPool3dExt(kernel_size=3, stride=1)
1043
+ >>> x = ms.ops.randn(1, 2, 4, 4, 5).astype(ms.float32)
1044
+ >>> output = pool(x)
1045
+ >>> print(output.shape)
1046
+ (1, 2, 2, 2, 3)
1047
+ >>> x1 = ms.ops.randn(6, 5, 7, 7, 5).astype(ms.float32)
1048
+ >>> pool2 = ms.nn.AvgPool3dExt(4, stride=2, padding=(2, 2, 1), divisor_override=10)
1049
+ >>> output2 = pool2(x1)
1050
+ >>> print(output2.shape)
1051
+ (6, 5, 4, 4, 2)
1052
+ """
1053
+ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
1054
+ count_include_pad=True, divisor_override=None):
1055
+ super(AvgPool3dExt, self).__init__()
1056
+ self.kernel_size = kernel_size
1057
+ self.stride = stride
1058
+ self.padding = padding
1059
+ self.ceil_mode = ceil_mode
1060
+ self.count_include_pad = count_include_pad
1061
+ self.divisor_override = divisor_override
1062
+
1063
+ def construct(self, input):
1064
+ return ops.function.nn_func.avg_pool3d_ext(input, self.kernel_size, self.stride, self.padding,
1065
+ self.ceil_mode, self.count_include_pad, self.divisor_override)
1066
+
1067
+
1024
1068
  class AvgPool1dExt(Cell):
1025
1069
  r"""
1026
1070
  Applies a 1D average pooling over an input Tensor which can be regarded as
@@ -1270,8 +1314,8 @@ class AvgPool1d(_PoolNd):
1270
1314
  This interface currently does not support Atlas A2 training series products.
1271
1315
 
1272
1316
  Args:
1273
- kernel_size (int): The size of kernel window used to take the average value, Default: ``1`` .
1274
- stride (int): The distance of kernel moving, an int number that represents
1317
+ kernel_size (int, optional): The size of kernel window used to take the average value, Default: ``1`` .
1318
+ stride (int, optional): The distance of kernel moving, an int number that represents
1275
1319
  the width of movement is strides, Default: ``1`` .
1276
1320
  pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
1277
1321
  ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
@@ -1282,17 +1326,20 @@ class AvgPool1d(_PoolNd):
1282
1326
  uniformly distributed around the input, if it is odd, the excess padding is goes to the right side.
1283
1327
  If this mode is set, `padding` must be 0.
1284
1328
  - ``"valid"``: No padding is applied to the input, and the output returns the maximum
1285
- possible length. Extra pixels that could not complete a full stride will
1286
- be discarded. If this mode is set, `padding` must be 0.
1329
+ possible length. If a full stride cannot be formed, the extra pixels will be discarded.
1330
+ If this mode is set, `padding` must be 0.
1287
1331
  - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
1288
1332
  at the begin and end is determined by the `padding` parameter.
1289
1333
  If this mode is set, `padding` must be greater than or equal to 0.
1290
1334
 
1291
- padding (Union(int, tuple[int], list[int])): Pooling padding value, only ``"pad"`` mode can be set to non-zero.
1335
+ padding (Union(int, tuple[int], list[int]), optional): Pooling padding value,
1336
+ only ``"pad"`` mode can be set to non-zero.
1292
1337
  Default: ``0`` . padding can only be an integer or a tuple/list containing a single integer, in which case
1293
1338
  padding times or padding[0] times are padded on both sides of the input.
1294
- ceil_mode (bool): If ``True`` , use ceil to compute the output shape instead of floor. Default: ``False`` .
1295
- count_include_pad (bool): If ``True`` , averaging calculation will include the zero-padding. Default: ``True`` .
1339
+ ceil_mode (bool, optional): If ``True`` , use ceil to compute the output shape instead of floor.
1340
+ Default: ``False`` .
1341
+ count_include_pad (bool, optional): If ``True`` , averaging calculation will include the zero-padding.
1342
+ Default: ``True`` .
1296
1343
 
1297
1344
  Inputs:
1298
1345
  - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`.
@@ -1728,13 +1775,14 @@ class AdaptiveMaxPool2d(Cell):
1728
1775
  \end{align}
1729
1776
 
1730
1777
  Note:
1731
- Ascend platform only supports float16 type for input.
1778
+ In KBK mode, `output_size` does not support mutable.
1732
1779
 
1733
1780
  Args:
1734
1781
  output_size (Union[int, tuple]): The target output size. `output_size` can be a tuple :math:`(H, W)`,
1735
1782
  or an int H for :math:`(H, H)`. :math:`H` and :math:`W` can be int or None.
1736
1783
  If it is None, it means the output size is the same as the input size.
1737
- return_indices (bool): If `return_indices` is ``True`` , the indices of max value would be output.
1784
+ return_indices (bool, optional): Whether to output the index of the maximum value.
1785
+ If `return_indices` is ``True`` , the indices of max value would be output.
1738
1786
  Default: ``False`` .
1739
1787
 
1740
1788
  Inputs:
@@ -1797,15 +1845,11 @@ class AdaptiveMaxPool2d(Cell):
1797
1845
  def __init__(self, output_size, return_indices=False):
1798
1846
  """Initialize AdaptiveMaxPool2d."""
1799
1847
  super(AdaptiveMaxPool2d, self).__init__()
1800
- validator.check_value_type('return_indices', return_indices, [bool], self.cls_name)
1801
- self.adaptive_max_pool2d = ops.AdaptiveMaxPool2D(output_size)
1848
+ self.output_size = output_size
1802
1849
  self.return_indices = return_indices
1803
1850
 
1804
1851
  def construct(self, input):
1805
- output = self.adaptive_max_pool2d(input)
1806
- if self.return_indices:
1807
- return output
1808
- return output[0]
1852
+ return ops.adaptive_max_pool2d(input, self.output_size, self.return_indices)
1809
1853
 
1810
1854
 
1811
1855
  class AdaptiveMaxPool3d(Cell):
@@ -54,16 +54,16 @@ class MultiheadAttention(Cell):
54
54
  embed_dim (int): Total dimension of MultiheadAttention.
55
55
  num_heads (int): Number of attention heads. Note that `embed_dim` will be split
56
56
  across `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`).
57
- dropout (float): Dropout probability of `attn_output_weights`. Default: ``0.0``.
58
- has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``.
59
- add_bias_kv (bool): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
60
- add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at axis=1.
57
+ dropout (float, optional): Dropout probability of `attn_output_weights`. Default: ``0.0``.
58
+ has_bias (bool, optional): Whether adds bias to input / output projection layers. Default: ``True``.
59
+ add_bias_kv (bool, optional): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
60
+ add_zero_attn (bool, optional): Whether adds a new batch of zeros to the key and value sequences at axis=1.
61
61
  Default: ``False``.
62
- kdim (int): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
63
- vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
64
- batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
62
+ kdim (int, optional): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
63
+ vdim (int, optional): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
64
+ batch_first (bool, optional): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
65
65
  else :math:`(seq, batch, feature)` . Default: ``False``.
66
- dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
66
+ dtype (:class:`mindspore.dtype`, optional): Data type of Parameter. Default: ``mstype.float32`` .
67
67
 
68
68
  Inputs:
69
69
  - **query** (Tensor) - The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
@@ -85,7 +85,7 @@ class MultiheadAttention(Cell):
85
85
  For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
86
86
  the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
87
87
  Supported float types: float16, float32, float64. Default: ``None``.
88
- - **need_weights** (bool) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
88
+ - **need_weights** (bool, optional) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
89
89
  Default: ``True``.
90
90
  - **attn_mask** (Tensor, optional) - If specified, a 2D or 3D mask preventing attention to certain positions.
91
91
  Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
@@ -94,7 +94,8 @@ class MultiheadAttention(Cell):
94
94
  in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
95
95
  to attend. For a float mask, the mask values will be added to the attention weight.
96
96
  Supported float types: float16, float32, float64. Default: ``None``.
97
- - **average_attn_weights** (bool) - If true, indicates that the returned `attn_weights` should be averaged
97
+ - **average_attn_weights** (bool, optional) - If true, indicates that
98
+ the returned `attn_weights` should be averaged
98
99
  across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
99
100
  has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
100
101
 
@@ -80,7 +80,8 @@ class ExponentialDecayLR(LearningRateSchedule):
80
80
  learning_rate (float): The initial value of learning rate.
81
81
  decay_rate (float): The decay rate.
82
82
  decay_steps (int): Number of steps to decay over.
83
- is_stair (bool): If true, learning rate is decayed once every `decay_steps` time. Default: ``False`` .
83
+ is_stair (bool, optional): If ``True``, learning rate is decayed once every `decay_steps` time.
84
+ Default: ``False`` .
84
85
 
85
86
  Inputs:
86
87
  - **global_step** (Tensor) - The current step number. :math:`current\_step` in the above formula.
@@ -223,7 +224,8 @@ class InverseDecayLR(LearningRateSchedule):
223
224
  learning_rate (float): The initial value of learning rate.
224
225
  decay_rate (float): The decay rate.
225
226
  decay_steps (int): Number of steps to decay over.
226
- is_stair (bool): If true, learning rate decay once every `decay_steps` times. If False, the learning rate
227
+ is_stair (bool, optional): If true, learning rate decay once every `decay_steps` times.
228
+ If False, the learning rate
227
229
  decays for every step. Default: ``False`` .
228
230
 
229
231
  Inputs:
mindspore/nn/loss/loss.py CHANGED
@@ -127,7 +127,8 @@ class LossBase(Cell):
127
127
  Args:
128
128
  x (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of
129
129
  additional dimensions.
130
- weights (Union[float, Tensor]): Optional `Tensor` whose rank is either 0, or the same rank as inputs,
130
+ weights (Union[float, Tensor], optional): Weights. When `weights` is a Tensor,
131
+ the rank is either 0, or the same rank as inputs,
131
132
  and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
132
133
  or the same as the corresponding inputs dimension). Default: ``1.0`` .
133
134
 
@@ -617,7 +618,8 @@ class MarginRankingLoss(LossBase):
617
618
 
618
619
  class SmoothL1Loss(LossBase):
619
620
  r"""
620
- SmoothL1 loss function, if the absolute error element-wise between the predicted value and the target value
621
+ SmoothL1 loss function. Compare the error value element-wise and
622
+ if the absolute error between the predicted value and the target value
621
623
  is less than the set threshold `beta`, the square term is used, otherwise the absolute error term is used.
622
624
 
623
625
  Given two input :math:`x,\ y`, the SmoothL1Loss can be described as follows:
@@ -667,11 +669,13 @@ class SmoothL1Loss(LossBase):
667
669
 
668
670
  - Ascend: float16, float32, bfloat16.
669
671
  - CPU/GPU: float16, float32, float64.
672
+
670
673
  - **labels** (Tensor) - Ground truth data.
671
674
 
672
675
  - CPU/Ascend: has the same shape as the `logits`,
673
676
  `logits` and `labels` comply with the implicit type conversion rules to make the data types consistent.
674
677
  - GPU: has the same shape and dtype as the `logits`.
678
+
675
679
  Outputs:
676
680
  Tensor, if `reduction` is ``'none'``, then output is a tensor with the same shape as `logits`.
677
681
  Otherwise the shape of output tensor is :math:`()`.
@@ -732,16 +736,19 @@ class SoftMarginLoss(LossBase):
732
736
  - ``'sum'``: the output elements will be summed.
733
737
 
734
738
  Inputs:
735
- - **logits** (Tensor) - Predict data. Data type must be float16 or float32.
736
- - **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
739
+ - **logits** (Tensor) - Predict data. Data type must be float16, float32,
740
+ bfloat16 (Among them, the Atlas training series products do not support bfloat16).
741
+ - **labels** (Tensor) - Ground truth data, with the same shape as `logits`.
742
+ In GE mode, the data type should be the same as `logits`.
737
743
 
738
744
  Outputs:
739
- Tensor or Scalar, if `reduction` is ``"none"``, its shape is the same as `logits`.
745
+ Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `logits`.
740
746
  Otherwise, a scalar value will be returned.
741
747
 
742
748
  Raises:
743
749
  TypeError: If `logits` or `labels` is not a Tensor.
744
- TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
750
+ TypeError: If dtype of `logits` or `labels` is not float16, float32,
751
+ bfloat16 (Among them, the Atlas training series products do not support bfloat16).
745
752
  ValueError: If shape of `logits` is not the same as `labels`.
746
753
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
747
754
 
@@ -762,10 +769,10 @@ class SoftMarginLoss(LossBase):
762
769
 
763
770
  def __init__(self, reduction='mean'):
764
771
  super(SoftMarginLoss, self).__init__()
765
- self.soft_margin_loss = P.SoftMarginLoss(reduction)
772
+ self.reduction = reduction
766
773
 
767
774
  def construct(self, logits, labels):
768
- return self.soft_margin_loss(logits, labels)
775
+ return F.soft_margin_loss(logits, labels, self.reduction)
769
776
 
770
777
 
771
778
  class SoftmaxCrossEntropyWithLogits(LossBase):
@@ -813,8 +820,8 @@ class SoftmaxCrossEntropyWithLogits(LossBase):
813
820
 
814
821
  Raises:
815
822
  TypeError: If `sparse` is not a bool.
816
- TypeError: If `sparse` is True and dtype of `labels` is neither int32 nor int64.
817
- TypeError: If `sparse` is False and dtype of `labels` is neither float16 not float32.
823
+ TypeError: If `sparse` is ``True`` and dtype of `labels` is neither int32 nor int64.
824
+ TypeError: If `sparse` is ``False`` and dtype of `labels` is neither float16 not float32.
818
825
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
819
826
 
820
827
  Supported Platforms:
@@ -893,8 +900,8 @@ class DiceLoss(LossBase):
893
900
  :math:`pred` represent `logits`, :math:`true` represent `labels` .
894
901
 
895
902
  Args:
896
- smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
897
- Default: ``1e-5`` .
903
+ smooth (float, optional): A term added to the denominator to improve numerical stability.
904
+ Should be greater than 0. Default: ``1e-5`` .
898
905
 
899
906
  Inputs:
900
907
  - **logits** (Tensor) - Input predicted value. The data type must be float16 or float32.
@@ -938,11 +945,12 @@ class DiceLoss(LossBase):
938
945
  if label.dtype == mstype.uint8:
939
946
  raise TypeError(f"For '{self.cls_name}', the dtype of 'labels' can not be uint8.")
940
947
  intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
941
- unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
942
- self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
948
+ unionset_part1 = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1)))
949
+ unionset_part2 = self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
950
+ unionset = ops.add(unionset_part1, unionset_part2)
943
951
 
944
- single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
945
- dice_loss = 1 - single_dice_coeff
952
+ single_dice_coeff = (2 * intersection) / ops.add(unionset, self.smooth)
953
+ dice_loss = ops.sub(1, single_dice_coeff)
946
954
 
947
955
  return dice_loss
948
956
 
@@ -1058,7 +1066,7 @@ class MultiClassDiceLoss(LossBase):
1058
1066
  dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
1059
1067
  if self.weights is not None:
1060
1068
  _check_weights(self.weights.shape[0], label.shape[1], self.cls_name)
1061
- dice_loss *= self.weights[i]
1069
+ dice_loss = dice_loss * self.weights[i]
1062
1070
  total_loss += dice_loss
1063
1071
 
1064
1072
  return total_loss / label.shape[1]
@@ -2571,7 +2579,7 @@ class KLDivLoss(LossBase):
2571
2579
  the updating formulas of KLDivLoss algorithm are as follows,
2572
2580
 
2573
2581
  .. math::
2574
- L(x, target) = target \cdot (\log target - x)
2582
+ L(x, target) = target \cdot (\log target - \log x)
2575
2583
 
2576
2584
  Then,
2577
2585
 
@@ -2865,7 +2873,7 @@ class HingeEmbeddingLoss(LossBase):
2865
2873
  where :math:`L = \{l_1,\dots,l_N\}^\top`.
2866
2874
 
2867
2875
  Args:
2868
- margin (float, int): Threshold defined by Hinge Embedding Loss :math:`margin`.
2876
+ margin (float, int, optional): Threshold defined by Hinge Embedding Loss :math:`margin`.
2869
2877
  Represented as :math:`\Delta` in the formula. Default: ``1.0`` .
2870
2878
  reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2871
2879
  ``'sum'`` . Default: ``'mean'`` .
@@ -113,8 +113,8 @@ class Adagrad(Optimizer):
113
113
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
114
114
  one group of `params`.
115
115
 
116
- accum (float): The starting value for :math:`h`, must be zero or positive values. Default: ``0.1`` .
117
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` .
116
+ accum (float, optional): The starting value for :math:`h`, must be zero or positive values. Default: ``0.1`` .
117
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.001`` .
118
118
 
119
119
  - float: The fixed learning rate value. Must be equal to or greater than 0.
120
120
 
@@ -130,13 +130,14 @@ class Adagrad(Optimizer):
130
130
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
131
131
  with step as the input to get the learning rate of current step.
132
132
 
133
- update_slots (bool): Whether the :math:`h` will be updated. Default: ``True`` .
134
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
133
+ update_slots (bool, optional): Whether the :math:`h` will be updated. Default: ``True`` .
134
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
135
+ use the default value.
135
136
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
136
137
  `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
137
138
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
138
139
  Default: ``1.0`` .
139
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
140
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
140
141
 
141
142
  - float: The fixed weight decay value. Must be equal to or greater than 0.
142
143
 
@@ -68,8 +68,8 @@ class Adadelta(Optimizer):
68
68
 
69
69
  Args:
70
70
  params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
71
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
72
- "order_params" are the keys can be parsed.
71
+ `params` is a list of `dict`, the string `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
72
+ `"order_params"` are the keys can be parsed.
73
73
 
74
74
  - params: Required. Parameters in current group. The value must be a list of `Parameter`.
75
75
 
@@ -93,7 +93,7 @@ class Adadelta(Optimizer):
93
93
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
94
94
  one group of `params`.
95
95
 
96
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1.0`` .
96
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1.0`` .
97
97
 
98
98
  - float: The fixed learning rate value. Must be equal to or greater than 0.
99
99
 
@@ -109,14 +109,16 @@ class Adadelta(Optimizer):
109
109
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
110
110
  with step as the input to get the learning rate of current step.
111
111
 
112
- rho (float): Decay rate, must be in range [0.0, 1.0]. Default: ``0.9`` .
113
- epsilon (float): A small value added for numerical stability, must be non-negative. Default: ``1e-6`` .
114
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
112
+ rho (float, optional): Decay rate, must be in range [0.0, 1.0]. Default: ``0.9`` .
113
+ epsilon (float, optional): A small value added for numerical stability, must be non-negative.
114
+ Default: ``1e-6`` .
115
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
116
+ use the default value.
115
117
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
116
118
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
117
119
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
118
120
  Default: ``1.0`` .
119
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
121
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
120
122
 
121
123
  - float: The fixed weight decay value. Must be equal to or greater than 0.
122
124
 
@@ -406,7 +406,7 @@ class AdaFactor(Optimizer):
406
406
  """
407
407
  return False
408
408
 
409
- @jit
409
+ @jit(backend="ms_backend")
410
410
  def construct(self, gradients):
411
411
  gradients = self.flatten_gradients(gradients)
412
412
  lr = self.get_lr()
@@ -566,7 +566,7 @@ class Adam(Optimizer):
566
566
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
567
567
  one group of `params`.
568
568
 
569
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1e-3`` .
569
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1e-3`` .
570
570
 
571
571
  - float: The fixed learning rate value. Must be equal to or greater than 0.
572
572
 
@@ -582,23 +582,26 @@ class Adam(Optimizer):
582
582
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
583
583
  with step as the input to get the learning rate of current step.
584
584
 
585
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
585
+ beta1 (float, optional): The exponential decay rate for the 1st moment estimations.
586
+ Should be in range (0.0, 1.0).
586
587
  Default: ``0.9`` .
587
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
588
+ beta2 (float, optional): The exponential decay rate for the 2nd moment estimations.
589
+ Should be in range (0.0, 1.0).
588
590
  Default: ``0.999`` .
589
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
591
+ eps (float, optional): Term added to the denominator to improve numerical stability. Should be greater than 0.
590
592
  Default: ``1e-8`` .
591
- use_locking (bool): Whether to enable a lock to protect the updating process of variable tensors.
593
+ use_locking (bool, optional): Whether to enable a lock to protect the updating process of variable tensors.
592
594
  If ``true`` , updates of the `w`, `m`, and `v` tensors will be protected by a lock.
593
595
  If ``false`` , the result is unpredictable. Default: ``False`` .
594
- use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
596
+ use_nesterov (bool, optional): Whether to use Nesterov Accelerated Gradient (NAG) algorithm
597
+ to update the gradients.
595
598
  If ``true`` , update the gradients using NAG.
596
599
  If ``false`` , update the gradients without using NAG. Default: ``False`` .
597
- use_amsgrad (bool): Whether to use Amsgrad algorithm to update the gradients.
600
+ use_amsgrad (bool, optional): Whether to use Amsgrad algorithm to update the gradients.
598
601
  If ``true`` , update the gradients using Amsgrad.
599
602
  If ``false`` , update the gradients without using Amsgrad. Default: ``False`` .
600
603
 
601
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
604
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
602
605
 
603
606
  - float: The fixed weight decay value. Must be equal to or greater than 0.
604
607
 
@@ -607,11 +610,12 @@ class Adam(Optimizer):
607
610
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
608
611
  the Cell with step as the input to get the weight decay value of current step.
609
612
 
610
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
613
+ loss_scale (float, optional): A floating point value for the loss scale.
614
+ Should be greater than 0. In general, use the
611
615
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
612
616
  `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
613
617
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
614
- Default: 1.0.
618
+ Default: ``1.0``.
615
619
 
616
620
  kwargs:
617
621
 
@@ -1024,7 +1028,7 @@ class AdamWeightDecay(Optimizer):
1024
1028
  self.fused_opt = P.AdamWeightDecay()
1025
1029
  self.use_fused_opt = True
1026
1030
 
1027
- @jit
1031
+ @jit(backend="ms_backend")
1028
1032
  def construct(self, gradients):
1029
1033
  gradients = self.flatten_gradients(gradients)
1030
1034
  weight_decay = self.get_weight_decay()
@@ -1244,7 +1248,7 @@ class AdamOffload(Optimizer):
1244
1248
  self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
1245
1249
  self.opt.set_device("CPU")
1246
1250
 
1247
- @jit
1251
+ @jit(backend="ms_backend")
1248
1252
  def construct(self, gradients):
1249
1253
  params = self._parameters
1250
1254
  moment1 = self.moment1
@@ -118,12 +118,12 @@ class AdaMax(Optimizer):
118
118
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
119
119
  with step as the input to get the learning rate of current step.
120
120
 
121
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
122
- Default: ``0.9`` .
123
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
124
- Default: ``0.999`` .
125
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
126
- Default: ``1e-08`` .
121
+ beta1 (float, optional): The exponential decay rate for the 1st moment estimations.
122
+ Should be in range (0.0, 1.0). Default: ``0.9`` .
123
+ beta2 (float, optional): The exponential decay rate for the 2nd moment estimations.
124
+ Should be in range (0.0, 1.0). Default: ``0.999`` .
125
+ eps (float, optional): Term added to the denominator to improve numerical stability. Should be greater than 0.
126
+ Default: ``1e-08`` .
127
127
 
128
128
  weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
129
129
 
@@ -134,7 +134,8 @@ class AdaMax(Optimizer):
134
134
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
135
135
  the Cell with step as the input to get the weight decay value of current step.
136
136
 
137
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
137
+ loss_scale (float, optional): A floating point value for the loss scale. Should be greater than 0.
138
+ In general, use the
138
139
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
139
140
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
140
141
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
@@ -420,17 +420,17 @@ class AdaSumByGradWrapCell(Cell):
420
420
  and the subscripts represent different devices in the data-parallel dimension.
421
421
 
422
422
  Note:
423
- When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required.
424
- Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
425
- It is recommended to using AdaSumByGradWrapCell in semi auto parallel/auto parallel mode. In data parallel
426
- mode, we recommend to using mindspore.boost to applying AdaSum.
423
+ - It is recommended to using AdaSumByGradWrapCell in semi auto parallel/auto parallel mode. In data parallel
424
+ mode, we recommend to using mindspore.boost to applying AdaSum.
425
+ - When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required.
426
+ Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
427
427
 
428
428
  Args:
429
429
  optimizer (Union[Cell]): Optimizer for updating the weights. The construct function of the optimizer
430
430
  requires only one input.
431
431
 
432
432
  Inputs:
433
- - **grads** (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer.
433
+ - **grads** (Tuple[Tensor]) - Tuple of gradients, same with the input of passed optimizer.
434
434
 
435
435
  Raises:
436
436
  RuntimeError: If `parallel_mode` uses `stand_alone` mode, AdaSum only supports use in distributed scenarios.
@@ -180,7 +180,7 @@ class ASGD(Optimizer):
180
180
  self.cast = P.Cast()
181
181
  self.squeeze = P.Squeeze()
182
182
 
183
- @jit
183
+ @jit(backend="ms_backend")
184
184
  def construct(self, gradients):
185
185
  gradients = self.flatten_gradients(gradients)
186
186
  gradients = self.decay_weight(gradients)