mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  164. mindspore/lib/libmindspore_shared_lib.so +0 -0
  165. mindspore/lib/libnnacl.so +0 -0
  166. mindspore/lib/libopencv_core.so.4.5 +0 -0
  167. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  169. mindspore/lib/libps_cache.so +0 -0
  170. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  183. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  187. mindspore/lib/plugin/ascend/libakg.so +0 -0
  188. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  189. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  190. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  191. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  193. mindspore/lib/plugin/cpu/libakg.so +0 -0
  194. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  196. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  197. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  198. mindspore/nn/__init__.py +0 -2
  199. mindspore/nn/cell.py +313 -74
  200. mindspore/nn/dynamic_lr.py +21 -21
  201. mindspore/nn/layer/activation.py +22 -30
  202. mindspore/nn/layer/basic.py +15 -13
  203. mindspore/nn/layer/channel_shuffle.py +1 -1
  204. mindspore/nn/layer/container.py +271 -9
  205. mindspore/nn/layer/conv.py +323 -204
  206. mindspore/nn/layer/dense.py +8 -5
  207. mindspore/nn/layer/embedding.py +33 -27
  208. mindspore/nn/layer/flash_attention.py +61 -95
  209. mindspore/nn/layer/image.py +8 -6
  210. mindspore/nn/layer/math.py +16 -25
  211. mindspore/nn/layer/normalization.py +107 -66
  212. mindspore/nn/layer/padding.py +1 -1
  213. mindspore/nn/layer/pooling.py +131 -109
  214. mindspore/nn/layer/rnn_cells.py +27 -22
  215. mindspore/nn/layer/rnns.py +13 -16
  216. mindspore/nn/layer/thor_layer.py +1 -1
  217. mindspore/nn/layer/transformer.py +221 -154
  218. mindspore/nn/learning_rate_schedule.py +9 -1
  219. mindspore/nn/loss/loss.py +235 -174
  220. mindspore/nn/optim/ada_grad.py +2 -1
  221. mindspore/nn/optim/adadelta.py +1 -0
  222. mindspore/nn/optim/adafactor.py +2 -1
  223. mindspore/nn/optim/adam.py +7 -4
  224. mindspore/nn/optim/adamax.py +3 -2
  225. mindspore/nn/optim/adasum.py +2 -2
  226. mindspore/nn/optim/asgd.py +2 -3
  227. mindspore/nn/optim/ftrl.py +6 -5
  228. mindspore/nn/optim/lamb.py +7 -4
  229. mindspore/nn/optim/lars.py +1 -1
  230. mindspore/nn/optim/lazyadam.py +5 -3
  231. mindspore/nn/optim/momentum.py +2 -1
  232. mindspore/nn/optim/optimizer.py +53 -4
  233. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  234. mindspore/nn/optim/rmsprop.py +4 -3
  235. mindspore/nn/optim/rprop.py +23 -12
  236. mindspore/nn/optim/sgd.py +26 -11
  237. mindspore/nn/optim/thor.py +9 -7
  238. mindspore/nn/probability/bijector/bijector.py +5 -5
  239. mindspore/nn/probability/bijector/power_transform.py +27 -27
  240. mindspore/nn/probability/bijector/softplus.py +3 -3
  241. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  242. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  243. mindspore/nn/probability/distribution/beta.py +3 -3
  244. mindspore/nn/probability/distribution/categorical.py +7 -7
  245. mindspore/nn/probability/distribution/cauchy.py +0 -1
  246. mindspore/nn/probability/distribution/distribution.py +3 -3
  247. mindspore/nn/probability/distribution/gamma.py +3 -3
  248. mindspore/nn/probability/distribution/geometric.py +4 -4
  249. mindspore/nn/probability/distribution/gumbel.py +4 -4
  250. mindspore/nn/probability/distribution/log_normal.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +2 -2
  252. mindspore/nn/probability/distribution/poisson.py +4 -4
  253. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  254. mindspore/nn/probability/distribution/uniform.py +6 -6
  255. mindspore/nn/wrap/__init__.py +4 -2
  256. mindspore/nn/wrap/cell_wrapper.py +87 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  274. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  275. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  276. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  277. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  278. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  279. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  280. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  281. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  283. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  284. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  285. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  286. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  287. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  288. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  289. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  290. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  291. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  293. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  294. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  295. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  297. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  298. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  299. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  300. mindspore/ops/_primitive_cache.py +1 -1
  301. mindspore/ops/_tracefunc.py +45 -13
  302. mindspore/ops/_utils/utils.py +6 -1
  303. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  304. mindspore/ops/_vmap/vmap_base.py +3 -3
  305. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  306. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  307. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  308. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  309. mindspore/ops/arg_dtype_cast.py +54 -0
  310. mindspore/ops/composite/base.py +37 -10
  311. mindspore/ops/composite/math_ops.py +5 -4
  312. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  313. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  314. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  315. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  316. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  319. mindspore/ops/deprecated.py +304 -0
  320. mindspore/ops/function/__init__.py +4 -1
  321. mindspore/ops/function/array_func.py +174 -193
  322. mindspore/ops/function/clip_func.py +81 -13
  323. mindspore/ops/function/debug_func.py +1 -1
  324. mindspore/ops/function/grad/grad_func.py +18 -9
  325. mindspore/ops/function/image_func.py +10 -4
  326. mindspore/ops/function/linalg_func.py +5 -5
  327. mindspore/ops/function/math_func.py +575 -386
  328. mindspore/ops/function/nn_func.py +568 -260
  329. mindspore/ops/function/random_func.py +88 -57
  330. mindspore/ops/function/sparse_func.py +1 -1
  331. mindspore/ops/function/sparse_unary_func.py +14 -12
  332. mindspore/ops/function/vmap_func.py +6 -5
  333. mindspore/ops/functional.py +15 -10
  334. mindspore/ops/op_info_register.py +244 -25
  335. mindspore/ops/operations/__init__.py +31 -19
  336. mindspore/ops/operations/_grad_ops.py +71 -7
  337. mindspore/ops/operations/_inner_ops.py +350 -17
  338. mindspore/ops/operations/_quant_ops.py +4 -8
  339. mindspore/ops/operations/_sequence_ops.py +42 -0
  340. mindspore/ops/operations/array_ops.py +68 -282
  341. mindspore/ops/operations/comm_ops.py +107 -59
  342. mindspore/ops/operations/custom_ops.py +94 -70
  343. mindspore/ops/operations/debug_ops.py +8 -4
  344. mindspore/ops/operations/image_ops.py +18 -12
  345. mindspore/ops/operations/inner_ops.py +26 -3
  346. mindspore/ops/operations/math_ops.py +192 -144
  347. mindspore/ops/operations/nn_ops.py +857 -489
  348. mindspore/ops/operations/other_ops.py +0 -22
  349. mindspore/ops/operations/random_ops.py +53 -111
  350. mindspore/ops/operations/sparse_ops.py +3 -1
  351. mindspore/ops/primitive.py +24 -18
  352. mindspore/parallel/_auto_parallel_context.py +68 -8
  353. mindspore/parallel/_cost_model_context.py +2 -2
  354. mindspore/parallel/_offload_context.py +17 -3
  355. mindspore/parallel/_parallel_serialization.py +12 -5
  356. mindspore/parallel/_ps_context.py +12 -0
  357. mindspore/parallel/_tensor.py +18 -13
  358. mindspore/parallel/_transformer/layers.py +5 -3
  359. mindspore/parallel/_transformer/loss.py +1 -0
  360. mindspore/parallel/_transformer/moe.py +2 -2
  361. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  362. mindspore/parallel/_transformer/transformer.py +23 -3
  363. mindspore/parallel/_utils.py +11 -7
  364. mindspore/parallel/algo_parameter_config.py +85 -5
  365. mindspore/parallel/checkpoint_transform.py +19 -12
  366. mindspore/parallel/shard.py +21 -14
  367. mindspore/profiler/common/struct_type.py +3 -3
  368. mindspore/profiler/common/util.py +4 -2
  369. mindspore/profiler/envprofiling.py +1 -1
  370. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  371. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  372. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  373. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  374. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  375. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  376. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  377. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  378. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  379. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  380. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  381. mindspore/profiler/parser/flops_parser.py +15 -11
  382. mindspore/profiler/parser/framework_parser.py +38 -22
  383. mindspore/profiler/parser/hccl_parser.py +16 -12
  384. mindspore/profiler/parser/integrator.py +22 -11
  385. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  386. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  387. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  388. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  389. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  390. mindspore/profiler/parser/optime_parser.py +1 -1
  391. mindspore/profiler/parser/profiler_info.py +21 -2
  392. mindspore/profiler/parser/step_trace_parser.py +11 -14
  393. mindspore/profiler/profiling.py +179 -89
  394. mindspore/rewrite/api/node.py +102 -19
  395. mindspore/rewrite/api/node_type.py +5 -1
  396. mindspore/rewrite/api/pattern_engine.py +1 -1
  397. mindspore/rewrite/api/scoped_value.py +9 -17
  398. mindspore/rewrite/api/symbol_tree.py +131 -47
  399. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  400. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  401. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  402. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  403. mindspore/rewrite/common/rewrite_elog.py +5 -1
  404. mindspore/rewrite/namer.py +33 -24
  405. mindspore/rewrite/namespace.py +14 -5
  406. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  407. mindspore/rewrite/node/call_function.py +79 -0
  408. mindspore/rewrite/node/cell_container.py +135 -0
  409. mindspore/rewrite/node/control_flow.py +88 -0
  410. mindspore/rewrite/{node.py → node/node.py} +273 -234
  411. mindspore/rewrite/node/node_manager.py +254 -0
  412. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  413. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  414. mindspore/rewrite/parsers/assign_parser.py +216 -221
  415. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  416. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  417. mindspore/rewrite/parsers/constant_parser.py +9 -6
  418. mindspore/rewrite/parsers/container_parser.py +9 -7
  419. mindspore/rewrite/parsers/for_parser.py +42 -21
  420. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  421. mindspore/rewrite/parsers/if_parser.py +28 -24
  422. mindspore/rewrite/parsers/module_parser.py +196 -25
  423. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  424. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  425. mindspore/rewrite/parsers/return_parser.py +6 -6
  426. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  427. mindspore/rewrite/sparsify/utils.py +1 -1
  428. mindspore/rewrite/symbol_tree.py +523 -578
  429. mindspore/rewrite/symbol_tree_builder.py +9 -193
  430. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  431. mindspore/run_check/_check_version.py +6 -4
  432. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  433. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  434. mindspore/scipy/linalg.py +1 -1
  435. mindspore/scipy/ops.py +55 -5
  436. mindspore/scipy/optimize/__init__.py +3 -2
  437. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  438. mindspore/scipy/optimize/minimize.py +7 -3
  439. mindspore/train/_utils.py +7 -3
  440. mindspore/train/amp.py +323 -123
  441. mindspore/train/anf_ir_pb2.py +14 -2
  442. mindspore/train/callback/_backup_and_restore.py +2 -12
  443. mindspore/train/callback/_callback.py +29 -4
  444. mindspore/train/callback/_checkpoint.py +23 -8
  445. mindspore/train/callback/_early_stop.py +2 -2
  446. mindspore/train/callback/_landscape.py +4 -4
  447. mindspore/train/callback/_loss_monitor.py +2 -2
  448. mindspore/train/callback/_on_request_exit.py +2 -2
  449. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  450. mindspore/train/callback/_summary_collector.py +15 -8
  451. mindspore/train/callback/_time_monitor.py +58 -5
  452. mindspore/train/data_sink.py +5 -11
  453. mindspore/train/dataset_helper.py +84 -57
  454. mindspore/train/loss_scale_manager.py +2 -2
  455. mindspore/train/metrics/__init__.py +3 -3
  456. mindspore/train/metrics/cosine_similarity.py +1 -1
  457. mindspore/train/metrics/hausdorff_distance.py +3 -2
  458. mindspore/train/metrics/mean_surface_distance.py +3 -2
  459. mindspore/train/metrics/metric.py +39 -19
  460. mindspore/train/metrics/roc.py +2 -2
  461. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  462. mindspore/train/mind_ir_pb2.py +85 -36
  463. mindspore/train/model.py +187 -47
  464. mindspore/train/serialization.py +487 -161
  465. mindspore/train/summary/_summary_adapter.py +1 -1
  466. mindspore/train/summary/_writer_pool.py +3 -2
  467. mindspore/train/summary/summary_record.py +37 -17
  468. mindspore/train/train_thor/convert_utils.py +3 -3
  469. mindspore/train/train_thor/dataset_helper.py +1 -1
  470. mindspore/version.py +1 -1
  471. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  474. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  475. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  477. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  478. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  479. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  480. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  481. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  482. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  483. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  485. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  486. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  487. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  488. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  489. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  490. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  491. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  492. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  493. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  494. mindspore/_extends/graph_kernel/expander.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  496. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  497. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  498. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  499. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  500. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  501. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  502. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  503. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  504. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  505. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  507. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  508. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  509. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  510. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  511. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  513. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  514. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  515. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  516. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  517. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  518. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  519. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  520. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  522. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  523. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  524. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  526. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  527. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  528. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  531. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  533. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  534. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  535. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  536. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  537. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  538. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  540. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  541. mindspore/dataset/datapreprocess/__init__.py +0 -20
  542. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  543. mindspore/include/api/net.h +0 -142
  544. mindspore/nn/lr_scheduler.py +0 -262
  545. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  546. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  547. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  548. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  549. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  550. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  559. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  563. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  574. mindspore/rewrite/node_visitor.py +0 -44
  575. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  576. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,203 +1,181 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """adamw"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import functional as F, operations as P
19
- from mindspore.common.parameter import Parameter, ParameterTuple
20
- from mindspore.common.tensor import Tensor
21
- import mindspore.common.dtype as mstype
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
- from mindspore import ops
24
-
25
-
26
- class AdamW(Optimizer):
27
- r"""
28
- Implements Adam Weight Decay algorithm.
29
-
30
- .. math::
31
- \begin{aligned}
32
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
33
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
34
- \: \epsilon \text{ (epsilon)} \\
35
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
36
- \: \textit{maximize} \\
37
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
38
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
39
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
40
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
41
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
42
- &\hspace{5mm}\textbf{else} \\
43
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
44
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
45
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
46
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
47
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
48
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
49
- &\hspace{5mm}\textbf{if} \: amsgrad \\
50
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
51
- \widehat{v_t}) \\
52
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
53
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
54
- &\hspace{5mm}\textbf{else} \\
55
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
56
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
57
- &\bf{return} \: \theta_t \\[-1.ex]
58
- \end{aligned}
59
-
60
- .. warning::
61
- This is an experimental optimizer API that is subject to change.
62
- This module must be used with lr scheduler module in `LRScheduler Class
63
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
64
-
65
- Args:
66
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
67
- parameter groups
68
- lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
69
- betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
70
- Default: ``(0.9, 0.999)``.
71
- eps (float, optional): term added to the denominator to improve
72
- numerical stability. Default: ``1e-8``.
73
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
74
- amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
75
-
76
- Keyword Args:
77
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
78
- Default: ``False``.
79
-
80
- Inputs:
81
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
82
-
83
- Raises:
84
- ValueError: If the learning rate is not int, float or Tensor.
85
- ValueError: If the learning rate is less than 0.
86
- ValueError: If the `eps` is less than 0.0.
87
- ValueError: If the `betas` not in the range of 0-1.
88
- ValueError: If the `weight_decay` is less than 0.
89
-
90
- Supported Platforms:
91
- ``Ascend`` ``GPU`` ``CPU``
92
-
93
- Examples:
94
- >>> import mindspore
95
- >>> from mindspore import nn
96
- >>> # Define the network structure of LeNet5. Refer to
97
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
98
- >>> net = LeNet5()
99
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
100
- >>> optimizer = nn.optim_ex.AdamW(net.trainable_params(), lr=0.1)
101
- >>> def forward_fn(data, label):
102
- ... logits = net(data)
103
- ... loss = loss_fn(logits, label)
104
- ... return loss, logits
105
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
106
- >>> def train_step(data, label):
107
- ... (loss, _), grads = grad_fn(data, label)
108
- ... optimizer(grads)
109
- ... return loss
110
- """
111
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
112
- weight_decay=1e-2, amsgrad=False, *, maximize=False):
113
- if lr < 0.0:
114
- raise ValueError("Invalid learning rate: {}".format(lr))
115
- if eps < 0.0:
116
- raise ValueError("Invalid epsilon value: {}".format(eps))
117
- if not 0.0 <= betas[0] < 1.0:
118
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
119
- if not 0.0 <= betas[1] < 1.0:
120
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
121
- if weight_decay < 0.0:
122
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
123
-
124
- defaults = dict(lr=lr, betas=betas, eps=eps,
125
- weight_decay=weight_decay, amsgrad=amsgrad,
126
- maximize=maximize)
127
- super(AdamW, self).__init__(params, defaults)
128
-
129
- self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
130
- self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
131
- self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
132
- self.state_step = ParameterTuple(Parameter(Tensor(0, mstype.int32), "step_"+str(i))
133
- for i in range(len(self.parameters)))
134
- self.increase_tensor = Tensor(1, mstype.int32)
135
-
136
- self.op_mul = P.Mul()
137
- self.assignadd = P.AssignAdd()
138
- self.op_pow = P.Pow()
139
- self.op_sqrt = P.Sqrt()
140
- self.op_maximum = P.Maximum()
141
- self.op_cast = P.Cast()
142
-
143
- def construct(self, gradients):
144
- for group_id, group in enumerate(self.param_groups):
145
- params = []
146
- grads = []
147
- exp_avgs = []
148
- exp_avg_sqs = []
149
- max_exp_avg_sqs = []
150
- state_steps = []
151
- amsgrad = group["amsgrad"]
152
- beta1, beta2 = group['betas']
153
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps = \
154
- self._init_group(group, gradients, params, grads, amsgrad, exp_avgs,
155
- exp_avg_sqs, max_exp_avg_sqs, state_steps, group_id)
156
-
157
- self.apply_adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
158
- amsgrad, beta1, beta2, group['lr'], group['weight_decay'], group['eps'],
159
- group["maximize"], group["grad_centralization"])
160
-
161
- def apply_adamw(self, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
162
- amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, grad_centralization):
163
- grads = self._gradients_centralization(grad_centralization, grads)
164
- for i, param in enumerate(params):
165
- grad = grads[i] if not maximize else -grads[i]
166
- exp_avg = exp_avgs[i]
167
- exp_avg_sq = exp_avg_sqs[i]
168
- step_t = state_steps[i]
169
-
170
- next_param = self.op_mul(param, F.tuple_to_array((1.0,)) - lr * weight_decay)
171
- F.assign(exp_avg, self.op_mul(exp_avg, beta1) + self.op_mul(grad, 1-beta1))
172
- F.assign(exp_avg_sq, ops.addcmul(self.op_mul(exp_avg_sq, beta2), grad, grad, 1-beta2))
173
- step_t = F.depend(step_t, self.assignadd(step_t, self.increase_tensor))
174
-
175
- bias_correction1 = F.tuple_to_array((1.0,)) - self.op_pow(beta1, step_t)
176
- bias_correction2 = F.tuple_to_array((1.0,)) - self.op_pow(beta2, step_t)
177
- step_size = lr / bias_correction1
178
- bias_correction2_sqrt = self.op_sqrt(bias_correction2)
179
-
180
- if amsgrad:
181
- next_max_exp_avg = self.op_maximum(max_exp_avg_sqs[i], exp_avg_sq)
182
- denom = self.op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
183
- F.assign(max_exp_avg_sqs[i], next_max_exp_avg)
184
- else:
185
- denom = self.op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
186
-
187
- return_param = next_param - self.op_mul(exp_avg / denom, step_size)
188
- F.assign(param, return_param)
189
-
190
- def _init_group(self, group, gradients, params, grads, amsgrad, exp_avgs, exp_avg_sqs,
191
- max_exp_avg_sqs, state_steps, group_id):
192
- """ Initialize group params. """
193
- p_id = self.group_start_id[group_id]
194
- for i, param in enumerate(group["params"]):
195
- grad = gradients[p_id+i]
196
- grads.append(grad)
197
- params.append(param)
198
- exp_avgs.append(self.exp_avg[p_id+i])
199
- exp_avg_sqs.append(self.exp_avg_sq[p_id+i])
200
- if amsgrad:
201
- max_exp_avg_sqs.append(self.max_exp_avg_sq[p_id+i])
202
- state_steps.append(self.state_step[p_id+i])
203
- return params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """adamw"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.parameter import Parameter
20
+ from mindspore.common.tensor import Tensor
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+ from mindspore import ops
24
+
25
+ _adamw_opt = C.MultitypeFuncGraph("adamw_opt")
26
+
27
+ op_mul = P.Mul()
28
+ op_pow = P.Pow()
29
+ op_sqrt = P.Sqrt()
30
+ op_maximum = P.Maximum()
31
+
32
+
33
+ @_adamw_opt.register("Float", "Tensor", "Bool", "Float", "Tensor", "Float", "Float", "Tensor", "Tensor",
34
+ "Tensor", "Tensor", "Tensor")
35
+ def _run_adamw_opt(weight_decay, lr, amsgrad, eps, state_step, beta1, beta2, param, grad,
36
+ exp_avg, exp_avg_sq, max_exp_avg_sq):
37
+ """Apply adamw optimizer to the weight parameter."""
38
+ success = True
39
+ next_param = op_mul(param, 1 - lr * weight_decay)
40
+ F.assign(exp_avg, op_mul(exp_avg, beta1) + op_mul(grad, 1 - beta1))
41
+ F.assign(exp_avg_sq, ops.addcmul(op_mul(exp_avg_sq, beta2), grad, grad, 1 - beta2))
42
+ bias_correction1 = 1 - op_pow(beta1, state_step)
43
+ bias_correction2 = 1 - op_pow(beta2, state_step)
44
+ step_size = lr / bias_correction1
45
+ bias_correction2_sqrt = op_sqrt(bias_correction2)
46
+
47
+ if amsgrad:
48
+ next_max_exp_avg = op_maximum(max_exp_avg_sq, exp_avg_sq)
49
+ denom = op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
50
+ F.assign(max_exp_avg_sq, next_max_exp_avg)
51
+ else:
52
+ denom = op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
53
+
54
+ return_param = next_param - op_mul(exp_avg / denom, step_size)
55
+ F.assign(param, return_param)
56
+ return success
57
+
58
+
59
+ class AdamW(Optimizer):
60
+ r"""
61
+ Implements Adam Weight Decay algorithm.
62
+
63
+ .. math::
64
+ \begin{aligned}
65
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
66
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
67
+ \: \epsilon \text{ (epsilon)} \\
68
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
69
+ \: \textit{maximize} \\
70
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
71
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
72
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
73
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
74
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
75
+ &\hspace{5mm}\textbf{else} \\
76
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
77
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
78
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
79
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
80
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
81
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
82
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
83
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
84
+ \widehat{v_t}) \\
85
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
86
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
87
+ &\hspace{5mm}\textbf{else} \\
88
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
89
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
90
+ &\bf{return} \: \theta_t \\[-1.ex]
91
+ \end{aligned}
92
+
93
+ .. warning::
94
+ This is an experimental optimizer API that is subject to change.
95
+ This module must be used with lr scheduler module in `LRScheduler Class
96
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
97
+
98
+ Args:
99
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
100
+ parameter groups
101
+ lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
102
+ betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
103
+ Default: ``(0.9, 0.999)``.
104
+ eps (float, optional): term added to the denominator to improve
105
+ numerical stability. Default: ``1e-8``.
106
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
107
+ amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
108
+
109
+ Keyword Args:
110
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
111
+ Default: ``False``.
112
+
113
+ Inputs:
114
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
115
+
116
+ Raises:
117
+ ValueError: If the learning rate is not int, float or Tensor.
118
+ ValueError: If the learning rate is less than 0.
119
+ ValueError: If the `eps` is less than 0.0.
120
+ ValueError: If the `betas` not in the range of 0-1.
121
+ ValueError: If the `weight_decay` is less than 0.
122
+
123
+ Supported Platforms:
124
+ ``Ascend`` ``GPU`` ``CPU``
125
+
126
+ Examples:
127
+ >>> import mindspore
128
+ >>> from mindspore import nn
129
+ >>> from mindspore.experimental import optim
130
+ >>> # Define the network structure of LeNet5. Refer to
131
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
132
+ >>> net = LeNet5()
133
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
134
+ >>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1)
135
+ >>> def forward_fn(data, label):
136
+ ... logits = net(data)
137
+ ... loss = loss_fn(logits, label)
138
+ ... return loss, logits
139
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
140
+ >>> def train_step(data, label):
141
+ ... (loss, _), grads = grad_fn(data, label)
142
+ ... optimizer(grads)
143
+ ... return loss
144
+ """
145
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
146
+ weight_decay=1e-2, amsgrad=False, *, maximize=False):
147
+ if lr < 0.0:
148
+ raise ValueError("Invalid learning rate: {}".format(lr))
149
+ if eps < 0.0:
150
+ raise ValueError("Invalid epsilon value: {}".format(eps))
151
+ if not 0.0 <= betas[0] < 1.0:
152
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
153
+ if not 0.0 <= betas[1] < 1.0:
154
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
155
+ if weight_decay < 0.0:
156
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
157
+
158
+ defaults = dict(lr=lr, betas=betas, eps=eps,
159
+ weight_decay=weight_decay, amsgrad=amsgrad,
160
+ maximize=maximize)
161
+ super(AdamW, self).__init__(params, defaults)
162
+
163
+ self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
164
+ self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
165
+ self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
166
+ self.state_step = Parameter(Tensor(0, mstype.int32), "state_step")
167
+ self.increase_tensor = Tensor(1, mstype.int32)
168
+ self.assignadd = P.AssignAdd()
169
+
170
+ def construct(self, gradients):
171
+ self.assignadd(self.state_step, self.increase_tensor)
172
+ for group_id, group in enumerate(self.param_groups):
173
+ beta1, beta2 = group['betas']
174
+ start_id = self.group_start_id[group_id]
175
+ end_id = self.group_start_id[group_id + 1]
176
+ grads = gradients[start_id: end_id] if not group.get("maximize") else -gradients[start_id: end_id]
177
+ self.hyper_map(F.partial(_adamw_opt, group.get("weight_decay"), group.get("lr"), group.get("amsgrad"),
178
+ group.get("eps"), self.state_step, beta1, beta2),
179
+ self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
180
+ self.exp_avg_sq[start_id: end_id], self.max_exp_avg_sq[start_id: end_id])
181
+ return True