mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,8 @@ trans_data_op_info = TBERegOp("TransData") \
28
28
  "DefaultFormat, NC1HWC0, FRACTAL_Z, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC") \
29
29
  .attr("dst_format", "required", "str",
30
30
  "DefaultFormat, NC1HWC0, FRACTAL_Z, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC") \
31
+ .attr("src_subformat", "optional", "int", "all", "1") \
32
+ .attr("dst_subformat", "optional", "int", "all", "1") \
31
33
  .attr("groups", "optional", "int", "all", "1") \
32
34
  .input(0, "src", False, "required", "all") \
33
35
  .output(0, "dst", False, "required", "all") \
@@ -85,6 +85,6 @@ def _get_cache_prim(cls: Primitive) -> Primitive:
85
85
  _PRIM_CACHE[key] = prim
86
86
  return _PRIM_CACHE.get(key)
87
87
 
88
- if _is_need_compile(_temp_func):
88
+ if _is_need_compile(_temp_func): # @jit.cond: True
89
89
  return _new_prim_for_graph
90
90
  return _get_cache_prim_for_pynative
@@ -17,12 +17,13 @@ import functools
17
17
  import types
18
18
  import textwrap
19
19
  import inspect
20
+ import os
20
21
  from mindspore.common.tensor import Tensor
21
22
  from mindspore.ops.primitive import _RunOpHook, Primitive
22
23
  from mindspore._c_expression import PackExpander, PackNode
23
24
  from mindspore.common._stub_tensor import StubTensor
24
25
  from mindspore.common._register_for_tensor import tensor_operator_registry
25
- from mindspore.common.api import _handle_func_args
26
+ from mindspore.common.api import _handle_func_args, _pynative_executor
26
27
 
27
28
 
28
29
  class _PackTensor(StubTensor):
@@ -64,6 +65,7 @@ class PackFunc(Primitive):
64
65
  """pack function with lazy expander"""
65
66
 
66
67
  expander = PackExpander.get_instance()
68
+ current = None
67
69
 
68
70
  def __init__(self, fun, unique_key, cell_obj, is_pynative_mode=False):
69
71
  super(PackFunc, self).__init__(self.__class__.__name__)
@@ -79,19 +81,29 @@ class PackFunc(Primitive):
79
81
  args = (self.cell_obj, *args)
80
82
  return self.func(*args, **kwargs)
81
83
  self.kwargs = kwargs
82
- return super().__call__(*args)
84
+ output = super().__call__(*args)
85
+ if self.is_pynative_mode and self.grad_attach_num > 0:
86
+ output_num = len(output) - self.grad_attach_num
87
+ if output_num == 1:
88
+ return output[0]
89
+ return output[:output_num]
90
+ return output
83
91
 
84
92
  def __expand__(self, args):
93
+ old = PackFunc.current
94
+ PackFunc.current = self
85
95
  if self.cell_obj:
86
96
  args = (self.cell_obj, *args)
87
97
  with _SetMixedPrecision(self.cell_obj):
88
98
  ret = self._run_op(args)
89
- return ret
90
- return self._run_op(args)
99
+ else:
100
+ ret = self._run_op(args)
101
+ PackFunc.current = old
102
+ return ret
91
103
 
92
104
  @staticmethod
93
105
  def is_tracing():
94
- return _RunOpHook.current and _RunOpHook.current.hook is PackFunc._trace_run_op
106
+ return PackFunc.current is not None
95
107
 
96
108
  @staticmethod
97
109
  def _trace_run_op(obj, args):
@@ -197,13 +209,33 @@ def trace(fn):
197
209
 
198
210
  @functools.wraps(fn)
199
211
  def _trace_wrap(*args, **kwargs):
200
- args, kwargs = _handle_func_args(fn, *args, **kwargs)
201
- obj = None
202
-
203
- if args and not isinstance(args[0], Tensor) and hasattr(args[0], fn.__name__):
204
- obj, args = args[0], args[1:]
205
- key = f"{id(obj)}_{id(fn)}"
206
-
207
- return PackFunc(fn, key, obj, True)(*args, **kwargs)
212
+ pynative_grad_flag = _pynative_executor.grad_flag()
213
+ grad_flag_expr = "1" if pynative_grad_flag else "0"
214
+ if _trace_wrap.is_method is None:
215
+ if args and not isinstance(args[0], Tensor) and hasattr(args[0], fn.__name__):
216
+ _trace_wrap.is_method = False
217
+ else:
218
+ _trace_wrap.is_method = True
219
+ if _trace_wrap.is_method:
220
+ # Similar processing has been done in the __call__ of Cell,
221
+ # so only when obj is None, there is need to do `_handle_func_args`.
222
+ args, kwargs = _handle_func_args(fn, *args, **kwargs)
223
+ pack_func_name = "pack" + grad_flag_expr
224
+ pack_func = getattr(fn, pack_func_name, None)
225
+ if pack_func is None:
226
+ pack_func = PackFunc(fn, f"{id(fn)}_{grad_flag_expr}", None, True)
227
+ setattr(fn, pack_func_name, pack_func)
228
+ return pack_func(*args, **kwargs)
229
+ obj, args = args[0], args[1:]
230
+ pack_func_name = "".join((fn.__name__, "pack", grad_flag_expr))
231
+ pack_func = getattr(obj, pack_func_name, None)
232
+ if pack_func is None:
233
+ pack_func = PackFunc(fn, f"{id(obj)}_{id(fn)}_{grad_flag_expr}", obj, True)
234
+ setattr(obj, pack_func_name, pack_func)
235
+ return pack_func(*args, **kwargs)
236
+
237
+ if "MS_DEV_DISABLE_TRACE" in os.environ and os.environ["MS_DEV_DISABLE_TRACE"] == "on":
238
+ return fn
208
239
  _trace_wrap.pack_fn = fn
240
+ _trace_wrap.is_method = None
209
241
  return _trace_wrap
@@ -78,6 +78,11 @@ def get_broadcast_shape(x_shape, y_shape, prim_name, arg_name1="x", arg_name2="y
78
78
  return broadcast_shape
79
79
 
80
80
 
81
+ def dim_not_equal(dim1, dim2):
82
+ """Compare dim in shape"""
83
+ return dim1 != dim2 and dim1 >= 0 and dim2 >= 0
84
+
85
+
81
86
  def get_concat_offset(x_shp, x_type, axis, prim_name):
82
87
  """for concat and concatoffset check args and compute offset"""
83
88
  validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
@@ -98,7 +103,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
98
103
  for i in range(1, len(x_shp)):
99
104
  v = x_shp[i]
100
105
  for j in range(rank_base):
101
- if j != axis and v[j] != x_shp[0][j] and v[j] >= 0 and x_shp[0][j] >= 0:
106
+ if j != axis and dim_not_equal(v[j], x_shp[0][j]):
102
107
  raise ValueError(f"The shape of the two input elements of the Concat operator do not match:"
103
108
  f"shape[0] = {x_shp[0]} and shape[{i}] = {x_shp[i]}.")
104
109
  offset.append(all_shp)
@@ -155,7 +155,7 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
155
155
  _check(indices_shape)
156
156
  indices_len = len(indices_shape)
157
157
  if indices_len == 1:
158
- prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
158
+ prefix = P.Range()(Tensor(0, indices_dtype), F.fill(
159
159
  indices_dtype, (), axis_size), Tensor(1, indices_dtype))
160
160
  return prefix
161
161
 
@@ -850,9 +850,9 @@ def get_fill_vmap_rule(prim, axis_size):
850
850
 
851
851
 
852
852
  @constexpr
853
- def to_tensor_with_type(x, type):
853
+ def to_tensor_with_type(x, dtype):
854
854
  """x to Tensor with type"""
855
- return Tensor(x, type)
855
+ return Tensor(x, dtype)
856
856
 
857
857
 
858
858
  @vmap_rules_getters.register(P.FillV2)
@@ -250,7 +250,7 @@ def vmap_monad_rule(prim, axis_size):
250
250
  def _bdim_at_any(x, src, dst, axis_size):
251
251
  """
252
252
  Moves source axes of an array to the dst axis, and other axes remain in their original order. If the source axes
253
- is 'None', broadcasts the array at dst axis with axis_size.
253
+ is ``None``, broadcasts the array at dst axis with axis_size.
254
254
 
255
255
  Args:
256
256
  x (Tensor or Scalar): The input tensor or scalar. The data type should be one of the following types: float16,
@@ -272,7 +272,7 @@ def _bdim_at_any(x, src, dst, axis_size):
272
272
  def _bdim_at_front(x, src, axis_size):
273
273
  """
274
274
  Moves source axes of an array to the foremost, and other axes remain in their original order. If the source axes
275
- is 'None', broadcasts the array at foremost axis with axis_size.
275
+ is ``None``, broadcasts the array at foremost axis with axis_size.
276
276
 
277
277
  Args:
278
278
  x (Tensor or Scalar): The input tensor or scalar. The data type should be one of the following types: float16,
@@ -289,7 +289,7 @@ def _bdim_at_front(x, src, axis_size):
289
289
  def _bdim_at_back(x, src, axis_size):
290
290
  """
291
291
  Moves source axes of an array to the last, and other axes remain in their original order. If the source axes
292
- is 'None', broadcasts the array at foremost axis with axis_size.
292
+ is ``None``, broadcasts the array at foremost axis with axis_size.
293
293
 
294
294
  Args:
295
295
  x (Tensor or Scalar): The input tensor or scalar. The data type should be one of the following types: float16,
@@ -190,8 +190,8 @@ def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
190
190
  @_primexpr
191
191
  def _get_new_size_by_index(input_size, batch_size, index):
192
192
  """Get the new size of input_size by multiplying input_size[index] by batch_size."""
193
- new_size = ()
194
193
  if input_size is None:
194
+ new_size = ()
195
195
  return new_size
196
196
  new_size = list(input_size)
197
197
  new_size[index] *= batch_size
@@ -62,8 +62,9 @@ def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
62
62
  y_shape = F.shape(y)
63
63
  g_shape = F.shape(g)
64
64
 
65
- if x_dim == y_dim and x_dim == g_dim and \
66
- x_shape == y_shape and x_shape == g_shape:
65
+ is_dim_ok = x_dim == y_dim and x_dim == g_dim
66
+ is_shape_ok = x_shape == y_shape and x_shape == g_shape
67
+ if is_dim_ok and is_shape_ok:
67
68
  dx, dy = prim(x, y, g)
68
69
  return (dx, x_dim), (dy, y_dim)
69
70
 
@@ -113,8 +114,9 @@ def get_broadcast_grad_grad_vmap_rule(prim, axis_size):
113
114
  dx1_shape = F.shape(dx1)
114
115
  dx2_shape = F.shape(dx2)
115
116
 
116
- if x1_dim == x2_dim and dx1_dim == dx2_dim and x1_dim == dx1_dim \
117
- and x1_shape == x2_shape and dx1_shape == dx2_shape:
117
+ is_dim_ok = x1_dim == x2_dim and dx1_dim == dx2_dim and x1_dim == dx1_dim
118
+ is_shape_ok = x1_shape == x2_shape and dx1_shape == dx2_shape
119
+ if is_dim_ok and is_shape_ok:
118
120
  sopd_x1, sopd_x2, sopd_grad = prim(x1, x2, dx1, dx2)
119
121
  return (sopd_x1, x1_dim), (sopd_x2, x1_dim), (sopd_grad, x1_dim)
120
122
 
@@ -66,6 +66,7 @@ def _broadcast_shape(nd, x_ndim, x_shape):
66
66
  @vmap_rules_getters.register(P.ApproximateEqual)
67
67
  @vmap_rules_getters.register(P.TruncateDiv)
68
68
  @vmap_rules_getters.register(P.TruncateMod)
69
+
69
70
  def get_broadcast_binary_op_vmap_rule(prim, axis_size):
70
71
  """VmapRule for binary operations with broadcasting, such as `Add` and `Sub`."""
71
72
 
@@ -216,8 +217,9 @@ def get_lerp_vamp_rule(prim, axis_size):
216
217
  # Both broadcast end and weight to start.
217
218
  else:
218
219
  weight_shape = F.shape(weight)
219
- if (start_dim == end_dim and start_dim == weight_dim) and (
220
- start_shape == end_shape and start_shape == weight_shape):
220
+ is_dim_ok = start_dim == end_dim and start_dim == weight_dim
221
+ is_shape_ok = start_shape == end_shape and start_shape == weight_shape
222
+ if is_dim_ok and is_shape_ok:
221
223
  out = prim(start, end, weight)
222
224
  return out, start_dim
223
225
  start, end = broadcast_a_b_shape(start_bdim, end_bdim)
@@ -900,3 +902,4 @@ get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule)
900
902
  get_unop_vmap_rule = vmap_rules_getters.register(P.Trunc)(get_unop_vmap_rule)
901
903
  get_unop_vmap_rule = vmap_rules_getters.register(P.PopulationCount)(get_unop_vmap_rule)
902
904
  get_unop_vmap_rule = vmap_rules_getters.register(P.Square)(get_unop_vmap_rule)
905
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Eps)(get_unop_vmap_rule)
@@ -325,9 +325,10 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
325
325
  # If rank is larger than 1, we need to reduce result when reduction != 'none'
326
326
  if max_rank > 1:
327
327
  reduce_indexes = tuple(range(1, max_rank))
328
- if logits_dim == label_dim and F.shape(logits) == F.shape(label) \
329
- and logits_dim == weight_dim and F.shape(logits) == F.shape(weight) \
330
- and logits_dim == pos_weight_dim and F.shape(logits) == F.shape(pos_weight):
328
+ logits_dim_ok = logits_dim == label_dim and logits_dim == weight_dim and logits_dim == pos_weight_dim
329
+ shape = F.shape(logits)
330
+ shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
331
+ if logits_dim_ok and shape_ok:
331
332
  if prim_reduction == 'none':
332
333
  output = prim(logits, label, weight, pos_weight)
333
334
  elif prim_reduction in ('mean', 'sum'):
@@ -798,7 +799,8 @@ def get_instance_norm_rule(prim, axis_size):
798
799
  output_x, updated_moving_mean, updated_moving_variance = prim(input_x, gamma, beta, mean, variance, u_monad)
799
800
  return (output_x, None), (updated_moving_mean, None), (updated_moving_variance, None)
800
801
 
801
- if gamma_dim != 0 or beta_dim != gamma_dim or mean_dim != gamma_dim or variance_dim != gamma_dim:
802
+ precondition = gamma_dim != 0 or beta_dim != gamma_dim or mean_dim != gamma_dim or variance_dim != gamma_dim
803
+ if precondition:
802
804
  # pylint: disable=too-many-format-args
803
805
  raise ValueError(
804
806
  "For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
@@ -1679,7 +1681,8 @@ def get_rmsprop_vmap_rule(prim, axis_size):
1679
1681
  res = prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
1680
1682
  u_monad) # low dimensional operator;
1681
1683
  return (res, None)
1682
- if var_dim != 0 or var_dim != mean_square_dim or var_dim != moment_dim or var_dim != grad_dim:
1684
+ precondition = var_dim != 0 or var_dim != mean_square_dim or var_dim != moment_dim or var_dim != grad_dim
1685
+ if precondition:
1683
1686
  raise ValueError(
1684
1687
  f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_square_dim' "
1685
1688
  f"and 'moment_dim' and 'grad_dim' and not equal to 0, "
@@ -1735,8 +1738,8 @@ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
1735
1738
  var = prim(var, mean_grad, mean_square,
1736
1739
  mom, grad, lr, rho, momentum, eps, u_monad)
1737
1740
  return (var, None)
1738
-
1739
- if var_dim != 0 or var_dim != mean_grad_dim or var_dim != mean_square_dim or var_dim != mom_dim:
1741
+ precondition = var_dim != 0 or var_dim != mean_grad_dim or var_dim != mean_square_dim or var_dim != mom_dim
1742
+ if precondition:
1740
1743
  raise ValueError(
1741
1744
  f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_grad_dim' "
1742
1745
  f"and 'mean_square_dim' and 'mom_dim' and not equal to 0, "
@@ -2000,6 +2003,57 @@ def get_sparse_apply_ftrl_vmap_rule(prim, axis_size):
2000
2003
  return vmap_rule
2001
2004
 
2002
2005
 
2006
+ @vmap_rules_getters.register(P.Dense)
2007
+ def get_dense_vmap_rule(prim, axis_size):
2008
+ """VmapRule for `Dense` operation."""
2009
+ if isinstance(prim, str):
2010
+ prim = Primitive(prim)
2011
+
2012
+ batch_matmul = P.BatchMatMul(transpose_b=True)
2013
+
2014
+ @_primexpr
2015
+ def get_start_mid_end(x_shape):
2016
+ start = x_shape[0]
2017
+ mid = 1
2018
+ for shp in x_shape[1:-1]:
2019
+ mid *= shp
2020
+ end = x_shape[-1]
2021
+ return start, mid, end
2022
+
2023
+ def vmap_rule(x_bdim, w_bdim, b_bdim):
2024
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, w_bdim, b_bdim)
2025
+ if is_all_none:
2026
+ return result
2027
+
2028
+ x, x_dim = x_bdim
2029
+ w, w_dim = w_bdim
2030
+ b, b_dim = b_bdim
2031
+ x = _bdim_at_front(x, x_dim, axis_size)
2032
+ w = _bdim_at_front(w, w_dim, axis_size)
2033
+ if b is not None:
2034
+ b = _bdim_at_front(b, b_dim, axis_size)
2035
+
2036
+ x_shape = x.shape
2037
+ start, mid, end = get_start_mid_end(x_shape)
2038
+
2039
+ x = x.reshape(start, mid, end)
2040
+
2041
+ out = batch_matmul(x, w)
2042
+ out_shape = tuple(x_shape[:-1]) + (out.shape[-1],)
2043
+ out = out.reshape(out_shape)
2044
+
2045
+ if b is not None:
2046
+ b_shape = b.shape
2047
+ b_shape = (start,) + (1,) * (len(out_shape) - 2) + (b_shape[-1],)
2048
+ b = b.reshape(b_shape)
2049
+
2050
+ out = out + b
2051
+
2052
+ return out, 0
2053
+
2054
+ return vmap_rule
2055
+
2056
+
2003
2057
  # Unary vmap
2004
2058
  get_unop_vmap_rule = vmap_rules_getters.register(P.Elu)(get_unop_vmap_rule)
2005
2059
  get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule)
@@ -0,0 +1,54 @@
1
+ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
+ #
3
+ # Copyright 2023-2024 Huawei Technologies Co., Ltd
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ============================================================================
17
+ """Operator argument data type cast function."""
18
+ from enum import Enum
19
+
20
+
21
+ class TypeCastKind(Enum):
22
+ INT_TO_TUPLE = 1
23
+ INT_OR_TUPLE_TO_LIST = 2
24
+
25
+
26
+ def type_it(src_data, cast_type):
27
+ """
28
+ cast operator argument data type.
29
+ """
30
+ if cast_type == TypeCastKind.INT_TO_TUPLE:
31
+ if isinstance(src_data, tuple):
32
+ return src_data
33
+
34
+ if isinstance(src_data, int):
35
+ return (src_data,)
36
+
37
+ raise TypeError(f'{src_data} is the wrong data type.')
38
+
39
+ if cast_type == TypeCastKind.INT_OR_TUPLE_TO_LIST:
40
+ if isinstance(src_data, list):
41
+ return src_data
42
+
43
+ if isinstance(src_data, int):
44
+ return [
45
+ src_data,
46
+ ]
47
+
48
+ if isinstance(src_data, tuple):
49
+ dst_list = [item for item in src_data]
50
+ return dst_list
51
+
52
+ raise TypeError(f'{src_data} is the wrong data type.')
53
+
54
+ raise TypeError("Unsupported type cast")
@@ -20,6 +20,7 @@ from __future__ import absolute_import
20
20
  from functools import partial
21
21
 
22
22
  from types import FunctionType, MethodType
23
+ import numpy as np
23
24
  import mindspore as ms
24
25
  from mindspore import context
25
26
  from mindspore.common.parameter import Parameter, ParameterTuple
@@ -28,7 +29,8 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
28
29
  TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
29
30
  SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
30
31
  ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
31
- ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_
32
+ ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
33
+ HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_
32
34
  from mindspore.common import dtype as mstype
33
35
  from mindspore.common.api import jit, _pynative_executor, _wrap_func
34
36
  from mindspore.common.api import _add_flags, _core
@@ -36,7 +38,8 @@ from mindspore.ops.primitive import Primitive
36
38
  from mindspore.ops import signature as sig
37
39
 
38
40
  __all__ = [TupleAdd_, ListAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_,
39
- ListSliceSetItem_, ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_]
41
+ ListSliceSetItem_, ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_,
42
+ HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_]
40
43
 
41
44
 
42
45
  def add_flags(fn=None, **flags):
@@ -334,7 +337,7 @@ class GradOperation(GradOperation_):
334
337
  self.get_all = get_all
335
338
  self.get_by_list = get_by_list
336
339
  self.sens_param = sens_param
337
- GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False)
340
+ GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False, False)
338
341
  self.grad_fn = None
339
342
  self.fn = None
340
343
  self.weights_id = None
@@ -511,8 +514,8 @@ class _Grad(GradOperation_):
511
514
  A higher-order function which is used to generate the gradient function by position for the input function.
512
515
  """
513
516
 
514
- def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False, get_value=False,
515
- return_ids=False):
517
+ def __init__(self, get_all=False, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False,
518
+ get_value=False, return_ids=False, merge_forward=False):
516
519
  """Initialize _Grad."""
517
520
  if not isinstance(get_by_position, bool):
518
521
  raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
@@ -532,14 +535,16 @@ class _Grad(GradOperation_):
532
535
  if not isinstance(return_ids, bool):
533
536
  raise TypeError(f"For '_Grad', the 'return_ids' should be bool, "
534
537
  f"but got {type(return_ids).__name__}")
538
+ self.get_all = get_all
535
539
  self.get_by_position = get_by_position
536
540
  self.get_by_list = get_by_list
537
541
  self.sens_param = sens_param
538
542
  self.has_aux = has_aux
539
543
  self.get_value = get_value
540
544
  self.return_ids = return_ids
541
- GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position, has_aux, get_value,
542
- return_ids)
545
+ self.merge_forward = merge_forward
546
+ GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, get_by_position, has_aux, get_value,
547
+ return_ids, merge_forward)
543
548
  self.grad_fn = None
544
549
  self.fn = None
545
550
  self.pynative_ = False
@@ -562,8 +567,8 @@ class _Grad(GradOperation_):
562
567
  res += (stop_gradient(item),)
563
568
  return res
564
569
 
565
- grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value,
566
- self.return_ids)
570
+ grad_ = _Grad(self.get_all, self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
571
+ self.get_value, self.return_ids, self.merge_forward)
567
572
  # If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
568
573
  # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
569
574
  # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
@@ -738,6 +743,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
738
743
  sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
739
744
 
740
745
  def __call__(self, *args):
746
+ for arg in args:
747
+ if isinstance(arg, np.ndarray):
748
+ raise TypeError("For 'MultitypeFuncGraph', the input can not be numpy.ndarray")
741
749
  if len(self.entries) == 1:
742
750
  output = self.entries[0][1](*args)
743
751
  return output
@@ -890,7 +898,7 @@ class Map(Map_):
890
898
  If `ops` is `None`, the first input is the operation, and the other is inputs.
891
899
 
892
900
  Outputs:
893
- Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`.
901
+ Sequence, the sequence of output after applying the ops function. e.g. `ops(args[0][i], args[1][i])`.
894
902
 
895
903
  Supported Platforms:
896
904
  ``Ascend`` ``GPU`` ``CPU``
@@ -1046,6 +1054,25 @@ class _ListExtend(ListExtend_):
1046
1054
  _extend = _ListExtend("extend")
1047
1055
 
1048
1056
 
1057
+ class _DictSetItem(DictSetItem_):
1058
+ """
1059
+ A metafuncgraph class that setitem for the dict.
1060
+
1061
+ Args:
1062
+ name (str): The name of the metafuncgraph object.
1063
+ """
1064
+
1065
+ def __init__(self, name):
1066
+ """Initialize _DictClear."""
1067
+ DictSetItem_.__init__(self, name)
1068
+
1069
+ def __call__(self, *args):
1070
+ pass
1071
+
1072
+
1073
+ _dict_setitem = _DictSetItem("setitem")
1074
+
1075
+
1049
1076
  class _DictClear(DictClear_):
1050
1077
  """
1051
1078
  A metafuncgraph class that clear the dict.
@@ -13,8 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """math Operations."""
16
+ import mindspore.ops as ops
16
17
  from mindspore.ops import functional as F
17
18
  from mindspore.ops.function.math_func import cummin as cummin_
19
+ from mindspore.ops._primitive_cache import _get_cache_prim
18
20
 
19
21
 
20
22
  def matmul(x1, x2, dtype=None):
@@ -117,10 +119,9 @@ def mm(input, mat2):
117
119
  >>> print(out.shape)
118
120
  (2, 4)
119
121
  """
120
- if input.ndim != 2 or mat2.ndim != 2:
121
- raise ValueError(f"For mm, the input tensor must be a matrix, "
122
- f"but got mat1.ndim:{input.ndim}, mat2.ndim:{mat2.ndim}")
123
- return matmul(input, mat2)
122
+ _matmul = _get_cache_prim(ops.MatMul)()
123
+ out = _matmul(input, mat2)
124
+ return out
124
125
 
125
126
 
126
127
  def cummin(x, axis):