mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (505) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +12 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +61 -71
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +74 -104
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +13 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +28 -5
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +141 -88
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/cell_wrapper.py +84 -34
  205. mindspore/nn/wrap/grad_reducer.py +8 -5
  206. mindspore/nn/wrap/loss_scale.py +105 -42
  207. mindspore/numpy/array_creations.py +1 -2
  208. mindspore/numpy/array_ops.py +3 -2
  209. mindspore/numpy/utils_const.py +5 -5
  210. mindspore/opencv_core452.dll +0 -0
  211. mindspore/opencv_imgcodecs452.dll +0 -0
  212. mindspore/opencv_imgproc452.dll +0 -0
  213. mindspore/ops/_grad_experimental/__init__.py +0 -5
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  216. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  217. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  218. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  219. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  220. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  222. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  224. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  225. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  226. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  227. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  228. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  229. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  230. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  231. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  232. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  233. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  234. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  235. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  236. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  237. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  238. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  239. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  240. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  241. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  242. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  243. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  244. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  245. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  246. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  248. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  249. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  250. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  251. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  252. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  253. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  254. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  255. mindspore/ops/_primitive_cache.py +1 -1
  256. mindspore/ops/_tracefunc.py +45 -13
  257. mindspore/ops/_utils/utils.py +6 -1
  258. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  259. mindspore/ops/_vmap/vmap_base.py +3 -3
  260. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  261. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  262. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  263. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  264. mindspore/ops/arg_dtype_cast.py +54 -0
  265. mindspore/ops/composite/base.py +37 -10
  266. mindspore/ops/composite/math_ops.py +5 -4
  267. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  268. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  269. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  270. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  271. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  272. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  273. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  274. mindspore/ops/deprecated.py +304 -0
  275. mindspore/ops/function/__init__.py +4 -1
  276. mindspore/ops/function/array_func.py +174 -193
  277. mindspore/ops/function/clip_func.py +81 -13
  278. mindspore/ops/function/debug_func.py +1 -1
  279. mindspore/ops/function/grad/grad_func.py +18 -9
  280. mindspore/ops/function/image_func.py +10 -4
  281. mindspore/ops/function/linalg_func.py +5 -5
  282. mindspore/ops/function/math_func.py +575 -386
  283. mindspore/ops/function/nn_func.py +568 -260
  284. mindspore/ops/function/random_func.py +88 -57
  285. mindspore/ops/function/sparse_func.py +1 -1
  286. mindspore/ops/function/sparse_unary_func.py +14 -12
  287. mindspore/ops/function/vmap_func.py +6 -5
  288. mindspore/ops/functional.py +15 -10
  289. mindspore/ops/op_info_register.py +244 -25
  290. mindspore/ops/operations/__init__.py +28 -19
  291. mindspore/ops/operations/_grad_ops.py +72 -7
  292. mindspore/ops/operations/_inner_ops.py +350 -17
  293. mindspore/ops/operations/_quant_ops.py +4 -8
  294. mindspore/ops/operations/_sequence_ops.py +42 -0
  295. mindspore/ops/operations/array_ops.py +68 -282
  296. mindspore/ops/operations/comm_ops.py +107 -59
  297. mindspore/ops/operations/custom_ops.py +94 -70
  298. mindspore/ops/operations/debug_ops.py +8 -4
  299. mindspore/ops/operations/image_ops.py +18 -12
  300. mindspore/ops/operations/inner_ops.py +26 -3
  301. mindspore/ops/operations/math_ops.py +189 -141
  302. mindspore/ops/operations/nn_ops.py +794 -489
  303. mindspore/ops/operations/other_ops.py +0 -22
  304. mindspore/ops/operations/random_ops.py +53 -111
  305. mindspore/ops/operations/sparse_ops.py +3 -1
  306. mindspore/ops/primitive.py +24 -18
  307. mindspore/parallel/_auto_parallel_context.py +68 -8
  308. mindspore/parallel/_cost_model_context.py +2 -2
  309. mindspore/parallel/_offload_context.py +17 -3
  310. mindspore/parallel/_parallel_serialization.py +12 -5
  311. mindspore/parallel/_ps_context.py +12 -0
  312. mindspore/parallel/_tensor.py +18 -13
  313. mindspore/parallel/_transformer/layers.py +5 -3
  314. mindspore/parallel/_transformer/loss.py +1 -0
  315. mindspore/parallel/_transformer/moe.py +2 -2
  316. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  317. mindspore/parallel/_transformer/transformer.py +23 -3
  318. mindspore/parallel/_utils.py +11 -7
  319. mindspore/parallel/algo_parameter_config.py +85 -5
  320. mindspore/parallel/checkpoint_transform.py +19 -12
  321. mindspore/parallel/shard.py +21 -14
  322. mindspore/pgodb140.dll +0 -0
  323. mindspore/pgort140.dll +0 -0
  324. mindspore/profiler/common/struct_type.py +3 -3
  325. mindspore/profiler/common/util.py +4 -2
  326. mindspore/profiler/envprofiling.py +1 -1
  327. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  328. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  329. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  330. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  331. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  332. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  333. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  334. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  335. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  336. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  337. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  338. mindspore/profiler/parser/flops_parser.py +15 -11
  339. mindspore/profiler/parser/framework_parser.py +38 -22
  340. mindspore/profiler/parser/hccl_parser.py +16 -12
  341. mindspore/profiler/parser/integrator.py +22 -11
  342. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  343. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  344. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  345. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  346. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  347. mindspore/profiler/parser/optime_parser.py +1 -1
  348. mindspore/profiler/parser/profiler_info.py +21 -2
  349. mindspore/profiler/parser/step_trace_parser.py +11 -14
  350. mindspore/profiler/profiling.py +179 -89
  351. mindspore/rewrite/api/node.py +102 -19
  352. mindspore/rewrite/api/node_type.py +5 -1
  353. mindspore/rewrite/api/pattern_engine.py +1 -1
  354. mindspore/rewrite/api/scoped_value.py +9 -17
  355. mindspore/rewrite/api/symbol_tree.py +131 -47
  356. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  357. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  358. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  359. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  360. mindspore/rewrite/common/rewrite_elog.py +5 -1
  361. mindspore/rewrite/namer.py +33 -24
  362. mindspore/rewrite/namespace.py +14 -5
  363. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  364. mindspore/rewrite/node/call_function.py +79 -0
  365. mindspore/rewrite/node/cell_container.py +135 -0
  366. mindspore/rewrite/node/control_flow.py +88 -0
  367. mindspore/rewrite/{node.py → node/node.py} +273 -234
  368. mindspore/rewrite/node/node_manager.py +254 -0
  369. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  370. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  371. mindspore/rewrite/parsers/assign_parser.py +216 -221
  372. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  373. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  374. mindspore/rewrite/parsers/constant_parser.py +9 -6
  375. mindspore/rewrite/parsers/container_parser.py +9 -7
  376. mindspore/rewrite/parsers/for_parser.py +36 -15
  377. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  378. mindspore/rewrite/parsers/if_parser.py +28 -24
  379. mindspore/rewrite/parsers/module_parser.py +196 -25
  380. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  381. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  382. mindspore/rewrite/parsers/return_parser.py +6 -6
  383. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  384. mindspore/rewrite/sparsify/utils.py +1 -1
  385. mindspore/rewrite/symbol_tree.py +523 -578
  386. mindspore/rewrite/symbol_tree_builder.py +9 -193
  387. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  388. mindspore/run_check/_check_version.py +6 -4
  389. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  390. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  391. mindspore/tbbmalloc.dll +0 -0
  392. mindspore/tinyxml2.dll +0 -0
  393. mindspore/train/_utils.py +7 -3
  394. mindspore/train/amp.py +323 -123
  395. mindspore/train/anf_ir_pb2.py +14 -2
  396. mindspore/train/callback/_backup_and_restore.py +2 -12
  397. mindspore/train/callback/_callback.py +29 -4
  398. mindspore/train/callback/_checkpoint.py +23 -8
  399. mindspore/train/callback/_early_stop.py +2 -2
  400. mindspore/train/callback/_landscape.py +4 -4
  401. mindspore/train/callback/_loss_monitor.py +2 -2
  402. mindspore/train/callback/_on_request_exit.py +2 -2
  403. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  404. mindspore/train/callback/_summary_collector.py +15 -8
  405. mindspore/train/callback/_time_monitor.py +58 -5
  406. mindspore/train/data_sink.py +5 -11
  407. mindspore/train/dataset_helper.py +84 -57
  408. mindspore/train/loss_scale_manager.py +2 -2
  409. mindspore/train/metrics/__init__.py +3 -3
  410. mindspore/train/metrics/cosine_similarity.py +1 -1
  411. mindspore/train/metrics/hausdorff_distance.py +3 -2
  412. mindspore/train/metrics/mean_surface_distance.py +3 -2
  413. mindspore/train/metrics/metric.py +39 -19
  414. mindspore/train/metrics/roc.py +2 -2
  415. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  416. mindspore/train/mind_ir_pb2.py +85 -36
  417. mindspore/train/model.py +187 -47
  418. mindspore/train/serialization.py +487 -161
  419. mindspore/train/summary/_summary_adapter.py +1 -1
  420. mindspore/train/summary/_writer_pool.py +3 -2
  421. mindspore/train/summary/summary_record.py +37 -17
  422. mindspore/train/train_thor/convert_utils.py +3 -3
  423. mindspore/train/train_thor/dataset_helper.py +1 -1
  424. mindspore/turbojpeg.dll +0 -0
  425. mindspore/vcmeta.dll +0 -0
  426. mindspore/vcruntime140.dll +0 -0
  427. mindspore/vcruntime140_1.dll +0 -0
  428. mindspore/version.py +1 -1
  429. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
  430. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
  431. mindspore/_extends/graph_kernel/expander.py +0 -80
  432. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  433. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  434. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  435. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  436. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  437. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  438. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  439. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  440. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  441. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  442. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  443. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  444. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  445. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  446. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  447. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  448. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  449. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  450. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  451. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  452. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  453. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  454. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  455. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  456. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  457. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  458. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  459. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  460. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  461. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  462. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  463. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  464. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  465. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  466. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  467. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  468. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  469. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  470. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  471. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  472. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  473. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  474. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  475. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  476. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  477. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  478. mindspore/dataset/datapreprocess/__init__.py +0 -20
  479. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  480. mindspore/include/api/net.h +0 -142
  481. mindspore/nn/lr_scheduler.py +0 -262
  482. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  483. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  484. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  485. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  486. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  487. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  488. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  489. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  490. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  491. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  492. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  493. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  497. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  502. mindspore/rewrite/node_visitor.py +0 -44
  503. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  504. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  505. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -1,203 +1,181 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """adamw"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import functional as F, operations as P
19
- from mindspore.common.parameter import Parameter, ParameterTuple
20
- from mindspore.common.tensor import Tensor
21
- import mindspore.common.dtype as mstype
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
- from mindspore import ops
24
-
25
-
26
- class AdamW(Optimizer):
27
- r"""
28
- Implements Adam Weight Decay algorithm.
29
-
30
- .. math::
31
- \begin{aligned}
32
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
33
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
34
- \: \epsilon \text{ (epsilon)} \\
35
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
36
- \: \textit{maximize} \\
37
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
38
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
39
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
40
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
41
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
42
- &\hspace{5mm}\textbf{else} \\
43
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
44
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
45
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
46
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
47
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
48
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
49
- &\hspace{5mm}\textbf{if} \: amsgrad \\
50
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
51
- \widehat{v_t}) \\
52
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
53
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
54
- &\hspace{5mm}\textbf{else} \\
55
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
56
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
57
- &\bf{return} \: \theta_t \\[-1.ex]
58
- \end{aligned}
59
-
60
- .. warning::
61
- This is an experimental optimizer API that is subject to change.
62
- This module must be used with lr scheduler module in `LRScheduler Class
63
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
64
-
65
- Args:
66
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
67
- parameter groups
68
- lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
69
- betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
70
- Default: ``(0.9, 0.999)``.
71
- eps (float, optional): term added to the denominator to improve
72
- numerical stability. Default: ``1e-8``.
73
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
74
- amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
75
-
76
- Keyword Args:
77
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
78
- Default: ``False``.
79
-
80
- Inputs:
81
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
82
-
83
- Raises:
84
- ValueError: If the learning rate is not int, float or Tensor.
85
- ValueError: If the learning rate is less than 0.
86
- ValueError: If the `eps` is less than 0.0.
87
- ValueError: If the `betas` not in the range of 0-1.
88
- ValueError: If the `weight_decay` is less than 0.
89
-
90
- Supported Platforms:
91
- ``Ascend`` ``GPU`` ``CPU``
92
-
93
- Examples:
94
- >>> import mindspore
95
- >>> from mindspore import nn
96
- >>> # Define the network structure of LeNet5. Refer to
97
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
98
- >>> net = LeNet5()
99
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
100
- >>> optimizer = nn.optim_ex.AdamW(net.trainable_params(), lr=0.1)
101
- >>> def forward_fn(data, label):
102
- ... logits = net(data)
103
- ... loss = loss_fn(logits, label)
104
- ... return loss, logits
105
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
106
- >>> def train_step(data, label):
107
- ... (loss, _), grads = grad_fn(data, label)
108
- ... optimizer(grads)
109
- ... return loss
110
- """
111
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
112
- weight_decay=1e-2, amsgrad=False, *, maximize=False):
113
- if lr < 0.0:
114
- raise ValueError("Invalid learning rate: {}".format(lr))
115
- if eps < 0.0:
116
- raise ValueError("Invalid epsilon value: {}".format(eps))
117
- if not 0.0 <= betas[0] < 1.0:
118
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
119
- if not 0.0 <= betas[1] < 1.0:
120
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
121
- if weight_decay < 0.0:
122
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
123
-
124
- defaults = dict(lr=lr, betas=betas, eps=eps,
125
- weight_decay=weight_decay, amsgrad=amsgrad,
126
- maximize=maximize)
127
- super(AdamW, self).__init__(params, defaults)
128
-
129
- self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
130
- self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
131
- self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
132
- self.state_step = ParameterTuple(Parameter(Tensor(0, mstype.int32), "step_"+str(i))
133
- for i in range(len(self.parameters)))
134
- self.increase_tensor = Tensor(1, mstype.int32)
135
-
136
- self.op_mul = P.Mul()
137
- self.assignadd = P.AssignAdd()
138
- self.op_pow = P.Pow()
139
- self.op_sqrt = P.Sqrt()
140
- self.op_maximum = P.Maximum()
141
- self.op_cast = P.Cast()
142
-
143
- def construct(self, gradients):
144
- for group_id, group in enumerate(self.param_groups):
145
- params = []
146
- grads = []
147
- exp_avgs = []
148
- exp_avg_sqs = []
149
- max_exp_avg_sqs = []
150
- state_steps = []
151
- amsgrad = group["amsgrad"]
152
- beta1, beta2 = group['betas']
153
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps = \
154
- self._init_group(group, gradients, params, grads, amsgrad, exp_avgs,
155
- exp_avg_sqs, max_exp_avg_sqs, state_steps, group_id)
156
-
157
- self.apply_adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
158
- amsgrad, beta1, beta2, group['lr'], group['weight_decay'], group['eps'],
159
- group["maximize"], group["grad_centralization"])
160
-
161
- def apply_adamw(self, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
162
- amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, grad_centralization):
163
- grads = self._gradients_centralization(grad_centralization, grads)
164
- for i, param in enumerate(params):
165
- grad = grads[i] if not maximize else -grads[i]
166
- exp_avg = exp_avgs[i]
167
- exp_avg_sq = exp_avg_sqs[i]
168
- step_t = state_steps[i]
169
-
170
- next_param = self.op_mul(param, F.tuple_to_array((1.0,)) - lr * weight_decay)
171
- F.assign(exp_avg, self.op_mul(exp_avg, beta1) + self.op_mul(grad, 1-beta1))
172
- F.assign(exp_avg_sq, ops.addcmul(self.op_mul(exp_avg_sq, beta2), grad, grad, 1-beta2))
173
- step_t = F.depend(step_t, self.assignadd(step_t, self.increase_tensor))
174
-
175
- bias_correction1 = F.tuple_to_array((1.0,)) - self.op_pow(beta1, step_t)
176
- bias_correction2 = F.tuple_to_array((1.0,)) - self.op_pow(beta2, step_t)
177
- step_size = lr / bias_correction1
178
- bias_correction2_sqrt = self.op_sqrt(bias_correction2)
179
-
180
- if amsgrad:
181
- next_max_exp_avg = self.op_maximum(max_exp_avg_sqs[i], exp_avg_sq)
182
- denom = self.op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
183
- F.assign(max_exp_avg_sqs[i], next_max_exp_avg)
184
- else:
185
- denom = self.op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
186
-
187
- return_param = next_param - self.op_mul(exp_avg / denom, step_size)
188
- F.assign(param, return_param)
189
-
190
- def _init_group(self, group, gradients, params, grads, amsgrad, exp_avgs, exp_avg_sqs,
191
- max_exp_avg_sqs, state_steps, group_id):
192
- """ Initialize group params. """
193
- p_id = self.group_start_id[group_id]
194
- for i, param in enumerate(group["params"]):
195
- grad = gradients[p_id+i]
196
- grads.append(grad)
197
- params.append(param)
198
- exp_avgs.append(self.exp_avg[p_id+i])
199
- exp_avg_sqs.append(self.exp_avg_sq[p_id+i])
200
- if amsgrad:
201
- max_exp_avg_sqs.append(self.max_exp_avg_sq[p_id+i])
202
- state_steps.append(self.state_step[p_id+i])
203
- return params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """adamw"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.parameter import Parameter
20
+ from mindspore.common.tensor import Tensor
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+ from mindspore import ops
24
+
25
+ _adamw_opt = C.MultitypeFuncGraph("adamw_opt")
26
+
27
+ op_mul = P.Mul()
28
+ op_pow = P.Pow()
29
+ op_sqrt = P.Sqrt()
30
+ op_maximum = P.Maximum()
31
+
32
+
33
+ @_adamw_opt.register("Float", "Tensor", "Bool", "Float", "Tensor", "Float", "Float", "Tensor", "Tensor",
34
+ "Tensor", "Tensor", "Tensor")
35
+ def _run_adamw_opt(weight_decay, lr, amsgrad, eps, state_step, beta1, beta2, param, grad,
36
+ exp_avg, exp_avg_sq, max_exp_avg_sq):
37
+ """Apply adamw optimizer to the weight parameter."""
38
+ success = True
39
+ next_param = op_mul(param, 1 - lr * weight_decay)
40
+ F.assign(exp_avg, op_mul(exp_avg, beta1) + op_mul(grad, 1 - beta1))
41
+ F.assign(exp_avg_sq, ops.addcmul(op_mul(exp_avg_sq, beta2), grad, grad, 1 - beta2))
42
+ bias_correction1 = 1 - op_pow(beta1, state_step)
43
+ bias_correction2 = 1 - op_pow(beta2, state_step)
44
+ step_size = lr / bias_correction1
45
+ bias_correction2_sqrt = op_sqrt(bias_correction2)
46
+
47
+ if amsgrad:
48
+ next_max_exp_avg = op_maximum(max_exp_avg_sq, exp_avg_sq)
49
+ denom = op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
50
+ F.assign(max_exp_avg_sq, next_max_exp_avg)
51
+ else:
52
+ denom = op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
53
+
54
+ return_param = next_param - op_mul(exp_avg / denom, step_size)
55
+ F.assign(param, return_param)
56
+ return success
57
+
58
+
59
+ class AdamW(Optimizer):
60
+ r"""
61
+ Implements Adam Weight Decay algorithm.
62
+
63
+ .. math::
64
+ \begin{aligned}
65
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
66
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
67
+ \: \epsilon \text{ (epsilon)} \\
68
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
69
+ \: \textit{maximize} \\
70
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
71
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
72
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
73
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
74
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
75
+ &\hspace{5mm}\textbf{else} \\
76
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
77
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
78
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
79
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
80
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
81
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
82
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
83
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
84
+ \widehat{v_t}) \\
85
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
86
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
87
+ &\hspace{5mm}\textbf{else} \\
88
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
89
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
90
+ &\bf{return} \: \theta_t \\[-1.ex]
91
+ \end{aligned}
92
+
93
+ .. warning::
94
+ This is an experimental optimizer API that is subject to change.
95
+ This module must be used with lr scheduler module in `LRScheduler Class
96
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
97
+
98
+ Args:
99
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
100
+ parameter groups
101
+ lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
102
+ betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
103
+ Default: ``(0.9, 0.999)``.
104
+ eps (float, optional): term added to the denominator to improve
105
+ numerical stability. Default: ``1e-8``.
106
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
107
+ amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
108
+
109
+ Keyword Args:
110
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
111
+ Default: ``False``.
112
+
113
+ Inputs:
114
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
115
+
116
+ Raises:
117
+ ValueError: If the learning rate is not int, float or Tensor.
118
+ ValueError: If the learning rate is less than 0.
119
+ ValueError: If the `eps` is less than 0.0.
120
+ ValueError: If the `betas` not in the range of 0-1.
121
+ ValueError: If the `weight_decay` is less than 0.
122
+
123
+ Supported Platforms:
124
+ ``Ascend`` ``GPU`` ``CPU``
125
+
126
+ Examples:
127
+ >>> import mindspore
128
+ >>> from mindspore import nn
129
+ >>> from mindspore.experimental import optim
130
+ >>> # Define the network structure of LeNet5. Refer to
131
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
132
+ >>> net = LeNet5()
133
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
134
+ >>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1)
135
+ >>> def forward_fn(data, label):
136
+ ... logits = net(data)
137
+ ... loss = loss_fn(logits, label)
138
+ ... return loss, logits
139
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
140
+ >>> def train_step(data, label):
141
+ ... (loss, _), grads = grad_fn(data, label)
142
+ ... optimizer(grads)
143
+ ... return loss
144
+ """
145
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
146
+ weight_decay=1e-2, amsgrad=False, *, maximize=False):
147
+ if lr < 0.0:
148
+ raise ValueError("Invalid learning rate: {}".format(lr))
149
+ if eps < 0.0:
150
+ raise ValueError("Invalid epsilon value: {}".format(eps))
151
+ if not 0.0 <= betas[0] < 1.0:
152
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
153
+ if not 0.0 <= betas[1] < 1.0:
154
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
155
+ if weight_decay < 0.0:
156
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
157
+
158
+ defaults = dict(lr=lr, betas=betas, eps=eps,
159
+ weight_decay=weight_decay, amsgrad=amsgrad,
160
+ maximize=maximize)
161
+ super(AdamW, self).__init__(params, defaults)
162
+
163
+ self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
164
+ self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
165
+ self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
166
+ self.state_step = Parameter(Tensor(0, mstype.int32), "state_step")
167
+ self.increase_tensor = Tensor(1, mstype.int32)
168
+ self.assignadd = P.AssignAdd()
169
+
170
+ def construct(self, gradients):
171
+ self.assignadd(self.state_step, self.increase_tensor)
172
+ for group_id, group in enumerate(self.param_groups):
173
+ beta1, beta2 = group['betas']
174
+ start_id = self.group_start_id[group_id]
175
+ end_id = self.group_start_id[group_id + 1]
176
+ grads = gradients[start_id: end_id] if not group.get("maximize") else -gradients[start_id: end_id]
177
+ self.hyper_map(F.partial(_adamw_opt, group.get("weight_decay"), group.get("lr"), group.get("amsgrad"),
178
+ group.get("eps"), self.state_step, beta1, beta2),
179
+ self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
180
+ self.exp_avg_sq[start_id: end_id], self.max_exp_avg_sq[start_id: end_id])
181
+ return True