mindspore 2.1.0__cp37-cp37m-win_amd64.whl → 2.2.11__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (511) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +13 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +67 -72
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +86 -106
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +29 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +33 -7
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +61 -95
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/__init__.py +4 -2
  205. mindspore/nn/wrap/cell_wrapper.py +87 -34
  206. mindspore/nn/wrap/grad_reducer.py +8 -5
  207. mindspore/nn/wrap/loss_scale.py +105 -42
  208. mindspore/numpy/array_creations.py +1 -2
  209. mindspore/numpy/array_ops.py +3 -2
  210. mindspore/numpy/utils_const.py +5 -5
  211. mindspore/opencv_core452.dll +0 -0
  212. mindspore/opencv_imgcodecs452.dll +0 -0
  213. mindspore/opencv_imgproc452.dll +0 -0
  214. mindspore/ops/_grad_experimental/__init__.py +0 -5
  215. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  216. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  217. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  218. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  219. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  220. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  221. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  222. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  223. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  224. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  225. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  226. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  227. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  228. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  229. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  230. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  231. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  232. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  233. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  234. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  235. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  236. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  237. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  238. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  239. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  240. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  241. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  242. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  243. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  244. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  245. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  246. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  248. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  249. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  250. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  251. mindspore/ops/_primitive_cache.py +1 -1
  252. mindspore/ops/_tracefunc.py +45 -13
  253. mindspore/ops/_utils/utils.py +6 -1
  254. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  255. mindspore/ops/_vmap/vmap_base.py +3 -3
  256. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  257. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  258. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  259. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  260. mindspore/ops/arg_dtype_cast.py +54 -0
  261. mindspore/ops/composite/base.py +37 -10
  262. mindspore/ops/composite/math_ops.py +5 -4
  263. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  264. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  265. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  266. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  268. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  270. mindspore/ops/deprecated.py +304 -0
  271. mindspore/ops/function/__init__.py +4 -1
  272. mindspore/ops/function/array_func.py +174 -193
  273. mindspore/ops/function/clip_func.py +81 -13
  274. mindspore/ops/function/debug_func.py +1 -1
  275. mindspore/ops/function/grad/grad_func.py +18 -9
  276. mindspore/ops/function/image_func.py +10 -4
  277. mindspore/ops/function/linalg_func.py +5 -5
  278. mindspore/ops/function/math_func.py +575 -386
  279. mindspore/ops/function/nn_func.py +568 -260
  280. mindspore/ops/function/random_func.py +88 -57
  281. mindspore/ops/function/sparse_func.py +1 -1
  282. mindspore/ops/function/sparse_unary_func.py +14 -12
  283. mindspore/ops/function/vmap_func.py +6 -5
  284. mindspore/ops/functional.py +15 -10
  285. mindspore/ops/op_info_register.py +244 -25
  286. mindspore/ops/operations/__init__.py +31 -19
  287. mindspore/ops/operations/_grad_ops.py +71 -7
  288. mindspore/ops/operations/_inner_ops.py +350 -17
  289. mindspore/ops/operations/_quant_ops.py +4 -8
  290. mindspore/ops/operations/_sequence_ops.py +42 -0
  291. mindspore/ops/operations/array_ops.py +68 -282
  292. mindspore/ops/operations/comm_ops.py +107 -59
  293. mindspore/ops/operations/custom_ops.py +94 -70
  294. mindspore/ops/operations/debug_ops.py +8 -4
  295. mindspore/ops/operations/image_ops.py +18 -12
  296. mindspore/ops/operations/inner_ops.py +26 -3
  297. mindspore/ops/operations/math_ops.py +192 -144
  298. mindspore/ops/operations/nn_ops.py +857 -489
  299. mindspore/ops/operations/other_ops.py +0 -22
  300. mindspore/ops/operations/random_ops.py +53 -111
  301. mindspore/ops/operations/sparse_ops.py +3 -1
  302. mindspore/ops/primitive.py +24 -18
  303. mindspore/parallel/_auto_parallel_context.py +68 -8
  304. mindspore/parallel/_cost_model_context.py +2 -2
  305. mindspore/parallel/_offload_context.py +17 -3
  306. mindspore/parallel/_parallel_serialization.py +12 -5
  307. mindspore/parallel/_ps_context.py +12 -0
  308. mindspore/parallel/_tensor.py +18 -13
  309. mindspore/parallel/_transformer/layers.py +5 -3
  310. mindspore/parallel/_transformer/loss.py +1 -0
  311. mindspore/parallel/_transformer/moe.py +2 -2
  312. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  313. mindspore/parallel/_transformer/transformer.py +23 -3
  314. mindspore/parallel/_utils.py +11 -7
  315. mindspore/parallel/algo_parameter_config.py +85 -5
  316. mindspore/parallel/checkpoint_transform.py +19 -12
  317. mindspore/parallel/shard.py +21 -14
  318. mindspore/pgodb140.dll +0 -0
  319. mindspore/pgort140.dll +0 -0
  320. mindspore/profiler/common/struct_type.py +3 -3
  321. mindspore/profiler/common/util.py +4 -2
  322. mindspore/profiler/envprofiling.py +1 -1
  323. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  324. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  325. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  326. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  327. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  328. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  329. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  330. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  331. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  332. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  333. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  334. mindspore/profiler/parser/flops_parser.py +15 -11
  335. mindspore/profiler/parser/framework_parser.py +38 -22
  336. mindspore/profiler/parser/hccl_parser.py +16 -12
  337. mindspore/profiler/parser/integrator.py +22 -11
  338. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  339. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  340. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  341. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  342. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  343. mindspore/profiler/parser/optime_parser.py +1 -1
  344. mindspore/profiler/parser/profiler_info.py +21 -2
  345. mindspore/profiler/parser/step_trace_parser.py +11 -14
  346. mindspore/profiler/profiling.py +179 -89
  347. mindspore/rewrite/api/node.py +102 -19
  348. mindspore/rewrite/api/node_type.py +5 -1
  349. mindspore/rewrite/api/pattern_engine.py +1 -1
  350. mindspore/rewrite/api/scoped_value.py +9 -17
  351. mindspore/rewrite/api/symbol_tree.py +131 -47
  352. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  353. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  354. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  355. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  356. mindspore/rewrite/common/rewrite_elog.py +5 -1
  357. mindspore/rewrite/namer.py +33 -24
  358. mindspore/rewrite/namespace.py +14 -5
  359. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  360. mindspore/rewrite/node/call_function.py +79 -0
  361. mindspore/rewrite/node/cell_container.py +135 -0
  362. mindspore/rewrite/node/control_flow.py +88 -0
  363. mindspore/rewrite/{node.py → node/node.py} +273 -234
  364. mindspore/rewrite/node/node_manager.py +254 -0
  365. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  366. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  367. mindspore/rewrite/parsers/assign_parser.py +216 -221
  368. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  369. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  370. mindspore/rewrite/parsers/constant_parser.py +9 -6
  371. mindspore/rewrite/parsers/container_parser.py +9 -7
  372. mindspore/rewrite/parsers/for_parser.py +42 -21
  373. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  374. mindspore/rewrite/parsers/if_parser.py +28 -24
  375. mindspore/rewrite/parsers/module_parser.py +196 -25
  376. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  377. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  378. mindspore/rewrite/parsers/return_parser.py +6 -6
  379. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  380. mindspore/rewrite/sparsify/utils.py +1 -1
  381. mindspore/rewrite/symbol_tree.py +523 -578
  382. mindspore/rewrite/symbol_tree_builder.py +9 -193
  383. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  384. mindspore/run_check/_check_version.py +6 -4
  385. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  386. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  387. mindspore/tbbmalloc.dll +0 -0
  388. mindspore/tinyxml2.dll +0 -0
  389. mindspore/train/_utils.py +7 -3
  390. mindspore/train/amp.py +323 -123
  391. mindspore/train/anf_ir_pb2.py +14 -2
  392. mindspore/train/callback/_backup_and_restore.py +2 -12
  393. mindspore/train/callback/_callback.py +29 -4
  394. mindspore/train/callback/_checkpoint.py +23 -8
  395. mindspore/train/callback/_early_stop.py +2 -2
  396. mindspore/train/callback/_landscape.py +4 -4
  397. mindspore/train/callback/_loss_monitor.py +2 -2
  398. mindspore/train/callback/_on_request_exit.py +2 -2
  399. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  400. mindspore/train/callback/_summary_collector.py +15 -8
  401. mindspore/train/callback/_time_monitor.py +58 -5
  402. mindspore/train/data_sink.py +5 -11
  403. mindspore/train/dataset_helper.py +84 -57
  404. mindspore/train/loss_scale_manager.py +2 -2
  405. mindspore/train/metrics/__init__.py +3 -3
  406. mindspore/train/metrics/cosine_similarity.py +1 -1
  407. mindspore/train/metrics/hausdorff_distance.py +3 -2
  408. mindspore/train/metrics/mean_surface_distance.py +3 -2
  409. mindspore/train/metrics/metric.py +39 -19
  410. mindspore/train/metrics/roc.py +2 -2
  411. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  412. mindspore/train/mind_ir_pb2.py +85 -36
  413. mindspore/train/model.py +187 -47
  414. mindspore/train/serialization.py +487 -161
  415. mindspore/train/summary/_summary_adapter.py +1 -1
  416. mindspore/train/summary/_writer_pool.py +3 -2
  417. mindspore/train/summary/summary_record.py +37 -17
  418. mindspore/train/train_thor/convert_utils.py +3 -3
  419. mindspore/train/train_thor/dataset_helper.py +1 -1
  420. mindspore/turbojpeg.dll +0 -0
  421. mindspore/vcmeta.dll +0 -0
  422. mindspore/vcruntime140.dll +0 -0
  423. mindspore/vcruntime140_1.dll +0 -0
  424. mindspore/version.py +1 -1
  425. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  426. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
  427. mindspore/_extends/graph_kernel/expander.py +0 -80
  428. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  429. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  430. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  431. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  432. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  433. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  434. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  435. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  436. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  437. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  438. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  439. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  440. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  441. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  442. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  443. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  444. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  445. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  446. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  447. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  448. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  449. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  450. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  451. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  452. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  453. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  454. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  455. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  456. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  457. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  458. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  459. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  460. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  461. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  462. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  463. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  464. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  465. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  466. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  467. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  468. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  469. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  470. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  471. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  472. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  473. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  474. mindspore/dataset/datapreprocess/__init__.py +0 -20
  475. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  476. mindspore/include/api/net.h +0 -142
  477. mindspore/nn/lr_scheduler.py +0 -262
  478. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  479. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  480. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  481. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  482. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  483. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  484. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  485. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  486. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  487. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  488. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  489. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  490. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  491. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  492. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  493. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  497. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  502. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  503. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  504. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  505. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  506. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  507. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  508. mindspore/rewrite/node_visitor.py +0 -44
  509. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  510. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  511. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,350 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """the base class of flash attention"""
16
- from abc import ABCMeta
17
- from abc import abstractmethod
18
- from functools import partial
19
- from collections import defaultdict
20
-
21
- import te.platform as tbe_platform
22
- from tbe import tik
23
- from tbe.common.platform import get_soc_spec
24
-
25
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
26
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
27
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
28
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import INT8
29
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import MASK_FILL_VALUE
30
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
31
- from mindspore.ops._op_impl._custom_op.flash_attention.tik_ops_utils import TikOpsUtils
32
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
33
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
34
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.sparse_tiling import SparseTiling
35
-
36
-
37
- class FlashAttention(metaclass=ABCMeta):
38
- """The base class of FlashAttention"""
39
-
40
- def __init__(self, q, k, v, dim_mask, attn_mask, dropout_mask, alibi_mask, kernel_name,
41
- tiling_stgy_cls,
42
- prev_block_num=65536,
43
- next_block_num=65536,
44
- high_precision=False,
45
- disable_debug=True):
46
- """
47
- Init parameter shape
48
- :param q: with shape: (B, h, N, d)
49
- :param k: with shape: (B, h, N, d)
50
- :param v: with shape: (B, h, N, d)
51
- :param dim_mask: with shape: (x, ) x equals length last dim that not padded to 16.
52
- :param attn_mask: with shape: (1, N, N) or (B, N, N)
53
- :param dropout_mask: with shape: (B, h, N, N)
54
- :param alibi_mask: with shape: (B, h, 1, N)
55
- :param kernel_name:
56
- :param tiling_stgy_cls:
57
- :param prev_block_num:
58
- :param next_block_num:
59
- :param disable_debug:
60
- """
61
- self.tik_instance = tik.Tik(disable_debug=disable_debug)
62
- self.core_num = get_soc_spec(tbe_platform.CORE_NUM)
63
- self.M = tbe_platform.get_soc_spec(tbe_platform.L1_SIZE)
64
- self.kernel_name = kernel_name
65
- self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
66
- src_stride=0,
67
- dst_stride=0)
68
- self.tik_ops_utils = TikOpsUtils(self.tik_instance)
69
- self.parser_input_shape(alibi_mask, attn_mask, dim_mask, dropout_mask, k, q, v)
70
-
71
- batch_size, h, Nq, d = self.q_shape
72
- self.head_num = h
73
- self.B, self.Nq, self.d = batch_size * h, Nq, d
74
- self.N = self.k_shape[2]
75
-
76
- self.l_shape = [batch_size, h, self.Nq]
77
- self.m_shape = [batch_size, h, self.Nq]
78
- self.O_shape = [batch_size, h, self.Nq, self.d]
79
- self.actual_d = self.dim_mask_shape[0]
80
-
81
- self.K0 = 16
82
- self.prev_block_num = prev_block_num
83
- self.next_block_num = next_block_num
84
- self.high_precision = high_precision
85
- if self.high_precision:
86
- self.precision_type = FP32
87
- else:
88
- self.precision_type = FP16
89
-
90
- if tiling_stgy_cls is None:
91
- self.tiling_stgy = SparseTiling(self.Nq, self.N, self.d)
92
- else:
93
- self.tiling_stgy: TilingStrategy = tiling_stgy_cls(self.Nq, self.N, self.d)
94
- self.Br = None
95
- self.last_Br = None
96
- self.Bc = None
97
- self.last_Bc = None
98
- self.Tr = None
99
- self.Tc = None
100
- self.Q_gm = None
101
- self.K_gm = None
102
- self.V_gm = None
103
- self.dim_mask_gm = None
104
- self.att_mask_gm = None
105
- self.drop_mask_gm = None
106
- self.alibi_mask_gm = None
107
-
108
- @staticmethod
109
- def get_gm_offset(batch_start, batch_idx, h, w, block_h, block_idx):
110
- """get gm offset"""
111
- gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * w
112
- return gm_offset
113
-
114
- @staticmethod
115
- def get_alibi_gm_offset(batch_start, batch_idx, w, block_w, block_idx):
116
- """get alibi gm offset"""
117
- gm_offset = (batch_start + batch_idx) * w + block_idx * block_w
118
- return gm_offset
119
-
120
- @staticmethod
121
- def get_drop_mask_gm_offset(batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
122
- """get drop mask gm offset"""
123
- gm_offset = (batch_start + batch_idx) * h * w + block_h_idx * (w * block_h) + block_w_idx * block_w
124
- return gm_offset
125
-
126
- @abstractmethod
127
- def define_custom_inputs(self):
128
- """define custom inputs"""
129
- raise NotImplementedError
130
-
131
- @abstractmethod
132
- def define_outputs(self):
133
- """define outputs"""
134
- raise NotImplementedError
135
-
136
- @abstractmethod
137
- def collect_inputs(self):
138
- """collect inputs"""
139
- raise NotImplementedError
140
-
141
- @abstractmethod
142
- def collect_outputs(self):
143
- """collect outputs"""
144
- raise NotImplementedError
145
-
146
- def get_core_bath_info(self):
147
- """
148
- Get batch start and batch number of each NPU core.
149
- :return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
150
- core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
151
- core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
152
- core_m_batch_n_Tr_end]]
153
- """
154
- if self.core_num > self.B * self.Tr:
155
- self.core_num = self.B * self.Tr
156
-
157
- task_idx_to_batch_tr_idx = dict()
158
- for task_idx in range(self.B * self.Tr):
159
- batch_idx = task_idx // self.Tr
160
- tr_idx = task_idx % self.Tr
161
- task_idx_to_batch_tr_idx[task_idx] = [batch_idx, tr_idx]
162
-
163
- core_idx_to_batch_idx = defaultdict(lambda: [100000, -1])
164
- core_idx_to_tr_idx = defaultdict(lambda: defaultdict(lambda: [100000, -1]))
165
- task_start = 0
166
- avg_task_num_per_core, remain_task = divmod(self.B * self.Tr, self.core_num)
167
-
168
- for core_idx in range(self.core_num):
169
- cur_core_task_num = avg_task_num_per_core
170
- if core_idx < remain_task:
171
- cur_core_task_num += 1
172
- task_end = task_start + cur_core_task_num
173
- for task_idx in range(task_start, task_end):
174
- try:
175
- batch_idx, tr_idx = task_idx_to_batch_tr_idx[task_idx]
176
- except KeyError:
177
- raise ValueError("The argument 'task_idx' is not valid.")
178
- batch_start_end_pair = core_idx_to_batch_idx[core_idx]
179
- if batch_idx < batch_start_end_pair[0]:
180
- batch_start_end_pair[0] = batch_idx
181
- if batch_idx > batch_start_end_pair[1]:
182
- batch_start_end_pair[1] = batch_idx
183
- tr_start_end_pair = core_idx_to_tr_idx[core_idx][batch_idx]
184
- if tr_idx < tr_start_end_pair[0]:
185
- tr_start_end_pair[0] = tr_idx
186
- if tr_idx > tr_start_end_pair[1]:
187
- tr_start_end_pair[1] = tr_idx
188
- task_start = task_end
189
-
190
- core_idx_to_batch_info = self.tik_instance.Tensor(
191
- "int32", (self.core_num, 2), name="core_idx_to_batch_info", scope=UB
192
- )
193
- core_idx_to_tr_info = self.tik_instance.Tensor(
194
- "int32", (self.core_num, self.B, 2), name="core_idx_to_tr_info", scope=UB
195
- )
196
- for core_idx in core_idx_to_batch_idx:
197
- batch_start, batch_end = core_idx_to_batch_idx[core_idx]
198
- core_idx_to_batch_info[core_idx, 0] = batch_start
199
- core_idx_to_batch_info[core_idx, 1] = batch_end - batch_start + 1
200
- for batch_idx in core_idx_to_tr_idx[core_idx]:
201
- tr_start, tr_end = core_idx_to_tr_idx[core_idx][batch_idx]
202
- core_idx_to_tr_info[core_idx, batch_idx, 0] = tr_start
203
- core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end + 1
204
- return core_idx_to_batch_info, core_idx_to_tr_info
205
-
206
- def get_attn_mask_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
207
- """get attn mask gm offset"""
208
- if self.att_mask_shape[0] == 1:
209
- gm_offset = block_h_idx * (w * block_h) + block_w_idx * block_w
210
- else:
211
- gm_offset = ((batch_start + batch_idx) // self.head_num) * h * w \
212
- + block_h_idx * (w * block_h) + block_w_idx * block_w
213
- return gm_offset
214
-
215
- def parser_input_shape(self, alibi_mask, attn_mask, dim_mask, dropout_mask, k, q, v):
216
- """parser input shape"""
217
- self.has_attn_mask = False
218
- self.has_drop_mask = False
219
- self.has_alibi_mask = False
220
- if isinstance(q, dict):
221
- self.q_shape = q["shape"]
222
- self.k_shape = k["shape"]
223
- self.v_shape = v["shape"]
224
- self.dim_mask_shape = dim_mask["shape"]
225
- if attn_mask is not None:
226
- self.has_attn_mask = True
227
- self.att_mask_shape = attn_mask["shape"]
228
- if dropout_mask is not None:
229
- self.has_drop_mask = True
230
- self.drop_mask_shape = dropout_mask["shape"]
231
- if alibi_mask is not None:
232
- self.has_alibi_mask = True
233
- self.alibi_mask_shape = alibi_mask["shape"]
234
- else:
235
- self.q_shape = q.shape
236
- self.k_shape = k.shape
237
- self.v_shape = v.shape
238
- self.dim_mask_shape = dim_mask.shape
239
- if attn_mask is not None:
240
- self.has_attn_mask = True
241
- self.att_mask_shape = attn_mask.shape
242
- if dropout_mask is not None:
243
- self.has_drop_mask = True
244
- self.drop_mask_shape = dropout_mask.shape
245
- if alibi_mask is not None:
246
- self.has_alibi_mask = True
247
- self.alibi_mask_shape = alibi_mask.shape
248
-
249
- def define_inputs_outputs(self):
250
- """define inputs outputs"""
251
- self.define_common_inputs()
252
-
253
- self.define_custom_inputs()
254
-
255
- self.define_outputs()
256
-
257
- def init(self):
258
- """init parameters"""
259
- tiling_para: TilingPara = self.tiling_stgy.tiling()
260
-
261
- self.Br = tiling_para.Br
262
- self.last_Br = tiling_para.last_Br
263
- self.Bc = tiling_para.Bc
264
- self.last_Bc = tiling_para.last_Bc
265
- self.Tr = tiling_para.Tr
266
- self.Tc = tiling_para.Tc
267
-
268
- self.define_inputs_outputs()
269
-
270
- def define_common_inputs(self):
271
- """define common input gm tensors"""
272
- self.Q_gm = self.tik_instance.Tensor(FP16, self.q_shape, name="Q_gm", scope=GM)
273
- self.K_gm = self.tik_instance.Tensor(FP16, self.k_shape, name="K_gm", scope=GM)
274
- self.V_gm = self.tik_instance.Tensor(FP16, self.v_shape, name="V_gm", scope=GM)
275
- self.dim_mask_gm = self.tik_instance.Tensor(INT8, self.dim_mask_shape, name="mask_gm",
276
- scope=GM)
277
- if self.has_attn_mask:
278
- self.att_mask_gm = self.tik_instance.Tensor(FP16, self.att_mask_shape,
279
- name="att_mask_gm", scope=GM)
280
- if self.has_drop_mask:
281
- self.drop_mask_gm = self.tik_instance.Tensor(FP16, self.drop_mask_shape,
282
- name="drop_mask_gm", scope=GM)
283
- if self.has_alibi_mask:
284
- self.alibi_mask_gm = self.tik_instance.Tensor(FP16, self.alibi_mask_shape,
285
- name="alibi_mask_gm", scope=GM)
286
-
287
- def do_alibi_mask(self, Sij_ub, alibi_mask_gm_offset, m_aligned, n_aligned):
288
- """load alibi mask from gm to ub, then add Sij"""
289
- with self.tik_instance.new_stmt_scope(disable_sync=False):
290
- alibi_mask_ub = self.tik_instance.Tensor(FP16, (1, n_aligned),
291
- scope=UB, name="alibi_mask_ub")
292
- self.tik_instance.data_move(alibi_mask_ub, self.alibi_mask_gm[alibi_mask_gm_offset], 0, 1,
293
- n_aligned // 16, 0, 0)
294
- alibi_mask_ub_broadcast = self.tik_ops_utils.broadcast_row(alibi_mask_ub, (m_aligned, n_aligned))
295
- self.tik_instance.h_add(Sij_ub, Sij_ub, alibi_mask_ub_broadcast)
296
-
297
- def do_att_mask(self, Sij_ub, attn_mask_gm_offset, q_blk_height, kv_blk_height,
298
- q_blk_h_aligned, kv_blk_h_aligned):
299
- """load attn mask from gm to ub, then mul it by MASK_FILL_VALUE and add Sij"""
300
- with self.tik_instance.new_stmt_scope(disable_sync=False):
301
- att_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
302
- scope=UB, name="att_mask_ub")
303
- self.tik_instance.data_move(att_mask_ub, self.att_mask_gm[attn_mask_gm_offset], 0,
304
- q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
305
- self.tik_instance.h_mul(att_mask_ub, att_mask_ub, MASK_FILL_VALUE)
306
- self.tik_instance.h_add(Sij_ub, Sij_ub, att_mask_ub)
307
-
308
- def do_dropout_mask(self, Pij_ub, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
309
- q_blk_h_aligned, q_blk_height, precision_type=FP16):
310
- """load drop mask from gm to ub, then mul it by Pij"""
311
- with self.tik_instance.new_stmt_scope(disable_sync=False):
312
- dropout_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
313
- scope=UB, name="drop_mask_ub")
314
- self.tik_instance.data_move(dropout_mask_ub, self.drop_mask_gm[dropout_mask_gm_offset], 0,
315
- q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
316
- if precision_type == FP32:
317
- dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32, (q_blk_h_aligned, kv_blk_h_aligned),
318
- scope=UB, name="dropout_mask_ub_fp32")
319
- self.tik_instance.h_cast(dropout_mask_ub_fp32, dropout_mask_ub, "none")
320
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
321
- else:
322
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
323
-
324
- @abstractmethod
325
- def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
326
- """compute one core"""
327
- raise NotImplementedError
328
-
329
- def compute_process(self):
330
- """The compute process of FlashAttention"""
331
- self.init()
332
-
333
- core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_bath_info()
334
- with self.tik_instance.for_range(begint=0, endt=self.core_num, name="core_index",
335
- block_num=self.core_num) as core_idx:
336
- batch_start_s = self.tik_instance.Scalar("int32", name="batch_start_s")
337
- batch_num_s = self.tik_instance.Scalar("int32", name="batch_num_s")
338
-
339
- batch_start_s.set_as(core_idx_to_batch_info[core_idx, 0])
340
- batch_num_s.set_as(core_idx_to_batch_info[core_idx, 1])
341
-
342
- self.compute_one_core(batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx)
343
-
344
- self.tik_instance.BuildCCE(
345
- kernel_name=self.kernel_name,
346
- inputs=self.collect_inputs(),
347
- outputs=self.collect_outputs(),
348
- config={"dump_cce_code": False, "save_temp_cce_file": True, "enable_const_fold": True},
349
- enable_l2=True
350
- )