mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,199 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """The impl of flash attention"""
16
- from __future__ import absolute_import
17
- import mindspore.ops as ops
18
- from mindspore import dtype as mstype
19
- from mindspore.ops import Custom
20
- from mindspore.ops import DataType
21
- from mindspore.ops import TBERegOp
22
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_bwd import flash_attention_grad
23
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_fwd import flash_attention
24
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
25
-
26
- KERNEL_NAME = "flash_attention"
27
-
28
- cus_flash_atten_op_info = TBERegOp("FlashAttentionPrimitive") \
29
- .fusion_type("OPAQUE") \
30
- .partial_flag(True) \
31
- .async_flag(False) \
32
- .binfile_name("flash_attention.so") \
33
- .compute_cost(10) \
34
- .kernel_name(KERNEL_NAME) \
35
- .attr("prev_block_num", "required", "int", "all", "65536") \
36
- .attr("next_block_num", "required", "int", "all", "65536") \
37
- .attr("high_precision", "required", "bool", "all", "false") \
38
- .attr("tiling_stgy_name", "required", "str", "all", "sparse") \
39
- .input(0, "query", False, "required", "all") \
40
- .input(1, "key", False, "required", "all") \
41
- .input(2, "value", False, "required", "all") \
42
- .input(3, "dim_mask", False, "required", "all") \
43
- .input(4, "attn_mask", False, "optional", "all") \
44
- .input(5, "dropout_mask", False, "optional", "all") \
45
- .input(6, "alibi_mask", False, "optional", "all") \
46
- .output(0, "output", False, "required", "all") \
47
- .output(1, "rowsum", False, "required", "all") \
48
- .output(2, "rowmax", False, "required", "all") \
49
- .dtype_format(DataType.F16_Default,
50
- DataType.F16_Default,
51
- DataType.F16_Default,
52
- DataType.I8_Default,
53
- DataType.F16_Default,
54
- DataType.F16_Default,
55
- DataType.F16_Default,
56
- DataType.F16_Default,
57
- DataType.F16_Default,
58
- DataType.F16_Default) \
59
- .dtype_format(DataType.F16_Default,
60
- DataType.F16_Default,
61
- DataType.F16_Default,
62
- DataType.I8_Default,
63
- DataType.F16_Default,
64
- DataType.F16_Default,
65
- DataType.F16_Default,
66
- DataType.F16_Default,
67
- DataType.F32_Default,
68
- DataType.F16_Default) \
69
- .get_op_info()
70
-
71
- GRAD_KERNEL_NAME = "flash_attention_grad"
72
-
73
- cus_flash_atten_grad_op_info = TBERegOp("FlashAttentionGradPrimitive") \
74
- .fusion_type("OPAQUE") \
75
- .partial_flag(True) \
76
- .async_flag(False) \
77
- .binfile_name("flash_attention_grad.so") \
78
- .compute_cost(10) \
79
- .kernel_name(GRAD_KERNEL_NAME) \
80
- .attr("prev_block_num", "required", "int", "all", "65536") \
81
- .attr("next_block_num", "required", "int", "all", "65536") \
82
- .attr("high_precision", "required", "bool", "all", "false") \
83
- .attr("tiling_stgy_name", "required", "str", "all", "sparse") \
84
- .input(0, "query", False, "required", "all") \
85
- .input(1, "key", False, "required", "all") \
86
- .input(2, "value", False, "required", "all") \
87
- .input(3, "output", False, "required", "all") \
88
- .input(4, "do", False, "required", "all") \
89
- .input(5, "rowsum", False, "required", "all") \
90
- .input(6, "rowmax", False, "required", "all") \
91
- .input(7, "dim_mask", False, "required", "all") \
92
- .input(8, "attn_mask", False, "optional", "all") \
93
- .input(9, "dropout_mask", False, "optional", "all") \
94
- .input(10, "alibi_mask", False, "optional", "all") \
95
- .output(0, "dq", False, "required", "all") \
96
- .output(1, "dk", False, "required", "all") \
97
- .output(2, "dv", False, "required", "all") \
98
- .dtype_format(DataType.F16_Default,
99
- DataType.F16_Default,
100
- DataType.F16_Default,
101
- DataType.F16_Default,
102
- DataType.F16_Default,
103
- DataType.F16_Default,
104
- DataType.F16_Default,
105
- DataType.I8_Default,
106
- DataType.F16_Default,
107
- DataType.F16_Default,
108
- DataType.F16_Default,
109
- DataType.F32_Default,
110
- DataType.F32_Default,
111
- DataType.F32_Default) \
112
- .dtype_format(DataType.F16_Default,
113
- DataType.F16_Default,
114
- DataType.F16_Default,
115
- DataType.F16_Default,
116
- DataType.F16_Default,
117
- DataType.F32_Default,
118
- DataType.F16_Default,
119
- DataType.I8_Default,
120
- DataType.F16_Default,
121
- DataType.F16_Default,
122
- DataType.F16_Default,
123
- DataType.F32_Default,
124
- DataType.F32_Default,
125
- DataType.F32_Default) \
126
- .get_op_info()
127
-
128
-
129
- def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
130
- tiling_stgy_name='sparse', high_precision=False):
131
- """get flash attention grad"""
132
-
133
- def infer_shape(q_shape, k_shape, v_shape, o_shape, do_shape, l_shape, m_shape,
134
- dim_mask_shape, att_mask_shape, dropout_mask_shape, alibi_mask_shape):
135
- return q_shape, k_shape, v_shape
136
-
137
- def infer_dtype(q_dtype, k_dtype, v_dtype, o_dytpe, do_dtype, l_dtype, m_dtype,
138
- dim_mask_dtype, attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
139
- return mstype.float32, mstype.float32, mstype.float32
140
-
141
- fa_grad = Custom(flash_attention_grad, out_shape=infer_shape,
142
- out_dtype=infer_dtype, func_type="tbe", reg_info=cus_flash_atten_grad_op_info)
143
- fa_grad.add_prim_attr("prev_block_num", prev_block_num)
144
- fa_grad.add_prim_attr("next_block_num", next_block_num)
145
- fa_grad.add_prim_attr("high_precision", high_precision)
146
- fa_grad.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
147
- fa_grad.init_prim_io_names(
148
- inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "dim_mask", "attn_mask", "dropout_mask",
149
- "alibi_mask"],
150
- outputs=["dq", "dk", "dv"]
151
- )
152
-
153
- def bprop(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_mask, out, douts):
154
- output, rowsum, rowmax = out
155
- dout, _, _ = douts
156
- dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, dim_mask, attn_mask, dropout_mask,
157
- alibi_mask)
158
- dq = ops.cast(dq, mstype.float16)
159
- dk = ops.cast(dk, mstype.float16)
160
- dv = ops.cast(dv, mstype.float16)
161
- return dq, dk, dv, zeros_like(dim_mask), zeros_like(attn_mask), \
162
- zeros_like(dropout_mask), zeros_like(alibi_mask)
163
-
164
- return bprop
165
-
166
-
167
- def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_name='sparse', high_precision=False):
168
- """get_flash_attention"""
169
-
170
- def infer_shape(q_shape, k_shape, v_shape, dim_mask_shape, attn_mask_shape=None,
171
- dropout_mask_shape=None, alibi_mask_shape=None):
172
- """infer shape"""
173
- batch, hidden_size, seq_len, _ = q_shape
174
- l_shape = (batch, hidden_size, seq_len)
175
- m_shape = (batch, hidden_size, seq_len)
176
- return q_shape, l_shape, m_shape
177
-
178
- def infer_dtype(q_dtype, k_dtype, v_dtype, dim_mask_dtype, attn_mask_dtype=None,
179
- dropout_mask_dtype=None, alibi_mask_type=None):
180
- """infer type"""
181
- l_dtype = mstype.float16
182
- if high_precision:
183
- l_dtype = mstype.float32
184
- return q_dtype, l_dtype, q_dtype
185
-
186
- fa_grad = get_flash_attention_grad(prev_block_num, next_block_num, tiling_stgy_name, high_precision)
187
- fa_forward = Custom(flash_attention, out_shape=infer_shape,
188
- out_dtype=infer_dtype, func_type="tbe", bprop=fa_grad,
189
- reg_info=cus_flash_atten_op_info)
190
- fa_forward.add_prim_attr("prev_block_num", prev_block_num)
191
- fa_forward.add_prim_attr("next_block_num", next_block_num)
192
- fa_forward.add_prim_attr("high_precision", high_precision)
193
- fa_forward.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
194
- fa_forward.init_prim_io_names(
195
- inputs=["query", "key", "value", "dim_mask", "attn_mask", "dropout_mask", "alibi_mask"],
196
- outputs=["output", "rowsum", "rowmax"]
197
- )
198
-
199
- return fa_forward
@@ -1,446 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """the common about tik ops"""
16
- from functools import partial
17
-
18
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SIZE
19
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
20
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
21
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import L0C
22
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import L1
23
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
24
-
25
-
26
- class TikOpsUtils:
27
- """Utils function class about tik ops"""
28
-
29
- def __init__(self, tik_instance):
30
- self.tik_instance = tik_instance
31
- self.dtype = "float16"
32
- self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
33
- src_stride=0,
34
- dst_stride=0)
35
-
36
- def MK_TO_K1MK0(self, mk_input_tensor, workspace_tensor=None):
37
- """change data shape from (M, K) to (K1, M, K0), K1 = K // K0, the effect is equant to:
38
- new_tensor = np.stack(np.hsplit(mk_input_tensor, K1), axis=0)
39
-
40
- :param mk_input_tensor: input tensor in GM with shape: (M, K)
41
- :param workspace_tensor: workspace tensor with shape: (K1, M, K0)
42
- tensor will be changed, otherwise the new data will be copied to the workspace tensor,
43
- and input tensor will stay unchanged.
44
- :return: Tensor with shape (K1,M, K0)
45
- """
46
- dtype = mk_input_tensor.dtype
47
- m, k = mk_input_tensor.shape
48
- K0 = 16
49
- K1 = k // K0
50
- M = self.up_align_to_K0(m)
51
- try:
52
- dtype_size = DTYPE_SIZE[dtype]
53
- except KeyError:
54
- raise ValueError("The argument 'dtype' is not valid.")
55
- if workspace_tensor is not None:
56
- with self.tik_instance.for_range(0, K1) as i:
57
- self.tik_instance.data_move(
58
- workspace_tensor[i * M * K0:],
59
- mk_input_tensor[i * K0:],
60
- 0,
61
- M,
62
- K0 * dtype_size // 32,
63
- (K1 - 1) * K0 * dtype_size // 32,
64
- 0,
65
- )
66
- return workspace_tensor.reshape((K1, M, K0))
67
-
68
- with self.tik_instance.new_stmt_scope(disable_sync=False):
69
- tmp_ub = self.tik_instance.Tensor(dtype, (K1, M, K0), name="tmp_ub", scope=UB)
70
- # data_move(m,k) --> (k1,m,K0)
71
- with self.tik_instance.for_range(0, K1) as i:
72
- self.tik_instance.data_move(
73
- tmp_ub[i * M * K0:],
74
- mk_input_tensor[i * K0:],
75
- 0,
76
- M,
77
- K0 * dtype_size // 32,
78
- (K1 - 1) * K0 * dtype_size // 32,
79
- 0,
80
- )
81
- self.cont_data_mv_1_bust(
82
- dst=mk_input_tensor, src=tmp_ub, burst=K1 * M * K0 * dtype_size // 32)
83
- return mk_input_tensor.reshape((K1, M, K0))
84
-
85
- def transpose_matrix(self, src_ub, dst_ub, N, nk0=False):
86
- """ transpose matrix, default support shape: (16, n) -> (n, 16)
87
- if nk0 is true, support shape: (n, 16) -> (16, n)
88
- """
89
- K0 = 16
90
- rep_times = N // K0
91
- if nk0:
92
- src_list = [src_ub[16 * i] for i in range(16)]
93
- dst_list = [dst_ub[N * i] for i in range(16)]
94
- else:
95
- src_list = [src_ub[N * i] for i in range(16)]
96
- dst_list = [dst_ub[16 * i] for i in range(16)]
97
-
98
- dst_rep_stride = K0
99
- src_rep_stride = 1
100
- if rep_times == 1:
101
- dst_rep_stride = 0
102
- src_rep_stride = 0
103
-
104
- if nk0:
105
- src_rep_stride, dst_rep_stride = dst_rep_stride, src_rep_stride
106
-
107
- self.tik_instance.vec_trans_scatter(
108
- False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride
109
- )
110
- return dst_ub
111
-
112
- def KN_TO_K1NK0(self, kn_input_tensor, workspace_tensor=None):
113
- """change data shape from (K,N) to (K1, N, K0), K1 = K // K0, the effect is equvilent to:
114
- new_tensor = np.reshape(kn_input_tensor, newshape=(K1, K0, N)).swapaxes(1, 2)
115
-
116
- :param kn_input_tensor: input tensor with shape: (K, N)
117
- :param workspace_tensor: workspace tensor with shape: (K1, N, K0)
118
- tensor will be changed, otherwise the new data will be copied to the workspace tensor,
119
- and input tensor will stay unchanged.
120
- :return: Tensor with shape: (K1, N, K0)
121
- """
122
- dtype = kn_input_tensor.dtype
123
- k, n = kn_input_tensor.shape
124
- K0 = 16
125
- K1 = k // K0
126
- N = n
127
- try:
128
- dtype_size = DTYPE_SIZE[dtype]
129
- except KeyError:
130
- raise ValueError("The argument 'dtype' is not valid.")
131
- with self.tik_instance.for_range(0, K1) as index:
132
- k1nk0_ub = self.tik_instance.Tensor(dtype, (N, K0), UB, "k1nk0_ub")
133
- src_ub = self.tik_instance.Tensor(dtype, (K0, N), UB, "src_ub")
134
- burst_len = K0 * N * dtype_size // 32
135
- self.cont_data_mv_1_bust(dst=src_ub, src=kn_input_tensor[index * K0 * N],
136
- burst=burst_len)
137
- k1nk0_ub = self.transpose_matrix(src_ub, k1nk0_ub, N)
138
- if workspace_tensor is None:
139
- self.cont_data_mv_1_bust(dst=kn_input_tensor[index * K0 * N], src=k1nk0_ub,
140
- burst=burst_len)
141
- else:
142
- self.cont_data_mv_1_bust(dst=workspace_tensor[index * K0 * N], src=k1nk0_ub,
143
- burst=burst_len)
144
- if workspace_tensor is None:
145
- return kn_input_tensor.reshape((K1, N, K0))
146
-
147
- return workspace_tensor.reshape((K1, N, K0))
148
-
149
- def N1MN0_TO_MN(self, N1MN0_input):
150
- """change data shape from (N1, M, N0) to (M, N), N0=16, N = N1 * K0, the effect is equant to:
151
- N1MN0_input = np.concatenate(list(map(np.squeeze, np.split(N1MN0_input, N1))), axis=1)
152
-
153
- :param N1MN0_input: input tensor with shape (N, M, N0) in GM or L1.
154
- :return:
155
- """
156
- dtype = N1MN0_input.dtype
157
- N1, M, N0 = N1MN0_input.shape
158
- try:
159
- dtype_size = DTYPE_SIZE[dtype]
160
- except KeyError:
161
- raise ValueError("The argument 'dtype' is not valid.")
162
- with self.tik_instance.new_stmt_scope(disable_sync=False):
163
- tmp_ub = self.tik_instance.Tensor(dtype, (M, N1 * N0), name="tmp_ub", scope=UB)
164
- # data_move (n1,m,n0) --> (m,n)
165
- with self.tik_instance.for_range(0, N1) as i:
166
- self.tik_instance.data_move(
167
- tmp_ub[i * N0:],
168
- N1MN0_input[i * M * N0:],
169
- 0,
170
- M,
171
- N0 * dtype_size // 32,
172
- 0,
173
- (N1 - 1) * N0 * dtype_size // 32,
174
- )
175
- # data_move out
176
- self.cont_data_mv_1_bust(dst=N1MN0_input, src=tmp_ub, burst=M * N1 * N0 * dtype_size // 32)
177
- return N1MN0_input.reshape((M, N1 * N0))
178
-
179
- def broadcast(self, vec_ub, shape):
180
- """ broadcast a vector to a matrix
181
- :param vec_ub: a tensor in UB with shape of (M,), and dtype is float16
182
- :param shape: the target shape, a tuple with value (M, N),M and N are integer multiples of 16
183
- :return: a tensor in UB with shape of (M, N)
184
- """
185
- M, N = shape
186
- dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
187
-
188
- with self.tik_instance.new_stmt_scope(disable_sync=False):
189
- # (M,) -> (2, M) -> (4, M) -> (8, M) -> (16, M)
190
- tmp_ub1 = self.tik_instance.Tensor(FP16, (16, M), name="tmp_ub1", scope=UB)
191
- self.tik_instance.data_move(tmp_ub1, vec_ub, 0, 1, M // 16, 0, 0)
192
- times = self.tik_instance.Scalar("int32", name="times", init_value=1)
193
- with self.tik_instance.for_range(begint=0, endt=16):
194
- with self.tik_instance.if_scope(times <= 8):
195
- offset = times * M
196
- burst = times * M // 16
197
- self.cont_data_mv_1_bust(dst=tmp_ub1[offset], src=tmp_ub1, burst=burst)
198
- with self.tik_instance.else_scope():
199
- self.tik_instance.tik_break()
200
- times.set_as(times * 2)
201
-
202
- # (16, M) -> (M, 16)
203
- tmp_ub2 = self.tik_instance.Tensor(FP16, (M, 16), name="tmp_ub2", scope=UB)
204
- tmp_ub2_transposed = self.transpose_matrix(tmp_ub1, tmp_ub2, M)
205
-
206
- # (M, 16) -> (M, 32) -> (M, 64) -> ... -> (M, N)
207
- self.tik_instance.data_move(dst_ub, tmp_ub2_transposed, 0, M, 1, 0, N // 16 - 1)
208
- times.set_as(1)
209
- with self.tik_instance.for_range(begint=0, endt=N):
210
- offset = times * 16
211
- with self.tik_instance.if_scope(offset * 2 <= N):
212
- burst = offset // 16
213
- src_stride = N // 16 - burst
214
- dst_stride = N // 16 - burst
215
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
216
- dst_stride)
217
- with self.tik_instance.else_scope():
218
- burst = (N - offset) // 16
219
- src_stride = N // 16 - burst
220
- dst_stride = N // 16 - burst
221
- with self.tik_instance.if_scope(burst > 0):
222
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
223
- dst_stride)
224
- self.tik_instance.tik_break()
225
- times.set_as(times * 2)
226
- return dst_ub
227
-
228
- def broadcast_row(self, vec_ub, shape):
229
- """broadcast row"""
230
- M, N = shape
231
- dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
232
- self.tik_instance.data_move(dst_ub, vec_ub, 0, 1, N // 16, 0, 0)
233
- times = self.tik_instance.Scalar("int32", name="times", init_value=1)
234
- # (1, N) -> (2, M) -> (4, N) -> ... -> (M, N)
235
- with self.tik_instance.for_range(begint=0, endt=M):
236
- with self.tik_instance.if_scope(times * 2 <= M):
237
- burst = times * N // 16
238
- offset = times * N
239
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
240
- with self.tik_instance.else_scope():
241
- burst = (M - times) * N // 16
242
- offset = times * N
243
- with self.tik_instance.if_scope(burst > 0):
244
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
245
- self.tik_instance.tik_break()
246
- times.set_as(times * 2)
247
- return dst_ub
248
-
249
- def get_K0(self, dtype=None):
250
- """get K0"""
251
- if dtype is None:
252
- dtype = self.dtype
253
- try:
254
- dtype_size = DTYPE_SIZE[dtype]
255
- except KeyError:
256
- raise ValueError("The argument 'dtype' is not valid.")
257
- return 32 // dtype_size
258
-
259
- def up_align_to_K0(self, n, dtype=None):
260
- """byte alignment by dtype"""
261
- if dtype is None:
262
- dtype = self.dtype
263
- try:
264
- dtype_size = DTYPE_SIZE[dtype]
265
- except KeyError:
266
- raise ValueError("The argument 'dtype' is not valid.")
267
- K0 = 32 // dtype_size
268
- return (n + K0 - 1) // K0 * K0
269
-
270
- def calc_vec_rec(self, vec_ub, vec_len):
271
- """cal the reciprocal of a vector"""
272
- dtype = vec_ub.dtype
273
- vec_len_aligned = self.up_align_to_K0(vec_len)
274
- vec_rec_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="li_new_rec_ub")
275
- try:
276
- dtype_size = DTYPE_SIZE[dtype]
277
- except KeyError:
278
- raise ValueError("The argument 'dtype' is not valid.")
279
- mask_len = 256 // dtype_size
280
- block_len = 32 // dtype_size
281
- work_size = 8 // dtype_size
282
-
283
- with self.tik_instance.new_stmt_scope(disable_sync=False):
284
- repeat_times = vec_len // mask_len
285
- if repeat_times > 0:
286
- dst_rep_stride = 8
287
- src_rep_stride = 8
288
-
289
- src_extent_size = (repeat_times - 1) * src_rep_stride * block_len + mask_len
290
- wk_size_unit = ((src_extent_size + block_len - 1) // block_len) * block_len
291
- wk_size = work_size * wk_size_unit
292
- # 定义work_tensor
293
- work_tensor_ub = self.tik_instance.Tensor(
294
- "float32", (wk_size,), name="work_tensor_ub", scope=UB
295
- )
296
- # 如果work_tensor有索引,需要写成 work_tensor[index:]
297
- self.tik_instance.vec_rec_high_preci(
298
- mask_len,
299
- vec_rec_ub[0:],
300
- vec_ub[0:],
301
- work_tensor_ub[0:],
302
- repeat_times,
303
- dst_rep_stride,
304
- src_rep_stride,
305
- )
306
-
307
- mask_len = vec_len - repeat_times * mask_len
308
- if mask_len > 0:
309
- wk_size = work_size * ((mask_len + block_len - 1) // block_len) * block_len
310
- work_tensor_ub2 = self.tik_instance.Tensor(
311
- "float32", (wk_size,), name="work_tensor_ub2", scope=UB
312
- )
313
- self.tik_instance.vec_rec_high_preci(
314
- mask_len,
315
- vec_rec_ub[repeat_times * 128:],
316
- vec_ub[repeat_times * 128:],
317
- work_tensor_ub2[0:],
318
- 1,
319
- 0,
320
- 0,
321
- )
322
- return vec_rec_ub
323
-
324
- def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, rowsum_ub, m, k, precision_type):
325
- """用cube实现矩阵行和:右乘一个shape=(n,1)全一矩阵
326
- :param matrix_l1_K1MK0_ed: input tensor with shape (K1, M, K0)
327
- :param rowsum_ub: output tensor stores the row sum of input tensor.
328
- :param m: actual tensor height
329
- :param k: actual tensor width
330
- :return: row sum of the output tensor
331
- """
332
- K1, M, K0 = matrix_l1_K1MK0_ed.shape
333
- K = K1 * K0
334
-
335
- # 构造全一右矩阵,由于cube无法处理shape=(n, 1),所以shape=(n, 16),全一矩阵不需分形
336
- right_all_one_matrix_ub = self.tik_instance.Tensor(
337
- FP16, (K, 16), name="right_all_one_matrix_ub", scope=UB
338
- )
339
- self.tik_instance.h_duplicate(right_all_one_matrix_ub, 1.0)
340
- right_all_one_matrix_l1 = self.tik_instance.Tensor(
341
- FP16, (K1 * K0, 16), name="right_all_one_matrix_l1", scope=L1
342
- )
343
- self.cont_data_mv_1_bust(dst=right_all_one_matrix_l1, src=right_all_one_matrix_ub, burst=K)
344
-
345
- # 调用matmul实现rowsum,结果shape=(m, 16),取每行的第一个数
346
- with self.tik_instance.new_stmt_scope(disable_sync=False):
347
- row_sum_ub_N1MN0 = self.matmul_compute(matrix_l1_K1MK0_ed, right_all_one_matrix_l1, m, k, 16,
348
- N1MN0_to_MN=False, precision_type=precision_type)
349
- row_sum_ub_MN_ed = row_sum_ub_N1MN0.reshape((M, 16))
350
- if precision_type == FP32:
351
- for idx in range(0, m):
352
- cur_row_sum = self.tik_instance.Scalar(FP32, init_value=row_sum_ub_MN_ed[idx, 0])
353
- rowsum_ub[idx].set_as(cur_row_sum)
354
- else:
355
- row_sum_ub_trans = self.tik_instance.Tensor(FP16, (16, M), name="row_sum_ub_trans", scope=UB)
356
- row_sum_ub_trans = self.transpose_matrix(row_sum_ub_MN_ed, row_sum_ub_trans, M, True)
357
- self.cont_data_mv_1_bust(dst=rowsum_ub, src=row_sum_ub_trans, burst=M // 16)
358
-
359
- return rowsum_ub
360
-
361
- def matmul_compute(self, A_l1, B_l1, m, k, n, N1MN0_to_MN=True, precision_type=FP16):
362
- """calculate matrix multiplication A_l1 * B_l1, and move the result to C_ub,
363
- then rearrange C_ub
364
- :param A_l1: input tensor in L1 with shape of (K1, M, K0)
365
- :param B_l1: input tensor in L1 with shape of (K1, N, K0)
366
- :param m: the actual number of rows of A_l1
367
- :param k: the actual number of cols of A_l1
368
- :param n: the actual number of cols of B_l1
369
- :param N1MN0_to_MN: Whether reorder the result tensor.
370
- :return: C_ub with tensor with shape of (M, N) if N1MN0_to_MN else (N1, M, N0)
371
- """
372
- M = self.up_align_to_K0(m)
373
- N = self.up_align_to_K0(n)
374
- C_ub = self.tik_instance.Tensor(precision_type, (N // 16, M, 16), name="C_ub", scope=UB)
375
- try:
376
- dtype_size = DTYPE_SIZE[FP32]
377
- except KeyError:
378
- raise ValueError("The argument 'dtype' is not valid.")
379
- with self.tik_instance.new_stmt_scope(disable_sync=False):
380
- # matmul
381
- C_l0c = self.tik_instance.Tensor(
382
- FP32, (N // 16, M, 16), scope=L0C, name="C_l0c"
383
- ) # n1mn0 (n0=16)
384
- self.tik_instance.matmul(C_l0c, A_l1, B_l1, m, k, n)
385
- # L0C -> ub, fp32 -> fp16 (tensor_mov可做随路转换)
386
- self.tik_instance.tensor_mov(C_ub, C_l0c, "m", 1, M * N * dtype_size // 1024, 0, 0)
387
- if N1MN0_to_MN:
388
- return self.N1MN0_TO_MN(C_ub)
389
- return C_ub
390
-
391
- def move_vector_from_gm_to_ub(self, dst_tensor, src_tensor, gm_offset, vec_len):
392
- """load the vector from gm to ub
393
- :param dst_tensor:
394
- :param src_tensor:
395
- :param gm_offset:
396
- :return:
397
- """
398
- try:
399
- dtype_size = DTYPE_SIZE[src_tensor.dtype]
400
- except KeyError:
401
- raise ValueError("The argument 'src_tensor dtype' is not valid.")
402
- a_burst_num = 32 // dtype_size
403
- full_tik_blk_num, tail_num = divmod(vec_len, a_burst_num)
404
- with self.tik_instance.if_scope(full_tik_blk_num > 0):
405
- self.cont_data_mv_1_bust(dst=dst_tensor, src=src_tensor[gm_offset],
406
- burst=full_tik_blk_num)
407
- # 地址回退处理尾部数据
408
- with self.tik_instance.if_scope(tail_num > 0):
409
- offset = vec_len - a_burst_num
410
- last_blk_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="last_blk_ub", scope=UB)
411
- self.cont_data_mv_1_bust(dst=last_blk_ub, src=src_tensor[gm_offset + offset], burst=1)
412
- with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐,无法用datamove
413
- dst_tensor[offset + idx].set_as(last_blk_ub[idx])
414
-
415
- def move_vector_from_ub_to_gm(self, dst_tensor, src_tensor, gm_offset, block_h):
416
- """write the vector back to gm
417
- :param dst_tensor:
418
- :param src_tensor:
419
- :param gm_offset:
420
- :param block_h:
421
- :return:
422
- """
423
- try:
424
- dtype_size = DTYPE_SIZE[src_tensor.dtype]
425
- except KeyError:
426
- raise ValueError("The argument 'src_tensor dtype' is not valid.")
427
- a_burst_num = 32 // dtype_size
428
- full_tik_blk_num = block_h // a_burst_num
429
- with self.tik_instance.if_scope(full_tik_blk_num > 0):
430
- self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset], src=src_tensor,
431
- burst=full_tik_blk_num)
432
- tail_num = block_h % a_burst_num
433
- with self.tik_instance.if_scope(tail_num > 0):
434
- offset = block_h - a_burst_num
435
- tmp_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="tmp_ub", scope=UB)
436
- with self.tik_instance.for_range(0, a_burst_num) as idx:
437
- tmp_ub[idx].set_as(src_tensor[offset + idx])
438
- self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset + offset], src=tmp_ub, burst=1)
439
-
440
- def scale_compute_vector(self, Sij_ub, dim):
441
- """scale compute vector"""
442
- scale_value = dim ** -0.5
443
- scale = self.tik_instance.Scalar(dtype=FP16)
444
- scale.set_as(scale_value)
445
- self.tik_instance.h_mul(Sij_ub, Sij_ub, scale)
446
- return Sij_ub