mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -26,12 +26,12 @@ from mindspore.common.tensor import Tensor
26
26
  from mindspore.common.parameter import Parameter
27
27
  from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
28
28
  HeUniform, Uniform, _calculate_fan_in_and_fan_out
29
- from mindspore.ops.function.nn_func import multi_head_attention_forward
30
29
  from mindspore.nn.cell import Cell
31
30
  from .basic import Dense, Dropout
32
31
  from .activation import ReLU, GELU
33
32
  from .normalization import LayerNorm
34
33
  from .container import CellList
34
+
35
35
  __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer',
36
36
  'TransformerEncoder', 'TransformerDecoder', 'Transformer']
37
37
 
@@ -54,16 +54,16 @@ class MultiheadAttention(Cell):
54
54
  embed_dim (int): Total dimension of MultiheadAttention.
55
55
  num_heads (int): Number of attention heads. Note that `embed_dim` will be split
56
56
  across `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`).
57
- dropout (float): Dropout probability of `attn_output_weights`. Default: ``0.0``.
58
- has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``.
59
- add_bias_kv (bool): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
60
- add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at axis=1.
57
+ dropout (float, optional): Dropout probability of `attn_output_weights`. Default: ``0.0``.
58
+ has_bias (bool, optional): Whether adds bias to input / output projection layers. Default: ``True``.
59
+ add_bias_kv (bool, optional): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
60
+ add_zero_attn (bool, optional): Whether adds a new batch of zeros to the key and value sequences at axis=1.
61
61
  Default: ``False``.
62
- kdim (int): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
63
- vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
64
- batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
62
+ kdim (int, optional): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
63
+ vdim (int, optional): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
64
+ batch_first (bool, optional): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
65
65
  else :math:`(seq, batch, feature)` . Default: ``False``.
66
- dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
66
+ dtype (:class:`mindspore.dtype`, optional): Data type of Parameter. Default: ``mstype.float32`` .
67
67
 
68
68
  Inputs:
69
69
  - **query** (Tensor) - The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
@@ -85,7 +85,7 @@ class MultiheadAttention(Cell):
85
85
  For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
86
86
  the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
87
87
  Supported float types: float16, float32, float64. Default: ``None``.
88
- - **need_weights** (bool) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
88
+ - **need_weights** (bool, optional) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
89
89
  Default: ``True``.
90
90
  - **attn_mask** (Tensor, optional) - If specified, a 2D or 3D mask preventing attention to certain positions.
91
91
  Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
@@ -94,7 +94,8 @@ class MultiheadAttention(Cell):
94
94
  in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
95
95
  to attend. For a float mask, the mask values will be added to the attention weight.
96
96
  Supported float types: float16, float32, float64. Default: ``None``.
97
- - **average_attn_weights** (bool) - If true, indicates that the returned `attn_weights` should be averaged
97
+ - **average_attn_weights** (bool, optional) - If true, indicates that
98
+ the returned `attn_weights` should be averaged
98
99
  across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
99
100
  has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
100
101
 
@@ -212,7 +213,7 @@ class MultiheadAttention(Cell):
212
213
  query, key, value = [x.swapaxes(1, 0) for x in (query, key, value)]
213
214
 
214
215
  if not self._qkv_same_embed_dim:
215
- attn_output, attn_output_weights = multi_head_attention_forward(
216
+ attn_output, attn_output_weights = ops.function.nn_func.multi_head_attention_forward(
216
217
  query, key, value, self.embed_dim, self.num_heads,
217
218
  self.in_proj_weight, self.in_proj_bias,
218
219
  self.bias_k, self.bias_v, self.add_zero_attn,
@@ -224,7 +225,7 @@ class MultiheadAttention(Cell):
224
225
  v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
225
226
  k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
226
227
  else:
227
- attn_output, attn_output_weights = multi_head_attention_forward(
228
+ attn_output, attn_output_weights = ops.function.nn_func.multi_head_attention_forward(
228
229
  query, key, value, self.embed_dim, self.num_heads,
229
230
  self.in_proj_weight, self.in_proj_bias,
230
231
  self.bias_k, self.bias_v, self.add_zero_attn,
@@ -328,7 +329,7 @@ class TransformerEncoderLayer(Cell):
328
329
  self.activation1 = activation
329
330
 
330
331
  if not isinstance(activation, str) and not isinstance(activation, Cell) \
331
- and not callable(activation):
332
+ and not callable(activation):
332
333
  raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
333
334
  f" but get {activation}.")
334
335
  if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
@@ -360,15 +361,23 @@ class TransformerEncoderLayer(Cell):
360
361
  raise AssertionError(
361
362
  "only bool and floating types of key_padding_mask are supported")
362
363
 
363
- x = src
364
+ input_data = src
365
+
364
366
  if self.norm_first:
365
- x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
- x = x + self._ff_block(self.norm2(x))
367
+ normed_input = self.norm1(input_data)
368
+ sa_block_result = self._sa_block(normed_input, src_mask, src_key_padding_mask)
369
+ input_data = input_data + sa_block_result
370
+ normed_updated_input = self.norm2(input_data)
371
+ ff_block_result = self._ff_block(normed_updated_input)
372
+ input_data = input_data + ff_block_result
367
373
  else:
368
- x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
369
- x = self.norm2(x + self._ff_block(x))
374
+ sa_block_result = self._sa_block(input_data, src_mask, src_key_padding_mask)
375
+ normed_sa_result = self.norm1(input_data + sa_block_result)
376
+ input_data = normed_sa_result
377
+ ff_block_result = self._ff_block(input_data)
378
+ input_data = self.norm2(input_data + ff_block_result)
370
379
 
371
- return x
380
+ return input_data
372
381
 
373
382
  def _sa_block(self, x, attn_mask, key_padding_mask):
374
383
  x = self.self_attn(x, x, x,
@@ -480,7 +489,7 @@ class TransformerDecoderLayer(Cell):
480
489
  self.activation1 = activation
481
490
 
482
491
  if not isinstance(activation, str) and not isinstance(activation, Cell) \
483
- and not callable(activation):
492
+ and not callable(activation):
484
493
  raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
485
494
  f" but get {activation}.")
486
495
  if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
@@ -507,17 +516,29 @@ class TransformerDecoderLayer(Cell):
507
516
  def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
508
517
  memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
509
518
  memory_key_padding_mask: Optional[Tensor] = None):
510
- x = tgt
519
+ input_data = tgt
520
+
511
521
  if self.norm_first:
512
- x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
513
- x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
514
- x = x + self._ff_block(self.norm3(x))
522
+ normed_input = self.norm1(input_data)
523
+ sa_block_result = self._sa_block(normed_input, tgt_mask, tgt_key_padding_mask)
524
+ input_data = input_data + sa_block_result
525
+ normed_updated_input_1 = self.norm2(input_data)
526
+ mha_block_result = self._mha_block(normed_updated_input_1, memory, memory_mask, memory_key_padding_mask)
527
+ input_data = input_data + mha_block_result
528
+ normed_updated_input_2 = self.norm3(input_data)
529
+ ff_block_result = self._ff_block(normed_updated_input_2)
530
+ input_data = input_data + ff_block_result
515
531
  else:
516
- x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
517
- x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
518
- x = self.norm3(x + self._ff_block(x))
532
+ sa_block_result = self._sa_block(input_data, tgt_mask, tgt_key_padding_mask)
533
+ normed_sa_result = self.norm1(input_data + sa_block_result)
534
+ input_data = normed_sa_result
535
+ mha_block_result = self._mha_block(input_data, memory, memory_mask, memory_key_padding_mask)
536
+ normed_mha_result = self.norm2(input_data + mha_block_result)
537
+ input_data = normed_mha_result
538
+ ff_block_result = self._ff_block(input_data)
539
+ input_data = self.norm3(input_data + ff_block_result)
519
540
 
520
- return x
541
+ return input_data
521
542
 
522
543
  def _sa_block(self, x, attn_mask, key_padding_mask):
523
544
  x = self.self_attn(x, x, x,
@@ -670,17 +691,19 @@ class TransformerDecoder(Cell):
670
691
  def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
671
692
  memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
672
693
  memory_key_padding_mask: Optional[Tensor] = None):
673
- output = tgt
694
+ processed_output = tgt
674
695
  for mod in self.layers:
675
- output = mod(output, memory, tgt_mask=tgt_mask,
676
- memory_mask=memory_mask,
677
- tgt_key_padding_mask=tgt_key_padding_mask,
678
- memory_key_padding_mask=memory_key_padding_mask)
696
+ layer_output = mod(processed_output, memory,
697
+ tgt_mask=tgt_mask,
698
+ memory_mask=memory_mask,
699
+ tgt_key_padding_mask=tgt_key_padding_mask,
700
+ memory_key_padding_mask=memory_key_padding_mask)
701
+ processed_output = layer_output
679
702
 
680
703
  if self.norm is not None:
681
- output = self.norm(output)
704
+ processed_output = self.norm(processed_output)
682
705
 
683
- return output
706
+ return processed_output
684
707
 
685
708
 
686
709
  class Transformer(Cell):
@@ -80,7 +80,8 @@ class ExponentialDecayLR(LearningRateSchedule):
80
80
  learning_rate (float): The initial value of learning rate.
81
81
  decay_rate (float): The decay rate.
82
82
  decay_steps (int): Number of steps to decay over.
83
- is_stair (bool): If true, learning rate is decayed once every `decay_steps` time. Default: ``False`` .
83
+ is_stair (bool, optional): If ``True``, learning rate is decayed once every `decay_steps` time.
84
+ Default: ``False`` .
84
85
 
85
86
  Inputs:
86
87
  - **global_step** (Tensor) - The current step number. :math:`current\_step` in the above formula.
@@ -223,7 +224,9 @@ class InverseDecayLR(LearningRateSchedule):
223
224
  learning_rate (float): The initial value of learning rate.
224
225
  decay_rate (float): The decay rate.
225
226
  decay_steps (int): Number of steps to decay over.
226
- is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: ``False`` .
227
+ is_stair (bool, optional): If true, learning rate decay once every `decay_steps` times.
228
+ If False, the learning rate
229
+ decays for every step. Default: ``False`` .
227
230
 
228
231
  Inputs:
229
232
  - **global_step** (Tensor) - The current step number.
@@ -454,8 +457,9 @@ class WarmUpLR(LearningRateSchedule):
454
457
  tmp\_step= \min(current\_step, warmup\_steps)
455
458
 
456
459
  Args:
457
- learning_rate (float): The initial value of learning rate.
458
- warmup_steps (int): The warm up steps of learning rate.
460
+ learning_rate (float): The initial value of learning rate. The value of `learning_rate` must be greater than 0.
461
+ warmup_steps (int): The warm up steps of learning rate. The value of `warmup_steps` must be greater than
462
+ or equal to 1.
459
463
 
460
464
  Inputs:
461
465
  - **global_step** (Tensor) - The current step number. Shape is :math:`()`.
mindspore/nn/loss/loss.py CHANGED
@@ -24,8 +24,6 @@ from mindspore.common.tensor import Tensor
24
24
  from mindspore.common.parameter import Parameter
25
25
  from mindspore.ops import operations as P
26
26
  from mindspore.ops.operations import _inner_ops as inner
27
- from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
28
- from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
29
27
  from mindspore.ops import functional as F
30
28
  from mindspore import nn
31
29
  from mindspore.ops.primitive import constexpr, _primexpr
@@ -33,7 +31,6 @@ from mindspore.nn.cell import Cell
33
31
  from mindspore.nn.layer.activation import get_activation
34
32
  from mindspore import _checkparam as validator
35
33
  from mindspore import context
36
- from mindspore.ops.auto_generate import l1_loss_ext_op
37
34
 
38
35
 
39
36
  class LossBase(Cell):
@@ -130,7 +127,8 @@ class LossBase(Cell):
130
127
  Args:
131
128
  x (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of
132
129
  additional dimensions.
133
- weights (Union[float, Tensor]): Optional `Tensor` whose rank is either 0, or the same rank as inputs,
130
+ weights (Union[float, Tensor], optional): Weights. When `weights` is a Tensor,
131
+ the rank is either 0, or the same rank as inputs,
134
132
  and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
135
133
  or the same as the corresponding inputs dimension). Default: ``1.0`` .
136
134
 
@@ -319,7 +317,7 @@ class L1LossExt(LossBase):
319
317
  self.reduction = reduction
320
318
 
321
319
  def construct(self, logits, labels):
322
- return l1_loss_ext_op(logits, labels, self.reduction)
320
+ return ops.auto_generate.l1_loss_ext_op(logits, labels, self.reduction)
323
321
 
324
322
 
325
323
  class MSELoss(LossBase):
@@ -620,7 +618,8 @@ class MarginRankingLoss(LossBase):
620
618
 
621
619
  class SmoothL1Loss(LossBase):
622
620
  r"""
623
- SmoothL1 loss function, if the absolute error element-wise between the predicted value and the target value
621
+ SmoothL1 loss function. Compare the error value element-wise and
622
+ if the absolute error between the predicted value and the target value
624
623
  is less than the set threshold `beta`, the square term is used, otherwise the absolute error term is used.
625
624
 
626
625
  Given two input :math:`x,\ y`, the SmoothL1Loss can be described as follows:
@@ -628,11 +627,11 @@ class SmoothL1Loss(LossBase):
628
627
  .. math::
629
628
  L_{i} =
630
629
  \begin{cases}
631
- \frac{0.5 (x_i - y_i)^{2}}{\beta}, & \text{if } |x_i - y_i| < {\beta} \\
632
- |x_i - y_i| - 0.5 {\beta}, & \text{otherwise.}
630
+ \frac{0.5 (x_i - y_i)^{2}}{\text{beta}}, & \text{if } |x_i - y_i| < \text{beta} \\
631
+ |x_i - y_i| - 0.5 * {\text{beta}}, & \text{otherwise.}
633
632
  \end{cases}
634
633
 
635
- Where :math:`{\beta}` represents the threshold `beta`.
634
+ Where :math:`{\text{beta}}` represents the threshold `beta`.
636
635
 
637
636
  If `reduction` is not `none`, then:
638
637
 
@@ -653,8 +652,11 @@ class SmoothL1Loss(LossBase):
653
652
  robust to outliers, and the loss function has better robustness.
654
653
 
655
654
  Args:
656
- beta (float): The loss function calculates the threshold of the transformation between L1Loss and L2Loss.
657
- Default: ``1.0`` .
655
+ beta (number, optional): The loss function calculates the threshold of the transformation
656
+ between L1Loss and L2Loss. Default: ``1.0`` .
657
+
658
+ - Ascend: The value should be equal to or greater than zero.
659
+ - CPU/GPU: The value should be greater than zero.
658
660
  reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
659
661
  ``'sum'`` . Default: ``'none'`` .
660
662
 
@@ -663,22 +665,28 @@ class SmoothL1Loss(LossBase):
663
665
  - ``'sum'``: the output elements will be summed.
664
666
 
665
667
  Inputs:
666
- - **logits** (Tensor) - Predictive value. Tensor of any dimension. Data type must be one of float16 or
667
- float32.
668
- - **labels** (Tensor) - Ground truth data, same shape and dtype as the `logits`.
668
+ - **logits** (Tensor) - Predictive value. Tensor of any dimension. Supported dtypes:
669
+
670
+ - Ascend: float16, float32, bfloat16.
671
+ - CPU/GPU: float16, float32, float64.
672
+
673
+ - **labels** (Tensor) - Ground truth data.
674
+
675
+ - CPU/Ascend: has the same shape as the `logits`,
676
+ `logits` and `labels` comply with the implicit type conversion rules to make the data types consistent.
677
+ - GPU: has the same shape and dtype as the `logits`.
669
678
 
670
679
  Outputs:
671
680
  Tensor, if `reduction` is ``'none'``, then output is a tensor with the same shape as `logits`.
672
681
  Otherwise the shape of output tensor is :math:`()`.
673
682
 
674
683
  Raises:
675
- TypeError: If `beta` is not a float.
676
- ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
677
- TypeError: If `logits` or `labels` are not Tensor.
678
- TypeError: If dtype of `logits` or `labels` is neither float16 not float32.
679
- TypeError: If dtype of `logits` is not the same as `labels`.
680
- ValueError: If `beta` is less than or equal to 0.
684
+ TypeError: If input `logits` or `labels` are not Tensor.
685
+ RuntimeError: If dtype of `logits` or `labels` is not one of float16, float32, float64, bfloat16.
681
686
  ValueError: If shape of `logits` is not the same as `labels`.
687
+ ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
688
+ TypeError: If `beta` is not a float, int or bool.
689
+ RuntimeError: If `beta` is less than or equal to 0.
682
690
 
683
691
  Supported Platforms:
684
692
  ``Ascend`` ``GPU`` ``CPU``
@@ -728,16 +736,19 @@ class SoftMarginLoss(LossBase):
728
736
  - ``'sum'``: the output elements will be summed.
729
737
 
730
738
  Inputs:
731
- - **logits** (Tensor) - Predict data. Data type must be float16 or float32.
732
- - **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
739
+ - **logits** (Tensor) - Predict data. Data type must be float16, float32,
740
+ bfloat16 (Among them, the Atlas training series products do not support bfloat16).
741
+ - **labels** (Tensor) - Ground truth data, with the same shape as `logits`.
742
+ In GE mode, the data type should be the same as `logits`.
733
743
 
734
744
  Outputs:
735
- Tensor or Scalar, if `reduction` is ``"none"``, its shape is the same as `logits`.
745
+ Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `logits`.
736
746
  Otherwise, a scalar value will be returned.
737
747
 
738
748
  Raises:
739
749
  TypeError: If `logits` or `labels` is not a Tensor.
740
- TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
750
+ TypeError: If dtype of `logits` or `labels` is not float16, float32,
751
+ bfloat16 (Among them, the Atlas training series products do not support bfloat16).
741
752
  ValueError: If shape of `logits` is not the same as `labels`.
742
753
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
743
754
 
@@ -758,10 +769,10 @@ class SoftMarginLoss(LossBase):
758
769
 
759
770
  def __init__(self, reduction='mean'):
760
771
  super(SoftMarginLoss, self).__init__()
761
- self.soft_margin_loss = P.SoftMarginLoss(reduction)
772
+ self.reduction = reduction
762
773
 
763
774
  def construct(self, logits, labels):
764
- return self.soft_margin_loss(logits, labels)
775
+ return F.soft_margin_loss(logits, labels, self.reduction)
765
776
 
766
777
 
767
778
  class SoftmaxCrossEntropyWithLogits(LossBase):
@@ -809,8 +820,8 @@ class SoftmaxCrossEntropyWithLogits(LossBase):
809
820
 
810
821
  Raises:
811
822
  TypeError: If `sparse` is not a bool.
812
- TypeError: If `sparse` is True and dtype of `labels` is neither int32 nor int64.
813
- TypeError: If `sparse` is False and dtype of `labels` is neither float16 not float32.
823
+ TypeError: If `sparse` is ``True`` and dtype of `labels` is neither int32 nor int64.
824
+ TypeError: If `sparse` is ``False`` and dtype of `labels` is neither float16 not float32.
814
825
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
815
826
 
816
827
  Supported Platforms:
@@ -889,8 +900,8 @@ class DiceLoss(LossBase):
889
900
  :math:`pred` represent `logits`, :math:`true` represent `labels` .
890
901
 
891
902
  Args:
892
- smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
893
- Default: ``1e-5`` .
903
+ smooth (float, optional): A term added to the denominator to improve numerical stability.
904
+ Should be greater than 0. Default: ``1e-5`` .
894
905
 
895
906
  Inputs:
896
907
  - **logits** (Tensor) - Input predicted value. The data type must be float16 or float32.
@@ -934,11 +945,12 @@ class DiceLoss(LossBase):
934
945
  if label.dtype == mstype.uint8:
935
946
  raise TypeError(f"For '{self.cls_name}', the dtype of 'labels' can not be uint8.")
936
947
  intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
937
- unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
938
- self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
948
+ unionset_part1 = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1)))
949
+ unionset_part2 = self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
950
+ unionset = ops.add(unionset_part1, unionset_part2)
939
951
 
940
- single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
941
- dice_loss = 1 - single_dice_coeff
952
+ single_dice_coeff = (2 * intersection) / ops.add(unionset, self.smooth)
953
+ dice_loss = ops.sub(1, single_dice_coeff)
942
954
 
943
955
  return dice_loss
944
956
 
@@ -1054,7 +1066,7 @@ class MultiClassDiceLoss(LossBase):
1054
1066
  dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
1055
1067
  if self.weights is not None:
1056
1068
  _check_weights(self.weights.shape[0], label.shape[1], self.cls_name)
1057
- dice_loss *= self.weights[i]
1069
+ dice_loss = dice_loss * self.weights[i]
1058
1070
  total_loss += dice_loss
1059
1071
 
1060
1072
  return total_loss / label.shape[1]
@@ -1631,7 +1643,7 @@ class MultiMarginLoss(LossBase):
1631
1643
  def __init__(self, p=1, margin=1.0, reduction='mean', weight=None):
1632
1644
  """Initialize MultiMarginLoss."""
1633
1645
  super(MultiMarginLoss, self).__init__()
1634
- self.multi_margin_loss = MultiMarginLossOp(p=p, margin=margin, reduction=reduction)
1646
+ self.multi_margin_loss = ops.MultiMarginLoss(p=p, margin=margin, reduction=reduction)
1635
1647
  self.weight = weight
1636
1648
 
1637
1649
  def construct(self, x, target, weight=None):
@@ -1718,22 +1730,11 @@ class BCELoss(LossBase):
1718
1730
  def __init__(self, weight=None, reduction='mean'):
1719
1731
  """Initialize BCELoss."""
1720
1732
  super(BCELoss, self).__init__(reduction)
1721
- self.binary_cross_entropy = P.BinaryCrossEntropy(reduction=reduction)
1722
- self.weight_one = weight is None
1723
- if not self.weight_one:
1724
- self.weight = weight
1725
- else:
1726
- self.ones = P.OnesLike()
1733
+ self.reduction = reduction
1734
+ self.weight = weight
1727
1735
 
1728
1736
  def construct(self, logits, labels):
1729
- _check_is_tensor('logits', logits, self.cls_name)
1730
- _check_is_tensor('labels', labels, self.cls_name)
1731
- if self.weight_one:
1732
- weight = self.ones(logits)
1733
- else:
1734
- weight = self.weight
1735
- loss = self.binary_cross_entropy(logits, labels, weight)
1736
- return loss
1737
+ return F.binary_cross_entropy(logits, labels, self.weight, self.reduction)
1737
1738
 
1738
1739
 
1739
1740
  class CosineEmbeddingLoss(LossBase):
@@ -1887,7 +1888,7 @@ class MultilabelMarginLoss(LossBase):
1887
1888
 
1888
1889
  def __init__(self, reduction='mean'):
1889
1890
  super(MultilabelMarginLoss, self).__init__()
1890
- self.multilabel_margin_loss = MultilabelMarginLossOp(reduction=reduction)
1891
+ self.multilabel_margin_loss = ops.MultilabelMarginLoss(reduction=reduction)
1891
1892
 
1892
1893
  def construct(self, x, target):
1893
1894
  loss, _ = self.multilabel_margin_loss(x, target)
@@ -2265,7 +2266,8 @@ class TripletMarginLoss(LossBase):
2265
2266
  - ``'mean'``: compute and return the mean of elements in the output.
2266
2267
  - ``'sum'``: the output elements will be summed.
2267
2268
 
2268
- margin (Union[Tensor, float]): Make a margin between the positive pair and the negative pair.
2269
+ margin (Union[Tensor, float]): Make a margin between the positive pair and the negative pair. The length of
2270
+ shape of `margin` must be 0.
2269
2271
  Default: ``1.0`` .
2270
2272
 
2271
2273
  Inputs:
@@ -2275,7 +2277,8 @@ class TripletMarginLoss(LossBase):
2275
2277
  shape as `x`. :math:`p` in the above formula.
2276
2278
  - **negative** (Tensor) - A sample belonging to the different class from `x`, with the same type and shape
2277
2279
  as `x`. :math:`n` in the above formula.
2278
- - **margin** (Union[Tensor, float]) - Make a margin between the positive pair and the negative pair.
2280
+ - **margin** (Union[Tensor, float]) - Make a margin between the positive pair and the negative pair. The length
2281
+ of shape of `margin` must be 0.
2279
2282
  Default: ``1.0`` .
2280
2283
 
2281
2284
  Outputs:
@@ -2576,7 +2579,7 @@ class KLDivLoss(LossBase):
2576
2579
  the updating formulas of KLDivLoss algorithm are as follows,
2577
2580
 
2578
2581
  .. math::
2579
- L(x, target) = target \cdot (\log target - x)
2582
+ L(x, target) = target \cdot (\log target - \log x)
2580
2583
 
2581
2584
  Then,
2582
2585
 
@@ -2870,7 +2873,7 @@ class HingeEmbeddingLoss(LossBase):
2870
2873
  where :math:`L = \{l_1,\dots,l_N\}^\top`.
2871
2874
 
2872
2875
  Args:
2873
- margin (float, int): Threshold defined by Hinge Embedding Loss :math:`margin`.
2876
+ margin (float, int, optional): Threshold defined by Hinge Embedding Loss :math:`margin`.
2874
2877
  Represented as :math:`\Delta` in the formula. Default: ``1.0`` .
2875
2878
  reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2876
2879
  ``'sum'`` . Default: ``'mean'`` .
@@ -78,6 +78,7 @@ class Adagrad(Optimizer):
78
78
  :math:`state\_sum` stands for the accumulated squared sum of the gradients :math:`accum`.
79
79
  :math:`g` stands for `grads`, :math:`\lambda` stands for `weight_decay`.
80
80
  :math:`\gamma` stands for `learning_rate`, :math:`w` stands for `params`.
81
+ :math:`t` represents current `step`.
81
82
 
82
83
  Note:
83
84
  If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
@@ -112,8 +113,8 @@ class Adagrad(Optimizer):
112
113
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
113
114
  one group of `params`.
114
115
 
115
- accum (float): The starting value for :math:`h`, must be zero or positive values. Default: ``0.1`` .
116
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` .
116
+ accum (float, optional): The starting value for :math:`h`, must be zero or positive values. Default: ``0.1`` .
117
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.001`` .
117
118
 
118
119
  - float: The fixed learning rate value. Must be equal to or greater than 0.
119
120
 
@@ -129,13 +130,14 @@ class Adagrad(Optimizer):
129
130
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
130
131
  with step as the input to get the learning rate of current step.
131
132
 
132
- update_slots (bool): Whether the :math:`h` will be updated. Default: ``True`` .
133
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
133
+ update_slots (bool, optional): Whether the :math:`h` will be updated. Default: ``True`` .
134
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
135
+ use the default value.
134
136
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
135
137
  `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
136
138
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
137
139
  Default: ``1.0`` .
138
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
140
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
139
141
 
140
142
  - float: The fixed weight decay value. Must be equal to or greater than 0.
141
143
 
@@ -68,8 +68,8 @@ class Adadelta(Optimizer):
68
68
 
69
69
  Args:
70
70
  params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
71
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
72
- "order_params" are the keys can be parsed.
71
+ `params` is a list of `dict`, the string `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
72
+ `"order_params"` are the keys can be parsed.
73
73
 
74
74
  - params: Required. Parameters in current group. The value must be a list of `Parameter`.
75
75
 
@@ -93,7 +93,7 @@ class Adadelta(Optimizer):
93
93
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
94
94
  one group of `params`.
95
95
 
96
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1.0`` .
96
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1.0`` .
97
97
 
98
98
  - float: The fixed learning rate value. Must be equal to or greater than 0.
99
99
 
@@ -109,14 +109,16 @@ class Adadelta(Optimizer):
109
109
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
110
110
  with step as the input to get the learning rate of current step.
111
111
 
112
- rho (float): Decay rate, must be in range [0.0, 1.0]. Default: ``0.9`` .
113
- epsilon (float): A small value added for numerical stability, must be non-negative. Default: ``1e-6`` .
114
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
112
+ rho (float, optional): Decay rate, must be in range [0.0, 1.0]. Default: ``0.9`` .
113
+ epsilon (float, optional): A small value added for numerical stability, must be non-negative.
114
+ Default: ``1e-6`` .
115
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
116
+ use the default value.
115
117
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
116
118
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
117
119
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
118
120
  Default: ``1.0`` .
119
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
121
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
120
122
 
121
123
  - float: The fixed weight decay value. Must be equal to or greater than 0.
122
124
 
@@ -134,9 +136,9 @@ class Adadelta(Optimizer):
134
136
 
135
137
  Raises:
136
138
  TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
137
- TypeError: If element of `parameters` is neither Parameter nor dict.
139
+ TypeError: If element of `params` is neither Parameter nor dict.
138
140
  TypeError: If `rho`, `epsilon` or `loss_scale` is not a float.
139
- TypeError: If `weight_decay` is neither float nor int.
141
+ TypeError: If `weight_decay` is not float, int or cell.
140
142
  ValueError: if `rho` is not in range [0.0, 1.0].
141
143
  ValueError: If `loss_scale` is less than or equal to 0.
142
144
  ValueError: If `learning_rate`, `epsilon` or `weight_decay` is less than 0.
@@ -406,7 +406,7 @@ class AdaFactor(Optimizer):
406
406
  """
407
407
  return False
408
408
 
409
- @jit
409
+ @jit(backend="ms_backend")
410
410
  def construct(self, gradients):
411
411
  gradients = self.flatten_gradients(gradients)
412
412
  lr = self.get_lr()