mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ from tbe.common.platform import get_soc_spec
25
25
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
26
26
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
27
27
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
28
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import INT8
29
28
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import MASK_FILL_VALUE
30
29
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
31
30
  from mindspore.ops._op_impl._custom_op.flash_attention.tik_ops_utils import TikOpsUtils
@@ -37,7 +36,7 @@ from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.sparse_ti
37
36
  class FlashAttention(metaclass=ABCMeta):
38
37
  """The base class of FlashAttention"""
39
38
 
40
- def __init__(self, q, k, v, dim_mask, attn_mask, dropout_mask, alibi_mask, kernel_name,
39
+ def __init__(self, q, k, v, attn_mask, dropout_mask, alibi_mask, kernel_name,
41
40
  tiling_stgy_cls,
42
41
  prev_block_num=65536,
43
42
  next_block_num=65536,
@@ -48,7 +47,6 @@ class FlashAttention(metaclass=ABCMeta):
48
47
  :param q: with shape: (B, h, N, d)
49
48
  :param k: with shape: (B, h, N, d)
50
49
  :param v: with shape: (B, h, N, d)
51
- :param dim_mask: with shape: (x, ) x equals length last dim that not padded to 16.
52
50
  :param attn_mask: with shape: (1, N, N) or (B, N, N)
53
51
  :param dropout_mask: with shape: (B, h, N, N)
54
52
  :param alibi_mask: with shape: (B, h, 1, N)
@@ -66,19 +64,25 @@ class FlashAttention(metaclass=ABCMeta):
66
64
  src_stride=0,
67
65
  dst_stride=0)
68
66
  self.tik_ops_utils = TikOpsUtils(self.tik_instance)
69
- self.parser_input_shape(alibi_mask, attn_mask, dim_mask, dropout_mask, k, q, v)
70
-
71
- batch_size, h, Nq, d = self.q_shape
67
+ self.parse_input_shape(alibi_mask, attn_mask, dropout_mask, k, q, v)
68
+ # NZ
69
+ _, _, N1, M1, M0, N0 = self.q_shape
70
+ self.M1 = M1
71
+ self.N1 = N1
72
+ self.M0 = M0
73
+ self.N0 = N0
74
+ self.d = N1 * N0
75
+ # ND
76
+ batch_size, h, Nq, actual_d = self.q_ori_shape
72
77
  self.head_num = h
73
- self.B, self.Nq, self.d = batch_size * h, Nq, d
74
- self.N = self.k_shape[2]
78
+ self.B, self.Nq = batch_size * h, Nq
79
+ self.N = self.k_ori_shape[2]
80
+ self.actual_d = actual_d
75
81
 
76
82
  self.l_shape = [batch_size, h, self.Nq]
77
83
  self.m_shape = [batch_size, h, self.Nq]
78
- self.O_shape = [batch_size, h, self.Nq, self.d]
79
- self.actual_d = self.dim_mask_shape[0]
84
+ self.O_shape = self.q_shape
80
85
 
81
- self.K0 = 16
82
86
  self.prev_block_num = prev_block_num
83
87
  self.next_block_num = next_block_num
84
88
  self.high_precision = high_precision
@@ -86,7 +90,6 @@ class FlashAttention(metaclass=ABCMeta):
86
90
  self.precision_type = FP32
87
91
  else:
88
92
  self.precision_type = FP16
89
-
90
93
  if tiling_stgy_cls is None:
91
94
  self.tiling_stgy = SparseTiling(self.Nq, self.N, self.d)
92
95
  else:
@@ -106,9 +109,9 @@ class FlashAttention(metaclass=ABCMeta):
106
109
  self.alibi_mask_gm = None
107
110
 
108
111
  @staticmethod
109
- def get_gm_offset(batch_start, batch_idx, h, w, block_h, block_idx):
110
- """get gm offset"""
111
- gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * w
112
+ def get_l_m_gm_offset(batch_start, batch_idx, h, block_h, block_idx):
113
+ """get l m gm offset"""
114
+ gm_offset = (batch_start + batch_idx) * h + block_idx * block_h
112
115
  return gm_offset
113
116
 
114
117
  @staticmethod
@@ -143,108 +146,159 @@ class FlashAttention(metaclass=ABCMeta):
143
146
  """collect outputs"""
144
147
  raise NotImplementedError
145
148
 
146
- def get_core_bath_info(self):
147
- """
148
- Get batch start and batch number of each NPU core.
149
- :return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
150
- core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
151
- core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
152
- core_m_batch_n_Tr_end]]
153
- """
154
- if self.core_num > self.B * self.Tr:
155
- self.core_num = self.B * self.Tr
156
-
157
- task_idx_to_batch_tr_idx = dict()
158
- for task_idx in range(self.B * self.Tr):
159
- batch_idx = task_idx // self.Tr
160
- tr_idx = task_idx % self.Tr
161
- task_idx_to_batch_tr_idx[task_idx] = [batch_idx, tr_idx]
149
+ @abstractmethod
150
+ def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
151
+ """compute one core"""
152
+ raise NotImplementedError
162
153
 
163
- core_idx_to_batch_idx = defaultdict(lambda: [100000, -1])
164
- core_idx_to_tr_idx = defaultdict(lambda: defaultdict(lambda: [100000, -1]))
165
- task_start = 0
166
- avg_task_num_per_core, remain_task = divmod(self.B * self.Tr, self.core_num)
154
+ @abstractmethod
155
+ def prepare_global_ones(self):
156
+ """prepare global ones"""
157
+ raise NotImplementedError
167
158
 
168
- for core_idx in range(self.core_num):
169
- cur_core_task_num = avg_task_num_per_core
170
- if core_idx < remain_task:
171
- cur_core_task_num += 1
172
- task_end = task_start + cur_core_task_num
173
- for task_idx in range(task_start, task_end):
174
- try:
175
- batch_idx, tr_idx = task_idx_to_batch_tr_idx[task_idx]
176
- except KeyError:
177
- raise ValueError("The argument 'task_idx' is not valid.")
178
- batch_start_end_pair = core_idx_to_batch_idx[core_idx]
179
- if batch_idx < batch_start_end_pair[0]:
180
- batch_start_end_pair[0] = batch_idx
181
- if batch_idx > batch_start_end_pair[1]:
182
- batch_start_end_pair[1] = batch_idx
183
- tr_start_end_pair = core_idx_to_tr_idx[core_idx][batch_idx]
184
- if tr_idx < tr_start_end_pair[0]:
185
- tr_start_end_pair[0] = tr_idx
186
- if tr_idx > tr_start_end_pair[1]:
187
- tr_start_end_pair[1] = tr_idx
188
- task_start = task_end
159
+ def get_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_idx):
160
+ """get gm offset"""
161
+ gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * self.N0
162
+ return gm_offset
189
163
 
164
+ def get_cur_tr_block_num(self, tr_idx):
165
+ """get cur tr block_num"""
166
+ cur_prev_block_num = min(tr_idx, self.prev_block_num)
167
+ cur_next_block_num = min(self.next_block_num, self.Tc - tr_idx - 1)
168
+ block_num = cur_prev_block_num + 1 + cur_next_block_num
169
+ return block_num
170
+
171
+ def get_total_block_num(self):
172
+ """get total block num"""
173
+ block_num = 0
174
+ for b_idx in range(self.B):
175
+ for tr_idx in range(self.Tr):
176
+ block_num += self.get_cur_tr_block_num(tr_idx)
177
+ return block_num
178
+
179
+ def update_core_task_map(self,
180
+ core_b_map,
181
+ core_b_tr_map,
182
+ core_idx,
183
+ b_start,
184
+ b_end,
185
+ tr_start,
186
+ tr_end):
187
+ """update core task map"""
188
+ core_b_map[core_idx][0] = min(core_b_map[core_idx][0], b_start)
189
+ if tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
190
+ core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end - 1)
191
+ else:
192
+ core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end)
193
+ for b_idx in range(b_start, b_end + 1):
194
+ if b_idx == b_end and tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
195
+ break
196
+ elif b_idx == b_start and b_idx == b_end: # 没跨head
197
+ core_b_tr_map[core_idx][b_idx] = (tr_start, tr_end)
198
+ elif b_idx == b_start: # 跨head,第一个head
199
+ core_b_tr_map[core_idx][b_idx] = (tr_start, self.Tr)
200
+ elif b_idx == b_end: # 跨head,最后一个head
201
+ core_b_tr_map[core_idx][b_idx] = (0, tr_end)
202
+ else: # 跨head,中间的head
203
+ core_b_tr_map[core_idx][b_idx] = (0, self.Tr)
204
+
205
+ def convert_py_dict_to_tik_tensor(self, core_b_map, core_b_tr_map):
206
+ """convert py dict to tik tensor"""
207
+ # python dict -> tik tensor
208
+ # [batch_start, batch_idx_end] -> [batch_start, batch_num]
209
+ # [tr_start, tr_idx_end] -> [tr_start, tr_idx_end)
190
210
  core_idx_to_batch_info = self.tik_instance.Tensor(
191
211
  "int32", (self.core_num, 2), name="core_idx_to_batch_info", scope=UB
192
212
  )
193
213
  core_idx_to_tr_info = self.tik_instance.Tensor(
194
214
  "int32", (self.core_num, self.B, 2), name="core_idx_to_tr_info", scope=UB
195
215
  )
196
- for core_idx in core_idx_to_batch_idx:
197
- batch_start, batch_end = core_idx_to_batch_idx[core_idx]
216
+ for core_idx in core_b_map.keys():
217
+ batch_start, batch_end = core_b_map[core_idx]
198
218
  core_idx_to_batch_info[core_idx, 0] = batch_start
199
219
  core_idx_to_batch_info[core_idx, 1] = batch_end - batch_start + 1
200
- for batch_idx in core_idx_to_tr_idx[core_idx]:
201
- tr_start, tr_end = core_idx_to_tr_idx[core_idx][batch_idx]
220
+ for batch_idx in core_b_tr_map[core_idx].keys():
221
+ tr_start, tr_end = core_b_tr_map[core_idx][batch_idx]
202
222
  core_idx_to_tr_info[core_idx, batch_idx, 0] = tr_start
203
- core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end + 1
223
+ core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end
224
+
225
+ return core_idx_to_batch_info, core_idx_to_tr_info
226
+
227
+ def get_core_task_info(self):
228
+ """
229
+ Get batch start and batch number of each NPU core.
230
+ :return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
231
+ core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
232
+ core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
233
+ core_m_batch_n_Tr_end]]
234
+ """
235
+ if self.core_num > self.B * self.Tr:
236
+ self.core_num = self.B * self.Tr
237
+
238
+ total_blk_num = self.get_total_block_num()
239
+ b_start = 0
240
+ tr_start = 0
241
+ remain_blk_num = total_blk_num
242
+ core_b_map = defaultdict(lambda: [100000, -1])
243
+ core_b_tr_map = defaultdict(lambda: defaultdict(list))
244
+ for core_idx in range(self.core_num):
245
+ cur_core_blk_num = 0
246
+ cur_each_core_blk_num = remain_blk_num // (self.core_num - core_idx)
247
+ cur_core_finished = False
248
+ b_end = b_start
249
+ tr_end = tr_start
250
+ while b_end < self.B:
251
+ while tr_end < self.Tr:
252
+ cur_tr_blk_num = self.get_cur_tr_block_num(tr_end)
253
+ if abs(cur_core_blk_num - cur_each_core_blk_num) <= \
254
+ (cur_core_blk_num + cur_tr_blk_num - cur_each_core_blk_num):
255
+ self.update_core_task_map(core_b_map, core_b_tr_map, core_idx, b_start, b_end, tr_start, tr_end)
256
+ remain_blk_num -= cur_core_blk_num
257
+ cur_core_finished = True
258
+ break
259
+ else:
260
+ cur_core_blk_num += cur_tr_blk_num
261
+ tr_end += 1
262
+ if tr_end == self.Tr:
263
+ tr_end = 0
264
+ b_end += 1
265
+ if cur_core_finished:
266
+ b_start = b_end
267
+ tr_start = tr_end
268
+ break
269
+ core_idx_to_batch_info, core_idx_to_tr_info = self.convert_py_dict_to_tik_tensor(core_b_map, core_b_tr_map)
204
270
  return core_idx_to_batch_info, core_idx_to_tr_info
205
271
 
206
272
  def get_attn_mask_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
207
273
  """get attn mask gm offset"""
208
274
  if self.att_mask_shape[0] == 1:
209
- gm_offset = block_h_idx * (w * block_h) + block_w_idx * block_w
275
+ gm_offset = block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
210
276
  else:
211
277
  gm_offset = ((batch_start + batch_idx) // self.head_num) * h * w \
212
- + block_h_idx * (w * block_h) + block_w_idx * block_w
278
+ + block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
213
279
  return gm_offset
214
280
 
215
- def parser_input_shape(self, alibi_mask, attn_mask, dim_mask, dropout_mask, k, q, v):
281
+ def parse_input_shape(self, alibi_mask, attn_mask, dropout_mask, k, q, v):
216
282
  """parser input shape"""
217
283
  self.has_attn_mask = False
218
284
  self.has_drop_mask = False
219
285
  self.has_alibi_mask = False
220
- if isinstance(q, dict):
221
- self.q_shape = q["shape"]
222
- self.k_shape = k["shape"]
223
- self.v_shape = v["shape"]
224
- self.dim_mask_shape = dim_mask["shape"]
225
- if attn_mask is not None:
226
- self.has_attn_mask = True
227
- self.att_mask_shape = attn_mask["shape"]
228
- if dropout_mask is not None:
229
- self.has_drop_mask = True
230
- self.drop_mask_shape = dropout_mask["shape"]
231
- if alibi_mask is not None:
232
- self.has_alibi_mask = True
233
- self.alibi_mask_shape = alibi_mask["shape"]
234
- else:
235
- self.q_shape = q.shape
236
- self.k_shape = k.shape
237
- self.v_shape = v.shape
238
- self.dim_mask_shape = dim_mask.shape
239
- if attn_mask is not None:
240
- self.has_attn_mask = True
241
- self.att_mask_shape = attn_mask.shape
242
- if dropout_mask is not None:
243
- self.has_drop_mask = True
244
- self.drop_mask_shape = dropout_mask.shape
245
- if alibi_mask is not None:
246
- self.has_alibi_mask = True
247
- self.alibi_mask_shape = alibi_mask.shape
286
+ # NZ
287
+ self.q_shape = q["shape"]
288
+ self.k_shape = k["shape"]
289
+ self.v_shape = v["shape"]
290
+ # ND
291
+ self.q_ori_shape = q["ori_shape"]
292
+ self.k_ori_shape = k["ori_shape"]
293
+ if attn_mask is not None:
294
+ self.has_attn_mask = True
295
+ self.att_mask_shape = attn_mask["shape"]
296
+ if dropout_mask is not None:
297
+ self.has_drop_mask = True
298
+ self.drop_mask_shape = dropout_mask["shape"]
299
+ if alibi_mask is not None:
300
+ self.has_alibi_mask = True
301
+ self.alibi_mask_shape = alibi_mask["shape"]
248
302
 
249
303
  def define_inputs_outputs(self):
250
304
  """define inputs outputs"""
@@ -272,8 +326,6 @@ class FlashAttention(metaclass=ABCMeta):
272
326
  self.Q_gm = self.tik_instance.Tensor(FP16, self.q_shape, name="Q_gm", scope=GM)
273
327
  self.K_gm = self.tik_instance.Tensor(FP16, self.k_shape, name="K_gm", scope=GM)
274
328
  self.V_gm = self.tik_instance.Tensor(FP16, self.v_shape, name="V_gm", scope=GM)
275
- self.dim_mask_gm = self.tik_instance.Tensor(INT8, self.dim_mask_shape, name="mask_gm",
276
- scope=GM)
277
329
  if self.has_attn_mask:
278
330
  self.att_mask_gm = self.tik_instance.Tensor(FP16, self.att_mask_shape,
279
331
  name="att_mask_gm", scope=GM)
@@ -294,43 +346,47 @@ class FlashAttention(metaclass=ABCMeta):
294
346
  alibi_mask_ub_broadcast = self.tik_ops_utils.broadcast_row(alibi_mask_ub, (m_aligned, n_aligned))
295
347
  self.tik_instance.h_add(Sij_ub, Sij_ub, alibi_mask_ub_broadcast)
296
348
 
297
- def do_att_mask(self, Sij_ub, attn_mask_gm_offset, q_blk_height, kv_blk_height,
349
+ def do_att_mask(self, Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
298
350
  q_blk_h_aligned, kv_blk_h_aligned):
299
351
  """load attn mask from gm to ub, then mul it by MASK_FILL_VALUE and add Sij"""
300
352
  with self.tik_instance.new_stmt_scope(disable_sync=False):
301
- att_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
353
+ att_mask_ub = self.tik_instance.Tensor(FP16, (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
302
354
  scope=UB, name="att_mask_ub")
303
355
  self.tik_instance.data_move(att_mask_ub, self.att_mask_gm[attn_mask_gm_offset], 0,
304
- q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
356
+ kv_blk_height // self.N0, q_blk_height * self.N0 // 16,
357
+ (self.Nq - q_blk_height) * self.N0 // 16, 0)
305
358
  self.tik_instance.h_mul(att_mask_ub, att_mask_ub, MASK_FILL_VALUE)
306
- self.tik_instance.h_add(Sij_ub, Sij_ub, att_mask_ub)
359
+ self.tik_instance.h_add(Sij_ub_N1MN0, Sij_ub_N1MN0, att_mask_ub)
307
360
 
308
361
  def do_dropout_mask(self, Pij_ub, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
309
- q_blk_h_aligned, q_blk_height, precision_type=FP16):
362
+ q_blk_h_aligned, q_blk_height, precision_type=FP16, workspace=None):
310
363
  """load drop mask from gm to ub, then mul it by Pij"""
311
364
  with self.tik_instance.new_stmt_scope(disable_sync=False):
312
365
  dropout_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
313
366
  scope=UB, name="drop_mask_ub")
314
367
  self.tik_instance.data_move(dropout_mask_ub, self.drop_mask_gm[dropout_mask_gm_offset], 0,
315
368
  q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
369
+ dropout_mask_ub = dropout_mask_ub.reshape((kv_blk_height // self.N0, q_blk_height, self.N0))
316
370
  if precision_type == FP32:
317
- dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32, (q_blk_h_aligned, kv_blk_h_aligned),
371
+ dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32,
372
+ (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
318
373
  scope=UB, name="dropout_mask_ub_fp32")
319
374
  self.tik_instance.h_cast(dropout_mask_ub_fp32, dropout_mask_ub, "none")
320
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
375
+ if workspace is None:
376
+ self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
377
+ else:
378
+ self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub_fp32)
321
379
  else:
322
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
323
-
324
- @abstractmethod
325
- def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
326
- """compute one core"""
327
- raise NotImplementedError
380
+ if workspace is None:
381
+ self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
382
+ else:
383
+ self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub)
328
384
 
329
385
  def compute_process(self):
330
386
  """The compute process of FlashAttention"""
331
387
  self.init()
332
-
333
- core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_bath_info()
388
+ self.prepare_global_ones()
389
+ core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_task_info()
334
390
  with self.tik_instance.for_range(begint=0, endt=self.core_num, name="core_index",
335
391
  block_num=self.core_num) as core_idx:
336
392
  batch_start_s = self.tik_instance.Scalar("int32", name="batch_start_s")