mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
1
+ # Copyright 2023-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,20 +23,21 @@ import numpy as np
23
23
  from mindspore.ops import signature as sig
24
24
  from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
25
25
  from mindspore.ops._primitive_cache import _get_cache_prim
26
- from mindspore.ops.auto_generate import gen_arg_handler as handler
26
+ from mindspore.ops._utils import arg_handler as handler
27
+ from mindspore.ops._utils.arg_dtype_cast import DtypeToEnum
27
28
  from mindspore.common import Tensor, CSRTensor, COOTensor
28
29
  from mindspore.common._stub_tensor import _convert_stub
29
30
  from mindspore._c_expression import typing
30
- from mindspore._c_expression import Tensor as Tensor_
31
- from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
31
+ from mindspore._c_expression import TensorPy as Tensor_
32
+ from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones, pyboost_type_as
32
33
  from mindspore.common import dtype as mstype
33
34
  from mindspore.common._utils import is_shape_unknown
34
35
  from mindspore import _checkparam as validator
35
36
  from mindspore.ops.operations.manually_defined._inner import ScalarCast
36
- from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
37
37
  from mindspore.common.initializer import Zero
38
38
  from mindspore.common.parameter import Parameter
39
- from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
39
+ from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore, FusedInferAttentionScore
40
+ from mindspore.common.jit_context import jit_context
40
41
 
41
42
 
42
43
  dtype_to_type_id = DtypeToEnum()
@@ -527,6 +528,64 @@ class ScalarBool(Primitive):
527
528
  return bool(x)
528
529
 
529
530
 
531
+ class ScalarMax(Primitive):
532
+ r"""
533
+ Return the maximum of two input scalars.
534
+
535
+ .. note::
536
+ The inputs can be constant/variable value. Usage is the same as 'max' in Python.
537
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
538
+
539
+ Inputs:
540
+ - **x** (Scalar) - A constant or variable scalar.
541
+ - **y** (Scalar) - A constant or variable scalar.
542
+
543
+ Outputs:
544
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
545
+
546
+ Raises:
547
+ TypeError: If `x` and `y` are not scalar.
548
+
549
+ Supported Platforms:
550
+ ``Ascend`` ``GPU`` ``CPU``
551
+ """
552
+ @prim_attr_register
553
+ def __init__(self):
554
+ """Initialize ScalarMax"""
555
+
556
+ def __call__(self, x, y):
557
+ return max(x, y)
558
+
559
+
560
+ class ScalarMin(Primitive):
561
+ r"""
562
+ Return the minimum of two input scalars.
563
+
564
+ .. note::
565
+ The inputs can be constant/variable value. Usage is the same as 'min' in Python.
566
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
567
+
568
+ Inputs:
569
+ - **x** (Scalar) - A constant or variable scalar.
570
+ - **y** (Scalar) - A constant or variable scalar.
571
+
572
+ Outputs:
573
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
574
+
575
+ Raises:
576
+ TypeError: If `x` and `y` are not scalar.
577
+
578
+ Supported Platforms:
579
+ ``Ascend`` ``GPU`` ``CPU``
580
+ """
581
+ @prim_attr_register
582
+ def __init__(self):
583
+ """Initialize ScalarMin"""
584
+
585
+ def __call__(self, x, y):
586
+ return min(x, y)
587
+
588
+
530
589
  scalar_div = ScalarDiv()
531
590
  scalar_mod = ScalarMod()
532
591
  scalar_add = ScalarAdd()
@@ -543,6 +602,8 @@ scalar_log = ScalarLog()
543
602
  scalar_pow = ScalarPow()
544
603
  scalar_uadd = ScalarUadd()
545
604
  scalar_usub = ScalarUsub()
605
+ scalar_max = ScalarMax()
606
+ scalar_min = ScalarMin()
546
607
 
547
608
 
548
609
  class BatchNorm(Primitive):
@@ -570,31 +631,28 @@ class BatchNorm(Primitive):
570
631
  - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
571
632
 
572
633
  Args:
573
- is_training (bool): If `is_training` is ``True`` , `mean` and `variance` are computed during training.
634
+ is_training (bool, optional): If `is_training` is ``True`` ,
635
+ `mean` and `variance` are computed during training.
574
636
  If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` .
575
- epsilon (float): A small value added for numerical stability. Default: ``1e-5``, value must be (0, 1] .
576
- momentum (float): The hyper parameter to compute moving average for running_mean and running_var
637
+ epsilon (float, optional): A small value added for numerical stability.
638
+ Default: ``1e-5``, value must be (0, 1] .
639
+ momentum (float, optional): The hyper parameter to compute moving average for running_mean and running_var
577
640
  (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
578
641
  Momentum value must be [0, 1]. Default: ``0.1`` .
579
- data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``, and the ``'NHWC'`` format
642
+ data_format (str, optional): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
643
+ and the ``'NHWC'`` format
580
644
  is only supported in GPU target. Default: ``"NCHW"`` .
581
645
 
582
646
  Inputs:
583
- If `is_training` is ``False`` , inputs are Tensors.
584
-
585
- - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
586
- - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
587
- - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
588
- - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
589
- - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
590
-
591
- If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters.
592
-
593
647
  - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
594
- - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type.
595
- - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
596
- - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
597
- - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
648
+ - **scale** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
649
+ with float16 or float32 data type.
650
+ - **bias** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
651
+ has the same data type with `scale`.
652
+ - **mean** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
653
+ has the same data type with `scale`.
654
+ - **variance** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
655
+ has the same data type with `scale`.
598
656
 
599
657
  Outputs:
600
658
  Tuple of 5 Tensors, the normalized inputs and the updated parameters.
@@ -794,29 +852,21 @@ class Rank(Primitive):
794
852
 
795
853
  def rank(input_x):
796
854
  """
797
- Returns the rank of a tensor.
798
-
799
- Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
800
- is the number of indices required to uniquely select each element of the tensor.
855
+ Return the rank of a tensor.
801
856
 
802
857
  Args:
803
- input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
858
+ input_x (Tensor): The input tensor.
804
859
 
805
860
  Returns:
806
- Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int.
807
-
808
- Raises:
809
- TypeError: If `input_x` is not a Tensor.
861
+ Tensor
810
862
 
811
863
  Supported Platforms:
812
864
  ``Ascend`` ``GPU`` ``CPU``
813
865
 
814
866
  Examples:
815
867
  >>> import mindspore
816
- >>> import numpy as np
817
- >>> from mindspore import Tensor, ops
818
- >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
819
- >>> output = ops.rank(input_tensor)
868
+ >>> input_tensor = mindspore.tensor([[2, 2], [2, 2]], mindspore.float32)
869
+ >>> output = mindspore.ops.rank(input_tensor)
820
870
  >>> print(output)
821
871
  2
822
872
  >>> print(type(output))
@@ -938,10 +988,6 @@ class Tile(Primitive):
938
988
 
939
989
  Refer to :func:`mindspore.ops.tile` for more details.
940
990
 
941
- Note:
942
- On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
943
- where more than 4 dimensions are repeated simultaneously.
944
-
945
991
  Inputs:
946
992
  - **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
947
993
  :math:`(x_1, x_2, ..., x_S)` .
@@ -949,6 +995,10 @@ class Tile(Primitive):
949
995
  the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
950
996
  Only constant value is allowed.
951
997
 
998
+ .. note::
999
+ On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
1000
+ where more than 4 dimensions are repeated simultaneously.
1001
+
952
1002
  Outputs:
953
1003
  Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
954
1004
  the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
@@ -1005,7 +1055,16 @@ class Tile(Primitive):
1005
1055
  """Initialize."""
1006
1056
 
1007
1057
  def __call__(self, input, dims):
1008
- return _convert_stub(pyboost_tile(self, [input, dims]))
1058
+ # Add for jit context.
1059
+ if jit_context() and jit_context().compiled:
1060
+ return None
1061
+ res = _convert_stub(pyboost_tile(self, [input, dims]))
1062
+ # Add for jit context.
1063
+ if jit_context():
1064
+ if validator.is_stub_tensor(res):
1065
+ res = res.stub_sync()
1066
+ return jit_context().run_op(self, res, input, dims)
1067
+ return res
1009
1068
 
1010
1069
  # pylint: disable=missing-docstring
1011
1070
  def check_elim(self, *args):
@@ -1026,26 +1085,14 @@ class Tile(Primitive):
1026
1085
 
1027
1086
  def tile(input, dims):
1028
1087
  r"""
1029
- Creates a new tensor by replicating `input` `dims` times. The i'th dimension of
1030
- output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
1031
- are replicated `dims[i]` times along the i'th dimension.
1088
+ Creates a new tensor by repeating the elements in the input tensor `dims` times.
1032
1089
 
1033
- Note:
1034
- On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
1035
- where more than 4 dimensions are repeated simultaneously.
1036
-
1037
- Args:
1038
- input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
1039
- :math:`(x_1, x_2, ..., x_S)` .
1040
-
1041
- dims (tuple[int]): The parameter that specifies the number of replications,
1042
- the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
1043
- Only constant value is allowed.
1044
-
1045
- Returns:
1046
- Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
1047
- the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
1090
+ The i'th dimension of output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
1091
+ are repeated `dims[i]` times along the i'th dimension.
1048
1092
 
1093
+ Note:
1094
+ - On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
1095
+ where more than 4 dimensions are repeated simultaneously.
1049
1096
  - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
1050
1097
  the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
1051
1098
  - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
@@ -1056,40 +1103,39 @@ def tile(input, dims):
1056
1103
  `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
1057
1104
  can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
1058
1105
 
1059
- Raises:
1060
- TypeError: If `dims` is not a tuple or its elements are not all int.
1061
- ValueError: If the elements of `dims` are not all greater than or equal to 0.
1106
+ Args:
1107
+ input (Tensor): The input tensor.
1108
+ dims (tuple[int]): The specified number of repetitions in each dimension.
1109
+
1110
+ Returns:
1111
+ Tensor
1062
1112
 
1063
1113
  Supported Platforms:
1064
1114
  ``Ascend`` ``GPU`` ``CPU``
1065
1115
 
1066
1116
  Examples:
1067
1117
  >>> import mindspore
1068
- >>> import numpy as np
1069
- >>> from mindspore import Tensor, ops
1070
- >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
1071
- >>> dims = (2, 3)
1072
- >>> output = ops.tile(input, dims)
1073
- >>> print(output)
1074
- [[1. 2. 1. 2. 1. 2.]
1075
- [3. 4. 3. 4. 3. 4.]
1076
- [1. 2. 1. 2. 1. 2.]
1077
- [3. 4. 3. 4. 3. 4.]]
1078
- >>> dims = (2, 3, 2)
1079
- >>> output = ops.tile(input, dims)
1080
- >>> print(output)
1081
- [[[1. 2. 1. 2.]
1082
- [3. 4. 3. 4.]
1083
- [1. 2. 1. 2.]
1084
- [3. 4. 3. 4.]
1085
- [1. 2. 1. 2.]
1086
- [3. 4. 3. 4.]]
1087
- [[1. 2. 1. 2.]
1088
- [3. 4. 3. 4.]
1089
- [1. 2. 1. 2.]
1090
- [3. 4. 3. 4.]
1091
- [1. 2. 1. 2.]
1092
- [3. 4. 3. 4.]]]
1118
+ >>> input = mindspore.tensor([[1, 2], [3, 4]])
1119
+ >>> mindspore.ops.tile(input, (2, 3))
1120
+ Tensor(shape=[4, 6], dtype=Int64, value=
1121
+ [[1, 2, 1, 2, 1, 2],
1122
+ [3, 4, 3, 4, 3, 4],
1123
+ [1, 2, 1, 2, 1, 2],
1124
+ [3, 4, 3, 4, 3, 4]])
1125
+ >>> mindspore.ops.tile(input, (2, 3, 2))
1126
+ Tensor(shape=[2, 6, 4], dtype=Int64, value=
1127
+ [[[1, 2, 1, 2],
1128
+ [3, 4, 3, 4],
1129
+ [1, 2, 1, 2],
1130
+ [3, 4, 3, 4],
1131
+ [1, 2, 1, 2],
1132
+ [3, 4, 3, 4]],
1133
+ [[1, 2, 1, 2],
1134
+ [3, 4, 3, 4],
1135
+ [1, 2, 1, 2],
1136
+ [3, 4, 3, 4],
1137
+ [1, 2, 1, 2],
1138
+ [3, 4, 3, 4]]])
1093
1139
  """
1094
1140
  tile_op = _get_cache_prim(Tile)()
1095
1141
  return tile_op(input, dims)
@@ -1176,17 +1222,78 @@ class Cast(Primitive):
1176
1222
  if data.dtype == dtype:
1177
1223
  return (True, x)
1178
1224
  if isinstance(x, Tensor) and x.dtype == dtype:
1179
- x.set_cast_dtype()
1180
1225
  return (True, x)
1181
1226
  if isinstance(x, numbers.Number):
1182
1227
  return (True, Tensor(x, dtype=dtype))
1183
1228
  return (False, None)
1184
1229
 
1185
1230
  def __call__(self, input_x, dtype):
1231
+ # Add for jit context.
1232
+ if jit_context() and jit_context().compiled:
1233
+ return None
1186
1234
  should_elim, output = self.check_elim(input_x, dtype)
1187
1235
  if should_elim:
1188
1236
  return output
1189
- return _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
1237
+ res = _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
1238
+ # Add for jit context.
1239
+ if jit_context():
1240
+ if validator.is_stub_tensor(res):
1241
+ res = res.stub_sync()
1242
+ return jit_context().run_op(self, res, input_x, dtype)
1243
+ return res
1244
+
1245
+
1246
+ class TypeAs(Primitive):
1247
+ """
1248
+ Returns first input tensor cast to the type of the with the second input tensor.
1249
+
1250
+ .. warning::
1251
+ This is an experimental API that is subject to change or deletion.
1252
+
1253
+ Note:
1254
+ When converting complex numbers to boolean type, the imaginary part of the complex number is not
1255
+ taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
1256
+
1257
+ Inputs:
1258
+ - **input** (Tensor) - The shape of tensor is :math:`(x_0, x_1, ..., x_R)`.
1259
+ The tensor whose data type is to be converted.
1260
+ - **other ** (Tensor) - The shape of tensor is :math:`(x_0, x_1, ..., x_R)`.
1261
+ The tensor whose data type is specified.
1262
+
1263
+ Outputs:
1264
+ Tensor, the shape of tensor is the same as `input`, :math:`(x_0, x_1, ..., x_R)`.
1265
+
1266
+ Raises:
1267
+ TypeError: If `input` is not a Tensor.
1268
+ TypeError: If `other` is not a Tensor.
1269
+
1270
+ Supported Platforms:
1271
+ ``Ascend``
1272
+
1273
+ Examples:
1274
+ >>> import mindspore
1275
+ >>> import numpy as np
1276
+ >>> from mindspore import Tensor, ops
1277
+ >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
1278
+ >>> input = Tensor(input_np)
1279
+ >>> other_np = np.random.randn(2, 3, 4).astype(np.int32)
1280
+ >>> other = Tensor(other_np)
1281
+ >>> type_as = ops.TypeAs()
1282
+ >>> output = type_as(input, other)
1283
+ >>> print(output.dtype)
1284
+ Int32
1285
+ >>> print(output.shape)
1286
+ (2, 3, 4, 5)
1287
+ """
1288
+
1289
+ @prim_attr_register
1290
+ def __init__(self):
1291
+ pass
1292
+
1293
+ def __call__(self, input, other):
1294
+ if input.dtype == other.dtype:
1295
+ return input
1296
+ return _convert_stub(pyboost_type_as(self, [input, other]))
1190
1297
 
1191
1298
 
1192
1299
  def to_sequence(val):
@@ -1506,15 +1613,52 @@ def infer_value_for_Tile(input, dims):
1506
1613
  return Tensor(np.tile(input.asnumpy(), dims))
1507
1614
 
1508
1615
 
1616
+ def infer_value_for_EqualExt(x, y):
1617
+ """Infer value for EqualExt op."""
1618
+ if x is None or y is None:
1619
+ return None
1620
+ result = np.equal(x.asnumpy(), y.asnumpy())
1621
+ value = False
1622
+ if result.all():
1623
+ value = True
1624
+ return Tensor(value)
1625
+
1626
+
1509
1627
  def infer_value_for_Concat(tensors, axis):
1510
1628
  """Infer value for Concat op."""
1511
1629
  if not tensors or None in tensors or axis is None:
1512
1630
  return None
1513
1631
 
1514
- tensor_to_concat = [x.asnumpy() if x.dtype != mstype.bfloat16 else x.float().asnumpy() for x in tensors]
1632
+ tensor_to_concat = [x.asnumpy() for x in tensors]
1515
1633
  return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
1516
1634
 
1517
1635
 
1636
+ def infer_value_for_GatherD(input, dim, index):
1637
+ """Infer value for GatherD op."""
1638
+ if input is None or dim is None or index is None:
1639
+ return None
1640
+
1641
+ input_np = input.asnumpy()
1642
+ index_np = index.asnumpy()
1643
+
1644
+ index_shape = index_np.shape
1645
+ multi_index = [np.indices(index_shape)[i] for i in range(len(index_shape))]
1646
+ multi_index[dim] = index_np
1647
+
1648
+ output = input_np[tuple(multi_index)]
1649
+ return Tensor(output, dtype=input.dtype)
1650
+
1651
+
1652
+ def infer_value_for_Softmax(input, axis):
1653
+ """Infer value for Softmax op."""
1654
+ if input is None or axis is None:
1655
+ return None
1656
+
1657
+ e_input = np.exp(input.asnumpy())
1658
+ output = e_input / np.sum(e_input, axis=axis, keepdims=True)
1659
+ return Tensor(output, dtype=input.dtype)
1660
+
1661
+
1518
1662
  def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode):
1519
1663
  """Infer value for ReduceSum op."""
1520
1664
  value = None
@@ -1562,6 +1706,20 @@ def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
1562
1706
  return value
1563
1707
 
1564
1708
 
1709
+ def infer_value_for_Arange(start, end, step, dtype=None):
1710
+ """Infer value for Arange op."""
1711
+ if start is None or end is None or step is None:
1712
+ return None
1713
+ np_dtype = np.int64
1714
+ if dtype is None:
1715
+ has_float = any(isinstance(i, float) for i in [start, end, step])
1716
+ if has_float:
1717
+ np_dtype = np.float32
1718
+ else:
1719
+ np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1720
+ return Tensor(np.arange(start, end, step, dtype=np_dtype))
1721
+
1722
+
1565
1723
  def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
1566
1724
  """Infer value for Common ReduceExtand op."""
1567
1725
  value = None
@@ -1633,6 +1791,95 @@ def infer_value_for_Cast(x, dst_type_enum=None):
1633
1791
  return value
1634
1792
 
1635
1793
 
1794
+ def infer_value_for_LinalgVectorNorm(input_x, ord, dim, keepdim, dtype):
1795
+ """Infer value for linalg_vector_norm op.
1796
+ Current version numpy is not support numpy.linalg.vector_norm.
1797
+ So using numpy.linalg.norm.
1798
+ """
1799
+ if input_x is None or ord is None:
1800
+ return None
1801
+ if ord != 0:
1802
+ out = np.power(np.sum(np.power(np.abs(input_x.asnumpy()), ord), axis=dim, keepdims=keepdim), 1/ord)
1803
+ else:
1804
+ out = np.sum(input_x.asnumpy() != 0, axis=dim, keepdims=keepdim)
1805
+ if dtype is None:
1806
+ return Tensor(out)
1807
+ dtype_for_ms = typing.type_id_to_type(dtype)
1808
+ return Tensor(out, dtype=dtype_for_ms)
1809
+
1810
+
1811
+ def infer_value_for_LpNormV2(input_x, p=2, dim=None, keepdim=False, eps=1e-12):
1812
+ """Infer value for linalg_vector_norm op.
1813
+ Current version numpy is not support numpy.linalg.vector_norm.
1814
+ So using numpy.linalg.norm.
1815
+ """
1816
+ if input_x is None:
1817
+ return None
1818
+ return Tensor(np.linalg.norm(input_x.asnumpy(), axis=dim, keepdims=keepdim,
1819
+ ord=p))
1820
+
1821
+
1822
+ def infer_value_for_Svd(input_x, full_matrices, compute_uv):
1823
+ """Infer value for Svd op."""
1824
+ if input_x is None:
1825
+ return None
1826
+ if bool(compute_uv):
1827
+ s, u, v = np.linalg.svd(input_x.asnumpy(), full_matrices=full_matrices, compute_uv=True)
1828
+ return Tensor(s), Tensor(u), Tensor(v)
1829
+ s = np.linalg.svd(input_x.asnumpy(), full_matrices=full_matrices, compute_uv=False)
1830
+ return Tensor(s), np.zeros(1), np.zeros(1)
1831
+
1832
+
1833
+ def infer_value_for_Div(input_x, other_x):
1834
+ """Infer value for Div op."""
1835
+ if input_x is None or other_x is None:
1836
+ return None
1837
+ return Tensor(np.true_divide(input_x.asnumpy(), other_x.asnumpy()))
1838
+
1839
+
1840
+ def infer_value_for_Divs(input_x, other_x):
1841
+ """Infer value for Divs op."""
1842
+ if input_x is None or other_x is None:
1843
+ return None
1844
+ tmp = np.true_divide(input_x.asnumpy(), other_x)
1845
+ if not input_x.shape:
1846
+ # tensor scalar has a special rule for data type promote
1847
+ if input_x.dtype in (mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.int8, mstype.int16,
1848
+ mstype.int32, mstype.int64):
1849
+ res = Tensor(tmp, dtype=mstype.float32)
1850
+ else:
1851
+ res = Tensor(tmp, dtype=input_x.dtype)
1852
+ else:
1853
+ res = Tensor(tmp)
1854
+ return res
1855
+
1856
+
1857
+ def infer_value_for_DivMod(input_x, other_x, rounding_mode):
1858
+ """Infer value for DivMod op."""
1859
+ if input_x is None or other_x is None:
1860
+ return None
1861
+ if rounding_mode == 1:
1862
+ # trunc
1863
+ return Tensor(np.trunc(np.true_divide(input_x.asnumpy(), other_x.asnumpy())))
1864
+ if rounding_mode == 2:
1865
+ # floor
1866
+ return Tensor(np.floor_divide(input_x.asnumpy(), other_x.asnumpy()))
1867
+ return None
1868
+
1869
+
1870
+ def infer_value_for_DivMods(input_x, other_x, rounding_mode):
1871
+ """Infer value for DivMods op."""
1872
+ if input_x is None or other_x is None:
1873
+ return None
1874
+ if rounding_mode == 1:
1875
+ # trunc
1876
+ return Tensor(np.trunc(np.true_divide(input_x.asnumpy(), other_x)))
1877
+ if rounding_mode == 2:
1878
+ # floor
1879
+ return Tensor(np.floor_divide(input_x.asnumpy(), other_x))
1880
+ return None
1881
+
1882
+
1636
1883
  def infer_value_for_ReduceMax(input_x, axis, keep_dims):
1637
1884
  """Infer value for ReduceMax op."""
1638
1885
  return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
@@ -1791,7 +2038,7 @@ class Ones(Primitive):
1791
2038
  Tensor, whose dtype and size are defined by input.
1792
2039
 
1793
2040
  Raises:
1794
- TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
2041
+ TypeError: If `shape` is neither an int nor a tuple/list/Tensor of int.
1795
2042
 
1796
2043
  Supported Platforms:
1797
2044
  ``Ascend`` ``GPU`` ``CPU``
@@ -1821,13 +2068,23 @@ class Ones(Primitive):
1821
2068
  pass
1822
2069
 
1823
2070
  def __call__(self, size, type=None):
1824
- return _convert_stub(pyboost_ones(self, [size, type if type is None \
2071
+ # Add for jit context.
2072
+ if jit_context() and jit_context().compiled:
2073
+ return None
2074
+ res = _convert_stub(pyboost_ones(self, [size, type if type is None \
1825
2075
  else handler.dtype_to_type_id('Ones', 'type', type)]))
2076
+ # Add for jit context.
2077
+ if jit_context():
2078
+ if validator.is_stub_tensor(res):
2079
+ res = res.stub_sync()
2080
+ return jit_context().run_op(self, res, size, type if type is None \
2081
+ else handler.dtype_to_type_id('Ones', 'type', type))
2082
+ return res
1826
2083
 
1827
2084
 
1828
2085
  class Zeros(Primitive):
1829
2086
  r"""
1830
- Zeros will be deprecated in the future. Please use class `mindspore.ops.zeros` instead.
2087
+ Zeros will be deprecated in the future. Please use class :func:`mindspore.ops.zeros` instead.
1831
2088
 
1832
2089
  Creates a tensor filled with value zeros.
1833
2090
 
@@ -1845,7 +2102,7 @@ class Zeros(Primitive):
1845
2102
  Tensor, whose dtype and size are defined by input.
1846
2103
 
1847
2104
  Raises:
1848
- TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
2105
+ TypeError: If `shape` is neither an int nor a tuple/list/Tensor of int.
1849
2106
 
1850
2107
  Supported Platforms:
1851
2108
  ``Ascend`` ``GPU`` ``CPU``
@@ -1871,8 +2128,18 @@ class Zeros(Primitive):
1871
2128
  pass
1872
2129
 
1873
2130
  def __call__(self, size, type=None):
1874
- return _convert_stub(pyboost_zeros(self, [size, type if type is None else \
2131
+ # Add for jit context.
2132
+ if jit_context() and jit_context().compiled:
2133
+ return None
2134
+ res = _convert_stub(pyboost_zeros(self, [size, type if type is None else \
1875
2135
  handler.dtype_to_type_id('Zeros', 'type', type)]))
2136
+ # Add for jit context.
2137
+ if jit_context():
2138
+ if validator.is_stub_tensor(res):
2139
+ res = res.stub_sync()
2140
+ return jit_context().run_op(self, res, size, type if type is None else \
2141
+ handler.dtype_to_type_id('Zeros', 'type', type))
2142
+ return res
1876
2143
 
1877
2144
 
1878
2145
  def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
@@ -1880,116 +2147,132 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
1880
2147
  scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
1881
2148
  input_layout='BSH', sparse_mode=0):
1882
2149
  r"""
1883
- The interface is not open to the public, just for internal use,
2150
+ Implement self-attention calculations in training scenarios.
2151
+
2152
+ - B: Batch size. Value range 1 to 2k.
2153
+ - S1: Sequence length of `query`. Value range 1 to 512k.
2154
+ - S2: Sequence length of `key` and `value`. Value range 1 to 512k.
2155
+ - N1: Num heads of `query`. Value range 1 to 256.
2156
+ - N2: Num heads of `key` and `value`, and N2 must be a factor of N1.
2157
+ - D: Head size. The value ranges is a multiple of 16, with the max value of 512.
2158
+ - H1: Hidden size of `query`, which equals to N1 * D.
2159
+ - H2: Hidden size of `key` and `value`, which equals to N2 * D.
2160
+
2161
+ The self attention calculation formula is defined as:
1884
2162
 
1885
2163
  .. math::
1886
2164
  \begin{array}{ll} \\
1887
- y = Dropout(Softmax(Mask(scale_value \mul (real_shift + query * key), attn_mask), -1), keep\_prob) \\
1888
- \mul value \\
2165
+ \text { attention_out }=\operatorname{Dropout}\left(\operatorname{Softmax}\left(\text
2166
+ { Mask(scale } *\left(\text { query } * \mathrm{key}^{\top}\right)+\text { pse }\right)\text
2167
+ {, atten_mask), keep_prob) } *\right. \text { value }
1889
2168
  \end{array}
1890
2169
 
1891
- B -- Batch size. Value range 1 to 2k.
1892
- S1 -- Sequence length of query. Value range 1 to 512k.
1893
- S2 -- Sequence length of key and value. Value range 1 to 512k.
1894
- N1 -- Num heads of query. Value range 1 to 256.
1895
- N2 -- Num heads of key and value, and N2 must be a factor of N1.
1896
- D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
1897
- H1 -- Hidden size of query, which equals to N1 * D.
1898
- H2 -- Hidden size of key and value, which equals to N2 * D.
1899
-
1900
2170
  .. warning::
1901
- This is an experimental API that is subject to change or deletion. Only support on Atlas A2 training series.
2171
+ - This is an experimental API that is subject to change or deletion.
2172
+ - Only support on Atlas A2 training series.
1902
2173
 
1903
2174
  Args:
1904
- query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
1905
- `(B, N1, S1, D)`, `(S1, B, H1)`, `(B, S1, N1, D)` or `(T1, N1, D)`.
1906
- key (Tensor[float16, bfloat16]): The key tensor. Input tensor of shape :math:`(B, S2, H2)`,
1907
- `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`.
1908
- value (Tensor[float16, bfloat16]): The value tensor. Input tensor of shape :math:`(B, S2, H2)`,
1909
- `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. The key and value have the same shape.
1910
- head_num (int): The head num of query, equal to N1.
1911
- real_shift (Union[Tensor[float16, bfloat16], None]): Also known as pse. The position embedding code. If S
1912
- is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of
1913
- the lower triangle for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
1914
- `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1915
-
1916
- - ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
1917
- In this scenario, real_shift is `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1918
- - Non-ALiBi scenario: real_shift is `(B, N1, S1, S2)`, `(1, N1, S1, S2)`.
1919
-
1920
- The shape of `real_shift` should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)` when input_layout is
1921
- `TND`.
1922
- drop_mask (Union[Tensor[uint8], None]): The dropout mask tensor. Input tensor of shape :math:
1923
- `(B, N1, S1, S2 // 8) or None`. S2 is a multiple of 8 when not None.
1924
- padding_mask (None): Reserved parameter. Not implemented yet.
1925
- attn_mask (Union[Tensor[uint8], Tensor[bool], None]): The attention mask tensor. For each element, 0
1926
- indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
1927
- `(B, 1, S1, S2)`, `(S1, S2)` or `(2048, 2048)`. In compression scenario, sparse_mode is 2, 3, or 4,
1928
- attn_mask must be `(2048, 2048)`. When sparse_mode is 5, attn_mask must be `(B, N1, S1, S2)`,
1929
- `(B, 1, S1, S2)`. When sparse_mode is 0 and 1, attn_mask should be `(B, N1, S1, S2)`, `(B, 1, S1, S2)`,
1930
- `(S1, S2)`.
1931
- prefix (Union[List[int64], Tuple[int64] None]): N value of each Batch in the prefix sparse calculation
1932
- scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when sparse_mode is 5.
2175
+ query (Tensor): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
2176
+ :math:`(B, N1, S1, D)`, :math:`(S1, B, H1)`, :math:`(B, S1, N1, D)` or :math:`(T1, N1, D)`.
2177
+ The supported dtype is float16 and bfloat16.
2178
+ key (Tensor): The key tensor with the same dtype as `query`. Supported shape: :math:`(B, S2, H2)`,
2179
+ :math:`(B, N2, S2, D)`, :math:`(S2, B, H2)`, :math:`(B, S2, N2, D)` or :math:`(T2, N2, D)`.
2180
+ value (Tensor): The value tensor with the same dtype and shape as `key`.
2181
+ head_num (int): The head num of `query`, equal to N1.
2182
+ real_shift (Tensor, optional): The position embedding code which is also known as pse, it has the same
2183
+ dtype as `query`.
2184
+ Default: ``None``.
2185
+ If S is greater than 1024 and the mask of the lower triangle is used, only the inverse 1024 lines of
2186
+ the lower triangle is used for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
2187
+ :math:`(1, N1, S1, S2)`, :math:`(B, N1, 1024, S2)`, :math:`(1, N1, 1024, S2)`.
2188
+
2189
+ - ALiBi scenario: `real_shift` must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
2190
+ In this scenario, `real_shift` is :math:`(B, N1, 1024, S2)`, :math:`(1, N1, 1024, S2)`.
2191
+ - Non-ALiBi scenario: `real_shift` is :math:`(B, N1, S1, S2)`, :math:`(1, N1, S1, S2)`.
2192
+ - input_layout is TND: shape should be :math:`(B, N1, 1024, S2)` and :math:`(1, N1, 1024, S2)`.
2193
+
2194
+ drop_mask (Tensor, optional): The dropout mask tensor of uint8. Input tensor of shape
2195
+ :math:`(B, N1, S1, S2 // 8) or None`. `S2` is a multiple of 8 when not None. Default: ``None``.
2196
+ padding_mask (Tensor, optional): Reserved parameter. Not implemented yet. Default: ``None``.
2197
+ attn_mask (Tensor, optional): The attention mask tensor of bool or uint8. For each element, 0/False
2198
+ indicates retention and 1/True indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
2199
+ :math:`(B, 1, S1, S2)`, :math:`(S1, S2)` or :math:`(2048, 2048)`.
2200
+ Default: ``None``.
2201
+
2202
+ - In compression scenario, `sparse_mode` is 2, 3, or 4, `attn_mask` must be :math:`(2048, 2048)`.
2203
+ - When `sparse_mode` is 5, `attn_mask` should be :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`.
2204
+ - When `sparse_mode` is 0 and 1, `attn_mask` should be :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`,
2205
+ :math:`(S1, S2)`.
2206
+
2207
+ prefix (Union[Tensor, tuple[int], list[int]], optional): N value of each Batch in the prefix sparse calculation
2208
+ scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when `sparse_mode` is 5.
2209
+ Default: ``None``.
1933
2210
  If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
1934
- actual_seq_qlen (Union[List[int64], Tuple[int64], None]): Size of query corresponding to each batch, array
1935
- with increasing values and the last value equal to T1.
1936
- actual_seq_kvlen (Union[List[int64], Tuple[int64], None]): Size of key and value corresponding to each batch,
1937
- array with increasing values and the last value equal to T2.
1938
- keep_prob (float): The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0. when keep_prob
1939
- is 1.0, drop_mask should be none.
1940
- scale_value (float): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0.
1941
- pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
1942
- When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1943
- next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
1944
- When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1945
- The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the
1946
- valid area on the attn_mask matrix. It must ensure that the band is not empty.
2211
+ actual_seq_qlen (Union[Tensor, tuple[int], list[int]], optional): Size of query corresponding to each batch,
2212
+ array with increasing values and the last value equal to T1.
2213
+ Default: ``None``.
2214
+ actual_seq_kvlen (Union[Tensor, tuple[int], list[int]], optional): Size of key and value corresponding
2215
+ to each batch, array with increasing values and the last value equal to T2.
2216
+ Default: ``None``.
2217
+ keep_prob (double, optional): The keep probability of dropout. Value range is (0.0, 1.0]. When `keep_prob`
2218
+ is 1.0, `drop_mask` should be None.
2219
+ Default: ``1.0``.
2220
+ scalar_value (double, optional): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5).
2221
+ Default: ``1.0``.
2222
+ pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
2223
+ When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect.
2224
+ Default: ``2147483647``.
2225
+ next_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted backward.
2226
+ When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect. Default: ``2147483647``.
2227
+ The value of `pre_tokens` corresponds to S1, and the value of `next_tokens` corresponds to S2.
2228
+ They define the valid area on the `attn_mask` matrix. It must ensure that the band is not empty.
1947
2229
  The following values are not allowed:
1948
2230
 
1949
2231
  - pre_tokens < 0 and next_tokens < 0.
1950
2232
  - (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
1951
2233
  - (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
1952
2234
 
1953
- inner_precise (int): The parameter is reserved and not implemented yet. Default: 0.
1954
- input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH", "BNSD",
1955
- "SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH".
2235
+ inner_precise (int, optional): The parameter is reserved and not implemented yet. Default:``0``.
2236
+ input_layout (str, optional): Specifies the layout of input `query`, `key` and `value`. The value can be
2237
+ "BSH", "BNSD", "SBH", "BSND" or "TND". "TND" is an experimental format. Default: ``"BSH"``.
1956
2238
  When input_layout is "TND", the following restrictions must be met.
1957
- There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
2239
+ Assume there are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
1958
2240
  value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
1959
2241
  list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
1960
2242
  sum(list_seq_k) = 22.
1961
2243
  max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
1962
2244
  qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
1963
2245
 
1964
- - The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024.
1965
- - When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none.
2246
+ - The lengths of two lists must be the same, and size of list is batch. batch is less than or equal to
2247
+ 1024.
2248
+ - When `input_layout` is "TND", `actual_seq_qlen` and `actual_seq_kvlen` must be not none.
1966
2249
  Otherwise, they are none.
1967
- - The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
2250
+ - The `actual_seq_qlen` and `actual_seq_kvlen` are the cumulative sum of sequence of key/value, so they must
1968
2251
  be non-decreasing.
1969
- - If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
1970
- list_seq_k is greater than 1024. Real_shift should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)`, and
1971
- S2 is equal to max_seqlen_k.
1972
- - Attn mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask
1973
- should be `(2048, 2048)`.
1974
- - The shape of drop_mask is (qk_pointer * N1 // 8,).
1975
- - Prefix is none.
1976
- - Next_tokens is 0, and pre_tokens is not less than max_seqlen_q.
1977
- - When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
2252
+ - If `real_shift` is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
2253
+ list_seq_k is greater than 1024. `real_shift` should be :math:`(B, N1, 1024, S2)` and
2254
+ :math:`(1, N1, 1024, S2)`, and S2 is equal to max_seqlen_k.
2255
+ - `attn_mask` must be a lower trianglar matrix, so `sparse_mode` should be 2 or 3. The shape of `attn_mask`
2256
+ should be :math:`(2048, 2048)`.
2257
+ - The shape of `drop_mask` is :math:`(qk\_pointer * N1 // 8,)`.
2258
+ - `prefix` is none.
2259
+ - `next_tokens` is 0, and `pre_tokens` is not less than max_seqlen_q.
2260
+ - When `sparse_mode` is 3, S1 of each batch should be less than or equal to S2.
1978
2261
  - 0 should not exist in list_seq_k.
1979
2262
 
1980
- sparse_mode (int): Indicates sparse mode. Default 0.
2263
+ sparse_mode (int, optional): Indicates sparse mode. Default: ``0``.
1981
2264
 
1982
- - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
1983
- and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full
1984
- attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and
1985
- nextTokens needs to be calculated.
1986
- - 1: Represents allMask, that is, passing in the complete attn_mask matrix.
2265
+ - 0: Indicates the defaultMask mode. If `attn_mask` is not passed, the mask operation is not performed,
2266
+ `next_tokens` and `pre_tokens` (internally assigned as INT_MAX) are ignored. If passed in, the full
2267
+ `attn_mask` matrix (S1 * S2) needs to be passed in, indicating that the part between `next_tokens` and
2268
+ `pre_tokens` needs to be calculated.
2269
+ - 1: Represents allMask, that is, passing in the complete `attn_mask` matrix.
1987
2270
  - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
1988
- vertex, and the optimized attn_mask matrix (2048*2048) is required.
2271
+ vertex, and the optimized `attn_mask` matrix (2048*2048) is required.
1989
2272
  - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
1990
- right vertex, and the optimized attn_mask matrix (2048*2048) is required.
1991
- - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
1992
- optimized attn_mask matrix (2048*2048) is required.
2273
+ right vertex, and the optimized `attn_mask` matrix (2048*2048) is required.
2274
+ - 4: Represents the band scenario, that is, the part between counting `next_tokens` and `pre_tokens`,
2275
+ and the optimized `attn_mask` matrix (2048*2048) is required.
1993
2276
  - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
1994
2277
  width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
1995
2278
  of each Batch axis is different, not implemented yet.
@@ -1998,8 +2281,27 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
1998
2281
  - 8: Represents the block_local scenario, not implemented yet.
1999
2282
 
2000
2283
  Returns:
2001
- attention_out (Tensor[float16, bfloat16]), The output of attention, its shape, and data type are the same
2002
- as the query.
2284
+ attention_out (Tensor) - The output of attention, it has the same shape and dtype as `query`.
2285
+
2286
+ Raises:
2287
+ TypeError: Dtype of `query` is not float16 or bfloat16.
2288
+ TypeError: `query`, `key` and `value` don't have the same dtype.
2289
+ TypeError: Dtype of `attn_mask` is not bool or uint8.
2290
+ TypeError: Dtype of `real_shift` has a different dtype as `query`.
2291
+ TypeError: `scalar_value` or `keep_prob` is not a double number.
2292
+ TypeError: `input_layout` is not a string.
2293
+ TypeError: `num_key_value_heads` is not an int.
2294
+ TypeError: `sparse_mode` is not an int.
2295
+ TypeError: `real_shift` is not Tensor type.
2296
+ TypeError: `drop_mask` is not Tensor type.
2297
+ TypeError: `padding_mask` is not Tensor type.
2298
+ TypeError: `attn_mask` is not Tensor type.
2299
+ ValueError: `input_layout` is a string but not valid.
2300
+ RuntimeError: `head_num` is not divisible by `N2`.
2301
+ RuntimeError: `head_num` is not greater than 0.
2302
+ RuntimeError: `attn_mask` shape is not valid.
2303
+ RuntimeError: The specified value of `sparse_mode` is invalid.
2304
+ RuntimeError: D-axis of `query`, `key` and `value` is not the same.
2003
2305
 
2004
2306
  Supported Platforms:
2005
2307
  ``Ascend``
@@ -2023,6 +2325,452 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
2023
2325
  actual_seq_kvlen)[3]
2024
2326
 
2025
2327
 
2328
+ def fused_infer_attention_score(query, key, value, *, pse_shift=None, atten_mask=None, actual_seq_lengths=None,
2329
+ actual_seq_lengths_kv=None, dequant_scale1=None, quant_scale1=None, dequant_scale2=None,
2330
+ quant_scale2=None, quant_offset2=None, antiquant_scale=None, antiquant_offset=None,
2331
+ key_antiquant_scale=None, key_antiquant_offset=None, value_antiquant_scale=None,
2332
+ value_antiquant_offset=None, block_table=None, query_padding_size=None,
2333
+ kv_padding_size=None, key_shared_prefix=None, value_shared_prefix=None,
2334
+ actual_shared_prefix_len=None, num_heads=1, scale=1.0, pre_tokens=2147483647,
2335
+ next_tokens=2147483647, input_layout='BSH', num_key_value_heads=0, sparse_mode=0,
2336
+ inner_precise=1, block_size=0, antiquant_mode=0, key_antiquant_mode=0,
2337
+ value_antiquant_mode=0, softmax_lse_flag=False):
2338
+ r"""
2339
+ This is a FlashAttention function designed for both incremental and full inference scenarios. It supports full
2340
+ inference scenarios (PromptFlashAttention) as well as incremental inference scenarios (IncreFlashAttention).
2341
+ When the S dimension of the query tensor (Q_S) equals 1, it enters the IncreFlashAttention branch; otherwise,
2342
+ it enters the PromptFlashAttention branch.
2343
+
2344
+ .. math::
2345
+
2346
+ Attention(Q,K,V) = Softmax(\frac{QK^{T}}{\sqrt{d}})V
2347
+
2348
+ .. warning::
2349
+ - This is an experimental API that is subject to change or deletion.
2350
+ - For Ascend, only the Atlas A2 training series products and Atlas 800I A2 inference products are currently
2351
+ supported.
2352
+
2353
+ Note:
2354
+ - The data layout formats of query, key and value can be interpreted from multiple dimensions, as shown below:
2355
+
2356
+ - B, Batch size. Represents the batch size of the input samples.
2357
+ - S, Sequence length. Represents the sequence length of the input samples. S1 represents the sequence length
2358
+ of the query, and S2 represents the sequence length of the key/value.
2359
+ - H, Head size. Represents the size of the hidden layer.
2360
+ - N, Head nums. Represents the number of attention heads.
2361
+ - D, Head dims. Represents the smallest unit size of the hidden layer, satisfying :math:`D = H / N`.
2362
+
2363
+ Args:
2364
+ query (Tensor): The query input of the attention structure, with data type of float16, bfloat16 or int8.
2365
+ Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or :math:`(B, S, N, D)`.
2366
+ key (Union[Tensor, tuple[Tensor], list[Tensor]]): The key input of the attention structure, with data type
2367
+ of float16, bfloat16 or int8. Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or
2368
+ :math:`(B, S, N, D)`.
2369
+ value (Union[Tensor, tuple[Tensor], list[Tensor]]): The value input of the attention structure, with data
2370
+ type of float16, bfloat16 or int8. Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or
2371
+ :math:`(B, S, N, D)`.
2372
+
2373
+ Keyword Args:
2374
+ pse_shift (Tensor, optional): The padding mask tensor with data type of float16 or bfloat16.
2375
+ Default: ``None``.
2376
+
2377
+ - When Q_S is not 1, if pse_shift is of type float16, the query must be of type float16 or int8.
2378
+ If pse_shift is of type bfloat16, the query must also be of type bfloat16. The input shape
2379
+ must be either :math:`(B, N, Q\_S, KV\_S)` or :math:`(1, N, Q\_S, KV\_S)`, where Q_S corresponds to the
2380
+ S dimension of the query shape, and KV_S corresponds to the S dimension of the key and value shapes.
2381
+ For scenarios where the KV_S of pse_shift is not 32-aligned, it is recommended to pad it
2382
+ to 32 bytes to improve performance. The padding values for the extra portions are not restricted.
2383
+ - When Q_S is 1, if pse_shift is of type float16, the query must also be of type float16.
2384
+ If pse_shift is of type bfloat16, the query must be of type bfloat16. The input shape must be
2385
+ :math:`(B, N, 1, KV\_S)` or :math:`(1, N, 1, KV\_S)`, where KV_S corresponds to the S dimension of the
2386
+ key/value shapes. For scenarios where the KV\_S of pse_shift is not 32-aligned, it is recommended
2387
+ to pad it to 32 bytes to improve performance. The padding values for the extra portions are not
2388
+ restricted.
2389
+
2390
+ atten_mask (Tensor, optional): The attention mask tensor for the result of query*key with data type of int8,
2391
+ uint8 or bool. For each element, 0 indicates retention and 1 indicates discard.
2392
+ Default: ``None``.
2393
+
2394
+ - When Q_S is not 1, the recommended input shapes are Q_S,KV_S; B,Q_S,KV_S; 1,Q_S,KV_S; B,1,Q_S,KV_S
2395
+ or 1,1,Q_S,KV_S.
2396
+ - When Q_S is 1, the recommended input shapes are B,KV_S; B,1,KV_S or B,1,1,KV_S.
2397
+
2398
+ actual_seq_lengths (Union[tuple[int], list[int], Tensor], optional): Describe actual sequence length of the
2399
+ query with data type of int64. If this parameter is not specified, it can be set to None, indicating that
2400
+ it matches the S dimension of the query shape. Constraint: The effective sequence length for each batch in
2401
+ this parameter should not exceed the corresponding batch's sequence length in the query. When Q_S is 1, this
2402
+ parameter is ignored.
2403
+ Default: ``None``.
2404
+ actual_seq_lengths_kv (Union[tuple[int], list[int], Tensor], optional): Describe actual sequence length of the
2405
+ key and value with data type of int64. If this parameter is not specified, it can be set to None,
2406
+ indicating that it matches the S dimension of the key and value shape. Constraint: The effective sequence
2407
+ length for each batch in this parameter should not exceed the corresponding batch's sequence length in the
2408
+ key and value.
2409
+ Default: ``None``.
2410
+ dequant_scale1 (Tensor, optional): Quantization factors for inverse quantization after BMM1 with data type of
2411
+ uint64. Supports per-tensor mode. If not used, set it to None.
2412
+ Default: ``None``.
2413
+ quant_scale1 (Tensor, optional): Quantization factors for quantization before BMM2 with data type of float32.
2414
+ Supports per-tensor mode. If not used, set it to None.
2415
+ Default: ``None``.
2416
+ dequant_scale2 (Tensor, optional): Quantization factors for inverse quantization after BMM2 with data type of
2417
+ uint64. Supports per-tensor mode. If not used, set it to None.
2418
+ Default: ``None``.
2419
+ quant_scale2 (Tensor, optional): Quantization factors for output quantization with data type of float32,
2420
+ bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None.
2421
+ Default: ``None``.
2422
+ quant_offset2 (Tensor, optional): Quantization offset for output quantization with data type of float32,
2423
+ bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None.
2424
+ Default: ``None``.
2425
+
2426
+ For scenarios where the input is int8 and the output is int8: the parameters dequant_scale1, quant_scale1,
2427
+ dequant_scale2, and quant_scale2 must all be provided. The parameter quant_offset2 is optional and defaults
2428
+ to 0 if not specified.
2429
+
2430
+ - When the output is int8 and quant_scale2 and quant_offset2 are per-channel, left padding, Ring Attention,
2431
+ or D-axis misalignment (not 32-aligned) scenarios are not supported.
2432
+ - When the output is int8, scenarios with sparse_mode as band and pre_tokens/next_tokens being negative are
2433
+ not supported.
2434
+ - When the output is int8, if quant_offset2 is not None and empty tensor, and the sparse_mode, pre_tokens,
2435
+ and next_tokens meet the following conditions, certain rows of the matrix may not participate in
2436
+ calculations, leading to errors. This scenario will be intercepted (solution: if this scenario should
2437
+ not be intercepted, quantization should be performed outside the FIA interface, not enabled inside the
2438
+ FIA interface):
2439
+
2440
+ - sparse_mode = 0, if atten_mask is a not None and each batch's
2441
+ actual_seq_lengths - actual_seq_lengths_kv - pre_tokens > 0 or next_tokens < 0, it will meet the
2442
+ interception condition.
2443
+ - sparse_mode = 1 or 2, no interception condition will occur.
2444
+ - sparse_mode = 3, if each batch's actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the
2445
+ interception condition.
2446
+ - sparse_mode = 4, if pre_tokens < 0 or each batch's
2447
+ next_tokens + actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the interception
2448
+ condition.
2449
+
2450
+ For scenarios where the input is int8 and the output is float16: the parameters dequant_scale1,
2451
+ quant_scale1, and dequant_scale2 must all be provided.
2452
+
2453
+ For scenarios where the input is entirely float16 or bfloat16 and the output is int8: the parameter
2454
+ quant_scale2 must be provided. The parameter quant_offset2 is optional and defaults to 0 if not specified.
2455
+
2456
+ The parameters quant_scale2 and quant_offset2 support both per-tensor and per-channel modes and two data
2457
+ types: float32 and bfloat16. If quant_offset2 is provided, its type and shape must match those of
2458
+ quant_scale2. When the input is bfloat16, both float32 and bfloat16 are supported; otherwise, only float32
2459
+ is supported. For per-channel mode: When the output layout is BSH, the product of all dimensions in
2460
+ quant_scale2 must equal H. For other layouts, the product must equal N * D. When the output layout is BSH,
2461
+ it is recommended to set the shape of quant_scale2 as :math:`(1, 1, H)` or :math:`(H)`. When the output
2462
+ layout is BNSD, it is recommended to set the shape as :math:`(1, N, 1, D)` or :math:`(N, D)`. When the
2463
+ output layout is BSND, it is recommended to set the shape as :math:`(1, 1, N, D)` or :math:`(N, D)`.
2464
+
2465
+ antiquant_scale (Tensor, optional): Inverse quantization factors with data type of float16, float32 or bfloat16.
2466
+ Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes.
2467
+ Default: ``None``.
2468
+ antiquant_offset (Tensor, optional): Inverse quantization offset with data type of float16, float32 or bfloat16.
2469
+ Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes.
2470
+ Default: ``None``.
2471
+
2472
+ Constraints for antiquant_scale and antiquant_offset parameters:
2473
+
2474
+ - Supports three modes: per-channel, per-tensor, and per-token:
2475
+
2476
+ - Per-channel mode: The shape of both parameters in the BNSD layout is :math:`(2, N, 1, D)`, the shape
2477
+ in the BSND layout is :math:`(2, N, D)`, and the shape in the BSH layout is :math:`(2, H)`, where 2
2478
+ corresponds to the key and value, and N represents num_key_value_heads. The parameter data type is
2479
+ the same as the query data type, and antiquant_mode should be set to 0.
2480
+ - Per-tensor mode: The shape of both parameters is :math:`(2)`, the data type is the same as the query
2481
+ data type, and antiquant_mode should be set to 0.
2482
+ - Per-token mode: The shape of both parameters is :math:`(2, B, S)`, the data type is fixed to float32,
2483
+ and antiquant_mode should be set to 1.
2484
+
2485
+ - Supports both symmetric and asymmetric quantization:
2486
+
2487
+ - Asymmetric quantization mode: Both antiquant_scale and antiquant_offset must be provided.
2488
+ - Symmetric quantization mode: antiquant_offset can be empty (``None``). If antiquant_offset is empty,
2489
+ symmetric quantization is performed. If antiquant_offset is provided, asymmetric quantization is
2490
+ performed.
2491
+
2492
+ key_antiquant_scale (Tensor, optional): Inverse quantization factors for the key, with data type of float16,
2493
+ float32 or bfloat16, when the KV fake quantization parameters are separated.
2494
+ Supports per-tensor, per-channel and per-token modes.
2495
+ Default: ``None``. Invalid when Q_S > 1.
2496
+ key_antiquant_offset (Tensor, optional): Inverse quantization offset for the key, with data type of float16,
2497
+ float32 or bfloat16, when the KV fake quantization parameters are separated.
2498
+ Supports per-tensor, per-channel and per-token modes.
2499
+ Default: ``None``. Invalid when Q_S > 1.
2500
+ value_antiquant_scale (Tensor, optional): Inverse quantization factors for the value, with data type of
2501
+ float16, float32 or bfloat16, when the KV fake quantization parameters are separated.
2502
+ Supports per-tensor, per-channel and per-token modes.
2503
+ Default: ``None``. Invalid when Q_S > 1.
2504
+ value_antiquant_offset (Tensor, optional): Inverse quantization offset for the value, with data type of
2505
+ float16, float32 or bfloat16, when the KV fake quantization parameters are separated.
2506
+ Supports per-tensor, per-channel and per-token modes.
2507
+ Default: ``None``. Invalid when Q_S > 1.
2508
+ block_table (Tensor, optional): Block mapping table in KV cache for PageAttention, with data type of int32.
2509
+ If not used, set it to None.
2510
+ Default: ``None``. Invalid when Q_S > 1.
2511
+ query_padding_size (Tensor, optional): The query padding size with data type of int64. Indicates whether the
2512
+ data in each batch of the query is right-aligned, and how many elements are right-aligned.
2513
+ Default: ``None``. Invalid when Q_S is 1.
2514
+ kv_padding_size (Tensor, optional): The key and value padding size with data type of int64. Indicates whether
2515
+ the data in each batch of the key and value is right-aligned, and how many elements are right-aligned.
2516
+ Default: ``None``. Invalid when Q_S is 1.
2517
+ key_shared_prefix (Tensor, optional): Shared prefix of the key. This is a reserved parameter and is not yet
2518
+ enabled. Default: ``None``.
2519
+ value_shared_prefix (Tensor, optional): Shared prefix of the value. This is a reserved parameter and is not yet
2520
+ enabled. Default: ``None``.
2521
+ actual_shared_prefix_len (Union[tuple[int], list[int], Tensor], optional): Describe the actual length of shared
2522
+ prefix. This is a reserved parameter and is not yet enabled.
2523
+ Default: ``None``.
2524
+ num_heads (int, optional): The number of heads in the query, equal to N when input_layout is BNSD.
2525
+ Default: ``1``.
2526
+ scale (double, optional): The scale value indicating the scale coefficient, which serves as the scalar value for
2527
+ the Muls in the calculation. Generally, the value is :math:`1.0 / \sqrt{d}`. Default: ``1.0``.
2528
+ pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
2529
+ Default: ``2147483647``. Invalid when Q_S is 1.
2530
+ next_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted backward.
2531
+ Default: ``2147483647``. Invalid when Q_S is 1.
2532
+ input_layout (str, optional): Specifies the layout of input query, key and value. BSH, BNSD, BSND or
2533
+ BNSD_BSND is supported. When the layout is BNSD_BSND, it means the input is in the BNSD format and
2534
+ the output is in the BSND format, this is only supported when Q_S > 1.
2535
+ Default: ``BSH``.
2536
+ num_key_value_heads (int, optional): Head numbers of key/value which are used in GQA (Grouped-Query Attention)
2537
+ scenario. Default: ``0``. A value of 0 means it is equal to the number of key/value heads. The num_heads
2538
+ must be divisible by num_key_value_heads, and the ratio of num_heads to num_key_value_heads must not be
2539
+ greater than 64. When the layout is BNSD, the num_key_value_heads must also equals to the N dimension of
2540
+ the key/value shapes, otherwise, an execution error will occur.
2541
+ sparse_mode (int, optional): Indicates sparse mode. Default ``0``. Invalid when Q_S is 1.
2542
+
2543
+ - 0: Indicates the defaultMask mode. If atten_mask is not passed, the mask operation is not performed,
2544
+ and pre_tokens and next_tokens(internally assigned as INT_MAX) are ignored. If passed in, the complete
2545
+ atten_mask matrix (S1 * S2) also must be passed in, indicating that the part between pre_tokens and
2546
+ next_tokens needs to be calculated.
2547
+ - 1: Represents allMask. The complete atten_mask matrix (S1 * S2) is required.
2548
+ - 2: Represents the mask in leftUpCausal mode. The optimized atten_mask matrix (2048*2048) is required.
2549
+ - 3: Represents the mask in rightDownCausal mode, corresponding to the lower triangular scenario divided by
2550
+ the right vertex. The optimized atten_mask matrix (2048*2048) is required.
2551
+ - 4: Represents the mask in band mode, that is, the part between counting pre_tokens and next_tokens. The
2552
+ optimized atten_mask matrix (2048*2048) is required.
2553
+ - 5: Represents the prefix scenario, not implemented yet.
2554
+ - 6: Represents the global scenario, not implemented yet.
2555
+ - 7: Represents the dilated scenario, not implemented yet.
2556
+ - 8: Represents the block_local scenario, not implemented yet.
2557
+
2558
+ inner_precise (int, optional): There are four modes: 0, 1, 2, and 3, represented by 2 bits: bit 0 (bit0)
2559
+ represents the choice for high precision or high performance, and bit 1 (bit1) indicates whether row-wise
2560
+ invalidity correction is applied.
2561
+
2562
+ - 0: Enable high-precise mode, without row-wise invalidity correction.
2563
+ - 1: High-performance mode, without row-wise invalidity correction.
2564
+ - 2: Enable high-precise mode, with row-wise invalidity correction.
2565
+ - 3: High-performance mode, with row-wise invalidity correction.
2566
+
2567
+ When Q_S > 1, if sparse_mode is 0 or 1 and a user-defined mask is provided, it is recommended to enable
2568
+ row-wise invalidity correction. Only support 0 and 1 when Q_S is 1. Default: ``1``.
2569
+
2570
+ High-precise and high-performance are only effective for float16 inputs; Row invalidity correction
2571
+ is effective for float16, bfloat16, and int8 inputs.
2572
+ Currently, 0 and 1 are reserved configuration values. If there is a situation where an entire row in the
2573
+ "mask portion involved in computation" is all 1s, precision may degrade. In such cases, you can try
2574
+ setting this parameter to 2 or 3 to enable row invalidity correction for improved precision. However,
2575
+ this configuration will result in decreased performance.
2576
+ If the function can detect the presence of invalid row scenarios, e.g. in cases where sparse_mode is 3
2577
+ and S_q > S_kv, it will automatically enable row invalidity computation.
2578
+
2579
+ block_size (int, optional): Maximum number of tokens per block in the KV cache block for PageAttention.
2580
+ Default: ``0``. Invalid when Q_S > 1.
2581
+ antiquant_mode (int, optional): Fake-quantization mode, 0: per-channel (per-channel includes per-tensor),
2582
+ 1: per-token. The per-channel and per-tensor modes can be distinguished by the dimension of the input
2583
+ shape. When the dimension is 1, it runs in per-tensor mode; otherwise, it runs in per-channel mode.
2584
+ Default: ``0``. Invalid when Q_S > 1.
2585
+ key_antiquant_mode (int, optional): Fake-quantization mode for the key. 0: per-channel (per-channel includes
2586
+ per-tensor), 1: per-token. Default: ``0``. Invalid when Q_S > 1.
2587
+ value_antiquant_mode (int, optional): Fake-quantization mode for the value. 0: per-channel (per-channel includes
2588
+ per-tensor), 1: per-token. Default: ``0``. Invalid when Q_S > 1.
2589
+ softmax_lse_flag (bool, optional): Whether to output softmax_lse. Default: ``False``.
2590
+
2591
+ Returns:
2592
+ attention_out (Tensor), the attention score with data type of float16, bfloat16 or int8. When the input_layout
2593
+ is BNSD_BSND, the shape is :math:`(B, S, N, D)`. In all other cases, the shape is consistent with the
2594
+ input query shape.
2595
+
2596
+ softmax_lse (Tensor), the softmax_lse with data type of float32, obtained by taking the lse (log, sum and exp)
2597
+ of the result of query*key. Specifically, the Ring Attention algorithm first takes the max of the result of
2598
+ query*key, obtaining softmax_max. The result of query*key is then subtracted by softmax_max, followed by
2599
+ taking exp, and then the sum is computed to obtain softmax_sum. Finally, the log of softmax_sum is taken,
2600
+ and softmax_max is added to obtain softmax_lse. The softmax_lse is only calculated when softmax_lse_flag
2601
+ is True, and the shape would be :math:`(B, N, Q\_S, 1)`. If softmax_lse_flag is False, then a tensor with
2602
+ shape :math:`(1)` filled with zeros would be returned. In graph mode with JitConfig set to O2, please ensure
2603
+ that the softmax_lse_flag is enabled before using softmax_lse; otherwise, an exception will occur.
2604
+
2605
+ Constraints:
2606
+ - Full Inference Scenario (Q_S > 1):
2607
+
2608
+ - Query, key, and value inputs functional usage restrictions:
2609
+
2610
+ - The B axis supports values less than or equal to 65535. If the input type includes int8, or
2611
+ if the input type is float16 or bfloat16 and the D axis is not 16-aligned, the B axis is only
2612
+ supported up to 128.
2613
+ - The N axis supports values less than or equal to 256, and the D axis supports values less than
2614
+ or equal to 512.
2615
+ - The S axis supports values less than or equal to 20,971,520 (20M). In some long sequence
2616
+ scenarios, if the computation load is too large, it may cause a timeout in the PFA operator
2617
+ (AICore error type with errorStr: "timeout or trap error"). In this case, it is recommended to
2618
+ perform an S split. Note: The computational load is affected by B, S, N, D, etc.; the larger the
2619
+ values, the greater the computational load. Typical long sequence timeout scenarios (where the
2620
+ product of B, S, N, and D is large) include, but are not limited to:
2621
+
2622
+ 1. B=1, Q_N=20, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
2623
+ 2. B=1, Q_N=2, Q_S=20971520, D=256, KV_N=2, KV_S=20971520;
2624
+ 3. B=20, Q_N=1, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
2625
+ 4. B=1, Q_N=10, Q_S=2097152, D=512, KV_N=1, KV_S=2097152.
2626
+
2627
+ - When the query, key, value, or attention_out type includes int8, the D axis must be 32-aligned.
2628
+ If all types are float16 or bfloat16, the D axis must be 16-aligned.
2629
+
2630
+ - The sparse_mode parameter currently only supports values 0, 1, 2, 3, and 4. Using any other values
2631
+ will result in an error.
2632
+
2633
+ - When sparse_mode = 0, if the atten_mask is None, or if the atten_mask is provided in the left
2634
+ padding scenario, the input parameters pre_tokens and next_tokens are ignored.
2635
+ - When sparse_mode = 2, 3, or 4, the shape of the atten_mask must be S,S or 1,S,S or 1,1,S,S, where
2636
+ S must be fixed at 2048, and the user must ensure the atten_mask is a lower triangular matrix. If
2637
+ no atten_mask is provided or if the shape is incorrect, an error will occur.
2638
+ - In sparse_mode = 1, 2, 3 scenarios, the pre_tokens and next_tokens inputs are ignored and assigned
2639
+ according to the relevant rules.
2640
+
2641
+ - The KV cache de-quantization only supports queries of type float16, where int8 keys and values are
2642
+ de-quantized to float16. The data range of the input key/value and the antiquant_scale must have a
2643
+ product within the range of (-1, 1). High-performance mode can guarantee precision; otherwise,
2644
+ high-precision mode should be enabled to ensure accuracy.
2645
+
2646
+ - Query left padding scenario:
2647
+
2648
+ - In the query left padding scenario, the formula for calculating the starting point of the query
2649
+ transport is: Q_S - query_padding_size - actual_seq_lengths. The formula for the
2650
+ ending point of the query transport is: Q_S - query_padding_size. The query transport
2651
+ starting point must not be less than 0, and the ending point must not exceed Q_S; otherwise,
2652
+ the results will be incorrect.
2653
+ - If the kv_padding_size in the query left padding scenario is less than 0, it will be set to 0.
2654
+ - The query left padding scenario must be enabled together with the actual_seq_lengths parameter,
2655
+ otherwise, the default is the query right padding scenario.
2656
+ - The query left padding scenario does not support PageAttention and cannot be enabled together with
2657
+ the block_table parameter.
2658
+
2659
+ - KV left padding scenario:
2660
+
2661
+ - In the KV left padding scenario, the formula for calculating the starting point of the key and
2662
+ value transport is: KV_S - kv_padding_size - actual_seq_lengths_kv. The formula
2663
+ for the ending point of the key and value transport is: KV_S - kv_padding_size. The
2664
+ key and value transport starting point must not be less than 0, and the ending point must not
2665
+ exceed KV_S; otherwise, the results will be incorrect.
2666
+ - If the kv_padding_size in the KV left padding scenario is less than 0, it will be set to 0.
2667
+ - The KV left padding scenario must be enabled together with the actual_seq_lengths_kv parameter,
2668
+ otherwise, the default is the KV right padding scenario.
2669
+ - The KV left padding scenario does not support PageAttention and cannot be enabled together with
2670
+ the block_table parameter.
2671
+
2672
+ - pse_shift functional usage restrictions:
2673
+
2674
+ - This function is supported when the query data type is float16, bfloat16, or int8.
2675
+ - If the query data type is float16 and pse_shift is enabled, it will force high-precision mode,
2676
+ inheriting the limitations of high-precision mode.
2677
+ - Q_S must be greater than or equal to the length of the query S, and KV_S must be greater than
2678
+ or equal to the length of the key S.
2679
+
2680
+ - KV fake quantization parameter separation is not currently supported.
2681
+
2682
+ - Incremental Inference Scenario (Q_S is 1):
2683
+
2684
+ - Query, key, and value inputs functional usage restrictions:
2685
+
2686
+ - The B axis supports values less than or equal to 65,536.
2687
+ - The N axis supports values less than or equal to 256.
2688
+ - The D axis supports values less than or equal to 512.
2689
+ - Scenarios where the input types of query, key, and value are all int8 are not supported.
2690
+
2691
+ - Page attention scenario:
2692
+
2693
+ - The necessary condition to enable page attention is that the block_table exists and is valid.
2694
+ The key and value are arranged in contiguous memory according to the indices in the block_table.
2695
+ The key and value dtypes supported are float16, bfloat16, and int8. In this scenario, the
2696
+ input_layout parameter for key and value is invalid.
2697
+ - block_size is a user-defined parameter, and its value will affect the performance of page
2698
+ attention. When enabling page attention, a non-zero value for block_size must be provided, and
2699
+ the maximum value for block_size is 512.
2700
+ - If the input types of key and value are float16 or bfloat16, they must be 16-aligned. If the
2701
+ input types are int8, they must be 32-aligned, with 128 being recommended. In general, page
2702
+ attention can increase throughput but may lead to a performance decrease.
2703
+ - In the page attention enabled scenario, when the KV cache layout is (blocknum, block_size, H) and
2704
+ num_key_value_heads * D exceeds 64K, an error will be reported due to hardware
2705
+ instruction constraints. This can be resolved by enabling GQA (reducing num_key_value_heads) or
2706
+ adjusting the KV cache layout to (blocknum, num_key_value_heads, block_size, D).
2707
+ - The product of all dimensions of the shape of the key and value tensors in the page attention
2708
+ scenario must not exceed the representable range of int32.
2709
+
2710
+ - In the page attention enabled scenario, the input S must be greater than or equal to
2711
+ max_block_num_per_seq * block_size.
2712
+
2713
+ - Enabling attention mask (e.g., mask shape = (B, 1, 1, S))
2714
+ - Enabling pse_shift (e.g., pse_shift shape = (B, N, 1, S))
2715
+ - Enabling fake quantization in per-token mode (e.g., antiquant_scale and antiquant_offset shapes =
2716
+ (2, B, S)) are also supported.
2717
+
2718
+ - KV left padding scenario:
2719
+
2720
+ - In the KV left padding scenario, the formula for calculating the starting point of the KV cache
2721
+ transport is: KV_S - kv_padding_size - actual_seq_lengths. The formula for the endpoint of the
2722
+ KV cache transport is: KV_S - kv_padding_size. If the starting point or endpoint of the KV cache
2723
+ is less than 0, the returned data result will be all zeros.
2724
+ - If kv_padding_size is less than 0 in the KV left padding scenario, it will be set to 0.
2725
+ - The KV left padding scenario must be enabled together with the actual_seq_lengths parameter,
2726
+ otherwise, it defaults to the KV right padding scenario.
2727
+ - The KV left padding scenario must be enabled together with the atten_mask parameter, and the
2728
+ atten_mask must be correctly applied to hide invalid data. Otherwise, accuracy issues may arise.
2729
+
2730
+ - pse_shift functional usage restrictions:
2731
+
2732
+ - The data type of pse_shift must match the data type of the query.
2733
+ - Only the D axis alignment is supported, meaning the D axis must be divisible by 16.
2734
+
2735
+ - KV fake quantization parameter separation:
2736
+
2737
+ - key_antiquant_mode and value_antiquant_mode must be consistent.
2738
+ - key_antiquant_scale and value_antiquant_scale must either both be empty or both non-empty.
2739
+ - key_antiquant_offset and value_antiquant_offset must either both be empty or both non-empty.
2740
+ - When both key_antiquant_scale and value_antiquant_scale are non-empty, their shapes must be
2741
+ consistent.
2742
+ - When both key_antiquant_offset and value_antiquant_offset are non-empty, their shapes must be
2743
+ consistent.
2744
+
2745
+
2746
+ Supported Platforms:
2747
+ ``Ascend``
2748
+
2749
+ Examples:
2750
+ >>> from mindspore import ops
2751
+ >>> from mindspore import Tensor
2752
+ >>> import numpy as np
2753
+ >>> B, N, S, D = 1, 8, 1024, 128
2754
+ >>> query = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
2755
+ >>> key = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
2756
+ >>> value = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
2757
+ >>> out = ops.fused_infer_attention_score(query, key, value, num_heads=N, input_layout='BNSD')
2758
+ >>> print(out[0].shape)
2759
+ (1, 8, 1024, 128)
2760
+ """
2761
+ fias_op = _get_cache_prim(FusedInferAttentionScore)(num_heads, scale, pre_tokens, next_tokens, input_layout,
2762
+ num_key_value_heads, sparse_mode, inner_precise, block_size,
2763
+ antiquant_mode, softmax_lse_flag, key_antiquant_mode,
2764
+ value_antiquant_mode)
2765
+ key_list = key if isinstance(key, (tuple, list)) else [key]
2766
+ value_list = value if isinstance(value, (tuple, list)) else [value]
2767
+ return fias_op(query, key_list, value_list, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
2768
+ dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale,
2769
+ antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale,
2770
+ key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix,
2771
+ value_shared_prefix, actual_shared_prefix_len)
2772
+
2773
+
2026
2774
  class WhileLoop(Primitive):
2027
2775
  """
2028
2776
  Provide a useful op for reducing compilation times of while loop.
@@ -2195,7 +2943,7 @@ class Scan(Primitive):
2195
2943
 
2196
2944
  class ForiLoop(Primitive):
2197
2945
  """
2198
- Provide a useful op for loop from lower to upper.
2946
+ Performs a loop operation within the specified range.
2199
2947
  The execution logic of the ForiLoop operator can be roughly represented by the following code:
2200
2948
 
2201
2949
  .. code-block:: python