mindspore 2.1.0__cp37-cp37m-win_amd64.whl → 2.2.11__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (511) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +13 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +67 -72
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +86 -106
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +29 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +33 -7
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +61 -95
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/__init__.py +4 -2
  205. mindspore/nn/wrap/cell_wrapper.py +87 -34
  206. mindspore/nn/wrap/grad_reducer.py +8 -5
  207. mindspore/nn/wrap/loss_scale.py +105 -42
  208. mindspore/numpy/array_creations.py +1 -2
  209. mindspore/numpy/array_ops.py +3 -2
  210. mindspore/numpy/utils_const.py +5 -5
  211. mindspore/opencv_core452.dll +0 -0
  212. mindspore/opencv_imgcodecs452.dll +0 -0
  213. mindspore/opencv_imgproc452.dll +0 -0
  214. mindspore/ops/_grad_experimental/__init__.py +0 -5
  215. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  216. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  217. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  218. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  219. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  220. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  221. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  222. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  223. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  224. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  225. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  226. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  227. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  228. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  229. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  230. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  231. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  232. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  233. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  234. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  235. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  236. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  237. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  238. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  239. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  240. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  241. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  242. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  243. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  244. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  245. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  246. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  248. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  249. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  250. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  251. mindspore/ops/_primitive_cache.py +1 -1
  252. mindspore/ops/_tracefunc.py +45 -13
  253. mindspore/ops/_utils/utils.py +6 -1
  254. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  255. mindspore/ops/_vmap/vmap_base.py +3 -3
  256. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  257. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  258. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  259. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  260. mindspore/ops/arg_dtype_cast.py +54 -0
  261. mindspore/ops/composite/base.py +37 -10
  262. mindspore/ops/composite/math_ops.py +5 -4
  263. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  264. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  265. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  266. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  268. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  270. mindspore/ops/deprecated.py +304 -0
  271. mindspore/ops/function/__init__.py +4 -1
  272. mindspore/ops/function/array_func.py +174 -193
  273. mindspore/ops/function/clip_func.py +81 -13
  274. mindspore/ops/function/debug_func.py +1 -1
  275. mindspore/ops/function/grad/grad_func.py +18 -9
  276. mindspore/ops/function/image_func.py +10 -4
  277. mindspore/ops/function/linalg_func.py +5 -5
  278. mindspore/ops/function/math_func.py +575 -386
  279. mindspore/ops/function/nn_func.py +568 -260
  280. mindspore/ops/function/random_func.py +88 -57
  281. mindspore/ops/function/sparse_func.py +1 -1
  282. mindspore/ops/function/sparse_unary_func.py +14 -12
  283. mindspore/ops/function/vmap_func.py +6 -5
  284. mindspore/ops/functional.py +15 -10
  285. mindspore/ops/op_info_register.py +244 -25
  286. mindspore/ops/operations/__init__.py +31 -19
  287. mindspore/ops/operations/_grad_ops.py +71 -7
  288. mindspore/ops/operations/_inner_ops.py +350 -17
  289. mindspore/ops/operations/_quant_ops.py +4 -8
  290. mindspore/ops/operations/_sequence_ops.py +42 -0
  291. mindspore/ops/operations/array_ops.py +68 -282
  292. mindspore/ops/operations/comm_ops.py +107 -59
  293. mindspore/ops/operations/custom_ops.py +94 -70
  294. mindspore/ops/operations/debug_ops.py +8 -4
  295. mindspore/ops/operations/image_ops.py +18 -12
  296. mindspore/ops/operations/inner_ops.py +26 -3
  297. mindspore/ops/operations/math_ops.py +192 -144
  298. mindspore/ops/operations/nn_ops.py +857 -489
  299. mindspore/ops/operations/other_ops.py +0 -22
  300. mindspore/ops/operations/random_ops.py +53 -111
  301. mindspore/ops/operations/sparse_ops.py +3 -1
  302. mindspore/ops/primitive.py +24 -18
  303. mindspore/parallel/_auto_parallel_context.py +68 -8
  304. mindspore/parallel/_cost_model_context.py +2 -2
  305. mindspore/parallel/_offload_context.py +17 -3
  306. mindspore/parallel/_parallel_serialization.py +12 -5
  307. mindspore/parallel/_ps_context.py +12 -0
  308. mindspore/parallel/_tensor.py +18 -13
  309. mindspore/parallel/_transformer/layers.py +5 -3
  310. mindspore/parallel/_transformer/loss.py +1 -0
  311. mindspore/parallel/_transformer/moe.py +2 -2
  312. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  313. mindspore/parallel/_transformer/transformer.py +23 -3
  314. mindspore/parallel/_utils.py +11 -7
  315. mindspore/parallel/algo_parameter_config.py +85 -5
  316. mindspore/parallel/checkpoint_transform.py +19 -12
  317. mindspore/parallel/shard.py +21 -14
  318. mindspore/pgodb140.dll +0 -0
  319. mindspore/pgort140.dll +0 -0
  320. mindspore/profiler/common/struct_type.py +3 -3
  321. mindspore/profiler/common/util.py +4 -2
  322. mindspore/profiler/envprofiling.py +1 -1
  323. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  324. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  325. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  326. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  327. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  328. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  329. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  330. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  331. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  332. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  333. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  334. mindspore/profiler/parser/flops_parser.py +15 -11
  335. mindspore/profiler/parser/framework_parser.py +38 -22
  336. mindspore/profiler/parser/hccl_parser.py +16 -12
  337. mindspore/profiler/parser/integrator.py +22 -11
  338. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  339. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  340. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  341. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  342. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  343. mindspore/profiler/parser/optime_parser.py +1 -1
  344. mindspore/profiler/parser/profiler_info.py +21 -2
  345. mindspore/profiler/parser/step_trace_parser.py +11 -14
  346. mindspore/profiler/profiling.py +179 -89
  347. mindspore/rewrite/api/node.py +102 -19
  348. mindspore/rewrite/api/node_type.py +5 -1
  349. mindspore/rewrite/api/pattern_engine.py +1 -1
  350. mindspore/rewrite/api/scoped_value.py +9 -17
  351. mindspore/rewrite/api/symbol_tree.py +131 -47
  352. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  353. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  354. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  355. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  356. mindspore/rewrite/common/rewrite_elog.py +5 -1
  357. mindspore/rewrite/namer.py +33 -24
  358. mindspore/rewrite/namespace.py +14 -5
  359. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  360. mindspore/rewrite/node/call_function.py +79 -0
  361. mindspore/rewrite/node/cell_container.py +135 -0
  362. mindspore/rewrite/node/control_flow.py +88 -0
  363. mindspore/rewrite/{node.py → node/node.py} +273 -234
  364. mindspore/rewrite/node/node_manager.py +254 -0
  365. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  366. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  367. mindspore/rewrite/parsers/assign_parser.py +216 -221
  368. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  369. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  370. mindspore/rewrite/parsers/constant_parser.py +9 -6
  371. mindspore/rewrite/parsers/container_parser.py +9 -7
  372. mindspore/rewrite/parsers/for_parser.py +42 -21
  373. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  374. mindspore/rewrite/parsers/if_parser.py +28 -24
  375. mindspore/rewrite/parsers/module_parser.py +196 -25
  376. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  377. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  378. mindspore/rewrite/parsers/return_parser.py +6 -6
  379. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  380. mindspore/rewrite/sparsify/utils.py +1 -1
  381. mindspore/rewrite/symbol_tree.py +523 -578
  382. mindspore/rewrite/symbol_tree_builder.py +9 -193
  383. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  384. mindspore/run_check/_check_version.py +6 -4
  385. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  386. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  387. mindspore/tbbmalloc.dll +0 -0
  388. mindspore/tinyxml2.dll +0 -0
  389. mindspore/train/_utils.py +7 -3
  390. mindspore/train/amp.py +323 -123
  391. mindspore/train/anf_ir_pb2.py +14 -2
  392. mindspore/train/callback/_backup_and_restore.py +2 -12
  393. mindspore/train/callback/_callback.py +29 -4
  394. mindspore/train/callback/_checkpoint.py +23 -8
  395. mindspore/train/callback/_early_stop.py +2 -2
  396. mindspore/train/callback/_landscape.py +4 -4
  397. mindspore/train/callback/_loss_monitor.py +2 -2
  398. mindspore/train/callback/_on_request_exit.py +2 -2
  399. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  400. mindspore/train/callback/_summary_collector.py +15 -8
  401. mindspore/train/callback/_time_monitor.py +58 -5
  402. mindspore/train/data_sink.py +5 -11
  403. mindspore/train/dataset_helper.py +84 -57
  404. mindspore/train/loss_scale_manager.py +2 -2
  405. mindspore/train/metrics/__init__.py +3 -3
  406. mindspore/train/metrics/cosine_similarity.py +1 -1
  407. mindspore/train/metrics/hausdorff_distance.py +3 -2
  408. mindspore/train/metrics/mean_surface_distance.py +3 -2
  409. mindspore/train/metrics/metric.py +39 -19
  410. mindspore/train/metrics/roc.py +2 -2
  411. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  412. mindspore/train/mind_ir_pb2.py +85 -36
  413. mindspore/train/model.py +187 -47
  414. mindspore/train/serialization.py +487 -161
  415. mindspore/train/summary/_summary_adapter.py +1 -1
  416. mindspore/train/summary/_writer_pool.py +3 -2
  417. mindspore/train/summary/summary_record.py +37 -17
  418. mindspore/train/train_thor/convert_utils.py +3 -3
  419. mindspore/train/train_thor/dataset_helper.py +1 -1
  420. mindspore/turbojpeg.dll +0 -0
  421. mindspore/vcmeta.dll +0 -0
  422. mindspore/vcruntime140.dll +0 -0
  423. mindspore/vcruntime140_1.dll +0 -0
  424. mindspore/version.py +1 -1
  425. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  426. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
  427. mindspore/_extends/graph_kernel/expander.py +0 -80
  428. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  429. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  430. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  431. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  432. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  433. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  434. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  435. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  436. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  437. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  438. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  439. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  440. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  441. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  442. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  443. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  444. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  445. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  446. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  447. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  448. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  449. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  450. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  451. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  452. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  453. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  454. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  455. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  456. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  457. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  458. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  459. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  460. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  461. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  462. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  463. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  464. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  465. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  466. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  467. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  468. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  469. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  470. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  471. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  472. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  473. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  474. mindspore/dataset/datapreprocess/__init__.py +0 -20
  475. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  476. mindspore/include/api/net.h +0 -142
  477. mindspore/nn/lr_scheduler.py +0 -262
  478. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  479. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  480. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  481. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  482. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  483. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  484. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  485. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  486. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  487. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  488. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  489. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  490. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  491. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  492. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  493. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  497. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  502. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  503. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  504. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  505. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  506. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  507. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  508. mindspore/rewrite/node_visitor.py +0 -44
  509. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  510. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  511. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,203 +1,181 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """adamw"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import functional as F, operations as P
19
- from mindspore.common.parameter import Parameter, ParameterTuple
20
- from mindspore.common.tensor import Tensor
21
- import mindspore.common.dtype as mstype
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
- from mindspore import ops
24
-
25
-
26
- class AdamW(Optimizer):
27
- r"""
28
- Implements Adam Weight Decay algorithm.
29
-
30
- .. math::
31
- \begin{aligned}
32
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
33
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
34
- \: \epsilon \text{ (epsilon)} \\
35
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
36
- \: \textit{maximize} \\
37
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
38
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
39
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
40
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
41
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
42
- &\hspace{5mm}\textbf{else} \\
43
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
44
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
45
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
46
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
47
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
48
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
49
- &\hspace{5mm}\textbf{if} \: amsgrad \\
50
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
51
- \widehat{v_t}) \\
52
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
53
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
54
- &\hspace{5mm}\textbf{else} \\
55
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
56
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
57
- &\bf{return} \: \theta_t \\[-1.ex]
58
- \end{aligned}
59
-
60
- .. warning::
61
- This is an experimental optimizer API that is subject to change.
62
- This module must be used with lr scheduler module in `LRScheduler Class
63
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
64
-
65
- Args:
66
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
67
- parameter groups
68
- lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
69
- betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
70
- Default: ``(0.9, 0.999)``.
71
- eps (float, optional): term added to the denominator to improve
72
- numerical stability. Default: ``1e-8``.
73
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
74
- amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
75
-
76
- Keyword Args:
77
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
78
- Default: ``False``.
79
-
80
- Inputs:
81
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
82
-
83
- Raises:
84
- ValueError: If the learning rate is not int, float or Tensor.
85
- ValueError: If the learning rate is less than 0.
86
- ValueError: If the `eps` is less than 0.0.
87
- ValueError: If the `betas` not in the range of 0-1.
88
- ValueError: If the `weight_decay` is less than 0.
89
-
90
- Supported Platforms:
91
- ``Ascend`` ``GPU`` ``CPU``
92
-
93
- Examples:
94
- >>> import mindspore
95
- >>> from mindspore import nn
96
- >>> # Define the network structure of LeNet5. Refer to
97
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
98
- >>> net = LeNet5()
99
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
100
- >>> optimizer = nn.optim_ex.AdamW(net.trainable_params(), lr=0.1)
101
- >>> def forward_fn(data, label):
102
- ... logits = net(data)
103
- ... loss = loss_fn(logits, label)
104
- ... return loss, logits
105
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
106
- >>> def train_step(data, label):
107
- ... (loss, _), grads = grad_fn(data, label)
108
- ... optimizer(grads)
109
- ... return loss
110
- """
111
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
112
- weight_decay=1e-2, amsgrad=False, *, maximize=False):
113
- if lr < 0.0:
114
- raise ValueError("Invalid learning rate: {}".format(lr))
115
- if eps < 0.0:
116
- raise ValueError("Invalid epsilon value: {}".format(eps))
117
- if not 0.0 <= betas[0] < 1.0:
118
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
119
- if not 0.0 <= betas[1] < 1.0:
120
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
121
- if weight_decay < 0.0:
122
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
123
-
124
- defaults = dict(lr=lr, betas=betas, eps=eps,
125
- weight_decay=weight_decay, amsgrad=amsgrad,
126
- maximize=maximize)
127
- super(AdamW, self).__init__(params, defaults)
128
-
129
- self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
130
- self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
131
- self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
132
- self.state_step = ParameterTuple(Parameter(Tensor(0, mstype.int32), "step_"+str(i))
133
- for i in range(len(self.parameters)))
134
- self.increase_tensor = Tensor(1, mstype.int32)
135
-
136
- self.op_mul = P.Mul()
137
- self.assignadd = P.AssignAdd()
138
- self.op_pow = P.Pow()
139
- self.op_sqrt = P.Sqrt()
140
- self.op_maximum = P.Maximum()
141
- self.op_cast = P.Cast()
142
-
143
- def construct(self, gradients):
144
- for group_id, group in enumerate(self.param_groups):
145
- params = []
146
- grads = []
147
- exp_avgs = []
148
- exp_avg_sqs = []
149
- max_exp_avg_sqs = []
150
- state_steps = []
151
- amsgrad = group["amsgrad"]
152
- beta1, beta2 = group['betas']
153
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps = \
154
- self._init_group(group, gradients, params, grads, amsgrad, exp_avgs,
155
- exp_avg_sqs, max_exp_avg_sqs, state_steps, group_id)
156
-
157
- self.apply_adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
158
- amsgrad, beta1, beta2, group['lr'], group['weight_decay'], group['eps'],
159
- group["maximize"], group["grad_centralization"])
160
-
161
- def apply_adamw(self, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
162
- amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, grad_centralization):
163
- grads = self._gradients_centralization(grad_centralization, grads)
164
- for i, param in enumerate(params):
165
- grad = grads[i] if not maximize else -grads[i]
166
- exp_avg = exp_avgs[i]
167
- exp_avg_sq = exp_avg_sqs[i]
168
- step_t = state_steps[i]
169
-
170
- next_param = self.op_mul(param, F.tuple_to_array((1.0,)) - lr * weight_decay)
171
- F.assign(exp_avg, self.op_mul(exp_avg, beta1) + self.op_mul(grad, 1-beta1))
172
- F.assign(exp_avg_sq, ops.addcmul(self.op_mul(exp_avg_sq, beta2), grad, grad, 1-beta2))
173
- step_t = F.depend(step_t, self.assignadd(step_t, self.increase_tensor))
174
-
175
- bias_correction1 = F.tuple_to_array((1.0,)) - self.op_pow(beta1, step_t)
176
- bias_correction2 = F.tuple_to_array((1.0,)) - self.op_pow(beta2, step_t)
177
- step_size = lr / bias_correction1
178
- bias_correction2_sqrt = self.op_sqrt(bias_correction2)
179
-
180
- if amsgrad:
181
- next_max_exp_avg = self.op_maximum(max_exp_avg_sqs[i], exp_avg_sq)
182
- denom = self.op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
183
- F.assign(max_exp_avg_sqs[i], next_max_exp_avg)
184
- else:
185
- denom = self.op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
186
-
187
- return_param = next_param - self.op_mul(exp_avg / denom, step_size)
188
- F.assign(param, return_param)
189
-
190
- def _init_group(self, group, gradients, params, grads, amsgrad, exp_avgs, exp_avg_sqs,
191
- max_exp_avg_sqs, state_steps, group_id):
192
- """ Initialize group params. """
193
- p_id = self.group_start_id[group_id]
194
- for i, param in enumerate(group["params"]):
195
- grad = gradients[p_id+i]
196
- grads.append(grad)
197
- params.append(param)
198
- exp_avgs.append(self.exp_avg[p_id+i])
199
- exp_avg_sqs.append(self.exp_avg_sq[p_id+i])
200
- if amsgrad:
201
- max_exp_avg_sqs.append(self.max_exp_avg_sq[p_id+i])
202
- state_steps.append(self.state_step[p_id+i])
203
- return params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """adamw"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.parameter import Parameter
20
+ from mindspore.common.tensor import Tensor
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+ from mindspore import ops
24
+
25
+ _adamw_opt = C.MultitypeFuncGraph("adamw_opt")
26
+
27
+ op_mul = P.Mul()
28
+ op_pow = P.Pow()
29
+ op_sqrt = P.Sqrt()
30
+ op_maximum = P.Maximum()
31
+
32
+
33
+ @_adamw_opt.register("Float", "Tensor", "Bool", "Float", "Tensor", "Float", "Float", "Tensor", "Tensor",
34
+ "Tensor", "Tensor", "Tensor")
35
+ def _run_adamw_opt(weight_decay, lr, amsgrad, eps, state_step, beta1, beta2, param, grad,
36
+ exp_avg, exp_avg_sq, max_exp_avg_sq):
37
+ """Apply adamw optimizer to the weight parameter."""
38
+ success = True
39
+ next_param = op_mul(param, 1 - lr * weight_decay)
40
+ F.assign(exp_avg, op_mul(exp_avg, beta1) + op_mul(grad, 1 - beta1))
41
+ F.assign(exp_avg_sq, ops.addcmul(op_mul(exp_avg_sq, beta2), grad, grad, 1 - beta2))
42
+ bias_correction1 = 1 - op_pow(beta1, state_step)
43
+ bias_correction2 = 1 - op_pow(beta2, state_step)
44
+ step_size = lr / bias_correction1
45
+ bias_correction2_sqrt = op_sqrt(bias_correction2)
46
+
47
+ if amsgrad:
48
+ next_max_exp_avg = op_maximum(max_exp_avg_sq, exp_avg_sq)
49
+ denom = op_sqrt(next_max_exp_avg) / bias_correction2_sqrt + eps
50
+ F.assign(max_exp_avg_sq, next_max_exp_avg)
51
+ else:
52
+ denom = op_sqrt(exp_avg_sq) / bias_correction2_sqrt + eps
53
+
54
+ return_param = next_param - op_mul(exp_avg / denom, step_size)
55
+ F.assign(param, return_param)
56
+ return success
57
+
58
+
59
+ class AdamW(Optimizer):
60
+ r"""
61
+ Implements Adam Weight Decay algorithm.
62
+
63
+ .. math::
64
+ \begin{aligned}
65
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
66
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
67
+ \: \epsilon \text{ (epsilon)} \\
68
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
69
+ \: \textit{maximize} \\
70
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
71
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
72
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
73
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
74
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
75
+ &\hspace{5mm}\textbf{else} \\
76
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
77
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
78
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
79
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
80
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
81
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
82
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
83
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
84
+ \widehat{v_t}) \\
85
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
86
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
87
+ &\hspace{5mm}\textbf{else} \\
88
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
89
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
90
+ &\bf{return} \: \theta_t \\[-1.ex]
91
+ \end{aligned}
92
+
93
+ .. warning::
94
+ This is an experimental optimizer API that is subject to change.
95
+ This module must be used with lr scheduler module in `LRScheduler Class
96
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
97
+
98
+ Args:
99
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
100
+ parameter groups
101
+ lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
102
+ betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
103
+ Default: ``(0.9, 0.999)``.
104
+ eps (float, optional): term added to the denominator to improve
105
+ numerical stability. Default: ``1e-8``.
106
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
107
+ amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
108
+
109
+ Keyword Args:
110
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
111
+ Default: ``False``.
112
+
113
+ Inputs:
114
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
115
+
116
+ Raises:
117
+ ValueError: If the learning rate is not int, float or Tensor.
118
+ ValueError: If the learning rate is less than 0.
119
+ ValueError: If the `eps` is less than 0.0.
120
+ ValueError: If the `betas` not in the range of 0-1.
121
+ ValueError: If the `weight_decay` is less than 0.
122
+
123
+ Supported Platforms:
124
+ ``Ascend`` ``GPU`` ``CPU``
125
+
126
+ Examples:
127
+ >>> import mindspore
128
+ >>> from mindspore import nn
129
+ >>> from mindspore.experimental import optim
130
+ >>> # Define the network structure of LeNet5. Refer to
131
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
132
+ >>> net = LeNet5()
133
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
134
+ >>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1)
135
+ >>> def forward_fn(data, label):
136
+ ... logits = net(data)
137
+ ... loss = loss_fn(logits, label)
138
+ ... return loss, logits
139
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
140
+ >>> def train_step(data, label):
141
+ ... (loss, _), grads = grad_fn(data, label)
142
+ ... optimizer(grads)
143
+ ... return loss
144
+ """
145
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
146
+ weight_decay=1e-2, amsgrad=False, *, maximize=False):
147
+ if lr < 0.0:
148
+ raise ValueError("Invalid learning rate: {}".format(lr))
149
+ if eps < 0.0:
150
+ raise ValueError("Invalid epsilon value: {}".format(eps))
151
+ if not 0.0 <= betas[0] < 1.0:
152
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
153
+ if not 0.0 <= betas[1] < 1.0:
154
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
155
+ if weight_decay < 0.0:
156
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
157
+
158
+ defaults = dict(lr=lr, betas=betas, eps=eps,
159
+ weight_decay=weight_decay, amsgrad=amsgrad,
160
+ maximize=maximize)
161
+ super(AdamW, self).__init__(params, defaults)
162
+
163
+ self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
164
+ self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
165
+ self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
166
+ self.state_step = Parameter(Tensor(0, mstype.int32), "state_step")
167
+ self.increase_tensor = Tensor(1, mstype.int32)
168
+ self.assignadd = P.AssignAdd()
169
+
170
+ def construct(self, gradients):
171
+ self.assignadd(self.state_step, self.increase_tensor)
172
+ for group_id, group in enumerate(self.param_groups):
173
+ beta1, beta2 = group['betas']
174
+ start_id = self.group_start_id[group_id]
175
+ end_id = self.group_start_id[group_id + 1]
176
+ grads = gradients[start_id: end_id] if not group.get("maximize") else -gradients[start_id: end_id]
177
+ self.hyper_map(F.partial(_adamw_opt, group.get("weight_decay"), group.get("lr"), group.get("amsgrad"),
178
+ group.get("eps"), self.state_step, beta1, beta2),
179
+ self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
180
+ self.exp_avg_sq[start_id: end_id], self.max_exp_avg_sq[start_id: end_id])
181
+ return True