mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -41,8 +41,8 @@ def check_dense_inputs_same_shape(input1, input2, prim_name=None):
41
41
  @constexpr(check=False)
42
42
  def _check_is_tensor(param_name, input_data, cls_name):
43
43
  """Internal function, used to check whether the input data is Tensor."""
44
- if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
45
- raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.tensor_type}', "
44
+ if input_data is not None and not isinstance(P.typeof(input_data), mstype.TensorType):
45
+ raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.TensorType}', "
46
46
  f"but got '{P.typeof(input_data)}'")
47
47
 
48
48
 
@@ -66,17 +66,18 @@ class BiDense(Cell):
66
66
  where :math:`x_{1}` is the first input tensor, :math:`x_{2}` is the second input tensor
67
67
  , :math:`A` is a weight matrix with the same data type as the :math:`x_{*}` created by the layer
68
68
  , and :math:`b` is a bias vector with the same data type as the :math:`x_{*}` created by the layer
69
- (only if has_bias is True).
69
+ (only if has_bias is ``True`` ).
70
70
 
71
71
  Args:
72
72
  in1_channels (int): The number of channels in the input1 space.
73
73
  in2_channels (int): The number of channels in the input2 space.
74
74
  out_channels (int): The number of channels in the output space.
75
75
  weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter.
76
- The values of str refer to the function `initializer`. Default: None.
76
+ The values of str refer to the function `initializer`. Default: ``None`` .
77
77
  bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
78
- The values of str refer to the function `initializer`. Default: None.
79
- has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: True.
78
+ The values of str refer to the function `initializer`. Default: ``None`` .
79
+ has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
80
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
80
81
 
81
82
  Shape:
82
83
  - **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
@@ -90,17 +91,17 @@ class BiDense(Cell):
90
91
  are the same shape as the inputs.
91
92
 
92
93
  Dtype:
93
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2**.
94
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input1**.
94
+ - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
95
+ - **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
95
96
  - **output** (Tensor) - With the same dtype as the inputs.
96
97
 
97
98
  Weights:
98
99
  - **weight** (Parameter) - The learnable weights with shape
99
100
  :math:`(\text{out_channels}, \text{in1_channels}, \text{in2_channels})`.
100
- When `weight_init` is `None`, the values are initialized from
101
+ When `weight_init` is ``None`` , the values are initialized from
101
102
  :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in1_channels}}`.
102
103
  - **bias** (Parameter) - The learnable bias of shape :math:`(\text{out_channels})`.
103
- If `has_bias` is `True` and `bias_init` is `None`, the values are initialized from
104
+ If `has_bias` is ``True`` and `bias_init` is ``None`` , the values are initialized from
104
105
  :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in1_channels}}`.
105
106
 
106
107
  Raises:
@@ -116,6 +117,9 @@ class BiDense(Cell):
116
117
  ``Ascend`` ``GPU`` ``CPU``
117
118
 
118
119
  Examples:
120
+ >>> import mindspore
121
+ >>> from mindspore import Tensor, nn
122
+ >>> import numpy as np
119
123
  >>> x1 = Tensor(np.random.randn(128, 20), mindspore.float32)
120
124
  >>> x2 = Tensor(np.random.randn(128, 30), mindspore.float32)
121
125
  >>> net = nn.BiDense(20, 30, 40)
@@ -130,7 +134,8 @@ class BiDense(Cell):
130
134
  out_channels,
131
135
  weight_init=None,
132
136
  bias_init=None,
133
- has_bias=True):
137
+ has_bias=True,
138
+ dtype=mstype.float32):
134
139
  super().__init__()
135
140
  self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
136
141
  self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
@@ -153,7 +158,8 @@ class BiDense(Cell):
153
158
  f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
154
159
  f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
155
160
  f"'in2_channels': {in2_channels}")
156
- self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)), 'weight')
161
+ self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
162
+ 'weight')
157
163
 
158
164
  if self.has_bias:
159
165
  if bias_init is None:
@@ -163,7 +169,7 @@ class BiDense(Cell):
163
169
  raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
164
170
  f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
165
171
  f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
166
- self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
172
+ self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
167
173
  self.bias_add = P.BiasAdd()
168
174
  self.matmul = P.MatMul()
169
175
 
@@ -62,13 +62,15 @@ class Embedding(Cell):
62
62
  Args:
63
63
  vocab_size (int): Size of the dictionary of embeddings.
64
64
  embedding_size (int): The size of each embedding vector.
65
- use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
65
+ use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
66
66
  embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
67
- Refer to class `initializer` for the values of string when a string
68
- is specified. Default: 'normal'.
69
- dtype (:class:`mindspore.dtype`): Data type of `x`. Default: mindspore.float32.
67
+ Refer to class `mindspore.common.initializer
68
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
69
+ for the values of string when a string is specified. Default: ``'normal'`` .
70
+ dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
70
71
  padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
71
- will be initialized to zero. Default: None. The feature is inactivated.
72
+ will be initialized to zero. Default: ``None`` . The feature is inactivated.
73
+
72
74
  Inputs:
73
75
  - **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
74
76
  the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
@@ -86,6 +88,9 @@ class Embedding(Cell):
86
88
  ``Ascend`` ``GPU`` ``CPU``
87
89
 
88
90
  Examples:
91
+ >>> import mindspore
92
+ >>> from mindspore import Tensor, nn
93
+ >>> import numpy as np
89
94
  >>> net = nn.Embedding(20000, 768, True)
90
95
  >>> x = Tensor(np.ones([8, 128]), mindspore.int32)
91
96
  >>> # Maps the input word IDs to word embedding.
@@ -126,13 +131,10 @@ class Embedding(Cell):
126
131
  self.array_mul = P.MatMul()
127
132
  self.reshape = P.Reshape()
128
133
  self.get_shp = P.Shape()
129
- self.get_tensor_shp = P.TensorShape()
130
134
  self.concat = P.Concat()
131
135
 
132
136
  def construct(self, ids):
133
137
  out_shape = self.get_shp(ids) + (self.embedding_size,)
134
- if F.is_sequence_value_unknown(self.get_shp(ids)):
135
- out_shape = self.concat((self.get_tensor_shp(ids), Tensor([self.embedding_size])))
136
138
  flat_ids = self.reshape_flat(ids, self.shp_flat)
137
139
 
138
140
  if self.use_one_hot:
@@ -145,12 +147,11 @@ class Embedding(Cell):
145
147
  return output
146
148
 
147
149
  def extend_repr(self):
148
- s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
149
- self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
150
- return s
150
+ return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
151
+ f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
151
152
 
152
153
 
153
- @constexpr
154
+ @_primexpr
154
155
  def _make_axis_range(start, end):
155
156
  axis = tuple(range(start, end))
156
157
  return axis
@@ -177,19 +178,20 @@ class EmbeddingLookup(Cell):
177
178
  embedding_size (int): The size of each embedding vector.
178
179
  param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
179
180
  Refer to class `initializer` for the values of string when a string
180
- is specified. Default: 'normal'.
181
+ is specified. Default: ``'normal'`` .
181
182
  target (str): Specifies the target where the op is executed. The value must in
182
- ['DEVICE', 'CPU']. Default: 'CPU'.
183
+ [ ``'DEVICE'`` , ``'CPU'`` ]. Default: ``'CPU'`` .
183
184
  slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
184
- :class:`mindspore.nn.EmbeddingLookup`. Default: 'nn.EmbeddingLookup.BATCH_SLICE'.
185
- manual_shapes (tuple): The accompaniment array in field slice mode. Default: None.
185
+ :class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'`` .
186
+ manual_shapes (tuple): The accompaniment array in field slice mode. Default: ``None`` .
186
187
  max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
187
- or None. Default: None
188
- sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
189
- vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
188
+ or None. Default: ``None`` .
189
+ sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: ``True`` .
190
+ vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: ``0`` . It is valid only in
190
191
  parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
191
192
  optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
192
193
  memory, so suggests setting a reasonable value to avoid insufficient memory.
194
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
193
195
 
194
196
  Inputs:
195
197
  - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@@ -216,6 +218,9 @@ class EmbeddingLookup(Cell):
216
218
  ``Ascend`` ``GPU`` ``CPU``
217
219
 
218
220
  Examples:
221
+ >>> import mindspore
222
+ >>> from mindspore import Tensor, nn
223
+ >>> import numpy as np
219
224
  >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
220
225
  >>> result = nn.EmbeddingLookup(4,2)(input_indices)
221
226
  >>> print(result.shape)
@@ -228,7 +233,7 @@ class EmbeddingLookup(Cell):
228
233
 
229
234
  def __init__(self, vocab_size, embedding_size, param_init='normal',
230
235
  target='CPU', slice_mode='batch_slice', manual_shapes=None,
231
- max_norm=None, sparse=True, vocab_cache_size=0):
236
+ max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
232
237
  """Initialize EmbeddingLookup."""
233
238
  super(EmbeddingLookup, self).__init__()
234
239
  Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
@@ -252,8 +257,8 @@ class EmbeddingLookup(Cell):
252
257
  if enable_ps:
253
258
  self._process_vocab_cache(slice_mode)
254
259
  self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
255
- self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
256
- name='embedding_table')
260
+ self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
261
+ dtype=dtype), name='embedding_table')
257
262
  parallel_mode = _get_parallel_mode()
258
263
  is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
259
264
  self.gather_revert = P.Gather()
@@ -264,7 +269,7 @@ class EmbeddingLookup(Cell):
264
269
  if is_auto_parallel:
265
270
  self.unique = P.Unique().shard(((1,),))
266
271
  if self.cache_enable and enable_ps:
267
- self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
272
+ self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
268
273
  if is_auto_parallel:
269
274
  self.unique.add_prim_attr('cache_enable', True)
270
275
  indices_shape_size = 2
@@ -300,15 +305,15 @@ class EmbeddingLookup(Cell):
300
305
  self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
301
306
  elif slice_mode == "batch_slice" and is_auto_parallel:
302
307
  indices_strategy = [get_group_size()]
303
- indices_strategy.extend([1]*(indices_shape_size - 1))
308
+ indices_strategy.extend([1] * (indices_shape_size - 1))
304
309
  indices_strategy = tuple(indices_strategy)
305
310
  self.gatherv2.shard(((1, 1), indices_strategy))
306
311
  self.embeddinglookup.shard(((1, 1), indices_strategy))
307
312
  else:
308
313
  if is_auto_parallel:
309
314
  support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
310
- raise ValueError("For '{}', the 'slice_mode' must be in {}, "
311
- "but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
315
+ raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
316
+ f"but got \"{slice_mode}\".")
312
317
  if self.cache_enable and not enable_ps:
313
318
  raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
314
319
  self.embedding_table.unique = self.forward_unique
@@ -351,7 +356,8 @@ class EmbeddingLookup(Cell):
351
356
  if _is_role_worker():
352
357
  self.vocab_size = self.vocab_cache_size
353
358
 
354
- def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init):
359
+ def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
360
+ dtype=mstype.float32):
355
361
  """PS embeddingLookup cache enable set."""
356
362
  if self.sparse:
357
363
  self.forward_unique = True
@@ -365,10 +371,10 @@ class EmbeddingLookup(Cell):
365
371
  if _enable_distributed_mindrt():
366
372
  self.rank_id = get_rank()
367
373
  if self.is_ps_server:
368
- self._slice_pserver_embeddings("zeros")
374
+ self._slice_pserver_embeddings("zeros", dtype=dtype)
369
375
  self._set_cache_enable_and_key_for_pserver(param_key)
370
376
 
371
- def _slice_pserver_embeddings(self, param_init):
377
+ def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
372
378
  '''
373
379
  Method to slice embedding tables on Parameter Servers.
374
380
  It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
@@ -396,7 +402,7 @@ class EmbeddingLookup(Cell):
396
402
  for i in range(server_num):
397
403
  self.embedding_table_list.append(Parameter(initializer(param_init,
398
404
  [self.embedding_table_vocab_dim_list[i],
399
- self.embedding_size]),
405
+ self.embedding_size], dtype=dtype),
400
406
  name="embedding_table_server_" + str(i)))
401
407
 
402
408
  self.embedding_offset.append(offset)
@@ -495,17 +501,20 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
495
501
  field_size (int): The field size of the final outputs.
496
502
  param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
497
503
  Refer to class `initializer` for the values of string when a string
498
- is specified. Default: 'normal'.
504
+ is specified. Default: ``'normal'`` .
499
505
  target (str): Specifies the target where the op is executed. The value must in
500
- ['DEVICE', 'CPU']. Default: 'CPU'.
506
+ [ ``'DEVICE'`` , ``'CPU'`` ]. Default: ``'CPU'`` .
501
507
  slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
502
- :class:`mindspore.nn.EmbeddingLookup`. Default: 'nn.EmbeddingLookup.BATCH_SLICE'.
503
- feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. Default: None.
504
- max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
505
- or None. Default: None
506
- sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
507
- operator (str): The pooling method for the features in one field. Support 'SUM', 'MEAN' and 'MAX'.
508
- Default: 'SUM'.
508
+ :class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
509
+ feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
510
+ Default: ``None`` .
511
+ max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
512
+ Default: ``None`` .
513
+ sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
514
+ Default: ``True`` .
515
+ operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
516
+ ``'MAX'`` . Default: ``'SUM'`` .
517
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
509
518
 
510
519
  Inputs:
511
520
  - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
@@ -524,17 +533,19 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
524
533
  TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
525
534
  TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
526
535
  ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
527
- ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
528
- ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice',
529
- 'table_column_slice'.
530
- ValueError: If `sparse` is False and `target` is 'CPU'.
531
- ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
532
- ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'.
536
+ ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
537
+ ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
538
+ ``'table_column_slice'`` .
539
+ ValueError: If `sparse` is False and `target` is ``'CPU'`` .
540
+ ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
541
+ ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
533
542
 
534
543
  Supported Platforms:
535
544
  ``Ascend`` ``GPU``
536
545
 
537
546
  Examples:
547
+ >>> import mindspore
548
+ >>> from mindspore import Tensor, nn
538
549
  >>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
539
550
  >>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
540
551
  >>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
@@ -548,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
548
559
  OPERATOR_MAX = 'MAX'
549
560
 
550
561
  def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
551
- slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
562
+ slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
563
+ dtype=mstype.float32):
552
564
  """Initialize MultiFieldEmbeddingLookup."""
553
565
  super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
554
- slice_mode, feature_num_list, max_norm, sparse)
566
+ slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
555
567
  self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
556
568
  self.operator = operator
557
569
 
@@ -615,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
615
627
  self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
616
628
  else:
617
629
  if is_auto_parallel:
618
- raise ValueError("For '{}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' and \
619
- 'table_column_slice'], but got {}".format(self.cls_name, str(slice_mode)))
630
+ raise ValueError(
631
+ f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
632
+ f"and 'table_column_slice'], but got {str(slice_mode)}.")
620
633
 
621
634
  # Min value for fp32
622
635
  self.negative_inf_value = -3.402823466E+38
@@ -0,0 +1,264 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ A FlashAttention Layer.
17
+ """
18
+ import math
19
+
20
+ import mindspore.common.dtype as mstype
21
+ from mindspore.common.tensor import Tensor
22
+ from mindspore import ops
23
+ from mindspore.nn.cell import Cell
24
+ from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
25
+ from mindspore.ops.operations.nn_ops import FlashAttentionScore
26
+ from mindspore._c_expression import MSContext
27
+
28
+ __all__ = ['FlashAttention']
29
+
30
+
31
+ class FlashAttention(Cell):
32
+ """Flash Attention Layer.
33
+
34
+ This function contains the flash attention primitives used in FlashAttention (see paper)
35
+ `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/pdf/2205.14135.pdf>`
36
+
37
+ Specifically, it includes the following:
38
+
39
+ 1. An interface for calling flashattention operation.
40
+ 2. Two configuration parameters for enabling local block sparse of flashattention.
41
+
42
+ Args:
43
+ head_dim(int): The hidden size of input.
44
+ dropout_rate(float): The dropout rate of the attention score. Default 0.0.
45
+ prev_block_num(int): A integer to define the number of blocks to look ahead for local block sparse attention.
46
+ Default 65536.
47
+ next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
48
+ Default 65536.
49
+ tiling_stgy_name(str): A str to define tiling strategy of flash attention.
50
+ dp(int): data parallel.
51
+ Default 1.
52
+ mp(int): model parallel.
53
+ Default 1.
54
+ high_precision(bool): This mode has higher precision but some performance loss.
55
+ Default False.
56
+ have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
57
+ Default True
58
+ alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
59
+ Default: False
60
+
61
+
62
+ Inputs:
63
+ - **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
64
+ - **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
65
+ - **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
66
+ - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
67
+ seq_length]): A matrix to pass masked information.
68
+
69
+ Outputs:
70
+ A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
71
+
72
+ Supported Platforms:
73
+ ``Ascend``
74
+
75
+ Examples:
76
+ >>> import numpy as np
77
+ >>> from mindspore import dtype as mstype
78
+ >>> from mindspore.nn.layer.flash_attention import FlashAttention
79
+ >>> from mindspore import Tensor
80
+ >>> model = FlashAttention(head_dim=128,
81
+ ... dropout_rate=0.1,
82
+ ... prev_block_num=7,
83
+ ... next_block_num=0
84
+ ... )
85
+ >>> query = Tensor(np.ones((2, 16, 4096, 128)), mstype.float16)
86
+ >>> key = Tensor(np.ones((2, 16, 4096, 128)), mstype.float16)
87
+ >>> value = Tensor(np.ones((2, 16, 4096, 128)), mstype.float16)
88
+ >>> attention_mask = Tensor(np.ones((2, 4096, 4096)), mstype.float16)
89
+ >>> output = model(query, key, value, attention_mask)
90
+ >>> print(output.shape)
91
+ (2, 16, 4096, 128)
92
+ """
93
+
94
+ def __init__(self,
95
+ head_dim,
96
+ head_num,
97
+ dropout_rate=0.0,
98
+ prev_block_num=65536,
99
+ next_block_num=65536,
100
+ tiling_stgy_name="sparse",
101
+ dp=1,
102
+ mp=1,
103
+ high_precision=False,
104
+ have_attention_mask_batch=True,
105
+ alibi=False
106
+ ):
107
+ super(FlashAttention, self).__init__()
108
+
109
+ scaling_constant = math.sqrt(head_dim)
110
+ if scaling_constant == 0:
111
+ raise ValueError("the scaling constant must not be 0.")
112
+ self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
113
+
114
+ self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
115
+ if self.is_910A:
116
+ self.flash_attention = get_flash_attention(
117
+ prev_block_num=prev_block_num,
118
+ next_block_num=next_block_num,
119
+ tiling_stgy_name=tiling_stgy_name,
120
+ high_precision=high_precision
121
+ )
122
+ self.flash_attention.add_prim_attr("primitive_target", "Ascend")
123
+ else:
124
+ if alibi:
125
+ raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
126
+ self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
127
+ self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
128
+ self.reshape = ops.Reshape()
129
+ self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
130
+ self.zeros = ops.Zeros()
131
+ self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
132
+ fa_strategies = ((dp, 1, mp),
133
+ (dp, 1, mp),
134
+ (dp, 1, mp),
135
+ (dp, 1, 1, 1))
136
+ if dropout_rate > 1e-5:
137
+ fa_strategies += ((dp, mp, 1, 1),)
138
+ self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
139
+ next_tokens=next_block_num,
140
+ keep_prob=1 - dropout_rate,
141
+ scale_value=1.0,
142
+ inner_precise=0 if high_precision else 1).shard(fa_strategies)
143
+
144
+ self.ones = ops.Ones()
145
+ self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
146
+ self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
147
+ self.dropout_rate = dropout_rate
148
+ self.have_attention_mask_batch = have_attention_mask_batch
149
+ self.alibi = alibi
150
+ if self.dropout_rate > 1e-5:
151
+ self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
152
+ self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
153
+ self.tensor_one = Tensor(1.0, mstype.float16)
154
+ self.drop_gen_mask = ops.DropoutGenMask()
155
+ self.do_dropout = ops.DropoutDoMask().shard(((dp, mp, 1, 1),))
156
+ self.depend = ops.Depend()
157
+
158
+ def shard(self, in_strategy=None, out_strategy=None):
159
+ """Distributed configuration of FlashAttention
160
+ :param in_strategy: Describe the split strategy of operator input. Default: None.
161
+ :param out_strategy: Describe the split strategy of operator output, it is only for certain operators,
162
+ such as MatMul. Default: None.
163
+ :return:
164
+ """
165
+ if in_strategy is None:
166
+ # default: dp=1, mp=1, construct inputs only contain query, key, value
167
+ in_strategy = (
168
+ (1, 1, 1, 1),
169
+ (1, 1, 1, 1),
170
+ (1, 1, 1, 1),
171
+ )
172
+ self.flash_attention.shard(in_strategy)
173
+ dp = in_strategy[0][0]
174
+ mp = in_strategy[0][1]
175
+ self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
176
+ inputs_tensor_map = [
177
+ [3, 2, 1, 0],
178
+ [3, 2, 1, 0],
179
+ [3, 2, 1, 0],
180
+ ]
181
+ if self.have_attention_mask_batch:
182
+ inputs_tensor_map.append([3, 1, 0])
183
+ else:
184
+ inputs_tensor_map.append([-1, 1, 0])
185
+
186
+ input_empty_args_num = 2
187
+ # dropout_mask
188
+ if self.dropout_rate > 1e-5:
189
+ input_empty_args_num -= 1
190
+ inputs_tensor_map.append([3, 2, 1, 0])
191
+
192
+ if self.alibi:
193
+ input_empty_args_num -= 1
194
+ inputs_tensor_map.append([3, 2, 1, 0])
195
+
196
+ self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
197
+
198
+ self.flash_attention.add_prim_attr("outputs_tensor_map", [
199
+ [3, 2, 1, 0], # O
200
+ [3, 2, 1], # L
201
+ [3, 2, 1] # M
202
+ ])
203
+ self.flash_attention.add_prim_attr("as_loss_divisor", 0)
204
+ self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
205
+
206
+ def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
207
+ """FlashAttention forward
208
+ :param query: [bsz, head_num, seq_len, head_dim]
209
+ :param key: [bsz, head_num, seq_len, head_dim]
210
+ :param value: [bsz, head_num, seq_len, head_dim]
211
+ :param attn_mask: [1 or bsz, seq_len, seq_len], if not None
212
+ :param alibi_mask: [bsz, head_num, 1, seq_len], if not None
213
+ :return: output [bsz, head_num, seq_len, head_dim]
214
+ """
215
+ query = self.scale_mul(query, self.scale_factor)
216
+ bsz, head_num, seq_len, head_dim = query.shape
217
+ _, k_head_num, k_seq_len, _ = key.shape
218
+ _, v_head_num, v_seq_len, _ = value.shape
219
+ if head_num != k_head_num or head_num != v_head_num:
220
+ raise ValueError(
221
+ "the head_num of query, key and value must be the same, "
222
+ "If different head_num are used, users need to change themselves to be same by tile.")
223
+ if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
224
+ raise ValueError(
225
+ "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
226
+
227
+ if head_dim > 304:
228
+ raise ValueError(
229
+ "the head_dim must be less than 304, otherwise the ub would be OOM.")
230
+
231
+ if self.is_910A:
232
+ # 910A -- FlashAttentionPrimtive
233
+ if self.dropout_rate > 1e-5:
234
+ drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
235
+ tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
236
+ ones = self.fill_v2(tensor_shape, self.tensor_one)
237
+ ones = self.depend(ones, query)
238
+ drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
239
+ else:
240
+ drop_mask = None
241
+ output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
242
+ else:
243
+ # FlashAttentionScore
244
+ # Useless input, just for binary calls.
245
+ if self.dropout_rate > 1e-5:
246
+ drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
247
+ (bsz, head_num, seq_len, seq_len // 8))
248
+ else:
249
+ drop_mask_bits = None
250
+ # (B, N, S, D) -> (B, S, H)
251
+ query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
252
+ key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
253
+ value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
254
+ attn_mask = self.attn_expand_dims(attn_mask, 1)
255
+ output, _, _ = self.flash_attention(query,
256
+ key,
257
+ value,
258
+ attn_mask,
259
+ drop_mask_bits,
260
+ None,
261
+ None)
262
+ output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
263
+
264
+ return output