mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,152 +1,147 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """sgd"""
16
- from __future__ import absolute_import
17
-
18
- from mindspore.ops import operations as P
19
- from mindspore.common.tensor import Tensor
20
- import mindspore.common.dtype as mstype
21
- from mindspore import _checkparam as Validator
22
- from mindspore.nn.optim_ex.optimizer import Optimizer
23
-
24
-
25
- class SGD(Optimizer):
26
- """
27
- Stochastic Gradient Descent optimizer.
28
-
29
- .. math::
30
- v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
31
-
32
- If nesterov is True:
33
-
34
- .. math::
35
- p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
36
-
37
- If nesterov is False:
38
-
39
- .. math::
40
- p_{t+1} = p_{t} - lr \ast v_{t+1}
41
-
42
- To be noticed, for the first step, :math:`v_{t+1} = gradient`.
43
-
44
- Here : where p, v and u denote the parameters, accum, and momentum respectively.
45
-
46
- .. warning::
47
- This is an experimental optimizer API that is subject to change.
48
- This module must be used with lr scheduler module in `LRScheduler Class
49
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#lrscheduler>`_ .
50
-
51
- Args:
52
- params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
53
- parameter groups.
54
- lr (Union[int, float, Tensor]): learning rate.
55
- momentum (Union[int, float], optional): momentum factor. Default: ``0``.
56
- weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
57
- dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
58
- nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
59
-
60
- Keyword Args:
61
- maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
62
- Default: ``False``.
63
-
64
- Inputs:
65
- - **gradients** (tuple[Tensor]) - The gradients of `params`.
66
-
67
- Raises:
68
- ValueError: If the learning rate is not int, float or Tensor.
69
- ValueError: If the learning rate is less than 0.
70
- ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
71
- ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
72
- ValueError: If the `nesterov` and `maximize` is not bool.
73
- ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
74
-
75
- Supported Platforms:
76
- ``Ascend`` ``GPU`` ``CPU``
77
-
78
- Examples:
79
- >>> import mindspore
80
- >>> from mindspore import nn
81
- >>> # Define the network structure of LeNet5. Refer to
82
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
83
- >>> net = LeNet5()
84
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
85
- >>> optimizer = nn.optim_ex.SGD(net.trainable_params(), lr=0.1)
86
- >>> def forward_fn(data, label):
87
- ... logits = net(data)
88
- ... loss = loss_fn(logits, label)
89
- ... return loss, logits
90
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
91
- >>> def train_step(data, label):
92
- ... (loss, _), grads = grad_fn(data, label)
93
- ... optimizer(grads)
94
- ... return loss
95
- """
96
- def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
97
- maximize=False):
98
- Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
99
- if lr < 0.0:
100
- raise ValueError("Invalid learning rate: {}".format(lr))
101
- Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
102
- if momentum < 0.0:
103
- raise ValueError("Invalid momentum value: {}".format(momentum))
104
- momentum = float(momentum)
105
- Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
106
- Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
107
-
108
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
109
- weight_decay=weight_decay, nesterov=nesterov,
110
- maximize=maximize, grad_centralization=False)
111
- super(SGD, self).__init__(params, defaults)
112
- for group in self.param_groups:
113
- Validator.check_value_type("dampening", group["dampening"], [int, float], self.cls_name)
114
- group["dampening"] = float(group["dampening"])
115
- if nesterov and (momentum <= 0.0 or dampening != 0.0):
116
- raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
117
- "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
118
- self.accum = self.parameters.clone(prefix="accum", init='zeros')
119
- self.stat = self.parameters.clone(prefix="stat", init='ones')
120
- self.op_cast = P.Cast()
121
-
122
- def construct(self, gradients):
123
- for group_id, group in enumerate(self.param_groups):
124
- params = []
125
- grads = []
126
- accums = []
127
- stats = []
128
- params, grads, accums, stats = self._init_group(group, gradients, params, grads,
129
- accums, stats, group_id)
130
- opt = P.SGD(group["dampening"], group["weight_decay"], group["nesterov"])
131
- lr = group["lr"]
132
- if isinstance(lr, float):
133
- lr = self.op_cast(group["lr"], mstype.float32)
134
- momentum = self.op_cast(group["momentum"], mstype.float32)
135
- self.apply_sgd(opt, params, grads, lr, accums, momentum, stats, group["maximize"],
136
- group["grad_centralization"])
137
-
138
- def apply_sgd(self, opt, params, grads, lr, accums, momentum, stats, maximize, grad_centralization):
139
- grads = self._gradients_centralization(grad_centralization, grads)
140
-
141
- for i, param in enumerate(params):
142
- grad = grads[i] if not maximize else -grads[i]
143
- opt(param, grad, lr, accums[i], momentum, stats[i])
144
-
145
- def _init_group(self, group, gradients, params, accums, grads, stats, group_id):
146
- p_id = self.group_start_id[group_id]
147
- for i, param in enumerate(group["params"]):
148
- params.append(param)
149
- grads.append(gradients[p_id+i])
150
- accums.append(self.accum[p_id+i])
151
- stats.append(self.stat[p_id+i])
152
- return params, grads, accums, stats
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """sgd"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.tensor import Tensor
20
+ import mindspore.common.dtype as mstype
21
+ from mindspore import _checkparam as Validator
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+
24
+ _sgd_opt = C.MultitypeFuncGraph("sgd_opt")
25
+
26
+
27
+ @_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",)
28
+ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
29
+ """Apply sgd optimizer to the weight parameter using Tensor."""
30
+ success = True
31
+ success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
32
+ return success
33
+
34
+
35
+ class SGD(Optimizer):
36
+ r"""
37
+ Stochastic Gradient Descent optimizer.
38
+
39
+ .. math::
40
+ v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
41
+
42
+ If nesterov is True:
43
+
44
+ .. math::
45
+ p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
46
+
47
+ If nesterov is False:
48
+
49
+ .. math::
50
+ p_{t+1} = p_{t} - lr \ast v_{t+1}
51
+
52
+ To be noticed, for the first step, :math:`v_{t+1} = gradient`.
53
+
54
+ Here : where p, v and u denote the parameters, accum, and momentum respectively.
55
+
56
+ .. warning::
57
+ This is an experimental optimizer API that is subject to change.
58
+ This module must be used with lr scheduler module in `LRScheduler Class
59
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
60
+
61
+ Args:
62
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
63
+ parameter groups.
64
+ lr (Union[int, float, Tensor]): learning rate.
65
+ momentum (Union[int, float], optional): momentum factor. Default: ``0``.
66
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
67
+ dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
68
+ nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
69
+
70
+ Keyword Args:
71
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
72
+ Default: ``False``.
73
+
74
+ Inputs:
75
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
76
+
77
+ Raises:
78
+ ValueError: If the learning rate is not int, float or Tensor.
79
+ ValueError: If the learning rate is less than 0.
80
+ ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
81
+ ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
82
+ ValueError: If the `nesterov` and `maximize` is not bool.
83
+ ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
84
+
85
+ Supported Platforms:
86
+ ``Ascend`` ``GPU`` ``CPU``
87
+
88
+ Examples:
89
+ >>> import mindspore
90
+ >>> from mindspore import nn
91
+ >>> from mindspore.experimental import optim
92
+ >>> # Define the network structure of LeNet5. Refer to
93
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
94
+ >>> net = LeNet5()
95
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
96
+ >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1)
97
+ >>> def forward_fn(data, label):
98
+ ... logits = net(data)
99
+ ... loss = loss_fn(logits, label)
100
+ ... return loss, logits
101
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
102
+ >>> def train_step(data, label):
103
+ ... (loss, _), grads = grad_fn(data, label)
104
+ ... optimizer(grads)
105
+ ... return loss
106
+ """
107
+ def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
108
+ maximize=False):
109
+ Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
110
+ if lr < 0.0:
111
+ raise ValueError("Invalid learning rate: {}".format(lr))
112
+ Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
113
+ if momentum < 0.0:
114
+ raise ValueError("Invalid momentum value: {}".format(momentum))
115
+ momentum = float(momentum)
116
+ Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
117
+ Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
118
+
119
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
120
+ weight_decay=weight_decay, nesterov=nesterov,
121
+ maximize=maximize, grad_centralization=False)
122
+ super(SGD, self).__init__(params, defaults)
123
+ for group in self.param_groups:
124
+ Validator.check_value_type("dampening", group.get("dampening"), [int, float], self.cls_name)
125
+ group["dampening"] = float(group.get("dampening"))
126
+ if nesterov and (momentum <= 0.0 or dampening != 0.0):
127
+ raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
128
+ "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
129
+ self.accum = self.parameters.clone(prefix="accum", init='zeros')
130
+ self.stat = self.parameters.clone(prefix="stat", init='ones')
131
+ self.op_cast = P.Cast()
132
+
133
+ def construct(self, gradients):
134
+ for group_id, group in enumerate(self.param_groups):
135
+ opt = P.SGD(group.get("dampening"), group.get("weight_decay"), group.get("nesterov"))
136
+ lr = group.get("lr")
137
+ if isinstance(lr, float):
138
+ lr = self.op_cast(group.get("lr"), mstype.float32)
139
+ maximize = group.get("maximize")
140
+ momentum = self.op_cast(group.get("momentum"), mstype.float32)
141
+ start_id = self.group_start_id[group_id]
142
+ end_id = self.group_start_id[group_id+1]
143
+ grads = gradients[start_id: end_id] if not maximize else -gradients[start_id: end_id]
144
+ self.hyper_map(F.partial(_sgd_opt, opt, momentum, lr), grads,
145
+ self.parameters[start_id: end_id], self.accum[start_id: end_id],
146
+ self.stat[start_id: end_id])
147
+ return True
mindspore/gen_ops.py ADDED
@@ -0,0 +1,273 @@
1
+ # Copyright 2023-2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Generate operator definition from ops.yaml
17
+ """
18
+ import sys
19
+ import os
20
+ import yaml
21
+
22
+
23
+ def generate_py_op_func(yaml_data, doc_data):
24
+ """
25
+ generate python operator function
26
+ """
27
+ gen_py = ''
28
+
29
+ op_desc_dict = {}
30
+ for operator_name, operator_desc in doc_data.items():
31
+ desc = operator_desc.get("description")
32
+ op_desc_dict[operator_name] = desc
33
+
34
+ for operator_name, operator_data in yaml_data.items():
35
+ description = op_desc_dict.get(operator_name)
36
+ args = operator_data.get('args')
37
+ func_name = operator_data.get('func_name')
38
+ if func_name is None:
39
+ func_name = operator_name
40
+
41
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
42
+ func_args = []
43
+ primitive_init_args = []
44
+ input_args = []
45
+ for arg_name, arg_info in args.items():
46
+ dtype = arg_info.get('dtype')
47
+ init_value = arg_info.get('init')
48
+ if init_value:
49
+ if dtype == 'str':
50
+ init_value = '"' + init_value + '"'
51
+ func_args.append(f"""{arg_name}={init_value}""")
52
+ primitive_init_args.append(arg_name)
53
+ else:
54
+ func_args.append(arg_name)
55
+ input_args.append(arg_name)
56
+
57
+ function_code = f"""
58
+ def {func_name}({', '.join(arg for arg in func_args)}):
59
+ \"\"\"
60
+ {description}
61
+ \"\"\"
62
+ {operator_name}_op = _get_cache_prim(P.{class_name})({', '.join(arg_name for arg_name in primitive_init_args)})
63
+ return {operator_name}_op({', '.join(arg_name for arg_name in input_args)})
64
+ """
65
+ gen_py += function_code
66
+
67
+ return gen_py
68
+
69
+
70
+ def generate_py_primitive(yaml_data):
71
+ """
72
+ generate python primitive
73
+ """
74
+ gen_py = ''
75
+ for operator_name, operator_data in yaml_data.items():
76
+ args = operator_data.get('args')
77
+ func_name = operator_data.get('func_name')
78
+ if func_name is None:
79
+ func_name = operator_name
80
+
81
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
82
+
83
+ init_args_with_default = []
84
+ init_args = []
85
+ args_assign = []
86
+ for arg_name, arg_info in args.items():
87
+ dtype = arg_info.get('dtype')
88
+ type_cast = arg_info.get('type_cast')
89
+ type_cast_set = None
90
+ if type_cast:
91
+ type_cast_set = {ct.strip() for ct in type_cast.split(",")}
92
+
93
+ init_value = arg_info.get('init')
94
+ if init_value is None:
95
+ continue
96
+
97
+ if dtype == 'str':
98
+ init_value = '"' + init_value + '"'
99
+ init_args_with_default.append(f"""{arg_name}={init_value}""")
100
+ init_args.append(arg_name)
101
+
102
+ assign_str = f""" self.{arg_name} = """
103
+
104
+ if type_cast_set:
105
+ assign_str += f'type_it({arg_name}, '
106
+ type_cast_list = []
107
+
108
+ if 'int' in type_cast_set:
109
+ type_cast_list.append('INT')
110
+ if 'tuple[int]' in type_cast_list:
111
+ type_cast_list.append('TUPLE')
112
+ #add more type cast kind here
113
+
114
+ assign_str += 'TypeCastKind.' + '_OR_'.join(ct for ct in type_cast_list)
115
+ if dtype == 'tuple[int]':
116
+ assign_str += '_TO_TUPLE)'
117
+ if dtype == 'list[int]':
118
+ assign_str += '_TO_LIST)'
119
+ else:
120
+ assign_str += arg_name
121
+ args_assign.append(assign_str)
122
+
123
+ args_assign = '\n'.join([assign for assign in args_assign])
124
+ primitive_code = f"""
125
+ class {class_name}(Primitive):
126
+ def __init__(self, {', '.join(init_args_with_default)}):
127
+ {args_assign}
128
+ def __call__(self, *args):
129
+ super.__call__(self, *args, {', '.join([f'self.{arg}' for arg in init_args])})
130
+ """
131
+
132
+ gen_py += primitive_code
133
+ return gen_py
134
+
135
+
136
+ def generate_cc_opdef(yaml_data):
137
+ """
138
+ generate OpDef
139
+ """
140
+ gen_cc = ''
141
+ opdef_map_str = f"""
142
+ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
143
+
144
+ for operator_name, operator_data in yaml_data.items():
145
+ args = operator_data.get('args')
146
+ returns = operator_data.get('returns')
147
+ func_name = operator_data.get('func_name')
148
+ if func_name is None:
149
+ func_name = operator_name
150
+
151
+ class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
152
+ opdef_map_str += f"""
153
+ {{"{operator_name}", &g{class_name}}},"""
154
+
155
+ opdef_cc = f"""
156
+ OpDef g{class_name} = {{
157
+ .name_ = "{operator_name}","""
158
+ opdef_cc += f"""
159
+ .args_ = {{"""
160
+
161
+ for arg_name, arg_info in args.items():
162
+ dtype = arg_info.get('dtype')
163
+ init = arg_info.get('init')
164
+ if init is None:
165
+ init = 0
166
+ else:
167
+ init = 1
168
+ cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').replace('tuple', 'array').replace(
169
+ 'list', 'array').upper()
170
+ cc_dtype_str.replace('TUPLE', 'ARRAY').replace('LIST', 'ARRAY')
171
+ opdef_cc += f"""
172
+ {{.arg_name_ = "{arg_name}", .arg_dtype_ = {cc_dtype_str}, .as_init_arg_ = {init}}},"""
173
+ opdef_cc += f"""
174
+ }},"""
175
+
176
+ opdef_cc += f"""
177
+ .returns_ = {{"""
178
+
179
+ for return_name, return_info in returns.items():
180
+ return_dtype = return_info.get('dtype')
181
+ cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').replace(
182
+ 'tuple', 'array').replace('list', 'array').upper()
183
+ opdef_cc += f"""
184
+ {{.arg_name_ = "{return_name}", .arg_dtype_ = {cc_return_type_str}}},"""
185
+
186
+ opdef_cc += f"""
187
+ }},"""
188
+
189
+ opdef_cc += f"""
190
+ }};"""
191
+ gen_cc += opdef_cc
192
+
193
+ opdef_map_str += f"""
194
+ }};"""
195
+ gen_cc += opdef_map_str
196
+ return gen_cc
197
+
198
+
199
+ if __name__ == "__main__":
200
+ work_path = ''
201
+ if len(sys.argv) > 1:
202
+ work_path = sys.argv[1]
203
+
204
+ yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops.yaml')
205
+ doc_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_doc.yaml')
206
+ op_py_path = os.path.join(work_path, 'mindspore/python/mindspore/gen_ops_def.py')
207
+ op_cc_path = os.path.join(work_path, 'mindspore/core/ops/gen_ops_def.cc')
208
+
209
+ yaml_str = None
210
+ with open(yaml_path, 'r') as yaml_file:
211
+ yaml_str = yaml.safe_load(yaml_file)
212
+
213
+ doc_str = None
214
+ with open(doc_yaml_path, 'r') as doc_file:
215
+ doc_str = yaml.safe_load(doc_file)
216
+
217
+ cc_code = generate_cc_opdef(yaml_str)
218
+ cc_code += f"""
219
+ }} // namespace mindspore::ops"""
220
+
221
+ py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd
222
+ #
223
+ # Licensed under the Apache License, Version 2.0 (the "License");
224
+ # you may not use this file except in compliance with the License.
225
+ # You may obtain a copy of the License at
226
+ #
227
+ # http://www.apache.org/licenses/LICENSE-2.0
228
+ #
229
+ # Unless required by applicable law or agreed to in writing, software
230
+ # distributed under the License is distributed on an "AS IS" BASIS,
231
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
232
+ # See the License for the specific language governing permissions and
233
+ # limitations under the License.
234
+ # ============================================================================
235
+ """
236
+ pyheader = f"""
237
+ \"\"\"Operators definition generated by gen_os.py, includes functions and primitive classes.\"\"\"
238
+
239
+ from mindspore.ops.primitive import Primitive
240
+ from mindspore.ops import operations as P
241
+ from mindspore.ops import functional as F
242
+ from mindspore.ops._primitive_cache import _get_cache_prim
243
+ from mindspore.ops.arg_dtype_cast import TypeCastKind, type_it
244
+ """
245
+ cc_license_str = f"""/**
246
+ * Copyright 2023 Huawei Technologies Co., Ltd
247
+ *
248
+ * Licensed under the Apache License, Version 2.0 (the "License");
249
+ * you may not use this file except in compliance with the License.
250
+ * You may obtain a copy of the License at
251
+ *
252
+ * http://www.apache.org/licenses/LICENSE-2.0
253
+ *
254
+ * Unless required by applicable law or agreed to in writing, software
255
+ * distributed under the License is distributed on an "AS IS" BASIS,
256
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
257
+ * See the License for the specific language governing permissions and
258
+ * limitations under the License.
259
+ */"""
260
+
261
+ ccheader = f"""
262
+ #include "op_def.h"
263
+ namespace mindspore::ops {{
264
+ """
265
+ py_prim = generate_py_primitive(yaml_str)
266
+ py_func = generate_py_op_func(yaml_str, doc_str)
267
+ py_file = None
268
+ with open(op_py_path, 'w') as py_file:
269
+ py_file.write(py_licence_str + pyheader + py_prim + py_func)
270
+
271
+ cc_file = None
272
+ with open(op_cc_path, 'w') as cc_file:
273
+ cc_file.write(cc_license_str + ccheader + cc_code)
mindspore/include/OWNERS CHANGED
@@ -1,6 +1,5 @@
1
1
  approvers:
2
2
  - jpc_chenjianping #
3
- - zhoufeng54
4
3
  - zhang_xue_tong
5
4
  reviewers:
6
5
  - lx0095
@@ -38,7 +38,8 @@ enum class DataType : int {
38
38
  kNumberTypeFloat16 = 42,
39
39
  kNumberTypeFloat32 = 43,
40
40
  kNumberTypeFloat64 = 44,
41
- kNumberTypeEnd = 46,
41
+ kNumberTypeBFloat16 = 46,
42
+ kNumberTypeEnd = 53,
42
43
  // add new enum here
43
44
  kInvalidType = INT32_MAX,
44
45
  };
@@ -24,38 +24,23 @@
24
24
  #include "include/api/types.h"
25
25
 
26
26
  namespace mindspore {
27
- class NetData;
28
- class Net;
29
-
30
27
  class MS_API Graph {
31
28
  public:
32
29
  class GraphData;
33
- enum Type : uint32_t {
34
- kExpressionGraph = 0, ///< graph as expression - can auto grad
35
- kExecutableGraph = 1, ///< graph is loaded as is
36
- kUnknownTypeGraph = 0xffffffff
37
- };
38
30
  Graph();
39
31
  explicit Graph(const std::shared_ptr<GraphData> &graph_data);
40
32
  explicit Graph(std::shared_ptr<GraphData> &&graph_data);
41
33
  explicit Graph(std::nullptr_t);
42
34
  ~Graph();
43
- explicit Graph(Type executable);
44
- explicit Graph(Net *net);
45
35
 
46
36
  enum ModelType ModelType() const;
47
37
  bool operator==(std::nullptr_t) const;
48
38
  bool operator!=(std::nullptr_t) const;
49
- bool IsExecutable() { return graph_type_ == kExecutableGraph; }
50
39
 
51
40
  private:
52
41
  friend class GraphCell;
53
42
  friend class ModelImpl;
54
- friend class NetImpl;
55
- friend class Model;
56
43
  std::shared_ptr<GraphData> graph_data_;
57
- std::shared_ptr<NetData> net_data_;
58
- Type graph_type_ = kExecutableGraph;
59
44
  };
60
45
  } // namespace mindspore
61
46
  #endif // MINDSPORE_INCLUDE_API_GRAPH_H
@@ -35,6 +35,8 @@ class MS_API Kernel : public IKernel<schema::Primitive> {
35
35
  Initialize();
36
36
  }
37
37
  virtual ~Kernel() = default;
38
+
39
+ int InferShape() override;
38
40
  /// \brief obtain kernel's type.
39
41
  ///
40
42
  /// \return kernel's type.