mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -566,7 +566,7 @@ class Adam(Optimizer):
566
566
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
567
567
  one group of `params`.
568
568
 
569
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1e-3`` .
569
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1e-3`` .
570
570
 
571
571
  - float: The fixed learning rate value. Must be equal to or greater than 0.
572
572
 
@@ -582,23 +582,26 @@ class Adam(Optimizer):
582
582
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
583
583
  with step as the input to get the learning rate of current step.
584
584
 
585
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
585
+ beta1 (float, optional): The exponential decay rate for the 1st moment estimations.
586
+ Should be in range (0.0, 1.0).
586
587
  Default: ``0.9`` .
587
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
588
+ beta2 (float, optional): The exponential decay rate for the 2nd moment estimations.
589
+ Should be in range (0.0, 1.0).
588
590
  Default: ``0.999`` .
589
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
591
+ eps (float, optional): Term added to the denominator to improve numerical stability. Should be greater than 0.
590
592
  Default: ``1e-8`` .
591
- use_locking (bool): Whether to enable a lock to protect the updating process of variable tensors.
593
+ use_locking (bool, optional): Whether to enable a lock to protect the updating process of variable tensors.
592
594
  If ``true`` , updates of the `w`, `m`, and `v` tensors will be protected by a lock.
593
595
  If ``false`` , the result is unpredictable. Default: ``False`` .
594
- use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
596
+ use_nesterov (bool, optional): Whether to use Nesterov Accelerated Gradient (NAG) algorithm
597
+ to update the gradients.
595
598
  If ``true`` , update the gradients using NAG.
596
599
  If ``false`` , update the gradients without using NAG. Default: ``False`` .
597
- use_amsgrad (bool): Whether to use Amsgrad algorithm to update the gradients.
600
+ use_amsgrad (bool, optional): Whether to use Amsgrad algorithm to update the gradients.
598
601
  If ``true`` , update the gradients using Amsgrad.
599
602
  If ``false`` , update the gradients without using Amsgrad. Default: ``False`` .
600
603
 
601
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
604
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
602
605
 
603
606
  - float: The fixed weight decay value. Must be equal to or greater than 0.
604
607
 
@@ -607,11 +610,12 @@ class Adam(Optimizer):
607
610
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
608
611
  the Cell with step as the input to get the weight decay value of current step.
609
612
 
610
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
613
+ loss_scale (float, optional): A floating point value for the loss scale.
614
+ Should be greater than 0. In general, use the
611
615
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
612
616
  `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
613
617
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
614
- Default: 1.0.
618
+ Default: ``1.0``.
615
619
 
616
620
  kwargs:
617
621
 
@@ -633,7 +637,7 @@ class Adam(Optimizer):
633
637
  Raises:
634
638
  KeyError: If kwargs got keys other than 'use_lazy' or 'use_offload'.
635
639
  TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
636
- TypeError: If element of `parameters` is neither Parameter nor dict.
640
+ TypeError: If element of `params` is neither Parameter nor dict.
637
641
  TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
638
642
  TypeError: If `weight_decay` is neither float nor int.
639
643
  TypeError: If `use_locking`, `use_nesterov`, `use_amsgrad`, `use_lazy` or `use_offload` is not a bool.
@@ -1024,7 +1028,7 @@ class AdamWeightDecay(Optimizer):
1024
1028
  self.fused_opt = P.AdamWeightDecay()
1025
1029
  self.use_fused_opt = True
1026
1030
 
1027
- @jit
1031
+ @jit(backend="ms_backend")
1028
1032
  def construct(self, gradients):
1029
1033
  gradients = self.flatten_gradients(gradients)
1030
1034
  weight_decay = self.get_weight_decay()
@@ -1244,7 +1248,7 @@ class AdamOffload(Optimizer):
1244
1248
  self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
1245
1249
  self.opt.set_device("CPU")
1246
1250
 
1247
- @jit
1251
+ @jit(backend="ms_backend")
1248
1252
  def construct(self, gradients):
1249
1253
  params = self._parameters
1250
1254
  moment1 = self.moment1
@@ -118,12 +118,12 @@ class AdaMax(Optimizer):
118
118
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
119
119
  with step as the input to get the learning rate of current step.
120
120
 
121
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
122
- Default: ``0.9`` .
123
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
124
- Default: ``0.999`` .
125
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
126
- Default: ``1e-08`` .
121
+ beta1 (float, optional): The exponential decay rate for the 1st moment estimations.
122
+ Should be in range (0.0, 1.0). Default: ``0.9`` .
123
+ beta2 (float, optional): The exponential decay rate for the 2nd moment estimations.
124
+ Should be in range (0.0, 1.0). Default: ``0.999`` .
125
+ eps (float, optional): Term added to the denominator to improve numerical stability. Should be greater than 0.
126
+ Default: ``1e-08`` .
127
127
 
128
128
  weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
129
129
 
@@ -134,7 +134,8 @@ class AdaMax(Optimizer):
134
134
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
135
135
  the Cell with step as the input to get the weight decay value of current step.
136
136
 
137
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
137
+ loss_scale (float, optional): A floating point value for the loss scale. Should be greater than 0.
138
+ In general, use the
138
139
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
139
140
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
140
141
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
@@ -420,17 +420,17 @@ class AdaSumByGradWrapCell(Cell):
420
420
  and the subscripts represent different devices in the data-parallel dimension.
421
421
 
422
422
  Note:
423
- When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required.
424
- Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
425
- It is recommended to using AdaSumByGradWrapCell in semi auto parallel/auto parallel mode. In data parallel
426
- mode, we recommend to using mindspore.boost to applying AdaSum.
423
+ - It is recommended to using AdaSumByGradWrapCell in semi auto parallel/auto parallel mode. In data parallel
424
+ mode, we recommend to using mindspore.boost to applying AdaSum.
425
+ - When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required.
426
+ Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
427
427
 
428
428
  Args:
429
429
  optimizer (Union[Cell]): Optimizer for updating the weights. The construct function of the optimizer
430
430
  requires only one input.
431
431
 
432
432
  Inputs:
433
- - **grads** (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer.
433
+ - **grads** (Tuple[Tensor]) - Tuple of gradients, same with the input of passed optimizer.
434
434
 
435
435
  Raises:
436
436
  RuntimeError: If `parallel_mode` uses `stand_alone` mode, AdaSum only supports use in distributed scenarios.
@@ -180,7 +180,7 @@ class ASGD(Optimizer):
180
180
  self.cast = P.Cast()
181
181
  self.squeeze = P.Squeeze()
182
182
 
183
- @jit
183
+ @jit(backend="ms_backend")
184
184
  def construct(self, gradients):
185
185
  gradients = self.flatten_gradients(gradients)
186
186
  gradients = self.decay_weight(gradients)
@@ -228,21 +228,23 @@ class FTRL(Optimizer):
228
228
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
229
229
  one group of `params`.
230
230
 
231
- initial_accum (float): The starting value for accumulators `m`, must be zero or positive values.
231
+ initial_accum (float, optional): The starting value for accumulators `m`, must be zero or positive values.
232
232
  Default: ``0.1`` .
233
- learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently
234
- not supported. Default: ``0.001`` .
235
- lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less
233
+ learning_rate (float, optional): The learning rate value, must be zero or positive, dynamic learning rate
234
+ is currently not supported. Default: ``0.001`` .
235
+ lr_power (float, optional): Learning rate power controls how the learning rate decreases during training,
236
+ must be less
236
237
  than or equal to zero. Use fixed learning rate if lr_power is zero. Default: ``-0.5`` .
237
- l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
238
- l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
239
- use_locking (bool): If true, use locks for updating operation. Default: ``False`` .
240
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
238
+ l1 (float, optional): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
239
+ l2 (float, optional): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
240
+ use_locking (bool, optional): If true, use locks for updating operation. Default: ``False`` .
241
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
242
+ use the default value.
241
243
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
242
244
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
243
245
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
244
246
  Default: ``1.0`` .
245
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
247
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
246
248
 
247
249
  - float: The fixed weight decay value. Must be equal to or greater than 0.
248
250
 
@@ -265,7 +265,7 @@ class Lamb(Optimizer):
265
265
  self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
266
266
  self.device_ascend = context.get_context("device_target") == "Ascend"
267
267
 
268
- @jit
268
+ @jit(backend="ms_backend")
269
269
  def construct(self, gradients):
270
270
  weight_decay = self.get_weight_decay()
271
271
  lr = self.get_lr()
@@ -82,7 +82,7 @@ class LARS(Optimizer):
82
82
  &\hline \\[-1.ex]
83
83
  \end{array}
84
84
 
85
- :math:`w` represents the network parameters, :math:`g` represents `gradients`,
85
+ :math:`w` represents the network's params, :math:`g` represents `gradients`,
86
86
  :math:`t` represents the current step, :math:`\lambda` represents `weight_decay` in `optimizer`,
87
87
  :math:`\gamma` represents `learning_rate` in `optimizer`, :math:`\eta` represents `coefficient`.
88
88
 
@@ -98,9 +98,6 @@ class LARS(Optimizer):
98
98
  - **gradients** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the
99
99
  as same as the `params` in the optimizer.
100
100
 
101
- Outputs:
102
- Union[Tensor[bool], tuple[Parameter]], it depends on the output of `optimizer`.
103
-
104
101
  Supported Platforms:
105
102
  ``Ascend``
106
103
 
@@ -321,7 +321,7 @@ class LazyAdam(Optimizer):
321
321
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
322
322
  one group of `params`.
323
323
 
324
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1e-3`` .
324
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1e-3`` .
325
325
 
326
326
  - float: The fixed learning rate value. Must be equal to or greater than 0.
327
327
 
@@ -337,20 +337,21 @@ class LazyAdam(Optimizer):
337
337
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
338
338
  with step as the input to get the learning rate of current step.
339
339
 
340
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
341
- Default: ``0.9`` .
342
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
343
- Default: ``0.999`` .
344
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
340
+ beta1 (float, optional): The exponential decay rate for the 1st moment estimations.
341
+ Should be in range (0.0, 1.0). Default: ``0.9`` .
342
+ beta2 (float, optional): The exponential decay rate for the 2nd moment estimations.
343
+ Should be in range (0.0, 1.0). Default: ``0.999`` .
344
+ eps (float, optional): Term added to the denominator to improve numerical stability. Should be greater than 0.
345
345
  Default: ``1e-8`` .
346
- use_locking (bool): Whether to enable a lock to protect the updating process of variable tensors.
346
+ use_locking (bool, optional): Whether to enable a lock to protect the updating process of variable tensors.
347
347
  If ``true`` , updates of the `w`, `m`, and `v` tensors will be protected by a lock.
348
348
  If ``false`` , the result is unpredictable. Default: ``False`` .
349
- use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
349
+ use_nesterov (bool, optional): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to
350
+ update the gradients.
350
351
  If ``true`` , update the gradients using NAG.
351
352
  If ``false`` , update the gradients without using NAG. Default: ``False`` .
352
353
 
353
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
354
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
354
355
 
355
356
  - float: The fixed weight decay value. Must be equal to or greater than 0.
356
357
 
@@ -359,7 +360,8 @@ class LazyAdam(Optimizer):
359
360
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
360
361
  the Cell with step as the input to get the weight decay value of current step.
361
362
 
362
- loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. In general,
363
+ loss_scale (float, optional): A floating point value for the loss scale. Should be equal to or greater than 1.
364
+ In general,
363
365
  use the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update`
364
366
  in `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
365
367
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
@@ -103,7 +103,7 @@ class Momentum(Optimizer):
103
103
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
104
104
  one group of `params`.
105
105
 
106
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
106
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional):
107
107
 
108
108
  - float: The fixed learning rate value. Must be equal to or greater than 0.
109
109
 
@@ -119,10 +119,10 @@ class Momentum(Optimizer):
119
119
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
120
120
  with step as the input to get the learning rate of current step.
121
121
 
122
- momentum (float): Hyperparameter of type float, means momentum for the moving average.
122
+ momentum (float, optional): Hyperparameter of type float, means momentum for the moving average.
123
123
  It must be at least 0.0.
124
124
 
125
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
125
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
126
126
 
127
127
  - float: The fixed weight decay value. Must be equal to or greater than 0.
128
128
 
@@ -131,12 +131,13 @@ class Momentum(Optimizer):
131
131
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
132
132
  the Cell with step as the input to get the weight decay value of current step.
133
133
 
134
- loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. In general, use the
134
+ loss_scale (float, optional): A floating point value for the loss scale. It must be greater than 0.0.
135
+ In general, use the
135
136
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
136
137
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
137
138
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
138
139
  Default: ``1.0`` .
139
- use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` .
140
+ use_nesterov (bool, optional): Enable Nesterov momentum. Default: ``False`` .
140
141
 
141
142
  Inputs:
142
143
  - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -199,7 +200,7 @@ class Momentum(Optimizer):
199
200
  self._get_distributed_optimizer_list("momentum", use_nesterov=self.use_nesterov)
200
201
  self.use_dist_optimizer = self._use_distibuted_optimizer()
201
202
 
202
- @jit
203
+ @jit(backend="ms_backend")
203
204
  def construct(self, gradients):
204
205
  params = self.params
205
206
  moments = self.moments
@@ -122,10 +122,10 @@ class Optimizer(Cell):
122
122
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
123
123
  one group of `params`.
124
124
 
125
- weight_decay (Union[float, int]): An int or a floating point value for the weight decay.
125
+ weight_decay (Union[float, int], optional): An int or a floating point value for the weight decay.
126
126
  It must be equal to or greater than 0.
127
127
  If the type of `weight_decay` input is int, it will be converted to float. Default: ``0.0`` .
128
- loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
128
+ loss_scale (float, optional): A floating point value for the loss scale. It must be greater than 0. If the
129
129
  type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only
130
130
  when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
131
131
  `FixedLossScaleManager` is set to ``False`` , this value needs to be the same as the `loss_scale` in
@@ -848,7 +848,7 @@ class Optimizer(Cell):
848
848
  optim_result(bool): The results of updating parameters. This input is used to ensure that the parameters are
849
849
  updated before they are broadcast.
850
850
  Returns:
851
- bool, the status flag.
851
+ The broadcast parameters.
852
852
  """
853
853
  # If rank_id is 0, 1, 2, 3, there are param0 ~ param7,
854
854
  # then the value is[(param0, param4), (param1, param5), (param2, param6), (param3, param7)]
@@ -83,8 +83,8 @@ class ProximalAdagrad(Optimizer):
83
83
 
84
84
  Args:
85
85
  params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
86
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
87
- "order_params" are the keys can be parsed.
86
+ `params` is a list of `dict`, the string `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
87
+ `"order_params"` are the keys can be parsed.
88
88
 
89
89
  - params: Required. Parameters in current group. The value must be a list of `Parameter`.
90
90
 
@@ -108,8 +108,9 @@ class ProximalAdagrad(Optimizer):
108
108
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
109
109
  one group of `params`.
110
110
 
111
- accum (float): The starting value for accumulators `accum`, must be zero or positive values. Default: ``0.1`` .
112
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` .
111
+ accum (float, optional): The starting value for accumulators `accum`, must be zero or positive values.
112
+ Default: ``0.1`` .
113
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.001`` .
113
114
 
114
115
  - float: The fixed learning rate value. Must be equal to or greater than 0.
115
116
 
@@ -125,15 +126,16 @@ class ProximalAdagrad(Optimizer):
125
126
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
126
127
  with step as the input to get the learning rate of the current step.
127
128
 
128
- l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
129
- l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
130
- use_locking (bool): If true, use locks for updating operation. Default: ``False`` .
131
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
129
+ l1 (float, optional): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
130
+ l2 (float, optional): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
131
+ use_locking (bool, optional): If ``True``, use locks for updating operation. Default: ``False`` .
132
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
133
+ use the default value.
132
134
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
133
135
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
134
136
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
135
137
  Default: ``1.0`` .
136
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
138
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
137
139
 
138
140
  - float: The fixed weight decay value. Must be equal to or greater than 0.
139
141
 
@@ -199,7 +201,7 @@ class ProximalAdagrad(Optimizer):
199
201
  self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
200
202
  self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
201
203
 
202
- @jit
204
+ @jit(backend="ms_backend")
203
205
  def construct(self, grads):
204
206
  params = self._parameters
205
207
  accum = self.accum
@@ -92,9 +92,9 @@ class RMSProp(Optimizer):
92
92
  :math:`t` represents the current step.
93
93
 
94
94
  Note:
95
- If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
96
- 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
97
- parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be
95
+ If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters
96
+ without 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight.
97
+ When parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be
98
98
  applied.
99
99
 
100
100
  Args:
@@ -124,7 +124,7 @@ class RMSProp(Optimizer):
124
124
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
125
125
  one group of `params`.
126
126
 
127
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.1`` .
127
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.1`` .
128
128
 
129
129
  - float: The fixed learning rate value. Must be equal to or greater than 0.
130
130
 
@@ -140,21 +140,22 @@ class RMSProp(Optimizer):
140
140
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
141
141
  with step as the input to get the learning rate of the current step.
142
142
 
143
- decay (float): Decay rate. Should be equal to or greater than 0. Default: ``0.9`` .
144
- momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
145
- greater than 0. Default: ``0.0`` .
146
- epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than
143
+ decay (float, optional): Decay rate. Should be equal to or greater than 0. Default: ``0.9`` .
144
+ momentum (float, optional): Hyperparameter of type float, means momentum for the moving average.
145
+ Should be equal to or greater than 0. Default: ``0.0`` .
146
+ epsilon (float, optional): Term added to the denominator to improve numerical stability. Should be greater than
147
147
  0. Default: ``1e-10`` .
148
- use_locking (bool): Whether to enable a lock to protect the updating process of variable tensors.
148
+ use_locking (bool, optional): Whether to enable a lock to protect the updating process of variable tensors.
149
149
  Default: ``False`` .
150
- centered (bool): If True, gradients are normalized by the estimated variance of the gradient.
150
+ centered (bool, optional): If True, gradients are normalized by the estimated variance of the gradient.
151
151
  Default: ``False`` .
152
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
152
+ loss_scale (float, optional): A floating point value for the loss scale. Should be greater than 0. In general,
153
+ use the
153
154
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
154
155
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
155
156
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
156
157
  Default: ``1.0`` .
157
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
158
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
158
159
 
159
160
  - float: The fixed weight decay value. Must be equal to or greater than 0.
160
161
 
@@ -44,8 +44,8 @@ class Rprop(Optimizer):
44
44
  &\hspace{15mm} w_{t} \leftarrow w_{t-1}- \Delta_{t} \mathrm{sign}(g_t) \\
45
45
  \end{gather*}
46
46
 
47
- :math:`\Delta_{min/max}` represents the min/max step size, :math:`\eta_{+/-}` represents the factors of
48
- etaminus and etaplus, :math:`g` represents `gradients`, :math:`w` represents `parameters`.
47
+ :math:`g` represents `gradients`, :math:`w` represents `parameters`, :math:`\Delta_{min/max}` represents the
48
+ min/max step size, :math:`\eta_{+/-}` represents the factors of etaminus and etaplus.
49
49
 
50
50
  Note:
51
51
  If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta'
@@ -58,8 +58,8 @@ class Rprop(Optimizer):
58
58
 
59
59
  Args:
60
60
  params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
61
- `parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and
62
- "order_params" are the keys can be parsed.
61
+ `parameters` is a list of `dict`, the `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
62
+ `"order_params"` are the keys can be parsed.
63
63
 
64
64
  - params: Required. Parameters in current group. The value must be a list of `Parameter`.
65
65
 
@@ -83,7 +83,8 @@ class Rprop(Optimizer):
83
83
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
84
84
  one group of `params`.
85
85
 
86
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Learning_rate. Default: ``0.1`` .
86
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Learning_rate.
87
+ Default: ``0.1`` .
87
88
 
88
89
  - float: The fixed learning rate value. Must be equal to or greater than 0.
89
90
 
@@ -99,11 +100,12 @@ class Rprop(Optimizer):
99
100
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
100
101
  with step as the input to get the learning rate of current step.
101
102
 
102
- etas (tuple[float, float]): The factor of multiplicative increasing or
103
+ etas (tuple[float, float], optional): The factor of multiplicative increasing or
103
104
  descreasing(etaminus, etaplus). Default: ``(0.5, 1.2)`` .
104
- step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
105
+ step_sizes(tuple[float, float], optional): The allowed minimal and maximal
106
+ step size(min_step_sizes, max_step_size).
105
107
  Default: ``(1e-6, 50.)`` .
106
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
108
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
107
109
 
108
110
  - float: The fixed weight decay value. Must be equal to or greater than 0.
109
111
 
@@ -199,7 +201,7 @@ class Rprop(Optimizer):
199
201
  self.select = P.Select()
200
202
  self.ones_like = P.OnesLike()
201
203
 
202
- @jit
204
+ @jit(backend="ms_backend")
203
205
  def construct(self, gradients):
204
206
  gradients = self.flatten_gradients(gradients)
205
207
  gradients = self.decay_weight(gradients)
mindspore/nn/optim/sgd.py CHANGED
@@ -90,7 +90,7 @@ class SGD(Optimizer):
90
90
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
91
91
  one group of `params`.
92
92
 
93
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.1`` .
93
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.1`` .
94
94
 
95
95
  - float: The fixed learning rate value. Must be equal to or greater than 0.
96
96
 
@@ -106,12 +106,15 @@ class SGD(Optimizer):
106
106
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
107
107
  with step as the input to get the learning rate of current step.
108
108
 
109
- momentum (float): A floating point value the momentum. must be at least 0.0. Default: ``0.0`` .
110
- dampening (float): A floating point value of dampening for momentum. must be at least 0.0. Default: ``0.0`` .
111
- weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: ``0.0`` .
112
- nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
109
+ momentum (float, optional): A floating point value the momentum. must be at least 0.0. Default: ``0.0`` .
110
+ dampening (float, optional): A floating point value of dampening for momentum. must be at least 0.0.
111
+ Default: ``0.0`` .
112
+ weight_decay (float, optional): Weight decay (L2 penalty). It must be equal to or greater than 0.
113
+ Default: ``0.0`` .
114
+ nesterov (bool, optional): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
113
115
  and dampening must be equal to 0.0. Default: ``False`` .
114
- loss_scale (float): A floating point value for the loss scale, which must be larger than 0.0. In general, use
116
+ loss_scale (float, optional): A floating point value for the loss scale, which must be larger than 0.0.
117
+ In general, use
115
118
  the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
116
119
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
117
120
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
@@ -67,14 +67,16 @@ class OptTFTWrapper(Optimizer):
67
67
  raise TypeError(f"For 'OptTFTWrapper', the argument 'opt' must be Optimizer type, " f"but got {type(opt)}.")
68
68
  super(OptTFTWrapper, self).__init__(opt.learning_rate, opt._parameters) # pylint: disable=W0212
69
69
  tft_env = os.getenv("MS_ENABLE_TFT", "")
70
- if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
71
- raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
70
+ if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env) and ("ARF:1" not in tft_env):
71
+ raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1,ARF:1}']!")
72
72
  mode = context.get_context("mode")
73
73
  device_target = context.get_context("device_target")
74
74
  if device_target != "Ascend" or mode != context.GRAPH_MODE:
75
75
  raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
76
76
  self.opt = opt
77
77
  self.report = TensorReport()
78
+ self.report_end = TensorReport()
79
+ self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
78
80
  self.depend = ops.Depend()
79
81
  self.allreduce_sum = ops.AllReduce()
80
82
  self.allreduce_sum.add_prim_attr("tft_report_before", True)
@@ -121,4 +123,5 @@ class OptTFTWrapper(Optimizer):
121
123
 
122
124
  grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
123
125
  opt_ret = self.opt(grads)
126
+ self.report_end("tft_report", self.tft_g_one_flag)
124
127
  return opt_ret
@@ -21,6 +21,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
21
21
  from mindspore.common.initializer import initializer
22
22
  from mindspore.common.parameter import Parameter, ParameterTuple
23
23
  from mindspore.common.tensor import Tensor
24
+ from mindspore.common import set_recursion_limit
24
25
  import mindspore.ops as ops
25
26
  import mindspore.nn as nn
26
27
  import mindspore.common.dtype as mstype
@@ -355,7 +356,7 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
355
356
  ... amp_level="O2", keep_batchnorm_fp32=False)
356
357
 
357
358
  """
358
- context.set_context(max_call_depth=10000)
359
+ set_recursion_limit(10000)
359
360
  ConvertNetUtils().convert_to_thor_net(net)
360
361
  if context.get_context("device_target") == "Ascend":
361
362
  return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
@@ -41,15 +41,20 @@ class Bijector(Cell):
41
41
  param (dict): The parameters used to initialize the Bijector. Default: ``None`` .
42
42
 
43
43
  Note:
44
- `dtype` of bijector represents the type of the distributions that the bijector could operate on.
45
- When `dtype` is None, there is no enforcement on the type of input value except that the input value
46
- has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
47
- of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
48
- Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
49
- will be casted into the same type as input value when `dtype` is None.
50
- When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
51
- When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
52
- raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
44
+ - `dtype` of bijector represents the type of the distributions that the bijector could operate on.
45
+ - When `dtype` is None, there is no enforcement on the type of input value except that the input value
46
+ has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
47
+ of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
48
+
49
+ Specifically, the parameter type will follow the dtype of the input value.
50
+
51
+ - Parameters of the bijector will be casted into the same type as input value when `dtype` is None.
52
+
53
+ - When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
54
+ When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
55
+ raised.
56
+
57
+ - Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
53
58
 
54
59
  Supported Platforms:
55
60
  ``Ascend`` ``GPU``
@@ -226,7 +231,8 @@ class Bijector(Cell):
226
231
 
227
232
  def cast_param_by_value(self, value, para):
228
233
  """
229
- Cast the parameter(s) of the bijector to be the same type of input_value.
234
+ Converts the data type of `para` in the input to the same type as `value`.
235
+ Typically used by subclasses of Bijector to convert data types of their own parameters.
230
236
 
231
237
  Args:
232
238
  value (Tensor): input value.
@@ -276,7 +282,7 @@ class Bijector(Cell):
276
282
  **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
277
283
 
278
284
  Returns:
279
- Tensor, the value of logarithm of the derivative of the forward transformation.
285
+ Tensor, outputs the value of a random variable after mapping.
280
286
  """
281
287
  return self._forward_log_jacobian(value, *args, **kwargs)
282
288