mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,222 @@ import inspect
21
21
  import json
22
22
  import os
23
23
  import functools
24
+ import platform
25
+ import hashlib
26
+ import shutil
24
27
 
25
28
  from mindspore._c_expression import Oplib
26
29
  from mindspore import _checkparam as validator
30
+ from mindspore import log as logger
31
+
32
+ if platform.system() == "Linux":
33
+ import fcntl
27
34
 
28
35
  # path of built-in op info register.
29
36
  BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
30
37
  BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op"
31
38
 
32
39
 
40
+ def _get_reg_info_attr(op_info, attr_name):
41
+ """get attr value"""
42
+ for _, item in enumerate(op_info.get("attr", [])):
43
+ if item.get("name") == attr_name:
44
+ return item.get("defaultValue")
45
+ return None
46
+
47
+
48
+ class _CustomInstaller:
49
+ """save custom op registration information to a json file which will be used by GE"""
50
+ reg_info_hash = [] # used to avoid writing the same reg info to file multiple times
51
+ copied_paths = [] # used to avoid copying the same file multiple times
52
+
53
+ def __init__(self, op_info, func=None):
54
+ self.op_info = op_info
55
+ self.func = func
56
+ self.op_type = op_info.get("op_name") if not func else func.__name__
57
+ vendor_name = "ms"
58
+ custom_dir = os.path.join(os.path.realpath("./"), "vendors", vendor_name)
59
+ self._set_env(custom_dir)
60
+ op_impl_dir = os.path.join(custom_dir, "op_impl")
61
+ self.ai_core_config_dir = os.path.join(op_impl_dir, "ai_core", "tbe", "config")
62
+ self.ai_core_impl_dir = os.path.join(op_impl_dir, "ai_core", "tbe", vendor_name + "_impl")
63
+ self.ai_cpu_config_dir = os.path.join(op_impl_dir, "cpu", "config")
64
+ self.ai_cpu_impl_dir = os.path.join(op_impl_dir, "cpu", "aicpu_kernel", "impl")
65
+
66
+ @staticmethod
67
+ def _set_env(custom_opp_path):
68
+ """set custom file path to env"""
69
+ if not os.environ.get("ASCEND_CUSTOM_OPP_PATH"):
70
+ os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path
71
+ else:
72
+ paths = os.environ["ASCEND_CUSTOM_OPP_PATH"].split(':')
73
+ if custom_opp_path not in paths:
74
+ os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path + ':' + os.environ["ASCEND_CUSTOM_OPP_PATH"]
75
+
76
+ @staticmethod
77
+ def _create_dir(*dir_names):
78
+ """create directory"""
79
+ for dir_name in dir_names:
80
+ if not os.path.isdir(dir_name):
81
+ try:
82
+ os.makedirs(dir_name, exist_ok=True)
83
+ except OSError as err:
84
+ if err.errno == 17: # File exists
85
+ pass
86
+ else:
87
+ raise err
88
+
89
+ @staticmethod
90
+ def _copy_file(src_path, dst_dir):
91
+ """copy file"""
92
+ if not os.path.exists(src_path) or src_path in _CustomInstaller.copied_paths:
93
+ return
94
+ _CustomInstaller.copied_paths.append(src_path)
95
+ if os.path.isfile(src_path):
96
+ lock_file = os.path.join(dst_dir, "file.lock")
97
+ with open(lock_file, "w") as f:
98
+ fcntl.flock(f.fileno(), fcntl.LOCK_EX)
99
+ shutil.copy(src_path, dst_dir)
100
+
101
+ def _check(self):
102
+ """check if the reg info need written"""
103
+ if platform.system() != "Linux":
104
+ return False
105
+ if not os.environ.get("MS_DEV_CUSTOM_OPP_PATH"):
106
+ # only process the first time import the mindspore module
107
+ return False
108
+ if self.op_info.get("target") in ["GPU", "CPU"]:
109
+ return False
110
+ sha256 = hashlib.sha256()
111
+ value = json.dumps(self.op_info, sort_keys=True).encode()
112
+ sha256.update(value)
113
+ hash_value = sha256.hexdigest()
114
+ if hash_value in _CustomInstaller.reg_info_hash:
115
+ return False
116
+ _CustomInstaller.reg_info_hash.append(hash_value)
117
+ return True
118
+
119
+ def _find_ai_cpu_so_path(self, so_file):
120
+ """find the absolute path of so"""
121
+ current_path = os.path.dirname(os.path.abspath(__file__))
122
+ search_paths = [current_path + "/../lib", current_path + "/../lib/plugin/ascend"]
123
+ for path in search_paths:
124
+ so_path = os.path.join(path, so_file)
125
+ if os.path.exists(so_path):
126
+ return so_path
127
+ logger.warning("For Custom op '{}', can not find the aicpu so file '{}' in the following directories:\n{}"
128
+ .format(self.op_type, so_file, "\n".join(search_paths)))
129
+ return ""
130
+
131
+ def _gen_ai_core_reg_info(self, imply_path, func_name):
132
+ """generate reg info"""
133
+
134
+ def _get_dtype_format(idx):
135
+ data_type = []
136
+ data_format = []
137
+ for _, dtype_format in enumerate(self.op_info.get("dtype_format", [])):
138
+ if not dtype_format[idx][0]:
139
+ data_type = None
140
+ else:
141
+ data_type.append(dtype_format[idx][0])
142
+ if not dtype_format[idx][1]:
143
+ data_format = None
144
+ else:
145
+ if dtype_format[idx][1] == "DefaultFormat":
146
+ data_format.append("ND")
147
+ else:
148
+ data_format.append(dtype_format[idx][1])
149
+ return data_type, data_format
150
+
151
+ op_info = {"opFile": {"value": os.path.splitext(os.path.basename(imply_path))[0]},
152
+ "opInterface": {"value": func_name}}
153
+ # attr
154
+ attrs_name = []
155
+ for _, item in enumerate(self.op_info.get("attr", [])):
156
+ attr_name = item.get("name")
157
+ attrs_name.append(attr_name)
158
+ key = "attr_" + attr_name
159
+ op_info[key] = {}
160
+ for k, v in item.items():
161
+ if k != "name":
162
+ op_info[key][k] = v
163
+ if attrs_name:
164
+ op_info["attr"] = {"list": ",".join(attrs_name)}
165
+ # input and output
166
+ inputs = self.op_info.get("inputs", [])
167
+ outputs = self.op_info.get("outputs", [])
168
+ input_num = len(inputs)
169
+ output_num = len(outputs)
170
+ for i in range(input_num + output_num):
171
+ item = inputs[i] if i < input_num else outputs[i - input_num]
172
+ key = "input" if i < input_num else "output"
173
+ key += str(item.get("index"))
174
+ op_info[key] = {"name": item.get("name"),
175
+ "paramType": item.get("paramType", "required"),
176
+ "shape": item.get("shape", "all")}
177
+ dtype, formats = _get_dtype_format(i)
178
+ if dtype:
179
+ op_info[key]["dtype"] = ",".join(dtype)
180
+ if formats:
181
+ op_info[key]["format"] = ",".join(formats)
182
+ return op_info
183
+
184
+ def _gen_ai_cpu_reg_info(self, so_file):
185
+ """generate reg info"""
186
+ op_info = {"opInfo": {"computeCost": "100",
187
+ "engine": "DNN_VM_AICPU",
188
+ "flagAsync": "False",
189
+ "flagPartial": "False",
190
+ "functionName": "RunCpuKernel",
191
+ "kernelSo": so_file,
192
+ "opKernelLib": "CUSTAICPUKernel",
193
+ "userDefined": "True"}}
194
+ return op_info
195
+
196
+ def _save_op_info(self, dst_dir, file_name, op_info):
197
+ """save op info file"""
198
+ repo = {}
199
+ save_path = os.path.join(dst_dir, file_name)
200
+ lock_file = os.path.join(dst_dir, "file.lock")
201
+ with open(lock_file, "w") as f:
202
+ fcntl.flock(f.fileno(), fcntl.LOCK_EX)
203
+ if os.path.isfile(save_path):
204
+ with open(save_path, 'r') as fr:
205
+ json_str = fr.read()
206
+ json_str = "{}" if json_str == "" else json_str
207
+ repo = json.loads(json_str)
208
+ repo.update({self.op_type: op_info})
209
+ with os.fdopen(os.open(save_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as fw:
210
+ json.dump(repo, fw, sort_keys=True, indent=4, separators=(',', ':'))
211
+
212
+ def run(self):
213
+ """save reg info to file"""
214
+ if not self._check():
215
+ return
216
+ so_name = _get_reg_info_attr(self.op_info, "cust_aicpu")
217
+ if so_name:
218
+ _CustomInstaller._create_dir(self.ai_cpu_config_dir, self.ai_cpu_impl_dir)
219
+ # copy so file
220
+ so_file = "lib" + so_name + ".so"
221
+ imply_path = self._find_ai_cpu_so_path(so_file)
222
+ self._copy_file(imply_path, self.ai_cpu_impl_dir)
223
+ # generate and copy reg info file
224
+ op_info = self._gen_ai_cpu_reg_info(so_file)
225
+ self._save_op_info(self.ai_cpu_config_dir, "cust_aicpu_kernel.json", op_info)
226
+ else:
227
+ _CustomInstaller._create_dir(self.ai_core_config_dir, self.ai_core_impl_dir)
228
+ # copy dsl file
229
+ imply_path = os.path.realpath(inspect.getfile(self.func))
230
+ self._copy_file(imply_path, self.ai_core_impl_dir)
231
+ # generate and copy reg info file
232
+ op_info = self._gen_ai_core_reg_info(imply_path, self.func.__name__)
233
+ self._copy_file(imply_path, self.ai_core_impl_dir)
234
+ for arc_name in ["ascend910", "ascend910b"]:
235
+ arc_dir = os.path.join(self.ai_core_config_dir, arc_name)
236
+ _CustomInstaller._create_dir(arc_dir)
237
+ self._save_op_info(arc_dir, "aic-{}-ops-info.json".format(arc_name), op_info)
238
+
239
+
33
240
  def op_info_register(op_info):
34
241
  r"""
35
242
  A decorator which is used to register an operator.
@@ -125,6 +332,12 @@ def custom_info_register(*reg_info):
125
332
 
126
333
  def decorator(func):
127
334
  setattr(func, "reg_info", reg_info)
335
+ if reg_info:
336
+ used_reg_info = reg_info[0]
337
+ if isinstance(used_reg_info, dict):
338
+ # ai_cpu should be parsed inside CustomRegOp, skip it here
339
+ if not _get_reg_info_attr(used_reg_info, "cust_aicpu"):
340
+ _CustomInstaller(used_reg_info, func).run()
128
341
 
129
342
  @functools.wraps(func)
130
343
  def wrapper(*args, **kwargs):
@@ -140,7 +353,7 @@ class RegOp:
140
353
  Base class for op info register.
141
354
 
142
355
  Args:
143
- op_name (str): Name of op.
356
+ op_name (str): Name of operator.
144
357
  """
145
358
 
146
359
  def __init__(self, op_name=""):
@@ -446,10 +659,10 @@ class AkgCpuRegOp(AkgRegOp):
446
659
 
447
660
  class AiCPURegOp(CpuRegOp):
448
661
  r"""
449
- Class for AiCPU operator information register.
662
+ Class for AiCPU operator information registration.
450
663
 
451
664
  Args:
452
- op_name (str):kernel name.
665
+ op_name (str): Name of operator.
453
666
 
454
667
  Examples:
455
668
  >>> from mindspore.ops import AiCPURegOp, DataType
@@ -481,14 +694,15 @@ class AiCPURegOp(CpuRegOp):
481
694
 
482
695
  class TBERegOp(RegOp):
483
696
  r"""
484
- Class for TBE operator information register.
697
+ Class for TBE operator information registration. TBE (Tensor Boost Engine) is the Ascend operator development
698
+ tool, which is extended on the basis of the TVM framework to develop custom operators.
485
699
 
486
700
  Args:
487
- op_name (str):kernel name.
701
+ op_name (str): Name of operator.
488
702
 
489
703
  Examples:
490
- >>> import mindspore.ops as ops
491
- >>> op_name_op_info = ops.TBERegOp("OpName") \
704
+ >>> from mindspore.ops import TBERegOp, DataType
705
+ >>> op_name_op_info = TBERegOp("OpName") \
492
706
  ... .fusion_type("ELEMWISE") \
493
707
  ... .async_flag(False) \
494
708
  ... .binfile_name("op_name.so") \
@@ -505,14 +719,14 @@ class TBERegOp(RegOp):
505
719
  ... .input(0, "x2", None, "required", None) \
506
720
  ... .input(1, "axis", None, "required", None) \
507
721
  ... .output(0, "y", True, "required", "all") \
508
- ... .real_input_index([1, 0])
509
- ... .input_to_attr_index([2])
510
- ... .unknown_shape_formats(["ND", "ND", "ND", "ND"])
722
+ ... .real_input_index([1, 0]) \
723
+ ... .input_to_attr_index([2]) \
724
+ ... .unknown_shape_formats(["ND", "ND", "ND", "ND"]) \
511
725
  ... .reshape_type("NC") \
512
726
  ... .is_dynamic_format(True) \
513
- ... .dtype_format(DataType.F16_None, DataType.F16_None) \
514
- ... .dtype_format(DataType.F32_None, DataType.F32_None) \
515
- ... .dtype_format(DataType.I32_None, DataType.I32_None) \
727
+ ... .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None) \
728
+ ... .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
729
+ ... .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None) \
516
730
  ... .get_op_info()
517
731
  >>>
518
732
  """
@@ -830,7 +1044,7 @@ class CustomRegOp(RegOp):
830
1044
 
831
1045
  Tutorial Examples:
832
1046
  - `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
833
- <https://mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html#
1047
+ <https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
834
1048
  defining-custom-operator-of-aicpu-type>`_
835
1049
  """
836
1050
  param_list = [index, name, param_type]
@@ -870,7 +1084,7 @@ class CustomRegOp(RegOp):
870
1084
 
871
1085
  Tutorial Examples:
872
1086
  - `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
873
- <https://mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html#
1087
+ <https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
874
1088
  defining-custom-operator-of-aicpu-type>`_
875
1089
  """
876
1090
  param_list = [index, name, param_type]
@@ -898,7 +1112,7 @@ class CustomRegOp(RegOp):
898
1112
 
899
1113
  Tutorial Examples:
900
1114
  - `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
901
- <https://mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html#
1115
+ <https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
902
1116
  defining-custom-operator-of-aicpu-type>`_
903
1117
  """
904
1118
  io_nums = len(self.inputs) + len(self.outputs)
@@ -955,7 +1169,7 @@ class CustomRegOp(RegOp):
955
1169
 
956
1170
  Tutorial Examples:
957
1171
  - `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
958
- <https://mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html#
1172
+ <https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
959
1173
  defining-custom-operator-of-aicpu-type>`_
960
1174
  """
961
1175
  param_list = [name, param_type, value_type, default_value]
@@ -981,7 +1195,7 @@ class CustomRegOp(RegOp):
981
1195
 
982
1196
  Tutorial Examples:
983
1197
  - `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
984
- <https://mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html#
1198
+ <https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
985
1199
  defining-custom-operator-of-aicpu-type>`_
986
1200
  """
987
1201
  if target is not None:
@@ -996,7 +1210,7 @@ class CustomRegOp(RegOp):
996
1210
 
997
1211
  Tutorial Examples:
998
1212
  - `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
999
- <https://mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html#
1213
+ <https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
1000
1214
  defining-custom-operator-of-aicpu-type>`_
1001
1215
  """
1002
1216
  op_info = {}
@@ -1004,6 +1218,8 @@ class CustomRegOp(RegOp):
1004
1218
  if isinstance(k, str) and k.endswith('_'):
1005
1219
  k = k.rstrip('_')
1006
1220
  op_info[k] = v
1221
+ if _get_reg_info_attr(op_info, "cust_aicpu"):
1222
+ _CustomInstaller(op_info).run()
1007
1223
  return op_info
1008
1224
 
1009
1225
 
@@ -21,7 +21,7 @@ A collection of operators to build neural networks or to compute functions.
21
21
 
22
22
  from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
23
23
  MapUniform, DynamicAssign, PadAndShift)
24
- from ._inner_ops import (MatmulDDS, DSDMatmul, Cummin, ExtractImagePatches)
24
+ from ._inner_ops import (MatmulDDS, DSDMatmul, Cummin, ExtractImagePatches, SelectView, CopyWithSlice)
25
25
  from ._quant_ops import *
26
26
  from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
27
27
  CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
@@ -29,20 +29,20 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg
29
29
  LoadIm2Col, UpdateThorGradient, CholeskyTrsm,
30
30
  DetTriangle, ProdForceSeA)
31
31
  from ._ms_kernel import (ms_kernel, kernel)
32
- from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchToSpace, BatchToSpaceND,
32
+ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchToSpace,
33
33
  BatchToSpaceNDV2, BroadcastTo, Cast, Coalesce, Concat, Cummax, DType, DepthToSpace, Diag,
34
- DiagPart, DynamicShape, EditDistance, EmbeddingLookup, ExpandDims, ExtractVolumePatches,
35
- Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation,
34
+ DiagPart, EditDistance, EmbeddingLookup, ExpandDims, ExtractVolumePatches,
35
+ Eye, Fill, Gather, GatherD, GatherNd, Identity, Im2Col, InvertPermutation,
36
36
  LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
37
- Pack, Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
38
- ReverseSequence, ReverseV2, Rint, ScalarToArray, ScalarToTensor, ScatterAdd,
37
+ Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
38
+ ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
39
39
  ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
40
- ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterNonAliasingAdd, ScatterSub,
40
+ ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterSub,
41
41
  ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
42
42
  SpaceToDepth, SparseGatherV2, Split, SplitV, Squeeze, Stack, StridedSlice, TensorScatterAdd,
43
43
  TensorScatterDiv, TensorScatterMax, TensorScatterMin, TensorScatterMul, TensorScatterSub,
44
44
  TensorScatterUpdate, TensorShape, Tile, TopK, TransShape, Transpose, TupleToArray, Unique,
45
- UniqueWithPad, Unpack, UnsortedSegmentMax, UnsortedSegmentMin, UnsortedSegmentProd,
45
+ UniqueWithPad, UnsortedSegmentMax, UnsortedSegmentMin, UnsortedSegmentProd,
46
46
  UnsortedSegmentSum, Unstack, UpperBound, Zeros, ZerosLike, AffineGrid, Bincount, CheckNumerics,
47
47
  HammingWindow, IdentityN, IndexFill, LeftShift, ListDiff, LogSpace, MatrixBandPart,
48
48
  MatrixDiagPartV3, MatrixDiagV3, MatrixSetDiagV3, NonZero, Expand, Col2Im, ConjugateTranspose,
@@ -69,7 +69,7 @@ from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerA
69
69
  from .linalg_ops import (Svd, Geqrf)
70
70
  from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
71
71
  BitwiseAnd, BitwiseOr, Ger, BitwiseXor, Inv, Invert, ApproximateEqual,
72
- InplaceAdd, InplaceSub, InplaceUpdate, InplaceUpdateV2,
72
+ InplaceAdd, InplaceSub, InplaceUpdateV2,
73
73
  ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
74
74
  Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
75
75
  Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
@@ -79,7 +79,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
79
79
  NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
80
80
  Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
81
81
  Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv,
82
- Addcmul, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps,
82
+ Addcmul, Square, Sub, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps,
83
83
  Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose, LuSolve,
84
84
  CholeskyInverse, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselK1, BesselK1e, BesselY0,
85
85
  BesselY1, Bucketize, Cauchy, Cholesky, CholeskySolve, Betainc,
@@ -92,14 +92,14 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
92
92
  from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
93
93
  ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
94
94
  DepthwiseConv2dNative,
95
- DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
96
- InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
97
- GeLU, Gelu, FastGeLU, FastGelu, Elu, CeLU,
95
+ Dropout, Dropout2D, Dropout3D, Flatten,
96
+ InstanceNorm,
97
+ GeLU, FastGeLU, Elu, CeLU,
98
98
  GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
99
99
  LogSoftmax, MaxPool3D, AvgPool3D,
100
100
  MaxPool, DataFormatDimMap,
101
101
  AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
102
- MaxPoolWithArgmax, MaxPoolWithArgmaxV2, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2,
102
+ MaxPoolWithArgmaxV2, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2,
103
103
  HSwish, HSigmoid,
104
104
  ResizeBilinear, Sigmoid, SeLU, HShrink, ApplyKerasMomentum,
105
105
  SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
@@ -115,13 +115,13 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
115
115
  ApplyAdamWithAmsgrad, ApplyAdamWithAmsgradV2, AdaptiveAvgPool3D, AdaptiveMaxPool2D,
116
116
  AdaptiveMaxPool3D,
117
117
  GridSampler3D, MaxPool3DWithArgmax, MaxUnpool2D, NuclearNorm, NthElement, MultilabelMarginLoss,
118
- Dilation2D, DataFormatVecPermute, DeformableOffsets, FractionalAvgPool,
118
+ Dilation2D, DataFormatVecPermute, DeformableOffsets, Dense, FractionalAvgPool,
119
119
  FractionalMaxPool, FractionalMaxPool3DWithFixedKsize, FractionalMaxPoolWithFixedKsize,
120
120
  GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3, ChannelShuffle,
121
121
  GLU, MaxUnpool3D, Pdist)
122
122
  from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
123
123
  ConfusionMatrix, UpdateState, Load, StopGradient,
124
- CheckValid, Partial, Depend, identity, Push, Pull, PyExecute, PyFunc, _DynamicLossScale,
124
+ CheckValid, Partial, Depend, Push, Pull, PyExecute, PyFunc, _DynamicLossScale,
125
125
  SampleDistortedBoundingBoxV2)
126
126
  from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
127
127
  RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@@ -129,8 +129,13 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamm
129
129
  ParameterizedTruncatedNormal, RandomPoisson, MultinomialWithReplacement, RandomShuffle,
130
130
  RandpermV2)
131
131
  from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
132
- from .sparse_ops import (SparseToDense, SparseTensorDenseMatmul, SparseTensorDenseAdd, SparseSlice)
132
+ from .sparse_ops import (
133
+ SparseToDense, SparseTensorDenseMatmul, SparseTensorDenseAdd, SparseSlice)
133
134
  from .spectral_ops import (BartlettWindow, BlackmanWindow)
135
+ from ..deprecated import (identity, DropoutDoMask, MaxPoolWithArgmax,
136
+ BNTrainingReduce, BNTrainingUpdate, DropoutGenMask, Gelu, FastGelu,
137
+ TensorAdd, InplaceUpdate, ScatterNonAliasingAdd,
138
+ BatchToSpaceND, Unpack, GatherV2, DynamicShape, ScalarToArray, Pack)
134
139
 
135
140
  __all__ = [
136
141
  'HSVToRGB',
@@ -616,7 +621,10 @@ __all__ = [
616
621
  "CumulativeLogsumexp",
617
622
  "DataFormatVecPermute",
618
623
  "DeformableOffsets",
624
+ "Dense",
619
625
  "ExtractImagePatches",
626
+ "SelectView",
627
+ "CopyWithSlice",
620
628
  "FillDiagonal",
621
629
  "Fills",
622
630
  "Gcd",
@@ -390,7 +390,7 @@ class Conv2DBackpropFilter(Primitive):
390
390
  stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
391
391
  dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
392
392
  group (int): Splits input into groups. Default: 1.
393
- data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\
393
+ data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
394
394
  default is 'NCHW'.
395
395
 
396
396
  Returns:
@@ -636,7 +636,7 @@ class EinsumGrad(PrimitiveWithInfer):
636
636
 
637
637
  @prim_attr_register
638
638
  def __init__(self, equation):
639
- self.add_prim_attr('equation', equation)
639
+ pass
640
640
 
641
641
  def infer_shape(self, x_shapes, dout_shape):
642
642
  out_shape = ()
@@ -1521,9 +1521,11 @@ class LSTMGrad(Primitive):
1521
1521
  """Computes the data and weight gradients of LSTM."""
1522
1522
 
1523
1523
  @prim_attr_register
1524
- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1524
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
1525
1525
  self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1526
1526
  self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1527
+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
1528
+ 'proj_size', self.name)
1527
1529
  self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1528
1530
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1529
1531
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
@@ -2573,7 +2575,12 @@ class MultilabelMarginLossGrad(Primitive):
2573
2575
  Compute the gradients of MultilabelMarginLoss operation.
2574
2576
 
2575
2577
  Args:
2576
- reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
2578
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2579
+ ``'sum'`` . Default: ``'mean'`` .
2580
+
2581
+ - ``'none'``: no reduction will be applied.
2582
+ - ``'mean'``: compute and return the mean of elements in the output.
2583
+ - ``'sum'``: the output elements will be summed.
2577
2584
 
2578
2585
  Inputs:
2579
2586
  - **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
@@ -2595,7 +2602,7 @@ class MultilabelMarginLossGrad(Primitive):
2595
2602
  TypeError: If dtype of `y_grad` is not the same as `x`.
2596
2603
  ValueError: If length of shape of `x` is neither 1 nor 2.
2597
2604
  ValueError: If shape of `x` is not the same as `target`.
2598
- ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
2605
+ ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
2599
2606
  ValueError: If shape of `y_grad` is not the same as forward output `y`.
2600
2607
 
2601
2608
  Supported Platforms:
@@ -2862,7 +2869,9 @@ class Dilation2DBackpropFilter(Primitive):
2862
2869
  self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
2863
2870
  self.add_prim_attr("pad_mode", self.pad_mode.upper())
2864
2871
  self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
2865
- if self.stride[2] < 1 or self.stride[2] > 255 or self.stride[3] < 1 or self.stride[3] > 255:
2872
+ def is_in_range(x):
2873
+ return 1 <= x <= 255
2874
+ if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
2866
2875
  raise ValueError(f"For '{self.name}', size of stride is not supported, "
2867
2876
  f'stride should be in the range of [1, 255], '
2868
2877
  f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
@@ -2917,7 +2926,12 @@ class MultiMarginLossGrad(Primitive):
2917
2926
  Args:
2918
2927
  p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
2919
2928
  margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
2920
- reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
2929
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2930
+ ``'sum'`` . Default: ``'mean'`` .
2931
+
2932
+ - ``'none'``: no reduction will be applied.
2933
+ - ``'mean'``: compute and return the weighted mean of elements in the output.
2934
+ - ``'sum'``: the output elements will be summed.
2921
2935
 
2922
2936
  Inputs:
2923
2937
  - **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
@@ -3818,3 +3832,34 @@ class WKVGrad(Primitive):
3818
3832
  """Initialize WKVGrad."""
3819
3833
  self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
3820
3834
  outputs=["gw", "gu", "gk", "gv"])
3835
+
3836
+
3837
+ class FlashAttentionScoreGrad(Primitive):
3838
+ r"""
3839
+ Calculates the gradient of FlashAttentionScore operation.
3840
+ .. warning::
3841
+ This is an experimental API that is subject to change or deletion.
3842
+
3843
+ Supported Platforms:
3844
+ ``Ascend``
3845
+ """
3846
+ @prim_attr_register
3847
+ def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
3848
+ input_layout='BSH'):
3849
+ """Initialize FlashAttentionScoreGrad."""
3850
+ validator.check_value_type('head_num', head_num, [int], self.name)
3851
+ validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
3852
+ validator.check_float(keep_prob, 0.0, validator.GE, "keep_prob", self.name)
3853
+ validator.check_float(keep_prob, 1.0, validator.LE, "keep_prob", self.name)
3854
+ validator.check_value_type('scale_value', scale_value, [float], self.name)
3855
+ validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
3856
+ validator.check_value_type('next_tokens', next_tokens, [int], self.name)
3857
+ validator.check_value_type('inner_precise', inner_precise, [int], self.name)
3858
+ if inner_precise not in [0, 1]:
3859
+ raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
3860
+ validator.check_value_type('input_layout', input_layout, [str], self.name)
3861
+ if input_layout not in ["BSH"]:
3862
+ raise ValueError(f"Attribute 'input_layout' must be either 'bsh' or 'sbh', but got {input_layout}")
3863
+ self.init_prim_io_names(inputs=['query', 'key', 'value', 'attn_mask', 'attention_in', 'softmax_max',
3864
+ 'softmax_sum', 'dy', 'drop_mask', 'real_shift', "padding_mask", 'softmax_out'],
3865
+ outputs=['dq', 'dk', 'dv'])