mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """The impl of flash attention"""
16
+ from __future__ import absolute_import
17
+ import mindspore.ops as ops
18
+ import mindspore.common.dtype as mstype
19
+ from mindspore.ops import Custom
20
+ from mindspore.ops import DataType
21
+ from mindspore.ops import TBERegOp
22
+ from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_bwd import flash_attention_grad
23
+ from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_fwd import flash_attention
24
+ from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
25
+
26
+ KERNEL_NAME = "flash_attention"
27
+
28
+ cus_flash_atten_op_info = TBERegOp("FlashAttentionPrimitive") \
29
+ .fusion_type("OPAQUE") \
30
+ .partial_flag(True) \
31
+ .async_flag(False) \
32
+ .binfile_name("flash_attention.so") \
33
+ .compute_cost(10) \
34
+ .kernel_name(KERNEL_NAME) \
35
+ .attr("prev_block_num", "required", "int", "all", "65536") \
36
+ .attr("next_block_num", "required", "int", "all", "65536") \
37
+ .attr("high_precision", "required", "bool", "all", "false") \
38
+ .attr("tiling_stgy_name", "required", "str", "all", "sparse") \
39
+ .input(0, "query", False, "required", "all") \
40
+ .input(1, "key", False, "required", "all") \
41
+ .input(2, "value", False, "required", "all") \
42
+ .input(3, "attn_mask", False, "optional", "all") \
43
+ .input(4, "dropout_mask", False, "optional", "all") \
44
+ .input(5, "alibi_mask", False, "optional", "all") \
45
+ .output(0, "output", False, "required", "all") \
46
+ .output(1, "rowsum", False, "required", "all") \
47
+ .output(2, "rowmax", False, "required", "all") \
48
+ .dtype_format(DataType.F16_FracNZ,
49
+ DataType.F16_FracNZ,
50
+ DataType.F16_FracNZ,
51
+ DataType.F16_FracNZ,
52
+ DataType.F16_Default,
53
+ DataType.F16_FracNZ,
54
+ DataType.F16_FracNZ,
55
+ DataType.F16_Default,
56
+ DataType.F16_Default) \
57
+ .dtype_format(DataType.F16_FracNZ,
58
+ DataType.F16_FracNZ,
59
+ DataType.F16_FracNZ,
60
+ DataType.F16_FracNZ,
61
+ DataType.F16_Default,
62
+ DataType.F16_FracNZ,
63
+ DataType.F16_FracNZ,
64
+ DataType.F32_Default,
65
+ DataType.F16_Default) \
66
+ .get_op_info()
67
+
68
+ GRAD_KERNEL_NAME = "flash_attention_grad"
69
+
70
+ cus_flash_atten_grad_op_info = TBERegOp("FlashAttentionGradPrimitive") \
71
+ .fusion_type("OPAQUE") \
72
+ .partial_flag(True) \
73
+ .async_flag(False) \
74
+ .binfile_name("flash_attention_grad.so") \
75
+ .compute_cost(10) \
76
+ .kernel_name(GRAD_KERNEL_NAME) \
77
+ .attr("prev_block_num", "required", "int", "all", "65536") \
78
+ .attr("next_block_num", "required", "int", "all", "65536") \
79
+ .attr("high_precision", "required", "bool", "all", "false") \
80
+ .attr("tiling_stgy_name", "required", "str", "all", "sparse") \
81
+ .input(0, "query", False, "required", "all") \
82
+ .input(1, "key", False, "required", "all") \
83
+ .input(2, "value", False, "required", "all") \
84
+ .input(3, "output", False, "required", "all") \
85
+ .input(4, "do", False, "required", "all") \
86
+ .input(5, "rowsum", False, "required", "all") \
87
+ .input(6, "rowmax", False, "required", "all") \
88
+ .input(7, "attn_mask", False, "optional", "all") \
89
+ .input(8, "dropout_mask", False, "optional", "all") \
90
+ .input(9, "alibi_mask", False, "optional", "all") \
91
+ .output(0, "dq", False, "required", "all") \
92
+ .output(1, "dk", False, "required", "all") \
93
+ .output(2, "dv", False, "required", "all") \
94
+ .dtype_format(DataType.F16_FracNZ,
95
+ DataType.F16_FracNZ,
96
+ DataType.F16_FracNZ,
97
+ DataType.F16_FracNZ,
98
+ DataType.F16_FracNZ,
99
+ DataType.F16_Default,
100
+ DataType.F16_Default,
101
+ DataType.F16_FracNZ,
102
+ DataType.F16_Default,
103
+ DataType.F16_FracNZ,
104
+ DataType.F32_FracNZ,
105
+ DataType.F32_FracNZ,
106
+ DataType.F32_FracNZ) \
107
+ .dtype_format(DataType.F16_FracNZ,
108
+ DataType.F16_FracNZ,
109
+ DataType.F16_FracNZ,
110
+ DataType.F16_FracNZ,
111
+ DataType.F16_FracNZ,
112
+ DataType.F32_Default,
113
+ DataType.F16_Default,
114
+ DataType.F16_FracNZ,
115
+ DataType.F16_Default,
116
+ DataType.F16_FracNZ,
117
+ DataType.F32_FracNZ,
118
+ DataType.F32_FracNZ,
119
+ DataType.F32_FracNZ) \
120
+ .get_op_info()
121
+
122
+
123
+ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
124
+ tiling_stgy_name='sparse', high_precision=False):
125
+ """get flash attention grad"""
126
+
127
+ def infer_shape(q_shape, k_shape, v_shape, o_shape, do_shape, l_shape, m_shape,
128
+ att_mask_shape, dropout_mask_shape, alibi_mask_shape):
129
+ return q_shape, k_shape, v_shape
130
+
131
+ def infer_dtype(q_dtype, k_dtype, v_dtype, o_dytpe, do_dtype, l_dtype, m_dtype,
132
+ attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
133
+ return mstype.float32, mstype.float32, mstype.float32
134
+
135
+ fa_grad = Custom(flash_attention_grad, out_shape=infer_shape,
136
+ out_dtype=infer_dtype, func_type="tbe", reg_info=cus_flash_atten_grad_op_info)
137
+ fa_grad.add_prim_attr("prev_block_num", prev_block_num)
138
+ fa_grad.add_prim_attr("next_block_num", next_block_num)
139
+ fa_grad.add_prim_attr("high_precision", high_precision)
140
+ fa_grad.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
141
+ fa_grad.init_prim_io_names(
142
+ inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "attn_mask", "dropout_mask",
143
+ "alibi_mask"],
144
+ outputs=["dq", "dk", "dv"]
145
+ )
146
+
147
+ def bprop(query, key, value, attn_mask, dropout_mask, alibi_mask, out, douts):
148
+ output, rowsum, rowmax = out
149
+ dout, _, _ = douts
150
+ dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, attn_mask, dropout_mask,
151
+ alibi_mask)
152
+ dq = ops.cast(dq, mstype.float16)
153
+ dk = ops.cast(dk, mstype.float16)
154
+ dv = ops.cast(dv, mstype.float16)
155
+ return dq, dk, dv, zeros_like(attn_mask), \
156
+ zeros_like(dropout_mask), zeros_like(alibi_mask)
157
+
158
+ return bprop
159
+
160
+
161
+ def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_name='sparse', high_precision=False):
162
+ """get_flash_attention"""
163
+
164
+ def infer_shape(q_shape, k_shape, v_shape, attn_mask_shape=None,
165
+ dropout_mask_shape=None, alibi_mask_shape=None):
166
+ """infer shape"""
167
+ batch, hidden_size, seq_len, _ = q_shape
168
+ l_shape = (batch, hidden_size, seq_len)
169
+ m_shape = (batch, hidden_size, seq_len)
170
+ return q_shape, l_shape, m_shape
171
+
172
+ def infer_dtype(q_dtype, k_dtype, v_dtype, attn_mask_dtype=None,
173
+ dropout_mask_dtype=None, alibi_mask_type=None):
174
+ """infer type"""
175
+ l_dtype = mstype.float16
176
+ if high_precision:
177
+ l_dtype = mstype.float32
178
+ return q_dtype, l_dtype, q_dtype
179
+
180
+ fa_grad = get_flash_attention_grad(prev_block_num, next_block_num, tiling_stgy_name, high_precision)
181
+ fa_forward = Custom(flash_attention, out_shape=infer_shape,
182
+ out_dtype=infer_dtype, func_type="tbe", bprop=fa_grad,
183
+ reg_info=cus_flash_atten_op_info)
184
+ fa_forward.add_prim_attr("prev_block_num", prev_block_num)
185
+ fa_forward.add_prim_attr("next_block_num", next_block_num)
186
+ fa_forward.add_prim_attr("high_precision", high_precision)
187
+ fa_forward.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
188
+ fa_forward.init_prim_io_names(
189
+ inputs=["query", "key", "value", "attn_mask", "dropout_mask", "alibi_mask"],
190
+ outputs=["output", "rowsum", "rowmax"]
191
+ )
192
+
193
+ return fa_forward
@@ -0,0 +1,435 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """the common about tik ops"""
16
+ from functools import partial
17
+
18
+ from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SIZE
19
+ from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
20
+ from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
21
+ from mindspore.ops._op_impl._custom_op.flash_attention.constants import L0C
22
+ from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
23
+
24
+
25
+ class TikOpsUtils:
26
+ """Utils function class about tik ops"""
27
+
28
+ def __init__(self, tik_instance):
29
+ self.tik_instance = tik_instance
30
+ self.dtype = "float16"
31
+ self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
32
+ src_stride=0,
33
+ dst_stride=0)
34
+
35
+ def MK_TO_K1MK0(self, mk_input_tensor, workspace_tensor=None):
36
+ """change data shape from (M, K) to (K1, M, K0), K1 = K // K0, the effect is equant to:
37
+ new_tensor = np.stack(np.hsplit(mk_input_tensor, K1), axis=0)
38
+
39
+ :param mk_input_tensor: input tensor in GM with shape: (M, K)
40
+ :param workspace_tensor: workspace tensor with shape: (K1, M, K0)
41
+ tensor will be changed, otherwise the new data will be copied to the workspace tensor,
42
+ and input tensor will stay unchanged.
43
+ :return: Tensor with shape (K1,M, K0)
44
+ """
45
+ dtype = mk_input_tensor.dtype
46
+ m, k = mk_input_tensor.shape
47
+ K0 = 16
48
+ K1 = k // K0
49
+ M = self.up_align_to_K0(m)
50
+ try:
51
+ dtype_size = DTYPE_SIZE[dtype]
52
+ except KeyError:
53
+ raise ValueError("The argument 'dtype' is not valid.")
54
+ if workspace_tensor is not None:
55
+ with self.tik_instance.for_range(0, K1) as i:
56
+ self.tik_instance.data_move(
57
+ workspace_tensor[i * M * K0:],
58
+ mk_input_tensor[i * K0:],
59
+ 0,
60
+ M,
61
+ K0 * dtype_size // 32,
62
+ (K1 - 1) * K0 * dtype_size // 32,
63
+ 0,
64
+ )
65
+ return workspace_tensor.reshape((K1, M, K0))
66
+
67
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
68
+ tmp_ub = self.tik_instance.Tensor(dtype, (K1, M, K0), name="tmp_ub", scope=UB)
69
+ # data_move(m,k) --> (k1,m,K0)
70
+ with self.tik_instance.for_range(0, K1) as i:
71
+ self.tik_instance.data_move(
72
+ tmp_ub[i * M * K0:],
73
+ mk_input_tensor[i * K0:],
74
+ 0,
75
+ M,
76
+ K0 * dtype_size // 32,
77
+ (K1 - 1) * K0 * dtype_size // 32,
78
+ 0,
79
+ )
80
+ self.cont_data_mv_1_bust(
81
+ dst=mk_input_tensor, src=tmp_ub, burst=K1 * M * K0 * dtype_size // 32)
82
+ return mk_input_tensor.reshape((K1, M, K0))
83
+
84
+ def transpose_matrix(self, src_ub, dst_ub, N, nk0=False):
85
+ """ transpose matrix, default support shape: (16, n) -> (n, 16)
86
+ if nk0 is true, support shape: (n, 16) -> (16, n)
87
+ """
88
+ K0 = 16
89
+ rep_times = N // K0
90
+ if nk0:
91
+ src_list = [src_ub[16 * i] for i in range(16)]
92
+ dst_list = [dst_ub[N * i] for i in range(16)]
93
+ else:
94
+ src_list = [src_ub[N * i] for i in range(16)]
95
+ dst_list = [dst_ub[16 * i] for i in range(16)]
96
+
97
+ dst_rep_stride = K0
98
+ src_rep_stride = 1
99
+ if rep_times == 1:
100
+ dst_rep_stride = 0
101
+ src_rep_stride = 0
102
+
103
+ if nk0:
104
+ src_rep_stride, dst_rep_stride = dst_rep_stride, src_rep_stride
105
+
106
+ self.tik_instance.vec_trans_scatter(
107
+ False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride
108
+ )
109
+ return dst_ub
110
+
111
+ def KN_TO_K1NK0(self, kn_input_tensor, workspace_tensor=None):
112
+ """change data shape from (K,N) to (K1, N, K0), K1 = K // K0, the effect is equvilent to:
113
+ new_tensor = np.reshape(kn_input_tensor, newshape=(K1, K0, N)).swapaxes(1, 2)
114
+
115
+ :param kn_input_tensor: input tensor with shape: (K, N)
116
+ :param workspace_tensor: workspace tensor with shape: (K1, N, K0)
117
+ tensor will be changed, otherwise the new data will be copied to the workspace tensor,
118
+ and input tensor will stay unchanged.
119
+ :return: Tensor with shape: (K1, N, K0)
120
+ """
121
+ dtype = kn_input_tensor.dtype
122
+ k, n = kn_input_tensor.shape
123
+ K0 = 16
124
+ K1 = k // K0
125
+ N = n
126
+ try:
127
+ dtype_size = DTYPE_SIZE[dtype]
128
+ except KeyError:
129
+ raise ValueError("The argument 'dtype' is not valid.")
130
+ with self.tik_instance.for_range(0, K1) as index:
131
+ k1nk0_ub = self.tik_instance.Tensor(dtype, (N, K0), UB, "k1nk0_ub")
132
+ src_ub = self.tik_instance.Tensor(dtype, (K0, N), UB, "src_ub")
133
+ burst_len = K0 * N * dtype_size // 32
134
+ self.cont_data_mv_1_bust(dst=src_ub, src=kn_input_tensor[index * K0 * N],
135
+ burst=burst_len)
136
+ k1nk0_ub = self.transpose_matrix(src_ub, k1nk0_ub, N)
137
+ if workspace_tensor is None:
138
+ self.cont_data_mv_1_bust(dst=kn_input_tensor[index * K0 * N], src=k1nk0_ub,
139
+ burst=burst_len)
140
+ else:
141
+ self.cont_data_mv_1_bust(dst=workspace_tensor[index * K0 * N], src=k1nk0_ub,
142
+ burst=burst_len)
143
+ if workspace_tensor is None:
144
+ return kn_input_tensor.reshape((K1, N, K0))
145
+
146
+ return workspace_tensor.reshape((K1, N, K0))
147
+
148
+ def N1MN0_TO_MN(self, N1MN0_input):
149
+ """change data shape from (N1, M, N0) to (M, N), N0=16, N = N1 * K0, the effect is equant to:
150
+ N1MN0_input = np.concatenate(list(map(np.squeeze, np.split(N1MN0_input, N1))), axis=1)
151
+
152
+ :param N1MN0_input: input tensor with shape (N, M, N0) in GM or L1.
153
+ :return:
154
+ """
155
+ dtype = N1MN0_input.dtype
156
+ N1, M, N0 = N1MN0_input.shape
157
+ try:
158
+ dtype_size = DTYPE_SIZE[dtype]
159
+ except KeyError:
160
+ raise ValueError("The argument 'dtype' is not valid.")
161
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
162
+ tmp_ub = self.tik_instance.Tensor(dtype, (M, N1 * N0), name="tmp_ub", scope=UB)
163
+ # data_move (n1,m,n0) --> (m,n)
164
+ with self.tik_instance.for_range(0, N1) as i:
165
+ self.tik_instance.data_move(
166
+ tmp_ub[i * N0:],
167
+ N1MN0_input[i * M * N0:],
168
+ 0,
169
+ M,
170
+ N0 * dtype_size // 32,
171
+ 0,
172
+ (N1 - 1) * N0 * dtype_size // 32,
173
+ )
174
+ # data_move out
175
+ self.cont_data_mv_1_bust(dst=N1MN0_input, src=tmp_ub, burst=M * N1 * N0 * dtype_size // 32)
176
+ return N1MN0_input.reshape((M, N1 * N0))
177
+
178
+ def broadcast(self, vec_ub, shape):
179
+ """ broadcast a vector to a matrix
180
+ :param vec_ub: a tensor in UB with shape of (M,), and dtype is float16
181
+ :param shape: the target shape, a tuple with value (M, N), M and N are integer multiples of 16
182
+ :return: a tensor in UB with shape of (M, N)
183
+ """
184
+ M, N = shape
185
+ dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
186
+
187
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
188
+ # (M,) -> (2, M) -> (4, M) -> (8, M) -> (16, M)
189
+ tmp_ub1 = self.tik_instance.Tensor(FP16, (16, M), name="tmp_ub1", scope=UB)
190
+ self.tik_instance.data_move(tmp_ub1, vec_ub, 0, 1, M // 16, 0, 0)
191
+ times = self.tik_instance.Scalar("int32", name="times", init_value=1)
192
+ with self.tik_instance.for_range(begint=0, endt=16):
193
+ with self.tik_instance.if_scope(times <= 8):
194
+ offset = times * M
195
+ burst = times * M // 16
196
+ self.cont_data_mv_1_bust(dst=tmp_ub1[offset], src=tmp_ub1, burst=burst)
197
+ with self.tik_instance.else_scope():
198
+ self.tik_instance.tik_break()
199
+ times.set_as(times * 2)
200
+
201
+ # (16, M) -> (M, 16)
202
+ tmp_ub2 = self.tik_instance.Tensor(FP16, (M, 16), name="tmp_ub2", scope=UB)
203
+ tmp_ub2_transposed = self.transpose_matrix(tmp_ub1, tmp_ub2, M)
204
+
205
+ # (M, 16) -> (M, 32) -> (M, 64) -> ... -> (M, N)
206
+ self.tik_instance.data_move(dst_ub, tmp_ub2_transposed, 0, M, 1, 0, N // 16 - 1)
207
+ times.set_as(1)
208
+ with self.tik_instance.for_range(begint=0, endt=N):
209
+ offset = times * 16
210
+ with self.tik_instance.if_scope(offset * 2 <= N):
211
+ burst = offset // 16
212
+ src_stride = N // 16 - burst
213
+ dst_stride = N // 16 - burst
214
+ self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
215
+ dst_stride)
216
+ with self.tik_instance.else_scope():
217
+ burst = (N - offset) // 16
218
+ src_stride = N // 16 - burst
219
+ dst_stride = N // 16 - burst
220
+ with self.tik_instance.if_scope(burst > 0):
221
+ self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
222
+ dst_stride)
223
+ self.tik_instance.tik_break()
224
+ times.set_as(times * 2)
225
+ return dst_ub
226
+
227
+ def broadcast_row(self, vec_ub, shape):
228
+ """broadcast row"""
229
+ M, N = shape
230
+ dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
231
+ self.tik_instance.data_move(dst_ub, vec_ub, 0, 1, N // 16, 0, 0)
232
+ times = self.tik_instance.Scalar("int32", name="times", init_value=1)
233
+ # (1, N) -> (2, M) -> (4, N) -> ... -> (M, N)
234
+ with self.tik_instance.for_range(begint=0, endt=M):
235
+ with self.tik_instance.if_scope(times * 2 <= M):
236
+ burst = times * N // 16
237
+ offset = times * N
238
+ self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
239
+ with self.tik_instance.else_scope():
240
+ burst = (M - times) * N // 16
241
+ offset = times * N
242
+ with self.tik_instance.if_scope(burst > 0):
243
+ self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
244
+ self.tik_instance.tik_break()
245
+ times.set_as(times * 2)
246
+ return dst_ub
247
+
248
+ def get_K0(self, dtype=None):
249
+ """get K0"""
250
+ if dtype is None:
251
+ dtype = self.dtype
252
+ try:
253
+ dtype_size = DTYPE_SIZE[dtype]
254
+ except KeyError:
255
+ raise ValueError("The argument 'dtype' is not valid.")
256
+ return 32 // dtype_size
257
+
258
+ def up_align_to_K0(self, n, dtype=None):
259
+ """byte alignment by dtype"""
260
+ if dtype is None:
261
+ dtype = self.dtype
262
+ try:
263
+ dtype_size = DTYPE_SIZE[dtype]
264
+ except KeyError:
265
+ raise ValueError("The argument 'dtype' is not valid.")
266
+ K0 = 32 // dtype_size
267
+ return (n + K0 - 1) // K0 * K0
268
+
269
+ def calc_vec_rec(self, vec_ub, vec_len):
270
+ """cal the reciprocal of a vector"""
271
+ dtype = vec_ub.dtype
272
+ vec_len_aligned = self.up_align_to_K0(vec_len)
273
+ vec_rec_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="li_new_rec_ub")
274
+ try:
275
+ dtype_size = DTYPE_SIZE[dtype]
276
+ except KeyError:
277
+ raise ValueError("The argument 'dtype' is not valid.")
278
+ mask_len = 256 // dtype_size
279
+ block_len = 32 // dtype_size
280
+ work_size = 8 // dtype_size
281
+
282
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
283
+ repeat_times = vec_len // mask_len
284
+ if repeat_times > 0:
285
+ dst_rep_stride = 8
286
+ src_rep_stride = 8
287
+
288
+ src_extent_size = (repeat_times - 1) * src_rep_stride * block_len + mask_len
289
+ wk_size_unit = ((src_extent_size + block_len - 1) // block_len) * block_len
290
+ wk_size = work_size * wk_size_unit
291
+ # 定义work_tensor
292
+ work_tensor_ub = self.tik_instance.Tensor(
293
+ "float32", (wk_size,), name="work_tensor_ub", scope=UB
294
+ )
295
+ # 如果work_tensor有索引,需要写成 work_tensor[index:]
296
+ self.tik_instance.vec_rec_high_preci(
297
+ mask_len,
298
+ vec_rec_ub[0:],
299
+ vec_ub[0:],
300
+ work_tensor_ub[0:],
301
+ repeat_times,
302
+ dst_rep_stride,
303
+ src_rep_stride,
304
+ )
305
+
306
+ mask_len = vec_len - repeat_times * mask_len
307
+ if mask_len > 0:
308
+ wk_size = work_size * ((mask_len + block_len - 1) // block_len) * block_len
309
+ work_tensor_ub2 = self.tik_instance.Tensor(
310
+ "float32", (wk_size,), name="work_tensor_ub2", scope=UB
311
+ )
312
+ self.tik_instance.vec_rec_high_preci(
313
+ mask_len,
314
+ vec_rec_ub[repeat_times * 128:],
315
+ vec_ub[repeat_times * 128:],
316
+ work_tensor_ub2[0:],
317
+ 1,
318
+ 0,
319
+ 0,
320
+ )
321
+ return vec_rec_ub
322
+
323
+ def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, right_all_one_matrix_l1, rowsum_ub, m, k, precision_type):
324
+ """用cube实现矩阵行和:右乘一个shape=(n,1)全一矩阵
325
+ :param matrix_l1_K1MK0_ed: input tensor with shape (K1, M, K0)
326
+ :param right_all_one_matrix_l1: input tensor with shape (K, 16)
327
+ :param rowsum_ub: output tensor stores the row sum of input tensor
328
+ :param m: actual tensor height
329
+ :param k: actual tensor width
330
+ :return: row sum of the output tensor
331
+ """
332
+ K1, M, K0 = matrix_l1_K1MK0_ed.shape
333
+ # 调用matmul实现rowsum,结果shape=(m, 16),取每行的第一个数
334
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
335
+ row_sum_ub_N1MN0 = self.matmul_compute(matrix_l1_K1MK0_ed, right_all_one_matrix_l1, m, k, 16,
336
+ N1MN0_to_MN=False, precision_type=precision_type)
337
+ row_sum_ub_MN_ed = row_sum_ub_N1MN0.reshape((M, 16))
338
+ if precision_type == FP32:
339
+ for idx in range(0, m):
340
+ cur_row_sum = self.tik_instance.Scalar(FP32, init_value=row_sum_ub_MN_ed[idx, 0])
341
+ rowsum_ub[idx].set_as(cur_row_sum)
342
+ else:
343
+ # row_sum_ub_MN_ed 先转置,然后取一行, 替换原来按行操作: lij_ub[i].set_as(row_sum_ub_MN_ed[i, 0])
344
+ row_sum_ub_trans = self.tik_instance.Tensor(FP16, (16, M), name="row_sum_ub_trans", scope=UB)
345
+ row_sum_ub_trans = self.transpose_matrix(row_sum_ub_MN_ed, row_sum_ub_trans, M, True)
346
+ self.cont_data_mv_1_bust(dst=rowsum_ub, src=row_sum_ub_trans, burst=M // 16)
347
+
348
+ return rowsum_ub
349
+
350
+ def matmul_compute(self, A_l1, B_l1, m, k, n, N1MN0_to_MN=True, precision_type=FP16):
351
+ """calculate matrix multiplication A_l1 * B_l1, and move the result to C_ub,
352
+ then rearrange C_ub
353
+ :param A_l1: input tensor in L1 with shape of (K1, M, K0)
354
+ :param B_l1: input tensor in L1 with shape of (K1, N, K0)
355
+ :param m: the actual number of rows of A_l1
356
+ :param k: the actual number of cols of A_l1
357
+ :param n: the actual number of cols of B_l1
358
+ :param N1MN0_to_MN: Whether reorder the result tensor.
359
+ :return: C_ub with tensor with shape of (M, N) if N1MN0_to_MN else (N1, M, N0)
360
+ """
361
+ M = self.up_align_to_K0(m)
362
+ N = self.up_align_to_K0(n)
363
+ C_ub = self.tik_instance.Tensor(precision_type, (N // 16, M, 16), name="C_ub", scope=UB)
364
+ try:
365
+ dtype_size = DTYPE_SIZE[FP32]
366
+ except KeyError:
367
+ raise ValueError("The argument 'dtype' is not valid.")
368
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
369
+ # matmul
370
+ C_l0c = self.tik_instance.Tensor(
371
+ FP32, (N // 16, M, 16), scope=L0C, name="C_l0c"
372
+ ) # n1mn0 (n0=16)
373
+ self.tik_instance.matmul(C_l0c, A_l1, B_l1, m, k, n)
374
+ # L0C -> ub, fp32 -> fp16 (tensor_mov可做随路转换)
375
+ self.tik_instance.tensor_mov(C_ub, C_l0c, "m", 1, M * N * dtype_size // 1024, 0, 0)
376
+ if N1MN0_to_MN:
377
+ return self.N1MN0_TO_MN(C_ub)
378
+ return C_ub
379
+
380
+ def move_vector_from_gm_to_ub(self, dst_tensor, src_tensor, gm_offset, vec_len):
381
+ """load the vector from gm to ub
382
+ :param dst_tensor:
383
+ :param src_tensor:
384
+ :param gm_offset:
385
+ :return:
386
+ """
387
+ try:
388
+ dtype_size = DTYPE_SIZE[src_tensor.dtype]
389
+ except KeyError:
390
+ raise ValueError("The argument 'src_tensor dtype' is not valid.")
391
+ a_burst_num = 32 // dtype_size
392
+ full_tik_blk_num, tail_num = divmod(vec_len, a_burst_num)
393
+ with self.tik_instance.if_scope(full_tik_blk_num > 0):
394
+ self.cont_data_mv_1_bust(dst=dst_tensor, src=src_tensor[gm_offset],
395
+ burst=full_tik_blk_num)
396
+ # 地址回退处理尾部数据
397
+ with self.tik_instance.if_scope(tail_num > 0):
398
+ offset = vec_len - a_burst_num
399
+ last_blk_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="last_blk_ub", scope=UB)
400
+ self.cont_data_mv_1_bust(dst=last_blk_ub, src=src_tensor[gm_offset + offset], burst=1)
401
+ with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐, 无法用datamove
402
+ dst_tensor[offset + idx].set_as(last_blk_ub[idx])
403
+
404
+ def move_vector_from_ub_to_gm(self, dst_tensor, src_tensor, gm_offset, block_h):
405
+ """write the vector back to gm
406
+ :param dst_tensor:
407
+ :param src_tensor:
408
+ :param gm_offset:
409
+ :param block_h:
410
+ :return:
411
+ """
412
+ try:
413
+ dtype_size = DTYPE_SIZE[src_tensor.dtype]
414
+ except KeyError:
415
+ raise ValueError("The argument 'src_tensor dtype' is not valid.")
416
+ a_burst_num = 32 // dtype_size
417
+ full_tik_blk_num = block_h // a_burst_num
418
+ with self.tik_instance.if_scope(full_tik_blk_num > 0):
419
+ self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset], src=src_tensor,
420
+ burst=full_tik_blk_num)
421
+ tail_num = block_h % a_burst_num
422
+ with self.tik_instance.if_scope(tail_num > 0):
423
+ offset = block_h - a_burst_num
424
+ tmp_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="tmp_ub", scope=UB)
425
+ with self.tik_instance.for_range(0, a_burst_num) as idx:
426
+ tmp_ub[idx].set_as(src_tensor[offset + idx])
427
+ self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset + offset], src=tmp_ub, burst=1)
428
+
429
+ def scale_compute_vector(self, Sij_ub, dim):
430
+ """scale compute vector"""
431
+ scale_value = dim ** -0.5
432
+ scale = self.tik_instance.Scalar(dtype=FP16)
433
+ scale.set_as(scale_value)
434
+ self.tik_instance.h_mul(Sij_ub, Sij_ub, scale)
435
+ return Sij_ub
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """tiling for sparse """
16
+ from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
17
+ from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
18
+
19
+
20
+ class SparseTiling(TilingStrategy):
21
+ """A tiling strategy implementation for sparse shape"""
22
+
23
+ @classmethod
24
+ def strategy_name(cls):
25
+ return "sparse"
26
+
27
+ def tiling(self) -> TilingPara:
28
+ self.Br = min(128, self.Nq)
29
+ self.Bc = min(128, self.N)
30
+
31
+ self.Tr = self.Nq // self.Br
32
+ self.Tc = self.N // self.Bc
33
+
34
+ if self.Nq % self.Br != 0:
35
+ self.last_Br = self.Nq - self.Tr * self.Br
36
+ self.Tr += 1
37
+ else:
38
+ self.last_Br = self.Br
39
+ if self.N % self.Bc != 0:
40
+ self.last_Bc = self.N - self.Tc * self.Bc
41
+ self.Tc += 1
42
+ else:
43
+ self.last_Bc = self.Bc
44
+
45
+ return self.gen_tiling_para()