mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -77,6 +77,7 @@ class BiDense(Cell):
77
77
  bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
78
78
  The values of str refer to the function `initializer`. Default: ``None`` .
79
79
  has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
80
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
80
81
 
81
82
  Shape:
82
83
  - **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
@@ -90,8 +91,8 @@ class BiDense(Cell):
90
91
  are the same shape as the inputs.
91
92
 
92
93
  Dtype:
93
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2**.
94
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input1**.
94
+ - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
95
+ - **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
95
96
  - **output** (Tensor) - With the same dtype as the inputs.
96
97
 
97
98
  Weights:
@@ -133,7 +134,8 @@ class BiDense(Cell):
133
134
  out_channels,
134
135
  weight_init=None,
135
136
  bias_init=None,
136
- has_bias=True):
137
+ has_bias=True,
138
+ dtype=mstype.float32):
137
139
  super().__init__()
138
140
  self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
139
141
  self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
@@ -156,7 +158,8 @@ class BiDense(Cell):
156
158
  f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
157
159
  f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
158
160
  f"'in2_channels': {in2_channels}")
159
- self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)), 'weight')
161
+ self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
162
+ 'weight')
160
163
 
161
164
  if self.has_bias:
162
165
  if bias_init is None:
@@ -166,7 +169,7 @@ class BiDense(Cell):
166
169
  raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
167
170
  f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
168
171
  f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
169
- self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
172
+ self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
170
173
  self.bias_add = P.BiasAdd()
171
174
  self.matmul = P.MatMul()
172
175
 
@@ -64,11 +64,13 @@ class Embedding(Cell):
64
64
  embedding_size (int): The size of each embedding vector.
65
65
  use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
66
66
  embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
67
- Refer to class `initializer` for the values of string when a string
68
- is specified. Default: ``'normal'`` .
67
+ Refer to class `mindspore.common.initializer
68
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
69
+ for the values of string when a string is specified. Default: ``'normal'`` .
69
70
  dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
70
71
  padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
71
72
  will be initialized to zero. Default: ``None`` . The feature is inactivated.
73
+
72
74
  Inputs:
73
75
  - **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
74
76
  the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
@@ -145,9 +147,8 @@ class Embedding(Cell):
145
147
  return output
146
148
 
147
149
  def extend_repr(self):
148
- s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
149
- self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
150
- return s
150
+ return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
151
+ f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
151
152
 
152
153
 
153
154
  @_primexpr
@@ -190,6 +191,7 @@ class EmbeddingLookup(Cell):
190
191
  parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
191
192
  optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
192
193
  memory, so suggests setting a reasonable value to avoid insufficient memory.
194
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
193
195
 
194
196
  Inputs:
195
197
  - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@@ -231,7 +233,7 @@ class EmbeddingLookup(Cell):
231
233
 
232
234
  def __init__(self, vocab_size, embedding_size, param_init='normal',
233
235
  target='CPU', slice_mode='batch_slice', manual_shapes=None,
234
- max_norm=None, sparse=True, vocab_cache_size=0):
236
+ max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
235
237
  """Initialize EmbeddingLookup."""
236
238
  super(EmbeddingLookup, self).__init__()
237
239
  Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
@@ -255,8 +257,8 @@ class EmbeddingLookup(Cell):
255
257
  if enable_ps:
256
258
  self._process_vocab_cache(slice_mode)
257
259
  self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
258
- self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
259
- name='embedding_table')
260
+ self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
261
+ dtype=dtype), name='embedding_table')
260
262
  parallel_mode = _get_parallel_mode()
261
263
  is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
262
264
  self.gather_revert = P.Gather()
@@ -267,7 +269,7 @@ class EmbeddingLookup(Cell):
267
269
  if is_auto_parallel:
268
270
  self.unique = P.Unique().shard(((1,),))
269
271
  if self.cache_enable and enable_ps:
270
- self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
272
+ self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
271
273
  if is_auto_parallel:
272
274
  self.unique.add_prim_attr('cache_enable', True)
273
275
  indices_shape_size = 2
@@ -310,8 +312,8 @@ class EmbeddingLookup(Cell):
310
312
  else:
311
313
  if is_auto_parallel:
312
314
  support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
313
- raise ValueError("For '{}', the 'slice_mode' must be in {}, "
314
- "but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
315
+ raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
316
+ f"but got \"{slice_mode}\".")
315
317
  if self.cache_enable and not enable_ps:
316
318
  raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
317
319
  self.embedding_table.unique = self.forward_unique
@@ -354,7 +356,8 @@ class EmbeddingLookup(Cell):
354
356
  if _is_role_worker():
355
357
  self.vocab_size = self.vocab_cache_size
356
358
 
357
- def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init):
359
+ def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
360
+ dtype=mstype.float32):
358
361
  """PS embeddingLookup cache enable set."""
359
362
  if self.sparse:
360
363
  self.forward_unique = True
@@ -368,10 +371,10 @@ class EmbeddingLookup(Cell):
368
371
  if _enable_distributed_mindrt():
369
372
  self.rank_id = get_rank()
370
373
  if self.is_ps_server:
371
- self._slice_pserver_embeddings("zeros")
374
+ self._slice_pserver_embeddings("zeros", dtype=dtype)
372
375
  self._set_cache_enable_and_key_for_pserver(param_key)
373
376
 
374
- def _slice_pserver_embeddings(self, param_init):
377
+ def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
375
378
  '''
376
379
  Method to slice embedding tables on Parameter Servers.
377
380
  It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
@@ -399,7 +402,7 @@ class EmbeddingLookup(Cell):
399
402
  for i in range(server_num):
400
403
  self.embedding_table_list.append(Parameter(initializer(param_init,
401
404
  [self.embedding_table_vocab_dim_list[i],
402
- self.embedding_size]),
405
+ self.embedding_size], dtype=dtype),
403
406
  name="embedding_table_server_" + str(i)))
404
407
 
405
408
  self.embedding_offset.append(offset)
@@ -505,12 +508,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
505
508
  :class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
506
509
  feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
507
510
  Default: ``None`` .
508
- max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
509
- or None. Default: ``None`` .
511
+ max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
512
+ Default: ``None`` .
510
513
  sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
511
514
  Default: ``True`` .
512
515
  operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
513
516
  ``'MAX'`` . Default: ``'SUM'`` .
517
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
514
518
 
515
519
  Inputs:
516
520
  - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
@@ -529,12 +533,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
529
533
  TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
530
534
  TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
531
535
  ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
532
- ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
533
- ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice',
534
- 'table_column_slice'.
535
- ValueError: If `sparse` is False and `target` is 'CPU'.
536
- ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
537
- ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'.
536
+ ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
537
+ ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
538
+ ``'table_column_slice'`` .
539
+ ValueError: If `sparse` is False and `target` is ``'CPU'`` .
540
+ ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
541
+ ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
538
542
 
539
543
  Supported Platforms:
540
544
  ``Ascend`` ``GPU``
@@ -555,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
555
559
  OPERATOR_MAX = 'MAX'
556
560
 
557
561
  def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
558
- slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
562
+ slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
563
+ dtype=mstype.float32):
559
564
  """Initialize MultiFieldEmbeddingLookup."""
560
565
  super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
561
- slice_mode, feature_num_list, max_norm, sparse)
566
+ slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
562
567
  self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
563
568
  self.operator = operator
564
569
 
@@ -622,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
622
627
  self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
623
628
  else:
624
629
  if is_auto_parallel:
625
- raise ValueError("For '{}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' and \
626
- 'table_column_slice'], but got {}".format(self.cls_name, str(slice_mode)))
630
+ raise ValueError(
631
+ f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
632
+ f"and 'table_column_slice'], but got {str(slice_mode)}.")
627
633
 
628
634
  # Min value for fp32
629
635
  self.negative_inf_value = -3.402823466E+38
@@ -17,12 +17,11 @@ A FlashAttention Layer.
17
17
  """
18
18
  import math
19
19
 
20
- import mindspore.numpy as mnp
21
- from mindspore import ops
22
- from mindspore.common import dtype as mstype
20
+ import mindspore.common.dtype as mstype
23
21
  from mindspore.common.tensor import Tensor
22
+ from mindspore import ops
24
23
  from mindspore.nn.cell import Cell
25
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
24
+ from mindspore.ops.operations.nn_ops import FlashAttentionScore
26
25
 
27
26
  __all__ = ['FlashAttention']
28
27
 
@@ -45,25 +44,25 @@ class FlashAttention(Cell):
45
44
  Default 65536.
46
45
  next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
47
46
  Default 65536.
48
- tiling_stgy_name(str): A str to define tiling strategy of flash attention.
49
47
  dp(int): data parallel.
50
48
  Default 1.
51
49
  mp(int): model parallel.
52
50
  Default 1.
53
- high_precision(bool): This mode has higher precision but some performance loss.
51
+ high_precision(bool): This mode has higher precision but some performance loss. Only take effect on Ascend910A.
54
52
  Default False.
55
53
  have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
56
54
  Default True
57
55
  alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
58
56
  Default: False
57
+ use_mqa(bool): Using MQA if True, only take effect under 910B. Default: False.
59
58
 
60
59
 
61
60
  Inputs:
62
61
  - **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
63
62
  - **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
64
63
  - **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
65
- - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
66
- seq_length]): A matrix to pass masked information.
64
+ - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
65
+ [batch_size, seq_length, seq_length]): A matrix to pass masked information.
67
66
 
68
67
  Outputs:
69
68
  A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
@@ -92,35 +91,55 @@ class FlashAttention(Cell):
92
91
 
93
92
  def __init__(self,
94
93
  head_dim,
94
+ head_num,
95
95
  dropout_rate=0.0,
96
96
  prev_block_num=65536,
97
97
  next_block_num=65536,
98
- tiling_stgy_name="sparse",
99
98
  dp=1,
100
99
  mp=1,
101
100
  high_precision=False,
102
101
  have_attention_mask_batch=True,
103
- alibi=False
102
+ alibi=False,
103
+ use_mqa=False
104
104
  ):
105
105
  super(FlashAttention, self).__init__()
106
106
 
107
- self.flash_attention = get_flash_attention(
108
- prev_block_num=prev_block_num,
109
- next_block_num=next_block_num,
110
- tiling_stgy_name=tiling_stgy_name,
111
- high_precision=high_precision
112
- )
113
- self.flash_attention.add_prim_attr("primitive_target", "Ascend")
114
107
  scaling_constant = math.sqrt(head_dim)
115
- if scaling_constant != 0:
116
- self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
117
- else:
108
+ if scaling_constant == 0:
118
109
  raise ValueError("the scaling constant must not be 0.")
119
- self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
120
- self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
121
110
  self.dropout_rate = dropout_rate
122
- self.have_attention_mask_batch = have_attention_mask_batch
123
111
  self.alibi = alibi
112
+ self.have_attention_mask_batch = have_attention_mask_batch
113
+
114
+ self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
115
+ self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
116
+ self.reshape = ops.Reshape()
117
+ self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
118
+ self.zeros = ops.Zeros()
119
+ self.attn_cast = ops.Cast()
120
+ if use_mqa:
121
+ fa_strategies = ((dp, mp, 1, 1),
122
+ (dp, 1, 1, 1),
123
+ (dp, 1, 1, 1))
124
+ else:
125
+ fa_strategies = ((dp, mp, 1, 1),
126
+ (dp, mp, 1, 1),
127
+ (dp, mp, 1, 1))
128
+ if self.alibi:
129
+ self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
130
+ self.alibi_rescale_factor = Tensor([scaling_constant], dtype=mstype.float16)
131
+ fa_strategies += ((dp, mp, 1, 1),)
132
+ if dropout_rate > 1e-5:
133
+ fa_strategies += ((dp, mp, 1, 1),)
134
+ fa_strategies += ((dp, 1, 1, 1),)
135
+ self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
136
+ next_tokens=next_block_num,
137
+ keep_prob=1 - dropout_rate,
138
+ scale_value=1. / scaling_constant,
139
+ inner_precise=0,
140
+ input_layout="BNSD").shard(fa_strategies)
141
+
142
+ self.dropout_rate = dropout_rate
124
143
  if self.dropout_rate > 1e-5:
125
144
  self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
126
145
  self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
@@ -136,49 +155,7 @@ class FlashAttention(Cell):
136
155
  such as MatMul. Default: None.
137
156
  :return:
138
157
  """
139
- if in_strategy is not None:
140
- shard_stgy = list(in_strategy)
141
- shard_stgy.insert(3, (1,)) # dim_mask
142
- shard_stgy = tuple(shard_stgy)
143
- else:
144
- # default: dp=1, mp=1, construct inputs only contain query, key, value
145
- shard_stgy = (
146
- (1, 1, 1, 1),
147
- (1, 1, 1, 1),
148
- (1, 1, 1, 1),
149
- (1,), # dim_mask
150
- )
151
- self.flash_attention.shard(shard_stgy)
152
- dp = shard_stgy[0][0]
153
- mp = shard_stgy[0][1]
154
- self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
155
- inputs_tensor_map = [
156
- [3, 2, 1, 0],
157
- [3, 2, 1, 0],
158
- [3, 2, 1, 0],
159
- [-1]
160
- ]
161
- if self.have_attention_mask_batch:
162
- inputs_tensor_map.append([3, 1, 0])
163
- else:
164
- inputs_tensor_map.append([-1, 1, 0])
165
-
166
- # dropout_mask
167
- if self.dropout_rate > 1e-5:
168
- inputs_tensor_map.append([3, 2, 1, 0])
169
-
170
- if self.alibi:
171
- inputs_tensor_map.append([3, 2, 1, 0])
172
-
173
- self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
174
-
175
- self.flash_attention.add_prim_attr("outputs_tensor_map", [
176
- [3, 2, 1, 0], # O
177
- [3, 2, 1], # L
178
- [3, 2, 1] # M
179
- ])
180
- self.flash_attention.add_prim_attr("as_loss_divisor", 0)
181
- self.flash_attention.add_prim_attr("empty_mirror_ops", 1)
158
+ self.flash_attention.shard(in_strategy)
182
159
 
183
160
  def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
184
161
  """FlashAttention forward
@@ -189,35 +166,24 @@ class FlashAttention(Cell):
189
166
  :param alibi_mask: [bsz, head_num, 1, seq_len], if not None
190
167
  :return: output [bsz, head_num, seq_len, head_dim]
191
168
  """
192
- query = self.scale_mul(query, self.scale_factor)
193
- bsz, head_num, seq_len, head_dim = query.shape
194
- _, k_head_num, k_seq_len, _ = key.shape
195
- _, v_head_num, v_seq_len, _ = value.shape
196
- if head_num != k_head_num or head_num != v_head_num:
197
- raise ValueError(
198
- "the head_num of query, key and value must be the same, "
199
- "If different head_num are used, users need to change themselves to be same by tile.")
200
- if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
201
- raise ValueError(
202
- "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
169
+ bsz, head_num, seq_len, _ = query.shape
170
+ # 910B -- FlashAttentionScore
203
171
  if self.dropout_rate > 1e-5:
204
- drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
205
- tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
206
- ones = self.fill_v2(tensor_shape, self.tensor_one)
207
- ones = self.depend(ones, query)
208
- drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
209
- else:
210
- drop_mask = None
211
- if head_dim > 304:
212
- raise ValueError(
213
- "the head_dim must be less than 304, otherwise the ub would be OOM.")
214
- if head_dim % 16 != 0:
215
- padding_size = 16 - head_dim % 16
216
- query = mnp.pad(query, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
217
- key = mnp.pad(key, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
218
- value = mnp.pad(value, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
219
- output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
220
- output = ops.slice(output, [0, 0, 0, 0], [bsz, head_num, seq_len, head_dim])
172
+ drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
173
+ (bsz, head_num, seq_len, seq_len // 8))
221
174
  else:
222
- output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
175
+ drop_mask_bits = None
176
+ if self.alibi:
177
+ alibi_mask = self.alibi_rescale_mul(alibi_mask, self.cast(self.alibi_rescale_factor, alibi_mask.dtype))
178
+ # (B, S, S) -> (B, 1, S, S)
179
+ if self.have_attention_mask_batch:
180
+ attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
181
+ _, _, _, output = self.flash_attention(query,
182
+ key,
183
+ value,
184
+ alibi_mask,
185
+ drop_mask_bits,
186
+ None,
187
+ attn_mask,
188
+ None)
223
189
  return output
@@ -83,17 +83,17 @@ class ImageGradients(Cell):
83
83
  _check_input_4d(F.shape(images), "images", self.cls_name)
84
84
  batch_size, depth, height, width = P.Shape()(images)
85
85
  if height == 1:
86
- dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
86
+ dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
87
87
  else:
88
88
  dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
89
- dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
89
+ dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
90
90
  dy = P.Concat(2)((dy, dy_last))
91
91
 
92
92
  if width == 1:
93
- dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
93
+ dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
94
94
  else:
95
95
  dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
96
- dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
96
+ dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
97
97
  dx = P.Concat(3)((dx, dx_last))
98
98
  return dy, dx
99
99
 
@@ -571,7 +571,8 @@ class PixelShuffle(Cell):
571
571
  <https://arxiv.org/abs/1609.05158>`_ .
572
572
 
573
573
  Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
574
- :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions.
574
+ :math:`(*, C, H \times r, W \times r)`,
575
+ where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
575
576
 
576
577
  Note:
577
578
  The dimension of input Tensor on Ascend should be less than 7.
@@ -621,7 +622,8 @@ class PixelUnshuffle(Cell):
621
622
  <https://arxiv.org/abs/1609.05158>`_ .
622
623
 
623
624
  Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
624
- :math:`(*, C \times r^2, H, W)` , where r is a downscale factor and * is zero or more batch dimensions.
625
+ :math:`(*, C \times r^2, H, W)` ,
626
+ where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
625
627
 
626
628
  Args:
627
629
  downscale_factor (int): factor to unshuffle the input, and is a positive integer.
@@ -223,7 +223,6 @@ class LGamma(Cell):
223
223
  self.abs = P.Abs()
224
224
  self.shape = P.Shape()
225
225
  self.dtype = P.DType()
226
- self.fill = P.Fill()
227
226
  self.floor = P.Floor()
228
227
  self.equal = P.Equal()
229
228
  self.greater = P.Greater()
@@ -240,7 +239,7 @@ class LGamma(Cell):
240
239
  if F.is_sequence_value_unknown(self.shape(x)):
241
240
  infinity = self.ones_like(x) * F.cast(self.inf, input_dtype)
242
241
  else:
243
- infinity = self.fill(input_dtype, self.shape(x), self.inf)
242
+ infinity = F.fill(input_dtype, self.shape(x), self.inf)
244
243
 
245
244
  need_to_reflect = self.less(x, 0.5)
246
245
  neg_input = -x
@@ -335,7 +334,6 @@ class DiGamma(Cell):
335
334
  self.abs = P.Abs()
336
335
  self.shape = P.Shape()
337
336
  self.dtype = P.DType()
338
- self.fill = P.Fill()
339
337
  self.floor = P.Floor()
340
338
  self.equal = P.Equal()
341
339
  self.less = P.Less()
@@ -371,15 +369,12 @@ class DiGamma(Cell):
371
369
  reduced_input = x + self.abs(self.floor(x + 0.5))
372
370
  reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
373
371
  real_result = self.select(need_to_reflect, reflection, y)
374
- nan = self.fill(self.dtype(x), self.shape(x), np.nan)
372
+ nan = F.fill(self.dtype(x), self.shape(x), np.nan)
375
373
 
376
374
  return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
377
375
  nan, real_result)
378
376
 
379
377
 
380
- eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
381
-
382
-
383
378
  def _while_helper_func(cond, body, vals):
384
379
  while cond(vals).any():
385
380
  vals = body(vals)
@@ -391,13 +386,12 @@ def _igamma_series(ax, x, a, enabled):
391
386
 
392
387
  logicaland = P.LogicalAnd()
393
388
  greater = P.Greater()
394
- fill = P.Fill()
395
389
  shape = P.Shape()
396
390
  dtype = P.DType()
397
391
  select = P.Select()
398
392
 
399
393
  # If more data types are supported, this epsilon need to be selected.
400
- epsilon = eps_fp32
394
+ epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
401
395
 
402
396
  def cond(vals):
403
397
  enabled = vals[0]
@@ -424,8 +418,8 @@ def _igamma_series(ax, x, a, enabled):
424
418
  select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
425
419
  select(enabled, dans_da, vals[6]))
426
420
 
427
- ones = fill(dtype(a), shape(a), 1)
428
- zeros = fill(dtype(a), shape(a), 0)
421
+ ones = F.fill(dtype(a), shape(a), 1)
422
+ zeros = F.fill(dtype(a), shape(a), 0)
429
423
  vals = (enabled, a, ones, ones, x, zeros, zeros)
430
424
 
431
425
  vals = _while_helper_func(cond, body, vals)
@@ -441,13 +435,12 @@ def _igammac_continued_fraction(ax, x, a, enabled):
441
435
  greater = P.Greater()
442
436
  less = P.Less()
443
437
  notequal = P.NotEqual()
444
- fill = P.Fill()
445
438
  shape = P.Shape()
446
439
  dtype = P.DType()
447
440
  select = P.Select()
448
441
 
449
442
  # If more data types are supported, this epsilon need to be selected.
450
- epsilon = eps_fp32
443
+ epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
451
444
 
452
445
  def cond(vals):
453
446
  enabled = vals[0]
@@ -482,7 +475,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
482
475
  qk_is_nonzero = notequal(qk, 0)
483
476
  r = pk / qk
484
477
 
485
- t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
478
+ t = select(qk_is_nonzero, abs_x((ans - r) / r), F.fill(dtype(t), shape(t), 1))
486
479
  ans = select(qk_is_nonzero, r, ans)
487
480
 
488
481
  dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
@@ -490,7 +483,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
490
483
  dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
491
484
  grad_conditional = select(qk_is_nonzero,
492
485
  abs_x(dans_da_new - dans_da),
493
- fill(dtype(dans_da), shape(dans_da), 1))
486
+ F.fill(dtype(dans_da), shape(dans_da), 1))
494
487
 
495
488
  pkm2 = pkm1
496
489
  pkm1 = pk
@@ -525,16 +518,16 @@ def _igammac_continued_fraction(ax, x, a, enabled):
525
518
 
526
519
  y = 1 - a
527
520
  z = x + y + 1
528
- c = fill(dtype(x), shape(x), 0)
529
- pkm2 = fill(dtype(x), shape(x), 1)
521
+ c = F.fill(dtype(x), shape(x), 0)
522
+ pkm2 = F.fill(dtype(x), shape(x), 1)
530
523
  qkm2 = x
531
524
  pkm1 = x + 1
532
525
  qkm1 = z * x
533
526
  ans = pkm1 / qkm1
534
- t = fill(dtype(x), shape(x), 1)
535
- dpkm2_da = fill(dtype(x), shape(x), 0)
536
- dqkm2_da = fill(dtype(x), shape(x), 0)
537
- dpkm1_da = fill(dtype(x), shape(x), 0)
527
+ t = F.fill(dtype(x), shape(x), 1)
528
+ dpkm2_da = F.fill(dtype(x), shape(x), 0)
529
+ dqkm2_da = F.fill(dtype(x), shape(x), 0)
530
+ dpkm1_da = F.fill(dtype(x), shape(x), 0)
538
531
  dqkm1_da = -x
539
532
  dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
540
533
  vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
@@ -606,7 +599,6 @@ class IGamma(Cell):
606
599
  self.exp = P.Exp()
607
600
  self.select = P.Select()
608
601
  self.zeroslike = P.ZerosLike()
609
- self.fill = P.Fill()
610
602
  self.shape = P.Shape()
611
603
  self.dtype = P.DType()
612
604
  self.lgamma = LGamma()
@@ -625,15 +617,14 @@ class IGamma(Cell):
625
617
  x = F.broadcast_to(x, para_shape)
626
618
  a = F.broadcast_to(a, para_shape)
627
619
  x_is_zero = self.equal(x, 0)
628
- log_maxfloat = self.log_maxfloat32
629
- underflow = self.less(ax, self.neg(log_maxfloat))
620
+ underflow = self.less(ax, self.neg(self.log_maxfloat32))
630
621
  ax = self.exp(ax)
631
622
  enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
632
623
  output = self.select(use_igammac,
633
624
  1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)),
634
625
  _igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
635
626
  output = self.select(x_is_zero, self.zeroslike(output), output)
636
- output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
627
+ output = self.select(domain_error, F.fill(self.dtype(a), self.shape(a), np.nan), output)
637
628
  return output
638
629
 
639
630