mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (505) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +12 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +61 -71
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +74 -104
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +13 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +28 -5
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +141 -88
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/cell_wrapper.py +84 -34
  205. mindspore/nn/wrap/grad_reducer.py +8 -5
  206. mindspore/nn/wrap/loss_scale.py +105 -42
  207. mindspore/numpy/array_creations.py +1 -2
  208. mindspore/numpy/array_ops.py +3 -2
  209. mindspore/numpy/utils_const.py +5 -5
  210. mindspore/opencv_core452.dll +0 -0
  211. mindspore/opencv_imgcodecs452.dll +0 -0
  212. mindspore/opencv_imgproc452.dll +0 -0
  213. mindspore/ops/_grad_experimental/__init__.py +0 -5
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  216. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  217. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  218. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  219. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  220. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  222. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  224. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  225. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  226. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  227. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  228. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  229. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  230. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  231. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  232. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  233. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  234. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  235. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  236. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  237. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  238. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  239. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  240. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  241. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  242. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  243. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  244. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  245. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  246. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  248. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  249. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  250. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  251. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  252. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  253. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  254. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  255. mindspore/ops/_primitive_cache.py +1 -1
  256. mindspore/ops/_tracefunc.py +45 -13
  257. mindspore/ops/_utils/utils.py +6 -1
  258. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  259. mindspore/ops/_vmap/vmap_base.py +3 -3
  260. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  261. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  262. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  263. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  264. mindspore/ops/arg_dtype_cast.py +54 -0
  265. mindspore/ops/composite/base.py +37 -10
  266. mindspore/ops/composite/math_ops.py +5 -4
  267. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  268. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  269. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  270. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  271. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  272. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  273. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  274. mindspore/ops/deprecated.py +304 -0
  275. mindspore/ops/function/__init__.py +4 -1
  276. mindspore/ops/function/array_func.py +174 -193
  277. mindspore/ops/function/clip_func.py +81 -13
  278. mindspore/ops/function/debug_func.py +1 -1
  279. mindspore/ops/function/grad/grad_func.py +18 -9
  280. mindspore/ops/function/image_func.py +10 -4
  281. mindspore/ops/function/linalg_func.py +5 -5
  282. mindspore/ops/function/math_func.py +575 -386
  283. mindspore/ops/function/nn_func.py +568 -260
  284. mindspore/ops/function/random_func.py +88 -57
  285. mindspore/ops/function/sparse_func.py +1 -1
  286. mindspore/ops/function/sparse_unary_func.py +14 -12
  287. mindspore/ops/function/vmap_func.py +6 -5
  288. mindspore/ops/functional.py +15 -10
  289. mindspore/ops/op_info_register.py +244 -25
  290. mindspore/ops/operations/__init__.py +28 -19
  291. mindspore/ops/operations/_grad_ops.py +72 -7
  292. mindspore/ops/operations/_inner_ops.py +350 -17
  293. mindspore/ops/operations/_quant_ops.py +4 -8
  294. mindspore/ops/operations/_sequence_ops.py +42 -0
  295. mindspore/ops/operations/array_ops.py +68 -282
  296. mindspore/ops/operations/comm_ops.py +107 -59
  297. mindspore/ops/operations/custom_ops.py +94 -70
  298. mindspore/ops/operations/debug_ops.py +8 -4
  299. mindspore/ops/operations/image_ops.py +18 -12
  300. mindspore/ops/operations/inner_ops.py +26 -3
  301. mindspore/ops/operations/math_ops.py +189 -141
  302. mindspore/ops/operations/nn_ops.py +794 -489
  303. mindspore/ops/operations/other_ops.py +0 -22
  304. mindspore/ops/operations/random_ops.py +53 -111
  305. mindspore/ops/operations/sparse_ops.py +3 -1
  306. mindspore/ops/primitive.py +24 -18
  307. mindspore/parallel/_auto_parallel_context.py +68 -8
  308. mindspore/parallel/_cost_model_context.py +2 -2
  309. mindspore/parallel/_offload_context.py +17 -3
  310. mindspore/parallel/_parallel_serialization.py +12 -5
  311. mindspore/parallel/_ps_context.py +12 -0
  312. mindspore/parallel/_tensor.py +18 -13
  313. mindspore/parallel/_transformer/layers.py +5 -3
  314. mindspore/parallel/_transformer/loss.py +1 -0
  315. mindspore/parallel/_transformer/moe.py +2 -2
  316. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  317. mindspore/parallel/_transformer/transformer.py +23 -3
  318. mindspore/parallel/_utils.py +11 -7
  319. mindspore/parallel/algo_parameter_config.py +85 -5
  320. mindspore/parallel/checkpoint_transform.py +19 -12
  321. mindspore/parallel/shard.py +21 -14
  322. mindspore/pgodb140.dll +0 -0
  323. mindspore/pgort140.dll +0 -0
  324. mindspore/profiler/common/struct_type.py +3 -3
  325. mindspore/profiler/common/util.py +4 -2
  326. mindspore/profiler/envprofiling.py +1 -1
  327. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  328. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  329. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  330. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  331. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  332. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  333. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  334. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  335. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  336. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  337. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  338. mindspore/profiler/parser/flops_parser.py +15 -11
  339. mindspore/profiler/parser/framework_parser.py +38 -22
  340. mindspore/profiler/parser/hccl_parser.py +16 -12
  341. mindspore/profiler/parser/integrator.py +22 -11
  342. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  343. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  344. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  345. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  346. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  347. mindspore/profiler/parser/optime_parser.py +1 -1
  348. mindspore/profiler/parser/profiler_info.py +21 -2
  349. mindspore/profiler/parser/step_trace_parser.py +11 -14
  350. mindspore/profiler/profiling.py +179 -89
  351. mindspore/rewrite/api/node.py +102 -19
  352. mindspore/rewrite/api/node_type.py +5 -1
  353. mindspore/rewrite/api/pattern_engine.py +1 -1
  354. mindspore/rewrite/api/scoped_value.py +9 -17
  355. mindspore/rewrite/api/symbol_tree.py +131 -47
  356. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  357. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  358. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  359. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  360. mindspore/rewrite/common/rewrite_elog.py +5 -1
  361. mindspore/rewrite/namer.py +33 -24
  362. mindspore/rewrite/namespace.py +14 -5
  363. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  364. mindspore/rewrite/node/call_function.py +79 -0
  365. mindspore/rewrite/node/cell_container.py +135 -0
  366. mindspore/rewrite/node/control_flow.py +88 -0
  367. mindspore/rewrite/{node.py → node/node.py} +273 -234
  368. mindspore/rewrite/node/node_manager.py +254 -0
  369. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  370. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  371. mindspore/rewrite/parsers/assign_parser.py +216 -221
  372. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  373. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  374. mindspore/rewrite/parsers/constant_parser.py +9 -6
  375. mindspore/rewrite/parsers/container_parser.py +9 -7
  376. mindspore/rewrite/parsers/for_parser.py +36 -15
  377. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  378. mindspore/rewrite/parsers/if_parser.py +28 -24
  379. mindspore/rewrite/parsers/module_parser.py +196 -25
  380. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  381. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  382. mindspore/rewrite/parsers/return_parser.py +6 -6
  383. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  384. mindspore/rewrite/sparsify/utils.py +1 -1
  385. mindspore/rewrite/symbol_tree.py +523 -578
  386. mindspore/rewrite/symbol_tree_builder.py +9 -193
  387. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  388. mindspore/run_check/_check_version.py +6 -4
  389. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  390. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  391. mindspore/tbbmalloc.dll +0 -0
  392. mindspore/tinyxml2.dll +0 -0
  393. mindspore/train/_utils.py +7 -3
  394. mindspore/train/amp.py +323 -123
  395. mindspore/train/anf_ir_pb2.py +14 -2
  396. mindspore/train/callback/_backup_and_restore.py +2 -12
  397. mindspore/train/callback/_callback.py +29 -4
  398. mindspore/train/callback/_checkpoint.py +23 -8
  399. mindspore/train/callback/_early_stop.py +2 -2
  400. mindspore/train/callback/_landscape.py +4 -4
  401. mindspore/train/callback/_loss_monitor.py +2 -2
  402. mindspore/train/callback/_on_request_exit.py +2 -2
  403. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  404. mindspore/train/callback/_summary_collector.py +15 -8
  405. mindspore/train/callback/_time_monitor.py +58 -5
  406. mindspore/train/data_sink.py +5 -11
  407. mindspore/train/dataset_helper.py +84 -57
  408. mindspore/train/loss_scale_manager.py +2 -2
  409. mindspore/train/metrics/__init__.py +3 -3
  410. mindspore/train/metrics/cosine_similarity.py +1 -1
  411. mindspore/train/metrics/hausdorff_distance.py +3 -2
  412. mindspore/train/metrics/mean_surface_distance.py +3 -2
  413. mindspore/train/metrics/metric.py +39 -19
  414. mindspore/train/metrics/roc.py +2 -2
  415. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  416. mindspore/train/mind_ir_pb2.py +85 -36
  417. mindspore/train/model.py +187 -47
  418. mindspore/train/serialization.py +487 -161
  419. mindspore/train/summary/_summary_adapter.py +1 -1
  420. mindspore/train/summary/_writer_pool.py +3 -2
  421. mindspore/train/summary/summary_record.py +37 -17
  422. mindspore/train/train_thor/convert_utils.py +3 -3
  423. mindspore/train/train_thor/dataset_helper.py +1 -1
  424. mindspore/turbojpeg.dll +0 -0
  425. mindspore/vcmeta.dll +0 -0
  426. mindspore/vcruntime140.dll +0 -0
  427. mindspore/vcruntime140_1.dll +0 -0
  428. mindspore/version.py +1 -1
  429. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
  430. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
  431. mindspore/_extends/graph_kernel/expander.py +0 -80
  432. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  433. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  434. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  435. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  436. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  437. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  438. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  439. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  440. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  441. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  442. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  443. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  444. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  445. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  446. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  447. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  448. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  449. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  450. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  451. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  452. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  453. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  454. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  455. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  456. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  457. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  458. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  459. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  460. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  461. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  462. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  463. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  464. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  465. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  466. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  467. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  468. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  469. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  470. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  471. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  472. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  473. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  474. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  475. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  476. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  477. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  478. mindspore/dataset/datapreprocess/__init__.py +0 -20
  479. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  480. mindspore/include/api/net.h +0 -142
  481. mindspore/nn/lr_scheduler.py +0 -262
  482. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  483. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  484. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  485. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  486. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  487. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  488. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  489. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  490. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  491. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  492. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  493. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  497. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  502. mindspore/rewrite/node_visitor.py +0 -44
  503. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  504. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  505. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -31,14 +31,14 @@ class FlashAttentionBwd(FlashAttention):
31
31
  `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/pdf/2205.14135.pdf>`
32
32
  """
33
33
 
34
- def __init__(self, query, key, value, output, dO, rowsum, rowmax, dim_mask, attn_mask, dropout_mask, alibi_mask,
34
+ def __init__(self, query, key, value, output, dO, rowsum, rowmax, attn_mask, dropout_mask, alibi_mask,
35
35
  prev_block_num,
36
36
  next_block_num,
37
37
  high_precision,
38
38
  kernel_name,
39
39
  tiling_stgy: TilingStrategy,
40
40
  disable_debug):
41
- super().__init__(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_mask, kernel_name,
41
+ super().__init__(query, key, value, attn_mask, dropout_mask, alibi_mask, kernel_name,
42
42
  tiling_stgy, prev_block_num, next_block_num, high_precision, disable_debug)
43
43
 
44
44
  if isinstance(query, dict):
@@ -76,7 +76,7 @@ class FlashAttentionBwd(FlashAttention):
76
76
  """
77
77
  input_gm_list = [
78
78
  self.Q_gm, self.K_gm, self.V_gm, self.O_gm, self.dO_gm, self.l_gm,
79
- self.m_gm, self.dim_mask_gm
79
+ self.m_gm
80
80
  ]
81
81
  if self.has_attn_mask:
82
82
  input_gm_list.append(self.att_mask_gm)
@@ -86,15 +86,22 @@ class FlashAttentionBwd(FlashAttention):
86
86
  input_gm_list.append(self.alibi_mask_gm)
87
87
  return input_gm_list
88
88
 
89
+ def prepare_global_ones(self):
90
+ """Prepare global ones tensor in L1 for cube impl row_sum"""
91
+ self.ones_l1 = self.tik_instance.Tensor(FP16, (self.d, 16), name="ones_l1", scope=L1)
92
+ with self.tik_instance.new_stmt_scope(disable_sync=False):
93
+ ones_ub = self.tik_instance.Tensor(FP16, (self.d, 16), name="ones_ub", scope=UB)
94
+ self.tik_instance.h_duplicate(ones_ub, 1.0)
95
+ self.cont_data_mv_1_bust(dst=self.ones_l1, src=ones_ub, burst=self.d)
96
+
89
97
  def compute_Pij(self, Qi_l1_K1MK0_ed, KjT_l1_K1NK0_ed, m, k, n, lm_gm_offset, attn_mask_gm_offset,
90
98
  dropout_mask_gm_offset, alibi_mask_gm_offset):
91
99
  """Refer to Algorithm 4 line11-14 in FlashAttention implement Pij computation"""
92
100
  m_aligned = self.tik_ops_utils.up_align_to_K0(m)
93
101
  n_aligned = self.tik_ops_utils.up_align_to_K0(n)
94
- # Sij <- Qi * KjT
95
- Sij_ub = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1NK0_ed, m, k, n)
96
- if self.has_drop_mask:
97
- Pij_drop_ed_ub = self.tik_instance.Tensor(FP16, (m_aligned, n_aligned), name="Pij_drop_ed_ub", scope=UB)
102
+ Sij_ub = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1NK0_ed, m, k, n, N1MN0_to_MN=False)
103
+ Pij_drop_ed_ub = self.tik_instance.Tensor(FP16, (n_aligned // self.N0, m_aligned, self.N0),
104
+ name="Pij_drop_ed_ub", scope=UB)
98
105
 
99
106
  with self.tik_instance.new_stmt_scope(disable_sync=False):
100
107
  if self.has_alibi_mask:
@@ -107,35 +114,42 @@ class FlashAttentionBwd(FlashAttention):
107
114
  mi_ub = self.tik_instance.Tensor(FP16, (m_aligned,), name="mi_ub", scope=UB)
108
115
  self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, lm_gm_offset, m)
109
116
  self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, lm_gm_offset, m)
117
+ n1 = n_aligned // self.N0
110
118
  with self.tik_instance.new_stmt_scope(disable_sync=False):
111
- broadcast_mi_ub = self.tik_ops_utils.broadcast(mi_ub, (m_aligned, n_aligned))
112
- self.tik_instance.h_sub(Sij_ub, Sij_ub, broadcast_mi_ub)
119
+ broadcast_mi_ub = self.tik_ops_utils.broadcast(mi_ub, (m, self.N0))
120
+ broadcast_mi_ub = broadcast_mi_ub.reshape((1, m, self.N0))
121
+ for idx in range(n1):
122
+ self.tik_instance.h_sub(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_mi_ub)
113
123
  li_rec_ub = self.tik_ops_utils.calc_vec_rec(li_ub, m)
114
124
  with self.tik_instance.new_stmt_scope(disable_sync=False):
115
125
  if self.high_precision:
116
- Sij_ub_fp32 = self.tik_instance.Tensor(FP32, (m_aligned, n_aligned), name="Sij_ub_fp32", scope=UB)
126
+ # fp16 -> fp32
127
+ Sij_ub_fp32 = self.tik_instance.Tensor(FP32, (n_aligned // self.N0, m_aligned, self.N0),
128
+ name="Sij_ub_fp32", scope=UB)
117
129
  self.tik_instance.h_cast(Sij_ub_fp32, Sij_ub, "none")
118
130
  self.tik_instance.h_exp(Sij_ub_fp32, Sij_ub_fp32)
119
- with self.tik_instance.for_range(0, m) as idx:
120
- cur_row_sum_rec = self.tik_instance.Scalar(FP32, name="cur_row_sum_rec",
121
- init_value=li_rec_ub[idx])
122
- self.tik_instance.h_mul(Sij_ub_fp32[idx, :], Sij_ub_fp32[idx, :], cur_row_sum_rec)
131
+ cur_row_sum_rec = self.tik_instance.Tensor(FP32, (m_aligned, self.N0), name="cur_row_sum_rec",
132
+ scope=UB)
133
+ for i in range(m_aligned):
134
+ src_scalar = self.tik_instance.Scalar(init_value=li_rec_ub[i], dtype=FP32)
135
+ self.tik_instance.h_duplicate(cur_row_sum_rec[i, :], src_scalar)
136
+ cur_row_sum_rec = cur_row_sum_rec.reshape((1, m_aligned, self.N0))
137
+ with self.tik_instance.for_range(0, n_aligned // self.N0) as idx:
138
+ self.tik_instance.h_mul(Sij_ub_fp32[idx, :, :], Sij_ub_fp32[idx, :, :], cur_row_sum_rec)
139
+ # fp32 -> fp16
123
140
  self.tik_instance.h_cast(Sij_ub, Sij_ub_fp32, "none")
124
141
  else:
125
142
  self.tik_instance.h_exp(Sij_ub, Sij_ub)
126
- broadcast_li_rec_ub = self.tik_ops_utils.broadcast(li_rec_ub, (m_aligned, n_aligned))
127
- self.tik_instance.h_mul(Sij_ub, Sij_ub, broadcast_li_rec_ub)
143
+ broadcast_li_rec_ub = self.tik_ops_utils.broadcast(li_rec_ub, (m_aligned, self.N0))
144
+ broadcast_li_rec_ub = broadcast_li_rec_ub.reshape((1, m_aligned, self.N0))
145
+ for idx in range(n1):
146
+ self.tik_instance.h_mul(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_li_rec_ub)
128
147
 
129
- # dropout_mask
130
148
  if self.has_drop_mask:
131
- with self.tik_instance.new_stmt_scope(disable_sync=False):
132
- dropout_mask_ub = self.tik_instance.Tensor(FP16, (m_aligned, n_aligned),
133
- scope=UB, name="drop_mask_ub")
134
- self.tik_instance.data_move(dropout_mask_ub, self.drop_mask_gm[dropout_mask_gm_offset], 0,
135
- m_aligned, n_aligned // 16, (self.N - n_aligned) // 16, 0)
136
- self.tik_instance.h_mul(Pij_drop_ed_ub, Sij_ub, dropout_mask_ub)
149
+ self.do_dropout_mask(Sij_ub, dropout_mask_gm_offset, n_aligned, n, m_aligned, m,
150
+ workspace=Pij_drop_ed_ub)
137
151
  else:
138
- Pij_drop_ed_ub = Sij_ub
152
+ self.cont_data_mv_1_bust(dst=Pij_drop_ed_ub, src=Sij_ub, burst=m_aligned * n_aligned // 16)
139
153
 
140
154
  return Sij_ub, Pij_drop_ed_ub
141
155
 
@@ -143,13 +157,16 @@ class FlashAttentionBwd(FlashAttention):
143
157
  """Refer to Algorithm 4 line19 in FlashAttention implement Di computation"""
144
158
  q_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
145
159
  with self.tik_instance.new_stmt_scope(disable_sync=False):
146
- Oi_ub = self.tik_instance.Tensor(FP16, (q_blk_height_aligned, self.d), name="Oi_ub", scope=UB)
147
- self.cont_data_mv_1_bust(dst=Oi_ub, src=self.O_gm[qo_gm_offset],
148
- burst=q_blk_height * self.d // 16)
160
+ Oi_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_aligned, self.N0),
161
+ scope=UB, name="Oi_ub")
162
+ self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm[qo_gm_offset],
163
+ sid=0, nburst=self.N1, burst=q_blk_height * self.N0 // 16,
164
+ src_stride=(self.Nq - q_blk_height) * self.N0 // 16, dst_stride=0)
149
165
  self.tik_instance.h_mul(Oi_ub, dOi_ub, Oi_ub)
150
- dOi_Oi_l1 = self.tik_instance.Tensor(FP16, (q_blk_height_aligned, self.d), name="dOi_Oi_l1", scope=L1)
151
- dOi_Oi_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Oi_ub, workspace_tensor=dOi_Oi_l1)
152
- self.tik_ops_utils.row_sum_cube_impl(dOi_Oi_l1_K1MK0_ed, Di_ub, q_blk_height,
166
+ dOi_Oi_l1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_aligned, self.N0),
167
+ name="dOi_Oi_l1_K1MK0", scope=L1)
168
+ self.cont_data_mv_1_bust(dst=dOi_Oi_l1_K1MK0, src=Oi_ub, burst=q_blk_height_aligned * self.d // 16)
169
+ self.tik_ops_utils.row_sum_cube_impl(dOi_Oi_l1_K1MK0, self.ones_l1, Di_ub, q_blk_height,
153
170
  self.actual_d, precision_type=FP16)
154
171
 
155
172
  def compute_dSij(self, Pij_ub, dOi_l1_K1MK0_ed, VjT_K1NK0_ed, Di_ub, kv_blk_height, q_blk_height,
@@ -157,7 +174,7 @@ class FlashAttentionBwd(FlashAttention):
157
174
  """Refer to Algorithm 4 line20 in FlashAttention implement dSij computation"""
158
175
  with self.tik_instance.new_stmt_scope(disable_sync=False):
159
176
  dPij_ub = self.tik_ops_utils.matmul_compute(dOi_l1_K1MK0_ed, VjT_K1NK0_ed,
160
- q_blk_height, self.actual_d, kv_blk_height)
177
+ q_blk_height, self.actual_d, kv_blk_height, N1MN0_to_MN=False)
161
178
  q_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
162
179
  kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
163
180
  # dropout_mask
@@ -166,8 +183,11 @@ class FlashAttentionBwd(FlashAttention):
166
183
  q_blk_height_aligned, q_blk_height)
167
184
  # dPij - Di
168
185
  with self.tik_instance.new_stmt_scope(disable_sync=False):
169
- broadcast_Di_ub = self.tik_ops_utils.broadcast(Di_ub, (q_blk_height_aligned, kv_blk_height_aligned))
170
- self.tik_instance.h_sub(dPij_ub, dPij_ub, broadcast_Di_ub)
186
+ broadcast_Di_ub = self.tik_ops_utils.broadcast(Di_ub, (q_blk_height_aligned, self.N0))
187
+ broadcast_Di_ub = broadcast_Di_ub.reshape((1, q_blk_height_aligned, self.N0))
188
+ n1 = kv_blk_height_aligned // self.N0
189
+ for idx in range(n1):
190
+ self.tik_instance.h_sub(dPij_ub[idx, :, :], dPij_ub[idx, :, :], broadcast_Di_ub)
171
191
  self.tik_instance.h_mul(Pij_ub, Pij_ub, dPij_ub)
172
192
  return Pij_ub
173
193
 
@@ -181,10 +201,12 @@ class FlashAttentionBwd(FlashAttention):
181
201
  with self.tik_instance.new_stmt_scope(disable_sync=False):
182
202
  PijT_Oi_ub = self.tik_ops_utils.matmul_compute(PijT_l1_K1MK0_ed, dOi_l1_K1NK0_ed,
183
203
  kv_blk_height, q_blk_height,
184
- self.actual_d, precision_type=FP32)
204
+ self.actual_d, N1MN0_to_MN=False,
205
+ precision_type=FP32)
185
206
  self.tik_instance.set_atomic_add(1)
186
- self.cont_data_mv_1_bust(dst=self.dV_gm[kv_gm_offset], src=PijT_Oi_ub,
187
- burst=kv_blk_height * self.d // 8)
207
+ self.tik_instance.data_move(dst=self.dV_gm[kv_gm_offset], src=PijT_Oi_ub, sid=0,
208
+ nburst=self.N1, burst=kv_blk_height * self.N0 // 8,
209
+ src_stride=0, dst_stride=(self.Nq - kv_blk_height) * self.N0 // 8)
188
210
  self.tik_instance.set_atomic_add(0)
189
211
 
190
212
  def update_dQi(self,
@@ -197,10 +219,11 @@ class FlashAttentionBwd(FlashAttention):
197
219
  with self.tik_instance.new_stmt_scope(disable_sync=False):
198
220
  dSij_Kj_ub = self.tik_ops_utils.matmul_compute(dSij_l1_K1MK0_ed, Kj_l1_K1NK0_ed,
199
221
  q_blk_height, kv_blk_height,
200
- self.actual_d, precision_type=FP32)
222
+ self.actual_d, N1MN0_to_MN=False, precision_type=FP32)
201
223
  self.tik_instance.set_atomic_add(1)
202
- self.cont_data_mv_1_bust(dst=self.dQ_gm[qo_gm_offset], src=dSij_Kj_ub,
203
- burst=q_blk_height * self.d // 8)
224
+ self.tik_instance.data_move(dst=self.dQ_gm[qo_gm_offset], src=dSij_Kj_ub, sid=0,
225
+ nburst=self.d // self.N0, burst=q_blk_height * self.N0 // 8,
226
+ src_stride=0, dst_stride=(self.Nq - q_blk_height) * self.N0 // 8)
204
227
  self.tik_instance.set_atomic_add(0)
205
228
 
206
229
  def update_dKj(self,
@@ -213,10 +236,11 @@ class FlashAttentionBwd(FlashAttention):
213
236
  with self.tik_instance.new_stmt_scope(disable_sync=False):
214
237
  dSijT_Qi_ub = self.tik_ops_utils.matmul_compute(dSijT_l1_K1MK0_ed, Qi_l1_K1NK0_ed,
215
238
  kv_blk_height, q_blk_height,
216
- self.actual_d, precision_type=FP32)
239
+ self.actual_d, N1MN0_to_MN=False, precision_type=FP32)
217
240
  self.tik_instance.set_atomic_add(1)
218
- self.cont_data_mv_1_bust(dst=self.dK_gm[kv_gm_offset], src=dSijT_Qi_ub,
219
- burst=kv_blk_height * self.d // 8)
241
+ self.tik_instance.data_move(dst=self.dK_gm[kv_gm_offset], src=dSijT_Qi_ub, sid=0,
242
+ nburst=self.d // self.N0, burst=kv_blk_height * self.N0 // 8,
243
+ src_stride=0, dst_stride=(self.Nq - kv_blk_height) * self.N0 // 8)
220
244
  self.tik_instance.set_atomic_add(0)
221
245
 
222
246
  def compute_in_each_kv_block(self, batch_start, batch_idx, kv_blk_idx, kv_blk_height,
@@ -225,23 +249,34 @@ class FlashAttentionBwd(FlashAttention):
225
249
  kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
226
250
  kv_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.N, self.d,
227
251
  self.Bc, kv_blk_idx)
228
- Kj_l1_1 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Kj_l1_1",
229
- scope=L1)
252
+ # load KjT
253
+ Kj_l1_1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
254
+ name="Kj_l1_1_K1MK0",
255
+ scope=L1)
256
+ self.tik_instance.data_move(dst=Kj_l1_1_K1MK0, src=self.K_gm[kv_gm_offset],
257
+ sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
258
+ src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
259
+
260
+ # load Kj
230
261
  Kj_l1_2 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Kj_l1_2",
231
262
  scope=L1)
232
263
  with self.tik_instance.new_stmt_scope(disable_sync=False):
233
- Kj_ub = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Kj_ub", scope=UB)
234
- self.cont_data_mv_1_bust(dst=Kj_ub, src=self.K_gm[kv_gm_offset],
235
- burst=kv_blk_height * self.d // 16)
236
- KjT_l1_K1NK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Kj_ub, workspace_tensor=Kj_l1_1)
237
- Kj_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Kj_ub, workspace_tensor=Kj_l1_2)
238
-
239
- Vj_l1 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Vj_l1", scope=L1)
240
- with self.tik_instance.new_stmt_scope(disable_sync=False):
241
- Vj_ub = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Vj_ub", scope=UB)
242
- self.cont_data_mv_1_bust(dst=Vj_ub, src=self.V_gm[kv_gm_offset],
243
- burst=kv_blk_height * self.d // 16)
244
- VjT_l1_K1NK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Vj_ub, workspace_tensor=Vj_l1)
264
+ Kj_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
265
+ name="Kj_ub", scope=UB)
266
+ self.tik_instance.data_move(dst=Kj_ub, src=self.K_gm[kv_gm_offset],
267
+ sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
268
+ src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
269
+ # (N1, K, N0) -> (K, N)
270
+ Kj_ub = self.tik_ops_utils.N1MN0_TO_MN(Kj_ub)
271
+ # (K, N) -> (K1, N, K0)
272
+ Kj_l1_2_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Kj_ub, workspace_tensor=Kj_l1_2)
273
+
274
+ # load VjT
275
+ Vj_l1 = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0), name="Vj_l1",
276
+ scope=L1)
277
+ self.tik_instance.data_move(dst=Vj_l1, src=self.V_gm[kv_gm_offset],
278
+ sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
279
+ src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
245
280
 
246
281
  tr_start_s = self.tik_instance.Scalar("int32", name="tr_start_s")
247
282
  tr_end_s = self.tik_instance.Scalar("int32", name="tr_end_s")
@@ -251,9 +286,9 @@ class FlashAttentionBwd(FlashAttention):
251
286
  with self.tik_instance.if_scope(tik.all(kv_blk_idx - self.next_block_num <= q_blk_idx,
252
287
  q_blk_idx <= kv_blk_idx + self.prev_block_num)):
253
288
  with self.tik_instance.if_scope(q_blk_idx != self.Tr - 1):
254
- self.compute_in_each_q_block(KjT_l1_K1NK0_ed,
255
- Kj_l1_K1NK0_ed,
256
- VjT_l1_K1NK0_ed,
289
+ self.compute_in_each_q_block(Kj_l1_1_K1MK0,
290
+ Kj_l1_2_K1NK0_ed,
291
+ Vj_l1,
257
292
  batch_idx,
258
293
  batch_start,
259
294
  kv_gm_offset,
@@ -262,9 +297,9 @@ class FlashAttentionBwd(FlashAttention):
262
297
  kv_blk_idx,
263
298
  q_blk_idx)
264
299
  with self.tik_instance.else_scope():
265
- self.compute_in_each_q_block(KjT_l1_K1NK0_ed,
266
- Kj_l1_K1NK0_ed,
267
- VjT_l1_K1NK0_ed,
300
+ self.compute_in_each_q_block(Kj_l1_1_K1MK0,
301
+ Kj_l1_2_K1NK0_ed,
302
+ Vj_l1,
268
303
  batch_idx,
269
304
  batch_start,
270
305
  kv_gm_offset,
@@ -280,19 +315,28 @@ class FlashAttentionBwd(FlashAttention):
280
315
  kv_blk_height_alig = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
281
316
  q_blk_height_alig = self.tik_ops_utils.up_align_to_K0(q_blk_height)
282
317
 
283
- Qi_l1_left = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="Qi_l1_left",
284
- scope=L1) # for Qi*KjT (算法11行)
285
- Qi_l1_right = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="Qi_l1_right",
286
- scope=L1) # for dSijT*Qi (算法22行)
287
318
  qo_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
319
+ Qi_l1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
320
+ name="Qi_l1_K1MK0",
321
+ scope=L1)
322
+ self.tik_instance.data_move(dst=Qi_l1_K1MK0, src=self.Q_gm[qo_gm_offset],
323
+ sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
324
+ src_stride=(self.Nq - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
325
+
326
+ Qi_l1_right = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="Qi_l1_right",
327
+ scope=L1)
288
328
  with self.tik_instance.new_stmt_scope(disable_sync=False):
289
- Qi_ub = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="Qi_ub", scope=UB)
290
- self.cont_data_mv_1_bust(dst=Qi_ub, src=self.Q_gm[qo_gm_offset],
291
- burst=q_blk_height * self.d // 16)
292
- Qi_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(Qi_ub, workspace_tensor=Qi_l1_left)
329
+ Qi_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
330
+ name="Qi_ub", scope=UB)
331
+ self.tik_instance.data_move(dst=Qi_ub, src=self.Q_gm[qo_gm_offset],
332
+ sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
333
+ src_stride=(self.N - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
334
+ # (N1, K, N0) -> (K, N)
335
+ Qi_ub = self.tik_ops_utils.N1MN0_TO_MN(Qi_ub)
336
+ # (K, N) -> (K1, N, K0)
293
337
  Qi_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Qi_ub, workspace_tensor=Qi_l1_right)
294
338
 
295
- lm_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, 1, self.Br, q_blk_idx)
339
+ lm_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
296
340
  attn_mask_gm_offset, dropout_mask_gm_offset, alibi_mask_gm_offset = None, None, None
297
341
  if self.has_attn_mask:
298
342
  attn_mask_gm_offset = self.get_attn_mask_gm_offset(batch_start, batch_idx, self.Nq, self.N,
@@ -302,41 +346,55 @@ class FlashAttentionBwd(FlashAttention):
302
346
  self.Br, q_blk_idx, self.Bc, kv_blk_idx)
303
347
  if self.has_alibi_mask:
304
348
  alibi_mask_gm_offset = self.get_alibi_gm_offset(batch_start, batch_idx, self.N, self.Bc, kv_blk_idx)
305
- Pij_ub, Pij_drop_ed_ub = self.compute_Pij(Qi_l1_K1MK0_ed, KjT_l1_K1NK0_ed,
349
+ Pij_ub, Pij_drop_ed_ub = self.compute_Pij(Qi_l1_K1MK0, KjT_l1_K1NK0_ed,
306
350
  q_blk_height, self.actual_d, kv_blk_height,
307
351
  lm_gm_offset, attn_mask_gm_offset,
308
352
  dropout_mask_gm_offset, alibi_mask_gm_offset)
309
353
 
310
354
  dOi_l1_right = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="dOi_l1_right",
311
355
  scope=L1)
312
- dOi_l1_left = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="dOi_l1_left",
313
- scope=L1)
314
356
  Di_ub = self.tik_instance.Tensor(FP16, (q_blk_height_alig,), name="Di_ub", scope=UB)
315
357
  with self.tik_instance.new_stmt_scope(disable_sync=False):
316
- dOi_ub = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="dOi_ub", scope=UB)
317
- self.cont_data_mv_1_bust(dst=dOi_ub, src=self.dO_gm[qo_gm_offset],
318
- burst=q_blk_height * self.d // 16)
358
+ dOi_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
359
+ name="dOi_ub", scope=UB)
360
+ self.tik_instance.data_move(dst=dOi_ub, src=self.dO_gm[qo_gm_offset],
361
+ sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
362
+ src_stride=(self.Nq - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
363
+
319
364
  self.compute_Di(Di_ub, dOi_ub, qo_gm_offset, q_blk_height)
365
+ # (N1, K, N0) -> (K, N)
366
+ dOi_ub = self.tik_ops_utils.N1MN0_TO_MN(dOi_ub)
367
+ # (K, N) -> (K1, N, K0)
320
368
  dOi_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(dOi_ub, workspace_tensor=dOi_l1_right)
321
- dOi_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(dOi_ub, workspace_tensor=dOi_l1_left)
322
369
 
370
+ dOi_l1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
371
+ name="dOi_l1_K1MK0",
372
+ scope=L1)
373
+
374
+ self.tik_instance.data_move(dst=dOi_l1_K1MK0, src=self.dO_gm[qo_gm_offset],
375
+ sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
376
+ src_stride=(self.Nq - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
323
377
  Pij_l1 = self.tik_instance.Tensor(FP16, (q_blk_height_alig, kv_blk_height_alig), name="Pij_l1", scope=L1)
378
+ Pij_drop_ed_ub = self.tik_ops_utils.N1MN0_TO_MN(Pij_drop_ed_ub)
324
379
  PijT_l1_K1MK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Pij_drop_ed_ub, workspace_tensor=Pij_l1)
325
380
  self.update_dVj(PijT_l1_K1MK0_ed, dOi_l1_K1NK0_ed,
326
381
  kv_gm_offset, kv_blk_height, q_blk_height)
327
- dSij_l1_1 = self.tik_instance.Tensor(FP16, (q_blk_height_alig, kv_blk_height_alig),
328
- name="dSij_l1_1", scope=L1)
382
+ # (L1: 512K)
383
+ dSij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (kv_blk_height_alig // self.N0, q_blk_height_alig, self.N0),
384
+ name="dSij_l1_1", scope=L1)
329
385
  dSij_l1_2 = self.tik_instance.Tensor(FP16, (q_blk_height_alig, kv_blk_height_alig),
330
386
  name="dSij_l1_2", scope=L1)
331
387
  with self.tik_instance.new_stmt_scope(disable_sync=False):
332
388
  dSij_ub = self.compute_dSij(Pij_ub,
333
- dOi_l1_K1MK0_ed,
389
+ dOi_l1_K1MK0,
334
390
  VjT_l1_K1NK0_ed,
335
391
  Di_ub,
336
392
  kv_blk_height,
337
393
  q_blk_height,
338
394
  dropout_mask_gm_offset)
339
- dSij_l1_K1MK0_ed = self.tik_ops_utils.MK_TO_K1MK0(dSij_ub, workspace_tensor=dSij_l1_1)
395
+ self.cont_data_mv_1_bust(dst=dSij_l1_K1MK0_ed, src=dSij_ub,
396
+ burst=kv_blk_height_alig * q_blk_height_alig // 16)
397
+ dSij_ub = self.tik_ops_utils.N1MN0_TO_MN(dSij_ub)
340
398
  dSijT_l1_K1MK0_ed = self.tik_ops_utils.KN_TO_K1NK0(dSij_ub, workspace_tensor=dSij_l1_2)
341
399
  self.update_dQi(dSij_l1_K1MK0_ed, Kj_l1_K1NK0_ed,
342
400
  qo_gm_offset, q_blk_height, kv_blk_height)
@@ -358,10 +416,11 @@ class FlashAttentionBwd(FlashAttention):
358
416
  """collect all output gm tensors into output_gm_list,
359
417
  the output list should keep order with the para order in Primitive and init
360
418
  """
361
- return [self.dQ_gm, self.dK_gm, self.dV_gm]
419
+ output_gm_list = [self.dQ_gm, self.dK_gm, self.dV_gm]
420
+ return output_gm_list
362
421
 
363
422
 
364
- def flash_attention_grad(Query, Key, Value, Output, dO, rowsum, rowmax, dim_mask, attn_mask, dropout_mask, alibi_mask,
423
+ def flash_attention_grad(Query, Key, Value, Output, dO, rowsum, rowmax, attn_mask, dropout_mask, alibi_mask,
365
424
  dq, dk, dv,
366
425
  prev_block_num=65536,
367
426
  next_block_num=65536,
@@ -381,7 +440,6 @@ def flash_attention_grad(Query, Key, Value, Output, dO, rowsum, rowmax, dim_mask
381
440
  dO: dict. shape and dtype of input, only support float16
382
441
  rowsum: dict. shape and dtype of input, only support float16
383
442
  rowmax: dict. shape and dtype of input, only support float16
384
- dim_mask: dict. shape and dtype of input, only support int8
385
443
  dropout_mask: dict. shape and dtype of input, only support float16
386
444
  dropout_mask: dict. shape and dtype of input, only support float16
387
445
  alibi_mask: dict. shape and dtype of input, only support float16
@@ -398,7 +456,7 @@ def flash_attention_grad(Query, Key, Value, Output, dO, rowsum, rowmax, dim_mask
398
456
  -------
399
457
  tik_instance
400
458
  """
401
- fa_grad = FlashAttentionBwd(Query, Key, Value, Output, dO, rowsum, rowmax, dim_mask, attn_mask, dropout_mask,
459
+ fa_grad = FlashAttentionBwd(Query, Key, Value, Output, dO, rowsum, rowmax, attn_mask, dropout_mask,
402
460
  alibi_mask, prev_block_num=prev_block_num,
403
461
  next_block_num=next_block_num,
404
462
  high_precision=high_precision,