mindspore 2.1.0__cp39-none-any.whl → 2.2.10__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ from tbe.common.platform import get_soc_spec
25
25
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
26
26
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
27
27
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
28
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import INT8
29
28
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import MASK_FILL_VALUE
30
29
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
31
30
  from mindspore.ops._op_impl._custom_op.flash_attention.tik_ops_utils import TikOpsUtils
@@ -37,7 +36,7 @@ from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.sparse_ti
37
36
  class FlashAttention(metaclass=ABCMeta):
38
37
  """The base class of FlashAttention"""
39
38
 
40
- def __init__(self, q, k, v, dim_mask, attn_mask, dropout_mask, alibi_mask, kernel_name,
39
+ def __init__(self, q, k, v, attn_mask, dropout_mask, alibi_mask, kernel_name,
41
40
  tiling_stgy_cls,
42
41
  prev_block_num=65536,
43
42
  next_block_num=65536,
@@ -48,7 +47,6 @@ class FlashAttention(metaclass=ABCMeta):
48
47
  :param q: with shape: (B, h, N, d)
49
48
  :param k: with shape: (B, h, N, d)
50
49
  :param v: with shape: (B, h, N, d)
51
- :param dim_mask: with shape: (x, ) x equals length last dim that not padded to 16.
52
50
  :param attn_mask: with shape: (1, N, N) or (B, N, N)
53
51
  :param dropout_mask: with shape: (B, h, N, N)
54
52
  :param alibi_mask: with shape: (B, h, 1, N)
@@ -66,19 +64,25 @@ class FlashAttention(metaclass=ABCMeta):
66
64
  src_stride=0,
67
65
  dst_stride=0)
68
66
  self.tik_ops_utils = TikOpsUtils(self.tik_instance)
69
- self.parser_input_shape(alibi_mask, attn_mask, dim_mask, dropout_mask, k, q, v)
70
-
71
- batch_size, h, Nq, d = self.q_shape
67
+ self.parse_input_shape(alibi_mask, attn_mask, dropout_mask, k, q, v)
68
+ # NZ
69
+ _, _, N1, M1, M0, N0 = self.q_shape
70
+ self.M1 = M1
71
+ self.N1 = N1
72
+ self.M0 = M0
73
+ self.N0 = N0
74
+ self.d = N1 * N0
75
+ # ND
76
+ batch_size, h, Nq, actual_d = self.q_ori_shape
72
77
  self.head_num = h
73
- self.B, self.Nq, self.d = batch_size * h, Nq, d
74
- self.N = self.k_shape[2]
78
+ self.B, self.Nq = batch_size * h, Nq
79
+ self.N = self.k_ori_shape[2]
80
+ self.actual_d = actual_d
75
81
 
76
82
  self.l_shape = [batch_size, h, self.Nq]
77
83
  self.m_shape = [batch_size, h, self.Nq]
78
- self.O_shape = [batch_size, h, self.Nq, self.d]
79
- self.actual_d = self.dim_mask_shape[0]
84
+ self.O_shape = self.q_shape
80
85
 
81
- self.K0 = 16
82
86
  self.prev_block_num = prev_block_num
83
87
  self.next_block_num = next_block_num
84
88
  self.high_precision = high_precision
@@ -86,7 +90,6 @@ class FlashAttention(metaclass=ABCMeta):
86
90
  self.precision_type = FP32
87
91
  else:
88
92
  self.precision_type = FP16
89
-
90
93
  if tiling_stgy_cls is None:
91
94
  self.tiling_stgy = SparseTiling(self.Nq, self.N, self.d)
92
95
  else:
@@ -106,9 +109,9 @@ class FlashAttention(metaclass=ABCMeta):
106
109
  self.alibi_mask_gm = None
107
110
 
108
111
  @staticmethod
109
- def get_gm_offset(batch_start, batch_idx, h, w, block_h, block_idx):
110
- """get gm offset"""
111
- gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * w
112
+ def get_l_m_gm_offset(batch_start, batch_idx, h, block_h, block_idx):
113
+ """get l m gm offset"""
114
+ gm_offset = (batch_start + batch_idx) * h + block_idx * block_h
112
115
  return gm_offset
113
116
 
114
117
  @staticmethod
@@ -143,108 +146,159 @@ class FlashAttention(metaclass=ABCMeta):
143
146
  """collect outputs"""
144
147
  raise NotImplementedError
145
148
 
146
- def get_core_bath_info(self):
147
- """
148
- Get batch start and batch number of each NPU core.
149
- :return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
150
- core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
151
- core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
152
- core_m_batch_n_Tr_end]]
153
- """
154
- if self.core_num > self.B * self.Tr:
155
- self.core_num = self.B * self.Tr
156
-
157
- task_idx_to_batch_tr_idx = dict()
158
- for task_idx in range(self.B * self.Tr):
159
- batch_idx = task_idx // self.Tr
160
- tr_idx = task_idx % self.Tr
161
- task_idx_to_batch_tr_idx[task_idx] = [batch_idx, tr_idx]
149
+ @abstractmethod
150
+ def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
151
+ """compute one core"""
152
+ raise NotImplementedError
162
153
 
163
- core_idx_to_batch_idx = defaultdict(lambda: [100000, -1])
164
- core_idx_to_tr_idx = defaultdict(lambda: defaultdict(lambda: [100000, -1]))
165
- task_start = 0
166
- avg_task_num_per_core, remain_task = divmod(self.B * self.Tr, self.core_num)
154
+ @abstractmethod
155
+ def prepare_global_ones(self):
156
+ """prepare global ones"""
157
+ raise NotImplementedError
167
158
 
168
- for core_idx in range(self.core_num):
169
- cur_core_task_num = avg_task_num_per_core
170
- if core_idx < remain_task:
171
- cur_core_task_num += 1
172
- task_end = task_start + cur_core_task_num
173
- for task_idx in range(task_start, task_end):
174
- try:
175
- batch_idx, tr_idx = task_idx_to_batch_tr_idx[task_idx]
176
- except KeyError:
177
- raise ValueError("The argument 'task_idx' is not valid.")
178
- batch_start_end_pair = core_idx_to_batch_idx[core_idx]
179
- if batch_idx < batch_start_end_pair[0]:
180
- batch_start_end_pair[0] = batch_idx
181
- if batch_idx > batch_start_end_pair[1]:
182
- batch_start_end_pair[1] = batch_idx
183
- tr_start_end_pair = core_idx_to_tr_idx[core_idx][batch_idx]
184
- if tr_idx < tr_start_end_pair[0]:
185
- tr_start_end_pair[0] = tr_idx
186
- if tr_idx > tr_start_end_pair[1]:
187
- tr_start_end_pair[1] = tr_idx
188
- task_start = task_end
159
+ def get_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_idx):
160
+ """get gm offset"""
161
+ gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * self.N0
162
+ return gm_offset
189
163
 
164
+ def get_cur_tr_block_num(self, tr_idx):
165
+ """get cur tr block_num"""
166
+ cur_prev_block_num = min(tr_idx, self.prev_block_num)
167
+ cur_next_block_num = min(self.next_block_num, self.Tc - tr_idx - 1)
168
+ block_num = cur_prev_block_num + 1 + cur_next_block_num
169
+ return block_num
170
+
171
+ def get_total_block_num(self):
172
+ """get total block num"""
173
+ block_num = 0
174
+ for b_idx in range(self.B):
175
+ for tr_idx in range(self.Tr):
176
+ block_num += self.get_cur_tr_block_num(tr_idx)
177
+ return block_num
178
+
179
+ def update_core_task_map(self,
180
+ core_b_map,
181
+ core_b_tr_map,
182
+ core_idx,
183
+ b_start,
184
+ b_end,
185
+ tr_start,
186
+ tr_end):
187
+ """update core task map"""
188
+ core_b_map[core_idx][0] = min(core_b_map[core_idx][0], b_start)
189
+ if tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
190
+ core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end - 1)
191
+ else:
192
+ core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end)
193
+ for b_idx in range(b_start, b_end + 1):
194
+ if b_idx == b_end and tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
195
+ break
196
+ elif b_idx == b_start and b_idx == b_end: # 没跨head
197
+ core_b_tr_map[core_idx][b_idx] = (tr_start, tr_end)
198
+ elif b_idx == b_start: # 跨head,第一个head
199
+ core_b_tr_map[core_idx][b_idx] = (tr_start, self.Tr)
200
+ elif b_idx == b_end: # 跨head,最后一个head
201
+ core_b_tr_map[core_idx][b_idx] = (0, tr_end)
202
+ else: # 跨head,中间的head
203
+ core_b_tr_map[core_idx][b_idx] = (0, self.Tr)
204
+
205
+ def convert_py_dict_to_tik_tensor(self, core_b_map, core_b_tr_map):
206
+ """convert py dict to tik tensor"""
207
+ # python dict -> tik tensor
208
+ # [batch_start, batch_idx_end] -> [batch_start, batch_num]
209
+ # [tr_start, tr_idx_end] -> [tr_start, tr_idx_end)
190
210
  core_idx_to_batch_info = self.tik_instance.Tensor(
191
211
  "int32", (self.core_num, 2), name="core_idx_to_batch_info", scope=UB
192
212
  )
193
213
  core_idx_to_tr_info = self.tik_instance.Tensor(
194
214
  "int32", (self.core_num, self.B, 2), name="core_idx_to_tr_info", scope=UB
195
215
  )
196
- for core_idx in core_idx_to_batch_idx:
197
- batch_start, batch_end = core_idx_to_batch_idx[core_idx]
216
+ for core_idx in core_b_map.keys():
217
+ batch_start, batch_end = core_b_map[core_idx]
198
218
  core_idx_to_batch_info[core_idx, 0] = batch_start
199
219
  core_idx_to_batch_info[core_idx, 1] = batch_end - batch_start + 1
200
- for batch_idx in core_idx_to_tr_idx[core_idx]:
201
- tr_start, tr_end = core_idx_to_tr_idx[core_idx][batch_idx]
220
+ for batch_idx in core_b_tr_map[core_idx].keys():
221
+ tr_start, tr_end = core_b_tr_map[core_idx][batch_idx]
202
222
  core_idx_to_tr_info[core_idx, batch_idx, 0] = tr_start
203
- core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end + 1
223
+ core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end
224
+
225
+ return core_idx_to_batch_info, core_idx_to_tr_info
226
+
227
+ def get_core_task_info(self):
228
+ """
229
+ Get batch start and batch number of each NPU core.
230
+ :return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
231
+ core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
232
+ core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
233
+ core_m_batch_n_Tr_end]]
234
+ """
235
+ if self.core_num > self.B * self.Tr:
236
+ self.core_num = self.B * self.Tr
237
+
238
+ total_blk_num = self.get_total_block_num()
239
+ b_start = 0
240
+ tr_start = 0
241
+ remain_blk_num = total_blk_num
242
+ core_b_map = defaultdict(lambda: [100000, -1])
243
+ core_b_tr_map = defaultdict(lambda: defaultdict(list))
244
+ for core_idx in range(self.core_num):
245
+ cur_core_blk_num = 0
246
+ cur_each_core_blk_num = remain_blk_num // (self.core_num - core_idx)
247
+ cur_core_finished = False
248
+ b_end = b_start
249
+ tr_end = tr_start
250
+ while b_end < self.B:
251
+ while tr_end < self.Tr:
252
+ cur_tr_blk_num = self.get_cur_tr_block_num(tr_end)
253
+ if abs(cur_core_blk_num - cur_each_core_blk_num) <= \
254
+ (cur_core_blk_num + cur_tr_blk_num - cur_each_core_blk_num):
255
+ self.update_core_task_map(core_b_map, core_b_tr_map, core_idx, b_start, b_end, tr_start, tr_end)
256
+ remain_blk_num -= cur_core_blk_num
257
+ cur_core_finished = True
258
+ break
259
+ else:
260
+ cur_core_blk_num += cur_tr_blk_num
261
+ tr_end += 1
262
+ if tr_end == self.Tr:
263
+ tr_end = 0
264
+ b_end += 1
265
+ if cur_core_finished:
266
+ b_start = b_end
267
+ tr_start = tr_end
268
+ break
269
+ core_idx_to_batch_info, core_idx_to_tr_info = self.convert_py_dict_to_tik_tensor(core_b_map, core_b_tr_map)
204
270
  return core_idx_to_batch_info, core_idx_to_tr_info
205
271
 
206
272
  def get_attn_mask_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
207
273
  """get attn mask gm offset"""
208
274
  if self.att_mask_shape[0] == 1:
209
- gm_offset = block_h_idx * (w * block_h) + block_w_idx * block_w
275
+ gm_offset = block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
210
276
  else:
211
277
  gm_offset = ((batch_start + batch_idx) // self.head_num) * h * w \
212
- + block_h_idx * (w * block_h) + block_w_idx * block_w
278
+ + block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
213
279
  return gm_offset
214
280
 
215
- def parser_input_shape(self, alibi_mask, attn_mask, dim_mask, dropout_mask, k, q, v):
281
+ def parse_input_shape(self, alibi_mask, attn_mask, dropout_mask, k, q, v):
216
282
  """parser input shape"""
217
283
  self.has_attn_mask = False
218
284
  self.has_drop_mask = False
219
285
  self.has_alibi_mask = False
220
- if isinstance(q, dict):
221
- self.q_shape = q["shape"]
222
- self.k_shape = k["shape"]
223
- self.v_shape = v["shape"]
224
- self.dim_mask_shape = dim_mask["shape"]
225
- if attn_mask is not None:
226
- self.has_attn_mask = True
227
- self.att_mask_shape = attn_mask["shape"]
228
- if dropout_mask is not None:
229
- self.has_drop_mask = True
230
- self.drop_mask_shape = dropout_mask["shape"]
231
- if alibi_mask is not None:
232
- self.has_alibi_mask = True
233
- self.alibi_mask_shape = alibi_mask["shape"]
234
- else:
235
- self.q_shape = q.shape
236
- self.k_shape = k.shape
237
- self.v_shape = v.shape
238
- self.dim_mask_shape = dim_mask.shape
239
- if attn_mask is not None:
240
- self.has_attn_mask = True
241
- self.att_mask_shape = attn_mask.shape
242
- if dropout_mask is not None:
243
- self.has_drop_mask = True
244
- self.drop_mask_shape = dropout_mask.shape
245
- if alibi_mask is not None:
246
- self.has_alibi_mask = True
247
- self.alibi_mask_shape = alibi_mask.shape
286
+ # NZ
287
+ self.q_shape = q["shape"]
288
+ self.k_shape = k["shape"]
289
+ self.v_shape = v["shape"]
290
+ # ND
291
+ self.q_ori_shape = q["ori_shape"]
292
+ self.k_ori_shape = k["ori_shape"]
293
+ if attn_mask is not None:
294
+ self.has_attn_mask = True
295
+ self.att_mask_shape = attn_mask["shape"]
296
+ if dropout_mask is not None:
297
+ self.has_drop_mask = True
298
+ self.drop_mask_shape = dropout_mask["shape"]
299
+ if alibi_mask is not None:
300
+ self.has_alibi_mask = True
301
+ self.alibi_mask_shape = alibi_mask["shape"]
248
302
 
249
303
  def define_inputs_outputs(self):
250
304
  """define inputs outputs"""
@@ -272,8 +326,6 @@ class FlashAttention(metaclass=ABCMeta):
272
326
  self.Q_gm = self.tik_instance.Tensor(FP16, self.q_shape, name="Q_gm", scope=GM)
273
327
  self.K_gm = self.tik_instance.Tensor(FP16, self.k_shape, name="K_gm", scope=GM)
274
328
  self.V_gm = self.tik_instance.Tensor(FP16, self.v_shape, name="V_gm", scope=GM)
275
- self.dim_mask_gm = self.tik_instance.Tensor(INT8, self.dim_mask_shape, name="mask_gm",
276
- scope=GM)
277
329
  if self.has_attn_mask:
278
330
  self.att_mask_gm = self.tik_instance.Tensor(FP16, self.att_mask_shape,
279
331
  name="att_mask_gm", scope=GM)
@@ -294,43 +346,47 @@ class FlashAttention(metaclass=ABCMeta):
294
346
  alibi_mask_ub_broadcast = self.tik_ops_utils.broadcast_row(alibi_mask_ub, (m_aligned, n_aligned))
295
347
  self.tik_instance.h_add(Sij_ub, Sij_ub, alibi_mask_ub_broadcast)
296
348
 
297
- def do_att_mask(self, Sij_ub, attn_mask_gm_offset, q_blk_height, kv_blk_height,
349
+ def do_att_mask(self, Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
298
350
  q_blk_h_aligned, kv_blk_h_aligned):
299
351
  """load attn mask from gm to ub, then mul it by MASK_FILL_VALUE and add Sij"""
300
352
  with self.tik_instance.new_stmt_scope(disable_sync=False):
301
- att_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
353
+ att_mask_ub = self.tik_instance.Tensor(FP16, (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
302
354
  scope=UB, name="att_mask_ub")
303
355
  self.tik_instance.data_move(att_mask_ub, self.att_mask_gm[attn_mask_gm_offset], 0,
304
- q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
356
+ kv_blk_height // self.N0, q_blk_height * self.N0 // 16,
357
+ (self.Nq - q_blk_height) * self.N0 // 16, 0)
305
358
  self.tik_instance.h_mul(att_mask_ub, att_mask_ub, MASK_FILL_VALUE)
306
- self.tik_instance.h_add(Sij_ub, Sij_ub, att_mask_ub)
359
+ self.tik_instance.h_add(Sij_ub_N1MN0, Sij_ub_N1MN0, att_mask_ub)
307
360
 
308
361
  def do_dropout_mask(self, Pij_ub, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
309
- q_blk_h_aligned, q_blk_height, precision_type=FP16):
362
+ q_blk_h_aligned, q_blk_height, precision_type=FP16, workspace=None):
310
363
  """load drop mask from gm to ub, then mul it by Pij"""
311
364
  with self.tik_instance.new_stmt_scope(disable_sync=False):
312
365
  dropout_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
313
366
  scope=UB, name="drop_mask_ub")
314
367
  self.tik_instance.data_move(dropout_mask_ub, self.drop_mask_gm[dropout_mask_gm_offset], 0,
315
368
  q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
369
+ dropout_mask_ub = dropout_mask_ub.reshape((kv_blk_height // self.N0, q_blk_height, self.N0))
316
370
  if precision_type == FP32:
317
- dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32, (q_blk_h_aligned, kv_blk_h_aligned),
371
+ dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32,
372
+ (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
318
373
  scope=UB, name="dropout_mask_ub_fp32")
319
374
  self.tik_instance.h_cast(dropout_mask_ub_fp32, dropout_mask_ub, "none")
320
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
375
+ if workspace is None:
376
+ self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
377
+ else:
378
+ self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub_fp32)
321
379
  else:
322
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
323
-
324
- @abstractmethod
325
- def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
326
- """compute one core"""
327
- raise NotImplementedError
380
+ if workspace is None:
381
+ self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
382
+ else:
383
+ self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub)
328
384
 
329
385
  def compute_process(self):
330
386
  """The compute process of FlashAttention"""
331
387
  self.init()
332
-
333
- core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_bath_info()
388
+ self.prepare_global_ones()
389
+ core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_task_info()
334
390
  with self.tik_instance.for_range(begint=0, endt=self.core_num, name="core_index",
335
391
  block_num=self.core_num) as core_idx:
336
392
  batch_start_s = self.tik_instance.Scalar("int32", name="batch_start_s")