mindspore 2.1.0__cp38-none-any.whl → 2.2.11__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (578) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +61 -95
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/__init__.py +4 -2
  257. mindspore/nn/wrap/cell_wrapper.py +87 -34
  258. mindspore/nn/wrap/grad_reducer.py +8 -5
  259. mindspore/nn/wrap/loss_scale.py +105 -42
  260. mindspore/numpy/array_creations.py +1 -2
  261. mindspore/numpy/array_ops.py +3 -2
  262. mindspore/numpy/utils_const.py +5 -5
  263. mindspore/offline_debug/convert_async.py +2 -2
  264. mindspore/ops/_grad_experimental/__init__.py +0 -5
  265. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  266. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  267. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  268. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  269. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  270. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  271. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  272. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  273. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  274. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  275. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  276. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  277. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  278. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  279. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  280. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  281. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  282. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  283. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  284. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  285. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  286. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  287. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  288. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  289. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  290. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  291. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  293. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  294. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  295. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  296. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  298. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  299. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  300. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  301. mindspore/ops/_primitive_cache.py +1 -1
  302. mindspore/ops/_tracefunc.py +45 -13
  303. mindspore/ops/_utils/utils.py +6 -1
  304. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  305. mindspore/ops/_vmap/vmap_base.py +3 -3
  306. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  307. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  308. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  309. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  310. mindspore/ops/arg_dtype_cast.py +54 -0
  311. mindspore/ops/composite/base.py +37 -10
  312. mindspore/ops/composite/math_ops.py +5 -4
  313. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  314. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  315. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  316. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  317. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  320. mindspore/ops/deprecated.py +304 -0
  321. mindspore/ops/function/__init__.py +4 -1
  322. mindspore/ops/function/array_func.py +174 -193
  323. mindspore/ops/function/clip_func.py +81 -13
  324. mindspore/ops/function/debug_func.py +1 -1
  325. mindspore/ops/function/grad/grad_func.py +18 -9
  326. mindspore/ops/function/image_func.py +10 -4
  327. mindspore/ops/function/linalg_func.py +5 -5
  328. mindspore/ops/function/math_func.py +575 -386
  329. mindspore/ops/function/nn_func.py +568 -260
  330. mindspore/ops/function/random_func.py +88 -57
  331. mindspore/ops/function/sparse_func.py +1 -1
  332. mindspore/ops/function/sparse_unary_func.py +14 -12
  333. mindspore/ops/function/vmap_func.py +6 -5
  334. mindspore/ops/functional.py +15 -10
  335. mindspore/ops/op_info_register.py +244 -25
  336. mindspore/ops/operations/__init__.py +31 -19
  337. mindspore/ops/operations/_grad_ops.py +71 -7
  338. mindspore/ops/operations/_inner_ops.py +350 -17
  339. mindspore/ops/operations/_quant_ops.py +4 -8
  340. mindspore/ops/operations/_sequence_ops.py +42 -0
  341. mindspore/ops/operations/array_ops.py +68 -282
  342. mindspore/ops/operations/comm_ops.py +107 -59
  343. mindspore/ops/operations/custom_ops.py +94 -70
  344. mindspore/ops/operations/debug_ops.py +8 -4
  345. mindspore/ops/operations/image_ops.py +18 -12
  346. mindspore/ops/operations/inner_ops.py +26 -3
  347. mindspore/ops/operations/math_ops.py +192 -144
  348. mindspore/ops/operations/nn_ops.py +857 -489
  349. mindspore/ops/operations/other_ops.py +0 -22
  350. mindspore/ops/operations/random_ops.py +53 -111
  351. mindspore/ops/operations/sparse_ops.py +3 -1
  352. mindspore/ops/primitive.py +24 -18
  353. mindspore/parallel/_auto_parallel_context.py +68 -8
  354. mindspore/parallel/_cost_model_context.py +2 -2
  355. mindspore/parallel/_offload_context.py +17 -3
  356. mindspore/parallel/_parallel_serialization.py +12 -5
  357. mindspore/parallel/_ps_context.py +12 -0
  358. mindspore/parallel/_tensor.py +18 -13
  359. mindspore/parallel/_transformer/layers.py +5 -3
  360. mindspore/parallel/_transformer/loss.py +1 -0
  361. mindspore/parallel/_transformer/moe.py +2 -2
  362. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  363. mindspore/parallel/_transformer/transformer.py +23 -3
  364. mindspore/parallel/_utils.py +11 -7
  365. mindspore/parallel/algo_parameter_config.py +85 -5
  366. mindspore/parallel/checkpoint_transform.py +19 -12
  367. mindspore/parallel/shard.py +21 -14
  368. mindspore/profiler/common/struct_type.py +3 -3
  369. mindspore/profiler/common/util.py +4 -2
  370. mindspore/profiler/envprofiling.py +1 -1
  371. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  372. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  373. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  374. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  375. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  376. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  377. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  378. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  379. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  380. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  381. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  382. mindspore/profiler/parser/flops_parser.py +15 -11
  383. mindspore/profiler/parser/framework_parser.py +38 -22
  384. mindspore/profiler/parser/hccl_parser.py +16 -12
  385. mindspore/profiler/parser/integrator.py +22 -11
  386. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  387. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  388. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  389. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  390. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  391. mindspore/profiler/parser/optime_parser.py +1 -1
  392. mindspore/profiler/parser/profiler_info.py +21 -2
  393. mindspore/profiler/parser/step_trace_parser.py +11 -14
  394. mindspore/profiler/profiling.py +179 -89
  395. mindspore/rewrite/api/node.py +102 -19
  396. mindspore/rewrite/api/node_type.py +5 -1
  397. mindspore/rewrite/api/pattern_engine.py +1 -1
  398. mindspore/rewrite/api/scoped_value.py +9 -17
  399. mindspore/rewrite/api/symbol_tree.py +131 -47
  400. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  401. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  402. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  403. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  404. mindspore/rewrite/common/rewrite_elog.py +5 -1
  405. mindspore/rewrite/namer.py +33 -24
  406. mindspore/rewrite/namespace.py +14 -5
  407. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  408. mindspore/rewrite/node/call_function.py +79 -0
  409. mindspore/rewrite/node/cell_container.py +135 -0
  410. mindspore/rewrite/node/control_flow.py +88 -0
  411. mindspore/rewrite/{node.py → node/node.py} +273 -234
  412. mindspore/rewrite/node/node_manager.py +254 -0
  413. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  414. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  415. mindspore/rewrite/parsers/assign_parser.py +216 -221
  416. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  417. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  418. mindspore/rewrite/parsers/constant_parser.py +9 -6
  419. mindspore/rewrite/parsers/container_parser.py +9 -7
  420. mindspore/rewrite/parsers/for_parser.py +42 -21
  421. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  422. mindspore/rewrite/parsers/if_parser.py +28 -24
  423. mindspore/rewrite/parsers/module_parser.py +196 -25
  424. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  425. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  426. mindspore/rewrite/parsers/return_parser.py +6 -6
  427. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  428. mindspore/rewrite/sparsify/utils.py +1 -1
  429. mindspore/rewrite/symbol_tree.py +523 -578
  430. mindspore/rewrite/symbol_tree_builder.py +9 -193
  431. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  432. mindspore/run_check/_check_version.py +6 -4
  433. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  434. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  435. mindspore/scipy/linalg.py +1 -1
  436. mindspore/scipy/ops.py +55 -5
  437. mindspore/scipy/optimize/__init__.py +3 -2
  438. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  439. mindspore/scipy/optimize/minimize.py +7 -3
  440. mindspore/train/_utils.py +7 -3
  441. mindspore/train/amp.py +323 -123
  442. mindspore/train/anf_ir_pb2.py +14 -2
  443. mindspore/train/callback/_backup_and_restore.py +2 -12
  444. mindspore/train/callback/_callback.py +29 -4
  445. mindspore/train/callback/_checkpoint.py +23 -8
  446. mindspore/train/callback/_early_stop.py +2 -2
  447. mindspore/train/callback/_landscape.py +4 -4
  448. mindspore/train/callback/_loss_monitor.py +2 -2
  449. mindspore/train/callback/_on_request_exit.py +2 -2
  450. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  451. mindspore/train/callback/_summary_collector.py +15 -8
  452. mindspore/train/callback/_time_monitor.py +58 -5
  453. mindspore/train/data_sink.py +5 -11
  454. mindspore/train/dataset_helper.py +84 -57
  455. mindspore/train/loss_scale_manager.py +2 -2
  456. mindspore/train/metrics/__init__.py +3 -3
  457. mindspore/train/metrics/cosine_similarity.py +1 -1
  458. mindspore/train/metrics/hausdorff_distance.py +3 -2
  459. mindspore/train/metrics/mean_surface_distance.py +3 -2
  460. mindspore/train/metrics/metric.py +39 -19
  461. mindspore/train/metrics/roc.py +2 -2
  462. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  463. mindspore/train/mind_ir_pb2.py +85 -36
  464. mindspore/train/model.py +187 -47
  465. mindspore/train/serialization.py +487 -161
  466. mindspore/train/summary/_summary_adapter.py +1 -1
  467. mindspore/train/summary/_writer_pool.py +3 -2
  468. mindspore/train/summary/summary_record.py +37 -17
  469. mindspore/train/train_thor/convert_utils.py +3 -3
  470. mindspore/train/train_thor/dataset_helper.py +1 -1
  471. mindspore/version.py +1 -1
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  475. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  478. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  479. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  480. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  481. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  482. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  483. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  486. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  487. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  488. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  489. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  490. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  491. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  492. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  493. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  494. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  495. mindspore/_extends/graph_kernel/expander.py +0 -80
  496. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  497. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  498. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  499. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  500. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  501. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  502. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  503. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  504. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  505. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  506. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  507. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  508. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  509. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  510. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  511. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  512. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  513. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  514. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  515. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  516. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  517. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  518. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  519. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  520. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  523. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  524. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  525. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  526. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  527. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  528. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  532. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  533. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  534. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  535. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  536. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  537. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  538. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  539. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  541. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  542. mindspore/dataset/datapreprocess/__init__.py +0 -20
  543. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  544. mindspore/include/api/net.h +0 -142
  545. mindspore/nn/lr_scheduler.py +0 -262
  546. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  547. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  548. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  549. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  550. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  559. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  560. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  564. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  575. mindspore/rewrite/node_visitor.py +0 -44
  576. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  578. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -390,7 +390,7 @@ class Conv2DBackpropFilter(Primitive):
390
390
  stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
391
391
  dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
392
392
  group (int): Splits input into groups. Default: 1.
393
- data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\
393
+ data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
394
394
  default is 'NCHW'.
395
395
 
396
396
  Returns:
@@ -636,7 +636,7 @@ class EinsumGrad(PrimitiveWithInfer):
636
636
 
637
637
  @prim_attr_register
638
638
  def __init__(self, equation):
639
- self.add_prim_attr('equation', equation)
639
+ pass
640
640
 
641
641
  def infer_shape(self, x_shapes, dout_shape):
642
642
  out_shape = ()
@@ -1521,9 +1521,11 @@ class LSTMGrad(Primitive):
1521
1521
  """Computes the data and weight gradients of LSTM."""
1522
1522
 
1523
1523
  @prim_attr_register
1524
- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1524
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
1525
1525
  self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1526
1526
  self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1527
+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
1528
+ 'proj_size', self.name)
1527
1529
  self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1528
1530
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1529
1531
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
@@ -2573,7 +2575,12 @@ class MultilabelMarginLossGrad(Primitive):
2573
2575
  Compute the gradients of MultilabelMarginLoss operation.
2574
2576
 
2575
2577
  Args:
2576
- reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
2578
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2579
+ ``'sum'`` . Default: ``'mean'`` .
2580
+
2581
+ - ``'none'``: no reduction will be applied.
2582
+ - ``'mean'``: compute and return the mean of elements in the output.
2583
+ - ``'sum'``: the output elements will be summed.
2577
2584
 
2578
2585
  Inputs:
2579
2586
  - **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
@@ -2595,7 +2602,7 @@ class MultilabelMarginLossGrad(Primitive):
2595
2602
  TypeError: If dtype of `y_grad` is not the same as `x`.
2596
2603
  ValueError: If length of shape of `x` is neither 1 nor 2.
2597
2604
  ValueError: If shape of `x` is not the same as `target`.
2598
- ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
2605
+ ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
2599
2606
  ValueError: If shape of `y_grad` is not the same as forward output `y`.
2600
2607
 
2601
2608
  Supported Platforms:
@@ -2862,7 +2869,9 @@ class Dilation2DBackpropFilter(Primitive):
2862
2869
  self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
2863
2870
  self.add_prim_attr("pad_mode", self.pad_mode.upper())
2864
2871
  self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
2865
- if self.stride[2] < 1 or self.stride[2] > 255 or self.stride[3] < 1 or self.stride[3] > 255:
2872
+ def is_in_range(x):
2873
+ return 1 <= x <= 255
2874
+ if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
2866
2875
  raise ValueError(f"For '{self.name}', size of stride is not supported, "
2867
2876
  f'stride should be in the range of [1, 255], '
2868
2877
  f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
@@ -2917,7 +2926,12 @@ class MultiMarginLossGrad(Primitive):
2917
2926
  Args:
2918
2927
  p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
2919
2928
  margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
2920
- reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
2929
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2930
+ ``'sum'`` . Default: ``'mean'`` .
2931
+
2932
+ - ``'none'``: no reduction will be applied.
2933
+ - ``'mean'``: compute and return the weighted mean of elements in the output.
2934
+ - ``'sum'``: the output elements will be summed.
2921
2935
 
2922
2936
  Inputs:
2923
2937
  - **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
@@ -3818,3 +3832,53 @@ class WKVGrad(Primitive):
3818
3832
  """Initialize WKVGrad."""
3819
3833
  self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
3820
3834
  outputs=["gw", "gu", "gk", "gv"])
3835
+
3836
+
3837
+ class FlashAttentionScoreGrad(Primitive):
3838
+ r"""
3839
+ Calculates the gradient of FlashAttentionScore operation.
3840
+ .. warning::
3841
+ This is an experimental API that is subject to change or deletion.
3842
+
3843
+ Supported Platforms:
3844
+ ``Ascend``
3845
+ """
3846
+ @prim_attr_register
3847
+ def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
3848
+ input_layout='BSH', sparse_mode=0):
3849
+ """Initialize FlashAttentionScoreGrad."""
3850
+ validator.check_value_type('head_num', head_num, [int], self.name)
3851
+ validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
3852
+ validator.check_float(keep_prob, 0.0, validator.GE, "keep_prob", self.name)
3853
+ validator.check_float(keep_prob, 1.0, validator.LE, "keep_prob", self.name)
3854
+ validator.check_value_type('scale_value', scale_value, [float], self.name)
3855
+ validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
3856
+ validator.check_value_type('next_tokens', next_tokens, [int], self.name)
3857
+ validator.check_value_type('inner_precise', inner_precise, [int], self.name)
3858
+ validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
3859
+ if inner_precise not in [0, 1]:
3860
+ raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
3861
+ validator.check_value_type('input_layout', input_layout, [str], self.name)
3862
+ if input_layout not in ["BSH", "BNSD"]:
3863
+ raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
3864
+ self.init_prim_io_names(inputs=['query', 'key', 'value', 'dy', 'pse_shift', 'drop_mask', "padding_mask",
3865
+ 'attn_mask', 'softmax_max', 'softmax_sum', 'softmax_out', 'attention_in',
3866
+ 'prefix'],
3867
+ outputs=['dq', 'dk', 'dv', 'dpse'])
3868
+
3869
+
3870
+ class RmsNormGrad(Primitive):
3871
+ r"""
3872
+ Calculates the gradient of RmsNorm operation.
3873
+ .. warning::
3874
+ This is an experimental API that is subject to change or deletion.
3875
+
3876
+ Supported Platforms:
3877
+ ``Ascend``
3878
+ """
3879
+
3880
+ @prim_attr_register
3881
+ def __init__(self):
3882
+ """Initialize RmsNormGrad."""
3883
+ self.init_prim_io_names(inputs=["dy", "x", "rstd", "gamma"],
3884
+ outputs=["dx", "dgamma"])
@@ -23,16 +23,17 @@ from mindspore.common._stub_tensor import StubTensor
23
23
  from mindspore.ops import composite as C
24
24
  from mindspore.ops.operations.array_ops import Cast
25
25
  from mindspore.ops.operations._scalar_ops import bit_or, bit_and
26
+ from mindspore.ops.operations.comm_ops import ReduceOp
26
27
  from mindspore.ops import signature as sig
27
28
  from mindspore.ops.operations.math_ops import _infer_shape_reduce
28
- from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
29
- from mindspore import context
29
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30
+ _run_op, _check_contains_variable
30
31
  from mindspore._c_expression import Tensor as Tensor_
31
32
  from mindspore._c_expression import typing
32
33
  from mindspore import _checkparam as validator
33
34
  from mindspore.common import dtype as mstype
34
35
  from mindspore.common.parameter import Parameter
35
- from mindspore.communication.management import GlobalComm
36
+ from mindspore.communication.management import GlobalComm, get_rank
36
37
  from mindspore.common.api import _pynative_executor
37
38
  from mindspore.common._register_for_adapter import ms_adapter_registry
38
39
  from mindspore import ops
@@ -74,11 +75,11 @@ class ExtractImagePatches(Primitive):
74
75
  - valid: Means that the taken patch area must be completely covered in the original image.
75
76
 
76
77
  Inputs:
77
- - **input_x** (Tensor) - A 4-D tensor whose shape is :math:`(in_batch, in_depth, in_row, in_col)`.
78
+ - **input_x** (Tensor) - A 4-D tensor whose shape is :math:`(in\_batch, in\_depth, in\_row, in\_col)`.
78
79
 
79
80
  Outputs:
80
81
  Tensor, a 4-D tensor whose data type is same as 'input_x', and the shape
81
- is :math:`(out_batch, out_depth, out_row, out_col)`,where the out_batch is the same as the in_batch
82
+ is :math:`(out\_batch, out\_depth, out\_row, out\_col)`,where the out_batch is the same as the in_batch
82
83
  and
83
84
 
84
85
  .. math::
@@ -121,7 +122,6 @@ class ExtractImagePatches(Primitive):
121
122
  validator.check_value_type('padding', padding, [str], self.name)
122
123
  self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
123
124
  self.add_prim_attr("padding", self.padding)
124
- self.is_ge = context.get_context("enable_ge")
125
125
 
126
126
 
127
127
  class Quant(PrimitiveWithInfer):
@@ -167,6 +167,7 @@ class Quant(PrimitiveWithInfer):
167
167
  self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
168
168
  self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
169
169
  "round_mode", self.name)
170
+ self.add_prim_attr("dst_type", mstype.int8)
170
171
 
171
172
  def infer_shape(self, x_shape):
172
173
  return x_shape
@@ -174,7 +175,7 @@ class Quant(PrimitiveWithInfer):
174
175
  def infer_dtype(self, x_type):
175
176
  validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
176
177
  validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
177
- return mstype.int8
178
+ return self.get_attr_dict()['dst_type']
178
179
 
179
180
 
180
181
  class Lamb(PrimitiveWithInfer):
@@ -491,7 +492,7 @@ class Receive(PrimitiveWithInfer):
491
492
  self.dtype = dtype
492
493
  self.group = group
493
494
  self.add_prim_attr("no_eliminate", True)
494
- valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
+ valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
496
  args = {"dtype": dtype}
496
497
  validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
497
498
 
@@ -502,6 +503,109 @@ class Receive(PrimitiveWithInfer):
502
503
  return self.get_attr_dict()['dtype']
503
504
 
504
505
 
506
+ class Reduce(PrimitiveWithInfer):
507
+ """
508
+ Reduces tensor across the processes in the specified communication group.
509
+
510
+ Note:
511
+ Only process with destination rank receives the reduced output.
512
+ Other processes only get a tensor with shape [1], which has no mathematical meaning.
513
+
514
+ Args:
515
+ dest_rank (int): Specifies the rank of the process that receives the reduced output.
516
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
517
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
518
+ group (str, optional): The communication group to work on.
519
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
520
+
521
+ Inputs:
522
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
523
+
524
+ Examples:
525
+ >>> import mindspore.ops as ops
526
+ >>> import mindspore.nn as nn
527
+ >>> from mindspore.communication import init
528
+ >>> from mindspore import Tensor
529
+ >>> import numpy as np
530
+ >>> # Launch 4 processes.
531
+ >>> init()
532
+ >>> class ReduceNet(nn.Cell):
533
+ >>> def __init__(self):
534
+ >>> super(Net, self).__init__()
535
+ >>> self.reduce = ops.Reduce(dest_rank=1)
536
+ >>>
537
+ >>> def construct(self, x):
538
+ >>> out = self.reduce(x)
539
+ >>> return out
540
+ >>> input = Tensor(np.ones([2, 8]).astype(np.float32))
541
+ >>> net = ReduceNet()
542
+ >>> output = net(input)
543
+ >>> print(output)
544
+ Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
545
+ [4. 4. 4. 4. 4. 4. 4. 4.]],
546
+ Other proesses: [0.].
547
+ """
548
+
549
+ @prim_attr_register
550
+ def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
551
+ self.dest_rank = dest_rank
552
+ self.op = op
553
+ self.group = group
554
+
555
+ def infer_shape(self, x_shape):
556
+ # The process with dest_rank returns the reduced output.
557
+ # Other processes only gets a tensor with shape [1], which has no mathematical meaning.
558
+ if self.dest_rank == get_rank():
559
+ return x_shape
560
+ return [1]
561
+
562
+ def infer_dtype(self, x_dtype):
563
+ return x_dtype
564
+
565
+
566
+ class Barrier(PrimitiveWithInfer):
567
+ """
568
+ Synchronizes all processes in the specified group.
569
+
570
+ Note:
571
+ After calling this collective operator,
572
+ this process will be blocked until all other processes in the group call this operator.
573
+
574
+ Args:
575
+ group (str, optional): The communication group to work on.
576
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
577
+
578
+ Examples:
579
+ >>> import mindspore.ops as ops
580
+ >>> import mindspore.nn as nn
581
+ >>> from mindspore.communication import init
582
+ >>> from mindspore import Tensor
583
+ >>> import numpy as np
584
+ >>> # Launch 4 processes.
585
+ >>> init()
586
+ >>> class BarrierNet(nn.Cell):
587
+ >>> def __init__(self):
588
+ >>> super(Net, self).__init__()
589
+ >>> self.barrier = ops.Barrier()
590
+ >>>
591
+ >>> def construct(self):
592
+ >>> self.barrier()
593
+ >>> net = BarrierNet()
594
+ >>> net()
595
+ """
596
+
597
+ @prim_attr_register
598
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
599
+ self.group = group
600
+ self.add_prim_attr("side_effect_mem", True)
601
+
602
+ def infer_shape(self):
603
+ return [1]
604
+
605
+ def infer_dtype(self):
606
+ return mstype.float32
607
+
608
+
505
609
  class MatrixSetDiag(PrimitiveWithInfer):
506
610
  r"""
507
611
  Modifies the batched diagonal part of a batched tensor.
@@ -1843,16 +1947,32 @@ class Format(PrimitiveWithInfer):
1843
1947
  def __init__(self):
1844
1948
  self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
1845
1949
 
1950
+
1846
1951
  def __infer__(self, str_, *var):
1847
- str_value = str_["value"]
1952
+ def check_variable(str_, var):
1953
+ if _check_contains_variable(str_['dtype'], str_['value']):
1954
+ return True
1955
+
1956
+ for item in var:
1957
+ if _check_contains_variable(item['dtype'], item['value']):
1958
+ return True
1959
+ return False
1960
+
1961
+
1962
+ if check_variable(str_, var):
1963
+ return {'dtype': mstype.string, 'shape': [], 'value': None}
1964
+
1965
+
1966
+ str_value = str_['value']
1967
+ kwargs = dict()
1848
1968
  var_value = list()
1849
- if str_value is None and str_["dtype"] is not None:
1850
- raise ValueError("str.format not support to input a variable.")
1969
+
1851
1970
  for item in var:
1852
- if item["value"] is None and item["dtype"] is not None:
1853
- raise ValueError("str.format not support to input a variable.")
1971
+ if isinstance(item["dtype"], typing.Keyword):
1972
+ kwargs.update(item["value"])
1854
1973
  var_value.append(item["value"])
1855
- value = str_value.format(*var_value)
1974
+
1975
+ value = str_value.format(*var_value, **kwargs)
1856
1976
  return {'dtype': mstype.string, 'shape': [], 'value': value}
1857
1977
 
1858
1978
 
@@ -2027,13 +2147,14 @@ class ClipByNorm(PrimitiveWithInfer):
2027
2147
  @prim_attr_register
2028
2148
  def __init__(self, axis=None):
2029
2149
  """Initialize ClipByNorm"""
2150
+ self.axis_str = 'axis'
2030
2151
  self.axis = () if axis is None else axis
2031
- validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
2152
+ validator.check_value_type(self.axis_str, self.axis, [int, tuple, list], self.name)
2032
2153
  axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
2033
2154
  for i, value in enumerate(axis_check):
2034
2155
  validator.check_value_type('axis[%d]' % i, value, [int], self.name)
2035
- self.init_attrs['axis'] = self.axis
2036
- self.add_prim_attr('axis', self.axis)
2156
+ self.init_attrs[self.axis_str] = self.axis
2157
+ self.add_prim_attr(self.axis_str, self.axis)
2037
2158
  self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
2038
2159
 
2039
2160
  def infer_shape(self, x_shape, clip_norm_shape):
@@ -2588,3 +2709,215 @@ class IsConstant(Primitive):
2588
2709
 
2589
2710
  def __call__(self, x):
2590
2711
  return True
2712
+
2713
+
2714
+ class SelectView(Primitive):
2715
+ r"""
2716
+ Select tensor of view
2717
+ """
2718
+
2719
+ @prim_attr_register
2720
+ def __init__(self):
2721
+ self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
2722
+
2723
+
2724
+ class CopyWithSlice(Primitive):
2725
+ r"""
2726
+ Copy data to discontinuous tensor
2727
+ """
2728
+ @prim_attr_register
2729
+ def __init__(self):
2730
+ self.add_prim_attr('side_effect_mem', True)
2731
+ self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2732
+
2733
+
2734
+ class FFN(Primitive):
2735
+ r"""
2736
+ The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2737
+
2738
+ Args:
2739
+ activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2740
+ Only support 'fastgelu' for now. Default: "fastgelu".
2741
+ inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2742
+ Only support 1 for now. Default: 0.
2743
+
2744
+ Inputs:
2745
+ - **x** (Tensor) - The input tensor with data type of int8, float16.
2746
+ Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2747
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2750
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2751
+ - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2752
+ Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2753
+ indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2754
+ the 2th expert do noting and so on.
2755
+ - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2756
+ Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2757
+ - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2758
+ Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2759
+ - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
2760
+ - **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
2761
+ - **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
2762
+ - **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
2763
+
2764
+ Outputs:
2765
+ Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
2766
+
2767
+ Supported Platforms:
2768
+ ``Ascend``
2769
+
2770
+ Examples:
2771
+ >>> from mindspore.ops.operations import _inner_ops
2772
+ >>> b = 4
2773
+ >>> s = 128
2774
+ >>> h = 1024
2775
+ >>> h_f = 4 * h
2776
+ >>> e = 16
2777
+ >>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
2778
+ >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2779
+ >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2780
+ >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2781
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2782
+ >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2783
+ >>> ffn = _inner_ops.FFN("fastgelu", 1)
2784
+ >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2785
+ >>> print(output)
2786
+ """
2787
+
2788
+ @prim_attr_register
2789
+ def __init__(self, activation, inner_precise):
2790
+ """Initialize FFN."""
2791
+ self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2792
+ "bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
2793
+ outputs=["y"])
2794
+ cls_name = self.name
2795
+ validator.check_value_type("activation", activation, [str], cls_name)
2796
+ validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2797
+
2798
+
2799
+ class DecoderKVCache(Primitive):
2800
+ r"""
2801
+ The DecoderKVCache is used for decoding the KVCache of transformer network.
2802
+
2803
+ Args:
2804
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2805
+ When seq_len_axis is 2, cache tensor of shape
2806
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2807
+ When seq_len_axis is 1, cache tensor of shape
2808
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2809
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2810
+ When seq_len_axis is 2, update tensor of shape
2811
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2812
+ When seq_len_axis is 1, update tensor of shape
2813
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2814
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2815
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2816
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2817
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2818
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2819
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2820
+ New_max_seq_len tensor of shape :math:`(1)`.
2821
+ Indicate that user want to change the shape of cache tensor from
2822
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2823
+ :math:
2824
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2825
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2826
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2827
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2828
+
2829
+ Outputs:
2830
+ With same data type and same shape as `cache` tensor.
2831
+
2832
+ Supported Platforms:
2833
+ ``Ascend``
2834
+
2835
+ Examples:
2836
+ >>> from mindspore.ops.operations import _inner_ops
2837
+ >>> b = 4
2838
+ >>> h = 40
2839
+ >>> max_s = 1024
2840
+ >>> s = 1
2841
+ >>> d = 128
2842
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2843
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2844
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2845
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2846
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2847
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2848
+ >>> decoder_kv_cache = _inner_ops.DecoderKVCache()
2849
+ >>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2850
+ >>> print(cache)
2851
+ """
2852
+ @prim_attr_register
2853
+ def __init__(self):
2854
+ """Initialize DecoderKVCache."""
2855
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2856
+ "new_max_seq_len", "cur_max_seq_len"],
2857
+ outputs=["out"])
2858
+ self.add_prim_attr('side_effect_mem', True)
2859
+
2860
+
2861
+ class PromptKVCache(Primitive):
2862
+ r"""
2863
+ The PromptKVCache is used for prefill the KVCache of transformer network.
2864
+
2865
+ Args:
2866
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2867
+ When seq_len_axis is 2, cache tensor of shape
2868
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2869
+ When seq_len_axis is 1, cache tensor of shape
2870
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2871
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2872
+ When seq_len_axis is 2, update tensor of shape
2873
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2874
+ When seq_len_axis is 1, update tensor of shape
2875
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2876
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2877
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2878
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2879
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2880
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2881
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2882
+ New_max_seq_len tensor of shape :math:`(1)`.
2883
+ Indicate that user want to change the shape of cache tensor from
2884
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2885
+ :math:
2886
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2887
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2888
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2889
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2890
+ align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
2891
+
2892
+ Outputs:
2893
+ With same data type and same shape as `cache` tensor.
2894
+
2895
+ Supported Platforms:
2896
+ ``Ascend``
2897
+
2898
+ Examples:
2899
+ >>> from mindspore import Tensor
2900
+ >>> from mindspore.ops.operations import _inner_ops
2901
+ >>> b = 4
2902
+ >>> h = 40
2903
+ >>> max_s = 1024
2904
+ >>> s = 256
2905
+ >>> d = 128
2906
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2907
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2908
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2909
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2910
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2911
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2912
+ >>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
2913
+ >>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2914
+ >>> print(cache)
2915
+ """
2916
+ @prim_attr_register
2917
+ def __init__(self, padding_mode="right"):
2918
+ """Initialize PromptKVCache."""
2919
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2920
+ "new_max_seq_len", "cur_max_seq_len"],
2921
+ outputs=["out"])
2922
+ self.add_prim_attr('side_effect_mem', True)
2923
+ self.padding_mode = padding_mode
@@ -269,7 +269,7 @@ class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
269
269
  - **quant_max** (Tensor) : Value of the quantization range.
270
270
 
271
271
  Outputs:
272
- - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`.
272
+ - Tensor: Simulates quantize tensor of `input_x`, with the same type and shape as the `input_x`.
273
273
 
274
274
  Examples:
275
275
  >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
@@ -419,7 +419,7 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
419
419
  - **quant_max** (Tensor) : Value of the quantization range.
420
420
 
421
421
  Outputs:
422
- - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`.
422
+ - Tensor: Simulates quantize tensor of `input_x`, with the same type and shape as the `input_x`.
423
423
 
424
424
  Examples:
425
425
  >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
@@ -975,7 +975,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
975
975
  >>> result = fake_quant(input_x, _min, _max)
976
976
  """
977
977
  support_quant_bit = [4, 7, 8]
978
- ascend_support_x_rank = [2, 4]
978
+ ascend_support_x_rank = [2, 3, 4]
979
979
 
980
980
  @prim_attr_register
981
981
  def __init__(self,
@@ -1008,11 +1008,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
1008
1008
  self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
1009
1009
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
1010
1010
  self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
1011
- if self.is_ascend:
1012
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
1013
- 'channel_axis', self.name)
1014
- else:
1015
- self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
1011
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
1016
1012
  self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
1017
1013
 
1018
1014
  def infer_shape(self, x_shape, min_shape, max_shape):