mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -13,699 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """math Operations."""
16
- from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
17
- from mindspore.common import dtype as mstype
18
- from mindspore import _checkparam as validator
19
- from mindspore.ops.primitive import constexpr, _primexpr
16
+ import mindspore.ops as ops
20
17
  from mindspore.ops import functional as F
21
18
  from mindspore.ops.function.math_func import cummin as cummin_
22
- from mindspore.ops import operations as P
23
-
24
-
25
- @_primexpr
26
- def _check_validate_axis(axis, name):
27
- def _check(axis):
28
- if isinstance(axis, (tuple, list)):
29
- for idx, item in enumerate(axis):
30
- validator.check_value_type("axis[%d]" % idx, item, [int], name)
31
- _check(axis)
32
- axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
33
- return axis
34
-
35
-
36
- @constexpr
37
- def _check_validate_keepdims(keep_dims, name):
38
- keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
39
- return keep_dims
40
-
41
-
42
- @constexpr
43
- def is_const(x):
44
- return x is not None
45
-
46
-
47
- def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
48
- r"""
49
- Count number of nonzero elements across axis of input tensor.
50
-
51
- Args:
52
- x (Tensor): Input data is used to count non-zero numbers. With shape
53
- :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
54
- axis (Union[int, tuple(int), list(int)], optional): The dimensions to reduce.
55
- Default: (), reduce all dimensions.
56
- keep_dims (bool, optional): Whether to maintain dimensions specified by `axis`.
57
- If true, keep these reduced dimensions and the length is 1.
58
- If false, don't keep these dimensions. Default: False.
59
- dtype (Union[Number, mindspore.bool\_], optional): The data type of the output tensor.
60
- Default: mindspore.int32.
61
-
62
- Returns:
63
- Tensor, number of nonzero element across axis specified by `axis`.
64
- The data type is specified by `dtype`.
65
-
66
- Raises:
67
- TypeError: If `axis` is not int, tuple or list.
68
- ValueError: If any value in `axis` is not in range [-x.ndim, x.ndim).
69
-
70
- Supported Platforms:
71
- ``Ascend`` ``GPU`` ``CPU``
72
-
73
- Examples:
74
- >>> from mindspore import Tensor, ops
75
- >>> import numpy as np
76
- >>> # case 1: each value specified.
77
- >>> x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
78
- >>> nonzero_num = ops.count_nonzero(x=x, axis=[0, 1], keep_dims=True, dtype=mindspore.int32)
79
- >>> print(nonzero_num)
80
- [[3]]
81
- >>> # case 2: all value is default.
82
- >>> nonzero_num = ops.count_nonzero(x=x)
83
- >>> print(nonzero_num)
84
- 3
85
- >>> # case 3: axis value was specified 0.
86
- >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,])
87
- >>> print(nonzero_num)
88
- [1 2 0]
89
- >>> # case 4: axis value was specified 1.
90
- >>> nonzero_num = ops.count_nonzero(x=x, axis=[1,])
91
- >>> print(nonzero_num)
92
- [1 2]
93
- >>> # case 5: keep_dims value was specified.
94
- >>> nonzero_num = ops.count_nonzero(x=x, keep_dims=True)
95
- >>> print(nonzero_num)
96
- [[3]]
97
- >>> # case 6: keep_dims and axis value was specified.
98
- >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,], keep_dims=True)
99
- >>> print(nonzero_num)
100
- [[1 2 0]]
101
- """
102
-
103
- const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
104
- axis = _check_validate_axis(axis, "count_nonzero")
105
- keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
106
- const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
107
-
108
- not_equal = P.NotEqual()
109
- cast = P.Cast()
110
- reduce_sum = P.ReduceSum(keep_dims)
111
- zeros = P.Zeros()
112
- tensor_0 = zeros(x.shape, x.dtype)
113
- nonzero_bool = not_equal(x, tensor_0)
114
- # ReduceSum only support float16 or float32 tensor.
115
- nonzero_val = cast(nonzero_bool, mstype.float32)
116
- nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
117
-
118
- return nonzero_num
119
-
120
-
121
- @_primexpr
122
- def _int_to_tuple_conv(axes):
123
- """
124
- Converts ints to tuples in input axes, expected by most validation checks.
125
- """
126
- for x in [0, 1]:
127
- if isinstance(axes[x], int):
128
- axes[x] = (axes[x],)
129
- return axes
130
-
131
-
132
- @_primexpr
133
- def _check_axes(axes, prim_name=None):
134
- """
135
- Check for validity and type of axes passed to function.
136
- """
137
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
138
- validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
139
- if not isinstance(axes, int):
140
- axes = list(axes) # to avoid immutability issues
141
- if len(axes) != 2:
142
- raise ValueError(f"{msg_prefix} dimension of 'axes' must be 2, but got 'axes': {axes}.")
143
- axes = _int_to_tuple_conv(axes) # convert before length checks
144
- if len(axes[0]) != len(axes[1]):
145
- raise ValueError(f"{msg_prefix} first and second dim of 'axes' have to be the same size/length, "
146
- f"but got 'axes': {axes}.")
147
- if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
148
- raise ValueError(f"{msg_prefix} 'axes' cannot have duplicating values, but got {axes}.")
149
- return axes
150
-
151
-
152
- @constexpr
153
- def _typecheck_input(x1_type, x2_type, prim_name=None):
154
- """
155
- Check input tensor types to be valid and confirm they are the same type.
156
- """
157
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
158
- const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
159
- const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
160
- if x1_type != x2_type:
161
- raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} "
162
- f"and x2_type: {x2_type}.")
163
-
164
-
165
- @_primexpr
166
- def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
167
- """
168
- Convert from single int axes to 2d tuple if required
169
- """
170
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
171
-
172
- def _check_lt_zero(axes):
173
- if axes < 0:
174
- raise ValueError(f"{msg_prefix} 'axes' must be at least 0, but got {axes}.")
175
-
176
- def _check_len(axes, x1_shape, x2_shape):
177
- if axes > len(x1_shape) or axes > len(x2_shape):
178
- raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
179
- f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
180
-
181
-
182
- if isinstance(axes, int):
183
- _check_lt_zero(axes)
184
- if axes == 0:
185
- # outer product, no input validation required
186
- return [], []
187
- _check_len(axes, x1_shape, x2_shape)
188
- x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
189
- x2_ind = tuple(range(len(x2_shape))[:axes])
190
- axes = tuple((x1_ind, x2_ind))
191
- axes = _int_to_tuple_conv(axes)
192
- return axes
193
-
194
-
195
- @_primexpr
196
- def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
197
- """
198
- Checks for axes having the correct length according to input, for any value in axis
199
- being out of range with given shape and also checking for compatible axes values
200
- with given inputs.
201
- """
202
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
203
-
204
- def _check_len(axes_len, shape_dim_len, x_axes):
205
- if axes_len > shape_dim_len:
206
- raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
207
- f"{shape_dim_len}, but got {axes_len}.")
208
-
209
- def _check_value(x_axes, min_val, max_val):
210
- for _, x_value in enumerate(x_axes):
211
- if x_value > max_val or x_value < min_val:
212
- raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
213
- f"but got {x_value}.")
214
-
215
- shapes = [x1_shape, x2_shape]
216
-
217
- # axis length check
218
- for ix_input, x_axes in enumerate(axes):
219
- axes_len = len(x_axes)
220
- shape_dim_len = len(shapes[ix_input])
221
- _check_len(axes_len, shape_dim_len, x_axes)
222
-
223
- # axis values range check
224
- for ix_input, x_axes in enumerate(axes):
225
- comp_shape = shapes[ix_input]
226
- max_val = len(comp_shape) - 1
227
- min_val = -1 * len(comp_shape)
228
- _check_value(x_axes, min_val, max_val)
229
-
230
- # check axis value with input shape - both ways for axis valid
231
- invalid_a = False
232
- invalid_b = False
233
- for i in range(len(axes[0])): # sizes already validated
234
- if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
235
- invalid_a = True
236
- if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
237
- invalid_b = True
238
-
239
- def _check(invalid_a, invalid_b, x1_shape, x2_shape, axes):
240
- if invalid_a and invalid_b:
241
- raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
242
- f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
243
- f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
244
-
245
- _check(invalid_a, invalid_b, x1_shape, x2_shape, axes)
246
-
247
-
248
- @_primexpr
249
- def _calc_new_shape(shape, axes, position=0):
250
- """
251
- Calculate transpose and reshape parameters for input transformations,
252
- 'position' refers to whether tensor is first or second in the op.
253
- """
254
- contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
255
- prod_contraction = 1
256
- for i in contraction_axes:
257
- prod_contraction *= shape[i]
258
- free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
259
- free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
260
- prod_free = 1
261
- for free_dim in free_dims:
262
- prod_free *= free_dim
263
-
264
- transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
265
- new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
266
- return new_shape, transpose_perm, free_dims
267
-
268
-
269
- def tensor_dot(x1, x2, axes):
270
- """
271
- Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
272
-
273
- Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
274
- The same number of axes must be specified for both x1 and x2, and values must be within range
275
- of number of dims of both `a` and `b`.
276
-
277
- Selected dims in both inputs must also match.
278
-
279
- axes = 0 leads to outer product.
280
- axes = 1 leads to normal matrix multiplication when inputs both 2D.
281
- axes = 1 is the same as axes = ((1,),(0,)) where both `a` and `b` are 2D.
282
- axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
283
-
284
- Args:
285
- x1 (Tensor): First tensor in tensor_dot with datatype float16 or float32
286
- x2 (Tensor): Second tensor in tensor_dot with datatype float16 or float32
287
- axes (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
288
- tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
289
- automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
290
- as axes for each respectively.
291
-
292
- Returns:
293
- Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
294
- contracted in both inputs
295
-
296
- Raises:
297
- TypeError: If `x1` or `x2` is not a Tensor.
298
- TypeError: If `axes` is not one of the following: int, tuple, list.
299
-
300
- Supported Platforms:
301
- ``Ascend`` ``GPU`` ``CPU``
302
-
303
- Examples:
304
- >>> from mindspore import Tensor, ops
305
- >>> import mindspore
306
- >>> import numpy as np
307
- >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
308
- >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
309
- >>> output = ops.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
310
- >>> print(output)
311
- [[2. 2. 2]
312
- [2. 2. 2]
313
- [2. 2. 2]]
314
- """
315
- shape_op = P.Shape()
316
- reshape_op = P.Reshape()
317
- transpose_op = P.Transpose()
318
- matmul_op = P.MatMul(False, False)
319
- # input validity checks
320
- x1_shape = shape_op(x1)
321
- x2_shape = shape_op(x2)
322
- axes = _check_axes(axes, 'tensor_dot')
323
- # input compatibility check & axes format update
324
- axes = _axes_int_check(x1_shape, x2_shape, axes, 'tensor_dot')
325
- _validate_axes(x1_shape, x2_shape, axes, 'tensor_dot')
326
- x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
327
- x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
328
- output_shape = x1_ret + x2_ret # combine free axes from both inputs
329
- # run tensor_dot op
330
- x1_transposed = transpose_op(x1, x1_transpose_fwd)
331
- x2_transposed = transpose_op(x2, x2_transpose_fwd)
332
- x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
333
- x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
334
- mul_result = matmul_op(x1_reshaped, x2_reshaped)
335
- final_result = reshape_op(mul_result, output_shape)
336
- return final_result
337
-
338
-
339
- @_primexpr
340
- def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
341
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
342
- if len(x1_shape) < 2 or len(x2_shape) < 2:
343
- raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2',"
344
- f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
345
-
346
-
347
- @constexpr
348
- def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
349
- """
350
- Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
351
- """
352
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
353
- const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
354
- const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
355
- if x1_type != x2_type:
356
- raise TypeError(f"{msg_prefix} inputs must be the same type, but got "
357
- f"x1_type: {x1_type} and x2_type: {x2_type}.")
358
-
359
-
360
- @_primexpr
361
- def _get_transpose_shape(x2_shape):
362
- x2_shape_range = tuple(range(len(x2_shape)))
363
- x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
364
- return x2_shape_transpose
365
-
366
-
367
- def dot(input, other):
368
- """
369
- Computation a dot product between samples in two tensors.
370
-
371
- Args:
372
- input (Tensor): First tensor in Dot op with datatype float16 or float32,
373
- The rank must be greater than or equal to 2.
374
- other (Tensor): Second tensor in Dot op with datatype float16 or float32,
375
- The rank must be greater than or equal to 2.
376
-
377
- Returns:
378
- Tensor, dot product of input and other.
379
-
380
- Raises:
381
- TypeError: If type of input and other are not the same.
382
- TypeError: If dtype of input or other is not float16 or float32.
383
- ValueError: If rank of input or other less than 2.
384
-
385
- Supported Platforms:
386
- ``Ascend`` ``GPU`` ``CPU``
387
-
388
- Examples:
389
- >>> import numpy as np
390
- >>> import mindspore
391
- >>> from mindspore import Tensor, ops
392
- >>> input = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
393
- >>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
394
- >>> output = ops.dot(input, other)
395
- >>> print(output)
396
- [[[3. 3.]]
397
- [[3. 3.]]]
398
- >>> print(output.shape)
399
- (2, 1, 2)
400
- >>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
401
- >>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
402
- >>> output = ops.dot(input, other)
403
- >>> print(output)
404
- [[[[3. 3.]]
405
- [[3. 3.]]]]
406
- >>> print(output.shape)
407
- (1, 2, 1, 2)
408
- >>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
409
- >>> other = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
410
- >>> output = ops.dot(input, other)
411
- >>> print(output)
412
- [[[[3. 3.]
413
- [3. 3.]]
414
- [[3. 3.]
415
- [3. 3.]]]]
416
- >>> print(output.shape)
417
- (1, 2, 2, 2)
418
- >>> input = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
419
- >>> other = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
420
- >>> output = ops.dot(input, other)
421
- >>> print(output)
422
- [[[[[3. 3.]]
423
- [[3. 3.]]]
424
- [[[3. 3.]]
425
- [[3. 3.]]]]
426
- [[[[3. 3.]]
427
- [[3. 3.]]]
428
- [[[3. 3.]]
429
- [[3. 3.]]]]
430
- [[[[3. 3.]]
431
- [[3. 3.]]]
432
- [[[3. 3.]]
433
- [[3. 3.]]]]]
434
- >>> print(output.shape)
435
- (3, 2, 2, 1, 2)
436
- """
437
- shape_op = P.Shape()
438
- reshape_op = P.Reshape()
439
- transpose_op = P.Transpose()
440
- matmul_op = P.MatMul(False, False)
441
- input_shape = shape_op(input)
442
- other_shape = shape_op(other)
443
- input_type = F.dtype(input)
444
- other_type = F.dtype(other)
445
- _typecheck_input_dot(input_type, other_type, 'dot')
446
- _check_invalid_input(input_shape, other_shape, 'dot')
447
-
448
- if len(input_shape) > 2 or len(other_shape) > 2:
449
- other_shape_transpose = _get_transpose_shape(other_shape)
450
- other_transpose = transpose_op(other, other_shape_transpose)
451
- input_reshape = reshape_op(input, (-1, input_shape[-1]))
452
- other_reshape = reshape_op(other_transpose, (other_shape[-2], -1))
453
- mul_result = matmul_op(input_reshape, other_reshape)
454
- reshape_shape = input_shape[:-1] + other_shape[:-2] + other_shape[-1:]
455
- reshape_shape = (-1,) + reshape_shape[1:]
456
- return reshape_op(mul_result, reshape_shape)
457
- return matmul_op(input, other)
458
-
459
-
460
- @_primexpr
461
- def _get_batch_size(x1_shape, x2_shape, prim_name=None):
462
- """
463
- Get batch sizes from two inputs
464
- """
465
- def _check():
466
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
467
- if len(x1_shape) < 2 or len(x2_shape) < 2:
468
- raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
469
- f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
470
- _check()
471
- return x1_shape[0], x2_shape[0]
472
-
473
-
474
- @constexpr
475
- def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
476
- """
477
- Check input tensor types to be valid and confirm they are the same type for batch dot ops.
478
- """
479
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
480
- const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
481
- const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
482
- if x1_type != x2_type:
483
- raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} and "
484
- f"x2_type: {x2_type}.")
485
-
486
-
487
- @_primexpr
488
- def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
489
- """
490
- Check whether axes are valid and cast axes from tuple to list
491
- """
492
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
493
-
494
- def _check_1(axes):
495
- if 0 in axes:
496
- raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
497
- if len(axes) != 2:
498
- raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
499
-
500
- def _check_2(axes, x1_shape, x2_shape):
501
- if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
502
- raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
503
- f"and axes[1] must be less than or equal to len(x2_shape)."
504
- f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
505
-
506
- def _check_3(axes, x1_shape, x2_shape):
507
- if axes == 0:
508
- raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
509
-
510
- if axes > len(x1_shape) or axes > len(x2_shape):
511
- raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
512
- f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
513
-
514
- if axes is None:
515
- if len(x2_shape) == 2:
516
- axes = [len(x1_shape) - 1, len(x2_shape) - 1]
517
- else:
518
- axes = [len(x1_shape) - 1, len(x2_shape) - 2]
519
-
520
- if isinstance(axes, (list, tuple)):
521
- _check_1(axes)
522
- if isinstance(axes, tuple):
523
- axes = list(axes)
524
- validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
525
- validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
526
- # Reverse if axis < 0
527
- if axes[0] < 0:
528
- axes[0] += len(x1_shape)
529
- if axes[1] < 0:
530
- axes[1] += len(x2_shape)
531
- validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
532
- validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
533
- _check_2(axes, x1_shape, x2_shape)
534
- elif isinstance(axes, int):
535
- _check_3(axes, x1_shape, x2_shape)
536
- if axes < 0:
537
- axes = [axes + len(x1_shape), axes + len(x2_shape)]
538
- validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
539
- else:
540
- axes = [axes, axes]
541
- else:
542
- raise ValueError(f"{msg_prefix} type of 'axes' must be one of those: int, tuple(int), list(int), "
543
- f"but got {type(axes).__name__}.")
544
- return axes
545
-
546
-
547
- @_primexpr
548
- def _calc_new_shape_batchdot(shape, axes, position=0):
549
- """
550
- Calculate transpose and reshape parameters for input transformations,
551
- 'position' refers to whether tensor is first or second in the op.
552
- """
553
- axis = axes[position]
554
- contraction_axes = tuple([axis])
555
- prod_contraction = 1
556
- for i in contraction_axes:
557
- prod_contraction *= shape[i]
558
- free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
559
- free_dims = tuple(shape[i] for i in free_axes)
560
- prod_free = 1
561
- for free_dim in free_dims:
562
- prod_free *= free_dim
563
-
564
- transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
565
- transpose_perm = tuple([0]) + transpose_perm
566
- new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
567
- new_shape = tuple([shape[0]]) + new_shape
568
- return new_shape, transpose_perm, free_dims
569
-
570
-
571
- @_primexpr
572
- def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
573
- """
574
- Check whether batch size of two inputs are the same
575
- """
576
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
577
- if x1_batch_size != x2_batch_size:
578
- raise ValueError(f"{msg_prefix} inputs 'x1', 'x2' should have the same batch sizes, but got "
579
- f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
580
-
581
-
582
- @_primexpr
583
- def _get_output_shape(batch_size, x1_ret, x2_ret):
584
- """
585
- Compute output shape for batch dot
586
- """
587
- output_shape = tuple([batch_size]) + x1_ret + x2_ret
588
- return output_shape
589
-
590
-
591
- def batch_dot(x1, x2, axes=None):
592
- """
593
- Computation of batch dot product between samples in two tensors containing batch dims.
594
-
595
- .. math::
596
- output = x1[batch, :] * x2[batch, :]
597
-
598
- Args:
599
- x1 (Tensor): First tensor in Batch Dot op with datatype float32 and the rank of `x1` must be greater
600
- than or equal to 2.
601
- x2 (Tensor): Second tensor in Batch Dot op with datatype float32. The datatype of `x2` should
602
- be same as `x1` and the rank of `x2` must be greater than or equal to 2.
603
- axes (Union[int, tuple(int), list(int)]): Single value or tuple/list of length 2 with dimensions
604
- specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
605
- `a` input shape and last N dimensions from `b` input shape in order as axes for each respectively.
606
- Default: None.
607
-
608
- Returns:
609
- Tensor, batch dot product of `x1` and `x2`. For example, the Shape of output
610
- for input `x1` shapes (batch, d1, axes, d2) and `x2` shapes (batch, d3, axes, d4) is (batch, d1, d2, d3, d4),
611
- where d1 and d2 means any number.
612
-
613
- Raises:
614
- TypeError: If type of x1 and x2 are not the same.
615
- TypeError: If dtype of x1 or x2 is not float32.
616
- ValueError: If rank of x1 or x2 less than 2.
617
- ValueError: If batch dim used in axes.
618
- ValueError: If len(axes) less than 2.
619
- ValueError: If axes is not one of those: None, int, (int, int).
620
- ValueError: If axes reversed from negative int is too low for dimensions of input arrays.
621
- ValueError: If axes value is too high for dimensions of input arrays.
622
- ValueError: If batch size of x1 and x2 are not the same.
623
-
624
- Supported Platforms:
625
- ``Ascend`` ``GPU`` ``CPU``
626
-
627
- Examples:
628
- >>> from mindspore import Tensor, ops
629
- >>> import numpy as np
630
- >>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
631
- >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
632
- >>> axes = (-1, -2)
633
- >>> output = ops.batch_dot(x1, x2, axes)
634
- >>> print(output)
635
- [[[3. 3.]
636
- [3. 3.]]
637
- [[3. 3.]
638
- [3. 3.]]]
639
- >>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
640
- >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
641
- >>> axes = (1, 2)
642
- >>> output = ops.batch_dot(x1, x2, axes)
643
- >>> print(output)
644
- [[2. 2. 2.]
645
- [2. 2. 2.]]
646
- >>> print(output.shape)
647
- (2, 3)
648
- >>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
649
- >>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
650
- >>> output = ops.batch_dot(x1, x2)
651
- >>> print(output.shape)
652
- (6, 2, 3, 5, 8)
653
- >>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
654
- >>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
655
- >>> output = ops.batch_dot(x1, x2)
656
- >>> print(output.shape)
657
- (2, 2, 5, 5)
658
-
659
- """
660
-
661
- transpose_op = P.Transpose()
662
- batch_matmul_op = P.BatchMatMul()
663
- squeeze_one_op = P.Squeeze(1)
664
- squeeze_minus_one_op = P.Squeeze(-1)
665
- # input validity checks
666
- x1_shape = F.shape(x1)
667
- x2_shape = F.shape(x2)
668
- x1_dim_num = len(x1_shape)
669
- x2_dim_num = len(x2_shape)
670
- x1_type = F.dtype(x1)
671
- x2_type = F.dtype(x2)
672
-
673
- x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, 'batch_dot')
674
-
675
- _typecheck_input_batch_dot(x1_type, x2_type, 'batch_dot')
676
- _check_batch_size(x1_batch_size, x2_batch_size, 'batch_dot')
677
- axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, 'batch_dot')
678
-
679
- if x1_dim_num == 2:
680
- x1 = F.expand_dims(x1, 1)
681
- axes[0] += 1
682
- if x2_dim_num == 2:
683
- x2 = F.expand_dims(x2, 2)
684
-
685
- x1_shape = F.shape(x1)
686
- x2_shape = F.shape(x2)
687
-
688
- x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
689
- x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
690
- output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
691
-
692
- x1_transposed = transpose_op(x1, x1_transpose_fwd)
693
- x2_transposed = transpose_op(x2, x2_transpose_fwd)
694
- x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
695
- x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
696
-
697
- # Batch matmal op part
698
- mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
699
-
700
- final_result = F.reshape(mul_result, output_shape)
701
-
702
- # if the original dims are expanded, restore them from 3 to 2
703
- if x1_dim_num == 2:
704
- final_result = squeeze_one_op(final_result)
705
- elif x2_dim_num == 2:
706
- final_result = squeeze_minus_one_op(final_result)
707
-
708
- return final_result
19
+ from mindspore.ops._primitive_cache import _get_cache_prim
709
20
 
710
21
 
711
22
  def matmul(x1, x2, dtype=None):
@@ -808,10 +119,9 @@ def mm(input, mat2):
808
119
  >>> print(out.shape)
809
120
  (2, 4)
810
121
  """
811
- if input.ndim != 2 or mat2.ndim != 2:
812
- raise ValueError(f"For mm, the input tensor must be a matrix, "
813
- f"but got mat1.ndim:{input.ndim}, mat2.ndim:{mat2.ndim}")
814
- return matmul(input, mat2)
122
+ _matmul = _get_cache_prim(ops.MatMul)()
123
+ out = _matmul(input, mat2)
124
+ return out
815
125
 
816
126
 
817
127
  def cummin(x, axis):