mindspore 2.1.0__cp37-none-any.whl → 2.2.10__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -1,203 +1,181 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """adamw"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import functional as F, operations as P
19
- from mindspore.common.parameter import Parameter, ParameterTuple
20
- from mindspore.common.tensor import Tensor
21
- import mindspore.common.dtype as mstype
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
- from mindspore import ops
24
-
25
-
26
- class AdamW(Optimizer):
27
- r"""
28
- Implements Adam Weight Decay algorithm.
29
-
30
- .. math::
31
- \begin{aligned}
32
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
33
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
34
- \: \epsilon \text{ (epsilon)} \\
35
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
36
- \: \textit{maximize} \\
37
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
38
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
39
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
40
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
41
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
42
- &\hspace{5mm}\textbf{else} \\
43
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
44
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
45
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
46
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
47
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
48
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
49
- &\hspace{5mm}\textbf{if} \: amsgrad \\
50
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
51
- \widehat{v_t}) \\
52
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
53
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
54
- &\hspace{5mm}\textbf{else} \\
55
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
56
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
57
- &\bf{return} \: \theta_t \\[-1.ex]
58
- \end{aligned}
59
-
60
- .. warning::
61
- This is an experimental optimizer API that is subject to change.
62
- This module must be used with lr scheduler module in `LRScheduler Class
63
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
64
-
65
- Args:
66
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
67
- parameter groups
68
- lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
69
- betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
70
- Default: ``(0.9, 0.999)``.
71
- eps (float, optional): term added to the denominator to improve
72
- numerical stability. Default: ``1e-8``.
73
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
74
- amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
75
-
76
- Keyword Args:
77
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
78
- Default: ``False``.
79
-
80
- Inputs:
81
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
82
-
83
- Raises:
84
- ValueError: If the learning rate is not int, float or Tensor.
85
- ValueError: If the learning rate is less than 0.
86
- ValueError: If the `eps` is less than 0.0.
87
- ValueError: If the `betas` not in the range of 0-1.
88
- ValueError: If the `weight_decay` is less than 0.
89
-
90
- Supported Platforms:
91
- ``Ascend`` ``GPU`` ``CPU``
92
-
93
- Examples:
94
- >>> import mindspore
95
- >>> from mindspore import nn
96
- >>> # Define the network structure of LeNet5. Refer to
97
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
98
- >>> net = LeNet5()
99
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
100
- >>> optimizer = nn.optim_ex.AdamW(net.trainable_params(), lr=0.1)
101
- >>> def forward_fn(data, label):
102
- ... logits = net(data)
103
- ... loss = loss_fn(logits, label)
104
- ... return loss, logits
105
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
106
- >>> def train_step(data, label):
107
- ... (loss, _), grads = grad_fn(data, label)
108
- ... optimizer(grads)
109
- ... return loss
110
- """
111
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
112
- weight_decay=1e-2, amsgrad=False, *, maximize=False):
113
- if lr < 0.0:
114
- raise ValueError("Invalid learning rate: {}".format(lr))
115
- if eps < 0.0:
116
- raise ValueError("Invalid epsilon value: {}".format(eps))
117
- if not 0.0 <= betas[0] < 1.0:
118
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
119
- if not 0.0 <= betas[1] < 1.0:
120
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
121
- if weight_decay < 0.0:
122
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
123
-
124
- defaults = dict(lr=lr, betas=betas, eps=eps,
125
- weight_decay=weight_decay, amsgrad=amsgrad,
126
- maximize=maximize)
127
- super(AdamW, self).__init__(params, defaults)
128
-
129
- self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
130
- self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
131
- self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
132
- self.state_step = ParameterTuple(Parameter(Tensor(0, mstype.int32), "step_"+str(i))
133
- for i in range(len(self.parameters)))
134
- self.increase_tensor = Tensor(1, mstype.int32)
135
-
136
- self.op_mul = P.Mul()
137
- self.assignadd = P.AssignAdd()
138
- self.op_pow = P.Pow()
139
- self.op_sqrt = P.Sqrt()
140
- self.op_maximum = P.Maximum()
141
- self.op_cast = P.Cast()
142
-
143
- def construct(self, gradients):
144
- for group_id, group in enumerate(self.param_groups):
145
- params = []
146
- grads = []
147
- exp_avgs = []
148
- exp_avg_sqs = []
149
- max_exp_avg_sqs = []
150
- state_steps = []
151
- amsgrad = group["amsgrad"]
152
- beta1, beta2 = group['betas']
153
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps = \
154
- self._init_group(group, gradients, params, grads, amsgrad, exp_avgs,
155
- exp_avg_sqs, max_exp_avg_sqs, state_steps, group_id)
156
-
157
- self.apply_adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
158
- amsgrad, beta1, beta2, group['lr'], group['weight_decay'], group['eps'],
159
- group["maximize"], group["grad_centralization"])
160
-
161
- def apply_adamw(self, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
162
- amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, grad_centralization):
163
- grads = self._gradients_centralization(grad_centralization, grads)
164
- for i, param in enumerate(params):
165
- grad = grads[i] if not maximize else -grads[i]
166
- exp_avg = exp_avgs[i]
167
- exp_avg_sq = exp_avg_sqs[i]
168
- step_t = state_steps[i]
169
-
170
- next_param = self.op_mul(param, F.tuple_to_array((1.0,)) - lr * weight_decay)
171
- F.assign(exp_avg, self.op_mul(exp_avg, beta1) + self.op_mul(grad, 1-beta1))
172
- F.assign(exp_avg_sq, ops.addcmul(self.op_mul(exp_avg_sq, beta2), grad, grad, 1-beta2))
173
- step_t = F.depend(step_t, self.assignadd(step_t, self.increase_tensor))
174
-
175
- bias_correction1 = F.tuple_to_array((1.0,)) - self.op_pow(beta1, step_t)
176
- bias_correction2 = F.tuple_to_array((1.0,)) - self.op_pow(beta2, step_t)
177
- step_size = lr / bias_correction1
178
- bias_correction2_sqrt = self.op_sqrt(bias_correction2)
179
-
180
- if amsgrad:
181
- next_max_exp_avg = self.op_maximum(max_exp_avg_sqs[i], exp_avg_sq)
182
- denom = self.op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
183
- F.assign(max_exp_avg_sqs[i], next_max_exp_avg)
184
- else:
185
- denom = self.op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
186
-
187
- return_param = next_param - self.op_mul(exp_avg / denom, step_size)
188
- F.assign(param, return_param)
189
-
190
- def _init_group(self, group, gradients, params, grads, amsgrad, exp_avgs, exp_avg_sqs,
191
- max_exp_avg_sqs, state_steps, group_id):
192
- """ Initialize group params. """
193
- p_id = self.group_start_id[group_id]
194
- for i, param in enumerate(group["params"]):
195
- grad = gradients[p_id+i]
196
- grads.append(grad)
197
- params.append(param)
198
- exp_avgs.append(self.exp_avg[p_id+i])
199
- exp_avg_sqs.append(self.exp_avg_sq[p_id+i])
200
- if amsgrad:
201
- max_exp_avg_sqs.append(self.max_exp_avg_sq[p_id+i])
202
- state_steps.append(self.state_step[p_id+i])
203
- return params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """adamw"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.parameter import Parameter
20
+ from mindspore.common.tensor import Tensor
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+ from mindspore import ops
24
+
25
+ _adamw_opt = C.MultitypeFuncGraph("adamw_opt")
26
+
27
+ op_mul = P.Mul()
28
+ op_pow = P.Pow()
29
+ op_sqrt = P.Sqrt()
30
+ op_maximum = P.Maximum()
31
+
32
+
33
+ @_adamw_opt.register("Float", "Tensor", "Bool", "Float", "Tensor", "Float", "Float", "Tensor", "Tensor",
34
+ "Tensor", "Tensor", "Tensor")
35
+ def _run_adamw_opt(weight_decay, lr, amsgrad, eps, state_step, beta1, beta2, param, grad,
36
+ exp_avg, exp_avg_sq, max_exp_avg_sq):
37
+ """Apply adamw optimizer to the weight parameter."""
38
+ success = True
39
+ next_param = op_mul(param, 1 - lr * weight_decay)
40
+ F.assign(exp_avg, op_mul(exp_avg, beta1) + op_mul(grad, 1 - beta1))
41
+ F.assign(exp_avg_sq, ops.addcmul(op_mul(exp_avg_sq, beta2), grad, grad, 1 - beta2))
42
+ bias_correction1 = 1 - op_pow(beta1, state_step)
43
+ bias_correction2 = 1 - op_pow(beta2, state_step)
44
+ step_size = lr / bias_correction1
45
+ bias_correction2_sqrt = op_sqrt(bias_correction2)
46
+
47
+ if amsgrad:
48
+ next_max_exp_avg = op_maximum(max_exp_avg_sq, exp_avg_sq)
49
+ denom = op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
50
+ F.assign(max_exp_avg_sq, next_max_exp_avg)
51
+ else:
52
+ denom = op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
53
+
54
+ return_param = next_param - op_mul(exp_avg / denom, step_size)
55
+ F.assign(param, return_param)
56
+ return success
57
+
58
+
59
+ class AdamW(Optimizer):
60
+ r"""
61
+ Implements Adam Weight Decay algorithm.
62
+
63
+ .. math::
64
+ \begin{aligned}
65
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
66
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
67
+ \: \epsilon \text{ (epsilon)} \\
68
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
69
+ \: \textit{maximize} \\
70
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
71
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
72
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
73
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
74
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
75
+ &\hspace{5mm}\textbf{else} \\
76
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
77
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
78
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
79
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
80
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
81
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
82
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
83
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
84
+ \widehat{v_t}) \\
85
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
86
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
87
+ &\hspace{5mm}\textbf{else} \\
88
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
89
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
90
+ &\bf{return} \: \theta_t \\[-1.ex]
91
+ \end{aligned}
92
+
93
+ .. warning::
94
+ This is an experimental optimizer API that is subject to change.
95
+ This module must be used with lr scheduler module in `LRScheduler Class
96
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
97
+
98
+ Args:
99
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
100
+ parameter groups
101
+ lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
102
+ betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
103
+ Default: ``(0.9, 0.999)``.
104
+ eps (float, optional): term added to the denominator to improve
105
+ numerical stability. Default: ``1e-8``.
106
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
107
+ amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
108
+
109
+ Keyword Args:
110
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
111
+ Default: ``False``.
112
+
113
+ Inputs:
114
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
115
+
116
+ Raises:
117
+ ValueError: If the learning rate is not int, float or Tensor.
118
+ ValueError: If the learning rate is less than 0.
119
+ ValueError: If the `eps` is less than 0.0.
120
+ ValueError: If the `betas` not in the range of 0-1.
121
+ ValueError: If the `weight_decay` is less than 0.
122
+
123
+ Supported Platforms:
124
+ ``Ascend`` ``GPU`` ``CPU``
125
+
126
+ Examples:
127
+ >>> import mindspore
128
+ >>> from mindspore import nn
129
+ >>> from mindspore.experimental import optim
130
+ >>> # Define the network structure of LeNet5. Refer to
131
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
132
+ >>> net = LeNet5()
133
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
134
+ >>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1)
135
+ >>> def forward_fn(data, label):
136
+ ... logits = net(data)
137
+ ... loss = loss_fn(logits, label)
138
+ ... return loss, logits
139
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
140
+ >>> def train_step(data, label):
141
+ ... (loss, _), grads = grad_fn(data, label)
142
+ ... optimizer(grads)
143
+ ... return loss
144
+ """
145
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
146
+ weight_decay=1e-2, amsgrad=False, *, maximize=False):
147
+ if lr < 0.0:
148
+ raise ValueError("Invalid learning rate: {}".format(lr))
149
+ if eps < 0.0:
150
+ raise ValueError("Invalid epsilon value: {}".format(eps))
151
+ if not 0.0 <= betas[0] < 1.0:
152
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
153
+ if not 0.0 <= betas[1] < 1.0:
154
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
155
+ if weight_decay < 0.0:
156
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
157
+
158
+ defaults = dict(lr=lr, betas=betas, eps=eps,
159
+ weight_decay=weight_decay, amsgrad=amsgrad,
160
+ maximize=maximize)
161
+ super(AdamW, self).__init__(params, defaults)
162
+
163
+ self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
164
+ self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
165
+ self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
166
+ self.state_step = Parameter(Tensor(0, mstype.int32), "state_step")
167
+ self.increase_tensor = Tensor(1, mstype.int32)
168
+ self.assignadd = P.AssignAdd()
169
+
170
+ def construct(self, gradients):
171
+ self.assignadd(self.state_step, self.increase_tensor)
172
+ for group_id, group in enumerate(self.param_groups):
173
+ beta1, beta2 = group['betas']
174
+ start_id = self.group_start_id[group_id]
175
+ end_id = self.group_start_id[group_id + 1]
176
+ grads = gradients[start_id: end_id] if not group.get("maximize") else -gradients[start_id: end_id]
177
+ self.hyper_map(F.partial(_adamw_opt, group.get("weight_decay"), group.get("lr"), group.get("amsgrad"),
178
+ group.get("eps"), self.state_step, beta1, beta2),
179
+ self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
180
+ self.exp_avg_sq[start_id: end_id], self.max_exp_avg_sq[start_id: end_id])
181
+ return True