mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,259 +1,252 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """optimizer"""
16
- from __future__ import absolute_import
17
- from collections import defaultdict
18
- from typing import Iterable
19
- from mindspore.ops import functional as F, composite as C, operations as P
20
- from mindspore.ops.operations import _inner_ops as inner
21
- from mindspore.nn.cell import Cell
22
- from mindspore.common.parameter import Parameter, ParameterTuple
23
- from mindspore.common import Tensor
24
- from mindspore.common.sparse_tensor import RowTensorInner
25
- import mindspore.common.dtype as mstype
26
- from mindspore import _checkparam as validator
27
- from mindspore import log as logger
28
-
29
-
30
- __all__ = ['Optimizer']
31
-
32
-
33
- class Optimizer(Cell):
34
- r"""
35
- Base class for all optimizers.
36
-
37
- .. warning::
38
- This is an experimental optimizer API that is subject to change.
39
- This module must be used with lr scheduler module in `LRScheduler Class
40
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
41
-
42
- Args:
43
- params (Union[list(Parameter), list(dict)]): an iterable of :class:`mindspore.Parameter` or
44
- dict. Specifies what Tensors should be optimized.
45
- defaults: (dict): a dict containing default values of optimization
46
- options (used when a parameter group doesn't specify them).
47
-
48
- Raises:
49
- TypeError: If `learning_rate` is not one of int, float, Tensor.
50
- TypeError: If element of `parameters` is neither Parameter nor dict.
51
- TypeError: If `weight_decay` is neither float nor int.
52
- ValueError: If `weight_decay` is less than 0.
53
- ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
54
-
55
- Supported Platforms:
56
- ``Ascend`` ``GPU`` ``CPU``
57
- """
58
- def __init__(self, params, defaults):
59
- super(Optimizer, self).__init__(auto_prefix=False)
60
-
61
- param_groups = self._parameters_base_check(params, "params")
62
- self.defaults = defaults
63
- self.state = defaultdict(dict)
64
- self.param_groups = []
65
- self.parameters = []
66
- self.map_ = C.Map()
67
- self.group_start_id = [0]
68
- if not isinstance(param_groups[0], dict):
69
- param_groups = [{'params': param_groups}]
70
-
71
- for i, param_group in enumerate(param_groups):
72
- self.add_param_group(i, param_group)
73
- self.group_start_id.append(self.group_start_id[-1] + len(param_group["params"]))
74
- self.parameters = ParameterTuple(self.parameters)
75
-
76
- def __repr__(self):
77
- format_string = self.__class__.__name__ + ' ('
78
- for i, group in enumerate(self.param_groups):
79
- format_string += '\n'
80
- format_string += 'Parameter Group {0}\n'.format(i)
81
- for key in sorted(group.keys()):
82
- if key != 'params':
83
- format_string += ' {0}: {1}\n'.format(key, group[key].value()) \
84
- if key == "lr" and isinstance(group[key], Parameter) \
85
- else ' {0}: {1}\n'.format(key, group[key])
86
- format_string += ')'
87
- return format_string
88
-
89
- def add_param_group(self, group_id, param_group):
90
- r"""
91
- Add a param group to the `Optimizer.param_groups`.
92
-
93
- Args:
94
- group_id(int): Specifies the group index.
95
- param_group (dict): Specifies what Parameters should be optimized along with group
96
- specific optimization options.
97
- """
98
- param_group = self._preprocess_param_group(param_group)
99
- self.parameters += param_group["params"]
100
-
101
- for name, default in self.defaults.items():
102
- if name not in param_group:
103
- param_group.setdefault(name, default)
104
-
105
- lr = self._build_single_lr(param_group.get("lr"), 'learning_rate_group_' + str(group_id))
106
- weight_decay = self._preprocess_weight_decay(param_group.get("weight_decay", 0.0))
107
- param_group["lr"] = lr
108
- param_group["weight_decay"] = weight_decay
109
- param_group["grad_centralization"] = self._preprocess_grad_centralization(
110
- param_group.get('grad_centralization', False))
111
- self.param_groups.append(param_group)
112
-
113
- @staticmethod
114
- def _parameters_base_check(parameters, param_info):
115
- """Parameters base check."""
116
- if parameters is None:
117
- raise ValueError(f"For 'Optimizer', the argument {param_info} can not be None.")
118
- if not isinstance(parameters, Iterable):
119
- raise TypeError(f"For 'Optimizer', the argument {param_info} must be Iterable type, "
120
- f"but got {type(parameters)}.")
121
- parameters = list(parameters)
122
-
123
- if not parameters:
124
- raise ValueError(f"For 'Optimizer', the argument {param_info} must not be empty.")
125
- return parameters
126
-
127
- def _decay_weight(self, weight_decay, params, gradients):
128
- """Apply weight decay."""
129
- if weight_decay != 0.:
130
- weight_decay = Tensor(weight_decay, mstype.float32)
131
- gradients = self.map_(F.partial(_apply_decay, weight_decay), params, gradients)
132
- return gradients
133
-
134
- def _gradients_centralization(self, grad_centralization, gradients):
135
- """Apply gradients centralization."""
136
- if grad_centralization:
137
- return self.map_(_apply_grad_centralization, gradients)
138
- return gradients
139
-
140
- def _preprocess_param_group(self, param_group):
141
- """Preprocess param groups."""
142
- if not isinstance(param_group, dict):
143
- raise TypeError('Param group must be a dict.')
144
-
145
- params = param_group['params']
146
- if isinstance(params, Parameter):
147
- param_group['params'] = [params]
148
- elif isinstance(params, set):
149
- raise TypeError('Optimizer parameters need to be organized in ordered collections, but '
150
- 'the ordering of tensors in sets will change between runs. '
151
- 'Please use a list instead.')
152
- else:
153
- param_group['params'] = list(params)
154
-
155
- for param in param_group['params']:
156
- if not isinstance(param, Parameter):
157
- raise TypeError("Optimizer can only optimize Parameters, but one of the params is " + type(param))
158
-
159
- if len(param_group['params']) != len(set(param_group['params'])):
160
- logger.warning("Optimizer contains a parameter group with duplicate parameters.")
161
-
162
- param_set = set()
163
- for group in self.param_groups:
164
- param_set.update(set(group['params']))
165
- if not param_set.isdisjoint(set(param_group['params'])):
166
- raise ValueError("some parameters appear in more than one parameter group.")
167
- return param_group
168
-
169
- def _build_single_lr(self, learning_rate, name):
170
- """Check lr value, and convert lr to a float or a Tensor."""
171
- if isinstance(learning_rate, (float, int)):
172
- learning_rate = float(learning_rate)
173
- validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
174
- return Parameter(Tensor(learning_rate, mstype.float32), name)
175
-
176
- if isinstance(learning_rate, Tensor):
177
- if learning_rate.ndim == 0:
178
- return Parameter(learning_rate.astype(mstype.float32), name)
179
- raise ValueError(f"For 'Optimizer', if 'learning_rate' is a Tensor, "
180
- f"then it should be scalar Tensor")
181
-
182
- raise TypeError("For 'Optimizer', the argument 'learning_rate' must be int, float or Tensor, "
183
- "but got {}.".format(type(learning_rate)))
184
-
185
- def _preprocess_weight_decay(self, weight_decay):
186
- """preprocess weight decay"""
187
- if isinstance(weight_decay, (float, int)):
188
- weight_decay = float(weight_decay)
189
- validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
190
- else:
191
- raise TypeError("For 'Optimizer', the argument 'Weight_decay' must be int or "
192
- "float.but got {}".format(type(weight_decay)))
193
- return weight_decay
194
-
195
- @staticmethod
196
- def _preprocess_grad_centralization(grad_centralization):
197
- """ Preprocess gradient centralization. """
198
- if not isinstance(grad_centralization, bool):
199
- raise TypeError("For 'Optimizer', the 'gradients_centralization' must be bool type, "
200
- "but got {}.".format(type(grad_centralization)))
201
- return grad_centralization
202
-
203
- def construct(self, *hyper_params):
204
- raise NotImplementedError
205
-
206
-
207
- op_add = P.AddN()
208
- op_gather = P.Gather()
209
- op_mul = P.Mul()
210
- op_gc = inner.Centralization()
211
-
212
- _apply_decay = C.MultitypeFuncGraph("apply_decay")
213
- _apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
214
-
215
-
216
- @_apply_decay.register("Tensor", "Tensor", "RowTensor")
217
- def _tensor_apply_decay_with_sparse(weight_decay, weight, gradient):
218
- """Get grad with weight_decay."""
219
- indices = gradient.indices
220
- values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
221
- shape = gradient.dense_shape
222
- return RowTensorInner(indices, values, shape)
223
-
224
-
225
- @_apply_decay.register("Tensor", "Tensor", "Tensor")
226
- def _tensor_apply_decay(weight_decay, weight, gradient):
227
- """Get grad with weight_decay."""
228
- return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
229
-
230
-
231
- @_apply_grad_centralization.register("RowTensor")
232
- def _tensor_apply_grad_centralization_with_sparse(gradient):
233
- """Get grad with grad_centralization."""
234
- indices = gradient.indices
235
- shape = gradient.dense_shape
236
- grad_shape = F.shape(gradient)
237
- axis = []
238
- for i in range(1, len(grad_shape)):
239
- axis.append(i)
240
- if len(axis) >= 1:
241
- if grad_shape[1] % 16 != 0:
242
- return gradient
243
- values = op_gc(gradient.values, axis)
244
- return RowTensorInner(indices, values, shape)
245
- return gradient
246
-
247
-
248
- @_apply_grad_centralization.register("Tensor")
249
- def _tensor_apply_grad_centralization(gradient):
250
- """Get grad with grad_centralization."""
251
- axis = []
252
- grad_shape = F.shape(gradient)
253
- for i in range(1, len(grad_shape)):
254
- axis.append(i)
255
- if len(axis) >= 1:
256
- if grad_shape[1] % 16 != 0:
257
- return gradient
258
- return op_gc(gradient, axis)
259
- return gradient
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """optimizer"""
16
+ from __future__ import absolute_import
17
+ from collections import defaultdict
18
+ from typing import Iterable
19
+ from mindspore.ops import functional as F, composite as C, operations as P
20
+
21
+ from mindspore.nn.cell import Cell
22
+ from mindspore.common.parameter import Parameter, ParameterTuple
23
+ from mindspore.common import Tensor
24
+ import mindspore.common.dtype as mstype
25
+ from mindspore import _checkparam as validator
26
+ from mindspore import log as logger
27
+
28
+
29
+ __all__ = ['Optimizer']
30
+
31
+
32
+ class Optimizer(Cell):
33
+ r"""
34
+ Base class for all optimizers.
35
+
36
+ .. warning::
37
+ This is an experimental optimizer API that is subject to change.
38
+ This module must be used with lr scheduler module in `LRScheduler Class
39
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
40
+
41
+ Args:
42
+ params (Union[list(Parameter), list(dict)]): an iterable of :class:`mindspore.Parameter` or
43
+ dict. Specifies what Tensors should be optimized.
44
+ defaults (dict): a dict containing default values of optimization
45
+ options (used when a parameter group doesn't specify them).
46
+
47
+ Raises:
48
+ TypeError: If `learning_rate` is not one of int, float, Tensor.
49
+ TypeError: If element of `parameters` is neither Parameter nor dict.
50
+ TypeError: If `weight_decay` is neither float nor int.
51
+ ValueError: If `weight_decay` is less than 0.
52
+ ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
53
+
54
+ Supported Platforms:
55
+ ``Ascend`` ``GPU`` ``CPU``
56
+
57
+ Examples:
58
+ >>> import numpy as np
59
+ >>> import mindspore
60
+ >>> from mindspore import nn, Tensor, Parameter
61
+ >>> from mindspore import ops
62
+ >>> from mindspore.experimental import optim
63
+ >>>
64
+ >>> class MySGD(optim.Optimizer):
65
+ ... def __init__(self, params, lr):
66
+ ... defaults = dict(lr=lr)
67
+ ... super(MySGD, self).__init__(params, defaults)
68
+ ...
69
+ ... def construct(self, gradients):
70
+ ... for group_id, group in enumerate(self.param_groups):
71
+ ... id = self.group_start_id[group_id]
72
+ ... for i, param in enumerate(group["params"]):
73
+ ... next_param = param + gradients[id+i] * group["lr"]
74
+ ... ops.assign(param, next_param)
75
+ >>>
76
+ >>> net = nn.Dense(8, 2)
77
+ >>> data = Tensor(np.random.rand(20, 8).astype(np.float32))
78
+ >>> label = Tensor(np.random.rand(20, 2).astype(np.float32))
79
+ >>>
80
+ >>> optimizer = MySGD(net.trainable_params(), 0.01)
81
+ >>> optimizer.add_param_group({"params": Parameter([0.01, 0.02])})
82
+ >>>
83
+ >>> criterion = nn.MAELoss(reduction="mean")
84
+ >>>
85
+ >>> def forward_fn(data, label):
86
+ ... logits = net(data)
87
+ ... loss = criterion(logits, label)
88
+ ... return loss, logits
89
+ >>>
90
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
91
+ >>>
92
+ >>> def train_step(data, label):
93
+ ... (loss, _), grads = grad_fn(data, label)
94
+ ... optimizer(grads)
95
+ ... print(loss)
96
+ >>>
97
+ >>> train_step(data, label)
98
+ """
99
+ def __init__(self, params, defaults):
100
+ super(Optimizer, self).__init__(auto_prefix=False)
101
+
102
+ param_groups = self._parameters_base_check(params, "params")
103
+ self.defaults = defaults
104
+ self.state = defaultdict(dict)
105
+ self.param_groups = []
106
+ self.parameters = []
107
+ self.map_ = C.Map()
108
+ self.group_start_id = [0]
109
+ if not isinstance(param_groups[0], dict):
110
+ param_groups = [{'params': param_groups}]
111
+
112
+ for param_group in param_groups:
113
+ self.add_param_group(param_group)
114
+ self.parameters = ParameterTuple(self.parameters)
115
+ self.hyper_map = C.HyperMap()
116
+ self.enable_tuple_broaden = True
117
+
118
+ def __repr__(self):
119
+ format_string = self.__class__.__name__ + ' ('
120
+ for i, group in enumerate(self.param_groups):
121
+ format_string += '\n'
122
+ format_string += 'Parameter Group {0}\n'.format(i)
123
+ for key in sorted(group.keys()):
124
+ if key != 'params':
125
+ format_string += ' {0}: {1}\n'.format(key, group[key].value()) \
126
+ if key == "lr" and isinstance(group[key], Parameter) \
127
+ else ' {0}: {1}\n'.format(key, group[key])
128
+ format_string += ')'
129
+ return format_string
130
+
131
+ def add_param_group(self, param_group):
132
+ r"""
133
+ Add a param group to the `Optimizer.param_groups`.
134
+
135
+ Args:
136
+ param_group (dict): Specifies what Parameters should be optimized along with group
137
+ specific optimization options.
138
+ """
139
+ group_id = len(self.param_groups)
140
+ param_group = self._preprocess_param_group(param_group)
141
+ self.parameters += tuple(param_group.get("params"))
142
+
143
+ for name, default in self.defaults.items():
144
+ if name not in param_group:
145
+ param_group.setdefault(name, default)
146
+
147
+ lr = self._build_single_lr(param_group.get("lr"), 'learning_rate_group_' + str(group_id))
148
+ weight_decay = self._preprocess_weight_decay(param_group.get("weight_decay", 0.0))
149
+ param_group["lr"] = lr
150
+ param_group["weight_decay"] = weight_decay
151
+ self.param_groups.append(param_group)
152
+ self.group_start_id.append(self.group_start_id[-1] + len(param_group.get("params")))
153
+
154
+ @staticmethod
155
+ def _parameters_base_check(parameters, param_info):
156
+ """Parameters base check."""
157
+ if parameters is None:
158
+ raise ValueError(f"For 'Optimizer', the argument {param_info} can not be None.")
159
+ if not isinstance(parameters, Iterable):
160
+ raise TypeError(f"For 'Optimizer', the argument {param_info} must be Iterable type, "
161
+ f"but got {type(parameters)}.")
162
+ parameters = list(parameters)
163
+
164
+ if not parameters:
165
+ raise ValueError(f"For 'Optimizer', the argument {param_info} must not be empty.")
166
+ return parameters
167
+
168
+ def _decay_weight(self, weight_decay, params, gradients):
169
+ """Apply weight decay."""
170
+ if weight_decay != 0.:
171
+ weight_decay = Tensor(weight_decay, mstype.float32)
172
+ gradients = self.map_(F.partial(_apply_decay, weight_decay), params, gradients)
173
+ return gradients
174
+
175
+ def _preprocess_param_group(self, param_group):
176
+ """Preprocess param groups."""
177
+ if not isinstance(param_group, dict):
178
+ raise TypeError('Param group must be a dict.')
179
+
180
+ params = param_group['params']
181
+ if isinstance(params, Parameter):
182
+ param_group['params'] = [params]
183
+ elif isinstance(params, set):
184
+ raise TypeError('Optimizer parameters need to be organized in ordered collections, but '
185
+ 'the ordering of tensors in sets will change between runs. '
186
+ 'Please use a list instead.')
187
+ else:
188
+ param_group['params'] = list(params)
189
+
190
+ for param in param_group['params']:
191
+ if not isinstance(param, Parameter):
192
+ raise TypeError("Optimizer can only optimize Parameters, but one of the params is " + type(param))
193
+
194
+ if len(param_group['params']) != len(set(param_group['params'])):
195
+ logger.warning("Optimizer contains a parameter group with duplicate parameters.")
196
+
197
+ param_set = set()
198
+ for group in self.param_groups:
199
+ param_set.update(set(group['params']))
200
+ if not param_set.isdisjoint(set(param_group['params'])):
201
+ raise ValueError("some parameters appear in more than one parameter group.")
202
+ return param_group
203
+
204
+ def _build_single_lr(self, learning_rate, name):
205
+ """Check lr value, and convert lr to a float or a Tensor."""
206
+ if isinstance(learning_rate, (float, int)):
207
+ learning_rate = float(learning_rate)
208
+ validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
209
+ return Parameter(Tensor(learning_rate, mstype.float32), name)
210
+
211
+ if isinstance(learning_rate, Tensor):
212
+ if learning_rate.ndim == 0:
213
+ return Parameter(learning_rate.astype(mstype.float32), name)
214
+ raise ValueError(f"For 'Optimizer', if 'learning_rate' is a Tensor, "
215
+ f"then it should be scalar Tensor")
216
+
217
+ raise TypeError("For 'Optimizer', the argument 'learning_rate' must be int, float or Tensor, "
218
+ "but got {}.".format(type(learning_rate)))
219
+
220
+ def _preprocess_weight_decay(self, weight_decay):
221
+ """preprocess weight decay"""
222
+ if isinstance(weight_decay, (float, int)):
223
+ weight_decay = float(weight_decay)
224
+ validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
225
+ else:
226
+ raise TypeError("For 'Optimizer', the argument 'Weight_decay' must be int or "
227
+ "float.but got {}".format(type(weight_decay)))
228
+ return weight_decay
229
+
230
+ def construct(self, *hyper_params):
231
+ raise NotImplementedError
232
+
233
+ op_add = P.AddN()
234
+ op_gather = P.Gather()
235
+ op_mul = P.Mul()
236
+
237
+ _apply_decay = C.MultitypeFuncGraph("apply_decay")
238
+
239
+
240
+ @_apply_decay.register("Tensor", "Tensor", "RowTensor")
241
+ def _tensor_apply_decay_with_sparse(weight_decay, weight, gradient):
242
+ """Get grad with weight_decay."""
243
+ indices = gradient.indices
244
+ values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
245
+ shape = gradient.dense_shape
246
+ return RowTensorInner(indices, values, shape)
247
+
248
+
249
+ @_apply_decay.register("Tensor", "Tensor", "Tensor")
250
+ def _tensor_apply_decay(weight_decay, weight, gradient):
251
+ """Get grad with weight_decay."""
252
+ return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))