mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,152 +1,147 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """sgd"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import operations as P
19
- from mindspore.common.tensor import Tensor
20
- import mindspore.common.dtype as mstype
21
- from mindspore import _checkparam as Validator
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
-
24
-
25
- class SGD(Optimizer):
26
- """
27
- Stochastic Gradient Descent optimizer.
28
-
29
- .. math::
30
- v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
31
-
32
- If nesterov is True:
33
-
34
- .. math::
35
- p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
36
-
37
- If nesterov is False:
38
-
39
- .. math::
40
- p_{t+1} = p_{t} - lr \ast v_{t+1}
41
-
42
- To be noticed, for the first step, :math:`v_{t+1} = gradient`.
43
-
44
- Here : where p, v and u denote the parameters, accum, and momentum respectively.
45
-
46
- .. warning::
47
- This is an experimental optimizer API that is subject to change.
48
- This module must be used with lr scheduler module in `LRScheduler Class
49
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
50
-
51
- Args:
52
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
53
- parameter groups.
54
- lr (Union[int, float, Tensor]): learning rate.
55
- momentum (Union[int, float], optional): momentum factor. Default: ``0``.
56
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
57
- dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
58
- nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
59
-
60
- Keyword Args:
61
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
62
- Default: ``False``.
63
-
64
- Inputs:
65
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
66
-
67
- Raises:
68
- ValueError: If the learning rate is not int, float or Tensor.
69
- ValueError: If the learning rate is less than 0.
70
- ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
71
- ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
72
- ValueError: If the `nesterov` and `maximize` is not bool.
73
- ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
74
-
75
- Supported Platforms:
76
- ``Ascend`` ``GPU`` ``CPU``
77
-
78
- Examples:
79
- >>> import mindspore
80
- >>> from mindspore import nn
81
- >>> # Define the network structure of LeNet5. Refer to
82
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
83
- >>> net = LeNet5()
84
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
85
- >>> optimizer = nn.optim_ex.SGD(net.trainable_params(), lr=0.1)
86
- >>> def forward_fn(data, label):
87
- ... logits = net(data)
88
- ... loss = loss_fn(logits, label)
89
- ... return loss, logits
90
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
91
- >>> def train_step(data, label):
92
- ... (loss, _), grads = grad_fn(data, label)
93
- ... optimizer(grads)
94
- ... return loss
95
- """
96
- def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
97
- maximize=False):
98
- Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
99
- if lr < 0.0:
100
- raise ValueError("Invalid learning rate: {}".format(lr))
101
- Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
102
- if momentum < 0.0:
103
- raise ValueError("Invalid momentum value: {}".format(momentum))
104
- momentum = float(momentum)
105
- Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
106
- Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
107
-
108
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
109
- weight_decay=weight_decay, nesterov=nesterov,
110
- maximize=maximize, grad_centralization=False)
111
- super(SGD, self).__init__(params, defaults)
112
- for group in self.param_groups:
113
- Validator.check_value_type("dampening", group["dampening"], [int, float], self.cls_name)
114
- group["dampening"] = float(group["dampening"])
115
- if nesterov and (momentum <= 0.0 or dampening != 0.0):
116
- raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
117
- "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
118
- self.accum = self.parameters.clone(prefix="accum", init='zeros')
119
- self.stat = self.parameters.clone(prefix="stat", init='ones')
120
- self.op_cast = P.Cast()
121
-
122
- def construct(self, gradients):
123
- for group_id, group in enumerate(self.param_groups):
124
- params = []
125
- grads = []
126
- accums = []
127
- stats = []
128
- params, grads, accums, stats = self._init_group(group, gradients, params, grads,
129
- accums, stats, group_id)
130
- opt = P.SGD(group["dampening"], group["weight_decay"], group["nesterov"])
131
- lr = group["lr"]
132
- if isinstance(lr, float):
133
- lr = self.op_cast(group["lr"], mstype.float32)
134
- momentum = self.op_cast(group["momentum"], mstype.float32)
135
- self.apply_sgd(opt, params, grads, lr, accums, momentum, stats, group["maximize"],
136
- group["grad_centralization"])
137
-
138
- def apply_sgd(self, opt, params, grads, lr, accums, momentum, stats, maximize, grad_centralization):
139
- grads = self._gradients_centralization(grad_centralization, grads)
140
-
141
- for i, param in enumerate(params):
142
- grad = grads[i] if not maximize else -grads[i]
143
- opt(param, grad, lr, accums[i], momentum, stats[i])
144
-
145
- def _init_group(self, group, gradients, params, accums, grads, stats, group_id):
146
- p_id = self.group_start_id[group_id]
147
- for i, param in enumerate(group["params"]):
148
- params.append(param)
149
- grads.append(gradients[p_id+i])
150
- accums.append(self.accum[p_id+i])
151
- stats.append(self.stat[p_id+i])
152
- return params, grads, accums, stats
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """sgd"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.tensor import Tensor
20
+ import mindspore.common.dtype as mstype
21
+ from mindspore import _checkparam as Validator
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+
24
+ _sgd_opt = C.MultitypeFuncGraph("sgd_opt")
25
+
26
+
27
+ @_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",)
28
+ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
29
+ """Apply sgd optimizer to the weight parameter using Tensor."""
30
+ success = True
31
+ success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
32
+ return success
33
+
34
+
35
+ class SGD(Optimizer):
36
+ r"""
37
+ Stochastic Gradient Descent optimizer.
38
+
39
+ .. math::
40
+ v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
41
+
42
+ If nesterov is True:
43
+
44
+ .. math::
45
+ p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
46
+
47
+ If nesterov is False:
48
+
49
+ .. math::
50
+ p_{t+1} = p_{t} - lr \ast v_{t+1}
51
+
52
+ To be noticed, for the first step, :math:`v_{t+1} = gradient`.
53
+
54
+ Here : where p, v and u denote the parameters, accum, and momentum respectively.
55
+
56
+ .. warning::
57
+ This is an experimental optimizer API that is subject to change.
58
+ This module must be used with lr scheduler module in `LRScheduler Class
59
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
60
+
61
+ Args:
62
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
63
+ parameter groups.
64
+ lr (Union[int, float, Tensor]): learning rate.
65
+ momentum (Union[int, float], optional): momentum factor. Default: ``0``.
66
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
67
+ dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
68
+ nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
69
+
70
+ Keyword Args:
71
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
72
+ Default: ``False``.
73
+
74
+ Inputs:
75
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
76
+
77
+ Raises:
78
+ ValueError: If the learning rate is not int, float or Tensor.
79
+ ValueError: If the learning rate is less than 0.
80
+ ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
81
+ ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
82
+ ValueError: If the `nesterov` and `maximize` is not bool.
83
+ ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
84
+
85
+ Supported Platforms:
86
+ ``Ascend`` ``GPU`` ``CPU``
87
+
88
+ Examples:
89
+ >>> import mindspore
90
+ >>> from mindspore import nn
91
+ >>> from mindspore.experimental import optim
92
+ >>> # Define the network structure of LeNet5. Refer to
93
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
94
+ >>> net = LeNet5()
95
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
96
+ >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1)
97
+ >>> def forward_fn(data, label):
98
+ ... logits = net(data)
99
+ ... loss = loss_fn(logits, label)
100
+ ... return loss, logits
101
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
102
+ >>> def train_step(data, label):
103
+ ... (loss, _), grads = grad_fn(data, label)
104
+ ... optimizer(grads)
105
+ ... return loss
106
+ """
107
+ def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
108
+ maximize=False):
109
+ Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
110
+ if lr < 0.0:
111
+ raise ValueError("Invalid learning rate: {}".format(lr))
112
+ Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
113
+ if momentum < 0.0:
114
+ raise ValueError("Invalid momentum value: {}".format(momentum))
115
+ momentum = float(momentum)
116
+ Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
117
+ Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
118
+
119
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
120
+ weight_decay=weight_decay, nesterov=nesterov,
121
+ maximize=maximize, grad_centralization=False)
122
+ super(SGD, self).__init__(params, defaults)
123
+ for group in self.param_groups:
124
+ Validator.check_value_type("dampening", group.get("dampening"), [int, float], self.cls_name)
125
+ group["dampening"] = float(group.get("dampening"))
126
+ if nesterov and (momentum <= 0.0 or dampening != 0.0):
127
+ raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
128
+ "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
129
+ self.accum = self.parameters.clone(prefix="accum", init='zeros')
130
+ self.stat = self.parameters.clone(prefix="stat", init='ones')
131
+ self.op_cast = P.Cast()
132
+
133
+ def construct(self, gradients):
134
+ for group_id, group in enumerate(self.param_groups):
135
+ opt = P.SGD(group.get("dampening"), group.get("weight_decay"), group.get("nesterov"))
136
+ lr = group.get("lr")
137
+ if isinstance(lr, float):
138
+ lr = self.op_cast(group.get("lr"), mstype.float32)
139
+ maximize = group.get("maximize")
140
+ momentum = self.op_cast(group.get("momentum"), mstype.float32)
141
+ start_id = self.group_start_id[group_id]
142
+ end_id = self.group_start_id[group_id+1]
143
+ grads = gradients[start_id: end_id] if not maximize else -gradients[start_id: end_id]
144
+ self.hyper_map(F.partial(_sgd_opt, opt, momentum, lr), grads,
145
+ self.parameters[start_id: end_id], self.accum[start_id: end_id],
146
+ self.stat[start_id: end_id])
147
+ return True
mindspore/gen_ops.py ADDED
@@ -0,0 +1,273 @@
1
+ # Copyright 2023-2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Generate operator definition from ops.yaml
17
+ """
18
+ import sys
19
+ import os
20
+ import yaml
21
+
22
+
23
+ def generate_py_op_func(yaml_data, doc_data):
24
+ """
25
+ generate python operator function
26
+ """
27
+ gen_py = ''
28
+
29
+ op_desc_dict = {}
30
+ for operator_name, operator_desc in doc_data.items():
31
+ desc = operator_desc.get("description")
32
+ op_desc_dict[operator_name] = desc
33
+
34
+ for operator_name, operator_data in yaml_data.items():
35
+ description = op_desc_dict.get(operator_name)
36
+ args = operator_data.get('args')
37
+ func_name = operator_data.get('func_name')
38
+ if func_name is None:
39
+ func_name = operator_name
40
+
41
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
42
+ func_args = []
43
+ primitive_init_args = []
44
+ input_args = []
45
+ for arg_name, arg_info in args.items():
46
+ dtype = arg_info.get('dtype')
47
+ init_value = arg_info.get('init')
48
+ if init_value:
49
+ if dtype == 'str':
50
+ init_value = '"' + init_value + '"'
51
+ func_args.append(f"""{arg_name}={init_value}""")
52
+ primitive_init_args.append(arg_name)
53
+ else:
54
+ func_args.append(arg_name)
55
+ input_args.append(arg_name)
56
+
57
+ function_code = f"""
58
+ def {func_name}({', '.join(arg for arg in func_args)}):
59
+ \"\"\"
60
+ {description}
61
+ \"\"\"
62
+ {operator_name}_op = _get_cache_prim(P.{class_name})({', '.join(arg_name for arg_name in primitive_init_args)})
63
+ return {operator_name}_op({', '.join(arg_name for arg_name in input_args)})
64
+ """
65
+ gen_py += function_code
66
+
67
+ return gen_py
68
+
69
+
70
+ def generate_py_primitive(yaml_data):
71
+ """
72
+ generate python primitive
73
+ """
74
+ gen_py = ''
75
+ for operator_name, operator_data in yaml_data.items():
76
+ args = operator_data.get('args')
77
+ func_name = operator_data.get('func_name')
78
+ if func_name is None:
79
+ func_name = operator_name
80
+
81
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
82
+
83
+ init_args_with_default = []
84
+ init_args = []
85
+ args_assign = []
86
+ for arg_name, arg_info in args.items():
87
+ dtype = arg_info.get('dtype')
88
+ type_cast = arg_info.get('type_cast')
89
+ type_cast_set = None
90
+ if type_cast:
91
+ type_cast_set = {ct.strip() for ct in type_cast.split(",")}
92
+
93
+ init_value = arg_info.get('init')
94
+ if init_value is None:
95
+ continue
96
+
97
+ if dtype == 'str':
98
+ init_value = '"' + init_value + '"'
99
+ init_args_with_default.append(f"""{arg_name}={init_value}""")
100
+ init_args.append(arg_name)
101
+
102
+ assign_str = f""" self.{arg_name} = """
103
+
104
+ if type_cast_set:
105
+ assign_str += f'type_it({arg_name}, '
106
+ type_cast_list = []
107
+
108
+ if 'int' in type_cast_set:
109
+ type_cast_list.append('INT')
110
+ if 'tuple[int]' in type_cast_list:
111
+ type_cast_list.append('TUPLE')
112
+ #add more type cast kind here
113
+
114
+ assign_str += 'TypeCastKind.' + '_OR_'.join(ct for ct in type_cast_list)
115
+ if dtype == 'tuple[int]':
116
+ assign_str += '_TO_TUPLE)'
117
+ if dtype == 'list[int]':
118
+ assign_str += '_TO_LIST)'
119
+ else:
120
+ assign_str += arg_name
121
+ args_assign.append(assign_str)
122
+
123
+ args_assign = '\n'.join(assign for assign in args_assign)
124
+ primitive_code = f"""
125
+ class {class_name}(Primitive):
126
+ def __init__(self, {', '.join(init_args_with_default)}):
127
+ {args_assign}
128
+ def __call__(self, *args):
129
+ super.__call__(self, *args, {', '.join([f'self.{arg}' for arg in init_args])})
130
+ """
131
+
132
+ gen_py += primitive_code
133
+ return gen_py
134
+
135
+
136
+ def generate_cc_opdef(yaml_data):
137
+ """
138
+ generate OpDef
139
+ """
140
+ gen_cc = ''
141
+ opdef_map_str = f"""
142
+ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
143
+
144
+ for operator_name, operator_data in yaml_data.items():
145
+ args = operator_data.get('args')
146
+ returns = operator_data.get('returns')
147
+ func_name = operator_data.get('func_name')
148
+ if func_name is None:
149
+ func_name = operator_name
150
+
151
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
152
+ opdef_map_str += f"""
153
+ {{"{operator_name}", &g{class_name}}},"""
154
+
155
+ opdef_cc = f"""
156
+ OpDef g{class_name} = {{
157
+ .name_ = "{operator_name}","""
158
+ opdef_cc += f"""
159
+ .args_ = {{"""
160
+
161
+ for arg_name, arg_info in args.items():
162
+ dtype = arg_info.get('dtype')
163
+ init = arg_info.get('init')
164
+ if init is None:
165
+ init = 0
166
+ else:
167
+ init = 1
168
+ cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').replace('tuple', 'array').replace(
169
+ 'list', 'array').upper()
170
+ cc_dtype_str.replace('TUPLE', 'ARRAY').replace('LIST', 'ARRAY')
171
+ opdef_cc += f"""
172
+ {{.arg_name_ = "{arg_name}", .arg_dtype_ = {cc_dtype_str}, .as_init_arg_ = {init}}},"""
173
+ opdef_cc += f"""
174
+ }},"""
175
+
176
+ opdef_cc += f"""
177
+ .returns_ = {{"""
178
+
179
+ for return_name, return_info in returns.items():
180
+ return_dtype = return_info.get('dtype')
181
+ cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').replace(
182
+ 'tuple', 'array').replace('list', 'array').upper()
183
+ opdef_cc += f"""
184
+ {{.arg_name_ = "{return_name}", .arg_dtype_ = {cc_return_type_str}}},"""
185
+
186
+ opdef_cc += f"""
187
+ }},"""
188
+
189
+ opdef_cc += f"""
190
+ }};"""
191
+ gen_cc += opdef_cc
192
+
193
+ opdef_map_str += f"""
194
+ }};"""
195
+ gen_cc += opdef_map_str
196
+ return gen_cc
197
+
198
+
199
+ if __name__ == "__main__":
200
+ work_path = ''
201
+ if len(sys.argv) > 1:
202
+ work_path = sys.argv[1]
203
+
204
+ yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops.yaml')
205
+ doc_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_doc.yaml')
206
+ op_py_path = os.path.join(work_path, 'mindspore/python/mindspore/gen_ops_def.py')
207
+ op_cc_path = os.path.join(work_path, 'mindspore/core/ops/gen_ops_def.cc')
208
+
209
+ yaml_str = None
210
+ with open(yaml_path, 'r') as yaml_file:
211
+ yaml_str = yaml.safe_load(yaml_file)
212
+
213
+ doc_str = None
214
+ with open(doc_yaml_path, 'r') as doc_file:
215
+ doc_str = yaml.safe_load(doc_file)
216
+
217
+ cc_code = generate_cc_opdef(yaml_str)
218
+ cc_code += f"""
219
+ }} // namespace mindspore::ops"""
220
+
221
+ py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd
222
+ #
223
+ # Licensed under the Apache License, Version 2.0 (the "License");
224
+ # you may not use this file except in compliance with the License.
225
+ # You may obtain a copy of the License at
226
+ #
227
+ # http://www.apache.org/licenses/LICENSE-2.0
228
+ #
229
+ # Unless required by applicable law or agreed to in writing, software
230
+ # distributed under the License is distributed on an "AS IS" BASIS,
231
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
232
+ # See the License for the specific language governing permissions and
233
+ # limitations under the License.
234
+ # ============================================================================
235
+ """
236
+ pyheader = f"""
237
+ \"\"\"Operators definition generated by gen_os.py, includes functions and primitive classes.\"\"\"
238
+
239
+ from mindspore.ops.primitive import Primitive
240
+ from mindspore.ops import operations as P
241
+ from mindspore.ops import functional as F
242
+ from mindspore.ops._primitive_cache import _get_cache_prim
243
+ from mindspore.ops.arg_dtype_cast import TypeCastKind, type_it
244
+ """
245
+ cc_license_str = f"""/**
246
+ * Copyright 2023 Huawei Technologies Co., Ltd
247
+ *
248
+ * Licensed under the Apache License, Version 2.0 (the "License");
249
+ * you may not use this file except in compliance with the License.
250
+ * You may obtain a copy of the License at
251
+ *
252
+ * http://www.apache.org/licenses/LICENSE-2.0
253
+ *
254
+ * Unless required by applicable law or agreed to in writing, software
255
+ * distributed under the License is distributed on an "AS IS" BASIS,
256
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
257
+ * See the License for the specific language governing permissions and
258
+ * limitations under the License.
259
+ */"""
260
+
261
+ ccheader = f"""
262
+ #include "op_def.h"
263
+ namespace mindspore::ops {{
264
+ """
265
+ py_prim = generate_py_primitive(yaml_str)
266
+ py_func = generate_py_op_func(yaml_str, doc_str)
267
+ py_file = None
268
+ with open(op_py_path, 'w') as py_file:
269
+ py_file.write(py_licence_str + pyheader + py_prim + py_func)
270
+
271
+ cc_file = None
272
+ with open(op_cc_path, 'w') as cc_file:
273
+ cc_file.write(cc_license_str + ccheader + cc_code)
mindspore/include/OWNERS CHANGED
@@ -1,6 +1,5 @@
1
1
  approvers:
2
2
  - jpc_chenjianping #
3
- - zhoufeng54
4
3
  - zhang_xue_tong
5
4
  reviewers:
6
5
  - lx0095
@@ -38,7 +38,8 @@ enum class DataType : int {
38
38
  kNumberTypeFloat16 = 42,
39
39
  kNumberTypeFloat32 = 43,
40
40
  kNumberTypeFloat64 = 44,
41
- kNumberTypeEnd = 46,
41
+ kNumberTypeBFloat16 = 46,
42
+ kNumberTypeEnd = 53,
42
43
  // add new enum here
43
44
  kInvalidType = INT32_MAX,
44
45
  };
@@ -24,38 +24,23 @@
24
24
  #include "include/api/types.h"
25
25
 
26
26
  namespace mindspore {
27
- class NetData;
28
- class Net;
29
-
30
27
  class MS_API Graph {
31
28
  public:
32
29
  class GraphData;
33
- enum Type : uint32_t {
34
- kExpressionGraph = 0, ///< graph as expression - can auto grad
35
- kExecutableGraph = 1, ///< graph is loaded as is
36
- kUnknownTypeGraph = 0xffffffff
37
- };
38
30
  Graph();
39
31
  explicit Graph(const std::shared_ptr<GraphData> &graph_data);
40
32
  explicit Graph(std::shared_ptr<GraphData> &&graph_data);
41
33
  explicit Graph(std::nullptr_t);
42
34
  ~Graph();
43
- explicit Graph(Type executable);
44
- explicit Graph(Net *net);
45
35
 
46
36
  enum ModelType ModelType() const;
47
37
  bool operator==(std::nullptr_t) const;
48
38
  bool operator!=(std::nullptr_t) const;
49
- bool IsExecutable() { return graph_type_ == kExecutableGraph; }
50
39
 
51
40
  private:
52
41
  friend class GraphCell;
53
42
  friend class ModelImpl;
54
- friend class NetImpl;
55
- friend class Model;
56
43
  std::shared_ptr<GraphData> graph_data_;
57
- std::shared_ptr<NetData> net_data_;
58
- Type graph_type_ = kExecutableGraph;
59
44
  };
60
45
  } // namespace mindspore
61
46
  #endif // MINDSPORE_INCLUDE_API_GRAPH_H
@@ -35,6 +35,8 @@ class MS_API Kernel : public IKernel<schema::Primitive> {
35
35
  Initialize();
36
36
  }
37
37
  virtual ~Kernel() = default;
38
+
39
+ int InferShape() override;
38
40
  /// \brief obtain kernel's type.
39
41
  ///
40
42
  /// \return kernel's type.