mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -1,259 +1,252 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """optimizer"""
16
- from __future__ import absolute_import
17
- from collections import defaultdict
18
- from typing import Iterable
19
- from mindspore.ops import functional as F, composite as C, operations as P
20
- from mindspore.ops.operations import _inner_ops as inner
21
- from mindspore.nn.cell import Cell
22
- from mindspore.common.parameter import Parameter, ParameterTuple
23
- from mindspore.common import Tensor
24
- from mindspore.common.sparse_tensor import RowTensorInner
25
- import mindspore.common.dtype as mstype
26
- from mindspore import _checkparam as validator
27
- from mindspore import log as logger
28
-
29
-
30
- __all__ = ['Optimizer']
31
-
32
-
33
- class Optimizer(Cell):
34
- r"""
35
- Base class for all optimizers.
36
-
37
- .. warning::
38
- This is an experimental optimizer API that is subject to change.
39
- This module must be used with lr scheduler module in `LRScheduler Class
40
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
41
-
42
- Args:
43
- params (Union[list(Parameter), list(dict)]): an iterable of :class:`mindspore.Parameter` or
44
- dict. Specifies what Tensors should be optimized.
45
- defaults: (dict): a dict containing default values of optimization
46
- options (used when a parameter group doesn't specify them).
47
-
48
- Raises:
49
- TypeError: If `learning_rate` is not one of int, float, Tensor.
50
- TypeError: If element of `parameters` is neither Parameter nor dict.
51
- TypeError: If `weight_decay` is neither float nor int.
52
- ValueError: If `weight_decay` is less than 0.
53
- ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
54
-
55
- Supported Platforms:
56
- ``Ascend`` ``GPU`` ``CPU``
57
- """
58
- def __init__(self, params, defaults):
59
- super(Optimizer, self).__init__(auto_prefix=False)
60
-
61
- param_groups = self._parameters_base_check(params, "params")
62
- self.defaults = defaults
63
- self.state = defaultdict(dict)
64
- self.param_groups = []
65
- self.parameters = []
66
- self.map_ = C.Map()
67
- self.group_start_id = [0]
68
- if not isinstance(param_groups[0], dict):
69
- param_groups = [{'params': param_groups}]
70
-
71
- for i, param_group in enumerate(param_groups):
72
- self.add_param_group(i, param_group)
73
- self.group_start_id.append(self.group_start_id[-1] + len(param_group["params"]))
74
- self.parameters = ParameterTuple(self.parameters)
75
-
76
- def __repr__(self):
77
- format_string = self.__class__.__name__ + ' ('
78
- for i, group in enumerate(self.param_groups):
79
- format_string += '\n'
80
- format_string += 'Parameter Group {0}\n'.format(i)
81
- for key in sorted(group.keys()):
82
- if key != 'params':
83
- format_string += ' {0}: {1}\n'.format(key, group[key].value()) \
84
- if key == "lr" and isinstance(group[key], Parameter) \
85
- else ' {0}: {1}\n'.format(key, group[key])
86
- format_string += ')'
87
- return format_string
88
-
89
- def add_param_group(self, group_id, param_group):
90
- r"""
91
- Add a param group to the `Optimizer.param_groups`.
92
-
93
- Args:
94
- group_id(int): Specifies the group index.
95
- param_group (dict): Specifies what Parameters should be optimized along with group
96
- specific optimization options.
97
- """
98
- param_group = self._preprocess_param_group(param_group)
99
- self.parameters += param_group["params"]
100
-
101
- for name, default in self.defaults.items():
102
- if name not in param_group:
103
- param_group.setdefault(name, default)
104
-
105
- lr = self._build_single_lr(param_group.get("lr"), 'learning_rate_group_' + str(group_id))
106
- weight_decay = self._preprocess_weight_decay(param_group.get("weight_decay", 0.0))
107
- param_group["lr"] = lr
108
- param_group["weight_decay"] = weight_decay
109
- param_group["grad_centralization"] = self._preprocess_grad_centralization(
110
- param_group.get('grad_centralization', False))
111
- self.param_groups.append(param_group)
112
-
113
- @staticmethod
114
- def _parameters_base_check(parameters, param_info):
115
- """Parameters base check."""
116
- if parameters is None:
117
- raise ValueError(f"For 'Optimizer', the argument {param_info} can not be None.")
118
- if not isinstance(parameters, Iterable):
119
- raise TypeError(f"For 'Optimizer', the argument {param_info} must be Iterable type, "
120
- f"but got {type(parameters)}.")
121
- parameters = list(parameters)
122
-
123
- if not parameters:
124
- raise ValueError(f"For 'Optimizer', the argument {param_info} must not be empty.")
125
- return parameters
126
-
127
- def _decay_weight(self, weight_decay, params, gradients):
128
- """Apply weight decay."""
129
- if weight_decay != 0.:
130
- weight_decay = Tensor(weight_decay, mstype.float32)
131
- gradients = self.map_(F.partial(_apply_decay, weight_decay), params, gradients)
132
- return gradients
133
-
134
- def _gradients_centralization(self, grad_centralization, gradients):
135
- """Apply gradients centralization."""
136
- if grad_centralization:
137
- return self.map_(_apply_grad_centralization, gradients)
138
- return gradients
139
-
140
- def _preprocess_param_group(self, param_group):
141
- """Preprocess param groups."""
142
- if not isinstance(param_group, dict):
143
- raise TypeError('Param group must be a dict.')
144
-
145
- params = param_group['params']
146
- if isinstance(params, Parameter):
147
- param_group['params'] = [params]
148
- elif isinstance(params, set):
149
- raise TypeError('Optimizer parameters need to be organized in ordered collections, but '
150
- 'the ordering of tensors in sets will change between runs. '
151
- 'Please use a list instead.')
152
- else:
153
- param_group['params'] = list(params)
154
-
155
- for param in param_group['params']:
156
- if not isinstance(param, Parameter):
157
- raise TypeError("Optimizer can only optimize Parameters, but one of the params is " + type(param))
158
-
159
- if len(param_group['params']) != len(set(param_group['params'])):
160
- logger.warning("Optimizer contains a parameter group with duplicate parameters.")
161
-
162
- param_set = set()
163
- for group in self.param_groups:
164
- param_set.update(set(group['params']))
165
- if not param_set.isdisjoint(set(param_group['params'])):
166
- raise ValueError("some parameters appear in more than one parameter group.")
167
- return param_group
168
-
169
- def _build_single_lr(self, learning_rate, name):
170
- """Check lr value, and convert lr to a float or a Tensor."""
171
- if isinstance(learning_rate, (float, int)):
172
- learning_rate = float(learning_rate)
173
- validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
174
- return Parameter(Tensor(learning_rate, mstype.float32), name)
175
-
176
- if isinstance(learning_rate, Tensor):
177
- if learning_rate.ndim == 0:
178
- return Parameter(learning_rate.astype(mstype.float32), name)
179
- raise ValueError(f"For 'Optimizer', if 'learning_rate' is a Tensor, "
180
- f"then it should be scalar Tensor")
181
-
182
- raise TypeError("For 'Optimizer', the argument 'learning_rate' must be int, float or Tensor, "
183
- "but got {}.".format(type(learning_rate)))
184
-
185
- def _preprocess_weight_decay(self, weight_decay):
186
- """preprocess weight decay"""
187
- if isinstance(weight_decay, (float, int)):
188
- weight_decay = float(weight_decay)
189
- validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
190
- else:
191
- raise TypeError("For 'Optimizer', the argument 'Weight_decay' must be int or "
192
- "float.but got {}".format(type(weight_decay)))
193
- return weight_decay
194
-
195
- @staticmethod
196
- def _preprocess_grad_centralization(grad_centralization):
197
- """ Preprocess gradient centralization. """
198
- if not isinstance(grad_centralization, bool):
199
- raise TypeError("For 'Optimizer', the 'gradients_centralization' must be bool type, "
200
- "but got {}.".format(type(grad_centralization)))
201
- return grad_centralization
202
-
203
- def construct(self, *hyper_params):
204
- raise NotImplementedError
205
-
206
-
207
- op_add = P.AddN()
208
- op_gather = P.Gather()
209
- op_mul = P.Mul()
210
- op_gc = inner.Centralization()
211
-
212
- _apply_decay = C.MultitypeFuncGraph("apply_decay")
213
- _apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
214
-
215
-
216
- @_apply_decay.register("Tensor", "Tensor", "RowTensor")
217
- def _tensor_apply_decay_with_sparse(weight_decay, weight, gradient):
218
- """Get grad with weight_decay."""
219
- indices = gradient.indices
220
- values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
221
- shape = gradient.dense_shape
222
- return RowTensorInner(indices, values, shape)
223
-
224
-
225
- @_apply_decay.register("Tensor", "Tensor", "Tensor")
226
- def _tensor_apply_decay(weight_decay, weight, gradient):
227
- """Get grad with weight_decay."""
228
- return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
229
-
230
-
231
- @_apply_grad_centralization.register("RowTensor")
232
- def _tensor_apply_grad_centralization_with_sparse(gradient):
233
- """Get grad with grad_centralization."""
234
- indices = gradient.indices
235
- shape = gradient.dense_shape
236
- grad_shape = F.shape(gradient)
237
- axis = []
238
- for i in range(1, len(grad_shape)):
239
- axis.append(i)
240
- if len(axis) >= 1:
241
- if grad_shape[1] % 16 != 0:
242
- return gradient
243
- values = op_gc(gradient.values, axis)
244
- return RowTensorInner(indices, values, shape)
245
- return gradient
246
-
247
-
248
- @_apply_grad_centralization.register("Tensor")
249
- def _tensor_apply_grad_centralization(gradient):
250
- """Get grad with grad_centralization."""
251
- axis = []
252
- grad_shape = F.shape(gradient)
253
- for i in range(1, len(grad_shape)):
254
- axis.append(i)
255
- if len(axis) >= 1:
256
- if grad_shape[1] % 16 != 0:
257
- return gradient
258
- return op_gc(gradient, axis)
259
- return gradient
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """optimizer"""
16
+ from __future__ import absolute_import
17
+ from collections import defaultdict
18
+ from typing import Iterable
19
+ from mindspore.ops import functional as F, composite as C, operations as P
20
+
21
+ from mindspore.nn.cell import Cell
22
+ from mindspore.common.parameter import Parameter, ParameterTuple
23
+ from mindspore.common import Tensor
24
+ import mindspore.common.dtype as mstype
25
+ from mindspore import _checkparam as validator
26
+ from mindspore import log as logger
27
+
28
+
29
+ __all__ = ['Optimizer']
30
+
31
+
32
+ class Optimizer(Cell):
33
+ r"""
34
+ Base class for all optimizers.
35
+
36
+ .. warning::
37
+ This is an experimental optimizer API that is subject to change.
38
+ This module must be used with lr scheduler module in `LRScheduler Class
39
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
40
+
41
+ Args:
42
+ params (Union[list(Parameter), list(dict)]): an iterable of :class:`mindspore.Parameter` or
43
+ dict. Specifies what Tensors should be optimized.
44
+ defaults (dict): a dict containing default values of optimization
45
+ options (used when a parameter group doesn't specify them).
46
+
47
+ Raises:
48
+ TypeError: If `learning_rate` is not one of int, float, Tensor.
49
+ TypeError: If element of `parameters` is neither Parameter nor dict.
50
+ TypeError: If `weight_decay` is neither float nor int.
51
+ ValueError: If `weight_decay` is less than 0.
52
+ ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
53
+
54
+ Supported Platforms:
55
+ ``Ascend`` ``GPU`` ``CPU``
56
+
57
+ Examples:
58
+ >>> import numpy as np
59
+ >>> import mindspore
60
+ >>> from mindspore import nn, Tensor, Parameter
61
+ >>> from mindspore import ops
62
+ >>> from mindspore.experimental import optim
63
+ >>>
64
+ >>> class MySGD(optim.Optimizer):
65
+ ... def __init__(self, params, lr):
66
+ ... defaults = dict(lr=lr)
67
+ ... super(MySGD, self).__init__(params, defaults)
68
+ ...
69
+ ... def construct(self, gradients):
70
+ ... for group_id, group in enumerate(self.param_groups):
71
+ ... id = self.group_start_id[group_id]
72
+ ... for i, param in enumerate(group["params"]):
73
+ ... next_param = param + gradients[id+i] * group["lr"]
74
+ ... ops.assign(param, next_param)
75
+ >>>
76
+ >>> net = nn.Dense(8, 2)
77
+ >>> data = Tensor(np.random.rand(20, 8).astype(np.float32))
78
+ >>> label = Tensor(np.random.rand(20, 2).astype(np.float32))
79
+ >>>
80
+ >>> optimizer = MySGD(net.trainable_params(), 0.01)
81
+ >>> optimizer.add_param_group({"params": Parameter([0.01, 0.02])})
82
+ >>>
83
+ >>> criterion = nn.MAELoss(reduction="mean")
84
+ >>>
85
+ >>> def forward_fn(data, label):
86
+ ... logits = net(data)
87
+ ... loss = criterion(logits, label)
88
+ ... return loss, logits
89
+ >>>
90
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
91
+ >>>
92
+ >>> def train_step(data, label):
93
+ ... (loss, _), grads = grad_fn(data, label)
94
+ ... optimizer(grads)
95
+ ... print(loss)
96
+ >>>
97
+ >>> train_step(data, label)
98
+ """
99
+ def __init__(self, params, defaults):
100
+ super(Optimizer, self).__init__(auto_prefix=False)
101
+
102
+ param_groups = self._parameters_base_check(params, "params")
103
+ self.defaults = defaults
104
+ self.state = defaultdict(dict)
105
+ self.param_groups = []
106
+ self.parameters = []
107
+ self.map_ = C.Map()
108
+ self.group_start_id = [0]
109
+ if not isinstance(param_groups[0], dict):
110
+ param_groups = [{'params': param_groups}]
111
+
112
+ for param_group in param_groups:
113
+ self.add_param_group(param_group)
114
+ self.parameters = ParameterTuple(self.parameters)
115
+ self.hyper_map = C.HyperMap()
116
+ self.enable_tuple_broaden = True
117
+
118
+ def __repr__(self):
119
+ format_string = self.__class__.__name__ + ' ('
120
+ for i, group in enumerate(self.param_groups):
121
+ format_string += '\n'
122
+ format_string += 'Parameter Group {0}\n'.format(i)
123
+ for key in sorted(group.keys()):
124
+ if key != 'params':
125
+ format_string += ' {0}: {1}\n'.format(key, group[key].value()) \
126
+ if key == "lr" and isinstance(group[key], Parameter) \
127
+ else ' {0}: {1}\n'.format(key, group[key])
128
+ format_string += ')'
129
+ return format_string
130
+
131
+ def add_param_group(self, param_group):
132
+ r"""
133
+ Add a param group to the `Optimizer.param_groups`.
134
+
135
+ Args:
136
+ param_group (dict): Specifies what Parameters should be optimized along with group
137
+ specific optimization options.
138
+ """
139
+ group_id = len(self.param_groups)
140
+ param_group = self._preprocess_param_group(param_group)
141
+ self.parameters += tuple(param_group.get("params"))
142
+
143
+ for name, default in self.defaults.items():
144
+ if name not in param_group:
145
+ param_group.setdefault(name, default)
146
+
147
+ lr = self._build_single_lr(param_group.get("lr"), 'learning_rate_group_' + str(group_id))
148
+ weight_decay = self._preprocess_weight_decay(param_group.get("weight_decay", 0.0))
149
+ param_group["lr"] = lr
150
+ param_group["weight_decay"] = weight_decay
151
+ self.param_groups.append(param_group)
152
+ self.group_start_id.append(self.group_start_id[-1] + len(param_group.get("params")))
153
+
154
+ @staticmethod
155
+ def _parameters_base_check(parameters, param_info):
156
+ """Parameters base check."""
157
+ if parameters is None:
158
+ raise ValueError(f"For 'Optimizer', the argument {param_info} can not be None.")
159
+ if not isinstance(parameters, Iterable):
160
+ raise TypeError(f"For 'Optimizer', the argument {param_info} must be Iterable type, "
161
+ f"but got {type(parameters)}.")
162
+ parameters = list(parameters)
163
+
164
+ if not parameters:
165
+ raise ValueError(f"For 'Optimizer', the argument {param_info} must not be empty.")
166
+ return parameters
167
+
168
+ def _decay_weight(self, weight_decay, params, gradients):
169
+ """Apply weight decay."""
170
+ if weight_decay != 0.:
171
+ weight_decay = Tensor(weight_decay, mstype.float32)
172
+ gradients = self.map_(F.partial(_apply_decay, weight_decay), params, gradients)
173
+ return gradients
174
+
175
+ def _preprocess_param_group(self, param_group):
176
+ """Preprocess param groups."""
177
+ if not isinstance(param_group, dict):
178
+ raise TypeError('Param group must be a dict.')
179
+
180
+ params = param_group['params']
181
+ if isinstance(params, Parameter):
182
+ param_group['params'] = [params]
183
+ elif isinstance(params, set):
184
+ raise TypeError('Optimizer parameters need to be organized in ordered collections, but '
185
+ 'the ordering of tensors in sets will change between runs. '
186
+ 'Please use a list instead.')
187
+ else:
188
+ param_group['params'] = list(params)
189
+
190
+ for param in param_group['params']:
191
+ if not isinstance(param, Parameter):
192
+ raise TypeError("Optimizer can only optimize Parameters, but one of the params is " + type(param))
193
+
194
+ if len(param_group['params']) != len(set(param_group['params'])):
195
+ logger.warning("Optimizer contains a parameter group with duplicate parameters.")
196
+
197
+ param_set = set()
198
+ for group in self.param_groups:
199
+ param_set.update(set(group['params']))
200
+ if not param_set.isdisjoint(set(param_group['params'])):
201
+ raise ValueError("some parameters appear in more than one parameter group.")
202
+ return param_group
203
+
204
+ def _build_single_lr(self, learning_rate, name):
205
+ """Check lr value, and convert lr to a float or a Tensor."""
206
+ if isinstance(learning_rate, (float, int)):
207
+ learning_rate = float(learning_rate)
208
+ validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
209
+ return Parameter(Tensor(learning_rate, mstype.float32), name)
210
+
211
+ if isinstance(learning_rate, Tensor):
212
+ if learning_rate.ndim == 0:
213
+ return Parameter(learning_rate.astype(mstype.float32), name)
214
+ raise ValueError(f"For 'Optimizer', if 'learning_rate' is a Tensor, "
215
+ f"then it should be scalar Tensor")
216
+
217
+ raise TypeError("For 'Optimizer', the argument 'learning_rate' must be int, float or Tensor, "
218
+ "but got {}.".format(type(learning_rate)))
219
+
220
+ def _preprocess_weight_decay(self, weight_decay):
221
+ """preprocess weight decay"""
222
+ if isinstance(weight_decay, (float, int)):
223
+ weight_decay = float(weight_decay)
224
+ validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
225
+ else:
226
+ raise TypeError("For 'Optimizer', the argument 'Weight_decay' must be int or "
227
+ "float.but got {}".format(type(weight_decay)))
228
+ return weight_decay
229
+
230
+ def construct(self, *hyper_params):
231
+ raise NotImplementedError
232
+
233
+ op_add = P.AddN()
234
+ op_gather = P.Gather()
235
+ op_mul = P.Mul()
236
+
237
+ _apply_decay = C.MultitypeFuncGraph("apply_decay")
238
+
239
+
240
+ @_apply_decay.register("Tensor", "Tensor", "RowTensor")
241
+ def _tensor_apply_decay_with_sparse(weight_decay, weight, gradient):
242
+ """Get grad with weight_decay."""
243
+ indices = gradient.indices
244
+ values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
245
+ shape = gradient.dense_shape
246
+ return RowTensorInner(indices, values, shape)
247
+
248
+
249
+ @_apply_decay.register("Tensor", "Tensor", "Tensor")
250
+ def _tensor_apply_decay(weight_decay, weight, gradient):
251
+ """Get grad with weight_decay."""
252
+ return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))