mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.10__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -77,6 +77,7 @@ class BiDense(Cell):
77
77
  bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
78
78
  The values of str refer to the function `initializer`. Default: ``None`` .
79
79
  has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
80
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
80
81
 
81
82
  Shape:
82
83
  - **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
@@ -90,8 +91,8 @@ class BiDense(Cell):
90
91
  are the same shape as the inputs.
91
92
 
92
93
  Dtype:
93
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2**.
94
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input1**.
94
+ - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
95
+ - **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
95
96
  - **output** (Tensor) - With the same dtype as the inputs.
96
97
 
97
98
  Weights:
@@ -133,7 +134,8 @@ class BiDense(Cell):
133
134
  out_channels,
134
135
  weight_init=None,
135
136
  bias_init=None,
136
- has_bias=True):
137
+ has_bias=True,
138
+ dtype=mstype.float32):
137
139
  super().__init__()
138
140
  self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
139
141
  self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
@@ -156,7 +158,8 @@ class BiDense(Cell):
156
158
  f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
157
159
  f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
158
160
  f"'in2_channels': {in2_channels}")
159
- self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)), 'weight')
161
+ self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
162
+ 'weight')
160
163
 
161
164
  if self.has_bias:
162
165
  if bias_init is None:
@@ -166,7 +169,7 @@ class BiDense(Cell):
166
169
  raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
167
170
  f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
168
171
  f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
169
- self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
172
+ self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
170
173
  self.bias_add = P.BiasAdd()
171
174
  self.matmul = P.MatMul()
172
175
 
@@ -64,11 +64,13 @@ class Embedding(Cell):
64
64
  embedding_size (int): The size of each embedding vector.
65
65
  use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
66
66
  embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
67
- Refer to class `initializer` for the values of string when a string
68
- is specified. Default: ``'normal'`` .
67
+ Refer to class `mindspore.common.initializer
68
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
69
+ for the values of string when a string is specified. Default: ``'normal'`` .
69
70
  dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
70
71
  padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
71
72
  will be initialized to zero. Default: ``None`` . The feature is inactivated.
73
+
72
74
  Inputs:
73
75
  - **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
74
76
  the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
@@ -145,9 +147,8 @@ class Embedding(Cell):
145
147
  return output
146
148
 
147
149
  def extend_repr(self):
148
- s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
149
- self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
150
- return s
150
+ return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
151
+ f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
151
152
 
152
153
 
153
154
  @_primexpr
@@ -190,6 +191,7 @@ class EmbeddingLookup(Cell):
190
191
  parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
191
192
  optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
192
193
  memory, so suggests setting a reasonable value to avoid insufficient memory.
194
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
193
195
 
194
196
  Inputs:
195
197
  - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@@ -231,7 +233,7 @@ class EmbeddingLookup(Cell):
231
233
 
232
234
  def __init__(self, vocab_size, embedding_size, param_init='normal',
233
235
  target='CPU', slice_mode='batch_slice', manual_shapes=None,
234
- max_norm=None, sparse=True, vocab_cache_size=0):
236
+ max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
235
237
  """Initialize EmbeddingLookup."""
236
238
  super(EmbeddingLookup, self).__init__()
237
239
  Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
@@ -255,8 +257,8 @@ class EmbeddingLookup(Cell):
255
257
  if enable_ps:
256
258
  self._process_vocab_cache(slice_mode)
257
259
  self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
258
- self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
259
- name='embedding_table')
260
+ self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
261
+ dtype=dtype), name='embedding_table')
260
262
  parallel_mode = _get_parallel_mode()
261
263
  is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
262
264
  self.gather_revert = P.Gather()
@@ -267,7 +269,7 @@ class EmbeddingLookup(Cell):
267
269
  if is_auto_parallel:
268
270
  self.unique = P.Unique().shard(((1,),))
269
271
  if self.cache_enable and enable_ps:
270
- self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
272
+ self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
271
273
  if is_auto_parallel:
272
274
  self.unique.add_prim_attr('cache_enable', True)
273
275
  indices_shape_size = 2
@@ -310,8 +312,8 @@ class EmbeddingLookup(Cell):
310
312
  else:
311
313
  if is_auto_parallel:
312
314
  support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
313
- raise ValueError("For '{}', the 'slice_mode' must be in {}, "
314
- "but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
315
+ raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
316
+ f"but got \"{slice_mode}\".")
315
317
  if self.cache_enable and not enable_ps:
316
318
  raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
317
319
  self.embedding_table.unique = self.forward_unique
@@ -354,7 +356,8 @@ class EmbeddingLookup(Cell):
354
356
  if _is_role_worker():
355
357
  self.vocab_size = self.vocab_cache_size
356
358
 
357
- def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init):
359
+ def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
360
+ dtype=mstype.float32):
358
361
  """PS embeddingLookup cache enable set."""
359
362
  if self.sparse:
360
363
  self.forward_unique = True
@@ -368,10 +371,10 @@ class EmbeddingLookup(Cell):
368
371
  if _enable_distributed_mindrt():
369
372
  self.rank_id = get_rank()
370
373
  if self.is_ps_server:
371
- self._slice_pserver_embeddings("zeros")
374
+ self._slice_pserver_embeddings("zeros", dtype=dtype)
372
375
  self._set_cache_enable_and_key_for_pserver(param_key)
373
376
 
374
- def _slice_pserver_embeddings(self, param_init):
377
+ def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
375
378
  '''
376
379
  Method to slice embedding tables on Parameter Servers.
377
380
  It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
@@ -399,7 +402,7 @@ class EmbeddingLookup(Cell):
399
402
  for i in range(server_num):
400
403
  self.embedding_table_list.append(Parameter(initializer(param_init,
401
404
  [self.embedding_table_vocab_dim_list[i],
402
- self.embedding_size]),
405
+ self.embedding_size], dtype=dtype),
403
406
  name="embedding_table_server_" + str(i)))
404
407
 
405
408
  self.embedding_offset.append(offset)
@@ -505,12 +508,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
505
508
  :class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
506
509
  feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
507
510
  Default: ``None`` .
508
- max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
509
- or None. Default: ``None`` .
511
+ max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
512
+ Default: ``None`` .
510
513
  sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
511
514
  Default: ``True`` .
512
515
  operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
513
516
  ``'MAX'`` . Default: ``'SUM'`` .
517
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
514
518
 
515
519
  Inputs:
516
520
  - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
@@ -529,12 +533,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
529
533
  TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
530
534
  TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
531
535
  ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
532
- ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
533
- ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice',
534
- 'table_column_slice'.
535
- ValueError: If `sparse` is False and `target` is 'CPU'.
536
- ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
537
- ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'.
536
+ ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
537
+ ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
538
+ ``'table_column_slice'`` .
539
+ ValueError: If `sparse` is False and `target` is ``'CPU'`` .
540
+ ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
541
+ ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
538
542
 
539
543
  Supported Platforms:
540
544
  ``Ascend`` ``GPU``
@@ -555,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
555
559
  OPERATOR_MAX = 'MAX'
556
560
 
557
561
  def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
558
- slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
562
+ slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
563
+ dtype=mstype.float32):
559
564
  """Initialize MultiFieldEmbeddingLookup."""
560
565
  super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
561
- slice_mode, feature_num_list, max_norm, sparse)
566
+ slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
562
567
  self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
563
568
  self.operator = operator
564
569
 
@@ -622,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
622
627
  self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
623
628
  else:
624
629
  if is_auto_parallel:
625
- raise ValueError("For '{}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' and \
626
- 'table_column_slice'], but got {}".format(self.cls_name, str(slice_mode)))
630
+ raise ValueError(
631
+ f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
632
+ f"and 'table_column_slice'], but got {str(slice_mode)}.")
627
633
 
628
634
  # Min value for fp32
629
635
  self.negative_inf_value = -3.402823466E+38
@@ -17,12 +17,13 @@ A FlashAttention Layer.
17
17
  """
18
18
  import math
19
19
 
20
- import mindspore.numpy as mnp
21
- from mindspore import ops
22
- from mindspore.common import dtype as mstype
20
+ import mindspore.common.dtype as mstype
23
21
  from mindspore.common.tensor import Tensor
22
+ from mindspore import ops
24
23
  from mindspore.nn.cell import Cell
25
24
  from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
25
+ from mindspore.ops.operations.nn_ops import FlashAttentionScore
26
+ from mindspore._c_expression import MSContext
26
27
 
27
28
  __all__ = ['FlashAttention']
28
29
 
@@ -56,14 +57,15 @@ class FlashAttention(Cell):
56
57
  Default True
57
58
  alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
58
59
  Default: False
60
+ use_mqa(bool): Using MHA if True, only take effect under 910B. Default: False.
59
61
 
60
62
 
61
63
  Inputs:
62
64
  - **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
63
65
  - **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
64
66
  - **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
65
- - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
66
- seq_length]): A matrix to pass masked information.
67
+ - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
68
+ [batch_size, seq_length, seq_length]): A matrix to pass masked information.
67
69
 
68
70
  Outputs:
69
71
  A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
@@ -92,6 +94,7 @@ class FlashAttention(Cell):
92
94
 
93
95
  def __init__(self,
94
96
  head_dim,
97
+ head_num,
95
98
  dropout_rate=0.0,
96
99
  prev_block_num=65536,
97
100
  next_block_num=65536,
@@ -100,27 +103,63 @@ class FlashAttention(Cell):
100
103
  mp=1,
101
104
  high_precision=False,
102
105
  have_attention_mask_batch=True,
103
- alibi=False
106
+ alibi=False,
107
+ use_mqa=False
104
108
  ):
105
109
  super(FlashAttention, self).__init__()
106
110
 
107
- self.flash_attention = get_flash_attention(
108
- prev_block_num=prev_block_num,
109
- next_block_num=next_block_num,
110
- tiling_stgy_name=tiling_stgy_name,
111
- high_precision=high_precision
112
- )
113
- self.flash_attention.add_prim_attr("primitive_target", "Ascend")
114
111
  scaling_constant = math.sqrt(head_dim)
115
- if scaling_constant != 0:
116
- self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
117
- else:
112
+ if scaling_constant == 0:
118
113
  raise ValueError("the scaling constant must not be 0.")
119
- self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
120
- self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
121
114
  self.dropout_rate = dropout_rate
122
- self.have_attention_mask_batch = have_attention_mask_batch
123
- self.alibi = alibi
115
+ self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "ascend910"
116
+ if self.is_910A:
117
+ self.scale_factor = Tensor([1. / math.sqrt(scaling_constant)], dtype=mstype.float16)
118
+ self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
119
+ self.ones = ops.Ones()
120
+ self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
121
+ self.have_attention_mask_batch = have_attention_mask_batch
122
+ self.alibi = alibi
123
+ self.flash_attention = get_flash_attention(
124
+ prev_block_num=prev_block_num,
125
+ next_block_num=next_block_num,
126
+ tiling_stgy_name=tiling_stgy_name,
127
+ high_precision=high_precision
128
+ )
129
+ self.flash_attention.add_prim_attr("primitive_target", "Ascend")
130
+ fa_strategies = ((dp, mp, 1, 1),
131
+ (dp, mp, 1, 1),
132
+ (dp, mp, 1, 1))
133
+ self.shard(fa_strategies)
134
+ else:
135
+ if alibi:
136
+ raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
137
+ self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
138
+ self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
139
+ self.reshape = ops.Reshape()
140
+ self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
141
+ self.zeros = ops.Zeros()
142
+ self.attn_cast = ops.Cast()
143
+ if use_mqa:
144
+ fa_strategies = ((dp, mp, 1, 1),
145
+ (dp, 1, 1, 1),
146
+ (dp, 1, 1, 1),
147
+ (dp, 1, 1, 1))
148
+ else:
149
+ fa_strategies = ((dp, mp, 1, 1),
150
+ (dp, mp, 1, 1),
151
+ (dp, mp, 1, 1),
152
+ (dp, 1, 1, 1))
153
+ if dropout_rate > 1e-5:
154
+ fa_strategies += ((dp, mp, 1, 1),)
155
+ self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
156
+ next_tokens=next_block_num,
157
+ keep_prob=1 - dropout_rate,
158
+ scale_value=1. / scaling_constant,
159
+ inner_precise=0 if high_precision else 1,
160
+ input_layout="BNSD").shard(fa_strategies)
161
+
162
+ self.dropout_rate = dropout_rate
124
163
  if self.dropout_rate > 1e-5:
125
164
  self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
126
165
  self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
@@ -136,49 +175,49 @@ class FlashAttention(Cell):
136
175
  such as MatMul. Default: None.
137
176
  :return:
138
177
  """
139
- if in_strategy is not None:
140
- shard_stgy = list(in_strategy)
141
- shard_stgy.insert(3, (1,)) # dim_mask
142
- shard_stgy = tuple(shard_stgy)
143
- else:
144
- # default: dp=1, mp=1, construct inputs only contain query, key, value
145
- shard_stgy = (
146
- (1, 1, 1, 1),
147
- (1, 1, 1, 1),
148
- (1, 1, 1, 1),
149
- (1,), # dim_mask
150
- )
151
- self.flash_attention.shard(shard_stgy)
152
- dp = shard_stgy[0][0]
153
- mp = shard_stgy[0][1]
154
- self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
155
- inputs_tensor_map = [
156
- [3, 2, 1, 0],
157
- [3, 2, 1, 0],
158
- [3, 2, 1, 0],
159
- [-1]
160
- ]
161
- if self.have_attention_mask_batch:
162
- inputs_tensor_map.append([3, 1, 0])
163
- else:
164
- inputs_tensor_map.append([-1, 1, 0])
178
+ if self.is_910A:
179
+ if in_strategy is None:
180
+ # default: dp=1, mp=1, construct inputs only contain query, key, value
181
+ in_strategy = (
182
+ (1, 1, 1, 1),
183
+ (1, 1, 1, 1),
184
+ (1, 1, 1, 1),
185
+ )
186
+ self.flash_attention.shard(in_strategy)
187
+ dp = in_strategy[0][0]
188
+ mp = in_strategy[0][1]
189
+ self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
190
+ inputs_tensor_map = [
191
+ [3, 2, 1, 0],
192
+ [3, 2, 1, 0],
193
+ [3, 2, 1, 0],
194
+ ]
195
+ if self.have_attention_mask_batch:
196
+ inputs_tensor_map.append([3, 1, 0])
197
+ else:
198
+ inputs_tensor_map.append([-1, 1, 0])
165
199
 
166
- # dropout_mask
167
- if self.dropout_rate > 1e-5:
168
- inputs_tensor_map.append([3, 2, 1, 0])
200
+ input_empty_args_num = 2
201
+ # dropout_mask
202
+ if self.dropout_rate > 1e-5:
203
+ input_empty_args_num -= 1
204
+ inputs_tensor_map.append([3, 2, 1, 0])
169
205
 
170
- if self.alibi:
171
- inputs_tensor_map.append([3, 2, 1, 0])
206
+ if self.alibi:
207
+ input_empty_args_num -= 1
208
+ inputs_tensor_map.append([3, 2, 1, 0])
172
209
 
173
- self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
210
+ self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
174
211
 
175
- self.flash_attention.add_prim_attr("outputs_tensor_map", [
176
- [3, 2, 1, 0], # O
177
- [3, 2, 1], # L
178
- [3, 2, 1] # M
179
- ])
180
- self.flash_attention.add_prim_attr("as_loss_divisor", 0)
181
- self.flash_attention.add_prim_attr("empty_mirror_ops", 1)
212
+ self.flash_attention.add_prim_attr("outputs_tensor_map", [
213
+ [3, 2, 1, 0], # O
214
+ [3, 2, 1], # L
215
+ [3, 2, 1] # M
216
+ ])
217
+ self.flash_attention.add_prim_attr("as_loss_divisor", 0)
218
+ self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
219
+ else:
220
+ self.flash_attention.shard(in_strategy)
182
221
 
183
222
  def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
184
223
  """FlashAttention forward
@@ -189,35 +228,49 @@ class FlashAttention(Cell):
189
228
  :param alibi_mask: [bsz, head_num, 1, seq_len], if not None
190
229
  :return: output [bsz, head_num, seq_len, head_dim]
191
230
  """
192
- query = self.scale_mul(query, self.scale_factor)
193
231
  bsz, head_num, seq_len, head_dim = query.shape
194
- _, k_head_num, k_seq_len, _ = key.shape
195
- _, v_head_num, v_seq_len, _ = value.shape
196
- if head_num != k_head_num or head_num != v_head_num:
197
- raise ValueError(
198
- "the head_num of query, key and value must be the same, "
199
- "If different head_num are used, users need to change themselves to be same by tile.")
200
- if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
201
- raise ValueError(
202
- "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
203
- if self.dropout_rate > 1e-5:
204
- drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
205
- tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
206
- ones = self.fill_v2(tensor_shape, self.tensor_one)
207
- ones = self.depend(ones, query)
208
- drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
209
- else:
210
- drop_mask = None
211
- if head_dim > 304:
212
- raise ValueError(
213
- "the head_dim must be less than 304, otherwise the ub would be OOM.")
214
- if head_dim % 16 != 0:
215
- padding_size = 16 - head_dim % 16
216
- query = mnp.pad(query, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
217
- key = mnp.pad(key, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
218
- value = mnp.pad(value, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
219
- output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
220
- output = ops.slice(output, [0, 0, 0, 0], [bsz, head_num, seq_len, head_dim])
232
+ if self.is_910A:
233
+ _, k_head_num, k_seq_len, _ = key.shape
234
+ _, v_head_num, v_seq_len, _ = value.shape
235
+ if head_num != k_head_num or head_num != v_head_num:
236
+ raise ValueError(
237
+ "the head_num of query, key and value must be the same, "
238
+ "If different head_num are used, users need to change themselves to be same by tile.")
239
+ if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
240
+ raise ValueError(
241
+ "query, key, value seq_len must be a multiple of 16, "
242
+ "and the seq_len between key and value must be equal.")
243
+ # 910A -- FlashAttentionPrimtive
244
+ if head_dim > 304:
245
+ raise ValueError(
246
+ "the head_dim must be less than 304, otherwise the ub would be OOM.")
247
+ if self.dropout_rate > 1e-5:
248
+ drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
249
+ tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
250
+ ones = self.fill_v2(tensor_shape, self.tensor_one)
251
+ ones = self.depend(ones, query)
252
+ drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
253
+ else:
254
+ drop_mask = None
255
+ query = self.scale_mul(query, self.scale_factor)
256
+ key = self.scale_mul(key, self.scale_factor)
257
+ attn_mask = self.cast(attn_mask, mstype.float16)
258
+ output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
221
259
  else:
222
- output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
260
+ # 910B -- FlashAttentionScore
261
+ if self.dropout_rate > 1e-5:
262
+ drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
263
+ (bsz, head_num, seq_len, seq_len // 8))
264
+ else:
265
+ drop_mask_bits = None
266
+ # (B, S, S) -> (B, 1, S, S)
267
+ attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
268
+ output, _, _ = self.flash_attention(query,
269
+ key,
270
+ value,
271
+ attn_mask,
272
+ drop_mask_bits,
273
+ None,
274
+ None,
275
+ None)
223
276
  return output
@@ -83,17 +83,17 @@ class ImageGradients(Cell):
83
83
  _check_input_4d(F.shape(images), "images", self.cls_name)
84
84
  batch_size, depth, height, width = P.Shape()(images)
85
85
  if height == 1:
86
- dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
86
+ dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
87
87
  else:
88
88
  dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
89
- dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
89
+ dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
90
90
  dy = P.Concat(2)((dy, dy_last))
91
91
 
92
92
  if width == 1:
93
- dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
93
+ dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
94
94
  else:
95
95
  dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
96
- dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
96
+ dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
97
97
  dx = P.Concat(3)((dx, dx_last))
98
98
  return dy, dx
99
99
 
@@ -571,7 +571,8 @@ class PixelShuffle(Cell):
571
571
  <https://arxiv.org/abs/1609.05158>`_ .
572
572
 
573
573
  Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
574
- :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions.
574
+ :math:`(*, C, H \times r, W \times r)`,
575
+ where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
575
576
 
576
577
  Note:
577
578
  The dimension of input Tensor on Ascend should be less than 7.
@@ -621,7 +622,8 @@ class PixelUnshuffle(Cell):
621
622
  <https://arxiv.org/abs/1609.05158>`_ .
622
623
 
623
624
  Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
624
- :math:`(*, C \times r^2, H, W)` , where r is a downscale factor and * is zero or more batch dimensions.
625
+ :math:`(*, C \times r^2, H, W)` ,
626
+ where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
625
627
 
626
628
  Args:
627
629
  downscale_factor (int): factor to unshuffle the input, and is a positive integer.