mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (505) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +12 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +61 -71
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +74 -104
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +13 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +28 -5
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +141 -88
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/cell_wrapper.py +84 -34
  205. mindspore/nn/wrap/grad_reducer.py +8 -5
  206. mindspore/nn/wrap/loss_scale.py +105 -42
  207. mindspore/numpy/array_creations.py +1 -2
  208. mindspore/numpy/array_ops.py +3 -2
  209. mindspore/numpy/utils_const.py +5 -5
  210. mindspore/opencv_core452.dll +0 -0
  211. mindspore/opencv_imgcodecs452.dll +0 -0
  212. mindspore/opencv_imgproc452.dll +0 -0
  213. mindspore/ops/_grad_experimental/__init__.py +0 -5
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  216. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  217. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  218. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  219. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  220. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  222. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  224. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  225. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  226. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  227. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  228. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  229. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  230. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  231. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  232. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  233. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  234. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  235. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  236. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  237. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  238. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  239. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  240. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  241. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  242. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  243. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  244. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  245. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  246. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  248. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  249. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  250. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  251. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  252. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  253. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  254. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  255. mindspore/ops/_primitive_cache.py +1 -1
  256. mindspore/ops/_tracefunc.py +45 -13
  257. mindspore/ops/_utils/utils.py +6 -1
  258. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  259. mindspore/ops/_vmap/vmap_base.py +3 -3
  260. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  261. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  262. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  263. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  264. mindspore/ops/arg_dtype_cast.py +54 -0
  265. mindspore/ops/composite/base.py +37 -10
  266. mindspore/ops/composite/math_ops.py +5 -4
  267. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  268. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  269. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  270. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  271. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  272. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  273. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  274. mindspore/ops/deprecated.py +304 -0
  275. mindspore/ops/function/__init__.py +4 -1
  276. mindspore/ops/function/array_func.py +174 -193
  277. mindspore/ops/function/clip_func.py +81 -13
  278. mindspore/ops/function/debug_func.py +1 -1
  279. mindspore/ops/function/grad/grad_func.py +18 -9
  280. mindspore/ops/function/image_func.py +10 -4
  281. mindspore/ops/function/linalg_func.py +5 -5
  282. mindspore/ops/function/math_func.py +575 -386
  283. mindspore/ops/function/nn_func.py +568 -260
  284. mindspore/ops/function/random_func.py +88 -57
  285. mindspore/ops/function/sparse_func.py +1 -1
  286. mindspore/ops/function/sparse_unary_func.py +14 -12
  287. mindspore/ops/function/vmap_func.py +6 -5
  288. mindspore/ops/functional.py +15 -10
  289. mindspore/ops/op_info_register.py +244 -25
  290. mindspore/ops/operations/__init__.py +28 -19
  291. mindspore/ops/operations/_grad_ops.py +72 -7
  292. mindspore/ops/operations/_inner_ops.py +350 -17
  293. mindspore/ops/operations/_quant_ops.py +4 -8
  294. mindspore/ops/operations/_sequence_ops.py +42 -0
  295. mindspore/ops/operations/array_ops.py +68 -282
  296. mindspore/ops/operations/comm_ops.py +107 -59
  297. mindspore/ops/operations/custom_ops.py +94 -70
  298. mindspore/ops/operations/debug_ops.py +8 -4
  299. mindspore/ops/operations/image_ops.py +18 -12
  300. mindspore/ops/operations/inner_ops.py +26 -3
  301. mindspore/ops/operations/math_ops.py +189 -141
  302. mindspore/ops/operations/nn_ops.py +794 -489
  303. mindspore/ops/operations/other_ops.py +0 -22
  304. mindspore/ops/operations/random_ops.py +53 -111
  305. mindspore/ops/operations/sparse_ops.py +3 -1
  306. mindspore/ops/primitive.py +24 -18
  307. mindspore/parallel/_auto_parallel_context.py +68 -8
  308. mindspore/parallel/_cost_model_context.py +2 -2
  309. mindspore/parallel/_offload_context.py +17 -3
  310. mindspore/parallel/_parallel_serialization.py +12 -5
  311. mindspore/parallel/_ps_context.py +12 -0
  312. mindspore/parallel/_tensor.py +18 -13
  313. mindspore/parallel/_transformer/layers.py +5 -3
  314. mindspore/parallel/_transformer/loss.py +1 -0
  315. mindspore/parallel/_transformer/moe.py +2 -2
  316. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  317. mindspore/parallel/_transformer/transformer.py +23 -3
  318. mindspore/parallel/_utils.py +11 -7
  319. mindspore/parallel/algo_parameter_config.py +85 -5
  320. mindspore/parallel/checkpoint_transform.py +19 -12
  321. mindspore/parallel/shard.py +21 -14
  322. mindspore/pgodb140.dll +0 -0
  323. mindspore/pgort140.dll +0 -0
  324. mindspore/profiler/common/struct_type.py +3 -3
  325. mindspore/profiler/common/util.py +4 -2
  326. mindspore/profiler/envprofiling.py +1 -1
  327. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  328. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  329. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  330. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  331. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  332. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  333. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  334. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  335. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  336. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  337. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  338. mindspore/profiler/parser/flops_parser.py +15 -11
  339. mindspore/profiler/parser/framework_parser.py +38 -22
  340. mindspore/profiler/parser/hccl_parser.py +16 -12
  341. mindspore/profiler/parser/integrator.py +22 -11
  342. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  343. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  344. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  345. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  346. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  347. mindspore/profiler/parser/optime_parser.py +1 -1
  348. mindspore/profiler/parser/profiler_info.py +21 -2
  349. mindspore/profiler/parser/step_trace_parser.py +11 -14
  350. mindspore/profiler/profiling.py +179 -89
  351. mindspore/rewrite/api/node.py +102 -19
  352. mindspore/rewrite/api/node_type.py +5 -1
  353. mindspore/rewrite/api/pattern_engine.py +1 -1
  354. mindspore/rewrite/api/scoped_value.py +9 -17
  355. mindspore/rewrite/api/symbol_tree.py +131 -47
  356. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  357. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  358. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  359. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  360. mindspore/rewrite/common/rewrite_elog.py +5 -1
  361. mindspore/rewrite/namer.py +33 -24
  362. mindspore/rewrite/namespace.py +14 -5
  363. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  364. mindspore/rewrite/node/call_function.py +79 -0
  365. mindspore/rewrite/node/cell_container.py +135 -0
  366. mindspore/rewrite/node/control_flow.py +88 -0
  367. mindspore/rewrite/{node.py → node/node.py} +273 -234
  368. mindspore/rewrite/node/node_manager.py +254 -0
  369. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  370. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  371. mindspore/rewrite/parsers/assign_parser.py +216 -221
  372. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  373. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  374. mindspore/rewrite/parsers/constant_parser.py +9 -6
  375. mindspore/rewrite/parsers/container_parser.py +9 -7
  376. mindspore/rewrite/parsers/for_parser.py +36 -15
  377. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  378. mindspore/rewrite/parsers/if_parser.py +28 -24
  379. mindspore/rewrite/parsers/module_parser.py +196 -25
  380. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  381. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  382. mindspore/rewrite/parsers/return_parser.py +6 -6
  383. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  384. mindspore/rewrite/sparsify/utils.py +1 -1
  385. mindspore/rewrite/symbol_tree.py +523 -578
  386. mindspore/rewrite/symbol_tree_builder.py +9 -193
  387. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  388. mindspore/run_check/_check_version.py +6 -4
  389. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  390. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  391. mindspore/tbbmalloc.dll +0 -0
  392. mindspore/tinyxml2.dll +0 -0
  393. mindspore/train/_utils.py +7 -3
  394. mindspore/train/amp.py +323 -123
  395. mindspore/train/anf_ir_pb2.py +14 -2
  396. mindspore/train/callback/_backup_and_restore.py +2 -12
  397. mindspore/train/callback/_callback.py +29 -4
  398. mindspore/train/callback/_checkpoint.py +23 -8
  399. mindspore/train/callback/_early_stop.py +2 -2
  400. mindspore/train/callback/_landscape.py +4 -4
  401. mindspore/train/callback/_loss_monitor.py +2 -2
  402. mindspore/train/callback/_on_request_exit.py +2 -2
  403. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  404. mindspore/train/callback/_summary_collector.py +15 -8
  405. mindspore/train/callback/_time_monitor.py +58 -5
  406. mindspore/train/data_sink.py +5 -11
  407. mindspore/train/dataset_helper.py +84 -57
  408. mindspore/train/loss_scale_manager.py +2 -2
  409. mindspore/train/metrics/__init__.py +3 -3
  410. mindspore/train/metrics/cosine_similarity.py +1 -1
  411. mindspore/train/metrics/hausdorff_distance.py +3 -2
  412. mindspore/train/metrics/mean_surface_distance.py +3 -2
  413. mindspore/train/metrics/metric.py +39 -19
  414. mindspore/train/metrics/roc.py +2 -2
  415. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  416. mindspore/train/mind_ir_pb2.py +85 -36
  417. mindspore/train/model.py +187 -47
  418. mindspore/train/serialization.py +487 -161
  419. mindspore/train/summary/_summary_adapter.py +1 -1
  420. mindspore/train/summary/_writer_pool.py +3 -2
  421. mindspore/train/summary/summary_record.py +37 -17
  422. mindspore/train/train_thor/convert_utils.py +3 -3
  423. mindspore/train/train_thor/dataset_helper.py +1 -1
  424. mindspore/turbojpeg.dll +0 -0
  425. mindspore/vcmeta.dll +0 -0
  426. mindspore/vcruntime140.dll +0 -0
  427. mindspore/vcruntime140_1.dll +0 -0
  428. mindspore/version.py +1 -1
  429. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
  430. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
  431. mindspore/_extends/graph_kernel/expander.py +0 -80
  432. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  433. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  434. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  435. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  436. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  437. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  438. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  439. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  440. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  441. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  442. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  443. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  444. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  445. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  446. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  447. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  448. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  449. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  450. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  451. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  452. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  453. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  454. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  455. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  456. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  457. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  458. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  459. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  460. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  461. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  462. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  463. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  464. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  465. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  466. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  467. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  468. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  469. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  470. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  471. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  472. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  473. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  474. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  475. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  476. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  477. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  478. mindspore/dataset/datapreprocess/__init__.py +0 -20
  479. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  480. mindspore/include/api/net.h +0 -142
  481. mindspore/nn/lr_scheduler.py +0 -262
  482. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  483. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  484. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  485. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  486. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  487. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  488. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  489. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  490. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  491. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  492. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  493. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  497. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  502. mindspore/rewrite/node_visitor.py +0 -44
  503. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  504. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  505. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -33,12 +33,12 @@ class FlashAttentionFwd(FlashAttention):
33
33
  """
34
34
 
35
35
  def __init__(self, query, key, value,
36
- dim_mask, attn_mask, dropout_mask, alibi_mask,
36
+ attn_mask, dropout_mask, alibi_mask,
37
37
  kernel_name,
38
38
  tiling_stgy: TilingStrategy,
39
39
  prev_block_num=65536,
40
40
  next_block_num=65536, high_precision=False, disable_debug=True):
41
- super(FlashAttentionFwd, self).__init__(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_mask,
41
+ super(FlashAttentionFwd, self).__init__(query, key, value, attn_mask, dropout_mask, alibi_mask,
42
42
  kernel_name,
43
43
  tiling_stgy, prev_block_num, next_block_num, high_precision,
44
44
  disable_debug)
@@ -55,25 +55,49 @@ class FlashAttentionFwd(FlashAttention):
55
55
  self.O_gm = self.tik_instance.Tensor(FP16, self.O_shape, name="O_gm", scope=GM, is_atomic_add=True)
56
56
  if self.high_precision:
57
57
  self.O_gm_workspace = self.tik_instance.Tensor(FP32, self.O_shape, name="O_gm_workspace", scope=GM,
58
- is_workspace=True)
58
+ is_workspace=True, is_atomic_add=True)
59
59
  self.l_gm = self.tik_instance.Tensor(self.precision_type, self.l_shape, name="l_gm", scope=GM,
60
60
  is_atomic_add=True)
61
61
  self.m_gm = self.tik_instance.Tensor(FP16, self.m_shape, name="m_gm", scope=GM, is_atomic_add=True)
62
62
 
63
+ def prepare_global_ones(self):
64
+ """Prepare global ones tensor in L1 for cube impl row_sum"""
65
+ Bc_aligned = (self.Bc + 15) // 16 * 16
66
+ last_Bc_aligned = (self.last_Bc + 15) // 16 * 16
67
+ self.ones_l1 = self.tik_instance.Tensor(FP16, (Bc_aligned, 16), name="ones_l1", scope=L1)
68
+ self.last_ones_l1 = self.tik_instance.Tensor(FP16, (last_Bc_aligned, 16), name="last_ones_l1", scope=L1)
69
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
70
+ ones_ub = self.tik_instance.Tensor(FP16, (Bc_aligned, 16), name="ones_ub", scope=UB)
71
+ self.tik_instance.h_duplicate(ones_ub, 1.0)
72
+ self.cont_data_mv_1_bust(dst=self.ones_l1, src=ones_ub, burst=Bc_aligned)
73
+ last_ones_ub = self.tik_instance.Tensor(FP16, (last_Bc_aligned, 16), name="last_ones_ub", scope=UB)
74
+ self.tik_instance.h_duplicate(ones_ub, 1.0)
75
+ self.cont_data_mv_1_bust(dst=self.last_ones_l1, src=last_ones_ub, burst=last_Bc_aligned)
76
+
63
77
  def softmax_compute(self, Sij_ub, mij_ub, lij_ub, m, n):
64
78
  """Refer to Algorithm 2 line12"""
65
- # mij = rowmax(Sij) 计算Sij每行的最大值
66
- self.tik_instance.h_reduce_max(mij_ub, Sij_ub[:, 0:n], 1)
67
79
  m_aligned = self.tik_ops_utils.up_align_to_K0(m)
68
80
  n_aligned = self.tik_ops_utils.up_align_to_K0(n)
81
+ n0 = 16
82
+ n1 = n // 16
83
+ # only support n % 16 == 0
84
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
85
+ mn0_block_max = self.tik_instance.Tensor(FP16, (1, m, n0), name="mn0_block_max", scope=UB)
86
+ self.cont_data_mv_1_bust(dst=mn0_block_max, src=Sij_ub, burst=m)
87
+ with self.tik_instance.for_range(1, n1) as idx:
88
+ self.tik_instance.h_max(mn0_block_max, mn0_block_max, Sij_ub[idx, :, :])
89
+ mn0_block_max = mn0_block_max.reshape((m, n0))
90
+ self.tik_instance.h_reduce_max(mij_ub, mn0_block_max, 1)
69
91
  # Sij - mij
70
92
  with self.tik_instance.new_stmt_scope(disable_sync=False):
71
- broadcast_mij_ub = self.tik_ops_utils.broadcast(mij_ub, (m_aligned, n_aligned))
72
- self.tik_instance.h_sub(Sij_ub, Sij_ub, broadcast_mij_ub)
93
+ broadcast_mij_ub = self.tik_ops_utils.broadcast(mij_ub, (m, n0))
94
+ broadcast_mij_ub = broadcast_mij_ub.reshape((1, m, n0))
95
+ for idx in range(n1):
96
+ self.tik_instance.h_sub(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_mij_ub)
73
97
  # exp
74
98
  if self.high_precision:
75
99
  Sij_ub_fp32 = self.tik_instance.Tensor(
76
- FP32, (m_aligned, n_aligned), name="Sij_ub_fp32", scope=tik.scope_ubuf
100
+ FP32, (n_aligned // 16, m_aligned, 16), name="Sij_ub_fp32", scope=UB
77
101
  )
78
102
  with self.tik_instance.new_stmt_scope(disable_sync=False):
79
103
  self.tik_instance.h_cast(Sij_ub_fp32, Sij_ub, "none")
@@ -83,15 +107,21 @@ class FlashAttentionFwd(FlashAttention):
83
107
  self.tik_instance.h_exp(Sij_ub, Sij_ub)
84
108
 
85
109
  # cube impl rowsum
86
- Sij_l1_K1MK0_ws = self.tik_instance.Tensor(FP16, (n_aligned // 16, m_aligned, 16),
87
- name="Sij_l1_K1MK0_ws", scope=L1)
88
- Sij_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Sij_ub, Sij_l1_K1MK0_ws)
89
- Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, lij_ub, m, n, self.precision_type)
110
+ Sij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (n_aligned // 16, m_aligned, 16),
111
+ name="Sij_l1_K1MK0_ed", scope=L1)
112
+ self.cont_data_mv_1_bust(dst=Sij_l1_K1MK0_ed, src=Sij_ub, burst=m * n // 16)
113
+ if n == self.Bc:
114
+ Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, self.ones_l1,
115
+ lij_ub, m, n, self.precision_type)
116
+ else:
117
+ Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, self.last_ones_l1,
118
+ lij_ub, m, n, self.precision_type)
90
119
 
91
120
  if self.high_precision:
92
121
  return Sij_ub_fp32, mij_ub, Sij_row_sum_ub
93
-
94
- return Sij_ub, mij_ub, Sij_row_sum_ub
122
+ if self.has_drop_mask:
123
+ return Sij_ub, mij_ub, Sij_row_sum_ub
124
+ return Sij_l1_K1MK0_ed, mij_ub, Sij_row_sum_ub
95
125
 
96
126
  def update_m_l(self, mi_old_ub, mij_ub, li_old_ub, lij_ub, vec_len):
97
127
  """Refer to Algorithm 2 line13
@@ -146,34 +176,36 @@ class FlashAttentionFwd(FlashAttention):
146
176
  :param block_h:
147
177
  :return: None
148
178
  """
149
- vec_gm_offset = (batch_start + batch_idx) * self.Nq + q_blk_idx * self.Br
179
+ vec_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
150
180
  o_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
151
181
  block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
152
182
  block_k_aligned_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
153
- try:
154
- dtype_size = DTYPE_SIZE[FP32]
155
- except KeyError:
156
- raise ValueError("The argument 'FP32' is not valid.")
183
+ n1 = block_k_aligned_aligned // self.N0
157
184
  with self.tik_instance.if_scope(tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
158
185
  self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, lij_ub, vec_gm_offset, block_h)
159
186
  self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mij_ub, vec_gm_offset, block_h)
160
187
  li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(lij_ub, block_h)
188
+ vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="vec_ub", scope=UB)
161
189
  for i in range(block_h):
162
190
  src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
163
- self.tik_instance.h_mul(Pij_ub_fp32[i, :], Pij_ub_fp32[i, :], src_scalar)
164
-
191
+ self.tik_instance.h_duplicate(vec_ub[i, :], src_scalar)
192
+ vec_ub = vec_ub.reshape((1, block_h, self.N0))
193
+ with self.tik_instance.for_range(0, n1) as idx:
194
+ self.tik_instance.h_mul(Pij_ub_fp32[idx, :, :],
195
+ Pij_ub_fp32[idx, :, :],
196
+ vec_ub)
165
197
  self.tik_instance.h_cast(Pij_ub, Pij_ub_fp32, "none")
166
198
  Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
167
199
  FP16, (block_k_aligned_aligned // 16, block_h_aligned, 16), name="Pij_l1_K1MK0_ed", scope=L1
168
200
  )
169
- Pij_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Pij_ub, workspace_tensor=Pij_l1_K1MK0_ed)
201
+ self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_ub,
202
+ burst=block_k_aligned_aligned * block_h_aligned // 16)
170
203
  Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
171
- kv_blk_height, self.actual_d, N1MN0_to_MN=True,
204
+ kv_blk_height, self.actual_d, N1MN0_to_MN=False,
172
205
  precision_type=self.precision_type) # Pij*Vj
173
- self.cont_data_mv_1_bust(dst=self.O_gm_workspace[o_gm_offset],
174
- src=Pij_Vj_matmul_res_ub,
175
- burst=block_h * self.d * dtype_size // 32)
176
-
206
+ self.tik_instance.data_move(dst=self.O_gm_workspace[o_gm_offset], src=Pij_Vj_matmul_res_ub, sid=0,
207
+ nburst=self.N1, burst=block_h * self.N0 // 8,
208
+ src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 8)
177
209
  with self.tik_instance.else_scope():
178
210
  mi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="mi_old_ub", scope=UB)
179
211
  li_ub = self.tik_instance.Tensor(FP32, (block_h_aligned,), name="li_ub", scope=UB)
@@ -190,124 +222,59 @@ class FlashAttentionFwd(FlashAttention):
190
222
 
191
223
  li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(li_new_ub, block_h)
192
224
  self.tik_instance.h_mul(exp_m_cur_fp32, exp_m_cur_fp32, li_new_rec_ub)
225
+ exp_m_cur_fp32_vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="exp_m_cur_fp32_vec_ub",
226
+ scope=UB)
193
227
  for i in range(block_h):
194
228
  src_scalar = self.tik_instance.Scalar(init_value=exp_m_cur_fp32[i], dtype=FP32)
195
- self.tik_instance.h_mul(Pij_ub_fp32[i, :], Pij_ub_fp32[i, :], src_scalar)
196
-
229
+ self.tik_instance.h_duplicate(exp_m_cur_fp32_vec_ub[i, :], src_scalar)
230
+ exp_m_cur_fp32_vec_ub = exp_m_cur_fp32_vec_ub.reshape((1, block_h, self.N0))
231
+ with self.tik_instance.for_range(0, n1) as idx:
232
+ self.tik_instance.h_mul(Pij_ub_fp32[idx, :, :],
233
+ Pij_ub_fp32[idx, :, :],
234
+ exp_m_cur_fp32_vec_ub)
197
235
  self.tik_instance.h_cast(Pij_ub, Pij_ub_fp32, "none")
198
- # ub -> l1
199
236
  Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
200
237
  FP16, (block_k_aligned_aligned // 16, block_h_aligned, 16), name="Pij_l1_K1MK0_ed", scope=L1
201
238
  )
202
- Pij_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Pij_ub, workspace_tensor=Pij_l1_K1MK0_ed)
239
+ self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_ub,
240
+ burst=block_k_aligned_aligned * block_h_aligned // 16)
203
241
  Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
204
- kv_blk_height, self.actual_d, N1MN0_to_MN=True,
242
+ kv_blk_height, self.actual_d, N1MN0_to_MN=False,
205
243
  precision_type=self.precision_type) # Pij*Vj
206
- Oi_ub = self.tik_instance.Tensor(FP32, (block_h_aligned, self.d), scope=UB, name="Oi_ub")
207
- self.cont_data_mv_1_bust(dst=Oi_ub, src=self.O_gm_workspace[o_gm_offset],
208
- burst=block_h * self.d * dtype_size // 32)
244
+ n1, m, n0 = Pij_Vj_matmul_res_ub.shape
245
+ Oi_ub = self.tik_instance.Tensor(FP32, (n1, m, n0), name="Oi_ub", scope=UB)
246
+ self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm_workspace[o_gm_offset],
247
+ sid=0, nburst=self.N1, burst=m * self.N0 // 8,
248
+ src_stride=(self.Nq - m) * self.N0 // 8, dst_stride=0)
209
249
 
210
250
  self.tik_instance.h_mul(li_new_rec_ub, li_new_rec_ub, li_ub)
211
251
  self.tik_instance.h_mul(li_new_rec_ub, li_new_rec_ub, exp_m_old_fp32)
212
-
213
- with self.tik_instance.new_stmt_scope(disable_sync=False):
214
- with self.tik_instance.for_range(begint=0, endt=block_h) as i:
215
- src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
216
- self.tik_instance.h_mul(Oi_ub[i, :], Oi_ub[i, :], src_scalar)
217
-
252
+ li_new_rec_vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="li_new_rec_vec_ub",
253
+ scope=UB)
254
+ for i in range(block_h):
255
+ src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
256
+ self.tik_instance.h_duplicate(li_new_rec_vec_ub[i, :], src_scalar)
257
+ li_new_rec_vec_ub = li_new_rec_vec_ub.reshape((1, block_h, self.N0))
258
+ with self.tik_instance.for_range(0, n1) as idx:
259
+ self.tik_instance.h_mul(Oi_ub[idx, :, :],
260
+ Oi_ub[idx, :, :],
261
+ li_new_rec_vec_ub)
218
262
  self.tik_instance.h_add(Oi_ub, Oi_ub, Pij_Vj_matmul_res_ub)
219
- self.cont_data_mv_1_bust(dst=self.O_gm_workspace[o_gm_offset],
220
- src=Oi_ub,
221
- burst=block_h * self.d * dtype_size // 32)
222
-
223
- def update_o_gm(self, block_h, li_new_rec_ub, o_gm_offset, ub_data):
224
- """Load o from gm and update it, then write it back to gm"""
225
- block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
226
- half_block_h1 = self.tik_ops_utils.up_align_to_K0(block_h // 2)
227
- half_block_h2 = block_h_aligned - half_block_h1
228
- # double buffer: vec and mte3 parallel
229
- with self.tik_instance.for_range(0, 2, thread_num=2) as t_idx:
230
- with self.tik_instance.if_scope(t_idx == 0):
231
- row_begin = 0
232
- row_end = half_block_h1
233
- broadcast_li_new_rec_ub = self.tik_ops_utils.broadcast(
234
- li_new_rec_ub[row_begin:row_end], (half_block_h1, self.d)
235
- )
236
- self.tik_instance.h_mul(ub_data[row_begin:row_end, :],
237
- ub_data[row_begin:row_end, :],
238
- broadcast_li_new_rec_ub)
239
- if half_block_h1 <= block_h:
240
- self.cont_data_mv_1_bust(dst=self.O_gm[o_gm_offset],
241
- src=ub_data[row_begin:row_end, :],
242
- burst=half_block_h1 * self.d // 16)
243
- else:
244
- self.cont_data_mv_1_bust(dst=self.O_gm[o_gm_offset],
245
- src=ub_data[row_begin:row_end, :],
246
- burst=block_h * self.d // 16)
247
- with self.tik_instance.else_scope():
248
- if half_block_h2 > 0:
249
- row_begin = half_block_h1
250
- row_end = row_begin + half_block_h2
251
- broadcast_li_new_rec_ub = self.tik_ops_utils.broadcast(
252
- li_new_rec_ub[row_begin:row_end], (half_block_h2, self.d)
253
- )
254
- self.tik_instance.h_mul(ub_data[row_begin:row_end, :],
255
- ub_data[row_begin:row_end, :],
256
- broadcast_li_new_rec_ub)
257
- cur_o_gm_offset = o_gm_offset + half_block_h1 * self.d
258
- self.cont_data_mv_1_bust(dst=self.O_gm[cur_o_gm_offset],
259
- src=ub_data[row_begin:row_end, :],
260
- burst=(block_h - half_block_h1) * self.d // 16)
261
-
262
- def update_Oi(
263
- self,
264
- Oi_ub,
265
- exp_mi_sub_mi_new,
266
- Pij_Vj_ub,
267
- exp_mij_sub_mi_new,
268
- li_new_rec_ub,
269
- li_ub,
270
- o_gm_offset,
271
- block_h
272
- ):
273
- """Refer to Algorithm 2 line15"""
274
- block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
275
- diag_exp_Oi_ub = self.diag_exp_Oi(li_ub, exp_mi_sub_mi_new, Oi_ub, block_h_aligned)
276
- # exp_mij_sub_mi_new * Pij_Vj_ub
277
- exp_Pij_Vj_ub = self.exp_Pij_Vj(exp_mij_sub_mi_new, Pij_Vj_ub, block_h_aligned)
278
-
279
- # (diag(li)_exp_Oi + exp_P_V)
280
- sum_diag_exp_Oi_and_exp_Pij_Vj_ub = diag_exp_Oi_ub
281
- self.tik_instance.h_add(
282
- sum_diag_exp_Oi_and_exp_Pij_Vj_ub,
283
- sum_diag_exp_Oi_and_exp_Pij_Vj_ub,
284
- exp_Pij_Vj_ub
285
- )
286
- self.update_o_gm(block_h, li_new_rec_ub, o_gm_offset, sum_diag_exp_Oi_and_exp_Pij_Vj_ub)
287
-
288
- def diag_exp_Oi(self, li_ub, exp_mi_sub_mi_new, Oi_ub, block_h_aligned):
289
- """Refer to Algorithm 2 line15
290
- li * exp(mi - mi_new) * Oi
291
- """
292
- self.tik_instance.h_mul(exp_mi_sub_mi_new, exp_mi_sub_mi_new, li_ub)
293
- diag_exp = exp_mi_sub_mi_new
294
- with self.tik_instance.new_stmt_scope(disable_sync=False):
295
- broadcast_diag_exp = self.tik_ops_utils.broadcast(diag_exp, (block_h_aligned, self.d))
296
- self.tik_instance.h_mul(Oi_ub, Oi_ub, broadcast_diag_exp)
297
- return Oi_ub
263
+ self.tik_instance.data_move(dst=self.O_gm_workspace[o_gm_offset], src=Oi_ub, sid=0,
264
+ nburst=self.N1, burst=block_h * self.N0 // 8,
265
+ src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 8)
298
266
 
299
267
  def exp_Pij_Vj(self, exp_mij_sub_mi_new, Pij_Vj_ub, block_h_aligned):
300
268
  """Refer to Algorithm 2 line15
301
269
  exp(mij - mi_new) * Pij * Vj
302
270
  """
303
271
  with self.tik_instance.new_stmt_scope(disable_sync=False):
304
- broadcast_exp_mij_sub_mi_new = self.tik_ops_utils.broadcast(exp_mij_sub_mi_new,
305
- (block_h_aligned, self.d))
272
+ broadcast_exp_mij_sub_mi_new = self.tik_ops_utils.broadcast(exp_mij_sub_mi_new, (block_h_aligned, self.d))
306
273
  self.tik_instance.h_mul(Pij_Vj_ub, Pij_Vj_ub, broadcast_exp_mij_sub_mi_new)
307
274
  return Pij_Vj_ub
308
275
 
309
276
  def update_o_m_l(self,
310
- Pij_ub,
277
+ Pij_l1_K1MK0_ed,
311
278
  Vj_l1_K1NK0_ed,
312
279
  mij_ub,
313
280
  lij_ub,
@@ -318,76 +285,89 @@ class FlashAttentionFwd(FlashAttention):
318
285
  q_blk_idx,
319
286
  block_h):
320
287
  """Refer to Algorithm 2 line13 and line15 in FlashAttention"""
321
- vec_gm_offset = (batch_start + batch_idx) * self.Nq + q_blk_idx * self.Br
288
+ vec_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
322
289
  o_gm_offset = self.get_gm_offset(
323
290
  batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx
324
291
  )
325
292
  block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
326
- kv_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
327
- Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
328
- FP16, (kv_blk_h_aligned // 16, block_h_aligned, 16), name="Pij_l1", scope=L1
329
- )
330
- Pij_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Pij_ub, workspace_tensor=Pij_l1_K1MK0_ed)
293
+
331
294
  Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
332
295
  kv_blk_height, self.actual_d,
333
- N1MN0_to_MN=True) # Pij*Vj
334
- with self.tik_instance.if_scope(
335
- tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
296
+ N1MN0_to_MN=False) # Pij*Vj
297
+ n1, m, n0 = Pij_Vj_matmul_res_ub.shape
298
+ with self.tik_instance.if_scope(tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
336
299
  self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, lij_ub, vec_gm_offset, block_h)
337
300
  self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mij_ub, vec_gm_offset, block_h)
338
301
  li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(lij_ub, block_h)
339
- self.update_o_gm(block_h, li_new_rec_ub, o_gm_offset, Pij_Vj_matmul_res_ub)
302
+ broadcast_li_new_rec_ub = self.tik_ops_utils.broadcast(li_new_rec_ub, (m, n0))
303
+ broadcast_li_new_rec_ub = broadcast_li_new_rec_ub.reshape((1, m, n0))
304
+ with self.tik_instance.for_range(0, n1) as idx:
305
+ self.tik_instance.h_mul(Pij_Vj_matmul_res_ub[idx, :, :],
306
+ Pij_Vj_matmul_res_ub[idx, :, :],
307
+ broadcast_li_new_rec_ub)
308
+ self.tik_instance.data_move(dst=self.O_gm[o_gm_offset], src=Pij_Vj_matmul_res_ub, sid=0,
309
+ nburst=self.N1, burst=block_h * self.N0 // 16,
310
+ src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 16)
311
+
340
312
  with self.tik_instance.else_scope():
341
313
  mi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="mi_old_ub", scope=UB)
342
314
  li_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="li_ub", scope=UB)
343
315
  self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, vec_gm_offset, block_h)
344
316
  self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, vec_gm_offset, block_h)
345
-
346
- # 更新 l, m
347
317
  mi_new_ub, li_new_ub = self.update_m_l(mi_ub, mij_ub, li_ub, lij_ub, block_h)
348
318
  self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, li_new_ub, vec_gm_offset, block_h)
349
319
  self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mi_new_ub, vec_gm_offset, block_h)
350
-
351
320
  exp_mi_sub_mi_new = mi_ub
352
321
  exp_mij_sub_mi_new = mij_ub
353
- # 载入Oi 到 UB
354
- Oi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned, self.d), scope=UB, name="Oi_ub")
355
- self.cont_data_mv_1_bust(dst=Oi_ub, src=self.O_gm[o_gm_offset],
356
- burst=block_h * self.d // 16)
357
322
 
358
323
  li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(li_new_ub, block_h)
359
-
360
- self.update_Oi(
361
- Oi_ub,
362
- exp_mi_sub_mi_new,
363
- Pij_Vj_matmul_res_ub,
364
- exp_mij_sub_mi_new,
365
- li_new_rec_ub,
366
- li_ub,
367
- o_gm_offset,
368
- block_h
369
- )
324
+ self.tik_instance.h_mul(li_ub, li_ub, exp_mi_sub_mi_new)
325
+ self.tik_instance.h_mul(li_ub, li_ub, li_new_rec_ub)
326
+ scale1 = li_ub
327
+ self.tik_instance.h_mul(exp_mij_sub_mi_new, exp_mij_sub_mi_new, li_new_rec_ub)
328
+ scale2 = exp_mij_sub_mi_new
329
+ Oi_ub = self.tik_instance.Tensor(FP16, (n1, m, n0), name="Oi_ub", scope=UB)
330
+ self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm[o_gm_offset],
331
+ sid=0, nburst=self.N1, burst=m * self.N0 // 16,
332
+ src_stride=(self.Nq - m) * self.N0 // 16, dst_stride=0)
333
+ broadcast_scale1 = self.tik_ops_utils.broadcast(scale1, (m, n0))
334
+ broadcast_scale1 = broadcast_scale1.reshape((1, m, n0))
335
+ with self.tik_instance.for_range(0, n1) as idx:
336
+ self.tik_instance.h_mul(Oi_ub[idx, :, :], Oi_ub[idx, :, :], broadcast_scale1)
337
+ broadcast_scale2 = self.tik_ops_utils.broadcast(scale2, (m, n0))
338
+ broadcast_scale2 = broadcast_scale2.reshape((1, m, n0))
339
+ with self.tik_instance.for_range(0, n1) as idx:
340
+ self.tik_instance.h_mul(Pij_Vj_matmul_res_ub[idx, :, :],
341
+ Pij_Vj_matmul_res_ub[idx, :, :],
342
+ broadcast_scale2)
343
+ self.tik_instance.h_add(Oi_ub, Oi_ub, Pij_Vj_matmul_res_ub)
344
+ self.tik_instance.data_move(dst=self.O_gm[o_gm_offset], src=Oi_ub, sid=0,
345
+ nburst=self.N1, burst=block_h * self.N0 // 16,
346
+ src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 16)
370
347
 
371
348
  def compute_in_each_kv_block(self, batch_start, batch_idx, kv_blk_idx, kv_blk_height,
372
349
  core_idx_to_tr_info, core_idx):
373
350
  """The forward computation in each outer loop"""
374
351
  kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
375
- # load Kj (kv_blk_idx_th block of K_gm), then reorder it for Q*KjT
376
- Kj_l1 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Kj_l1", scope=L1)
377
- kv_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.N, self.d, self.Bc,
378
- kv_blk_idx)
379
- with self.tik_instance.new_stmt_scope(disable_sync=False):
380
- Kj_ub = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Kj_ub", scope=UB)
381
- self.cont_data_mv_1_bust(dst=Kj_ub, src=self.K_gm[kv_gm_offset],
382
- burst=kv_blk_height * self.d // 16)
383
- KjT_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Kj_ub, workspace_tensor=Kj_l1)
352
+ kv_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.N, self.d, self.Bc, kv_blk_idx)
353
+ # load Kj (kv_blk_idx_th block of K_gm)
354
+ KjT_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
355
+ name="KjT_l1_K1MK0_ed", scope=L1)
356
+ self.tik_instance.data_move(dst=KjT_l1_K1MK0_ed, src=self.K_gm[kv_gm_offset],
357
+ sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
358
+ src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
384
359
 
385
360
  # load Vj (kv_blk_idx_th block of V_gm), then reorder for Pij*Vj
386
361
  Vj_l1 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Vj_l1", scope=L1)
387
362
  with self.tik_instance.new_stmt_scope(disable_sync=False):
388
- Vj_ub = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Vj_ub", scope=UB)
389
- self.cont_data_mv_1_bust(dst=Vj_ub, src=self.V_gm[kv_gm_offset],
390
- burst=kv_blk_height * self.d // 16)
363
+ Vj_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
364
+ name="Vj_ub", scope=UB)
365
+ self.tik_instance.data_move(dst=Vj_ub, src=self.V_gm[kv_gm_offset],
366
+ sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
367
+ src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
368
+ # (N1, K, N0) -> (K, N)
369
+ Vj_ub = self.tik_ops_utils.N1MN0_TO_MN(Vj_ub)
370
+ # (K, N) -> (K1, N, K0)
391
371
  Vj_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Vj_ub, workspace_tensor=Vj_l1)
392
372
 
393
373
  tr_start_s = self.tik_instance.Scalar("int32", name="tr_start_s")
@@ -413,46 +393,51 @@ class FlashAttentionFwd(FlashAttention):
413
393
  kv_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
414
394
  q_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
415
395
  # load Qi (q_blk_idx_th block of Q_gm), then reorder it fo Qi*KjT
416
- Qi_l1 = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, self.d), scope=L1, name="Qi_l1")
417
396
  q_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
418
- with self.tik_instance.new_stmt_scope(disable_sync=False):
419
- Qi_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, self.d), scope=UB, name="Qi_ub")
420
- self.cont_data_mv_1_bust(dst=Qi_ub, src=self.Q_gm[q_gm_offset],
421
- burst=q_blk_height * self.d // 16)
422
- Qi_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Qi_ub, workspace_tensor=Qi_l1)
397
+ Qi_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_h_aligned, self.N0),
398
+ scope=L1, name="Qi_l1_K1MK0_ed")
399
+ self.tik_instance.data_move(dst=Qi_l1_K1MK0_ed, src=self.Q_gm[q_gm_offset],
400
+ sid=0, nburst=self.N1, burst=q_blk_h_aligned * self.N0 // 16,
401
+ src_stride=(self.Nq - q_blk_h_aligned) * self.N0 // 16, dst_stride=0)
423
402
 
424
403
  lij_ub = self.tik_instance.Tensor(self.precision_type, (q_blk_h_aligned,), scope=UB, name="lij_ub")
425
404
  mij_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned,), scope=UB, name="mij_ub")
426
- # QK^T Q shape: (q_blk_h_aligned, self.d), K^T shape: (self.d, kv_blk_h_aligned)
427
- Sij_ub_MN_ed = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1MK0_ed, m=q_blk_height,
405
+
406
+ Sij_ub_N1MN0 = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1MK0_ed, m=q_blk_height,
428
407
  k=self.actual_d, n=kv_blk_height,
429
- N1MN0_to_MN=True) # Qi*KjT
408
+ N1MN0_to_MN=False) # Qi*KjT
430
409
  if self.has_alibi_mask:
431
410
  alibi_mask_gm_offset = self.get_alibi_gm_offset(batch_start, batch_idx, self.N, self.Bc, kv_blk_idx)
432
- self.do_alibi_mask(Sij_ub_MN_ed, alibi_mask_gm_offset, q_blk_h_aligned, kv_blk_h_aligned)
411
+ self.do_alibi_mask(Sij_ub_N1MN0, alibi_mask_gm_offset, q_blk_h_aligned, kv_blk_h_aligned)
433
412
 
434
413
  # att_mask
435
414
  if self.has_attn_mask:
436
415
  attn_mask_gm_offset = self.get_attn_mask_gm_offset(batch_start, batch_idx, self.Nq, self.N,
437
416
  self.Br, q_blk_idx, self.Bc, kv_blk_idx)
438
- self.do_att_mask(Sij_ub_MN_ed, attn_mask_gm_offset, q_blk_height, kv_blk_height,
417
+ self.do_att_mask(Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
439
418
  q_blk_h_aligned, kv_blk_h_aligned)
440
419
 
441
- Pij_ub, mij_ub, lij_ub = self.softmax_compute(
442
- Sij_ub_MN_ed, mij_ub, lij_ub, q_blk_height, kv_blk_height
420
+ Pij_N1MN0, mij_ub, lij_ub = self.softmax_compute(
421
+ Sij_ub_N1MN0, mij_ub, lij_ub, q_blk_height, kv_blk_height
443
422
  ) # self.high_precision=True, Pij_ub return type fp32
444
423
  # dropout_mask
445
424
  if self.has_drop_mask:
446
425
  dropout_mask_gm_offset = self.get_drop_mask_gm_offset(batch_start, batch_idx, self.Nq,
447
- self.N, self.Br, q_blk_idx, self.Bc,
448
- kv_blk_idx)
449
- self.do_dropout_mask(Pij_ub, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
426
+ self.N, self.Br, q_blk_idx, self.Bc, kv_blk_idx)
427
+ self.do_dropout_mask(Pij_N1MN0, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
450
428
  q_blk_h_aligned, q_blk_height, precision_type=self.precision_type)
429
+ if not self.high_precision:
430
+ Pij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16,
431
+ (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
432
+ name="Pij_l1_K1MK0_ed", scope=L1)
433
+ self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_N1MN0,
434
+ burst=kv_blk_h_aligned * q_blk_h_aligned // 16)
435
+ Pij_N1MN0 = Pij_l1_K1MK0_ed
451
436
  if self.high_precision:
452
437
  self.update_o_m_l_fp32(
453
- Pij_ub,
438
+ Pij_N1MN0,
454
439
  Vj_l1_K1NK0_ed,
455
- Sij_ub_MN_ed,
440
+ Sij_ub_N1MN0,
456
441
  mij_ub,
457
442
  lij_ub,
458
443
  batch_start,
@@ -464,7 +449,7 @@ class FlashAttentionFwd(FlashAttention):
464
449
  )
465
450
  else:
466
451
  self.update_o_m_l(
467
- Pij_ub,
452
+ Pij_N1MN0,
468
453
  Vj_l1_K1NK0_ed,
469
454
  mij_ub,
470
455
  lij_ub,
@@ -523,7 +508,7 @@ class FlashAttentionFwd(FlashAttention):
523
508
  """collect all input gm tensors into input_gm_list,
524
509
  the input list should keep order with the para order in Primitive and init
525
510
  """
526
- input_gm_list = [self.Q_gm, self.K_gm, self.V_gm, self.dim_mask_gm]
511
+ input_gm_list = [self.Q_gm, self.K_gm, self.V_gm]
527
512
  if self.has_attn_mask:
528
513
  input_gm_list.append(self.att_mask_gm)
529
514
  if self.has_drop_mask:
@@ -537,10 +522,11 @@ class FlashAttentionFwd(FlashAttention):
537
522
  """collect all output gm tensors into output_gm_list,
538
523
  the output list should keep order with the para order in Primitive and init
539
524
  """
540
- return [self.O_gm, self.l_gm, self.m_gm]
525
+ output_gm_list = [self.O_gm, self.l_gm, self.m_gm]
526
+ return output_gm_list
541
527
 
542
528
 
543
- def flash_attention(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_mask, output, rowsum, rowmax,
529
+ def flash_attention(query, key, value, attn_mask, dropout_mask, alibi_mask, output, rowsum, rowmax,
544
530
  prev_block_num=65536, next_block_num=65536, high_precision=False, tiling_stgy_name='sparse',
545
531
  kernel_name="flash_attention", disable_debug=True):
546
532
  """
@@ -551,7 +537,6 @@ def flash_attention(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_
551
537
  query : dict. shape and dtype of input, only support float16
552
538
  key : dict. shape and dtype of input, only support float16
553
539
  value: dict. shape and dtype of input, only support float16
554
- dim_mask: dict. shape and dtype of input, only support int8
555
540
  attn_mask: dict. shape and dtype of input, only support float16
556
541
  dropout_mask: dict. shape and dtype of input, only support float16
557
542
  dropout_mask: dict. shape and dtype of input, only support float16
@@ -569,7 +554,7 @@ def flash_attention(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_
569
554
  -------
570
555
  tik_instance
571
556
  """
572
- fa = FlashAttentionFwd(query=query, key=key, value=value, dim_mask=dim_mask, attn_mask=attn_mask,
557
+ fa = FlashAttentionFwd(query=query, key=key, value=value, attn_mask=attn_mask,
573
558
  dropout_mask=dropout_mask, alibi_mask=alibi_mask, kernel_name=kernel_name,
574
559
  tiling_stgy=TilingStrategy.from_strategy_name(tiling_stgy_name),
575
560
  prev_block_num=prev_block_num, next_block_num=next_block_num,