mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -125,18 +125,20 @@ class EmbeddingServiceOut:
125
125
 
126
126
  class EmbeddingService:
127
127
  r"""
128
- Currently, ES(EmbeddingService) feature can only create one object which can support model training and inference
128
+ ES(EmbeddingService) feature can support model training and inference
129
129
  for PS embedding and data_parallel embedding, and provide unified embedding management, storage,
130
130
  and computing capabilities for training and inference.
131
131
  PS embedding refer to tables that vocab_size more than 100,000, and recommended to store them on the
132
132
  Parameter Server (PS). Data_parallel embedding refer to tables that vocab_size less than 100,000, and recommended
133
133
  to store them on device.
134
134
 
135
+ Currently, ES feature can only create one instance of EmbeddingService object.
136
+
135
137
  .. warning::
136
138
  This is an experimental EmbeddingService API that is subject to change.
137
139
 
138
140
  .. note::
139
- This API needs to call 'mindspore.communication.init()' before,
141
+ This API needs to call :func:`mindspore.communication.init` before,
140
142
  and it can take effect after the dynamic networking is completed.
141
143
 
142
144
  Raises:
@@ -241,24 +243,26 @@ class EmbeddingService:
241
243
  name (str): The embedding table name.
242
244
  init_vocabulary_size (int): The size of embedding table.
243
245
  embedding_dim (int): The embedding dim of data in embedding table.
244
- max_feature_count (int): The count of keys when look up for PS.
245
- initializer (Initializer): The initialization strategy for the PS embedding, default is ``Uniform``.
246
- embedding_type (str): The embedding type, configurable parameters ["PS", "data_parallel"],
246
+ max_feature_count (int, optional): The count of keys when look up for PS. Default: ``None``.
247
+ initializer (Initializer, optional): The initialization strategy for the PS embedding,
248
+ default is ``Uniform(scale=0.01)``.
249
+ embedding_type (str, optional): The embedding type, configurable parameters ["PS", "data_parallel"],
247
250
  ``"PS"`` means initializing PS embedding, ``"data_parallel"`` means initializing data_parallel
248
251
  embedding, and default is ``"PS"``.
249
- ev_option (EmbeddingVariableOption): Properties of the PS embedding,
252
+ ev_option (EmbeddingVariableOption, optional): Properties of the PS embedding,
250
253
  is a EmbeddingVariableOption obj which returned by embedding_variable_option function.
251
254
  Default is ``None``.
252
- multihot_lens (int): The param only use when allow_merge is enabled, and not support now.
255
+ multihot_lens (int, optional): The param only use when `allow_merge` is enabled, and not support now.
253
256
  Default is ``None``.
254
- optimizer (str): The type of optimizer in the train mode for PS embedding,
257
+ optimizer (str, optional): The type of optimizer in the train mode for PS embedding,
255
258
  cannot be shared among each PS embedding, and currently only ``"Adam"``, ``"Ftrl"``, ``"SGD"`` and
256
259
  ``"RMSProp"`` are supported, and default is ``None``.
257
- allow_merge (bool): Whether to enable merge data_parallel embeddings, currently only be False,
260
+ allow_merge (bool, optional): Whether to enable merge data_parallel embeddings, currently only be False,
258
261
  and default is ``False``.
259
- optimizer_param (float): The "initialize accumulator value" param of optimizer which configured by user,
262
+ optimizer_param (float, optional): The "initialize accumulator value" param
263
+ of optimizer which configured by user,
260
264
  representing the init value of moment accumulator, and default is ``None``.
261
- mode (str): Run mode, configurable parameters ["train", "predict", "export"],
265
+ mode (str, optional): Run mode, configurable parameters ["train", "predict", "export"],
262
266
  ``"train"`` means train mode, ``"predict"`` means predict mode, ``"export"`` mean export mode,
263
267
  and default is ``"train"``.
264
268
 
@@ -345,8 +349,9 @@ class EmbeddingService:
345
349
 
346
350
  Args:
347
351
  padding_key (int): The value for padding key, must be a genuine and legal hash key.
348
- mask (bool): Whether to update padding key. If set to false, it will not be updated. Default is ``True``.
349
- mask_zero (bool): Whether to update padding key when key is 0. Default is ``False``.
352
+ mask (bool, optional): Whether to update padding key. If set to false, it will not be updated.
353
+ Default is ``True``.
354
+ mask_zero (bool, optional): Whether to update padding key when key is 0. Default is ``False``.
350
355
 
351
356
  Returns:
352
357
  PaddingParamsOption object.
@@ -368,7 +373,7 @@ class EmbeddingService:
368
373
 
369
374
  Args:
370
375
  completion_key (int): The value for completion key.
371
- mask (bool): Whether to update completion key. If set to false, it will not be updated,
376
+ mask (bool, optional): Whether to update completion key. If set to false, it will not be updated,
372
377
  and default is ``True``.
373
378
 
374
379
  Returns:
@@ -396,10 +401,11 @@ class EmbeddingService:
396
401
 
397
402
  Args:
398
403
  filter_freq (int): The frequency threshold value for feature admission.
399
- default_key (int): The key that number of occurrences does not reach the threshold,
400
- return value of default key as the corresponding value when look up embedding, and default is ``None``.
401
- default_value (int/float): The key that number of occurrences does not reach the threshold,
402
- return default value which length value is embedding dim, and default is ``None``.
404
+ default_key (int, optional): The key that number of occurrences does not reach the threshold,
405
+ return value of `default_key` as the corresponding value when look up embedding,
406
+ and default is ``None``.
407
+ default_value (Union[int, float], optional): The key that number of occurrences does not
408
+ reach the threshold, return default value which length value is embedding dim, and default is ``None``.
403
409
 
404
410
  Returns:
405
411
  CounterFilter object.
@@ -460,16 +466,17 @@ class EmbeddingService:
460
466
  Set variable option for PS embedding.
461
467
 
462
468
  Args:
463
- filter_option (CounterFilter): The option of counter filter. Default is ``None``.
464
- padding_option (PaddingParamsOption): The option of padding key. Default is ``None``.
465
- evict_option (EvictOption): The option evict. Default is ``None``.
466
- completion_option (CompletionKeyOption): The option of completion key. Default is ``None``.
467
- storage_option (None): Reserved option, currently not supported. Default is ``None``.
468
- feature_freezing_option (None): Reserved option, currently not supported. Default is ``None``.
469
- communication_option (None): Reserved option, currently not supported. Default is ``None``.
469
+ filter_option (CounterFilter, optional): The option of counter filter. Default is ``None``.
470
+ padding_option (PaddingParamsOption, optional): The option of padding key. Default is ``None``.
471
+ evict_option (EvictOption, optional): The option evict. Default is ``None``.
472
+ completion_option (CompletionKeyOption, optional): The option of completion key. Default is ``None``.
473
+ storage_option (None, optional): Reserved option, currently not supported. Default is ``None``.
474
+ feature_freezing_option (None, optional): Reserved option, currently not supported. Default is ``None``.
475
+ communication_option (None, optional): Reserved option, currently not supported. Default is ``None``.
470
476
 
471
477
  Returns:
472
- EmbeddingVariableOption object, used as the ev_option parameter for embedding_init.
478
+ EmbeddingVariableOption object, used as the ev_option parameter for
479
+ :func:`mindspore.experimental.es.EmbeddingService.embedding_init` .
473
480
 
474
481
  Raises:
475
482
  TypeError: If value of "filter_option" is not None and the type of "filter_option" is not CounterFilter.
@@ -501,7 +508,8 @@ class EmbeddingService:
501
508
 
502
509
  .. note::
503
510
  This function can only be executed by rank 0.
504
- Need to call embedding_variable_option to set evict_option for each PS embedding before export.
511
+ Need to call :func:`mindspore.experimental.es.EmbeddingService.embedding_variable_option`
512
+ to set evict_option for each PS embedding before export.
505
513
 
506
514
  Args:
507
515
  file_path (str): The path to export embedding ckpt, and the last character cannot be ``"/"``.
@@ -16,6 +16,7 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  from mindspore.experimental.llm_boost.atb import LlamaBoost, QwenBoost
19
+ from mindspore.experimental.llm_boost.ascend_native import *
19
20
  from mindspore.experimental.llm_boost.register import LlmBoostRegister
20
21
 
21
22
  __all__ = ["LlmBoostRegister"]
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Provide llm boost for inference, such as LlamaBoost.
17
+ """
18
+ from __future__ import absolute_import
19
+
20
+ from mindspore.experimental.llm_boost.ascend_native.llama_boost_ascend_native import LlamaBoostAscendNative
21
+
22
+ __all__ = ['LlamaBoostAscendNative']
@@ -0,0 +1,211 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """AscendNative Llama Boost APIs."""
16
+
17
+ import os
18
+ import numpy as np
19
+ from mindspore.common import Tensor, dtype
20
+ from mindspore.experimental.llm_boost.ascend_native.llm_boost import LLMBoost
21
+ from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
22
+
23
+
24
+ def RoundUp(val: int, align: int) -> int:
25
+ if align == 0:
26
+ return 0
27
+ return -(val // -align) * align
28
+
29
+
30
+ def ConvertTensor(nd_mat: np.ndarray, transpose: bool = True, nd2nz: bool = True) -> np.ndarray:
31
+ """ Transforms tensor format from Nd to Nz """
32
+ if transpose:
33
+ nd_mat = np.transpose(nd_mat)
34
+ if not nd2nz:
35
+ return nd_mat
36
+ block_size = (16, 16)
37
+ r = RoundUp(nd_mat.shape[0], block_size[0])
38
+ c = RoundUp(nd_mat.shape[1], block_size[1])
39
+ r_pad = r - nd_mat.shape[0]
40
+ c_pad = c - nd_mat.shape[1]
41
+ nd_mat = np.pad(nd_mat, ((0, r_pad), (0, c_pad)))
42
+ nz_mat = np.transpose(np.reshape(
43
+ nd_mat, (r, c // block_size[1], block_size[1])), (1, 0, 2))
44
+ nz_mat = nz_mat.reshape(r, c)
45
+ return nz_mat
46
+
47
+
48
+ @LlmBoostRegister.register(LlmBoostType.ASCEND_NATIVE, "Llama")
49
+ class LlamaBoostAscendNative(LLMBoost):
50
+ r"""
51
+ Implements an Llama model in a single kernel.
52
+ it forwards the python functions to the C++ binded object
53
+ """
54
+ def _get_from_dict(self, dictionary, name):
55
+ """ internal function to get a specific tensor from the dictionary """
56
+ all_relevant_layers = [value for key, value in dictionary.items() if name in key]
57
+ if all_relevant_layers:
58
+ return all_relevant_layers[0].asnumpy()
59
+ return None
60
+
61
+ def _get_quant_triplet_from_dict(self, dictionary, name):
62
+ """ internal function to get a weight triple tensor from the dictionary """
63
+ weights = self._get_from_dict(dictionary, name + "._handler.weight")
64
+ scale = self._get_from_dict(dictionary, name + "._weight_quantizer.scale")
65
+ offset = self._get_from_dict(dictionary, name + "._weight_quantizer.zp_neg")
66
+ return weights, scale, offset
67
+
68
+ def _prepare_single_layer(self, ckpt, config, id):
69
+ """ prepares the dictionary of weights of a single layer """
70
+ prefix = 'model.layers.' + str(id)
71
+ is_last = (id == config.num_layers-1)
72
+ layer = 'layers.' + str(id) + '.'
73
+ l_dict = {key: value for key, value in ckpt.items() if layer in key}
74
+ if config.n_kv_heads is None:
75
+ config.n_kv_heads = config.num_heads
76
+ start = 0
77
+ end = config.hidden_size
78
+ kv_start = 0
79
+ kv_end = int(config.hidden_size*config.n_kv_heads/config.num_heads)
80
+ ffn_hid = [value for key, value in l_dict.items() if "w3" in key][0].shape[0]
81
+ ffn_start = 0
82
+ ffn_end = ffn_hid
83
+ rank_size = int(os.getenv('RANK_SIZE', '1'))
84
+ #Emir if (config.parallel_mode != 2): # 2 - AUTO_PARALLEL
85
+ hid_size = end
86
+ kv_hid_size = kv_end
87
+ embed_size = config.vocab_size
88
+ rank_id = int(os.getenv('RANK_ID', '0'))
89
+ if (hid_size % rank_size == 0) and (ffn_hid % rank_size == 0) and (embed_size % rank_size == 0):
90
+ start = int(rank_id * hid_size / rank_size)
91
+ end = int((rank_id + 1) * hid_size / rank_size)
92
+ kv_start = int(rank_id * kv_hid_size / rank_size)
93
+ kv_end = int((rank_id + 1) * kv_hid_size / rank_size)
94
+ ffn_start = int(rank_id * ffn_hid / rank_size)
95
+ ffn_end = int((rank_id + 1) * ffn_hid / rank_size)
96
+ else:
97
+ raise RuntimeError("hidden size and ffn hidden size must be divided by rank size without remainder. \
98
+ hidden_size: ", hid_size, " ffn_hidden_size: ", ffn_hid, " rank_size: ", rank_size)
99
+ quant = (self._get_from_dict(l_dict, "_weight_quantizer") is not None)
100
+ unite_qkv = (config.num_heads == config.n_kv_heads)
101
+ self.dictionary[prefix + ".attention_norm.weight"] = \
102
+ Tensor(self._get_from_dict(l_dict, "attention_norm"), dtype=dtype.float16)
103
+ self.dictionary[prefix + ".ffn_norm.weight"] = \
104
+ Tensor(self._get_from_dict(l_dict, "ffn_norm"), dtype=dtype.float16)
105
+ if is_last:
106
+ self.dictionary['lm_head.weight'] = Tensor(ConvertTensor(ckpt['lm_head.weight'].asnumpy()[:, start:end]))
107
+
108
+ if not quant:
109
+ self._pack_attn_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
110
+ self._pack_ffn_weights(l_dict, prefix, ffn_start, ffn_end)
111
+ else:
112
+ self._pack_attn_quant_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
113
+ self._pack_ffn_quant_weights(l_dict, prefix, ffn_start, ffn_end)
114
+
115
+ def _pack_attn_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
116
+ """ prepares the dictionary of weights of an attention block """
117
+ wq = self._get_from_dict(l_dict, "wq")[start:end, :]
118
+ wk = self._get_from_dict(l_dict, "wk")[kv_start:kv_end, :]
119
+ wv = self._get_from_dict(l_dict, "wv")[kv_start:kv_end, :]
120
+ self.dictionary[prefix + ".attention.wo.weight"] = \
121
+ Tensor(ConvertTensor(self._get_from_dict(l_dict, "wo")[:, start:end]))
122
+ if unite_qkv:
123
+ self.dictionary[prefix + ".attention.wqkv.weight"] = Tensor(ConvertTensor(np.concatenate((wq, wk, wv))))
124
+ else:
125
+ self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq))
126
+ self.dictionary[prefix + ".attention.wkv.weight"] = Tensor(ConvertTensor(np.concatenate((wk, wv))))
127
+
128
+ def _pack_ffn_weights(self, l_dict, prefix, ffn_start, ffn_end):
129
+ """ prepares the dictionary of weights of an ffn block """
130
+ self.dictionary[prefix + ".feed_forward.w2.weight"] = \
131
+ Tensor(ConvertTensor(self._get_from_dict(l_dict, "w2")[:, ffn_start:ffn_end]))
132
+ w1 = self._get_from_dict(l_dict, "w1")[ffn_start:ffn_end, :]
133
+ w3 = self._get_from_dict(l_dict, "w3")[ffn_start:ffn_end, :]
134
+ self.dictionary[prefix + ".feed_forward.w13.weight"] = Tensor(ConvertTensor(np.concatenate((w1, w3))))
135
+
136
+ def _pack_attn_quant_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
137
+ """ prepares the dictionary of weights of a quantized attention block """
138
+ wq, wq_scale, wq_offset = self._get_quant_triplet_from_dict(l_dict, "wq")
139
+ wk, wk_scale, wk_offset = self._get_quant_triplet_from_dict(l_dict, "wk")
140
+ wv, wv_scale, wv_offset = self._get_quant_triplet_from_dict(l_dict, "wv")
141
+ wo, wo_scale, wo_offset = self._get_quant_triplet_from_dict(l_dict, "wo")
142
+ self.dictionary[prefix + ".attention.wo.weight"] = Tensor(ConvertTensor(wo[:, start:end], nd2nz=False))
143
+ self.dictionary[prefix + ".attention.wo.weight.scale"] = Tensor(wo_scale[start:end])
144
+ self.dictionary[prefix + ".attention.wo.weight.offset"] = Tensor(wo_offset[start:end])
145
+
146
+ if unite_qkv:
147
+ self.dictionary[prefix + ".attention.wqkv.weight"] = \
148
+ Tensor(ConvertTensor(np.concatenate((wq[start:end, :], wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])),
149
+ nd2nz=False))
150
+ self.dictionary[prefix + ".attention.wqkv.weight.scale"] = \
151
+ Tensor(np.concatenate((wq_scale[start:end], wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
152
+ self.dictionary[prefix + ".attention.wqkv.weight.offset"] = \
153
+ Tensor(np.concatenate((wq_offset[start:end], wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
154
+ else:
155
+ self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq[start:end, :], nd2nz=False))
156
+ self.dictionary[prefix + ".attention.wq.weight.scale"] = Tensor(wq_scale[start:end])
157
+ self.dictionary[prefix + ".attention.wq.weight.offset"] = Tensor(wq_offset[start:end])
158
+ self.dictionary[prefix + ".attention.wkv.weight"] = \
159
+ Tensor(ConvertTensor(np.concatenate((wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])), nd2nz=False))
160
+ self.dictionary[prefix + ".attention.wkv.weight.scale"] = \
161
+ Tensor(np.concatenate((wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
162
+ self.dictionary[prefix + ".attention.wkv.weight.offset"] = \
163
+ Tensor(np.concatenate((wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
164
+
165
+ def _pack_ffn_quant_weights(self, l_dict, prefix, ffn_start, ffn_end):
166
+ """ prepares the dictionary of weights of a quantized ffn block """
167
+ w1, w1_scale, w1_offset = self._get_quant_triplet_from_dict(l_dict, "w1")
168
+ w2, w2_scale, w2_offset = self._get_quant_triplet_from_dict(l_dict, "w2")
169
+ w3, w3_scale, w3_offset = self._get_quant_triplet_from_dict(l_dict, "w3")
170
+ self.dictionary[prefix + ".feed_forward.w2.weight"] = Tensor(ConvertTensor(w2[:, ffn_start:ffn_end],
171
+ nd2nz=False))
172
+ self.dictionary[prefix + ".feed_forward.w2.weight.scale"] = Tensor(w2_scale[ffn_start:ffn_end])
173
+ self.dictionary[prefix + ".feed_forward.w2.weight.offset"] = Tensor(w2_offset[ffn_start:ffn_end])
174
+
175
+ self.dictionary[prefix + ".feed_forward.w13.weight"] = \
176
+ Tensor(ConvertTensor(np.concatenate((w1[ffn_start:ffn_end, :], w3[ffn_start:ffn_end, :])), nd2nz=False))
177
+ self.dictionary[prefix + ".feed_forward.w13.weight.scale"] = \
178
+ Tensor(np.concatenate((w1_scale[ffn_start:ffn_end], w3_scale[ffn_start:ffn_end])))
179
+ self.dictionary[prefix + ".feed_forward.w13.weight.offset"] = \
180
+ Tensor(np.concatenate((w1_offset[ffn_start:ffn_end], w3_offset[ffn_start:ffn_end])))
181
+
182
+ def _prepare_cos_sin_arrays(self, config, theta=10000):
183
+ """ prepares the cosine and sine arrays """
184
+ head_dim = config.hidden_size // config.num_heads
185
+ max_position_embedding = \
186
+ config.max_position_embedding if config.max_position_embedding is not None else config.seq_length
187
+ freqs_base = np.arange(0, head_dim, 2)[: (head_dim // 2)].astype(np.float32)
188
+ freqs = 1.0 / (theta ** (freqs_base / head_dim))
189
+ t = np.arange(0, max_position_embedding, 1).astype(np.float32)
190
+ freqs = np.outer(t, freqs)
191
+ emb = np.concatenate((freqs, freqs), axis=-1)
192
+ freqs_cos = Tensor(np.cos(emb), dtype=dtype.float16)
193
+ sin = np.sin(emb)
194
+
195
+ sin[:, :int(emb.shape[1]/2)] = -sin[:, :int(emb.shape[1]/2)]
196
+ self.dictionary['model.cos.weight'] = freqs_cos
197
+ freqs_sin = Tensor(sin, dtype=dtype.float16)
198
+ self.dictionary['model.sin.weight'] = freqs_sin
199
+
200
+ def set_weights(self, ckpt_dict):
201
+ """ load the checkpoint """
202
+ self.dictionary = {}
203
+ self.dictionary['model.tok_embeddings.embedding_weight'] = \
204
+ Tensor(ckpt_dict['model.tok_embeddings.embedding_weight'].asnumpy())
205
+ self.dictionary['model.norm_out.weight'] = \
206
+ Tensor(ckpt_dict['model.norm_out.weight'].asnumpy(), dtype=dtype.float16)
207
+ self._prepare_cos_sin_arrays(self.config)
208
+ for layer_id in range(self.config.num_layers):
209
+ self._prepare_single_layer(ckpt_dict, self.config, layer_id)
210
+
211
+ self.binder.set_weights_map(self.dictionary)
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """LLMBoost APIs."""
16
+
17
+ from mindspore.common import Tensor
18
+
19
+ class LLMBoost():
20
+ r"""
21
+ Implements an LLM in a single kernel.
22
+ it forwards the python function to the C++ binded object
23
+ """
24
+ def __init__(self, config):
25
+ r"""
26
+ initialize the parameters of the llm binder.
27
+ config is simply the config object of the model
28
+ """
29
+ from mindspore._c_expression import LlmBoostBinder
30
+ self.config = config
31
+ self.binder = LlmBoostBinder("AscendNative", config.model_type)
32
+ self.binder.init_model(config.to_dict())
33
+
34
+ def init(self):
35
+ """
36
+ Initialize the object
37
+ returns True if object needs input manipulation by mindformers
38
+ """
39
+ return False
40
+
41
+ def set_kvcache(self, k_caches=None, v_caches=None):
42
+ return
43
+
44
+ def forward(self, input_ids, batch_valid_length, position_ids=None):
45
+ ret = self.binder.forward([input_ids, batch_valid_length], "nothing really")
46
+ return Tensor(ret[0])
47
+
48
+ def set_weights(self, ckpt_dict):
49
+ self.binder.set_weights_map(ckpt_dict)
50
+
51
+ def add_flags(self, is_first_iteration=False):
52
+ self.binder.add_flags(is_first_iteration=is_first_iteration)
@@ -112,8 +112,7 @@ class AtbBoostBase:
112
112
 
113
113
  def _convert_qkv_concat_weight(self, param_dict):
114
114
  """convert qkv concat weight"""
115
- assume_num_layers = 500
116
- for i in range(assume_num_layers):
115
+ for i in range(self.num_layers):
117
116
  # qkv weight concat
118
117
  wq_weight_name = f"model.layers.{i}.attention.wq.weight"
119
118
  wk_weight_name = f"model.layers.{i}.attention.wk.weight"
@@ -151,7 +150,7 @@ class AtbBoostBase:
151
150
  logger.info(f"transform: {qkv_concat_weight_name}")
152
151
  logger.info(f"transform: {gate_hidden_concat_weight_name}")
153
152
 
154
- for i in range(assume_num_layers):
153
+ for i in range(self.num_layers):
155
154
  # qkv bias concat
156
155
  wq_bias_name = f"model.layers.{i}.attention.wq.bias"
157
156
  wk_bias_name = f"model.layers.{i}.attention.wk.bias"
@@ -43,7 +43,11 @@ class LlamaBoost(AtbBoostBase):
43
43
  )
44
44
 
45
45
  def init(self):
46
- """set param"""
46
+ """
47
+ Initialize the object
48
+ returns True if object needs input manipulation by mindformers
49
+ """
50
+
47
51
  coder_param = {
48
52
  "normEps": self.config.rms_norm_eps,
49
53
  "normType": NormType.RMS_NORM,
@@ -93,6 +97,7 @@ class LlamaBoost(AtbBoostBase):
93
97
  }
94
98
  self.atb_encoder_operation.init(json.dumps({**encoder_param}))
95
99
  self.atb_decoder_operation.init(json.dumps({**decoder_param}))
100
+ return True
96
101
 
97
102
  def _prepare_inputs(
98
103
  self,
@@ -23,6 +23,7 @@ class LlmBoostType:
23
23
  pass
24
24
 
25
25
  BUILDIN = 'BuildIn'
26
+ ASCEND_NATIVE = 'LLMBoost'
26
27
 
27
28
 
28
29
  class LlmBoostRegister:
@@ -23,7 +23,7 @@ from copy import copy
23
23
  import numbers
24
24
  import mindspore as ms
25
25
  from mindspore.common.parameter import Parameter, _get_unique_parameter_key
26
- from mindspore._c_expression import Tensor as Tensor_
26
+ from mindspore._c_expression import TensorPy as Tensor_
27
27
  from mindspore._c_expression import MapTensor_
28
28
  from mindspore.ops.operations import _map_tensor_ops
29
29
 
@@ -78,12 +78,12 @@ class MapParameter(Parameter):
78
78
  if value_dtype is not None:
79
79
  if isinstance(value_shape, numbers.Number):
80
80
  value_shape = (value_shape,)
81
- data = Tensor_(value_dtype, value_shape)
81
+ data = Tensor_(dtype=value_dtype, shape=value_shape)
82
82
  elif value_tensor is not None:
83
- data = Tensor_(value_tensor.dtype, value_tensor.shape)
83
+ data = Tensor_(dtype=value_tensor.dtype, shape=value_tensor.shape)
84
84
  else:
85
85
  # default
86
- data = Tensor_(ms.float32, (1,))
86
+ data = Tensor_(dtype=ms.float32, shape=(1,))
87
87
  obj = Tensor_.__new__(cls)
88
88
  Tensor_.__init__(obj, data)
89
89
  # Compatible attributes with Parameter.
@@ -37,14 +37,14 @@ class Adadelta(Optimizer):
37
37
  Implements Adadelta algorithm.
38
38
 
39
39
  .. math::
40
- \begin{aligned}
41
- &\rule{150mm}{0.4pt} \\
40
+ \begin{aligned}
41
+ &\rule{180mm}{0.4pt} \\
42
42
  &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
43
43
  \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
44
44
  \: \lambda \text{ (weight decay)} \\
45
45
  &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
46
46
  \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
47
- &\rule{110mm}{0.4pt} \\
47
+ &\rule{180mm}{0.4pt} \\
48
48
  &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
49
49
  &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
50
50
  &\hspace{5mm}if \: \lambda \neq 0 \\
@@ -55,10 +55,10 @@ class Adadelta(Optimizer):
55
55
  &\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
56
56
  \Delta x^2_t (1 - \rho) \\
57
57
  &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
58
- &\rule{110mm}{0.4pt} \\[-1.ex]
58
+ &\rule{180mm}{0.4pt} \\[-1.ex]
59
59
  &\bf{return} \: \theta_t \\[-1.ex]
60
- &\rule{110mm}{0.4pt} \\[-1.ex]
61
- \end{aligned}
60
+ &\rule{180mm}{0.4pt} \\[-1.ex]
61
+ \end{aligned}
62
62
 
63
63
  .. warning::
64
64
  This is an experimental optimizer API that is subject to change.
@@ -38,12 +38,12 @@ class Adagrad(Optimizer):
38
38
 
39
39
  .. math::
40
40
  \begin{aligned}
41
- &\rule{110mm}{0.4pt} \\
41
+ &\rule{160mm}{0.4pt} \\
42
42
  &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
43
43
  \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
44
44
  &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
45
45
  &\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
46
- &\rule{110mm}{0.4pt} \\
46
+ &\rule{160mm}{0.4pt} \\
47
47
  &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
48
48
  &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
49
49
  &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
@@ -52,9 +52,9 @@ class Adagrad(Optimizer):
52
52
  &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
53
53
  &\hspace{5mm}\theta_t \leftarrow
54
54
  \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
55
- &\rule{110mm}{0.4pt} \\[-1.ex]
55
+ &\rule{160mm}{0.4pt} \\[-1.ex]
56
56
  &\bf{return} \: \theta_t \\[-1.ex]
57
- &\rule{110mm}{0.4pt} \\[-1.ex]
57
+ &\rule{160mm}{0.4pt} \\[-1.ex]
58
58
  \end{aligned}
59
59
 
60
60
  .. warning::
@@ -49,12 +49,14 @@ class Adam(Optimizer):
49
49
 
50
50
  .. math::
51
51
  \begin{aligned}
52
+ &\rule{180mm}{0.4pt} \\
52
53
  &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
53
54
  \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
54
55
  &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
55
56
  \:\textit{maximize} \\
56
57
  &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
57
58
  v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
59
+ &\rule{180mm}{0.4pt} \\
58
60
  &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
59
61
  &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
60
62
  &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
@@ -74,10 +76,15 @@ class Adam(Optimizer):
74
76
  &\hspace{5mm}\textbf{else} \\
75
77
  &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
76
78
  \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
79
+ &\rule{180mm}{0.4pt} \\[-1.ex]
77
80
  &\bf{return} \: \theta_t \\[-1.ex]
81
+ &\rule{180mm}{0.4pt} \\[-1.ex]
78
82
  \end{aligned}
79
83
 
80
84
  .. warning::
85
+ The implementation formula of this optimizer interface is not completely consistent with that in the paper.
86
+ If you want to use an interface that is completely consistent, it is recommended to use
87
+ :class:`mindspore.mint.optim.Adam`, which currently only supports Ascend.
81
88
  This is an experimental optimizer API that is subject to change.
82
89
  This module must be used with lr scheduler module in `LRScheduler Class
83
90
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_ .
@@ -43,14 +43,14 @@ class Adamax(Optimizer):
43
43
 
44
44
  .. math::
45
45
  \begin{aligned}
46
- &\rule{110mm}{0.4pt} \\
46
+ &\rule{180mm}{0.4pt} \\
47
47
  &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
48
48
  \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
49
49
  \: \lambda \text{ (weight decay)}, \\
50
50
  &\hspace{13mm} \epsilon \text{ (epsilon)} \\
51
51
  &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
52
52
  u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
53
- &\rule{110mm}{0.4pt} \\
53
+ &\rule{180mm}{0.4pt} \\
54
54
  &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
55
55
  &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
56
56
  &\hspace{5mm}if \: \lambda \neq 0 \\
@@ -58,9 +58,9 @@ class Adamax(Optimizer):
58
58
  &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
59
59
  &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
60
60
  &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
61
- &\rule{110mm}{0.4pt} \\[-1.ex]
61
+ &\rule{180mm}{0.4pt} \\[-1.ex]
62
62
  &\bf{return} \: \theta_t \\[-1.ex]
63
- &\rule{110mm}{0.4pt} \\[-1.ex]
63
+ &\rule{180mm}{0.4pt} \\[-1.ex]
64
64
  \end{aligned}
65
65
 
66
66
  .. warning::