mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,152 +1,147 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """sgd"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import operations as P
19
- from mindspore.common.tensor import Tensor
20
- import mindspore.common.dtype as mstype
21
- from mindspore import _checkparam as Validator
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
-
24
-
25
- class SGD(Optimizer):
26
- """
27
- Stochastic Gradient Descent optimizer.
28
-
29
- .. math::
30
- v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
31
-
32
- If nesterov is True:
33
-
34
- .. math::
35
- p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
36
-
37
- If nesterov is False:
38
-
39
- .. math::
40
- p_{t+1} = p_{t} - lr \ast v_{t+1}
41
-
42
- To be noticed, for the first step, :math:`v_{t+1} = gradient`.
43
-
44
- Here : where p, v and u denote the parameters, accum, and momentum respectively.
45
-
46
- .. warning::
47
- This is an experimental optimizer API that is subject to change.
48
- This module must be used with lr scheduler module in `LRScheduler Class
49
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
50
-
51
- Args:
52
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
53
- parameter groups.
54
- lr (Union[int, float, Tensor]): learning rate.
55
- momentum (Union[int, float], optional): momentum factor. Default: ``0``.
56
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
57
- dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
58
- nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
59
-
60
- Keyword Args:
61
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
62
- Default: ``False``.
63
-
64
- Inputs:
65
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
66
-
67
- Raises:
68
- ValueError: If the learning rate is not int, float or Tensor.
69
- ValueError: If the learning rate is less than 0.
70
- ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
71
- ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
72
- ValueError: If the `nesterov` and `maximize` is not bool.
73
- ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
74
-
75
- Supported Platforms:
76
- ``Ascend`` ``GPU`` ``CPU``
77
-
78
- Examples:
79
- >>> import mindspore
80
- >>> from mindspore import nn
81
- >>> # Define the network structure of LeNet5. Refer to
82
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
83
- >>> net = LeNet5()
84
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
85
- >>> optimizer = nn.optim_ex.SGD(net.trainable_params(), lr=0.1)
86
- >>> def forward_fn(data, label):
87
- ... logits = net(data)
88
- ... loss = loss_fn(logits, label)
89
- ... return loss, logits
90
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
91
- >>> def train_step(data, label):
92
- ... (loss, _), grads = grad_fn(data, label)
93
- ... optimizer(grads)
94
- ... return loss
95
- """
96
- def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
97
- maximize=False):
98
- Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
99
- if lr < 0.0:
100
- raise ValueError("Invalid learning rate: {}".format(lr))
101
- Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
102
- if momentum < 0.0:
103
- raise ValueError("Invalid momentum value: {}".format(momentum))
104
- momentum = float(momentum)
105
- Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
106
- Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
107
-
108
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
109
- weight_decay=weight_decay, nesterov=nesterov,
110
- maximize=maximize, grad_centralization=False)
111
- super(SGD, self).__init__(params, defaults)
112
- for group in self.param_groups:
113
- Validator.check_value_type("dampening", group["dampening"], [int, float], self.cls_name)
114
- group["dampening"] = float(group["dampening"])
115
- if nesterov and (momentum <= 0.0 or dampening != 0.0):
116
- raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
117
- "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
118
- self.accum = self.parameters.clone(prefix="accum", init='zeros')
119
- self.stat = self.parameters.clone(prefix="stat", init='ones')
120
- self.op_cast = P.Cast()
121
-
122
- def construct(self, gradients):
123
- for group_id, group in enumerate(self.param_groups):
124
- params = []
125
- grads = []
126
- accums = []
127
- stats = []
128
- params, grads, accums, stats = self._init_group(group, gradients, params, grads,
129
- accums, stats, group_id)
130
- opt = P.SGD(group["dampening"], group["weight_decay"], group["nesterov"])
131
- lr = group["lr"]
132
- if isinstance(lr, float):
133
- lr = self.op_cast(group["lr"], mstype.float32)
134
- momentum = self.op_cast(group["momentum"], mstype.float32)
135
- self.apply_sgd(opt, params, grads, lr, accums, momentum, stats, group["maximize"],
136
- group["grad_centralization"])
137
-
138
- def apply_sgd(self, opt, params, grads, lr, accums, momentum, stats, maximize, grad_centralization):
139
- grads = self._gradients_centralization(grad_centralization, grads)
140
-
141
- for i, param in enumerate(params):
142
- grad = grads[i] if not maximize else -grads[i]
143
- opt(param, grad, lr, accums[i], momentum, stats[i])
144
-
145
- def _init_group(self, group, gradients, params, accums, grads, stats, group_id):
146
- p_id = self.group_start_id[group_id]
147
- for i, param in enumerate(group["params"]):
148
- params.append(param)
149
- grads.append(gradients[p_id+i])
150
- accums.append(self.accum[p_id+i])
151
- stats.append(self.stat[p_id+i])
152
- return params, grads, accums, stats
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """sgd"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.tensor import Tensor
20
+ import mindspore.common.dtype as mstype
21
+ from mindspore import _checkparam as Validator
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+
24
+ _sgd_opt = C.MultitypeFuncGraph("sgd_opt")
25
+
26
+
27
+ @_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",)
28
+ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
29
+ """Apply sgd optimizer to the weight parameter using Tensor."""
30
+ success = True
31
+ success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
32
+ return success
33
+
34
+
35
+ class SGD(Optimizer):
36
+ r"""
37
+ Stochastic Gradient Descent optimizer.
38
+
39
+ .. math::
40
+ v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
41
+
42
+ If nesterov is True:
43
+
44
+ .. math::
45
+ p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
46
+
47
+ If nesterov is False:
48
+
49
+ .. math::
50
+ p_{t+1} = p_{t} - lr \ast v_{t+1}
51
+
52
+ To be noticed, for the first step, :math:`v_{t+1} = gradient`.
53
+
54
+ Here : where p, v and u denote the parameters, accum, and momentum respectively.
55
+
56
+ .. warning::
57
+ This is an experimental optimizer API that is subject to change.
58
+ This module must be used with lr scheduler module in `LRScheduler Class
59
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
60
+
61
+ Args:
62
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
63
+ parameter groups.
64
+ lr (Union[int, float, Tensor]): learning rate.
65
+ momentum (Union[int, float], optional): momentum factor. Default: ``0``.
66
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
67
+ dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
68
+ nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
69
+
70
+ Keyword Args:
71
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
72
+ Default: ``False``.
73
+
74
+ Inputs:
75
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
76
+
77
+ Raises:
78
+ ValueError: If the learning rate is not int, float or Tensor.
79
+ ValueError: If the learning rate is less than 0.
80
+ ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
81
+ ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
82
+ ValueError: If the `nesterov` and `maximize` is not bool.
83
+ ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
84
+
85
+ Supported Platforms:
86
+ ``Ascend`` ``GPU`` ``CPU``
87
+
88
+ Examples:
89
+ >>> import mindspore
90
+ >>> from mindspore import nn
91
+ >>> from mindspore.experimental import optim
92
+ >>> # Define the network structure of LeNet5. Refer to
93
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
94
+ >>> net = LeNet5()
95
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
96
+ >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1)
97
+ >>> def forward_fn(data, label):
98
+ ... logits = net(data)
99
+ ... loss = loss_fn(logits, label)
100
+ ... return loss, logits
101
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
102
+ >>> def train_step(data, label):
103
+ ... (loss, _), grads = grad_fn(data, label)
104
+ ... optimizer(grads)
105
+ ... return loss
106
+ """
107
+ def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
108
+ maximize=False):
109
+ Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
110
+ if lr < 0.0:
111
+ raise ValueError("Invalid learning rate: {}".format(lr))
112
+ Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
113
+ if momentum < 0.0:
114
+ raise ValueError("Invalid momentum value: {}".format(momentum))
115
+ momentum = float(momentum)
116
+ Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
117
+ Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
118
+
119
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
120
+ weight_decay=weight_decay, nesterov=nesterov,
121
+ maximize=maximize, grad_centralization=False)
122
+ super(SGD, self).__init__(params, defaults)
123
+ for group in self.param_groups:
124
+ Validator.check_value_type("dampening", group.get("dampening"), [int, float], self.cls_name)
125
+ group["dampening"] = float(group.get("dampening"))
126
+ if nesterov and (momentum <= 0.0 or dampening != 0.0):
127
+ raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
128
+ "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
129
+ self.accum = self.parameters.clone(prefix="accum", init='zeros')
130
+ self.stat = self.parameters.clone(prefix="stat", init='ones')
131
+ self.op_cast = P.Cast()
132
+
133
+ def construct(self, gradients):
134
+ for group_id, group in enumerate(self.param_groups):
135
+ opt = P.SGD(group.get("dampening"), group.get("weight_decay"), group.get("nesterov"))
136
+ lr = group.get("lr")
137
+ if isinstance(lr, float):
138
+ lr = self.op_cast(group.get("lr"), mstype.float32)
139
+ maximize = group.get("maximize")
140
+ momentum = self.op_cast(group.get("momentum"), mstype.float32)
141
+ start_id = self.group_start_id[group_id]
142
+ end_id = self.group_start_id[group_id+1]
143
+ grads = gradients[start_id: end_id] if not maximize else -gradients[start_id: end_id]
144
+ self.hyper_map(F.partial(_sgd_opt, opt, momentum, lr), grads,
145
+ self.parameters[start_id: end_id], self.accum[start_id: end_id],
146
+ self.stat[start_id: end_id])
147
+ return True
mindspore/gen_ops.py ADDED
@@ -0,0 +1,273 @@
1
+ # Copyright 2023-2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Generate operator definition from ops.yaml
17
+ """
18
+ import sys
19
+ import os
20
+ import yaml
21
+
22
+
23
+ def generate_py_op_func(yaml_data, doc_data):
24
+ """
25
+ generate python operator function
26
+ """
27
+ gen_py = ''
28
+
29
+ op_desc_dict = {}
30
+ for operator_name, operator_desc in doc_data.items():
31
+ desc = operator_desc.get("description")
32
+ op_desc_dict[operator_name] = desc
33
+
34
+ for operator_name, operator_data in yaml_data.items():
35
+ description = op_desc_dict.get(operator_name)
36
+ args = operator_data.get('args')
37
+ func_name = operator_data.get('func_name')
38
+ if func_name is None:
39
+ func_name = operator_name
40
+
41
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
42
+ func_args = []
43
+ primitive_init_args = []
44
+ input_args = []
45
+ for arg_name, arg_info in args.items():
46
+ dtype = arg_info.get('dtype')
47
+ init_value = arg_info.get('init')
48
+ if init_value:
49
+ if dtype == 'str':
50
+ init_value = '"' + init_value + '"'
51
+ func_args.append(f"""{arg_name}={init_value}""")
52
+ primitive_init_args.append(arg_name)
53
+ else:
54
+ func_args.append(arg_name)
55
+ input_args.append(arg_name)
56
+
57
+ function_code = f"""
58
+ def {func_name}({', '.join(arg for arg in func_args)}):
59
+ \"\"\"
60
+ {description}
61
+ \"\"\"
62
+ {operator_name}_op = _get_cache_prim(P.{class_name})({', '.join(arg_name for arg_name in primitive_init_args)})
63
+ return {operator_name}_op({', '.join(arg_name for arg_name in input_args)})
64
+ """
65
+ gen_py += function_code
66
+
67
+ return gen_py
68
+
69
+
70
+ def generate_py_primitive(yaml_data):
71
+ """
72
+ generate python primitive
73
+ """
74
+ gen_py = ''
75
+ for operator_name, operator_data in yaml_data.items():
76
+ args = operator_data.get('args')
77
+ func_name = operator_data.get('func_name')
78
+ if func_name is None:
79
+ func_name = operator_name
80
+
81
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
82
+
83
+ init_args_with_default = []
84
+ init_args = []
85
+ args_assign = []
86
+ for arg_name, arg_info in args.items():
87
+ dtype = arg_info.get('dtype')
88
+ type_cast = arg_info.get('type_cast')
89
+ type_cast_set = None
90
+ if type_cast:
91
+ type_cast_set = {ct.strip() for ct in type_cast.split(",")}
92
+
93
+ init_value = arg_info.get('init')
94
+ if init_value is None:
95
+ continue
96
+
97
+ if dtype == 'str':
98
+ init_value = '"' + init_value + '"'
99
+ init_args_with_default.append(f"""{arg_name}={init_value}""")
100
+ init_args.append(arg_name)
101
+
102
+ assign_str = f""" self.{arg_name} = """
103
+
104
+ if type_cast_set:
105
+ assign_str += f'type_it({arg_name}, '
106
+ type_cast_list = []
107
+
108
+ if 'int' in type_cast_set:
109
+ type_cast_list.append('INT')
110
+ if 'tuple[int]' in type_cast_list:
111
+ type_cast_list.append('TUPLE')
112
+ #add more type cast kind here
113
+
114
+ assign_str += 'TypeCastKind.' + '_OR_'.join(ct for ct in type_cast_list)
115
+ if dtype == 'tuple[int]':
116
+ assign_str += '_TO_TUPLE)'
117
+ if dtype == 'list[int]':
118
+ assign_str += '_TO_LIST)'
119
+ else:
120
+ assign_str += arg_name
121
+ args_assign.append(assign_str)
122
+
123
+ args_assign = '\n'.join(assign for assign in args_assign)
124
+ primitive_code = f"""
125
+ class {class_name}(Primitive):
126
+ def __init__(self, {', '.join(init_args_with_default)}):
127
+ {args_assign}
128
+ def __call__(self, *args):
129
+ super.__call__(self, *args, {', '.join([f'self.{arg}' for arg in init_args])})
130
+ """
131
+
132
+ gen_py += primitive_code
133
+ return gen_py
134
+
135
+
136
+ def generate_cc_opdef(yaml_data):
137
+ """
138
+ generate OpDef
139
+ """
140
+ gen_cc = ''
141
+ opdef_map_str = f"""
142
+ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
143
+
144
+ for operator_name, operator_data in yaml_data.items():
145
+ args = operator_data.get('args')
146
+ returns = operator_data.get('returns')
147
+ func_name = operator_data.get('func_name')
148
+ if func_name is None:
149
+ func_name = operator_name
150
+
151
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
152
+ opdef_map_str += f"""
153
+ {{"{operator_name}", &g{class_name}}},"""
154
+
155
+ opdef_cc = f"""
156
+ OpDef g{class_name} = {{
157
+ .name_ = "{operator_name}","""
158
+ opdef_cc += f"""
159
+ .args_ = {{"""
160
+
161
+ for arg_name, arg_info in args.items():
162
+ dtype = arg_info.get('dtype')
163
+ init = arg_info.get('init')
164
+ if init is None:
165
+ init = 0
166
+ else:
167
+ init = 1
168
+ cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').replace('tuple', 'array').replace(
169
+ 'list', 'array').upper()
170
+ cc_dtype_str.replace('TUPLE', 'ARRAY').replace('LIST', 'ARRAY')
171
+ opdef_cc += f"""
172
+ {{.arg_name_ = "{arg_name}", .arg_dtype_ = {cc_dtype_str}, .as_init_arg_ = {init}}},"""
173
+ opdef_cc += f"""
174
+ }},"""
175
+
176
+ opdef_cc += f"""
177
+ .returns_ = {{"""
178
+
179
+ for return_name, return_info in returns.items():
180
+ return_dtype = return_info.get('dtype')
181
+ cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').replace(
182
+ 'tuple', 'array').replace('list', 'array').upper()
183
+ opdef_cc += f"""
184
+ {{.arg_name_ = "{return_name}", .arg_dtype_ = {cc_return_type_str}}},"""
185
+
186
+ opdef_cc += f"""
187
+ }},"""
188
+
189
+ opdef_cc += f"""
190
+ }};"""
191
+ gen_cc += opdef_cc
192
+
193
+ opdef_map_str += f"""
194
+ }};"""
195
+ gen_cc += opdef_map_str
196
+ return gen_cc
197
+
198
+
199
+ if __name__ == "__main__":
200
+ work_path = ''
201
+ if len(sys.argv) > 1:
202
+ work_path = sys.argv[1]
203
+
204
+ yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops.yaml')
205
+ doc_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_doc.yaml')
206
+ op_py_path = os.path.join(work_path, 'mindspore/python/mindspore/gen_ops_def.py')
207
+ op_cc_path = os.path.join(work_path, 'mindspore/core/ops/gen_ops_def.cc')
208
+
209
+ yaml_str = None
210
+ with open(yaml_path, 'r') as yaml_file:
211
+ yaml_str = yaml.safe_load(yaml_file)
212
+
213
+ doc_str = None
214
+ with open(doc_yaml_path, 'r') as doc_file:
215
+ doc_str = yaml.safe_load(doc_file)
216
+
217
+ cc_code = generate_cc_opdef(yaml_str)
218
+ cc_code += f"""
219
+ }} // namespace mindspore::ops"""
220
+
221
+ py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd
222
+ #
223
+ # Licensed under the Apache License, Version 2.0 (the "License");
224
+ # you may not use this file except in compliance with the License.
225
+ # You may obtain a copy of the License at
226
+ #
227
+ # http://www.apache.org/licenses/LICENSE-2.0
228
+ #
229
+ # Unless required by applicable law or agreed to in writing, software
230
+ # distributed under the License is distributed on an "AS IS" BASIS,
231
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
232
+ # See the License for the specific language governing permissions and
233
+ # limitations under the License.
234
+ # ============================================================================
235
+ """
236
+ pyheader = f"""
237
+ \"\"\"Operators definition generated by gen_os.py, includes functions and primitive classes.\"\"\"
238
+
239
+ from mindspore.ops.primitive import Primitive
240
+ from mindspore.ops import operations as P
241
+ from mindspore.ops import functional as F
242
+ from mindspore.ops._primitive_cache import _get_cache_prim
243
+ from mindspore.ops.arg_dtype_cast import TypeCastKind, type_it
244
+ """
245
+ cc_license_str = f"""/**
246
+ * Copyright 2023 Huawei Technologies Co., Ltd
247
+ *
248
+ * Licensed under the Apache License, Version 2.0 (the "License");
249
+ * you may not use this file except in compliance with the License.
250
+ * You may obtain a copy of the License at
251
+ *
252
+ * http://www.apache.org/licenses/LICENSE-2.0
253
+ *
254
+ * Unless required by applicable law or agreed to in writing, software
255
+ * distributed under the License is distributed on an "AS IS" BASIS,
256
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
257
+ * See the License for the specific language governing permissions and
258
+ * limitations under the License.
259
+ */"""
260
+
261
+ ccheader = f"""
262
+ #include "op_def.h"
263
+ namespace mindspore::ops {{
264
+ """
265
+ py_prim = generate_py_primitive(yaml_str)
266
+ py_func = generate_py_op_func(yaml_str, doc_str)
267
+ py_file = None
268
+ with open(op_py_path, 'w') as py_file:
269
+ py_file.write(py_licence_str + pyheader + py_prim + py_func)
270
+
271
+ cc_file = None
272
+ with open(op_cc_path, 'w') as cc_file:
273
+ cc_file.write(cc_license_str + ccheader + cc_code)
mindspore/include/OWNERS CHANGED
@@ -1,6 +1,5 @@
1
1
  approvers:
2
2
  - jpc_chenjianping #
3
- - zhoufeng54
4
3
  - zhang_xue_tong
5
4
  reviewers:
6
5
  - lx0095
@@ -38,7 +38,8 @@ enum class DataType : int {
38
38
  kNumberTypeFloat16 = 42,
39
39
  kNumberTypeFloat32 = 43,
40
40
  kNumberTypeFloat64 = 44,
41
- kNumberTypeEnd = 46,
41
+ kNumberTypeBFloat16 = 46,
42
+ kNumberTypeEnd = 53,
42
43
  // add new enum here
43
44
  kInvalidType = INT32_MAX,
44
45
  };
@@ -24,38 +24,23 @@
24
24
  #include "include/api/types.h"
25
25
 
26
26
  namespace mindspore {
27
- class NetData;
28
- class Net;
29
-
30
27
  class MS_API Graph {
31
28
  public:
32
29
  class GraphData;
33
- enum Type : uint32_t {
34
- kExpressionGraph = 0, ///< graph as expression - can auto grad
35
- kExecutableGraph = 1, ///< graph is loaded as is
36
- kUnknownTypeGraph = 0xffffffff
37
- };
38
30
  Graph();
39
31
  explicit Graph(const std::shared_ptr<GraphData> &graph_data);
40
32
  explicit Graph(std::shared_ptr<GraphData> &&graph_data);
41
33
  explicit Graph(std::nullptr_t);
42
34
  ~Graph();
43
- explicit Graph(Type executable);
44
- explicit Graph(Net *net);
45
35
 
46
36
  enum ModelType ModelType() const;
47
37
  bool operator==(std::nullptr_t) const;
48
38
  bool operator!=(std::nullptr_t) const;
49
- bool IsExecutable() { return graph_type_ == kExecutableGraph; }
50
39
 
51
40
  private:
52
41
  friend class GraphCell;
53
42
  friend class ModelImpl;
54
- friend class NetImpl;
55
- friend class Model;
56
43
  std::shared_ptr<GraphData> graph_data_;
57
- std::shared_ptr<NetData> net_data_;
58
- Type graph_type_ = kExecutableGraph;
59
44
  };
60
45
  } // namespace mindspore
61
46
  #endif // MINDSPORE_INCLUDE_API_GRAPH_H
@@ -35,6 +35,8 @@ class MS_API Kernel : public IKernel<schema::Primitive> {
35
35
  Initialize();
36
36
  }
37
37
  virtual ~Kernel() = default;
38
+
39
+ int InferShape() override;
38
40
  /// \brief obtain kernel's type.
39
41
  ///
40
42
  /// \return kernel's type.